Skip to main content

rift_dht/
lib.rs

1//! Libp2p-based DHT integration for internet discovery.
2//!
3//! This module maintains a lightweight Kademlia DHT used to announce and
4//! discover peer endpoints keyed by channel id. It exposes a simple async
5//! handle (`DhtHandle`) to announce or lookup peers.
6
7use std::collections::HashMap;
8use std::net::SocketAddr;
9
10use anyhow::Result;
11use libp2p::core::upgrade;
12use libp2p::identify::{Behaviour as Identify, Config as IdentifyConfig, Event as IdentifyEvent};
13use libp2p::kad::{
14    store::MemoryStore, Behaviour as Kademlia, Event as KademliaEvent, GetProvidersOk,
15    GetRecordOk, PutRecordOk, QueryId, QueryResult, Quorum, Record, RecordKey,
16};
17use libp2p::multiaddr::Protocol;
18use libp2p::noise;
19use libp2p::swarm::{NetworkBehaviour, Swarm, SwarmEvent};
20use libp2p::{tcp, yamux, Multiaddr, PeerId, Transport};
21use rift_core::{ChannelId, PeerId as RiftPeerId};
22use rift_metrics as metrics;
23use tracing::debug;
24use serde::{Deserialize, Serialize};
25use tokio::sync::{mpsc, oneshot};
26use futures::StreamExt;
27
28#[derive(Debug, Clone)]
29pub struct DhtConfig {
30    /// Bootstrap node socket addresses.
31    pub bootstrap_nodes: Vec<SocketAddr>,
32    /// Local listen address for the DHT.
33    pub listen_addr: SocketAddr,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct PeerEndpointInfo {
38    /// Rift peer id.
39    pub peer_id: RiftPeerId,
40    /// Known socket addresses for the peer.
41    pub addrs: Vec<SocketAddr>,
42}
43
44#[derive(Debug, thiserror::Error)]
45pub enum DhtError {
46    /// Transport/IO errors from the libp2p stack.
47    #[error("transport error: {0}")]
48    Transport(String),
49    /// DHT query errors.
50    #[error("dht error: {0}")]
51    Dht(String),
52}
53
54#[derive(Clone)]
55pub struct DhtHandle {
56    /// Command channel for interacting with the background swarm task.
57    cmd_tx: mpsc::Sender<Command>,
58}
59
60enum Command {
61    /// Announce peer info for a channel and become a provider.
62    Announce {
63        key: ChannelId,
64        info: PeerEndpointInfo,
65        resp: oneshot::Sender<Result<(), DhtError>>,
66    },
67    /// Lookup peers providing a channel.
68    Lookup {
69        key: ChannelId,
70        resp: oneshot::Sender<Result<Vec<PeerEndpointInfo>, DhtError>>,
71    },
72}
73
74/// Combined network behaviours used by the libp2p swarm.
75#[derive(NetworkBehaviour)]
76struct Behaviour {
77    kademlia: Kademlia<MemoryStore>,
78    identify: Identify,
79}
80
81/// Tracks in-flight lookups to aggregate multiple record results.
82struct LookupState {
83    channel: ChannelId,
84    pending: usize,
85    results: Vec<PeerEndpointInfo>,
86    resp: oneshot::Sender<Result<Vec<PeerEndpointInfo>, DhtError>>,
87}
88
89/// Internal classifier to connect query ids with lookup stages.
90enum LookupKind {
91    Providers { lookup_id: u64 },
92    Record { lookup_id: u64 },
93}
94
95impl DhtHandle {
96    /// Spawn the DHT swarm and return a handle for interaction.
97    pub async fn new(config: DhtConfig) -> Result<DhtHandle, DhtError> {
98        let local_key = libp2p::identity::Keypair::generate_ed25519();
99        let local_peer_id = PeerId::from(local_key.public());
100        let transport = tcp::tokio::Transport::new(tcp::Config::default().nodelay(true))
101            .upgrade(upgrade::Version::V1)
102            .authenticate(noise::Config::new(&local_key).map_err(|e| DhtError::Transport(e.to_string()))?)
103            .multiplex(yamux::Config::default())
104            .boxed();
105
106        let store = MemoryStore::new(local_peer_id);
107        let mut kademlia = Kademlia::new(local_peer_id, store);
108        kademlia.set_mode(Some(libp2p::kad::Mode::Server));
109
110        let identify = Identify::new(IdentifyConfig::new(
111            "rift-dht/1.0.0".to_string(),
112            local_key.public(),
113        ));
114
115        let behaviour = Behaviour { kademlia, identify };
116        let mut swarm = Swarm::new(
117            transport,
118            behaviour,
119            local_peer_id,
120            libp2p::swarm::Config::with_tokio_executor(),
121        );
122
123        let listen_addr = socket_to_multiaddr(config.listen_addr);
124        swarm
125            .listen_on(listen_addr)
126            .map_err(|e| DhtError::Transport(e.to_string()))?;
127
128        let (cmd_tx, mut cmd_rx) = mpsc::channel(64);
129        let mut pending_put: HashMap<QueryId, oneshot::Sender<Result<(), DhtError>>> = HashMap::new();
130        let mut pending_lookup: HashMap<QueryId, LookupKind> = HashMap::new();
131        let mut lookups: HashMap<u64, LookupState> = HashMap::new();
132        let mut next_lookup_id = 1u64;
133
134        for addr in config.bootstrap_nodes {
135            let multi = socket_to_multiaddr(addr);
136            let _ = swarm.dial(multi);
137        }
138
139        // Background task: drive swarm and answer requests.
140        tokio::spawn(async move {
141            loop {
142                tokio::select! {
143                    Some(cmd) = cmd_rx.recv() => match cmd {
144                        Command::Announce { key, info, resp } => {
145                            let channel_key = channel_key(key);
146                            let record_key = peer_record_key(key, info.peer_id);
147                            let value = bincode::serialize(&info).unwrap_or_default();
148                            let record = Record { key: record_key, value, publisher: None, expires: None };
149                            let qid = swarm.behaviour_mut().kademlia.put_record(record, Quorum::One);
150                            if let Ok(qid) = qid {
151                                pending_put.insert(qid, resp);
152                            } else {
153                                let _ = resp.send(Err(DhtError::Dht("put record failed".to_string())));
154                            }
155                            let _ = swarm.behaviour_mut().kademlia.start_providing(channel_key);
156                        }
157                        Command::Lookup { key, resp } => {
158                            let lookup_id = next_lookup_id;
159                            next_lookup_id += 1;
160                            let qid = swarm.behaviour_mut().kademlia.get_providers(channel_key(key));
161                            pending_lookup.insert(qid, LookupKind::Providers { lookup_id });
162                            lookups.insert(lookup_id, LookupState { channel: key, pending: 0, results: Vec::new(), resp });
163                        }
164                    },
165                    event = swarm.select_next_some() => match event {
166                        SwarmEvent::Behaviour(BehaviourEvent::Identify(IdentifyEvent::Received { peer_id, info, .. })) => {
167                            for addr in info.listen_addrs {
168                                swarm.behaviour_mut().kademlia.add_address(&peer_id, addr);
169                            }
170                            let _ = swarm.behaviour_mut().kademlia.bootstrap();
171                        }
172                        SwarmEvent::Behaviour(BehaviourEvent::Kademlia(event)) => {
173                            if let KademliaEvent::OutboundQueryProgressed { id, result, .. } = event {
174                                match result {
175                                    QueryResult::PutRecord(Ok(PutRecordOk { .. })) => {
176                                        if let Some(resp) = pending_put.remove(&id) {
177                                            let _ = resp.send(Ok(()));
178                                        }
179                                    }
180                                    QueryResult::PutRecord(Err(err)) => {
181                                        if let Some(resp) = pending_put.remove(&id) {
182                                            let _ = resp.send(Err(DhtError::Dht(err.to_string())));
183                                        }
184                                    }
185                                    QueryResult::GetProviders(Ok(GetProvidersOk::FoundProviders { providers, .. })) => {
186                                        if let Some(LookupKind::Providers { lookup_id }) = pending_lookup.remove(&id) {
187                                            if let Some(state) = lookups.get_mut(&lookup_id) {
188                                                if providers.is_empty() {
189                                                    let state = lookups.remove(&lookup_id).unwrap();
190                                                    let _ = state.resp.send(Ok(state.results));
191                                                } else {
192                                                    state.pending = providers.len();
193                                                    for provider in providers {
194                                                        let record_key = peer_record_key_from_peer(state.channel, provider);
195                                                        let qid = swarm.behaviour_mut().kademlia.get_record(record_key);
196                                                        pending_lookup.insert(qid, LookupKind::Record { lookup_id });
197                                                    }
198                                                }
199                                            }
200                                        }
201                                    }
202                                    QueryResult::GetProviders(Ok(GetProvidersOk::FinishedWithNoAdditionalRecord { .. })) => {
203                                        if let Some(LookupKind::Providers { lookup_id }) = pending_lookup.remove(&id) {
204                                            if let Some(state) = lookups.remove(&lookup_id) {
205                                                let _ = state.resp.send(Ok(state.results));
206                                            }
207                                        }
208                                    }
209                                    QueryResult::GetProviders(Err(err)) => {
210                                        if let Some(LookupKind::Providers { lookup_id }) = pending_lookup.remove(&id) {
211                                            if let Some(state) = lookups.remove(&lookup_id) {
212                                                let _ = state.resp.send(Err(DhtError::Dht(err.to_string())));
213                                            }
214                                        }
215                                    }
216                                    QueryResult::GetRecord(Ok(GetRecordOk::FoundRecord(record))) => {
217                                        if let Some(LookupKind::Record { lookup_id }) = pending_lookup.remove(&id) {
218                                            if let Some(state) = lookups.get_mut(&lookup_id) {
219                                                if let Ok(info) = bincode::deserialize::<PeerEndpointInfo>(&record.record.value) {
220                                                    state.results.push(info);
221                                                }
222                                                if state.pending > 0 {
223                                                    state.pending -= 1;
224                                                }
225                                                if state.pending == 0 {
226                                                    let state = lookups.remove(&lookup_id).unwrap();
227                                                    let _ = state.resp.send(Ok(state.results));
228                                                }
229                                            }
230                                        }
231                                    }
232                                    QueryResult::GetRecord(Ok(GetRecordOk::FinishedWithNoAdditionalRecord { .. })) => {
233                                        if let Some(LookupKind::Record { lookup_id }) = pending_lookup.remove(&id) {
234                                            if let Some(state) = lookups.get_mut(&lookup_id) {
235                                                if state.pending > 0 {
236                                                    state.pending -= 1;
237                                                }
238                                                if state.pending == 0 {
239                                                    let state = lookups.remove(&lookup_id).unwrap();
240                                                    let _ = state.resp.send(Ok(state.results));
241                                                }
242                                            }
243                                        }
244                                    }
245                                    QueryResult::GetRecord(Err(err)) => {
246                                        if let Some(LookupKind::Record { lookup_id }) = pending_lookup.remove(&id) {
247                                            if let Some(state) = lookups.get_mut(&lookup_id) {
248                                                if state.pending > 0 {
249                                                    state.pending -= 1;
250                                                }
251                                                if state.pending == 0 {
252                                                    let state = lookups.remove(&lookup_id).unwrap();
253                                                    let _ = state.resp.send(Err(DhtError::Dht(err.to_string())));
254                                                }
255                                            }
256                                        }
257                                    }
258                                    _ => {}
259                                }
260                            }
261                        }
262                        SwarmEvent::NewListenAddr { .. } => {}
263                        _ => {}
264                    }
265                }
266            }
267        });
268
269        metrics::inc_counter("rift_dht_started", &[]);
270        Ok(DhtHandle { cmd_tx })
271    }
272
273    /// Announce the current peer for a channel.
274    pub async fn announce(&self, key: ChannelId, info: PeerEndpointInfo) -> Result<(), DhtError> {
275        metrics::inc_counter("rift_dht_announce", &[]);
276        debug!(channel = %key.to_hex(), "dht announce");
277        let (tx, rx) = oneshot::channel();
278        let cmd = Command::Announce { key, info, resp: tx };
279        let _ = self.cmd_tx.send(cmd).await;
280        rx.await.unwrap_or(Err(DhtError::Dht("announce failed".to_string())))
281    }
282
283    /// Lookup all peers advertising a given channel.
284    pub async fn lookup(&self, key: ChannelId) -> Result<Vec<PeerEndpointInfo>, DhtError> {
285        metrics::inc_counter("rift_dht_lookup", &[]);
286        debug!(channel = %key.to_hex(), "dht lookup");
287        let (tx, rx) = oneshot::channel();
288        let cmd = Command::Lookup { key, resp: tx };
289        let _ = self.cmd_tx.send(cmd).await;
290        rx.await.unwrap_or(Err(DhtError::Dht("lookup failed".to_string())))
291    }
292}
293
294/// Convert a socket address to a multiaddr used by libp2p.
295fn socket_to_multiaddr(addr: SocketAddr) -> Multiaddr {
296    match addr {
297        SocketAddr::V4(v4) => Multiaddr::empty()
298            .with(Protocol::Ip4(*v4.ip()))
299            .with(Protocol::Tcp(v4.port())),
300        SocketAddr::V6(v6) => Multiaddr::empty()
301            .with(Protocol::Ip6(*v6.ip()))
302            .with(Protocol::Tcp(v6.port())),
303    }
304}
305
306/// Record key used to advertise a channel.
307fn channel_key(channel: ChannelId) -> RecordKey {
308    RecordKey::new(&channel.0)
309}
310
311/// Record key used to store a peer record within a channel.
312fn peer_record_key(channel: ChannelId, peer_id: RiftPeerId) -> RecordKey {
313    let mut bytes = Vec::with_capacity(64);
314    bytes.extend_from_slice(&channel.0);
315    bytes.extend_from_slice(&peer_id.0);
316    RecordKey::new(&bytes)
317}
318
319/// Record key used when we only have a libp2p peer id.
320fn peer_record_key_from_peer(channel: ChannelId, peer_id: PeerId) -> RecordKey {
321    let mut bytes = Vec::with_capacity(64);
322    bytes.extend_from_slice(&channel.0);
323    bytes.extend_from_slice(peer_id.to_bytes().as_ref());
324    RecordKey::new(&bytes)
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use std::net::{Ipv4Addr, Ipv6Addr};
331
332    #[test]
333    fn peer_endpoint_info_serialization_roundtrip() {
334        let info = PeerEndpointInfo {
335            peer_id: RiftPeerId([42u8; 32]),
336            addrs: vec![
337                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000),
338                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 9001),
339            ],
340        };
341
342        let serialized = bincode::serialize(&info).unwrap();
343        let deserialized: PeerEndpointInfo = bincode::deserialize(&serialized).unwrap();
344
345        assert_eq!(info.peer_id.0, deserialized.peer_id.0);
346        assert_eq!(info.addrs, deserialized.addrs);
347    }
348
349    #[test]
350    fn peer_endpoint_info_empty_addrs() {
351        let info = PeerEndpointInfo {
352            peer_id: RiftPeerId([0u8; 32]),
353            addrs: vec![],
354        };
355
356        let serialized = bincode::serialize(&info).unwrap();
357        let deserialized: PeerEndpointInfo = bincode::deserialize(&serialized).unwrap();
358
359        assert_eq!(info.addrs.len(), 0);
360        assert_eq!(deserialized.addrs.len(), 0);
361    }
362
363    #[test]
364    fn dht_config_construction() {
365        let config = DhtConfig {
366            bootstrap_nodes: vec![
367                SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 4001),
368            ],
369            listen_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
370        };
371
372        assert_eq!(config.bootstrap_nodes.len(), 1);
373        assert_eq!(config.listen_addr.port(), 0);
374    }
375
376    #[test]
377    fn dht_error_display() {
378        let err = DhtError::Transport("connection refused".to_string());
379        assert!(format!("{}", err).contains("transport error"));
380
381        let err = DhtError::Dht("no providers".to_string());
382        assert!(format!("{}", err).contains("dht error"));
383    }
384
385    #[test]
386    fn socket_to_multiaddr_ipv4() {
387        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 4001);
388        let multi = socket_to_multiaddr(addr);
389        let expected = "/ip4/127.0.0.1/tcp/4001".parse::<Multiaddr>().unwrap();
390        assert_eq!(multi, expected);
391    }
392
393    #[test]
394    fn socket_to_multiaddr_ipv6() {
395        let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 4001);
396        let multi = socket_to_multiaddr(addr);
397        let expected = "/ip6/::1/tcp/4001".parse::<Multiaddr>().unwrap();
398        assert_eq!(multi, expected);
399    }
400
401    #[test]
402    fn channel_key_deterministic() {
403        let channel = ChannelId([42u8; 32]);
404        let key1 = channel_key(channel);
405        let key2 = channel_key(channel);
406        assert_eq!(key1, key2);
407    }
408
409    #[test]
410    fn channel_key_different_channels() {
411        let channel1 = ChannelId([1u8; 32]);
412        let channel2 = ChannelId([2u8; 32]);
413        let key1 = channel_key(channel1);
414        let key2 = channel_key(channel2);
415        assert_ne!(key1, key2);
416    }
417
418    #[test]
419    fn peer_record_key_deterministic() {
420        let channel = ChannelId([42u8; 32]);
421        let peer = RiftPeerId([7u8; 32]);
422        let key1 = peer_record_key(channel, peer);
423        let key2 = peer_record_key(channel, peer);
424        assert_eq!(key1, key2);
425    }
426
427    #[test]
428    fn peer_record_key_different_peers() {
429        let channel = ChannelId([42u8; 32]);
430        let peer1 = RiftPeerId([1u8; 32]);
431        let peer2 = RiftPeerId([2u8; 32]);
432        let key1 = peer_record_key(channel, peer1);
433        let key2 = peer_record_key(channel, peer2);
434        assert_ne!(key1, key2);
435    }
436
437    #[test]
438    fn peer_record_key_different_channels() {
439        let channel1 = ChannelId([1u8; 32]);
440        let channel2 = ChannelId([2u8; 32]);
441        let peer = RiftPeerId([42u8; 32]);
442        let key1 = peer_record_key(channel1, peer);
443        let key2 = peer_record_key(channel2, peer);
444        assert_ne!(key1, key2);
445    }
446
447    use std::net::IpAddr;
448}