iroh_ping/
lib.rs

1use std::{
2    sync::Arc,
3    time::{Duration, Instant},
4};
5
6use iroh::{
7    endpoint::Connection,
8    protocol::{AcceptError, ProtocolHandler},
9    Endpoint, EndpointAddr,
10};
11use iroh_metrics::{Counter, MetricsGroup};
12
13/// Each protocol is identified by its ALPN string.
14///
15/// The ALPN, or application-layer protocol negotiation, is exchanged in the connection handshake,
16/// and the connection is aborted unless both nodes pass the same bytestring.
17pub const ALPN: &[u8] = b"iroh/ping/0";
18
19/// Ping is our protocol struct.
20///
21/// We'll implement [`ProtocolHandler`] on this struct so we can use it with
22/// an [`iroh::protocol::Router`].
23/// It's also fine to keep state in this struct for use across many incoming
24/// connections, in this case we'll keep metrics about the amount of pings we
25/// sent or received.
26#[derive(Debug, Clone)]
27pub struct Ping {
28    metrics: Arc<Metrics>,
29}
30
31impl Default for Ping {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl Ping {
38    /// Creates new ping state.
39    pub fn new() -> Self {
40        Self {
41            metrics: Arc::new(Metrics::default()),
42        }
43    }
44
45    /// Returns a handle to ping metrics.
46    pub fn metrics(&self) -> &Arc<Metrics> {
47        &self.metrics
48    }
49
50    /// Sends a ping on the provided endpoint to a given node address.
51    pub async fn ping(&self, endpoint: &Endpoint, addr: EndpointAddr) -> anyhow::Result<Duration> {
52        // Open a connection to the accepting node
53        let conn = endpoint.connect(addr, ALPN).await?;
54
55        // Open a bidirectional QUIC stream
56        let (mut send, mut recv) = conn.open_bi().await?;
57
58        let start = Instant::now();
59        // Send some data to be pinged
60        send.write_all(b"PING").await?;
61
62        // Signal the end of data for this particular stream
63        send.finish()?;
64
65        // read the response, which must be PONG as bytes
66        let response = recv.read_to_end(4).await?;
67        assert_eq!(&response, b"PONG");
68
69        let ping = start.elapsed();
70
71        // at this point we've successfully pinged, mark the metric
72        self.metrics.pings_sent.inc();
73
74        // Explicitly close the whole connection, as we're the last ones to receive data
75        // and know there's nothing else more to do in the connection.
76        conn.close(0u32.into(), b"bye!");
77
78        Ok(ping)
79    }
80}
81
82impl ProtocolHandler for Ping {
83    /// The `accept` method is called for each incoming connection for our ALPN.
84    ///
85    /// The returned future runs on a newly spawned tokio task, so it can run as long as
86    /// the connection lasts.
87    async fn accept(&self, connection: Connection) -> n0_error::Result<(), AcceptError> {
88        let metrics = self.metrics.clone();
89
90        // We can get the remote's node id from the connection.
91        let node_id = connection.remote_id();
92        println!("accepted connection from {node_id}");
93
94        // Our protocol is a simple request-response protocol, so we expect the
95        // connecting peer to open a single bi-directional stream.
96        let (mut send, mut recv) = connection.accept_bi().await?;
97
98        let req = recv.read_to_end(4).await.map_err(AcceptError::from_err)?;
99        assert_eq!(&req, b"PING");
100
101        // increment count of pings we've received
102        metrics.pings_recv.inc();
103
104        // send back "PONG" bytes
105        send.write_all(b"PONG")
106            .await
107            .map_err(AcceptError::from_err)?;
108
109        // By calling `finish` on the send stream we signal that we will not send anything
110        // further, which makes the receive stream on the other end terminate.
111        send.finish()?;
112
113        // Wait until the remote closes the connection, which it does once it
114        // received the response.
115        connection.closed().await;
116
117        Ok(())
118    }
119}
120
121/// Enum of metrics for the module
122#[derive(Debug, Default, MetricsGroup)]
123#[metrics(name = "ping")]
124pub struct Metrics {
125    /// count of valid ping messages sent
126    pub pings_sent: Counter,
127    /// count of valid ping messages received
128    pub pings_recv: Counter,
129}
130
131#[cfg(test)]
132mod tests {
133    use anyhow::Result;
134    use iroh::{protocol::Router, Endpoint};
135
136    use super::*;
137
138    #[tokio::test]
139    async fn test_ping() -> Result<()> {
140        let server_endpoint = Endpoint::builder().bind().await?;
141        let server_ping = Ping::new();
142        let server_metrics = server_ping.metrics().clone();
143        let server_router = Router::builder(server_endpoint)
144            .accept(ALPN, server_ping)
145            .spawn();
146        let server_addr = server_router.endpoint().addr();
147
148        let client_endpoint = Endpoint::builder().bind().await?;
149        let client_ping = Ping::new();
150        let client_metrics = client_ping.metrics().clone();
151
152        let res = client_ping
153            .ping(&client_endpoint, server_addr.clone())
154            .await?;
155        println!("ping response: {res:?}");
156        assert_eq!(server_metrics.pings_recv.get(), 1);
157        assert_eq!(client_metrics.pings_sent.get(), 1);
158
159        let res = client_ping
160            .ping(&client_endpoint, server_addr.clone())
161            .await?;
162        println!("ping response: {res:?}");
163        assert_eq!(server_metrics.pings_recv.get(), 2);
164        assert_eq!(client_metrics.pings_sent.get(), 2);
165
166        client_endpoint.close().await;
167        server_router.shutdown().await?;
168
169        Ok(())
170    }
171}