1use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::Arc;
9
10use mdns_sd::{ServiceDaemon, ServiceEvent, ServiceInfo};
11use tokio::sync::mpsc;
12use tokio_stream::wrappers::ReceiverStream;
13use tokio_stream::Stream;
14
15use rift_core::{ChannelId, PeerId};
16use rift_dht::{DhtConfig, DhtHandle, PeerEndpointInfo};
17
18const SERVICE_TYPE: &str = "_rift._udp.local.";
20
21#[derive(Debug, Clone)]
22pub struct DiscoveryConfig {
23 pub channel_name: String,
25 pub password: Option<String>,
27 pub peer_id: PeerId,
29 pub listen_port: u16,
31}
32
33impl DiscoveryConfig {
34 pub fn channel_id(&self) -> ChannelId {
36 ChannelId::from_channel(&self.channel_name, self.password.as_deref())
37 }
38}
39
40#[derive(Debug, Clone)]
41pub struct PeerInfo {
42 pub peer_id: PeerId,
44 pub addr: SocketAddr,
46}
47
48#[derive(Debug, thiserror::Error)]
49pub enum DiscoveryError {
50 #[error("mdns error: {0}")]
52 Mdns(#[from] mdns_sd::Error),
53 #[error("missing peer info in mDNS record")]
55 MissingPeerInfo,
56 #[error("invalid peer id")]
58 InvalidPeerId,
59 #[error("dht error: {0}")]
61 Dht(String),
62}
63
64#[derive(Debug, Clone)]
65pub enum DiscoveryMode {
66 Lan,
68 Dht(DhtConfig),
70}
71
72pub async fn start_dht(config: DhtConfig) -> Result<DhtHandle, DiscoveryError> {
74 DhtHandle::new(config)
75 .await
76 .map_err(|e| DiscoveryError::Dht(e.to_string()))
77}
78
79pub async fn dht_announce(
81 handle: &DhtHandle,
82 channel_id: ChannelId,
83 info: PeerEndpointInfo,
84) -> Result<(), DiscoveryError> {
85 handle
86 .announce(channel_id, info)
87 .await
88 .map_err(|e| DiscoveryError::Dht(e.to_string()))
89}
90
91pub async fn dht_lookup(
93 handle: &DhtHandle,
94 channel_id: ChannelId,
95) -> Result<Vec<PeerEndpointInfo>, DiscoveryError> {
96 handle
97 .lookup(channel_id)
98 .await
99 .map_err(|e| DiscoveryError::Dht(e.to_string()))
100}
101
102pub struct MdnsHandle {
104 _daemon: Arc<ServiceDaemon>,
105 _service: ServiceInfo,
106}
107
108impl MdnsHandle {
109 pub fn new(daemon: Arc<ServiceDaemon>, service: ServiceInfo) -> Self {
111 Self {
112 _daemon: daemon,
113 _service: service,
114 }
115 }
116}
117
118pub fn start_mdns_advertisement(config: DiscoveryConfig) -> Result<MdnsHandle, DiscoveryError> {
120 let daemon = Arc::new(ServiceDaemon::new()?);
121 let channel_id = config.channel_id();
122 let channel_hex = hex::encode(channel_id.0);
123 let peer_hex = hex::encode(config.peer_id.0);
124
125 let instance_name = format!("rift-{}", &peer_hex[..8]);
126 let host_name = format!("{}.local.", instance_name);
127
128 let props = [("channel", channel_hex.as_str()), ("peer", peer_hex.as_str())];
129 let addrs = local_ipv4_addrs()
130 .unwrap_or_else(|_| vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]);
131 let service = ServiceInfo::new(
132 SERVICE_TYPE,
133 &instance_name,
134 &host_name,
135 addrs.as_slice(),
136 config.listen_port,
137 &props[..],
138 )?;
139 daemon.register(service.clone())?;
140
141 Ok(MdnsHandle::new(daemon, service))
142}
143
144pub fn discover_peers(
146 config: DiscoveryConfig,
147) -> Result<impl Stream<Item = PeerInfo>, DiscoveryError> {
148 let daemon = ServiceDaemon::new()?;
149 let channel_hex = hex::encode(config.channel_id().0);
150 let (tx, rx) = mpsc::channel(64);
151
152 let receiver = daemon.browse(SERVICE_TYPE)?;
153 std::thread::spawn(move || {
154 for event in receiver {
155 if let ServiceEvent::ServiceResolved(info) = event {
156 if let Some(peer) = peer_info_from_service(&info, &channel_hex) {
157 let _ = tx.blocking_send(peer);
158 }
159 }
160 }
161 });
162
163 Ok(MdnsStream {
164 _daemon: daemon,
165 inner: ReceiverStream::new(rx),
166 })
167}
168
169fn peer_info_from_service(info: &ServiceInfo, channel_hex: &str) -> Option<PeerInfo> {
171 let channel = info.get_property_val_str("channel")?;
172 if channel != channel_hex {
173 return None;
174 }
175 let peer_hex = info.get_property_val_str("peer")?;
176 let peer_bytes = hex::decode(peer_hex).ok()?;
177 if peer_bytes.len() != 32 {
178 return None;
179 }
180 let mut peer_id = [0u8; 32];
181 peer_id.copy_from_slice(&peer_bytes);
182
183 let port = info.get_port();
184 let addr = info
185 .get_addresses()
186 .iter()
187 .find_map(|addr| {
188 let sock = SocketAddr::new(*addr, port);
189 Some(sock)
190 })?;
191
192 Some(PeerInfo {
193 peer_id: PeerId(peer_id),
194 addr,
195 })
196}
197
198struct MdnsStream {
200 _daemon: ServiceDaemon,
201 inner: ReceiverStream<PeerInfo>,
202}
203
204impl Stream for MdnsStream {
205 type Item = PeerInfo;
206
207 fn poll_next(
209 mut self: std::pin::Pin<&mut Self>,
210 cx: &mut std::task::Context<'_>,
211 ) -> std::task::Poll<Option<Self::Item>> {
212 std::pin::Pin::new(&mut self.inner).poll_next(cx)
213 }
214}
215
216pub fn local_ipv4_addrs() -> Result<Vec<IpAddr>, DiscoveryError> {
218 let mut addrs = Vec::new();
219 let interfaces = if_addrs::get_if_addrs()
220 .map_err(|e| DiscoveryError::Mdns(mdns_sd::Error::Msg(e.to_string())))?;
221 for iface in interfaces {
222 if let IpAddr::V4(ip) = iface.ip() {
223 if !ip.is_unspecified() {
224 addrs.push(IpAddr::V4(ip));
225 }
226 }
227 }
228 if addrs.is_empty() {
229 addrs.push(IpAddr::V4(Ipv4Addr::LOCALHOST));
230 }
231 Ok(addrs)
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn channel_id_deterministic() {
240 let config1 = DiscoveryConfig {
241 channel_name: "test-channel".to_string(),
242 password: None,
243 peer_id: PeerId([0u8; 32]),
244 listen_port: 9000,
245 };
246
247 let config2 = DiscoveryConfig {
248 channel_name: "test-channel".to_string(),
249 password: None,
250 peer_id: PeerId([1u8; 32]), listen_port: 9001, };
253
254 assert_eq!(config1.channel_id(), config2.channel_id());
256 }
257
258 #[test]
259 fn channel_id_with_password() {
260 let config_no_pass = DiscoveryConfig {
261 channel_name: "test-channel".to_string(),
262 password: None,
263 peer_id: PeerId([0u8; 32]),
264 listen_port: 9000,
265 };
266
267 let config_with_pass = DiscoveryConfig {
268 channel_name: "test-channel".to_string(),
269 password: Some("secret".to_string()),
270 peer_id: PeerId([0u8; 32]),
271 listen_port: 9000,
272 };
273
274 assert_ne!(config_no_pass.channel_id(), config_with_pass.channel_id());
276 }
277
278 #[test]
279 fn channel_id_different_names() {
280 let config1 = DiscoveryConfig {
281 channel_name: "channel-a".to_string(),
282 password: None,
283 peer_id: PeerId([0u8; 32]),
284 listen_port: 9000,
285 };
286
287 let config2 = DiscoveryConfig {
288 channel_name: "channel-b".to_string(),
289 password: None,
290 peer_id: PeerId([0u8; 32]),
291 listen_port: 9000,
292 };
293
294 assert_ne!(config1.channel_id(), config2.channel_id());
296 }
297
298 #[test]
299 fn local_addrs_returns_something() {
300 let addrs = local_ipv4_addrs().unwrap();
301 assert!(!addrs.is_empty());
303 }
304
305 #[test]
306 fn local_addrs_are_ipv4() {
307 let addrs = local_ipv4_addrs().unwrap();
308 for addr in addrs {
309 assert!(matches!(addr, IpAddr::V4(_)));
310 }
311 }
312
313 #[test]
314 fn peer_info_construction() {
315 let peer = PeerInfo {
316 peer_id: PeerId([42u8; 32]),
317 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000),
318 };
319
320 assert_eq!(peer.peer_id.0, [42u8; 32]);
321 assert_eq!(peer.addr.port(), 9000);
322 }
323
324 #[test]
325 fn discovery_error_display() {
326 let err = DiscoveryError::MissingPeerInfo;
327 assert_eq!(format!("{}", err), "missing peer info in mDNS record");
328
329 let err = DiscoveryError::InvalidPeerId;
330 assert_eq!(format!("{}", err), "invalid peer id");
331 }
332}