Skip to main content

rift_discovery/
lib.rs

1//! Peer discovery helpers (mDNS + optional DHT).
2//!
3//! This module provides LAN discovery via mDNS and optional internet discovery
4//! via the DHT. It exposes async helpers to start advertisements and streams of
5//! discovered peers.
6
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::Arc;
9
10use mdns_sd::{ServiceDaemon, ServiceEvent, ServiceInfo};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::Stream;
14
15use rift_core::{ChannelId, PeerId};
16use rift_dht::{DhtConfig, DhtHandle, PeerEndpointInfo};
17
18/// mDNS service type used for LAN discovery.
19const SERVICE_TYPE: &str = "_rift._udp.local.";
20
21#[derive(Debug, Clone)]
22pub struct DiscoveryConfig {
23    /// Channel name used to compute the discovery key.
24    pub channel_name: String,
25    /// Optional channel password.
26    pub password: Option<String>,
27    /// Local peer id to advertise.
28    pub peer_id: PeerId,
29    /// Local UDP listen port.
30    pub listen_port: u16,
31}
32
33impl DiscoveryConfig {
34    /// Derive the channel id used for discovery filtering.
35    pub fn channel_id(&self) -> ChannelId {
36        ChannelId::from_channel(&self.channel_name, self.password.as_deref())
37    }
38}
39
40#[derive(Debug, Clone)]
41pub struct PeerInfo {
42    /// Peer id discovered on the network.
43    pub peer_id: PeerId,
44    /// Primary socket address for the peer.
45    pub addr: SocketAddr,
46}
47
48#[derive(Debug, thiserror::Error)]
49pub enum DiscoveryError {
50    /// mDNS library errors.
51    #[error("mdns error: {0}")]
52    Mdns(#[from] mdns_sd::Error),
53    /// Missing expected metadata in mDNS records.
54    #[error("missing peer info in mDNS record")]
55    MissingPeerInfo,
56    /// Peer id could not be parsed.
57    #[error("invalid peer id")]
58    InvalidPeerId,
59    /// DHT errors are wrapped as strings.
60    #[error("dht error: {0}")]
61    Dht(String),
62}
63
64#[derive(Debug, Clone)]
65pub enum DiscoveryMode {
66    /// Local LAN discovery only.
67    Lan,
68    /// Internet discovery via DHT.
69    Dht(DhtConfig),
70}
71
72/// Start a DHT instance for internet discovery.
73pub async fn start_dht(config: DhtConfig) -> Result<DhtHandle, DiscoveryError> {
74    DhtHandle::new(config)
75        .await
76        .map_err(|e| DiscoveryError::Dht(e.to_string()))
77}
78
79/// Announce a peer in the DHT for the given channel.
80pub async fn dht_announce(
81    handle: &DhtHandle,
82    channel_id: ChannelId,
83    info: PeerEndpointInfo,
84) -> Result<(), DiscoveryError> {
85    handle
86        .announce(channel_id, info)
87        .await
88        .map_err(|e| DiscoveryError::Dht(e.to_string()))
89}
90
91/// Lookup peers in the DHT for the given channel.
92pub async fn dht_lookup(
93    handle: &DhtHandle,
94    channel_id: ChannelId,
95) -> Result<Vec<PeerEndpointInfo>, DiscoveryError> {
96    handle
97        .lookup(channel_id)
98        .await
99        .map_err(|e| DiscoveryError::Dht(e.to_string()))
100}
101
102/// Keeps the mDNS daemon and service registration alive.
103pub struct MdnsHandle {
104    _daemon: Arc<ServiceDaemon>,
105    _service: ServiceInfo,
106}
107
108impl MdnsHandle {
109    /// Construct a handle from daemon + service info.
110    pub fn new(daemon: Arc<ServiceDaemon>, service: ServiceInfo) -> Self {
111        Self {
112            _daemon: daemon,
113            _service: service,
114        }
115    }
116}
117
118/// Publish this peer's presence on the LAN via mDNS.
119pub fn start_mdns_advertisement(config: DiscoveryConfig) -> Result<MdnsHandle, DiscoveryError> {
120    let daemon = Arc::new(ServiceDaemon::new()?);
121    let channel_id = config.channel_id();
122    let channel_hex = hex::encode(channel_id.0);
123    let peer_hex = hex::encode(config.peer_id.0);
124
125    let instance_name = format!("rift-{}", &peer_hex[..8]);
126    let host_name = format!("{}.local.", instance_name);
127
128    let props = [("channel", channel_hex.as_str()), ("peer", peer_hex.as_str())];
129    let addrs = local_ipv4_addrs()
130        .unwrap_or_else(|_| vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]);
131    let service = ServiceInfo::new(
132        SERVICE_TYPE,
133        &instance_name,
134        &host_name,
135        addrs.as_slice(),
136        config.listen_port,
137        &props[..],
138    )?;
139    daemon.register(service.clone())?;
140
141    Ok(MdnsHandle::new(daemon, service))
142}
143
144/// Start browsing for peers in the same channel on the LAN.
145pub fn discover_peers(
146    config: DiscoveryConfig,
147) -> Result<impl Stream<Item = PeerInfo>, DiscoveryError> {
148    let daemon = ServiceDaemon::new()?;
149    let channel_hex = hex::encode(config.channel_id().0);
150    let (tx, rx) = mpsc::channel(64);
151
152    let receiver = daemon.browse(SERVICE_TYPE)?;
153    std::thread::spawn(move || {
154        for event in receiver {
155            if let ServiceEvent::ServiceResolved(info) = event {
156                if let Some(peer) = peer_info_from_service(&info, &channel_hex) {
157                    let _ = tx.blocking_send(peer);
158                }
159            }
160        }
161    });
162
163    Ok(MdnsStream {
164        _daemon: daemon,
165        inner: ReceiverStream::new(rx),
166    })
167}
168
169/// Extract peer metadata from an mDNS service record.
170fn peer_info_from_service(info: &ServiceInfo, channel_hex: &str) -> Option<PeerInfo> {
171    let channel = info.get_property_val_str("channel")?;
172    if channel != channel_hex {
173        return None;
174    }
175    let peer_hex = info.get_property_val_str("peer")?;
176    let peer_bytes = hex::decode(peer_hex).ok()?;
177    if peer_bytes.len() != 32 {
178        return None;
179    }
180    let mut peer_id = [0u8; 32];
181    peer_id.copy_from_slice(&peer_bytes);
182
183    let port = info.get_port();
184    let addr = info
185        .get_addresses()
186        .iter()
187        .find_map(|addr| {
188            let sock = SocketAddr::new(*addr, port);
189            Some(sock)
190        })?;
191
192    Some(PeerInfo {
193        peer_id: PeerId(peer_id),
194        addr,
195    })
196}
197
198/// Stream wrapper that keeps the mDNS daemon alive.
199struct MdnsStream {
200    _daemon: ServiceDaemon,
201    inner: ReceiverStream<PeerInfo>,
202}
203
204impl Stream for MdnsStream {
205    type Item = PeerInfo;
206
207    /// Delegate polling to the underlying receiver stream.
208    fn poll_next(
209        mut self: std::pin::Pin<&mut Self>,
210        cx: &mut std::task::Context<'_>,
211    ) -> std::task::Poll<Option<Self::Item>> {
212        std::pin::Pin::new(&mut self.inner).poll_next(cx)
213    }
214}
215
216/// Enumerate local IPv4 addresses for mDNS advertisement.
217pub fn local_ipv4_addrs() -> Result<Vec<IpAddr>, DiscoveryError> {
218    let mut addrs = Vec::new();
219    let interfaces = if_addrs::get_if_addrs()
220        .map_err(|e| DiscoveryError::Mdns(mdns_sd::Error::Msg(e.to_string())))?;
221    for iface in interfaces {
222        if let IpAddr::V4(ip) = iface.ip() {
223            if !ip.is_unspecified() {
224                addrs.push(IpAddr::V4(ip));
225            }
226        }
227    }
228    if addrs.is_empty() {
229        addrs.push(IpAddr::V4(Ipv4Addr::LOCALHOST));
230    }
231    Ok(addrs)
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn channel_id_deterministic() {
240        let config1 = DiscoveryConfig {
241            channel_name: "test-channel".to_string(),
242            password: None,
243            peer_id: PeerId([0u8; 32]),
244            listen_port: 9000,
245        };
246
247        let config2 = DiscoveryConfig {
248            channel_name: "test-channel".to_string(),
249            password: None,
250            peer_id: PeerId([1u8; 32]), // Different peer id
251            listen_port: 9001,           // Different port
252        };
253
254        // Same channel name without password should produce same channel ID
255        assert_eq!(config1.channel_id(), config2.channel_id());
256    }
257
258    #[test]
259    fn channel_id_with_password() {
260        let config_no_pass = DiscoveryConfig {
261            channel_name: "test-channel".to_string(),
262            password: None,
263            peer_id: PeerId([0u8; 32]),
264            listen_port: 9000,
265        };
266
267        let config_with_pass = DiscoveryConfig {
268            channel_name: "test-channel".to_string(),
269            password: Some("secret".to_string()),
270            peer_id: PeerId([0u8; 32]),
271            listen_port: 9000,
272        };
273
274        // Different passwords should produce different channel IDs
275        assert_ne!(config_no_pass.channel_id(), config_with_pass.channel_id());
276    }
277
278    #[test]
279    fn channel_id_different_names() {
280        let config1 = DiscoveryConfig {
281            channel_name: "channel-a".to_string(),
282            password: None,
283            peer_id: PeerId([0u8; 32]),
284            listen_port: 9000,
285        };
286
287        let config2 = DiscoveryConfig {
288            channel_name: "channel-b".to_string(),
289            password: None,
290            peer_id: PeerId([0u8; 32]),
291            listen_port: 9000,
292        };
293
294        // Different channel names should produce different channel IDs
295        assert_ne!(config1.channel_id(), config2.channel_id());
296    }
297
298    #[test]
299    fn local_addrs_returns_something() {
300        let addrs = local_ipv4_addrs().unwrap();
301        // Should always return at least one address (even if just localhost)
302        assert!(!addrs.is_empty());
303    }
304
305    #[test]
306    fn local_addrs_are_ipv4() {
307        let addrs = local_ipv4_addrs().unwrap();
308        for addr in addrs {
309            assert!(matches!(addr, IpAddr::V4(_)));
310        }
311    }
312
313    #[test]
314    fn peer_info_construction() {
315        let peer = PeerInfo {
316            peer_id: PeerId([42u8; 32]),
317            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000),
318        };
319
320        assert_eq!(peer.peer_id.0, [42u8; 32]);
321        assert_eq!(peer.addr.port(), 9000);
322    }
323
324    #[test]
325    fn discovery_error_display() {
326        let err = DiscoveryError::MissingPeerInfo;
327        assert_eq!(format!("{}", err), "missing peer info in mDNS record");
328
329        let err = DiscoveryError::InvalidPeerId;
330        assert_eq!(format!("{}", err), "invalid peer id");
331    }
332}