use crate::dht::{
errors::DHTError,
rpc::{DHTMessage, DHTMessageProcessor, DHTRecord, DHTRequest, DHTResponse},
swarm::{build_swarm, DHTEvent, DHTSwarm, DHTSwarmEvent},
DHTConfig, RecordValidator,
};
use libp2p::{
core::transport::ListenerId,
futures::StreamExt,
identify::Event as IdentifyEvent,
kad::{
self,
kbucket::{Distance, NodeStatus},
record::{store::RecordStore, Key},
KademliaEvent, PeerRecord, QueryResult, Quorum, Record,
},
swarm::{
dial_opts::{DialOpts, PeerCondition},
SwarmEvent,
},
Multiaddr, PeerId,
};
use std::fmt;
use std::{collections::HashMap, time::Duration};
use tokio;
pub struct DHTProcessor<V: RecordValidator + 'static> {
config: DHTConfig,
peer_id: PeerId,
processor: DHTMessageProcessor,
swarm: DHTSwarm,
requests: HashMap<kad::QueryId, DHTMessage>,
listeners: HashMap<Multiaddr, ListenerId>,
kad_last_range: Option<(Distance, Distance)>,
validator: Option<V>,
}
macro_rules! store_request {
($self:expr, $message:expr, $result:expr) => {
let result: Result<kad::QueryId, DHTError> = $result.map_err(|e| e.into());
match result {
Ok(query_id) => {
$self.requests.insert(query_id, $message);
}
Err(e) => {
$message.respond(Err(e));
}
};
};
}
impl<V> DHTProcessor<V>
where
V: RecordValidator + 'static,
{
pub(crate) fn spawn(
keypair: &libp2p::identity::Keypair,
peer_id: PeerId,
validator: Option<V>,
config: DHTConfig,
processor: DHTMessageProcessor,
) -> Result<tokio::task::JoinHandle<Result<(), DHTError>>, DHTError> {
let swarm = build_swarm(keypair, &peer_id, &config)?;
let mut node = DHTProcessor {
peer_id,
config,
processor,
swarm,
requests: HashMap::default(),
listeners: HashMap::default(),
kad_last_range: None,
validator,
};
Ok(tokio::spawn(async move { node.process().await }))
}
async fn process(&mut self) -> Result<(), DHTError> {
let mut bootstrap_tick =
tokio::time::interval(Duration::from_secs(self.config.bootstrap_interval));
let mut peer_dialing_tick =
tokio::time::interval(Duration::from_secs(self.config.peer_dialing_interval));
loop {
tokio::select! {
message = self.processor.pull_message() => {
match message {
Some(m) => self.process_message(m).await,
None => {
error!("DHT processing loop unexpectedly closed.");
break
},
}
}
event = self.swarm.select_next_some() => {
self.process_swarm_event(event).await
}
_ = bootstrap_tick.tick() => self.execute_bootstrap()?,
_ = peer_dialing_tick.tick() => self.dial_next_peer(),
}
}
Ok(())
}
async fn process_message(&mut self, message: DHTMessage) {
dht_event_trace(self, &message);
match message.request {
DHTRequest::AddPeers { ref peers } => {
let result = self.add_peers(peers).await.map(|_| DHTResponse::Success);
message.respond(result);
}
DHTRequest::StartListening { ref address } => {
let result = self.start_listening(address).map(|_| DHTResponse::Success);
message.respond(result);
}
DHTRequest::StopListening { ref address } => {
let result = self.stop_listening(address).map(|_| DHTResponse::Success);
message.respond(result);
}
DHTRequest::GetAddresses { external } => {
let listeners: Vec<Multiaddr> = if external {
self.swarm
.external_addresses()
.map(|addr_record| addr_record.addr.to_owned())
.collect::<Vec<Multiaddr>>()
} else {
self.swarm
.listeners()
.map(|addr| addr.to_owned())
.collect::<Vec<Multiaddr>>()
};
message.respond(Ok(DHTResponse::GetAddresses(listeners)));
}
DHTRequest::Bootstrap => {
message.respond(self.execute_bootstrap().map(|_| DHTResponse::Success));
}
DHTRequest::GetProviders { ref key } => {
store_request!(
self,
message,
Ok::<kad::QueryId, DHTError>(
self.swarm.behaviour_mut().kad.get_providers(Key::new(key))
)
);
}
DHTRequest::GetNetworkInfo => {
let info = self.swarm.network_info();
message.respond(Ok(DHTResponse::GetNetworkInfo(info.into())));
}
DHTRequest::StartProviding { ref key } => {
store_request!(
self,
message,
self.swarm
.behaviour_mut()
.kad
.start_providing(Key::new(key))
);
}
DHTRequest::GetRecord { ref key } => {
store_request!(
self,
message,
Ok::<kad::QueryId, DHTError>(
self.swarm
.behaviour_mut()
.kad
.get_record(Key::new(key), Quorum::One)
)
);
}
DHTRequest::PutRecord { ref key, ref value } => {
let value_owned = value.to_owned();
if self.validate(value).await {
let record = Record {
key: Key::new(key),
value: value_owned,
publisher: None,
expires: None,
};
store_request!(
self,
message,
self.swarm
.behaviour_mut()
.kad
.put_record(record, Quorum::One)
);
} else {
message.respond(Err(DHTError::ValidationError(value_owned)));
}
}
};
}
async fn process_swarm_event(&mut self, event: DHTSwarmEvent) {
dht_event_trace(self, &event);
match event {
SwarmEvent::Behaviour(DHTEvent::Kademlia(e)) => self.process_kad_event(e).await,
SwarmEvent::Behaviour(DHTEvent::Identify(e)) => self.process_identify_event(e),
SwarmEvent::NewListenAddr { address: _, .. } => {}
SwarmEvent::ConnectionEstablished { peer_id: _, .. } => {}
SwarmEvent::ConnectionClosed {
peer_id: _,
cause: _,
..
} => {}
SwarmEvent::IncomingConnection {
local_addr: _,
send_back_addr: _,
} => {}
SwarmEvent::IncomingConnectionError {
local_addr: _,
send_back_addr: _,
error: _,
} => {}
SwarmEvent::OutgoingConnectionError {
peer_id: _,
error: _,
} => {}
SwarmEvent::BannedPeer { peer_id: _, .. } => {}
SwarmEvent::ExpiredListenAddr {
listener_id: _,
address: _,
} => {}
SwarmEvent::ListenerClosed {
listener_id: _,
addresses: _,
reason: _,
} => {}
SwarmEvent::ListenerError {
listener_id: _,
error: _,
} => {}
SwarmEvent::Dialing(_) => {}
}
}
async fn process_kad_event(&mut self, event: KademliaEvent) {
match event {
KademliaEvent::OutboundQueryCompleted { id, result, .. } => match result {
QueryResult::GetRecord(Ok(ok)) => {
for PeerRecord {
record: Record { key, value, .. },
..
} in ok.records
{
if let Some(message) = self.requests.remove(&id) {
let is_valid = self.validate(&value).await;
message.respond(Ok(DHTResponse::GetRecord(DHTRecord {
key: key.to_vec(),
value: if is_valid { Some(value) } else { None },
})));
};
}
}
QueryResult::GetRecord(Err(e)) => {
if let Some(message) = self.requests.remove(&id) {
match e {
kad::GetRecordError::NotFound { key, .. } => {
message.respond(Ok(DHTResponse::GetRecord(DHTRecord {
key: key.to_vec(),
value: None,
})))
}
e => message.respond(Err(DHTError::from(e))),
};
}
}
QueryResult::PutRecord(Ok(kad::PutRecordOk { key })) => {
if let Some(message) = self.requests.remove(&id) {
message.respond(Ok(DHTResponse::PutRecord { key: key.to_vec() }));
}
}
QueryResult::PutRecord(Err(e)) => {
match e {
kad::PutRecordError::Timeout {
ref key,
quorum: _,
success: _,
}
| kad::PutRecordError::QuorumFailed {
ref key,
quorum: _,
success: _,
} => {
let record = self.swarm.behaviour_mut().kad.store_mut().get(key);
trace!("Has internal record? {:?}", record);
}
}
if let Some(message) = self.requests.remove(&id) {
message.respond(Err(DHTError::from(e)));
}
}
QueryResult::StartProviding(Ok(kad::AddProviderOk { key })) => {
if let Some(message) = self.requests.remove(&id) {
message.respond(Ok(DHTResponse::StartProviding { key: key.to_vec() }));
}
}
QueryResult::StartProviding(Err(e)) => {
if let Some(message) = self.requests.remove(&id) {
message.respond(Err(DHTError::from(e)));
}
}
QueryResult::GetProviders(Ok(kad::GetProvidersOk {
providers,
key,
closest_peers: _,
})) => {
if let Some(message) = self.requests.remove(&id) {
message.respond(Ok(DHTResponse::GetProviders {
providers: providers.into_iter().collect(),
key: key.to_vec(),
}));
}
}
QueryResult::GetProviders(Err(e)) => {
if let Some(message) = self.requests.remove(&id) {
message.respond(Err(DHTError::from(e)));
}
}
QueryResult::Bootstrap(Ok(kad::BootstrapOk {
peer: _,
num_remaining: _,
})) => {}
QueryResult::Bootstrap(Err(kad::BootstrapError::Timeout {
peer: _,
num_remaining: _,
})) => {}
_ => {}
},
KademliaEvent::InboundRequest { request } => match request {
kad::InboundRequest::FindNode {
num_closer_peers: _,
} => {}
kad::InboundRequest::GetProvider {
num_closer_peers: _,
num_provider_peers: _,
} => {}
kad::InboundRequest::AddProvider { record: _ } => {}
kad::InboundRequest::GetRecord {
num_closer_peers: _,
present_locally: _,
} => {}
kad::InboundRequest::PutRecord { source, record, .. } => match record {
Some(rec) => {
if self.validate(&rec.value).await {
if let Err(e) =
self.swarm.behaviour_mut().kad.store_mut().put(rec.clone())
{
warn!(
"InboundRequest::PutRecord write failed: {:?} {:?}, {}",
rec, source, e
);
}
} else {
warn!(
"InboundRequest::PutRecord validation failed: {:?} {:?}",
rec, source
);
}
}
None => warn!("InboundRequest::PutRecord failed; empty record"),
},
},
KademliaEvent::RoutingUpdated {
peer: _,
is_new_peer: _,
addresses: _,
..
} => {}
KademliaEvent::UnroutablePeer { peer: _ } => {}
KademliaEvent::RoutablePeer {
peer: _,
address: _,
} => {}
KademliaEvent::PendingRoutablePeer {
peer: _,
address: _,
} => {}
}
}
fn process_identify_event(&mut self, event: IdentifyEvent) {
if let IdentifyEvent::Received { peer_id, info } = event {
if info
.protocols
.iter()
.any(|p| p.as_bytes() == kad::protocol::DEFAULT_PROTO_NAME)
{
for addr in &info.listen_addrs {
self.swarm
.behaviour_mut()
.kad
.add_address(&peer_id, addr.clone());
}
}
}
}
fn dial_next_peer(&mut self) {
let mut to_dial = None;
for kbucket in self.swarm.behaviour_mut().kad.kbuckets() {
if let Some(range) = self.kad_last_range {
if kbucket.range() == range {
continue;
}
}
for entry in kbucket.iter() {
if entry.status == NodeStatus::Disconnected {
let peer_id = entry.node.key.preimage();
let dial_opts = DialOpts::peer_id(*peer_id)
.condition(PeerCondition::Disconnected)
.addresses(entry.node.value.clone().into_vec())
.extend_addresses_through_behaviour()
.build();
to_dial = Some((dial_opts, kbucket.range()));
break;
}
}
}
if let Some((dial_opts, range)) = to_dial {
if let Err(e) = self.swarm.dial(dial_opts) {
warn!("failed to dial: {:?}", e);
}
self.kad_last_range = Some(range);
}
}
fn start_listening(&mut self, address: &libp2p::Multiaddr) -> Result<(), DHTError> {
dht_event_trace(self, &format!("Start listening on {}", address));
let listener_id = self.swarm.listen_on(address.to_owned())?;
if let Some(previous_id) = self.listeners.insert(address.to_owned(), listener_id) {
assert!(self.swarm.remove_listener(previous_id));
}
Ok(())
}
fn stop_listening(&mut self, address: &libp2p::Multiaddr) -> Result<(), DHTError> {
dht_event_trace(self, &format!("Stop listening on {}", address));
if let Some(listener_id) = self.listeners.get(address) {
assert!(self.swarm.remove_listener(listener_id.to_owned()));
}
Ok(())
}
async fn add_peers(&mut self, peers: &[libp2p::Multiaddr]) -> Result<(), DHTError> {
for multiaddress in peers {
let mut addr = multiaddress.to_owned();
if let Some(libp2p::multiaddr::Protocol::P2p(p2p_hash)) = addr.pop() {
let peer_id = PeerId::from_multihash(p2p_hash).unwrap();
if peer_id != self.peer_id {
self.swarm.behaviour_mut().kad.add_address(&peer_id, addr);
}
}
}
Ok(())
}
fn execute_bootstrap(&mut self) -> Result<(), DHTError> {
dht_event_trace(self, &"Execute bootstrap");
match self.swarm.behaviour_mut().kad.bootstrap() {
Ok(_) => Ok(()),
Err(_) => {
Ok(())
}
}
}
async fn validate(&mut self, data: &[u8]) -> bool {
if let Some(v) = self.validator.as_mut() {
v.validate(data).await
} else {
true
}
}
}
impl<V> fmt::Debug for DHTProcessor<V>
where
V: RecordValidator + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DHTNode")
.field("peer_id", &self.peer_id)
.field("config", &self.config)
.finish()
}
}
impl<V> Drop for DHTProcessor<V>
where
V: RecordValidator + 'static,
{
fn drop(&mut self) {}
}
fn dht_event_trace<V: RecordValidator, T: std::fmt::Debug>(processor: &DHTProcessor<V>, data: &T) {
let peer_id_b58 = processor.peer_id.to_base58();
trace!(
"\nFrom ..{:#?}..\n{:#?}",
peer_id_b58.get(8..14).unwrap_or("INVALID PEER ID"),
data
);
}