iroh_ssh/
ssh.rs

1macro_rules! ok_or_continue {
2    ($expr:expr) => {
3        match $expr {
4            Ok(val) => val,
5            Err(_) => continue,
6        }
7    };
8}
9
10use crate::{Builder, Inner, IrohSsh};
11use std::{pin::Pin, process::Stdio};
12
13use ed25519_dalek::SECRET_KEY_LENGTH;
14use iroh::{
15    Endpoint, NodeId, SecretKey,
16    endpoint::Connection,
17    protocol::{ProtocolHandler, Router},
18};
19use tokio::{
20    net::{TcpListener, TcpStream},
21    process::{Child, Command},
22};
23
24impl Builder {
25    pub fn new() -> Self {
26        Self {
27            secret_key: SecretKey::generate(rand::rngs::OsRng).to_bytes(),
28            accept_incoming: false,
29            accept_port: None,
30        }
31    }
32
33    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
34        self.accept_incoming = accept_incoming;
35        self
36    }
37
38    pub fn accept_port(mut self, accept_port: u16) -> Self {
39        self.accept_port = Some(accept_port);
40        self
41    }
42
43    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
44        self.secret_key = *secret_key;
45        self
46    }
47
48    pub async fn build(self: &mut Self) -> anyhow::Result<IrohSsh> {
49        // Iroh setup
50        let secret_key = SecretKey::from_bytes(&self.secret_key);
51        let endpoint = Endpoint::builder()
52            .secret_key(secret_key)
53            .discovery_n0()
54            .bind()
55            .await?;
56
57        let mut iroh_ssh = IrohSsh {
58            public_key: endpoint.node_id().as_bytes().clone(),
59            secret_key: self.secret_key,
60            inner: None,
61        };
62
63        let router = if self.accept_incoming {
64            Router::builder(endpoint.clone()).accept(&IrohSsh::ALPN(), iroh_ssh.clone())
65        } else {
66            Router::builder(endpoint.clone())
67        }
68        .spawn();
69
70        iroh_ssh.add_inner(endpoint, router);
71
72        if self.accept_incoming && self.accept_port.is_some() {
73            tokio::spawn({
74                let iroh_ssh = iroh_ssh.clone();
75                let accept_port = self.accept_port.expect("accept_port not set");
76                async move {
77                    iroh_ssh._spawn(accept_port).await.expect("spawn failed");
78                }
79            });
80        }
81
82        Ok(iroh_ssh)
83    }
84}
85
86impl IrohSsh {
87    pub fn new() -> Builder {
88        Builder::new()
89    }
90
91    #[allow(non_snake_case)]
92    pub fn ALPN() -> Vec<u8> {
93        format!("/iroh/ssh").into_bytes()
94    }
95
96    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
97        self.inner = Some(Inner { endpoint, router });
98    }
99
100    pub async fn connect(&self, ssh_user: &str, node_id: NodeId) -> anyhow::Result<Child> {
101        let inner = self.inner.as_ref().expect("inner not set");
102        let conn = inner.endpoint.connect(node_id, &IrohSsh::ALPN()).await?;
103        let listener = TcpListener::bind("127.0.0.1:0").await?;
104        let port = listener.local_addr()?.port();
105
106        tokio::spawn(async move {
107            loop {
108                match listener.accept().await {
109                    Ok((mut stream, _)) => match conn.open_bi().await {
110                        Ok((mut iroh_send, mut iroh_recv)) => {
111                            tokio::spawn(async move {
112                                let (mut local_read, mut local_write) = stream.split();
113                                let a_to_b = async move {
114                                    tokio::io::copy(&mut local_read, &mut iroh_send).await
115                                };
116                                let b_to_a = async move {
117                                    tokio::io::copy(&mut iroh_recv, &mut local_write).await
118                                };
119
120                                tokio::select! {
121                                    result = a_to_b => {
122                                        let _ = result;
123                                    },
124                                    result = b_to_a => {
125                                        let _ = result;
126                                    },
127                                };
128                            });
129                        }
130                        Err(_) => break,
131                    },
132                    Err(_) => break,
133                }
134            }
135        });
136        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
137        let ssh_process = Command::new("ssh")
138            .arg("-tt") // Force pseudo-terminal allocation
139            .arg(format!("{}@127.0.0.1", ssh_user))
140            .arg("-p")
141            .arg(port.to_string())
142            .arg("-o")
143            .arg("StrictHostKeyChecking=no")
144            .arg("-o")
145            .arg("UserKnownHostsFile=/dev/null")
146            .arg("-o")
147            .arg("LogLevel=ERROR") // Reduce SSH debug output
148            .stdin(Stdio::inherit())
149            .stdout(Stdio::inherit())
150            .stderr(Stdio::inherit())
151            .spawn()?;
152
153        Ok(ssh_process)
154    }
155
156    pub fn node_id(&self) -> NodeId {
157        self.inner
158            .as_ref()
159            .expect("inner not set")
160            .endpoint
161            .node_id()
162    }
163
164    async fn _spawn(self, port: u16) -> anyhow::Result<()> {
165        println!("Server listening for iroh connections...");
166
167        while let Some(incoming) = self
168            .inner
169            .clone()
170            .expect("inner not set")
171            .endpoint
172            .accept()
173            .await
174        {
175            let mut connecting = match incoming.accept() {
176                Ok(connecting) => connecting,
177                Err(err) => {
178                    println!("Incoming connection failure: {err:#}");
179                    continue;
180                }
181            };
182
183            let alpn = ok_or_continue!(connecting.alpn().await);
184            let conn = ok_or_continue!(connecting.await);
185            let node_id = ok_or_continue!(conn.remote_node_id());
186
187            println!("{}: {node_id} connected", String::from_utf8_lossy(&alpn));
188
189            tokio::spawn(async move {
190                match conn.accept_bi().await {
191                    Ok((mut iroh_send, mut iroh_recv)) => {
192                        println!("Accepted bidirectional stream from {}", node_id);
193
194                        match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
195                            Ok(mut ssh_stream) => {
196                                println!("Connected to local SSH server on port {}", port);
197
198                                let (mut local_read, mut local_write) = ssh_stream.split();
199
200                                let a_to_b = async move {
201                                    tokio::io::copy(&mut local_read, &mut iroh_send).await
202                                };
203                                let b_to_a = async move {
204                                    tokio::io::copy(&mut iroh_recv, &mut local_write).await
205                                };
206
207                                tokio::select! {
208                                    result = a_to_b => {
209                                        println!("SSH->Iroh stream ended: {:?}", result);
210                                    },
211                                    result = b_to_a => {
212                                        println!("Iroh->SSH stream ended: {:?}", result);
213                                    },
214                                };
215                            }
216                            Err(e) => {
217                                println!("Failed to connect to SSH server: {}", e);
218                            }
219                        }
220                    }
221                    Err(e) => {
222                        println!("Failed to accept bidirectional stream: {}", e);
223                    }
224                }
225            });
226        }
227        Ok(())
228    }
229}
230
231impl ProtocolHandler for IrohSsh {
232    fn accept(
233        &self,
234        conn: Connection,
235    ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
236        let iroh_ssh = self.clone();
237
238        Box::pin(async move {
239            iroh_ssh.accept(conn).await?;
240            Ok(())
241        })
242    }
243}