peernet/
lib.rs

1use anyhow::{anyhow, Result};
2use futures::StreamExt;
3use iroh::defaults::prod::default_relay_map;
4use iroh::discovery::dns::{DnsDiscovery, N0_DNS_NODE_ORIGIN_PROD};
5use iroh::discovery::pkarr::{
6    PkarrPublisher, PkarrResolver, DEFAULT_REPUBLISH_INTERVAL, N0_DNS_PKARR_RELAY_PROD,
7};
8use iroh::discovery::{ConcurrentDiscovery, Discovery};
9use iroh::endpoint::Endpoint as MagicEndpoint;
10use iroh::NodeAddr;
11use iroh::{PublicKey, SecretKey};
12use iroh::{RelayMap, RelayMode};
13use std::time::Duration;
14
15mod protocol;
16
17pub use crate::protocol::{
18    NotificationHandler, Protocol, ProtocolHandler, ProtocolHandlerBuilder, RequestHandler,
19    Subscription, SubscriptionHandler,
20};
21pub use iroh::endpoint::{Connection, RecvStream, SendStream};
22pub type PeerId = PublicKey;
23
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
25pub enum ResolverMode {
26    Pkarr,
27    Dns,
28}
29
30pub struct EndpointBuilder {
31    alpn: Vec<u8>,
32    handler: Option<ProtocolHandler>,
33    secret: Option<[u8; 32]>,
34    relay_map: Option<RelayMap>,
35    // discovery fields
36    resolver_mode: ResolverMode,
37    dns_origin: String,
38    pkarr_relay: String,
39    publish_ttl: Duration,
40    republish_interval: Duration,
41}
42
43impl EndpointBuilder {
44    pub fn new(alpn: Vec<u8>) -> Self {
45        Self {
46            alpn,
47            handler: None,
48            secret: None,
49            relay_map: Some(default_relay_map()),
50            resolver_mode: ResolverMode::Pkarr,
51            dns_origin: N0_DNS_NODE_ORIGIN_PROD.into(),
52            pkarr_relay: N0_DNS_PKARR_RELAY_PROD.into(),
53            publish_ttl: DEFAULT_REPUBLISH_INTERVAL * 4,
54            republish_interval: DEFAULT_REPUBLISH_INTERVAL,
55        }
56    }
57
58    pub fn secret(&mut self, secret: [u8; 32]) -> &mut Self {
59        self.secret = Some(secret);
60        self
61    }
62
63    pub fn handler(&mut self, handler: ProtocolHandler) -> &mut Self {
64        self.handler = Some(handler);
65        self
66    }
67
68    pub fn relay_map(&mut self, relay_map: Option<RelayMap>) -> &mut Self {
69        self.relay_map = relay_map;
70        self
71    }
72
73    pub fn resolver_mode(&mut self, mode: ResolverMode) -> &mut Self {
74        self.resolver_mode = mode;
75        self
76    }
77
78    pub fn dns_origin(&mut self, dns_origin: impl Into<String>) -> &mut Self {
79        self.dns_origin = dns_origin.into();
80        self
81    }
82
83    pub fn pkarr_relay(&mut self, pkarr_relay: impl Into<String>) -> &mut Self {
84        self.pkarr_relay = pkarr_relay.into();
85        self
86    }
87
88    pub fn republish_interval(&mut self, interval: Duration) -> &mut Self {
89        self.republish_interval = interval;
90        self
91    }
92
93    pub fn publish_ttl(&mut self, ttl: Duration) -> &mut Self {
94        self.publish_ttl = ttl;
95        self
96    }
97
98    pub async fn build(self) -> Result<Endpoint> {
99        let secret = self.secret.unwrap_or_else(|| {
100            let mut secret = [0; 32];
101            getrandom::getrandom(&mut secret).unwrap();
102            secret
103        });
104        Endpoint::new(
105            SecretKey::from(secret),
106            self.alpn,
107            self.relay_map,
108            self.handler,
109            self.resolver_mode,
110            self.pkarr_relay,
111            self.dns_origin,
112            self.publish_ttl,
113            self.republish_interval,
114        )
115        .await
116    }
117}
118
119#[derive(Clone)]
120pub struct Endpoint {
121    alpn: Vec<u8>,
122    endpoint: MagicEndpoint,
123}
124
125impl Endpoint {
126    pub fn builder(alpn: Vec<u8>) -> EndpointBuilder {
127        EndpointBuilder::new(alpn)
128    }
129
130    async fn new(
131        secret: SecretKey,
132        alpn: Vec<u8>,
133        relay_map: Option<RelayMap>,
134        handler: Option<ProtocolHandler>,
135        resolver_mode: ResolverMode,
136        pkarr_relay: String,
137        dns_origin: String,
138        ttl: Duration,
139        republish_interval: Duration,
140    ) -> Result<Self> {
141        let publisher = Box::new(PkarrPublisher::with_options(
142            secret.clone(),
143            pkarr_relay.parse()?,
144            ttl.as_secs().try_into()?,
145            republish_interval,
146        ));
147        let resolver = match resolver_mode {
148            ResolverMode::Pkarr => {
149                Box::new(PkarrResolver::new(pkarr_relay.parse()?)) as Box<dyn Discovery>
150            }
151            ResolverMode::Dns => Box::new(DnsDiscovery::new(dns_origin)),
152        };
153        let discovery = ConcurrentDiscovery::from_services(vec![publisher, resolver]);
154        let builder = MagicEndpoint::builder()
155            .secret_key(secret)
156            .alpns(vec![alpn.clone()])
157            .discovery(Box::new(discovery));
158        let builder = if let Some(relay_map) = relay_map {
159            builder.relay_mode(RelayMode::Custom(relay_map))
160        } else {
161            builder.relay_mode(RelayMode::Disabled)
162        };
163        let endpoint = builder.bind().await?;
164        if let Some(handler) = handler {
165            tokio::spawn(server(endpoint.clone(), handler));
166        }
167        Ok(Self { alpn, endpoint })
168    }
169
170    pub fn peer_id(&self) -> PeerId {
171        self.endpoint.node_id()
172    }
173
174    pub async fn addr(&self) -> Result<NodeAddr> {
175        Ok(self.endpoint.node_addr().await?)
176    }
177
178    pub fn add_address(&self, address: NodeAddr) -> Result<()> {
179        self.endpoint.add_node_addr(address)?;
180        Ok(())
181    }
182
183    pub async fn resolve(&self, peer_id: PeerId) -> Result<NodeAddr> {
184        Ok(self
185            .endpoint
186            .discovery()
187            .ok_or(anyhow!("no descovery"))?
188            .resolve(self.endpoint.clone(), peer_id)
189            .ok_or(anyhow!("no item resolved"))?
190            .next()
191            .await
192            .ok_or(anyhow!("no item discovered"))??
193            .to_node_addr())
194    }
195
196    pub async fn connect(&self, peer_id: PeerId) -> Result<Connection> {
197        Ok(self.endpoint.connect(peer_id, &self.alpn).await?)
198    }
199
200    pub async fn notify<P: Protocol>(&self, peer_id: PeerId, msg: &P::Request) -> Result<()> {
201        let mut conn = self.connect(peer_id).await?;
202        crate::protocol::notify::<P>(&mut conn, msg).await
203    }
204
205    pub async fn request<P: Protocol>(
206        &self,
207        peer_id: PeerId,
208        msg: &P::Request,
209    ) -> Result<P::Response> {
210        let mut conn = self.connect(peer_id).await?;
211        crate::protocol::request_response::<P>(&mut conn, msg).await
212    }
213
214    pub async fn subscribe<P: Protocol>(
215        &self,
216        peer_id: PeerId,
217        msg: &P::Request,
218    ) -> Result<Subscription<P::Response>> {
219        let mut conn = self.connect(peer_id).await?;
220        crate::protocol::subscribe::<P>(&mut conn, msg).await
221    }
222}
223
224async fn server(endpoint: MagicEndpoint, handler: ProtocolHandler) {
225    loop {
226        let Some(conn) = endpoint.accept().await else {
227            tracing::info!("socket closed");
228            break;
229        };
230        let accept_conn = move || async {
231            let conn = conn.await?;
232            let node_id = conn.remote_node_id()?;
233            Result::<_, anyhow::Error>::Ok((node_id, conn))
234        };
235        match accept_conn().await {
236            Ok((peer_id, conn)) => {
237                handler.handle(peer_id, conn);
238            }
239            Err(err) => {
240                tracing::error!("{err}");
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use futures::channel::{mpsc, oneshot};
250    use futures::SinkExt;
251    use serde::{Deserialize, Serialize};
252    use std::sync::Mutex;
253    use std::time::Duration;
254
255    const ALPN: &[u8] = b"/analog/tss/1";
256
257    #[derive(Debug, Deserialize, Serialize)]
258    pub struct Ping(u16);
259
260    #[derive(Debug, Deserialize, Serialize)]
261    pub struct Pong(u16);
262
263    pub struct PingPong;
264
265    impl Protocol for PingPong {
266        const ID: u16 = 0;
267        const REQ_BUF: usize = 1024;
268        const RES_BUF: usize = 1024;
269        type Request = Ping;
270        type Response = Pong;
271    }
272
273    impl RequestHandler<Self> for PingPong {
274        fn request(
275            &self,
276            _peer_id: PeerId,
277            request: <Self as Protocol>::Request,
278            response: oneshot::Sender<<Self as Protocol>::Response>,
279        ) -> Result<()> {
280            response
281                .send(Pong(request.0))
282                .map_err(|_| anyhow::anyhow!("response channel closed"))?;
283            Ok(())
284        }
285    }
286
287    impl SubscriptionHandler<Self> for PingPong {
288        fn subscribe(
289            &self,
290            _peer_id: PeerId,
291            request: <Self as Protocol>::Request,
292            mut response: mpsc::Sender<<Self as Protocol>::Response>,
293        ) -> Result<()> {
294            tokio::spawn(async move {
295                response.send(Pong(request.0)).await.unwrap();
296                response.send(Pong(request.0)).await.unwrap();
297            });
298            Ok(())
299        }
300    }
301
302    pub struct Pinger(Mutex<Option<oneshot::Sender<Ping>>>);
303
304    impl Pinger {
305        pub fn new() -> (Self, oneshot::Receiver<Ping>) {
306            let (tx, rx) = oneshot::channel();
307            (Self(Mutex::new(Some(tx))), rx)
308        }
309    }
310
311    impl NotificationHandler<PingPong> for Pinger {
312        fn notify(&self, _peer_id: PeerId, request: Ping) -> Result<()> {
313            self.0
314                .lock()
315                .unwrap()
316                .take()
317                .unwrap()
318                .send(request)
319                .unwrap();
320            Ok(())
321        }
322    }
323
324    async fn wait_for_addr(endpoint: &Endpoint) -> Result<NodeAddr> {
325        loop {
326            let addr = endpoint.addr().await?;
327            if addr.direct_addresses.is_empty() {
328                tracing::info!("waiting for publish");
329                tokio::time::sleep(Duration::from_secs(1)).await;
330                continue;
331            }
332            return Ok(addr);
333        }
334    }
335
336    async fn wait_for_resolve(endpoint: &Endpoint) -> Result<NodeAddr> {
337        loop {
338            let addr = wait_for_addr(endpoint).await?;
339            let Ok(resolved_addr) = endpoint.resolve(endpoint.peer_id()).await else {
340                tracing::info!("waiting for publish");
341                tokio::time::sleep(Duration::from_secs(1)).await;
342                continue;
343            };
344            if addr != resolved_addr {
345                tracing::info!("waiting for publish");
346                tokio::time::sleep(Duration::from_secs(1)).await;
347                continue;
348            }
349            return Ok(addr);
350        }
351    }
352
353    /*#[tokio::test]
354    async fn mdns() -> Result<()> {
355        env_logger::try_init().ok();
356
357        let mut builder = Endpoint::builder(ALPN.to_vec());
358        builder.enable_mdns();
359        let e1 = builder.build().await?;
360        let p1 = e1.peer_id();
361
362        let mut builder = Endpoint::builder(ALPN.to_vec());
363        builder.enable_mdns();
364        let e2 = builder.build().await?;
365
366        let a1_2 = wait_for_resolve(&e2, &p1).await?;
367        tracing::info!("resolved {:?}", a1_2);
368
369        let a1 = e1.addr().await?;
370        assert_eq!(a1.info, a1_2);
371        Ok(())
372    }*/
373
374    #[tokio::test]
375    async fn pkarr() -> Result<()> {
376        env_logger::try_init().ok();
377
378        let mut builder = Endpoint::builder(ALPN.to_vec());
379        builder.relay_map(None);
380        let e1 = builder.build().await?;
381        wait_for_resolve(&e1).await?;
382        Ok(())
383    }
384
385    #[tokio::test]
386    async fn notify() -> Result<()> {
387        env_logger::try_init().ok();
388
389        let (pinger, rx) = Pinger::new();
390        let mut builder = ProtocolHandler::builder();
391        builder.register_notification_handler(pinger);
392        let handler = builder.build();
393
394        let mut builder = Endpoint::builder(ALPN.to_vec());
395        builder.handler(handler);
396        let e1 = builder.build().await?;
397        let p1 = e1.peer_id();
398
399        let builder = Endpoint::builder(ALPN.to_vec());
400        let e2 = builder.build().await?;
401
402        let a1 = wait_for_addr(&e1).await?;
403
404        e2.add_address(a1)?;
405        e2.notify::<PingPong>(p1, &Ping(42)).await?;
406        let ping = rx.await?;
407        assert_eq!(ping.0, 42);
408        Ok(())
409    }
410
411    #[tokio::test]
412    async fn request_response() -> Result<()> {
413        env_logger::try_init().ok();
414
415        let mut builder = ProtocolHandler::builder();
416        builder.register_request_handler(PingPong);
417        let handler = builder.build();
418
419        let mut builder = Endpoint::builder(ALPN.to_vec());
420        builder.handler(handler);
421        let e1 = builder.build().await?;
422        let p1 = e1.peer_id();
423
424        let builder = Endpoint::builder(ALPN.to_vec());
425        let e2 = builder.build().await?;
426
427        let a1 = wait_for_addr(&e1).await?;
428
429        e2.add_address(a1)?;
430        let pong = e2.request::<PingPong>(p1, &Ping(42)).await?;
431        assert_eq!(pong.0, 42);
432        Ok(())
433    }
434
435    #[tokio::test]
436    async fn subscribe() -> Result<()> {
437        env_logger::try_init().ok();
438
439        let mut builder = ProtocolHandler::builder();
440        builder.register_subscription_handler(PingPong);
441        let handler = builder.build();
442
443        let mut builder = Endpoint::builder(ALPN.to_vec());
444        builder.handler(handler);
445        let e1 = builder.build().await?;
446        let p1 = e1.peer_id();
447
448        let builder = Endpoint::builder(ALPN.to_vec());
449        let e2 = builder.build().await?;
450
451        let a1 = wait_for_addr(&e1).await?;
452
453        e2.add_address(a1)?;
454        let mut subscription = e2.subscribe::<PingPong>(p1, &Ping(42)).await?;
455        while let Some(pong) = subscription.next().await? {
456            assert_eq!(pong.0, 42);
457        }
458        Ok(())
459    }
460}