1use ant_libp2p_core as libp2p_core;
30use ant_libp2p_swarm as libp2p_swarm;
31
32use std::{io, iter, marker::PhantomData, time::Duration};
33
34use asynchronous_codec::{Decoder, Encoder, Framed};
35use bytes::BytesMut;
36use futures::prelude::*;
37use libp2p_core::{
38 upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
39 Multiaddr,
40};
41use libp2p_identity::PeerId;
42use libp2p_swarm::StreamProtocol;
43use tracing::debug;
44use web_time::Instant;
45
46use crate::{
47 proto,
48 record::{self, Record},
49};
50
51pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
53pub(crate) const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
55#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
57pub enum ConnectionType {
58 NotConnected = 0,
60 Connected = 1,
62 CanConnect = 2,
64 CannotConnect = 3,
66}
67
68impl From<proto::ConnectionType> for ConnectionType {
69 fn from(raw: proto::ConnectionType) -> ConnectionType {
70 use proto::ConnectionType::*;
71 match raw {
72 NOT_CONNECTED => ConnectionType::NotConnected,
73 CONNECTED => ConnectionType::Connected,
74 CAN_CONNECT => ConnectionType::CanConnect,
75 CANNOT_CONNECT => ConnectionType::CannotConnect,
76 }
77 }
78}
79
80impl From<ConnectionType> for proto::ConnectionType {
81 fn from(val: ConnectionType) -> Self {
82 use proto::ConnectionType::*;
83 match val {
84 ConnectionType::NotConnected => NOT_CONNECTED,
85 ConnectionType::Connected => CONNECTED,
86 ConnectionType::CanConnect => CAN_CONNECT,
87 ConnectionType::CannotConnect => CANNOT_CONNECT,
88 }
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub struct KadPeer {
95 pub node_id: PeerId,
97 pub multiaddrs: Vec<Multiaddr>,
99 pub connection_ty: ConnectionType,
101}
102
103impl TryFrom<proto::Peer> for KadPeer {
105 type Error = io::Error;
106
107 fn try_from(peer: proto::Peer) -> Result<KadPeer, Self::Error> {
108 let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
111
112 let mut addrs = Vec::with_capacity(peer.addrs.len());
113 for addr in peer.addrs.into_iter() {
114 match Multiaddr::try_from(addr).map(|addr| addr.with_p2p(node_id)) {
115 Ok(Ok(a)) => addrs.push(a),
116 Ok(Err(a)) => {
117 debug!("Unable to parse multiaddr: {a} is not compatible with {node_id}")
118 }
119 Err(e) => debug!("Unable to parse multiaddr: {e}"),
120 };
121 }
122
123 Ok(KadPeer {
124 node_id,
125 multiaddrs: addrs,
126 connection_ty: peer.connection.into(),
127 })
128 }
129}
130
131impl From<KadPeer> for proto::Peer {
132 fn from(peer: KadPeer) -> Self {
133 proto::Peer {
134 id: peer.node_id.to_bytes(),
135 addrs: peer.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
136 connection: peer.connection_ty.into(),
137 }
138 }
139}
140
141#[derive(Debug, Clone)]
147pub struct ProtocolConfig {
148 protocol_names: Vec<StreamProtocol>,
149 max_packet_size: usize,
151}
152
153impl ProtocolConfig {
154 pub fn new(protocol_name: StreamProtocol) -> Self {
156 ProtocolConfig {
157 protocol_names: vec![protocol_name],
158 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
159 }
160 }
161
162 #[deprecated(note = "Use `ProtocolConfig::new` instead")]
164 #[allow(clippy::should_implement_trait)]
165 pub fn default() -> Self {
166 Default::default()
167 }
168
169 pub fn protocol_names(&self) -> &[StreamProtocol] {
171 &self.protocol_names
172 }
173
174 #[deprecated(note = "Use `ProtocolConfig::new` instead")]
177 pub fn set_protocol_names(&mut self, names: Vec<StreamProtocol>) {
178 self.protocol_names = names;
179 }
180
181 pub fn set_max_packet_size(&mut self, size: usize) {
183 self.max_packet_size = size;
184 }
185}
186
187impl Default for ProtocolConfig {
188 fn default() -> Self {
192 ProtocolConfig {
193 protocol_names: iter::once(DEFAULT_PROTO_NAME).collect(),
194 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
195 }
196 }
197}
198
199impl UpgradeInfo for ProtocolConfig {
200 type Info = StreamProtocol;
201 type InfoIter = std::vec::IntoIter<Self::Info>;
202
203 fn protocol_info(&self) -> Self::InfoIter {
204 self.protocol_names.clone().into_iter()
205 }
206}
207
208pub struct Codec<A, B> {
210 codec: quick_protobuf_codec::Codec<proto::Message>,
211 __phantom: PhantomData<(A, B)>,
212}
213impl<A, B> Codec<A, B> {
214 fn new(max_packet_size: usize) -> Self {
215 Codec {
216 codec: quick_protobuf_codec::Codec::new(max_packet_size),
217 __phantom: PhantomData,
218 }
219 }
220}
221
222impl<A: Into<proto::Message>, B> Encoder for Codec<A, B> {
223 type Error = io::Error;
224 type Item<'a> = A;
225
226 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
227 Ok(self.codec.encode(item.into(), dst)?)
228 }
229}
230impl<A, B: TryFrom<proto::Message, Error = io::Error>> Decoder for Codec<A, B> {
231 type Error = io::Error;
232 type Item = B;
233
234 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
235 self.codec.decode(src)?.map(B::try_from).transpose()
236 }
237}
238
239pub(crate) type KadInStreamSink<S> = Framed<S, Codec<KadResponseMsg, KadRequestMsg>>;
241pub(crate) type KadOutStreamSink<S> = Framed<S, Codec<KadRequestMsg, KadResponseMsg>>;
243
244impl<C> InboundUpgrade<C> for ProtocolConfig
245where
246 C: AsyncRead + AsyncWrite + Unpin,
247{
248 type Output = KadInStreamSink<C>;
249 type Future = future::Ready<Result<Self::Output, io::Error>>;
250 type Error = io::Error;
251
252 fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
253 let codec = Codec::new(self.max_packet_size);
254
255 future::ok(Framed::new(incoming, codec))
256 }
257}
258
259impl<C> OutboundUpgrade<C> for ProtocolConfig
260where
261 C: AsyncRead + AsyncWrite + Unpin,
262{
263 type Output = KadOutStreamSink<C>;
264 type Future = future::Ready<Result<Self::Output, io::Error>>;
265 type Error = io::Error;
266
267 fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
268 let codec = Codec::new(self.max_packet_size);
269
270 future::ok(Framed::new(incoming, codec))
271 }
272}
273
274#[derive(Debug, Clone, PartialEq, Eq)]
276pub enum KadRequestMsg {
277 Ping,
279
280 FindNode {
283 key: Vec<u8>,
285 },
286
287 GetProviders {
290 key: record::Key,
292 },
293
294 AddProvider {
296 key: record::Key,
298 provider: KadPeer,
300 },
301
302 GetValue {
304 key: record::Key,
306 },
307
308 PutValue { record: Record },
310}
311
312#[derive(Debug, Clone, PartialEq, Eq)]
314pub enum KadResponseMsg {
315 Pong,
317
318 FindNode {
320 closer_peers: Vec<KadPeer>,
322 },
323
324 GetProviders {
326 closer_peers: Vec<KadPeer>,
328 provider_peers: Vec<KadPeer>,
330 },
331
332 GetValue {
334 record: Option<Record>,
336 closer_peers: Vec<KadPeer>,
338 },
339
340 PutValue {
342 key: record::Key,
344 value: Vec<u8>,
346 },
347}
348
349impl From<KadRequestMsg> for proto::Message {
350 fn from(kad_msg: KadRequestMsg) -> Self {
351 req_msg_to_proto(kad_msg)
352 }
353}
354impl From<KadResponseMsg> for proto::Message {
355 fn from(kad_msg: KadResponseMsg) -> Self {
356 resp_msg_to_proto(kad_msg)
357 }
358}
359impl TryFrom<proto::Message> for KadRequestMsg {
360 type Error = io::Error;
361
362 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
363 proto_to_req_msg(message)
364 }
365}
366impl TryFrom<proto::Message> for KadResponseMsg {
367 type Error = io::Error;
368
369 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
370 proto_to_resp_msg(message)
371 }
372}
373
374fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
376 match kad_msg {
377 KadRequestMsg::Ping => proto::Message {
378 type_pb: proto::MessageType::PING,
379 ..proto::Message::default()
380 },
381 KadRequestMsg::FindNode { key } => proto::Message {
382 type_pb: proto::MessageType::FIND_NODE,
383 key,
384 clusterLevelRaw: 10,
385 ..proto::Message::default()
386 },
387 KadRequestMsg::GetProviders { key } => proto::Message {
388 type_pb: proto::MessageType::GET_PROVIDERS,
389 key: key.to_vec(),
390 clusterLevelRaw: 10,
391 ..proto::Message::default()
392 },
393 KadRequestMsg::AddProvider { key, provider } => proto::Message {
394 type_pb: proto::MessageType::ADD_PROVIDER,
395 clusterLevelRaw: 10,
396 key: key.to_vec(),
397 providerPeers: vec![provider.into()],
398 ..proto::Message::default()
399 },
400 KadRequestMsg::GetValue { key } => proto::Message {
401 type_pb: proto::MessageType::GET_VALUE,
402 clusterLevelRaw: 10,
403 key: key.to_vec(),
404 ..proto::Message::default()
405 },
406 KadRequestMsg::PutValue { record } => proto::Message {
407 type_pb: proto::MessageType::PUT_VALUE,
408 key: record.key.to_vec(),
409 record: Some(record_to_proto(record)),
410 ..proto::Message::default()
411 },
412 }
413}
414
415fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
417 match kad_msg {
418 KadResponseMsg::Pong => proto::Message {
419 type_pb: proto::MessageType::PING,
420 ..proto::Message::default()
421 },
422 KadResponseMsg::FindNode { closer_peers } => proto::Message {
423 type_pb: proto::MessageType::FIND_NODE,
424 clusterLevelRaw: 9,
425 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
426 ..proto::Message::default()
427 },
428 KadResponseMsg::GetProviders {
429 closer_peers,
430 provider_peers,
431 } => proto::Message {
432 type_pb: proto::MessageType::GET_PROVIDERS,
433 clusterLevelRaw: 9,
434 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
435 providerPeers: provider_peers.into_iter().map(KadPeer::into).collect(),
436 ..proto::Message::default()
437 },
438 KadResponseMsg::GetValue {
439 record,
440 closer_peers,
441 } => proto::Message {
442 type_pb: proto::MessageType::GET_VALUE,
443 clusterLevelRaw: 9,
444 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
445 record: record.map(record_to_proto),
446 ..proto::Message::default()
447 },
448 KadResponseMsg::PutValue { key, value } => proto::Message {
449 type_pb: proto::MessageType::PUT_VALUE,
450 key: key.to_vec(),
451 record: Some(proto::Record {
452 key: key.to_vec(),
453 value,
454 ..proto::Record::default()
455 }),
456 ..proto::Message::default()
457 },
458 }
459}
460
461fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
465 match message.type_pb {
466 proto::MessageType::PING => Ok(KadRequestMsg::Ping),
467 proto::MessageType::PUT_VALUE => {
468 let record = record_from_proto(message.record.unwrap_or_default())?;
469 Ok(KadRequestMsg::PutValue { record })
470 }
471 proto::MessageType::GET_VALUE => Ok(KadRequestMsg::GetValue {
472 key: record::Key::from(message.key),
473 }),
474 proto::MessageType::FIND_NODE => Ok(KadRequestMsg::FindNode { key: message.key }),
475 proto::MessageType::GET_PROVIDERS => Ok(KadRequestMsg::GetProviders {
476 key: record::Key::from(message.key),
477 }),
478 proto::MessageType::ADD_PROVIDER => {
479 let provider = message
483 .providerPeers
484 .into_iter()
485 .find_map(|peer| KadPeer::try_from(peer).ok());
486
487 if let Some(provider) = provider {
488 let key = record::Key::from(message.key);
489 Ok(KadRequestMsg::AddProvider { key, provider })
490 } else {
491 Err(invalid_data("AddProvider message with no valid peer."))
492 }
493 }
494 }
495}
496
497fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
501 match message.type_pb {
502 proto::MessageType::PING => Ok(KadResponseMsg::Pong),
503 proto::MessageType::GET_VALUE => {
504 let record = if let Some(r) = message.record {
505 Some(record_from_proto(r)?)
506 } else {
507 None
508 };
509
510 let closer_peers = message
511 .closerPeers
512 .into_iter()
513 .filter_map(|peer| KadPeer::try_from(peer).ok())
514 .collect();
515
516 Ok(KadResponseMsg::GetValue {
517 record,
518 closer_peers,
519 })
520 }
521
522 proto::MessageType::FIND_NODE => {
523 let closer_peers = message
524 .closerPeers
525 .into_iter()
526 .filter_map(|peer| KadPeer::try_from(peer).ok())
527 .collect();
528
529 Ok(KadResponseMsg::FindNode { closer_peers })
530 }
531
532 proto::MessageType::GET_PROVIDERS => {
533 let closer_peers = message
534 .closerPeers
535 .into_iter()
536 .filter_map(|peer| KadPeer::try_from(peer).ok())
537 .collect();
538
539 let provider_peers = message
540 .providerPeers
541 .into_iter()
542 .filter_map(|peer| KadPeer::try_from(peer).ok())
543 .collect();
544
545 Ok(KadResponseMsg::GetProviders {
546 closer_peers,
547 provider_peers,
548 })
549 }
550
551 proto::MessageType::PUT_VALUE => {
552 let key = record::Key::from(message.key);
553 let rec = message
554 .record
555 .ok_or_else(|| invalid_data("received PutValue message with no record"))?;
556
557 Ok(KadResponseMsg::PutValue {
558 key,
559 value: rec.value,
560 })
561 }
562
563 proto::MessageType::ADD_PROVIDER => {
564 Err(invalid_data("received an unexpected AddProvider message"))
565 }
566 }
567}
568
569fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
570 let key = record::Key::from(record.key);
571 let value = record.value;
572
573 let publisher = if !record.publisher.is_empty() {
574 PeerId::from_bytes(&record.publisher)
575 .map(Some)
576 .map_err(|_| invalid_data("Invalid publisher peer ID."))?
577 } else {
578 None
579 };
580
581 let expires = if record.ttl > 0 {
582 Some(Instant::now() + Duration::from_secs(record.ttl as u64))
583 } else {
584 None
585 };
586
587 Ok(Record {
588 key,
589 value,
590 publisher,
591 expires,
592 })
593}
594
595fn record_to_proto(record: Record) -> proto::Record {
596 proto::Record {
597 key: record.key.to_vec(),
598 value: record.value,
599 publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
600 ttl: record
601 .expires
602 .map(|t| {
603 let now = Instant::now();
604 if t > now {
605 (t - now).as_secs() as u32
606 } else {
607 1 }
609 })
610 .unwrap_or(0),
611 timeReceived: String::new(),
612 }
613}
614
615fn invalid_data<E>(e: E) -> io::Error
617where
618 E: Into<Box<dyn std::error::Error + Send + Sync>>,
619{
620 io::Error::new(io::ErrorKind::InvalidData, e)
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[test]
628 fn append_p2p() {
629 let peer_id = PeerId::random();
630 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
631
632 let payload = proto::Peer {
633 id: peer_id.to_bytes(),
634 addrs: vec![multiaddr.to_vec()],
635 connection: proto::ConnectionType::CAN_CONNECT,
636 };
637
638 let peer = KadPeer::try_from(payload).unwrap();
639
640 assert_eq!(peer.multiaddrs, vec![multiaddr.with_p2p(peer_id).unwrap()])
641 }
642
643 #[test]
644 fn skip_invalid_multiaddr() {
645 let peer_id = PeerId::random();
646 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
647
648 let valid_multiaddr = multiaddr.clone().with_p2p(peer_id).unwrap();
649
650 let multiaddr_with_incorrect_peer_id = {
651 let other_peer_id = PeerId::random();
652 assert_ne!(peer_id, other_peer_id);
653 multiaddr.with_p2p(other_peer_id).unwrap()
654 };
655
656 let invalid_multiaddr = {
657 let a = vec![255; 8];
658 assert!(Multiaddr::try_from(a.clone()).is_err());
659 a
660 };
661
662 let payload = proto::Peer {
663 id: peer_id.to_bytes(),
664 addrs: vec![
665 valid_multiaddr.to_vec(),
666 multiaddr_with_incorrect_peer_id.to_vec(),
667 invalid_multiaddr,
668 ],
669 connection: proto::ConnectionType::CAN_CONNECT,
670 };
671
672 let peer = KadPeer::try_from(payload).unwrap();
673
674 assert_eq!(peer.multiaddrs, vec![valid_multiaddr])
675 }
676
677 }