xs/
listener.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use iroh::endpoint::{RecvStream, SendStream};
6use iroh::{Endpoint, RelayMode, SecretKey, Watcher};
7use iroh_base::ticket::NodeTicket;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio::net::{TcpListener, UnixListener};
10#[cfg(test)]
11use tokio::net::{TcpStream, UnixStream};
12
13/// The ALPN for xs protocol.
14pub const ALPN: &[u8] = b"XS/1.0";
15
16/// The handshake to send when connecting.
17/// The connecting side must send this handshake, the listening side must consume it.
18pub const HANDSHAKE: [u8; 5] = *b"xs..!";
19
20/// Get the secret key or generate a new one.
21/// Uses IROH_SECRET environment variable if available, otherwise generates a new one.
22fn get_or_create_secret() -> io::Result<SecretKey> {
23    match std::env::var("IROH_SECRET") {
24        Ok(secret) => {
25            use std::str::FromStr;
26            SecretKey::from_str(&secret).map_err(|e| {
27                io::Error::new(
28                    io::ErrorKind::InvalidData,
29                    format!("Invalid secret key: {e}"),
30                )
31            })
32        }
33        Err(_) => {
34            let key = SecretKey::generate(rand::rngs::OsRng);
35            tracing::info!(
36                "Generated new secret key: {}",
37                data_encoding::HEXLOWER.encode(&key.to_bytes())
38            );
39            Ok(key)
40        }
41    }
42}
43
44pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
45
46impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
47
48pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
49
50pub struct IrohStream {
51    send_stream: SendStream,
52    recv_stream: RecvStream,
53}
54
55impl IrohStream {
56    pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self {
57        Self {
58            send_stream,
59            recv_stream,
60        }
61    }
62}
63
64impl Drop for IrohStream {
65    fn drop(&mut self) {
66        // Send reset/stop signals to the other side
67        self.send_stream.reset(0u8.into()).ok();
68        self.recv_stream.stop(0u8.into()).ok();
69
70        tracing::debug!("IrohStream dropped with cleanup");
71    }
72}
73
74impl AsyncRead for IrohStream {
75    fn poll_read(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        buf: &mut tokio::io::ReadBuf<'_>,
79    ) -> Poll<io::Result<()>> {
80        let this = self.get_mut();
81        match Pin::new(&mut this.recv_stream).poll_read(cx, buf) {
82            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
83            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
84            Poll::Pending => Poll::Pending,
85        }
86    }
87}
88
89impl AsyncWrite for IrohStream {
90    fn poll_write(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &[u8],
94    ) -> Poll<io::Result<usize>> {
95        let this = self.get_mut();
96        match Pin::new(&mut this.send_stream).poll_write(cx, buf) {
97            Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
98            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
99            Poll::Pending => Poll::Pending,
100        }
101    }
102
103    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
104        let this = self.get_mut();
105        match Pin::new(&mut this.send_stream).poll_flush(cx) {
106            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
107            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
108            Poll::Pending => Poll::Pending,
109        }
110    }
111
112    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113        let this = self.get_mut();
114        match Pin::new(&mut this.send_stream).poll_shutdown(cx) {
115            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
116            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
117            Poll::Pending => Poll::Pending,
118        }
119    }
120}
121
122pub enum Listener {
123    Tcp(TcpListener),
124    Unix(UnixListener),
125    Iroh(Endpoint, String), // Endpoint and ticket
126}
127
128impl Listener {
129    pub async fn accept(
130        &mut self,
131    ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
132        match self {
133            Listener::Tcp(listener) => {
134                let (stream, addr) = listener.accept().await?;
135                Ok((Box::new(stream), Some(addr)))
136            }
137            Listener::Unix(listener) => {
138                let (stream, _) = listener.accept().await?;
139                Ok((Box::new(stream), None))
140            }
141            Listener::Iroh(endpoint, _) => {
142                // Accept incoming connections
143                let incoming = endpoint.accept().await.ok_or_else(|| {
144                    tracing::error!("No incoming iroh connection available");
145                    io::Error::other("No incoming connection")
146                })?;
147
148                let conn = incoming.await.map_err(|e| {
149                    tracing::error!("Failed to accept iroh connection: {}", e);
150                    io::Error::other(format!("Connection failed: {e}"))
151                })?;
152
153                let remote_node_id = "unknown"; // We'll use a placeholder for now
154                tracing::info!("Got iroh connection from {}", remote_node_id);
155
156                // Wait for the first incoming bidirectional stream
157                let (send_stream, mut recv_stream) = conn.accept_bi().await.map_err(|e| {
158                    tracing::error!(
159                        "Failed to accept bidirectional stream from {}: {}",
160                        remote_node_id,
161                        e
162                    );
163                    io::Error::other(format!("Failed to accept stream: {e}"))
164                })?;
165
166                tracing::debug!("Accepted bidirectional stream from {}", remote_node_id);
167
168                // Read and verify the handshake
169                let mut handshake_buf = [0u8; HANDSHAKE.len()];
170                #[allow(unused_imports)]
171                use tokio::io::AsyncReadExt;
172                recv_stream
173                    .read_exact(&mut handshake_buf)
174                    .await
175                    .map_err(|e| {
176                        tracing::error!("Failed to read handshake from {}: {}", remote_node_id, e);
177                        io::Error::other(format!("Failed to read handshake: {e}"))
178                    })?;
179
180                if handshake_buf != HANDSHAKE {
181                    tracing::error!(
182                        "Invalid handshake received from {}: expected {:?}, got {:?}",
183                        remote_node_id,
184                        HANDSHAKE,
185                        handshake_buf
186                    );
187                    return Err(io::Error::new(
188                        io::ErrorKind::InvalidData,
189                        format!("Invalid handshake from {remote_node_id}"),
190                    ));
191                }
192
193                tracing::info!("Handshake verified successfully from {}", remote_node_id);
194
195                let stream = IrohStream::new(send_stream, recv_stream);
196                Ok((Box::new(stream), None))
197            }
198        }
199    }
200
201    pub async fn bind(addr: &str) -> io::Result<Self> {
202        if addr.starts_with("iroh://") {
203            tracing::info!("Binding iroh endpoint");
204
205            let secret_key = get_or_create_secret()?;
206            let endpoint = Endpoint::builder()
207                .alpns(vec![ALPN.to_vec()])
208                .relay_mode(RelayMode::Default)
209                .secret_key(secret_key)
210                .bind()
211                .await
212                .map_err(|e| {
213                    tracing::error!("Failed to bind iroh endpoint: {}", e);
214                    io::Error::other(format!("Failed to bind endpoint: {e}"))
215                })?;
216
217            tracing::debug!("Iroh endpoint bound successfully");
218
219            // Wait for the endpoint to be fully ready before creating ticket
220            endpoint.home_relay().initialized().await;
221            let node_addr = endpoint.node_addr().initialized().await;
222
223            // Create a proper NodeTicket
224            let ticket = NodeTicket::new(node_addr.clone()).to_string();
225
226            tracing::info!("Iroh endpoint ready with node ID: {}", node_addr.node_id);
227            tracing::info!("Iroh ticket: {}", ticket);
228
229            Ok(Listener::Iroh(endpoint, ticket))
230        } else if addr.starts_with('/') || addr.starts_with('.') {
231            // attempt to remove the socket unconditionally
232            let _ = std::fs::remove_file(addr);
233            let listener = UnixListener::bind(addr)?;
234            Ok(Listener::Unix(listener))
235        } else {
236            let mut addr = addr.to_owned();
237            if addr.starts_with(':') {
238                addr = format!("127.0.0.1{addr}");
239            };
240            let listener = TcpListener::bind(addr).await?;
241            Ok(Listener::Tcp(listener))
242        }
243    }
244
245    pub fn get_ticket(&self) -> Option<&str> {
246        match self {
247            Listener::Iroh(_, ticket) => Some(ticket),
248            _ => None,
249        }
250    }
251
252    #[cfg(test)]
253    pub async fn connect(&self) -> io::Result<AsyncReadWriteBox> {
254        match self {
255            Listener::Tcp(listener) => {
256                let stream = TcpStream::connect(listener.local_addr()?).await?;
257                Ok(Box::new(stream))
258            }
259            Listener::Unix(listener) => {
260                let stream =
261                    UnixStream::connect(listener.local_addr()?.as_pathname().unwrap()).await?;
262                Ok(Box::new(stream))
263            }
264            Listener::Iroh(_, ticket) => {
265                let secret_key = get_or_create_secret()?;
266
267                // Create a client endpoint
268                let client_endpoint = Endpoint::builder()
269                    .alpns(vec![])
270                    .relay_mode(RelayMode::Default)
271                    .secret_key(secret_key)
272                    .bind()
273                    .await
274                    .map_err(io::Error::other)?;
275
276                // Parse ticket to get node address
277                let node_ticket: NodeTicket = ticket
278                    .parse()
279                    .map_err(|e| io::Error::other(format!("Invalid ticket: {}", e)))?;
280                let node_addr = node_ticket.node_addr().clone();
281
282                // Connect to the server
283                let conn = client_endpoint
284                    .connect(node_addr, ALPN)
285                    .await
286                    .map_err(io::Error::other)?;
287
288                // Open bidirectional stream
289                let (mut send_stream, recv_stream) =
290                    conn.open_bi().await.map_err(io::Error::other)?;
291
292                // Send handshake
293                #[allow(unused_imports)]
294                use tokio::io::AsyncWriteExt;
295                send_stream
296                    .write_all(&HANDSHAKE)
297                    .await
298                    .map_err(io::Error::other)?;
299
300                let stream = IrohStream::new(send_stream, recv_stream);
301                Ok(Box::new(stream))
302            }
303        }
304    }
305}
306
307impl std::fmt::Display for Listener {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        match self {
310            Listener::Tcp(listener) => {
311                let addr = listener.local_addr().unwrap();
312                write!(f, "{}:{}", addr.ip(), addr.port())
313            }
314            Listener::Unix(listener) => {
315                let addr = listener.local_addr().unwrap();
316                let path = addr.as_pathname().unwrap();
317                write!(f, "{}", path.display())
318            }
319            Listener::Iroh(_, ticket) => {
320                write!(f, "iroh://{ticket}")
321            }
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    use tokio::io::AsyncReadExt;
331    use tokio::io::AsyncWriteExt;
332
333    async fn exercise_listener(addr: &str) {
334        let mut listener = Listener::bind(addr).await.unwrap();
335        let mut client = listener.connect().await.unwrap();
336
337        let (mut serve, _) = listener.accept().await.unwrap();
338        let want = b"Hello from server!";
339        serve.write_all(want).await.unwrap();
340        drop(serve);
341
342        let mut got = Vec::new();
343        client.read_to_end(&mut got).await.unwrap();
344        assert_eq!(want.to_vec(), got);
345    }
346
347    #[tokio::test]
348    async fn test_bind_tcp() {
349        exercise_listener(":0").await;
350    }
351
352    #[tokio::test]
353    async fn test_bind_unix() {
354        let temp_dir = tempfile::tempdir().unwrap();
355        let path = temp_dir.path().join("test.sock");
356        let path = path.to_str().unwrap();
357        exercise_listener(path).await;
358    }
359
360    #[tokio::test]
361    #[ignore] // Skip by default due to network requirements
362    async fn test_bind_iroh() {
363        // This test may take longer due to network setup
364        exercise_listener("iroh://").await;
365    }
366}