use futures::channel::oneshot;
use prost::Message;
use std::error::Error;
use std::io;
use std::{convert::TryFrom, time::Duration, time::Instant};
use async_trait::async_trait;
use libp2prs_core::upgrade::UpgradeInfo;
use libp2prs_core::{Multiaddr, PeerId, ProtocolId};
use libp2prs_swarm::connection::Connection;
use libp2prs_swarm::protocol_handler::{IProtocolHandler, Notifiee, ProtocolHandler};
use libp2prs_swarm::substream::{Substream, SubstreamView};
use libp2prs_swarm::Control as SwarmControl;
use libp2prs_traits::{ReadEx, WriteEx};
use crate::kad::KadPoster;
use crate::query::QueryStats;
use crate::record::{self, Record};
use crate::{dht_proto as proto, KadError, ProviderRecord};
pub const DEFAULT_PROTO_NAME: &[u8] = b"/ipfs/kad/1.0.0";
pub const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
pub const DEFAULT_MAX_REUSE_TRIES: usize = 3;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum KadConnectionType {
NotConnected = 0,
Connected = 1,
CanConnect = 2,
CannotConnect = 3,
}
impl From<proto::message::ConnectionType> for KadConnectionType {
fn from(raw: proto::message::ConnectionType) -> KadConnectionType {
use proto::message::ConnectionType::*;
match raw {
NotConnected => KadConnectionType::NotConnected,
Connected => KadConnectionType::Connected,
CanConnect => KadConnectionType::CanConnect,
CannotConnect => KadConnectionType::CannotConnect,
}
}
}
impl Into<proto::message::ConnectionType> for KadConnectionType {
fn into(self) -> proto::message::ConnectionType {
use proto::message::ConnectionType::*;
match self {
KadConnectionType::NotConnected => NotConnected,
KadConnectionType::Connected => Connected,
KadConnectionType::CanConnect => CanConnect,
KadConnectionType::CannotConnect => CannotConnect,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KadPeer {
pub node_id: PeerId,
pub multiaddrs: Vec<Multiaddr>,
pub connection_ty: KadConnectionType,
}
impl TryFrom<proto::message::Peer> for KadPeer {
type Error = io::Error;
fn try_from(peer: proto::message::Peer) -> Result<KadPeer, Self::Error> {
let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
let mut addrs = Vec::with_capacity(peer.addrs.len());
for addr in peer.addrs.into_iter() {
let as_ma = Multiaddr::try_from(addr).map_err(invalid_data)?;
addrs.push(as_ma);
}
debug_assert_eq!(addrs.len(), addrs.capacity());
let connection_ty = proto::message::ConnectionType::from_i32(peer.connection)
.ok_or_else(|| invalid_data("unknown connection type"))?
.into();
Ok(KadPeer {
node_id,
multiaddrs: addrs,
connection_ty,
})
}
}
impl Into<proto::message::Peer> for KadPeer {
fn into(self) -> proto::message::Peer {
proto::message::Peer {
id: self.node_id.to_bytes(),
addrs: self.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
connection: {
let ct: proto::message::ConnectionType = self.connection_ty.into();
ct as i32
},
}
}
}
impl Into<PeerId> for KadPeer {
fn into(self) -> PeerId {
self.node_id
}
}
#[derive(Debug, Clone)]
pub struct KademliaProtocolConfig {
protocol_name: ProtocolId,
max_packet_size: usize,
max_reuse_count: usize,
}
impl KademliaProtocolConfig {
pub fn protocol_name(&self) -> &ProtocolId {
&self.protocol_name
}
pub fn set_protocol_name(&mut self, name: ProtocolId) {
self.protocol_name = name;
}
pub fn set_max_packet_size(&mut self, size: usize) {
self.max_packet_size = size;
}
}
impl Default for KademliaProtocolConfig {
fn default() -> Self {
KademliaProtocolConfig {
protocol_name: DEFAULT_PROTO_NAME.into(),
max_packet_size: DEFAULT_MAX_PACKET_SIZE,
max_reuse_count: DEFAULT_MAX_REUSE_TRIES,
}
}
}
#[derive(Debug, Clone)]
pub struct KadProtocolHandler {
config: KademliaProtocolConfig,
allow_listening: bool,
idle_timeout: Duration,
poster: KadPoster,
}
impl KadProtocolHandler {
pub(crate) fn new(config: KademliaProtocolConfig, poster: KadPoster) -> Self {
KadProtocolHandler {
config,
allow_listening: false,
idle_timeout: Duration::from_secs(10),
poster,
}
}
pub fn protocol_name(&self) -> &ProtocolId {
&self.config.protocol_name
}
pub fn set_max_packet_size(&mut self, size: usize) {
self.config.max_packet_size = size;
}
}
impl UpgradeInfo for KadProtocolHandler {
type Info = ProtocolId;
fn protocol_info(&self) -> Vec<Self::Info> {
vec![self.config.protocol_name.to_owned()]
}
}
impl Notifiee for KadProtocolHandler {
fn connected(&mut self, conn: &mut Connection) {
let peer_id = conn.remote_peer();
let _ = self.poster.unbounded_post(ProtocolEvent::PeerConnected(peer_id));
}
fn disconnected(&mut self, conn: &mut Connection) {
let peer_id = conn.remote_peer();
let _ = self.poster.unbounded_post(ProtocolEvent::PeerDisconnected(peer_id));
}
fn identified(&mut self, peer_id: PeerId) {
let _ = self.poster.unbounded_post(ProtocolEvent::PeerIdentified(peer_id));
}
fn address_changed(&mut self, addrs: Vec<Multiaddr>) {
let _ = self.poster.unbounded_post(ProtocolEvent::AddressChanged(addrs));
}
}
#[async_trait]
impl ProtocolHandler for KadProtocolHandler {
async fn handle(&mut self, mut stream: Substream, _info: <Self as UpgradeInfo>::Info) -> Result<(), Box<dyn Error>> {
let source = stream.remote_peer();
log::trace!("Kad Handler opened for remote {:?}", source);
loop {
let packet = stream.read_one(self.config.max_packet_size).await?;
let request = proto::Message::decode(&packet[..]).map_err(|_| KadError::Decode)?;
log::trace!("Kad handler recv : {:?}", request);
let request = proto_to_req_msg(request)?;
let (tx, rx) = oneshot::channel();
let evt = ProtocolEvent::KadRequest {
request,
source,
reply: tx,
};
self.poster.post(evt).await?;
let response = rx.await??;
if let Some(response) = response {
let proto_struct = resp_msg_to_proto(response);
let mut buf = Vec::with_capacity(proto_struct.encoded_len());
proto_struct.encode(&mut buf).expect("Vec<u8> provides capacity as needed");
let _ = stream.write_one(&buf).await?;
}
}
}
fn box_clone(&self) -> IProtocolHandler {
Box::new(self.clone())
}
}
pub(crate) struct KadMessenger {
pub(crate) stream: Substream,
pub(crate) config: KademliaProtocolConfig,
peer: PeerId,
reuse: usize,
}
#[derive(Debug)]
pub struct KadMessengerView {
pub peer: PeerId,
pub stream: SubstreamView,
pub reuse: usize,
}
impl KadMessenger {
pub(crate) async fn build(mut swarm: SwarmControl, peer: PeerId, config: KademliaProtocolConfig) -> Result<Self, KadError> {
let stream = swarm.new_stream_no_routing(peer, vec![config.protocol_name().to_owned()]).await?;
Ok(Self {
stream,
config,
peer,
reuse: 0,
})
}
pub(crate) fn to_view(&self) -> KadMessengerView {
KadMessengerView {
peer: self.peer,
stream: self.stream.to_view(),
reuse: self.reuse,
}
}
pub(crate) fn get_peer_id(&self) -> &PeerId {
&self.peer
}
pub(crate) fn reuse(&mut self) -> bool {
self.reuse += 1;
self.reuse < self.config.max_reuse_count
}
async fn send_message(&mut self, request: KadRequestMsg) -> Result<(), KadError> {
let proto_struct = req_msg_to_proto(request);
let mut buf = Vec::with_capacity(proto_struct.encoded_len());
proto_struct.encode(&mut buf).expect("Vec<u8> provides capacity as needed");
self.stream.write_one(&buf).await?;
Ok(())
}
async fn send_request(&mut self, request: KadRequestMsg) -> Result<KadResponseMsg, KadError> {
let proto_struct = req_msg_to_proto(request);
let mut buf = Vec::with_capacity(proto_struct.encoded_len());
proto_struct.encode(&mut buf).expect("Vec<u8> provides capacity as needed");
self.stream.write_one(&buf).await?;
let packet = self.stream.read_one(self.config.max_packet_size).await?;
let response = proto::Message::decode(&packet[..]).map_err(|_| KadError::Decode)?;
log::trace!("Kad handler recv : {:?}", response);
let response = proto_to_resp_msg(response)?;
Ok(response)
}
pub(crate) async fn send_find_node(&mut self, key: record::Key) -> Result<Vec<KadPeer>, KadError> {
let req = KadRequestMsg::FindNode { key };
let rsp = self.send_request(req).await?;
match rsp {
KadResponseMsg::FindNode { closer_peers } => Ok(closer_peers),
_ => Err(KadError::UnexpectedMessage("wrong message type received when FindNode")),
}
}
pub(crate) async fn send_get_providers(&mut self, key: record::Key) -> Result<(Vec<KadPeer>, Vec<KadPeer>), KadError> {
let req = KadRequestMsg::GetProviders { key };
let rsp = self.send_request(req).await?;
match rsp {
KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
} => Ok((closer_peers, provider_peers)),
_ => Err(KadError::UnexpectedMessage("wrong message type received when GetProviders")),
}
}
pub(crate) async fn send_add_provider(
&mut self,
provider_record: ProviderRecord,
addresses: Vec<Multiaddr>,
) -> Result<(), KadError> {
let provider = KadPeer {
node_id: provider_record.provider,
multiaddrs: addresses,
connection_ty: KadConnectionType::Connected,
};
let req = KadRequestMsg::AddProvider {
key: provider_record.key,
provider,
};
self.send_message(req).await?;
Ok(())
}
pub(crate) async fn send_get_value(&mut self, key: record::Key) -> Result<(Vec<KadPeer>, Option<Record>), KadError> {
let req = KadRequestMsg::GetValue { key };
let rsp = self.send_request(req).await?;
match rsp {
KadResponseMsg::GetValue { record, closer_peers } => Ok((closer_peers, record)),
_ => Err(KadError::UnexpectedMessage("wrong message type received when GetValue")),
}
}
pub(crate) async fn send_put_value(&mut self, record: Record) -> Result<(record::Key, Vec<u8>), KadError> {
let req = KadRequestMsg::PutValue { record };
let rsp = self.send_request(req).await?;
match rsp {
KadResponseMsg::PutValue { key, value } => Ok((key, value)),
_ => Err(KadError::UnexpectedMessage("wrong message type received when PutValue")),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadRequestMsg {
Ping,
FindNode {
key: record::Key,
},
GetProviders {
key: record::Key,
},
AddProvider {
key: record::Key,
provider: KadPeer,
},
GetValue {
key: record::Key,
},
PutValue { record: Record },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadResponseMsg {
Pong,
FindNode {
closer_peers: Vec<KadPeer>,
},
GetProviders {
closer_peers: Vec<KadPeer>,
provider_peers: Vec<KadPeer>,
},
GetValue {
record: Option<Record>,
closer_peers: Vec<KadPeer>,
},
PutValue {
key: record::Key,
value: Vec<u8>,
},
}
fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
match kad_msg {
KadRequestMsg::Ping => proto::Message {
r#type: proto::message::MessageType::Ping as i32,
..proto::Message::default()
},
KadRequestMsg::FindNode { key } => proto::Message {
r#type: proto::message::MessageType::FindNode as i32,
key: key.to_vec(),
cluster_level_raw: 10,
..proto::Message::default()
},
KadRequestMsg::GetProviders { key } => proto::Message {
r#type: proto::message::MessageType::GetProviders as i32,
key: key.to_vec(),
cluster_level_raw: 10,
..proto::Message::default()
},
KadRequestMsg::AddProvider { key, provider } => proto::Message {
r#type: proto::message::MessageType::AddProvider as i32,
cluster_level_raw: 10,
key: key.to_vec(),
provider_peers: vec![provider.into()],
..proto::Message::default()
},
KadRequestMsg::GetValue { key } => proto::Message {
r#type: proto::message::MessageType::GetValue as i32,
cluster_level_raw: 10,
key: key.to_vec(),
..proto::Message::default()
},
KadRequestMsg::PutValue { record } => proto::Message {
r#type: proto::message::MessageType::PutValue as i32,
record: Some(record_to_proto(record)),
..proto::Message::default()
},
}
}
fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
match kad_msg {
KadResponseMsg::Pong => proto::Message {
r#type: proto::message::MessageType::Ping as i32,
..proto::Message::default()
},
KadResponseMsg::FindNode { closer_peers } => proto::Message {
r#type: proto::message::MessageType::FindNode as i32,
cluster_level_raw: 9,
closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(),
..proto::Message::default()
},
KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
} => proto::Message {
r#type: proto::message::MessageType::GetProviders as i32,
cluster_level_raw: 9,
closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(),
provider_peers: provider_peers.into_iter().map(KadPeer::into).collect(),
..proto::Message::default()
},
KadResponseMsg::GetValue { record, closer_peers } => proto::Message {
r#type: proto::message::MessageType::GetValue as i32,
cluster_level_raw: 9,
closer_peers: closer_peers.into_iter().map(KadPeer::into).collect(),
record: record.map(record_to_proto),
..proto::Message::default()
},
KadResponseMsg::PutValue { key, value } => proto::Message {
r#type: proto::message::MessageType::PutValue as i32,
key: key.to_vec(),
record: Some(proto::Record {
key: key.to_vec(),
value,
..proto::Record::default()
}),
..proto::Message::default()
},
}
}
fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
let msg_type = proto::message::MessageType::from_i32(message.r#type)
.ok_or_else(|| invalid_data(format!("unknown message type: {}", message.r#type)))?;
match msg_type {
proto::message::MessageType::Ping => Ok(KadRequestMsg::Ping),
proto::message::MessageType::PutValue => {
let record = record_from_proto(message.record.unwrap_or_default())?;
Ok(KadRequestMsg::PutValue { record })
}
proto::message::MessageType::GetValue => Ok(KadRequestMsg::GetValue {
key: record::Key::from(message.key),
}),
proto::message::MessageType::FindNode => Ok(KadRequestMsg::FindNode {
key: record::Key::from(message.key),
}),
proto::message::MessageType::GetProviders => Ok(KadRequestMsg::GetProviders {
key: record::Key::from(message.key),
}),
proto::message::MessageType::AddProvider => {
let provider = message.provider_peers.into_iter().find_map(|peer| KadPeer::try_from(peer).ok());
if let Some(provider) = provider {
let key = record::Key::from(message.key);
Ok(KadRequestMsg::AddProvider { key, provider })
} else {
Err(invalid_data("AddProvider message with no valid peer."))
}
}
}
}
fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
let msg_type = proto::message::MessageType::from_i32(message.r#type)
.ok_or_else(|| invalid_data(format!("unknown message type: {}", message.r#type)))?;
match msg_type {
proto::message::MessageType::Ping => Ok(KadResponseMsg::Pong),
proto::message::MessageType::GetValue => {
let record = if let Some(r) = message.record {
Some(record_from_proto(r)?)
} else {
None
};
let closer_peers = message
.closer_peers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
Ok(KadResponseMsg::GetValue { record, closer_peers })
}
proto::message::MessageType::FindNode => {
let closer_peers = message
.closer_peers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
Ok(KadResponseMsg::FindNode { closer_peers })
}
proto::message::MessageType::GetProviders => {
let closer_peers = message
.closer_peers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
let provider_peers = message
.provider_peers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
Ok(KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
})
}
proto::message::MessageType::PutValue => {
let key = record::Key::from(message.key);
let rec = message
.record
.ok_or_else(|| invalid_data("received PutValue message with no record"))?;
Ok(KadResponseMsg::PutValue { key, value: rec.value })
}
proto::message::MessageType::AddProvider => Err(invalid_data("received an unexpected AddProvider message")),
}
}
fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
let key = record::Key::from(record.key);
let value = record.value;
let publisher = if !record.publisher.is_empty() {
PeerId::from_bytes(&record.publisher)
.map(Some)
.map_err(|_| invalid_data("Invalid publisher peer ID."))?
} else {
None
};
let expires = if record.ttl > 0 {
Some(Instant::now() + Duration::from_secs(record.ttl as u64))
} else {
None
};
Ok(Record {
key,
value,
publisher,
expires,
})
}
fn record_to_proto(record: Record) -> proto::Record {
proto::Record {
key: record.key.to_vec(),
value: record.value,
publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
ttl: record
.expires
.map(|t| {
let now = Instant::now();
if t > now {
(t - now).as_secs() as u32
} else {
1
}
})
.unwrap_or(0),
time_received: String::new(),
}
}
fn invalid_data<E>(e: E) -> io::Error
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
io::Error::new(io::ErrorKind::InvalidData, e)
}
#[cfg(test)]
mod tests {
}
#[derive(Debug)]
pub enum RefreshStage {
Start(Option<oneshot::Sender<Result<(), KadError>>>),
SelfQueryDone(Option<oneshot::Sender<Result<(), KadError>>>),
Completed,
}
#[derive(Debug)]
pub(crate) enum ProtocolEvent {
Refresh(RefreshStage),
PeerConnected(PeerId),
PeerDisconnected(PeerId),
PeerIdentified(PeerId),
AddressChanged(Vec<Multiaddr>),
KadPeerFound(PeerId, bool),
KadPeerStopped(PeerId),
IterativeQueryCompleted(QueryStats),
IterativeQueryTimeout,
ProviderCleanupTimer,
RefreshTimer,
KadRequest {
request: KadRequestMsg,
source: PeerId,
reply: oneshot::Sender<Result<Option<KadResponseMsg>, KadError>>,
},
}