1use std::{
2 sync::Arc,
3 time::{Duration, Instant},
4};
5
6use iroh::{
7 endpoint::Connection,
8 protocol::{AcceptError, ProtocolHandler},
9 Endpoint, EndpointAddr,
10};
11use iroh_metrics::{Counter, MetricsGroup};
12
13pub const ALPN: &[u8] = b"iroh/ping/0";
18
19#[derive(Debug, Clone)]
27pub struct Ping {
28 metrics: Arc<Metrics>,
29}
30
31impl Default for Ping {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl Ping {
38 pub fn new() -> Self {
40 Self {
41 metrics: Arc::new(Metrics::default()),
42 }
43 }
44
45 pub fn metrics(&self) -> &Arc<Metrics> {
47 &self.metrics
48 }
49
50 pub async fn ping(&self, endpoint: &Endpoint, addr: EndpointAddr) -> anyhow::Result<Duration> {
52 let conn = endpoint.connect(addr, ALPN).await?;
54
55 let (mut send, mut recv) = conn.open_bi().await?;
57
58 let start = Instant::now();
59 send.write_all(b"PING").await?;
61
62 send.finish()?;
64
65 let response = recv.read_to_end(4).await?;
67 assert_eq!(&response, b"PONG");
68
69 let ping = start.elapsed();
70
71 self.metrics.pings_sent.inc();
73
74 conn.close(0u32.into(), b"bye!");
77
78 Ok(ping)
79 }
80}
81
82impl ProtocolHandler for Ping {
83 async fn accept(&self, connection: Connection) -> n0_error::Result<(), AcceptError> {
88 let metrics = self.metrics.clone();
89
90 let node_id = connection.remote_id();
92 println!("accepted connection from {node_id}");
93
94 let (mut send, mut recv) = connection.accept_bi().await?;
97
98 let req = recv.read_to_end(4).await.map_err(AcceptError::from_err)?;
99 assert_eq!(&req, b"PING");
100
101 metrics.pings_recv.inc();
103
104 send.write_all(b"PONG")
106 .await
107 .map_err(AcceptError::from_err)?;
108
109 send.finish()?;
112
113 connection.closed().await;
116
117 Ok(())
118 }
119}
120
121#[derive(Debug, Default, MetricsGroup)]
123#[metrics(name = "ping")]
124pub struct Metrics {
125 pub pings_sent: Counter,
127 pub pings_recv: Counter,
129}
130
131#[cfg(test)]
132mod tests {
133 use anyhow::Result;
134 use iroh::{protocol::Router, Endpoint};
135
136 use super::*;
137
138 #[tokio::test]
139 async fn test_ping() -> Result<()> {
140 let server_endpoint = Endpoint::builder().bind().await?;
141 let server_ping = Ping::new();
142 let server_metrics = server_ping.metrics().clone();
143 let server_router = Router::builder(server_endpoint)
144 .accept(ALPN, server_ping)
145 .spawn();
146 let server_addr = server_router.endpoint().addr();
147
148 let client_endpoint = Endpoint::builder().bind().await?;
149 let client_ping = Ping::new();
150 let client_metrics = client_ping.metrics().clone();
151
152 let res = client_ping
153 .ping(&client_endpoint, server_addr.clone())
154 .await?;
155 println!("ping response: {res:?}");
156 assert_eq!(server_metrics.pings_recv.get(), 1);
157 assert_eq!(client_metrics.pings_sent.get(), 1);
158
159 let res = client_ping
160 .ping(&client_endpoint, server_addr.clone())
161 .await?;
162 println!("ping response: {res:?}");
163 assert_eq!(server_metrics.pings_recv.get(), 2);
164 assert_eq!(client_metrics.pings_sent.get(), 2);
165
166 client_endpoint.close().await;
167 server_router.shutdown().await?;
168
169 Ok(())
170 }
171}