iroh_ssh/
ssh.rs

1macro_rules! ok_or_continue {
2    ($expr:expr) => {
3        match $expr {
4            Ok(val) => val,
5            Err(_) => continue,
6        }
7    };
8}
9
10use crate::{Builder, Inner, IrohSsh};
11use std::{pin::Pin, process::Stdio};
12
13use ed25519_dalek::SECRET_KEY_LENGTH;
14use iroh::{
15    Endpoint, NodeId, SecretKey,
16    endpoint::Connection,
17    protocol::{ProtocolHandler, Router},
18};
19use tokio::{
20    net::{TcpListener, TcpStream},
21    process::{Child, Command},
22};
23use homedir::my_home;
24
25
26impl Builder {
27    pub fn new() -> Self {
28        Self {
29            secret_key: SecretKey::generate(rand::rngs::OsRng).to_bytes(),
30            accept_incoming: false,
31            accept_port: None,
32        }
33    }
34
35    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
36        self.accept_incoming = accept_incoming;
37        self
38    }
39
40    pub fn accept_port(mut self, accept_port: u16) -> Self {
41        self.accept_port = Some(accept_port);
42        self
43    }
44
45    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
46        self.secret_key = *secret_key;
47        self
48    }
49
50    pub fn dot_ssh_integration(mut self) -> Self {
51        if let Ok(secret_key) = dot_ssh(&SecretKey::from_bytes(&self.secret_key)) {
52            self.secret_key = secret_key.to_bytes();
53        }
54        self
55    }
56
57    pub async fn build(self: &mut Self) -> anyhow::Result<IrohSsh> {
58        // Iroh setup
59        let secret_key = SecretKey::from_bytes(&self.secret_key);
60        let endpoint = Endpoint::builder()
61            .secret_key(secret_key)
62            .discovery_n0()
63            .bind()
64            .await?;
65        
66        wait_for_relay(&endpoint).await?;
67
68        let mut iroh_ssh = IrohSsh {
69            public_key: endpoint.node_id().as_bytes().clone(),
70            secret_key: self.secret_key,
71            inner: None,
72        };
73
74        let router = if self.accept_incoming {
75            Router::builder(endpoint.clone()).accept(&IrohSsh::ALPN(), iroh_ssh.clone())
76        } else {
77            Router::builder(endpoint.clone())
78        }
79        .spawn();
80
81        iroh_ssh.add_inner(endpoint, router);
82
83        if self.accept_incoming && self.accept_port.is_some() {
84            tokio::spawn({
85                let iroh_ssh = iroh_ssh.clone();
86                let accept_port = self.accept_port.expect("accept_port not set");
87                async move {
88                    iroh_ssh._spawn(accept_port).await.expect("spawn failed");
89                }
90            });
91        }
92
93        Ok(iroh_ssh)
94    }
95}
96
97impl IrohSsh {
98    pub fn new() -> Builder {
99        Builder::new()
100    }
101
102    #[allow(non_snake_case)]
103    pub fn ALPN() -> Vec<u8> {
104        format!("/iroh/ssh").into_bytes()
105    }
106
107    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
108        self.inner = Some(Inner { endpoint, router });
109    }
110
111    pub async fn connect(&self, ssh_user: &str, node_id: NodeId) -> anyhow::Result<Child> {
112        let inner = self.inner.as_ref().expect("inner not set");
113        let conn = inner.endpoint.connect(node_id, &IrohSsh::ALPN()).await?;
114        let listener = TcpListener::bind("127.0.0.1:0").await?;
115        let port = listener.local_addr()?.port();
116
117        tokio::spawn(async move {
118            loop {
119                match listener.accept().await {
120                    Ok((mut stream, _)) => match conn.open_bi().await {
121                        Ok((mut iroh_send, mut iroh_recv)) => {
122                            tokio::spawn(async move {
123                                let (mut local_read, mut local_write) = stream.split();
124                                let a_to_b = async move {
125                                    tokio::io::copy(&mut local_read, &mut iroh_send).await
126                                };
127                                let b_to_a = async move {
128                                    tokio::io::copy(&mut iroh_recv, &mut local_write).await
129                                };
130
131                                tokio::select! {
132                                    result = a_to_b => {
133                                        let _ = result;
134                                    },
135                                    result = b_to_a => {
136                                        let _ = result;
137                                    },
138                                };
139                            });
140                        }
141                        Err(_) => break,
142                    },
143                    Err(_) => break,
144                }
145            }
146        });
147        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
148        let ssh_process = Command::new("ssh")
149            .arg("-tt") // Force pseudo-terminal allocation
150            .arg(format!("{}@127.0.0.1", ssh_user))
151            .arg("-p")
152            .arg(port.to_string())
153            .arg("-o")
154            .arg("StrictHostKeyChecking=no")
155            .arg("-o")
156            .arg("UserKnownHostsFile=/dev/null")
157            .arg("-o")
158            .arg("LogLevel=ERROR") // Reduce SSH debug output
159            .stdin(Stdio::inherit())
160            .stdout(Stdio::inherit())
161            .stderr(Stdio::inherit())
162            .spawn()?;
163
164        Ok(ssh_process)
165    }
166
167    pub fn node_id(&self) -> NodeId {
168        self.inner
169            .as_ref()
170            .expect("inner not set")
171            .endpoint
172            .node_id()
173    }
174
175    async fn _spawn(self, port: u16) -> anyhow::Result<()> {
176        println!("Server listening for iroh connections...");
177
178        while let Some(incoming) = self
179            .inner
180            .clone()
181            .expect("inner not set")
182            .endpoint
183            .accept()
184            .await
185        {
186            let mut connecting = match incoming.accept() {
187                Ok(connecting) => connecting,
188                Err(err) => {
189                    println!("Incoming connection failure: {err:#}");
190                    continue;
191                }
192            };
193
194            let alpn = ok_or_continue!(connecting.alpn().await);
195            let conn = ok_or_continue!(connecting.await);
196            let node_id = ok_or_continue!(conn.remote_node_id());
197
198            println!("{}: {node_id} connected", String::from_utf8_lossy(&alpn));
199
200            tokio::spawn(async move {
201                match conn.accept_bi().await {
202                    Ok((mut iroh_send, mut iroh_recv)) => {
203                        println!("Accepted bidirectional stream from {}", node_id);
204
205                        match TcpStream::connect(format!("127.0.0.1:{}", port)).await {
206                            Ok(mut ssh_stream) => {
207                                println!("Connected to local SSH server on port {}", port);
208
209                                let (mut local_read, mut local_write) = ssh_stream.split();
210
211                                let a_to_b = async move {
212                                    tokio::io::copy(&mut local_read, &mut iroh_send).await
213                                };
214                                let b_to_a = async move {
215                                    tokio::io::copy(&mut iroh_recv, &mut local_write).await
216                                };
217
218                                tokio::select! {
219                                    result = a_to_b => {
220                                        println!("SSH->Iroh stream ended: {:?}", result);
221                                    },
222                                    result = b_to_a => {
223                                        println!("Iroh->SSH stream ended: {:?}", result);
224                                    },
225                                };
226                            }
227                            Err(e) => {
228                                println!("Failed to connect to SSH server: {}", e);
229                            }
230                        }
231                    }
232                    Err(e) => {
233                        println!("Failed to accept bidirectional stream: {}", e);
234                    }
235                }
236            });
237        }
238        Ok(())
239    }
240}
241
242impl ProtocolHandler for IrohSsh {
243    fn accept(
244        &self,
245        conn: Connection,
246    ) -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>> {
247        let iroh_ssh = self.clone();
248
249        Box::pin(async move {
250            iroh_ssh.accept(conn).await?;
251            Ok(())
252        })
253    }
254}
255
256pub fn dot_ssh(default_secret_key: &SecretKey) -> anyhow::Result<SecretKey> {
257    let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
258    let ssh_dir = distro_home.join(".ssh");
259    let pub_key = ssh_dir.join("irohssh_ed25519.pub");
260    let priv_key = ssh_dir.join("irohssh_ed25519");
261
262    // check pub and priv key already exists
263    if pub_key.exists() && priv_key.exists() {
264        // read secret key
265        if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
266            let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
267            sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
268            return Ok(SecretKey::from_bytes(&sk_bytes));
269        }
270    }
271    
272    let key = default_secret_key.clone();
273    let secret_key = key.secret();
274    let public_key = key.public();
275
276    std::fs::write(pub_key, z32::encode(public_key.as_bytes()))?;
277    std::fs::write(priv_key, z32::encode(secret_key.as_bytes()))?;
278
279    Ok(key)
280}
281    
282async fn wait_for_relay(endpoint: &Endpoint) -> anyhow::Result<()> {
283    while endpoint.home_relay().get().is_err() {
284        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
285    }
286    Ok(())
287}