xs/
listener.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use iroh::endpoint::{RecvStream, SendStream};
6use iroh::{Endpoint, RelayMode, SecretKey, Watcher};
7use iroh_base::ticket::NodeTicket;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio::net::TcpListener;
10
11#[cfg(unix)]
12use tokio::net::UnixListener;
13#[cfg(unix)]
14#[cfg(test)]
15use tokio::net::UnixStream;
16
17#[cfg(windows)]
18mod win_uds_compat {
19    use std::io;
20    use std::pin::Pin;
21    use std::task::{Context, Poll};
22    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
23    use win_uds::net::{AsyncListener, AsyncStream};
24
25    /// Wrapper to adapt win_uds AsyncStream to tokio's AsyncRead/AsyncWrite
26    pub struct WinUnixStream(tokio_util::compat::Compat<AsyncStream>);
27
28    impl WinUnixStream {
29        pub async fn connect<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
30            use tokio_util::compat::FuturesAsyncReadCompatExt;
31            let stream = AsyncStream::connect(path).await?;
32            Ok(Self(stream.compat()))
33        }
34    }
35
36    impl AsyncRead for WinUnixStream {
37        fn poll_read(
38            mut self: Pin<&mut Self>,
39            cx: &mut Context<'_>,
40            buf: &mut ReadBuf<'_>,
41        ) -> Poll<io::Result<()>> {
42            Pin::new(&mut self.0).poll_read(cx, buf)
43        }
44    }
45
46    impl AsyncWrite for WinUnixStream {
47        fn poll_write(
48            mut self: Pin<&mut Self>,
49            cx: &mut Context<'_>,
50            buf: &[u8],
51        ) -> Poll<io::Result<usize>> {
52            Pin::new(&mut self.0).poll_write(cx, buf)
53        }
54
55        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
56            Pin::new(&mut self.0).poll_flush(cx)
57        }
58
59        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
60            Pin::new(&mut self.0).poll_shutdown(cx)
61        }
62    }
63
64    /// Wrapper for win_uds AsyncListener
65    pub struct WinUnixListener {
66        inner: AsyncListener,
67        path: std::path::PathBuf,
68    }
69
70    impl WinUnixListener {
71        pub fn bind<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
72            let path_buf = path.as_ref().to_path_buf();
73            Ok(Self {
74                inner: AsyncListener::bind(path)?,
75                path: path_buf,
76            })
77        }
78
79        pub async fn accept(&self) -> io::Result<(WinUnixStream, ())> {
80            use tokio_util::compat::FuturesAsyncReadCompatExt;
81            let (stream, _addr) = self.inner.accept().await?;
82            Ok((WinUnixStream(stream.compat()), ()))
83        }
84
85        pub fn local_addr(&self) -> io::Result<std::path::PathBuf> {
86            Ok(self.path.clone())
87        }
88    }
89}
90
91#[cfg(windows)]
92use win_uds_compat::WinUnixListener as UnixListener;
93#[cfg(windows)]
94pub use win_uds_compat::WinUnixStream;
95
96#[cfg(test)]
97use tokio::net::TcpStream;
98
99/// The ALPN for xs protocol.
100pub const ALPN: &[u8] = b"XS/1.0";
101
102/// The handshake to send when connecting.
103/// The connecting side must send this handshake, the listening side must consume it.
104pub const HANDSHAKE: [u8; 5] = *b"xs..!";
105
106/// Check if a string looks like a Windows absolute path (e.g., "C:\..." or "D:\...")
107fn is_windows_path(s: &str) -> bool {
108    let bytes = s.as_bytes();
109    bytes.len() >= 3
110        && bytes[0].is_ascii_alphabetic()
111        && bytes[1] == b':'
112        && (bytes[2] == b'\\' || bytes[2] == b'/')
113}
114
115/// Get the secret key or generate a new one.
116/// Uses IROH_SECRET environment variable if available, otherwise generates a new one.
117fn get_or_create_secret() -> io::Result<SecretKey> {
118    match std::env::var("IROH_SECRET") {
119        Ok(secret) => {
120            use std::str::FromStr;
121            SecretKey::from_str(&secret).map_err(|e| {
122                io::Error::new(
123                    io::ErrorKind::InvalidData,
124                    format!("Invalid secret key: {e}"),
125                )
126            })
127        }
128        Err(_) => {
129            let key = SecretKey::generate(rand::rngs::OsRng);
130            tracing::info!(
131                "Generated new secret key: {}",
132                data_encoding::HEXLOWER.encode(&key.to_bytes())
133            );
134            Ok(key)
135        }
136    }
137}
138
139pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
140
141impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
142
143pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
144
145pub struct IrohStream {
146    send_stream: SendStream,
147    recv_stream: RecvStream,
148}
149
150impl IrohStream {
151    pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self {
152        Self {
153            send_stream,
154            recv_stream,
155        }
156    }
157}
158
159impl Drop for IrohStream {
160    fn drop(&mut self) {
161        // Send reset/stop signals to the other side
162        self.send_stream.reset(0u8.into()).ok();
163        self.recv_stream.stop(0u8.into()).ok();
164
165        tracing::debug!("IrohStream dropped with cleanup");
166    }
167}
168
169impl AsyncRead for IrohStream {
170    fn poll_read(
171        self: Pin<&mut Self>,
172        cx: &mut Context<'_>,
173        buf: &mut tokio::io::ReadBuf<'_>,
174    ) -> Poll<io::Result<()>> {
175        let this = self.get_mut();
176        match Pin::new(&mut this.recv_stream).poll_read(cx, buf) {
177            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
178            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
179            Poll::Pending => Poll::Pending,
180        }
181    }
182}
183
184impl AsyncWrite for IrohStream {
185    fn poll_write(
186        self: Pin<&mut Self>,
187        cx: &mut Context<'_>,
188        buf: &[u8],
189    ) -> Poll<io::Result<usize>> {
190        let this = self.get_mut();
191        match Pin::new(&mut this.send_stream).poll_write(cx, buf) {
192            Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
193            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
194            Poll::Pending => Poll::Pending,
195        }
196    }
197
198    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199        let this = self.get_mut();
200        match Pin::new(&mut this.send_stream).poll_flush(cx) {
201            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
202            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
203            Poll::Pending => Poll::Pending,
204        }
205    }
206
207    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        let this = self.get_mut();
209        match Pin::new(&mut this.send_stream).poll_shutdown(cx) {
210            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
211            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
212            Poll::Pending => Poll::Pending,
213        }
214    }
215}
216
217pub enum Listener {
218    Tcp(TcpListener),
219    Unix(UnixListener),
220    Iroh(Endpoint, String), // Endpoint and ticket
221}
222
223impl Listener {
224    pub async fn accept(
225        &mut self,
226    ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
227        match self {
228            Listener::Tcp(listener) => {
229                let (stream, addr) = listener.accept().await?;
230                Ok((Box::new(stream), Some(addr)))
231            }
232            Listener::Unix(listener) => {
233                let (stream, _) = listener.accept().await?;
234                Ok((Box::new(stream), None))
235            }
236            Listener::Iroh(endpoint, _) => {
237                // Accept incoming connections
238                let incoming = endpoint.accept().await.ok_or_else(|| {
239                    tracing::error!("No incoming iroh connection available");
240                    io::Error::other("No incoming connection")
241                })?;
242
243                let conn = incoming.await.map_err(|e| {
244                    tracing::error!("Failed to accept iroh connection: {}", e);
245                    io::Error::other(format!("Connection failed: {e}"))
246                })?;
247
248                let remote_node_id = "unknown"; // We'll use a placeholder for now
249                tracing::info!("Got iroh connection from {}", remote_node_id);
250
251                // Wait for the first incoming bidirectional stream
252                let (send_stream, mut recv_stream) = conn.accept_bi().await.map_err(|e| {
253                    tracing::error!(
254                        "Failed to accept bidirectional stream from {}: {}",
255                        remote_node_id,
256                        e
257                    );
258                    io::Error::other(format!("Failed to accept stream: {e}"))
259                })?;
260
261                tracing::debug!("Accepted bidirectional stream from {}", remote_node_id);
262
263                // Read and verify the handshake
264                let mut handshake_buf = [0u8; HANDSHAKE.len()];
265                #[allow(unused_imports)]
266                use tokio::io::AsyncReadExt;
267                recv_stream
268                    .read_exact(&mut handshake_buf)
269                    .await
270                    .map_err(|e| {
271                        tracing::error!("Failed to read handshake from {}: {}", remote_node_id, e);
272                        io::Error::other(format!("Failed to read handshake: {e}"))
273                    })?;
274
275                if handshake_buf != HANDSHAKE {
276                    tracing::error!(
277                        "Invalid handshake received from {}: expected {:?}, got {:?}",
278                        remote_node_id,
279                        HANDSHAKE,
280                        handshake_buf
281                    );
282                    return Err(io::Error::new(
283                        io::ErrorKind::InvalidData,
284                        format!("Invalid handshake from {remote_node_id}"),
285                    ));
286                }
287
288                tracing::info!("Handshake verified successfully from {}", remote_node_id);
289
290                let stream = IrohStream::new(send_stream, recv_stream);
291                Ok((Box::new(stream), None))
292            }
293        }
294    }
295
296    pub async fn bind(addr: &str) -> io::Result<Self> {
297        if addr.starts_with("iroh://") {
298            tracing::info!("Binding iroh endpoint");
299
300            let secret_key = get_or_create_secret()?;
301            let endpoint = Endpoint::builder()
302                .alpns(vec![ALPN.to_vec()])
303                .relay_mode(RelayMode::Default)
304                .secret_key(secret_key)
305                .bind()
306                .await
307                .map_err(|e| {
308                    tracing::error!("Failed to bind iroh endpoint: {}", e);
309                    io::Error::other(format!("Failed to bind endpoint: {e}"))
310                })?;
311
312            tracing::debug!("Iroh endpoint bound successfully");
313
314            // Wait for the endpoint to be fully ready before creating ticket
315            endpoint.home_relay().initialized().await;
316            let node_addr = endpoint.node_addr().initialized().await;
317
318            // Create a proper NodeTicket
319            let ticket = NodeTicket::new(node_addr.clone()).to_string();
320
321            tracing::info!("Iroh endpoint ready with node ID: {}", node_addr.node_id);
322            tracing::info!("Iroh ticket: {}", ticket);
323
324            Ok(Listener::Iroh(endpoint, ticket))
325        } else if addr.starts_with('/') || addr.starts_with('.') || is_windows_path(addr) {
326            // attempt to remove the socket unconditionally
327            let _ = std::fs::remove_file(addr);
328            let listener = UnixListener::bind(addr)?;
329            Ok(Listener::Unix(listener))
330        } else {
331            let mut addr = addr.to_owned();
332            if addr.starts_with(':') {
333                addr = format!("127.0.0.1{addr}");
334            };
335            let listener = TcpListener::bind(addr).await?;
336            Ok(Listener::Tcp(listener))
337        }
338    }
339
340    pub fn get_ticket(&self) -> Option<&str> {
341        match self {
342            Listener::Iroh(_, ticket) => Some(ticket),
343            _ => None,
344        }
345    }
346
347    #[cfg(test)]
348    pub async fn connect(&self) -> io::Result<AsyncReadWriteBox> {
349        match self {
350            Listener::Tcp(listener) => {
351                let stream = TcpStream::connect(listener.local_addr()?).await?;
352                Ok(Box::new(stream))
353            }
354            Listener::Unix(listener) => {
355                #[cfg(unix)]
356                {
357                    let stream =
358                        UnixStream::connect(listener.local_addr()?.as_pathname().unwrap()).await?;
359                    Ok(Box::new(stream))
360                }
361                #[cfg(windows)]
362                {
363                    let path = listener.local_addr()?;
364                    let stream = WinUnixStream::connect(&path).await?;
365                    Ok(Box::new(stream))
366                }
367            }
368            Listener::Iroh(_, ticket) => {
369                let secret_key = get_or_create_secret()?;
370
371                // Create a client endpoint
372                let client_endpoint = Endpoint::builder()
373                    .alpns(vec![])
374                    .relay_mode(RelayMode::Default)
375                    .secret_key(secret_key)
376                    .bind()
377                    .await
378                    .map_err(io::Error::other)?;
379
380                // Parse ticket to get node address
381                let node_ticket: NodeTicket = ticket
382                    .parse()
383                    .map_err(|e| io::Error::other(format!("Invalid ticket: {}", e)))?;
384                let node_addr = node_ticket.node_addr().clone();
385
386                // Connect to the server
387                let conn = client_endpoint
388                    .connect(node_addr, ALPN)
389                    .await
390                    .map_err(io::Error::other)?;
391
392                // Open bidirectional stream
393                let (mut send_stream, recv_stream) =
394                    conn.open_bi().await.map_err(io::Error::other)?;
395
396                // Send handshake
397                #[allow(unused_imports)]
398                use tokio::io::AsyncWriteExt;
399                send_stream
400                    .write_all(&HANDSHAKE)
401                    .await
402                    .map_err(io::Error::other)?;
403
404                let stream = IrohStream::new(send_stream, recv_stream);
405                Ok(Box::new(stream))
406            }
407        }
408    }
409}
410
411impl std::fmt::Display for Listener {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        match self {
414            Listener::Tcp(listener) => {
415                let addr = listener.local_addr().unwrap();
416                write!(f, "{}:{}", addr.ip(), addr.port())
417            }
418            Listener::Unix(listener) => {
419                #[cfg(unix)]
420                {
421                    let addr = listener.local_addr().unwrap();
422                    let path = addr.as_pathname().unwrap();
423                    write!(f, "{}", path.display())
424                }
425                #[cfg(windows)]
426                {
427                    let path = listener.local_addr().unwrap();
428                    write!(f, "{}", path.display())
429                }
430            }
431            Listener::Iroh(_, ticket) => {
432                write!(f, "iroh://{ticket}")
433            }
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    use tokio::io::AsyncReadExt;
443    use tokio::io::AsyncWriteExt;
444
445    async fn exercise_listener(addr: &str) {
446        let mut listener = Listener::bind(addr).await.unwrap();
447        let mut client = listener.connect().await.unwrap();
448
449        let (mut serve, _) = listener.accept().await.unwrap();
450        let want = b"Hello from server!";
451        serve.write_all(want).await.unwrap();
452        drop(serve);
453
454        let mut got = Vec::new();
455        client.read_to_end(&mut got).await.unwrap();
456        assert_eq!(want.to_vec(), got);
457    }
458
459    #[tokio::test]
460    async fn test_bind_tcp() {
461        exercise_listener(":0").await;
462    }
463
464    #[tokio::test]
465    async fn test_bind_unix() {
466        let temp_dir = tempfile::tempdir().unwrap();
467        let path = temp_dir.path().join("test.sock");
468        let path = path.to_str().unwrap();
469        exercise_listener(path).await;
470    }
471
472    #[tokio::test]
473    #[ignore] // Skip by default due to network requirements
474    async fn test_bind_iroh() {
475        // This test may take longer due to network setup
476        exercise_listener("iroh://").await;
477    }
478}