1use std::collections::HashMap;
8use std::net::SocketAddr;
9
10use anyhow::Result;
11use libp2p::core::upgrade;
12use libp2p::identify::{Behaviour as Identify, Config as IdentifyConfig, Event as IdentifyEvent};
13use libp2p::kad::{
14 store::MemoryStore, Behaviour as Kademlia, Event as KademliaEvent, GetProvidersOk,
15 GetRecordOk, PutRecordOk, QueryId, QueryResult, Quorum, Record, RecordKey,
16};
17use libp2p::multiaddr::Protocol;
18use libp2p::noise;
19use libp2p::swarm::{NetworkBehaviour, Swarm, SwarmEvent};
20use libp2p::{tcp, yamux, Multiaddr, PeerId, Transport};
21use rift_core::{ChannelId, PeerId as RiftPeerId};
22use rift_metrics as metrics;
23use tracing::debug;
24use serde::{Deserialize, Serialize};
25use tokio::sync::{mpsc, oneshot};
26use futures::StreamExt;
27
28#[derive(Debug, Clone)]
29pub struct DhtConfig {
30 pub bootstrap_nodes: Vec<SocketAddr>,
32 pub listen_addr: SocketAddr,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct PeerEndpointInfo {
38 pub peer_id: RiftPeerId,
40 pub addrs: Vec<SocketAddr>,
42}
43
44#[derive(Debug, thiserror::Error)]
45pub enum DhtError {
46 #[error("transport error: {0}")]
48 Transport(String),
49 #[error("dht error: {0}")]
51 Dht(String),
52}
53
54#[derive(Clone)]
55pub struct DhtHandle {
56 cmd_tx: mpsc::Sender<Command>,
58}
59
60enum Command {
61 Announce {
63 key: ChannelId,
64 info: PeerEndpointInfo,
65 resp: oneshot::Sender<Result<(), DhtError>>,
66 },
67 Lookup {
69 key: ChannelId,
70 resp: oneshot::Sender<Result<Vec<PeerEndpointInfo>, DhtError>>,
71 },
72}
73
74#[derive(NetworkBehaviour)]
76struct Behaviour {
77 kademlia: Kademlia<MemoryStore>,
78 identify: Identify,
79}
80
81struct LookupState {
83 channel: ChannelId,
84 pending: usize,
85 results: Vec<PeerEndpointInfo>,
86 resp: oneshot::Sender<Result<Vec<PeerEndpointInfo>, DhtError>>,
87}
88
89enum LookupKind {
91 Providers { lookup_id: u64 },
92 Record { lookup_id: u64 },
93}
94
95impl DhtHandle {
96 pub async fn new(config: DhtConfig) -> Result<DhtHandle, DhtError> {
98 let local_key = libp2p::identity::Keypair::generate_ed25519();
99 let local_peer_id = PeerId::from(local_key.public());
100 let transport = tcp::tokio::Transport::new(tcp::Config::default().nodelay(true))
101 .upgrade(upgrade::Version::V1)
102 .authenticate(noise::Config::new(&local_key).map_err(|e| DhtError::Transport(e.to_string()))?)
103 .multiplex(yamux::Config::default())
104 .boxed();
105
106 let store = MemoryStore::new(local_peer_id);
107 let mut kademlia = Kademlia::new(local_peer_id, store);
108 kademlia.set_mode(Some(libp2p::kad::Mode::Server));
109
110 let identify = Identify::new(IdentifyConfig::new(
111 "rift-dht/1.0.0".to_string(),
112 local_key.public(),
113 ));
114
115 let behaviour = Behaviour { kademlia, identify };
116 let mut swarm = Swarm::new(
117 transport,
118 behaviour,
119 local_peer_id,
120 libp2p::swarm::Config::with_tokio_executor(),
121 );
122
123 let listen_addr = socket_to_multiaddr(config.listen_addr);
124 swarm
125 .listen_on(listen_addr)
126 .map_err(|e| DhtError::Transport(e.to_string()))?;
127
128 let (cmd_tx, mut cmd_rx) = mpsc::channel(64);
129 let mut pending_put: HashMap<QueryId, oneshot::Sender<Result<(), DhtError>>> = HashMap::new();
130 let mut pending_lookup: HashMap<QueryId, LookupKind> = HashMap::new();
131 let mut lookups: HashMap<u64, LookupState> = HashMap::new();
132 let mut next_lookup_id = 1u64;
133
134 for addr in config.bootstrap_nodes {
135 let multi = socket_to_multiaddr(addr);
136 let _ = swarm.dial(multi);
137 }
138
139 tokio::spawn(async move {
141 loop {
142 tokio::select! {
143 Some(cmd) = cmd_rx.recv() => match cmd {
144 Command::Announce { key, info, resp } => {
145 let channel_key = channel_key(key);
146 let record_key = peer_record_key(key, info.peer_id);
147 let value = bincode::serialize(&info).unwrap_or_default();
148 let record = Record { key: record_key, value, publisher: None, expires: None };
149 let qid = swarm.behaviour_mut().kademlia.put_record(record, Quorum::One);
150 if let Ok(qid) = qid {
151 pending_put.insert(qid, resp);
152 } else {
153 let _ = resp.send(Err(DhtError::Dht("put record failed".to_string())));
154 }
155 let _ = swarm.behaviour_mut().kademlia.start_providing(channel_key);
156 }
157 Command::Lookup { key, resp } => {
158 let lookup_id = next_lookup_id;
159 next_lookup_id += 1;
160 let qid = swarm.behaviour_mut().kademlia.get_providers(channel_key(key));
161 pending_lookup.insert(qid, LookupKind::Providers { lookup_id });
162 lookups.insert(lookup_id, LookupState { channel: key, pending: 0, results: Vec::new(), resp });
163 }
164 },
165 event = swarm.select_next_some() => match event {
166 SwarmEvent::Behaviour(BehaviourEvent::Identify(IdentifyEvent::Received { peer_id, info, .. })) => {
167 for addr in info.listen_addrs {
168 swarm.behaviour_mut().kademlia.add_address(&peer_id, addr);
169 }
170 let _ = swarm.behaviour_mut().kademlia.bootstrap();
171 }
172 SwarmEvent::Behaviour(BehaviourEvent::Kademlia(event)) => {
173 if let KademliaEvent::OutboundQueryProgressed { id, result, .. } = event {
174 match result {
175 QueryResult::PutRecord(Ok(PutRecordOk { .. })) => {
176 if let Some(resp) = pending_put.remove(&id) {
177 let _ = resp.send(Ok(()));
178 }
179 }
180 QueryResult::PutRecord(Err(err)) => {
181 if let Some(resp) = pending_put.remove(&id) {
182 let _ = resp.send(Err(DhtError::Dht(err.to_string())));
183 }
184 }
185 QueryResult::GetProviders(Ok(GetProvidersOk::FoundProviders { providers, .. })) => {
186 if let Some(LookupKind::Providers { lookup_id }) = pending_lookup.remove(&id) {
187 if let Some(state) = lookups.get_mut(&lookup_id) {
188 if providers.is_empty() {
189 let state = lookups.remove(&lookup_id).unwrap();
190 let _ = state.resp.send(Ok(state.results));
191 } else {
192 state.pending = providers.len();
193 for provider in providers {
194 let record_key = peer_record_key_from_peer(state.channel, provider);
195 let qid = swarm.behaviour_mut().kademlia.get_record(record_key);
196 pending_lookup.insert(qid, LookupKind::Record { lookup_id });
197 }
198 }
199 }
200 }
201 }
202 QueryResult::GetProviders(Ok(GetProvidersOk::FinishedWithNoAdditionalRecord { .. })) => {
203 if let Some(LookupKind::Providers { lookup_id }) = pending_lookup.remove(&id) {
204 if let Some(state) = lookups.remove(&lookup_id) {
205 let _ = state.resp.send(Ok(state.results));
206 }
207 }
208 }
209 QueryResult::GetProviders(Err(err)) => {
210 if let Some(LookupKind::Providers { lookup_id }) = pending_lookup.remove(&id) {
211 if let Some(state) = lookups.remove(&lookup_id) {
212 let _ = state.resp.send(Err(DhtError::Dht(err.to_string())));
213 }
214 }
215 }
216 QueryResult::GetRecord(Ok(GetRecordOk::FoundRecord(record))) => {
217 if let Some(LookupKind::Record { lookup_id }) = pending_lookup.remove(&id) {
218 if let Some(state) = lookups.get_mut(&lookup_id) {
219 if let Ok(info) = bincode::deserialize::<PeerEndpointInfo>(&record.record.value) {
220 state.results.push(info);
221 }
222 if state.pending > 0 {
223 state.pending -= 1;
224 }
225 if state.pending == 0 {
226 let state = lookups.remove(&lookup_id).unwrap();
227 let _ = state.resp.send(Ok(state.results));
228 }
229 }
230 }
231 }
232 QueryResult::GetRecord(Ok(GetRecordOk::FinishedWithNoAdditionalRecord { .. })) => {
233 if let Some(LookupKind::Record { lookup_id }) = pending_lookup.remove(&id) {
234 if let Some(state) = lookups.get_mut(&lookup_id) {
235 if state.pending > 0 {
236 state.pending -= 1;
237 }
238 if state.pending == 0 {
239 let state = lookups.remove(&lookup_id).unwrap();
240 let _ = state.resp.send(Ok(state.results));
241 }
242 }
243 }
244 }
245 QueryResult::GetRecord(Err(err)) => {
246 if let Some(LookupKind::Record { lookup_id }) = pending_lookup.remove(&id) {
247 if let Some(state) = lookups.get_mut(&lookup_id) {
248 if state.pending > 0 {
249 state.pending -= 1;
250 }
251 if state.pending == 0 {
252 let state = lookups.remove(&lookup_id).unwrap();
253 let _ = state.resp.send(Err(DhtError::Dht(err.to_string())));
254 }
255 }
256 }
257 }
258 _ => {}
259 }
260 }
261 }
262 SwarmEvent::NewListenAddr { .. } => {}
263 _ => {}
264 }
265 }
266 }
267 });
268
269 metrics::inc_counter("rift_dht_started", &[]);
270 Ok(DhtHandle { cmd_tx })
271 }
272
273 pub async fn announce(&self, key: ChannelId, info: PeerEndpointInfo) -> Result<(), DhtError> {
275 metrics::inc_counter("rift_dht_announce", &[]);
276 debug!(channel = %key.to_hex(), "dht announce");
277 let (tx, rx) = oneshot::channel();
278 let cmd = Command::Announce { key, info, resp: tx };
279 let _ = self.cmd_tx.send(cmd).await;
280 rx.await.unwrap_or(Err(DhtError::Dht("announce failed".to_string())))
281 }
282
283 pub async fn lookup(&self, key: ChannelId) -> Result<Vec<PeerEndpointInfo>, DhtError> {
285 metrics::inc_counter("rift_dht_lookup", &[]);
286 debug!(channel = %key.to_hex(), "dht lookup");
287 let (tx, rx) = oneshot::channel();
288 let cmd = Command::Lookup { key, resp: tx };
289 let _ = self.cmd_tx.send(cmd).await;
290 rx.await.unwrap_or(Err(DhtError::Dht("lookup failed".to_string())))
291 }
292}
293
294fn socket_to_multiaddr(addr: SocketAddr) -> Multiaddr {
296 match addr {
297 SocketAddr::V4(v4) => Multiaddr::empty()
298 .with(Protocol::Ip4(*v4.ip()))
299 .with(Protocol::Tcp(v4.port())),
300 SocketAddr::V6(v6) => Multiaddr::empty()
301 .with(Protocol::Ip6(*v6.ip()))
302 .with(Protocol::Tcp(v6.port())),
303 }
304}
305
306fn channel_key(channel: ChannelId) -> RecordKey {
308 RecordKey::new(&channel.0)
309}
310
311fn peer_record_key(channel: ChannelId, peer_id: RiftPeerId) -> RecordKey {
313 let mut bytes = Vec::with_capacity(64);
314 bytes.extend_from_slice(&channel.0);
315 bytes.extend_from_slice(&peer_id.0);
316 RecordKey::new(&bytes)
317}
318
319fn peer_record_key_from_peer(channel: ChannelId, peer_id: PeerId) -> RecordKey {
321 let mut bytes = Vec::with_capacity(64);
322 bytes.extend_from_slice(&channel.0);
323 bytes.extend_from_slice(peer_id.to_bytes().as_ref());
324 RecordKey::new(&bytes)
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use std::net::{Ipv4Addr, Ipv6Addr};
331
332 #[test]
333 fn peer_endpoint_info_serialization_roundtrip() {
334 let info = PeerEndpointInfo {
335 peer_id: RiftPeerId([42u8; 32]),
336 addrs: vec![
337 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000),
338 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 9001),
339 ],
340 };
341
342 let serialized = bincode::serialize(&info).unwrap();
343 let deserialized: PeerEndpointInfo = bincode::deserialize(&serialized).unwrap();
344
345 assert_eq!(info.peer_id.0, deserialized.peer_id.0);
346 assert_eq!(info.addrs, deserialized.addrs);
347 }
348
349 #[test]
350 fn peer_endpoint_info_empty_addrs() {
351 let info = PeerEndpointInfo {
352 peer_id: RiftPeerId([0u8; 32]),
353 addrs: vec![],
354 };
355
356 let serialized = bincode::serialize(&info).unwrap();
357 let deserialized: PeerEndpointInfo = bincode::deserialize(&serialized).unwrap();
358
359 assert_eq!(info.addrs.len(), 0);
360 assert_eq!(deserialized.addrs.len(), 0);
361 }
362
363 #[test]
364 fn dht_config_construction() {
365 let config = DhtConfig {
366 bootstrap_nodes: vec![
367 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), 4001),
368 ],
369 listen_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
370 };
371
372 assert_eq!(config.bootstrap_nodes.len(), 1);
373 assert_eq!(config.listen_addr.port(), 0);
374 }
375
376 #[test]
377 fn dht_error_display() {
378 let err = DhtError::Transport("connection refused".to_string());
379 assert!(format!("{}", err).contains("transport error"));
380
381 let err = DhtError::Dht("no providers".to_string());
382 assert!(format!("{}", err).contains("dht error"));
383 }
384
385 #[test]
386 fn socket_to_multiaddr_ipv4() {
387 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 4001);
388 let multi = socket_to_multiaddr(addr);
389 let expected = "/ip4/127.0.0.1/tcp/4001".parse::<Multiaddr>().unwrap();
390 assert_eq!(multi, expected);
391 }
392
393 #[test]
394 fn socket_to_multiaddr_ipv6() {
395 let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 4001);
396 let multi = socket_to_multiaddr(addr);
397 let expected = "/ip6/::1/tcp/4001".parse::<Multiaddr>().unwrap();
398 assert_eq!(multi, expected);
399 }
400
401 #[test]
402 fn channel_key_deterministic() {
403 let channel = ChannelId([42u8; 32]);
404 let key1 = channel_key(channel);
405 let key2 = channel_key(channel);
406 assert_eq!(key1, key2);
407 }
408
409 #[test]
410 fn channel_key_different_channels() {
411 let channel1 = ChannelId([1u8; 32]);
412 let channel2 = ChannelId([2u8; 32]);
413 let key1 = channel_key(channel1);
414 let key2 = channel_key(channel2);
415 assert_ne!(key1, key2);
416 }
417
418 #[test]
419 fn peer_record_key_deterministic() {
420 let channel = ChannelId([42u8; 32]);
421 let peer = RiftPeerId([7u8; 32]);
422 let key1 = peer_record_key(channel, peer);
423 let key2 = peer_record_key(channel, peer);
424 assert_eq!(key1, key2);
425 }
426
427 #[test]
428 fn peer_record_key_different_peers() {
429 let channel = ChannelId([42u8; 32]);
430 let peer1 = RiftPeerId([1u8; 32]);
431 let peer2 = RiftPeerId([2u8; 32]);
432 let key1 = peer_record_key(channel, peer1);
433 let key2 = peer_record_key(channel, peer2);
434 assert_ne!(key1, key2);
435 }
436
437 #[test]
438 fn peer_record_key_different_channels() {
439 let channel1 = ChannelId([1u8; 32]);
440 let channel2 = ChannelId([2u8; 32]);
441 let peer = RiftPeerId([42u8; 32]);
442 let key1 = peer_record_key(channel1, peer);
443 let key2 = peer_record_key(channel2, peer);
444 assert_ne!(key1, key2);
445 }
446
447 use std::net::IpAddr;
448}