use std::{io, marker::PhantomData, time::Duration};
use asynchronous_codec::{Decoder, Encoder, Framed};
use bytes::BytesMut;
use futures::prelude::*;
use libp2p_core::{
upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
Multiaddr,
};
use libp2p_identity::PeerId;
use libp2p_swarm::StreamProtocol;
use tracing::debug;
use web_time::Instant;
use crate::{
proto,
record::{self, Record},
};
pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
pub(crate) const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
const DEFAULT_SUBSTREAMS_TIMEOUT_S: Duration = Duration::from_secs(10);
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum ConnectionType {
NotConnected = 0,
Connected = 1,
CanConnect = 2,
CannotConnect = 3,
}
impl From<proto::ConnectionType> for ConnectionType {
fn from(raw: proto::ConnectionType) -> ConnectionType {
use proto::ConnectionType::*;
match raw {
NOT_CONNECTED => ConnectionType::NotConnected,
CONNECTED => ConnectionType::Connected,
CAN_CONNECT => ConnectionType::CanConnect,
CANNOT_CONNECT => ConnectionType::CannotConnect,
}
}
}
impl From<ConnectionType> for proto::ConnectionType {
fn from(val: ConnectionType) -> Self {
use proto::ConnectionType::*;
match val {
ConnectionType::NotConnected => NOT_CONNECTED,
ConnectionType::Connected => CONNECTED,
ConnectionType::CanConnect => CAN_CONNECT,
ConnectionType::CannotConnect => CANNOT_CONNECT,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KadPeer {
pub node_id: PeerId,
pub multiaddrs: Vec<Multiaddr>,
pub connection_ty: ConnectionType,
}
impl TryFrom<proto::Peer> for KadPeer {
type Error = io::Error;
fn try_from(peer: proto::Peer) -> Result<KadPeer, Self::Error> {
let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
let mut addrs = Vec::with_capacity(peer.addrs.len());
for addr in peer.addrs.into_iter() {
match Multiaddr::try_from(addr).map(|addr| addr.with_p2p(node_id)) {
Ok(Ok(a)) => addrs.push(a),
Ok(Err(a)) => {
debug!("Unable to parse multiaddr: {a} is not compatible with {node_id}")
}
Err(e) => debug!("Unable to parse multiaddr: {e}"),
};
}
Ok(KadPeer {
node_id,
multiaddrs: addrs,
connection_ty: peer.connection.into(),
})
}
}
impl From<KadPeer> for proto::Peer {
fn from(peer: KadPeer) -> Self {
proto::Peer {
id: peer.node_id.to_bytes(),
addrs: peer.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
connection: peer.connection_ty.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct ProtocolConfig {
protocol_names: Vec<StreamProtocol>,
max_packet_size: usize,
substreams_timeout_s: Duration,
}
impl ProtocolConfig {
pub fn new(protocol_name: StreamProtocol) -> Self {
ProtocolConfig {
protocol_names: vec![protocol_name],
max_packet_size: DEFAULT_MAX_PACKET_SIZE,
substreams_timeout_s: DEFAULT_SUBSTREAMS_TIMEOUT_S,
}
}
pub fn protocol_names(&self) -> &[StreamProtocol] {
&self.protocol_names
}
pub fn set_max_packet_size(&mut self, size: usize) {
self.max_packet_size = size;
}
pub fn set_substreams_timeout(&mut self, timeout: Duration) {
self.substreams_timeout_s = timeout;
}
pub fn substreams_timeout_s(&self) -> Duration {
self.substreams_timeout_s
}
}
impl UpgradeInfo for ProtocolConfig {
type Info = StreamProtocol;
type InfoIter = std::vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.protocol_names.clone().into_iter()
}
}
pub struct Codec<A, B> {
codec: quick_protobuf_codec::Codec<proto::Message>,
__phantom: PhantomData<(A, B)>,
}
impl<A, B> Codec<A, B> {
fn new(max_packet_size: usize) -> Self {
Codec {
codec: quick_protobuf_codec::Codec::new(max_packet_size),
__phantom: PhantomData,
}
}
}
impl<A: Into<proto::Message>, B> Encoder for Codec<A, B> {
type Error = io::Error;
type Item<'a> = A;
fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
Ok(self.codec.encode(item.into(), dst)?)
}
}
impl<A, B: TryFrom<proto::Message, Error = io::Error>> Decoder for Codec<A, B> {
type Error = io::Error;
type Item = B;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.codec.decode(src)?.map(B::try_from).transpose()
}
}
pub(crate) type KadInStreamSink<S> = Framed<S, Codec<KadResponseMsg, KadRequestMsg>>;
pub(crate) type KadOutStreamSink<S> = Framed<S, Codec<KadRequestMsg, KadResponseMsg>>;
impl<C> InboundUpgrade<C> for ProtocolConfig
where
C: AsyncRead + AsyncWrite + Unpin,
{
type Output = KadInStreamSink<C>;
type Future = future::Ready<Result<Self::Output, io::Error>>;
type Error = io::Error;
fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
let codec = Codec::new(self.max_packet_size);
future::ok(Framed::new(incoming, codec))
}
}
impl<C> OutboundUpgrade<C> for ProtocolConfig
where
C: AsyncRead + AsyncWrite + Unpin,
{
type Output = KadOutStreamSink<C>;
type Future = future::Ready<Result<Self::Output, io::Error>>;
type Error = io::Error;
fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
let codec = Codec::new(self.max_packet_size);
future::ok(Framed::new(incoming, codec))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadRequestMsg {
Ping,
FindNode {
key: Vec<u8>,
},
GetProviders {
key: record::Key,
},
AddProvider {
key: record::Key,
provider: KadPeer,
},
GetValue {
key: record::Key,
},
PutValue { record: Record },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KadResponseMsg {
Pong,
FindNode {
closer_peers: Vec<KadPeer>,
},
GetProviders {
closer_peers: Vec<KadPeer>,
provider_peers: Vec<KadPeer>,
},
GetValue {
record: Option<Record>,
closer_peers: Vec<KadPeer>,
},
PutValue {
key: record::Key,
value: Vec<u8>,
},
}
impl From<KadRequestMsg> for proto::Message {
fn from(kad_msg: KadRequestMsg) -> Self {
req_msg_to_proto(kad_msg)
}
}
impl From<KadResponseMsg> for proto::Message {
fn from(kad_msg: KadResponseMsg) -> Self {
resp_msg_to_proto(kad_msg)
}
}
impl TryFrom<proto::Message> for KadRequestMsg {
type Error = io::Error;
fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
proto_to_req_msg(message)
}
}
impl TryFrom<proto::Message> for KadResponseMsg {
type Error = io::Error;
fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
proto_to_resp_msg(message)
}
}
fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
match kad_msg {
KadRequestMsg::Ping => proto::Message {
type_pb: proto::MessageType::PING,
..proto::Message::default()
},
KadRequestMsg::FindNode { key } => proto::Message {
type_pb: proto::MessageType::FIND_NODE,
key,
clusterLevelRaw: 10,
..proto::Message::default()
},
KadRequestMsg::GetProviders { key } => proto::Message {
type_pb: proto::MessageType::GET_PROVIDERS,
key: key.to_vec(),
clusterLevelRaw: 10,
..proto::Message::default()
},
KadRequestMsg::AddProvider { key, provider } => proto::Message {
type_pb: proto::MessageType::ADD_PROVIDER,
clusterLevelRaw: 10,
key: key.to_vec(),
providerPeers: vec![provider.into()],
..proto::Message::default()
},
KadRequestMsg::GetValue { key } => proto::Message {
type_pb: proto::MessageType::GET_VALUE,
clusterLevelRaw: 10,
key: key.to_vec(),
..proto::Message::default()
},
KadRequestMsg::PutValue { record } => proto::Message {
type_pb: proto::MessageType::PUT_VALUE,
key: record.key.to_vec(),
record: Some(record_to_proto(record)),
..proto::Message::default()
},
}
}
fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
match kad_msg {
KadResponseMsg::Pong => proto::Message {
type_pb: proto::MessageType::PING,
..proto::Message::default()
},
KadResponseMsg::FindNode { closer_peers } => proto::Message {
type_pb: proto::MessageType::FIND_NODE,
clusterLevelRaw: 9,
closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
..proto::Message::default()
},
KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
} => proto::Message {
type_pb: proto::MessageType::GET_PROVIDERS,
clusterLevelRaw: 9,
closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
providerPeers: provider_peers.into_iter().map(KadPeer::into).collect(),
..proto::Message::default()
},
KadResponseMsg::GetValue {
record,
closer_peers,
} => proto::Message {
type_pb: proto::MessageType::GET_VALUE,
clusterLevelRaw: 9,
closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
record: record.map(record_to_proto),
..proto::Message::default()
},
KadResponseMsg::PutValue { key, value } => proto::Message {
type_pb: proto::MessageType::PUT_VALUE,
key: key.to_vec(),
record: Some(proto::Record {
key: key.to_vec(),
value,
..proto::Record::default()
}),
..proto::Message::default()
},
}
}
fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
match message.type_pb {
proto::MessageType::PING => Ok(KadRequestMsg::Ping),
proto::MessageType::PUT_VALUE => {
let record = record_from_proto(message.record.unwrap_or_default())?;
Ok(KadRequestMsg::PutValue { record })
}
proto::MessageType::GET_VALUE => Ok(KadRequestMsg::GetValue {
key: record::Key::from(message.key),
}),
proto::MessageType::FIND_NODE => Ok(KadRequestMsg::FindNode { key: message.key }),
proto::MessageType::GET_PROVIDERS => Ok(KadRequestMsg::GetProviders {
key: record::Key::from(message.key),
}),
proto::MessageType::ADD_PROVIDER => {
let provider = message
.providerPeers
.into_iter()
.find_map(|peer| KadPeer::try_from(peer).ok());
if let Some(provider) = provider {
let key = record::Key::from(message.key);
Ok(KadRequestMsg::AddProvider { key, provider })
} else {
Err(invalid_data("AddProvider message with no valid peer."))
}
}
}
}
fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
match message.type_pb {
proto::MessageType::PING => Ok(KadResponseMsg::Pong),
proto::MessageType::GET_VALUE => {
let record = if let Some(r) = message.record {
Some(record_from_proto(r)?)
} else {
None
};
let closer_peers = message
.closerPeers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
Ok(KadResponseMsg::GetValue {
record,
closer_peers,
})
}
proto::MessageType::FIND_NODE => {
let closer_peers = message
.closerPeers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
Ok(KadResponseMsg::FindNode { closer_peers })
}
proto::MessageType::GET_PROVIDERS => {
let closer_peers = message
.closerPeers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
let provider_peers = message
.providerPeers
.into_iter()
.filter_map(|peer| KadPeer::try_from(peer).ok())
.collect();
Ok(KadResponseMsg::GetProviders {
closer_peers,
provider_peers,
})
}
proto::MessageType::PUT_VALUE => {
let key = record::Key::from(message.key);
let rec = message
.record
.ok_or_else(|| invalid_data("received PutValue message with no record"))?;
Ok(KadResponseMsg::PutValue {
key,
value: rec.value,
})
}
proto::MessageType::ADD_PROVIDER => {
Err(invalid_data("received an unexpected AddProvider message"))
}
}
}
fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
let key = record::Key::from(record.key);
let value = record.value;
let publisher = if !record.publisher.is_empty() {
PeerId::from_bytes(&record.publisher)
.map(Some)
.map_err(|_| invalid_data("Invalid publisher peer ID."))?
} else {
None
};
let expires = if record.ttl > 0 {
Some(Instant::now() + Duration::from_secs(record.ttl as u64))
} else {
None
};
Ok(Record {
key,
value,
publisher,
expires,
})
}
fn record_to_proto(record: Record) -> proto::Record {
proto::Record {
key: record.key.to_vec(),
value: record.value,
publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
ttl: record
.expires
.map(|t| {
let now = Instant::now();
if t > now {
(t - now).as_secs() as u32
} else {
1 }
})
.unwrap_or(0),
timeReceived: String::new(),
}
}
fn invalid_data<E>(e: E) -> io::Error
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
io::Error::new(io::ErrorKind::InvalidData, e)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn append_p2p() {
let peer_id = PeerId::random();
let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
let payload = proto::Peer {
id: peer_id.to_bytes(),
addrs: vec![multiaddr.to_vec()],
connection: proto::ConnectionType::CAN_CONNECT,
};
let peer = KadPeer::try_from(payload).unwrap();
assert_eq!(peer.multiaddrs, vec![multiaddr.with_p2p(peer_id).unwrap()])
}
#[test]
fn skip_invalid_multiaddr() {
let peer_id = PeerId::random();
let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
let valid_multiaddr = multiaddr.clone().with_p2p(peer_id).unwrap();
let multiaddr_with_incorrect_peer_id = {
let other_peer_id = PeerId::random();
assert_ne!(peer_id, other_peer_id);
multiaddr.with_p2p(other_peer_id).unwrap()
};
let invalid_multiaddr = {
let a = vec![255; 8];
assert!(Multiaddr::try_from(a.clone()).is_err());
a
};
let payload = proto::Peer {
id: peer_id.to_bytes(),
addrs: vec![
valid_multiaddr.to_vec(),
multiaddr_with_incorrect_peer_id.to_vec(),
invalid_multiaddr,
],
connection: proto::ConnectionType::CAN_CONNECT,
};
let peer = KadPeer::try_from(payload).unwrap();
assert_eq!(peer.multiaddrs, vec![valid_multiaddr])
}
}