1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use iroh::endpoint::{RecvStream, SendStream};
6use iroh::{Endpoint, RelayMode, SecretKey, Watcher};
7use iroh_base::ticket::NodeTicket;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio::net::TcpListener;
10
11#[cfg(unix)]
12use tokio::net::UnixListener;
13#[cfg(unix)]
14#[cfg(test)]
15use tokio::net::UnixStream;
16
17#[cfg(windows)]
18mod win_uds_compat {
19 use std::io;
20 use std::pin::Pin;
21 use std::task::{Context, Poll};
22 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
23 use win_uds::net::{AsyncListener, AsyncStream};
24
25 pub struct WinUnixStream(tokio_util::compat::Compat<AsyncStream>);
27
28 impl WinUnixStream {
29 pub async fn connect<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
30 use tokio_util::compat::FuturesAsyncReadCompatExt;
31 let stream: AsyncStream = AsyncStream::connect(path).await?;
32 Ok(Self(stream.compat()))
33 }
34 }
35
36 impl AsyncRead for WinUnixStream {
37 fn poll_read(
38 mut self: Pin<&mut Self>,
39 cx: &mut Context<'_>,
40 buf: &mut ReadBuf<'_>,
41 ) -> Poll<io::Result<()>> {
42 Pin::new(&mut self.0).poll_read(cx, buf)
43 }
44 }
45
46 impl AsyncWrite for WinUnixStream {
47 fn poll_write(
48 mut self: Pin<&mut Self>,
49 cx: &mut Context<'_>,
50 buf: &[u8],
51 ) -> Poll<io::Result<usize>> {
52 Pin::new(&mut self.0).poll_write(cx, buf)
53 }
54
55 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
56 Pin::new(&mut self.0).poll_flush(cx)
57 }
58
59 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
60 Pin::new(&mut self.0).poll_shutdown(cx)
61 }
62 }
63
64 pub struct WinUnixListener {
66 inner: AsyncListener,
67 path: std::path::PathBuf,
68 }
69
70 impl WinUnixListener {
71 pub fn bind<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
72 let path_buf = path.as_ref().to_path_buf();
73 Ok(Self {
74 inner: AsyncListener::bind(path)?,
75 path: path_buf,
76 })
77 }
78
79 pub async fn accept(&self) -> io::Result<(WinUnixStream, ())> {
80 use tokio_util::compat::FuturesAsyncReadCompatExt;
81 let (stream, _addr): (AsyncStream, _) = self.inner.accept().await?;
82 Ok((WinUnixStream(stream.compat()), ()))
83 }
84
85 pub fn local_addr(&self) -> io::Result<std::path::PathBuf> {
86 Ok(self.path.clone())
87 }
88 }
89}
90
91#[cfg(windows)]
92use win_uds_compat::WinUnixListener as UnixListener;
93#[cfg(windows)]
94pub use win_uds_compat::WinUnixStream;
95
96#[cfg(test)]
97use tokio::net::TcpStream;
98
99pub const ALPN: &[u8] = b"XS/1.0";
101
102pub const HANDSHAKE: [u8; 5] = *b"xs..!";
105
106fn is_windows_path(s: &str) -> bool {
108 let bytes = s.as_bytes();
109 bytes.len() >= 3
110 && bytes[0].is_ascii_alphabetic()
111 && bytes[1] == b':'
112 && (bytes[2] == b'\\' || bytes[2] == b'/')
113}
114
115fn get_or_create_secret() -> io::Result<SecretKey> {
118 match std::env::var("IROH_SECRET") {
119 Ok(secret) => {
120 use std::str::FromStr;
121 SecretKey::from_str(&secret).map_err(|e| {
122 io::Error::new(
123 io::ErrorKind::InvalidData,
124 format!("Invalid secret key: {e}"),
125 )
126 })
127 }
128 Err(_) => {
129 let key = SecretKey::generate(rand::rngs::OsRng);
130 tracing::info!(
131 "Generated new secret key: {}",
132 data_encoding::HEXLOWER.encode(&key.to_bytes())
133 );
134 Ok(key)
135 }
136 }
137}
138
139pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
140
141impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
142
143pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
144
145pub struct IrohStream {
146 send_stream: SendStream,
147 recv_stream: RecvStream,
148}
149
150impl IrohStream {
151 pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self {
152 Self {
153 send_stream,
154 recv_stream,
155 }
156 }
157}
158
159impl Drop for IrohStream {
160 fn drop(&mut self) {
161 self.send_stream.reset(0u8.into()).ok();
163 self.recv_stream.stop(0u8.into()).ok();
164
165 tracing::debug!("IrohStream dropped with cleanup");
166 }
167}
168
169impl AsyncRead for IrohStream {
170 fn poll_read(
171 self: Pin<&mut Self>,
172 cx: &mut Context<'_>,
173 buf: &mut tokio::io::ReadBuf<'_>,
174 ) -> Poll<io::Result<()>> {
175 let this = self.get_mut();
176 match Pin::new(&mut this.recv_stream).poll_read(cx, buf) {
177 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
178 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
179 Poll::Pending => Poll::Pending,
180 }
181 }
182}
183
184impl AsyncWrite for IrohStream {
185 fn poll_write(
186 self: Pin<&mut Self>,
187 cx: &mut Context<'_>,
188 buf: &[u8],
189 ) -> Poll<io::Result<usize>> {
190 let this = self.get_mut();
191 match Pin::new(&mut this.send_stream).poll_write(cx, buf) {
192 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
193 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
194 Poll::Pending => Poll::Pending,
195 }
196 }
197
198 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199 let this = self.get_mut();
200 match Pin::new(&mut this.send_stream).poll_flush(cx) {
201 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
202 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
203 Poll::Pending => Poll::Pending,
204 }
205 }
206
207 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 let this = self.get_mut();
209 match Pin::new(&mut this.send_stream).poll_shutdown(cx) {
210 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
211 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
212 Poll::Pending => Poll::Pending,
213 }
214 }
215}
216
217pub enum Listener {
218 Tcp(TcpListener),
219 Unix(UnixListener),
220 Iroh(Endpoint, String), }
222
223impl Listener {
224 pub async fn accept(
225 &mut self,
226 ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
227 match self {
228 Listener::Tcp(listener) => {
229 let (stream, addr) = listener.accept().await?;
230 Ok((Box::new(stream), Some(addr)))
231 }
232 Listener::Unix(listener) => {
233 let (stream, _) = listener.accept().await?;
234 Ok((Box::new(stream), None))
235 }
236 Listener::Iroh(endpoint, _) => {
237 let incoming = endpoint.accept().await.ok_or_else(|| {
239 tracing::error!("No incoming iroh connection available");
240 io::Error::other("No incoming connection")
241 })?;
242
243 let conn = incoming.await.map_err(|e| {
244 tracing::error!("Failed to accept iroh connection: {}", e);
245 io::Error::other(format!("Connection failed: {e}"))
246 })?;
247
248 let remote_node_id = "unknown"; tracing::info!("Got iroh connection from {}", remote_node_id);
250
251 let (send_stream, mut recv_stream) = conn.accept_bi().await.map_err(|e| {
253 tracing::error!(
254 "Failed to accept bidirectional stream from {}: {}",
255 remote_node_id,
256 e
257 );
258 io::Error::other(format!("Failed to accept stream: {e}"))
259 })?;
260
261 tracing::debug!("Accepted bidirectional stream from {}", remote_node_id);
262
263 let mut handshake_buf = [0u8; HANDSHAKE.len()];
265 #[allow(unused_imports)]
266 use tokio::io::AsyncReadExt;
267 recv_stream
268 .read_exact(&mut handshake_buf)
269 .await
270 .map_err(|e| {
271 tracing::error!("Failed to read handshake from {}: {}", remote_node_id, e);
272 io::Error::other(format!("Failed to read handshake: {e}"))
273 })?;
274
275 if handshake_buf != HANDSHAKE {
276 tracing::error!(
277 "Invalid handshake received from {}: expected {:?}, got {:?}",
278 remote_node_id,
279 HANDSHAKE,
280 handshake_buf
281 );
282 return Err(io::Error::new(
283 io::ErrorKind::InvalidData,
284 format!("Invalid handshake from {remote_node_id}"),
285 ));
286 }
287
288 tracing::info!("Handshake verified successfully from {}", remote_node_id);
289
290 let stream = IrohStream::new(send_stream, recv_stream);
291 Ok((Box::new(stream), None))
292 }
293 }
294 }
295
296 pub async fn bind(addr: &str) -> io::Result<Self> {
297 if addr.starts_with("iroh://") {
298 tracing::info!("Binding iroh endpoint");
299
300 let secret_key = get_or_create_secret()?;
301 let endpoint = Endpoint::builder()
302 .alpns(vec![ALPN.to_vec()])
303 .relay_mode(RelayMode::Default)
304 .secret_key(secret_key)
305 .bind()
306 .await
307 .map_err(|e| {
308 tracing::error!("Failed to bind iroh endpoint: {}", e);
309 io::Error::other(format!("Failed to bind endpoint: {e}"))
310 })?;
311
312 tracing::debug!("Iroh endpoint bound successfully");
313
314 let relay_wait = std::time::Duration::from_secs(5);
318 if tokio::time::timeout(relay_wait, endpoint.home_relay().initialized())
319 .await
320 .is_err()
321 {
322 tracing::warn!(
323 "No iroh home relay after {relay_wait:?}; \
324 issuing a direct-addresses-only ticket"
325 );
326 }
327 let node_addr = endpoint.node_addr().initialized().await;
328
329 let ticket = NodeTicket::new(node_addr.clone()).to_string();
331
332 tracing::info!("Iroh endpoint ready with node ID: {}", node_addr.node_id);
333 tracing::info!("Iroh ticket: {}", ticket);
334
335 Ok(Listener::Iroh(endpoint, ticket))
336 } else if addr.starts_with('/') || addr.starts_with('.') || is_windows_path(addr) {
337 let _ = std::fs::remove_file(addr);
339 let listener = UnixListener::bind(addr)?;
340 Ok(Listener::Unix(listener))
341 } else {
342 let mut addr = addr.to_owned();
343 if addr.starts_with(':') {
344 addr = format!("127.0.0.1{addr}");
345 };
346 let listener = TcpListener::bind(addr).await?;
347 Ok(Listener::Tcp(listener))
348 }
349 }
350
351 pub fn get_ticket(&self) -> Option<&str> {
352 match self {
353 Listener::Iroh(_, ticket) => Some(ticket),
354 _ => None,
355 }
356 }
357
358 #[cfg(test)]
359 pub async fn connect(&self) -> io::Result<AsyncReadWriteBox> {
360 match self {
361 Listener::Tcp(listener) => {
362 let stream = TcpStream::connect(listener.local_addr()?).await?;
363 Ok(Box::new(stream))
364 }
365 Listener::Unix(listener) => {
366 #[cfg(unix)]
367 {
368 let stream =
369 UnixStream::connect(listener.local_addr()?.as_pathname().unwrap()).await?;
370 Ok(Box::new(stream))
371 }
372 #[cfg(windows)]
373 {
374 let path = listener.local_addr()?;
375 let stream = WinUnixStream::connect(&path).await?;
376 Ok(Box::new(stream))
377 }
378 }
379 Listener::Iroh(_, ticket) => {
380 let secret_key = get_or_create_secret()?;
381
382 let client_endpoint = Endpoint::builder()
384 .alpns(vec![])
385 .relay_mode(RelayMode::Default)
386 .secret_key(secret_key)
387 .bind()
388 .await
389 .map_err(io::Error::other)?;
390
391 let node_ticket: NodeTicket = ticket
393 .parse()
394 .map_err(|e| io::Error::other(format!("Invalid ticket: {}", e)))?;
395 let node_addr = node_ticket.node_addr().clone();
396
397 let conn = client_endpoint
399 .connect(node_addr, ALPN)
400 .await
401 .map_err(io::Error::other)?;
402
403 let (mut send_stream, recv_stream) =
405 conn.open_bi().await.map_err(io::Error::other)?;
406
407 #[allow(unused_imports)]
409 use tokio::io::AsyncWriteExt;
410 send_stream
411 .write_all(&HANDSHAKE)
412 .await
413 .map_err(io::Error::other)?;
414
415 let stream = IrohStream::new(send_stream, recv_stream);
416 Ok(Box::new(stream))
417 }
418 }
419 }
420}
421
422impl std::fmt::Display for Listener {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 match self {
425 Listener::Tcp(listener) => {
426 let addr = listener.local_addr().unwrap();
427 write!(f, "{}:{}", addr.ip(), addr.port())
428 }
429 Listener::Unix(listener) => {
430 #[cfg(unix)]
431 {
432 let addr = listener.local_addr().unwrap();
433 let path = addr.as_pathname().unwrap();
434 write!(f, "{}", path.display())
435 }
436 #[cfg(windows)]
437 {
438 let path = listener.local_addr().unwrap();
439 write!(f, "{}", path.display())
440 }
441 }
442 Listener::Iroh(_, ticket) => {
443 write!(f, "iroh://{ticket}")
444 }
445 }
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 use tokio::io::AsyncReadExt;
454 use tokio::io::AsyncWriteExt;
455
456 async fn exercise_listener(addr: &str) {
457 let mut listener = Listener::bind(addr).await.unwrap();
458 let mut client = listener.connect().await.unwrap();
459
460 let (mut serve, _) = listener.accept().await.unwrap();
461 let want = b"Hello from server!";
462 serve.write_all(want).await.unwrap();
463 drop(serve);
464
465 let mut got = Vec::new();
466 client.read_to_end(&mut got).await.unwrap();
467 assert_eq!(want.to_vec(), got);
468 }
469
470 #[tokio::test]
471 async fn test_bind_tcp() {
472 exercise_listener(":0").await;
473 }
474
475 #[tokio::test]
476 async fn test_bind_unix() {
477 let temp_dir = tempfile::tempdir().unwrap();
478 let path = temp_dir.path().join("test.sock");
479 let path = path.to_str().unwrap();
480 exercise_listener(path).await;
481 }
482
483 #[tokio::test]
484 #[ignore] async fn test_bind_iroh() {
486 exercise_listener("iroh://").await;
488 }
489}