1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use iroh::endpoint::{RecvStream, SendStream};
6use iroh::{Endpoint, RelayMode, SecretKey, Watcher};
7use iroh_base::ticket::NodeTicket;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio::net::{TcpListener, UnixListener};
10#[cfg(test)]
11use tokio::net::{TcpStream, UnixStream};
12
13pub const ALPN: &[u8] = b"XS/1.0";
15
16pub const HANDSHAKE: [u8; 5] = *b"xs..!";
19
20fn get_or_create_secret() -> io::Result<SecretKey> {
23 match std::env::var("IROH_SECRET") {
24 Ok(secret) => {
25 use std::str::FromStr;
26 SecretKey::from_str(&secret).map_err(|e| {
27 io::Error::new(
28 io::ErrorKind::InvalidData,
29 format!("Invalid secret key: {e}"),
30 )
31 })
32 }
33 Err(_) => {
34 let key = SecretKey::generate(rand::rngs::OsRng);
35 tracing::info!(
36 "Generated new secret key: {}",
37 data_encoding::HEXLOWER.encode(&key.to_bytes())
38 );
39 Ok(key)
40 }
41 }
42}
43
44pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
45
46impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
47
48pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
49
50pub struct IrohStream {
51 send_stream: SendStream,
52 recv_stream: RecvStream,
53}
54
55impl IrohStream {
56 pub fn new(send_stream: SendStream, recv_stream: RecvStream) -> Self {
57 Self {
58 send_stream,
59 recv_stream,
60 }
61 }
62}
63
64impl Drop for IrohStream {
65 fn drop(&mut self) {
66 self.send_stream.reset(0u8.into()).ok();
68 self.recv_stream.stop(0u8.into()).ok();
69
70 tracing::debug!("IrohStream dropped with cleanup");
71 }
72}
73
74impl AsyncRead for IrohStream {
75 fn poll_read(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 buf: &mut tokio::io::ReadBuf<'_>,
79 ) -> Poll<io::Result<()>> {
80 let this = self.get_mut();
81 match Pin::new(&mut this.recv_stream).poll_read(cx, buf) {
82 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
83 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
84 Poll::Pending => Poll::Pending,
85 }
86 }
87}
88
89impl AsyncWrite for IrohStream {
90 fn poll_write(
91 self: Pin<&mut Self>,
92 cx: &mut Context<'_>,
93 buf: &[u8],
94 ) -> Poll<io::Result<usize>> {
95 let this = self.get_mut();
96 match Pin::new(&mut this.send_stream).poll_write(cx, buf) {
97 Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
98 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
99 Poll::Pending => Poll::Pending,
100 }
101 }
102
103 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
104 let this = self.get_mut();
105 match Pin::new(&mut this.send_stream).poll_flush(cx) {
106 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
107 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
108 Poll::Pending => Poll::Pending,
109 }
110 }
111
112 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113 let this = self.get_mut();
114 match Pin::new(&mut this.send_stream).poll_shutdown(cx) {
115 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
116 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::other(e))),
117 Poll::Pending => Poll::Pending,
118 }
119 }
120}
121
122pub enum Listener {
123 Tcp(TcpListener),
124 Unix(UnixListener),
125 Iroh(Endpoint, String), }
127
128impl Listener {
129 pub async fn accept(
130 &mut self,
131 ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
132 match self {
133 Listener::Tcp(listener) => {
134 let (stream, addr) = listener.accept().await?;
135 Ok((Box::new(stream), Some(addr)))
136 }
137 Listener::Unix(listener) => {
138 let (stream, _) = listener.accept().await?;
139 Ok((Box::new(stream), None))
140 }
141 Listener::Iroh(endpoint, _) => {
142 let incoming = endpoint.accept().await.ok_or_else(|| {
144 tracing::error!("No incoming iroh connection available");
145 io::Error::other("No incoming connection")
146 })?;
147
148 let conn = incoming.await.map_err(|e| {
149 tracing::error!("Failed to accept iroh connection: {}", e);
150 io::Error::other(format!("Connection failed: {e}"))
151 })?;
152
153 let remote_node_id = "unknown"; tracing::info!("Got iroh connection from {}", remote_node_id);
155
156 let (send_stream, mut recv_stream) = conn.accept_bi().await.map_err(|e| {
158 tracing::error!(
159 "Failed to accept bidirectional stream from {}: {}",
160 remote_node_id,
161 e
162 );
163 io::Error::other(format!("Failed to accept stream: {e}"))
164 })?;
165
166 tracing::debug!("Accepted bidirectional stream from {}", remote_node_id);
167
168 let mut handshake_buf = [0u8; HANDSHAKE.len()];
170 #[allow(unused_imports)]
171 use tokio::io::AsyncReadExt;
172 recv_stream
173 .read_exact(&mut handshake_buf)
174 .await
175 .map_err(|e| {
176 tracing::error!("Failed to read handshake from {}: {}", remote_node_id, e);
177 io::Error::other(format!("Failed to read handshake: {e}"))
178 })?;
179
180 if handshake_buf != HANDSHAKE {
181 tracing::error!(
182 "Invalid handshake received from {}: expected {:?}, got {:?}",
183 remote_node_id,
184 HANDSHAKE,
185 handshake_buf
186 );
187 return Err(io::Error::new(
188 io::ErrorKind::InvalidData,
189 format!("Invalid handshake from {remote_node_id}"),
190 ));
191 }
192
193 tracing::info!("Handshake verified successfully from {}", remote_node_id);
194
195 let stream = IrohStream::new(send_stream, recv_stream);
196 Ok((Box::new(stream), None))
197 }
198 }
199 }
200
201 pub async fn bind(addr: &str) -> io::Result<Self> {
202 if addr.starts_with("iroh://") {
203 tracing::info!("Binding iroh endpoint");
204
205 let secret_key = get_or_create_secret()?;
206 let endpoint = Endpoint::builder()
207 .alpns(vec![ALPN.to_vec()])
208 .relay_mode(RelayMode::Default)
209 .secret_key(secret_key)
210 .bind()
211 .await
212 .map_err(|e| {
213 tracing::error!("Failed to bind iroh endpoint: {}", e);
214 io::Error::other(format!("Failed to bind endpoint: {e}"))
215 })?;
216
217 tracing::debug!("Iroh endpoint bound successfully");
218
219 endpoint.home_relay().initialized().await;
221 let node_addr = endpoint.node_addr().initialized().await;
222
223 let ticket = NodeTicket::new(node_addr.clone()).to_string();
225
226 tracing::info!("Iroh endpoint ready with node ID: {}", node_addr.node_id);
227 tracing::info!("Iroh ticket: {}", ticket);
228
229 Ok(Listener::Iroh(endpoint, ticket))
230 } else if addr.starts_with('/') || addr.starts_with('.') {
231 let _ = std::fs::remove_file(addr);
233 let listener = UnixListener::bind(addr)?;
234 Ok(Listener::Unix(listener))
235 } else {
236 let mut addr = addr.to_owned();
237 if addr.starts_with(':') {
238 addr = format!("127.0.0.1{addr}");
239 };
240 let listener = TcpListener::bind(addr).await?;
241 Ok(Listener::Tcp(listener))
242 }
243 }
244
245 pub fn get_ticket(&self) -> Option<&str> {
246 match self {
247 Listener::Iroh(_, ticket) => Some(ticket),
248 _ => None,
249 }
250 }
251
252 #[cfg(test)]
253 pub async fn connect(&self) -> io::Result<AsyncReadWriteBox> {
254 match self {
255 Listener::Tcp(listener) => {
256 let stream = TcpStream::connect(listener.local_addr()?).await?;
257 Ok(Box::new(stream))
258 }
259 Listener::Unix(listener) => {
260 let stream =
261 UnixStream::connect(listener.local_addr()?.as_pathname().unwrap()).await?;
262 Ok(Box::new(stream))
263 }
264 Listener::Iroh(_, ticket) => {
265 let secret_key = get_or_create_secret()?;
266
267 let client_endpoint = Endpoint::builder()
269 .alpns(vec![])
270 .relay_mode(RelayMode::Default)
271 .secret_key(secret_key)
272 .bind()
273 .await
274 .map_err(io::Error::other)?;
275
276 let node_ticket: NodeTicket = ticket
278 .parse()
279 .map_err(|e| io::Error::other(format!("Invalid ticket: {}", e)))?;
280 let node_addr = node_ticket.node_addr().clone();
281
282 let conn = client_endpoint
284 .connect(node_addr, ALPN)
285 .await
286 .map_err(io::Error::other)?;
287
288 let (mut send_stream, recv_stream) =
290 conn.open_bi().await.map_err(io::Error::other)?;
291
292 #[allow(unused_imports)]
294 use tokio::io::AsyncWriteExt;
295 send_stream
296 .write_all(&HANDSHAKE)
297 .await
298 .map_err(io::Error::other)?;
299
300 let stream = IrohStream::new(send_stream, recv_stream);
301 Ok(Box::new(stream))
302 }
303 }
304 }
305}
306
307impl std::fmt::Display for Listener {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 match self {
310 Listener::Tcp(listener) => {
311 let addr = listener.local_addr().unwrap();
312 write!(f, "{}:{}", addr.ip(), addr.port())
313 }
314 Listener::Unix(listener) => {
315 let addr = listener.local_addr().unwrap();
316 let path = addr.as_pathname().unwrap();
317 write!(f, "{}", path.display())
318 }
319 Listener::Iroh(_, ticket) => {
320 write!(f, "iroh://{ticket}")
321 }
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 use tokio::io::AsyncReadExt;
331 use tokio::io::AsyncWriteExt;
332
333 async fn exercise_listener(addr: &str) {
334 let mut listener = Listener::bind(addr).await.unwrap();
335 let mut client = listener.connect().await.unwrap();
336
337 let (mut serve, _) = listener.accept().await.unwrap();
338 let want = b"Hello from server!";
339 serve.write_all(want).await.unwrap();
340 drop(serve);
341
342 let mut got = Vec::new();
343 client.read_to_end(&mut got).await.unwrap();
344 assert_eq!(want.to_vec(), got);
345 }
346
347 #[tokio::test]
348 async fn test_bind_tcp() {
349 exercise_listener(":0").await;
350 }
351
352 #[tokio::test]
353 async fn test_bind_unix() {
354 let temp_dir = tempfile::tempdir().unwrap();
355 let path = temp_dir.path().join("test.sock");
356 let path = path.to_str().unwrap();
357 exercise_listener(path).await;
358 }
359
360 #[tokio::test]
361 #[ignore] async fn test_bind_iroh() {
363 exercise_listener("iroh://").await;
365 }
366}