1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use iroh::endpoint::{RecvStream, SendStream};
6use iroh::{Endpoint, RelayMode, SecretKey, Watcher};
7use iroh_base::ticket::NodeTicket;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio::net::TcpListener;
10
11#[cfg(unix)]
12use tokio::net::UnixListener;
13#[cfg(unix)]
14#[cfg(test)]
15use tokio::net::UnixStream;
16
17#[cfg(windows)]
18mod win_uds_compat {
19 use std::io;
20 use std::pin::Pin;
21 use std::task::{Context, Poll};
22 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
23 use win_uds::net::{AsyncListener, AsyncStream};
24
25 pub struct WinUnixStream(tokio_util::compat::Compat<AsyncStream>);
27
28 impl WinUnixStream {
29 pub async fn connect<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
30 use tokio_util::compat::FuturesAsyncReadCompatExt;
31 let stream = AsyncStream::connect(path).await?;
32 Ok(Self(stream.compat()))
33 }
34 }
35
36 impl AsyncRead for WinUnixStream {
37 fn poll_read(
38 mut self: Pin<&mut Self>,
39 cx: &mut Context<'_>,
40 buf: &mut ReadBuf<'_>,
41 ) -> Poll<io::Result<()>> {
42 Pin::new(&mut self.0).poll_read(cx, buf)
43 }
44 }
45
46 impl AsyncWrite for WinUnixStream {
47 fn poll_write(
48 mut self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 buf: &[u8],
51 ) -> Poll<io::Result<usize>> {
52 Pin::new(&mut self.0).poll_write(cx, buf)
53 }
54
55 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
56 Pin::new(&mut self.0).poll_flush(cx)
57 }
58
59 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
60 Pin::new(&mut self.0).poll_shutdown(cx)
61 }
62 }
63
64 pub struct WinUnixListener {
66 inner: AsyncListener,
67 path: std::path::PathBuf,
68 }
69
70 impl WinUnixListener {
71 pub fn bind<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
72 let path_buf = path.as_ref().to_path_buf();
73 Ok(Self {
74 inner: AsyncListener::bind(path)?,
75 path: path_buf,
76 })
77 }
78
79 pub async fn accept(&self) -> io::Result<(WinUnixStream, ())> {
80 use tokio_util::compat::FuturesAsyncReadCompatExt;
81 let (stream, _addr) = self.inner.accept().await?;
82 Ok((WinUnixStream(stream.compat()), ()))
83 }
84
85 pub fn local_addr(&self) -> io::Result<std::path::PathBuf> {
86 Ok(self.path.clone())
87 }
88 }
89}
90
91#[cfg(windows)]
92use win_uds_compat::WinUnixListener as UnixListener;
93#[cfg(windows)]
94pub use win_uds_compat::WinUnixStream;
95
96#[cfg(test)]
97use tokio::net::TcpStream;
98
99pub const ALPN: &[u8] = b"XS/1.0";
101
102pub const HANDSHAKE: [u8; 5] = *b"xs..!";
105
106fn is_windows_path(s: &str) -> bool {
108 let bytes = s.as_bytes();
109 bytes.len() >= 3
110 && bytes[0].is_ascii_alphabetic()
111 && bytes[1] == b':'
112 && (bytes[2] == b'\\' || bytes[2] == b'/')
113}
114
115fn get_or_create_secret() -> io::Result<SecretKey> {
118 match std::env::var("IROH_SECRET") {
119 Ok(secret) => {
120 use std::str::FromStr;
121 SecretKey::from_str(&secret).map_err(|e| {
122 io::Error::new(
123 io::ErrorKind::InvalidData,
124 format!("Invalid secret key: {e}"),
125 )
126 })
127 }
128 Err(_) => {
129 let key = SecretKey::generate(rand::rngs::OsRng);
130 tracing::info!(
131 "Generated new secret key: {}",
132 data_encoding::HEXLOWER.encode(&key.to_bytes())
133 );
134 Ok(key)
135 }
136 }
137}
138
139pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
140
141impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
142
143pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
144
145pub struct IrohStream {
146 send_stream: SendStream,
147 recv_stream: RecvStream,
148}
149
150impl IrohStream {
151 pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self {
152 Self {
153 send_stream,
154 recv_stream,
155 }
156 }
157}
158
159impl Drop for IrohStream {
160 fn drop(&mut self) {
161 self.send_stream.reset(0u8.into()).ok();
163 self.recv_stream.stop(0u8.into()).ok();
164
165 tracing::debug!("IrohStream dropped with cleanup");
166 }
167}
168
169impl AsyncRead for IrohStream {
170 fn poll_read(
171 self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 buf: &mut tokio::io::ReadBuf<'_>,
174 ) -> Poll<io::Result<()>> {
175 let this = self.get_mut();
176 match Pin::new(&mut this.recv_stream).poll_read(cx, buf) {
177 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
178 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
179 Poll::Pending => Poll::Pending,
180 }
181 }
182}
183
184impl AsyncWrite for IrohStream {
185 fn poll_write(
186 self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 buf: &[u8],
189 ) -> Poll<io::Result<usize>> {
190 let this = self.get_mut();
191 match Pin::new(&mut this.send_stream).poll_write(cx, buf) {
192 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
193 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
194 Poll::Pending => Poll::Pending,
195 }
196 }
197
198 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199 let this = self.get_mut();
200 match Pin::new(&mut this.send_stream).poll_flush(cx) {
201 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
202 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
203 Poll::Pending => Poll::Pending,
204 }
205 }
206
207 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 let this = self.get_mut();
209 match Pin::new(&mut this.send_stream).poll_shutdown(cx) {
210 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
211 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
212 Poll::Pending => Poll::Pending,
213 }
214 }
215}
216
217pub enum Listener {
218 Tcp(TcpListener),
219 Unix(UnixListener),
220 Iroh(Endpoint, String), }
222
223impl Listener {
224 pub async fn accept(
225 &mut self,
226 ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
227 match self {
228 Listener::Tcp(listener) => {
229 let (stream, addr) = listener.accept().await?;
230 Ok((Box::new(stream), Some(addr)))
231 }
232 Listener::Unix(listener) => {
233 let (stream, _) = listener.accept().await?;
234 Ok((Box::new(stream), None))
235 }
236 Listener::Iroh(endpoint, _) => {
237 let incoming = endpoint.accept().await.ok_or_else(|| {
239 tracing::error!("No incoming iroh connection available");
240 io::Error::other("No incoming connection")
241 })?;
242
243 let conn = incoming.await.map_err(|e| {
244 tracing::error!("Failed to accept iroh connection: {}", e);
245 io::Error::other(format!("Connection failed: {e}"))
246 })?;
247
248 let remote_node_id = "unknown"; tracing::info!("Got iroh connection from {}", remote_node_id);
250
251 let (send_stream, mut recv_stream) = conn.accept_bi().await.map_err(|e| {
253 tracing::error!(
254 "Failed to accept bidirectional stream from {}: {}",
255 remote_node_id,
256 e
257 );
258 io::Error::other(format!("Failed to accept stream: {e}"))
259 })?;
260
261 tracing::debug!("Accepted bidirectional stream from {}", remote_node_id);
262
263 let mut handshake_buf = [0u8; HANDSHAKE.len()];
265 #[allow(unused_imports)]
266 use tokio::io::AsyncReadExt;
267 recv_stream
268 .read_exact(&mut handshake_buf)
269 .await
270 .map_err(|e| {
271 tracing::error!("Failed to read handshake from {}: {}", remote_node_id, e);
272 io::Error::other(format!("Failed to read handshake: {e}"))
273 })?;
274
275 if handshake_buf != HANDSHAKE {
276 tracing::error!(
277 "Invalid handshake received from {}: expected {:?}, got {:?}",
278 remote_node_id,
279 HANDSHAKE,
280 handshake_buf
281 );
282 return Err(io::Error::new(
283 io::ErrorKind::InvalidData,
284 format!("Invalid handshake from {remote_node_id}"),
285 ));
286 }
287
288 tracing::info!("Handshake verified successfully from {}", remote_node_id);
289
290 let stream = IrohStream::new(send_stream, recv_stream);
291 Ok((Box::new(stream), None))
292 }
293 }
294 }
295
296 pub async fn bind(addr: &str) -> io::Result<Self> {
297 if addr.starts_with("iroh://") {
298 tracing::info!("Binding iroh endpoint");
299
300 let secret_key = get_or_create_secret()?;
301 let endpoint = Endpoint::builder()
302 .alpns(vec![ALPN.to_vec()])
303 .relay_mode(RelayMode::Default)
304 .secret_key(secret_key)
305 .bind()
306 .await
307 .map_err(|e| {
308 tracing::error!("Failed to bind iroh endpoint: {}", e);
309 io::Error::other(format!("Failed to bind endpoint: {e}"))
310 })?;
311
312 tracing::debug!("Iroh endpoint bound successfully");
313
314 endpoint.home_relay().initialized().await;
316 let node_addr = endpoint.node_addr().initialized().await;
317
318 let ticket = NodeTicket::new(node_addr.clone()).to_string();
320
321 tracing::info!("Iroh endpoint ready with node ID: {}", node_addr.node_id);
322 tracing::info!("Iroh ticket: {}", ticket);
323
324 Ok(Listener::Iroh(endpoint, ticket))
325 } else if addr.starts_with('/') || addr.starts_with('.') || is_windows_path(addr) {
326 let _ = std::fs::remove_file(addr);
328 let listener = UnixListener::bind(addr)?;
329 Ok(Listener::Unix(listener))
330 } else {
331 let mut addr = addr.to_owned();
332 if addr.starts_with(':') {
333 addr = format!("127.0.0.1{addr}");
334 };
335 let listener = TcpListener::bind(addr).await?;
336 Ok(Listener::Tcp(listener))
337 }
338 }
339
340 pub fn get_ticket(&self) -> Option<&str> {
341 match self {
342 Listener::Iroh(_, ticket) => Some(ticket),
343 _ => None,
344 }
345 }
346
347 #[cfg(test)]
348 pub async fn connect(&self) -> io::Result<AsyncReadWriteBox> {
349 match self {
350 Listener::Tcp(listener) => {
351 let stream = TcpStream::connect(listener.local_addr()?).await?;
352 Ok(Box::new(stream))
353 }
354 Listener::Unix(listener) => {
355 #[cfg(unix)]
356 {
357 let stream =
358 UnixStream::connect(listener.local_addr()?.as_pathname().unwrap()).await?;
359 Ok(Box::new(stream))
360 }
361 #[cfg(windows)]
362 {
363 let path = listener.local_addr()?;
364 let stream = WinUnixStream::connect(&path).await?;
365 Ok(Box::new(stream))
366 }
367 }
368 Listener::Iroh(_, ticket) => {
369 let secret_key = get_or_create_secret()?;
370
371 let client_endpoint = Endpoint::builder()
373 .alpns(vec![])
374 .relay_mode(RelayMode::Default)
375 .secret_key(secret_key)
376 .bind()
377 .await
378 .map_err(io::Error::other)?;
379
380 let node_ticket: NodeTicket = ticket
382 .parse()
383 .map_err(|e| io::Error::other(format!("Invalid ticket: {}", e)))?;
384 let node_addr = node_ticket.node_addr().clone();
385
386 let conn = client_endpoint
388 .connect(node_addr, ALPN)
389 .await
390 .map_err(io::Error::other)?;
391
392 let (mut send_stream, recv_stream) =
394 conn.open_bi().await.map_err(io::Error::other)?;
395
396 #[allow(unused_imports)]
398 use tokio::io::AsyncWriteExt;
399 send_stream
400 .write_all(&HANDSHAKE)
401 .await
402 .map_err(io::Error::other)?;
403
404 let stream = IrohStream::new(send_stream, recv_stream);
405 Ok(Box::new(stream))
406 }
407 }
408 }
409}
410
411impl std::fmt::Display for Listener {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 match self {
414 Listener::Tcp(listener) => {
415 let addr = listener.local_addr().unwrap();
416 write!(f, "{}:{}", addr.ip(), addr.port())
417 }
418 Listener::Unix(listener) => {
419 #[cfg(unix)]
420 {
421 let addr = listener.local_addr().unwrap();
422 let path = addr.as_pathname().unwrap();
423 write!(f, "{}", path.display())
424 }
425 #[cfg(windows)]
426 {
427 let path = listener.local_addr().unwrap();
428 write!(f, "{}", path.display())
429 }
430 }
431 Listener::Iroh(_, ticket) => {
432 write!(f, "iroh://{ticket}")
433 }
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 use tokio::io::AsyncReadExt;
443 use tokio::io::AsyncWriteExt;
444
445 async fn exercise_listener(addr: &str) {
446 let mut listener = Listener::bind(addr).await.unwrap();
447 let mut client = listener.connect().await.unwrap();
448
449 let (mut serve, _) = listener.accept().await.unwrap();
450 let want = b"Hello from server!";
451 serve.write_all(want).await.unwrap();
452 drop(serve);
453
454 let mut got = Vec::new();
455 client.read_to_end(&mut got).await.unwrap();
456 assert_eq!(want.to_vec(), got);
457 }
458
459 #[tokio::test]
460 async fn test_bind_tcp() {
461 exercise_listener(":0").await;
462 }
463
464 #[tokio::test]
465 async fn test_bind_unix() {
466 let temp_dir = tempfile::tempdir().unwrap();
467 let path = temp_dir.path().join("test.sock");
468 let path = path.to_str().unwrap();
469 exercise_listener(path).await;
470 }
471
472 #[tokio::test]
473 #[ignore] async fn test_bind_iroh() {
475 exercise_listener("iroh://").await;
477 }
478}