iroh_ssh/
ssh.rs

1use crate::{Builder, Inner, IrohSsh};
2use std::{path::PathBuf, process::Stdio};
3
4use anyhow::bail;
5use ed25519_dalek::SECRET_KEY_LENGTH;
6use homedir::my_home;
7use iroh::{
8    endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeId, SecretKey, Watcher
9};
10use tokio::{
11    net::{TcpListener, TcpStream},
12    process::{Child, Command},
13};
14
15impl Builder {
16    pub fn new() -> Self {
17        Self {
18            secret_key: SecretKey::generate(rand::rngs::OsRng).to_bytes(),
19            accept_incoming: false,
20            accept_port: None,
21        }
22    }
23
24    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
25        self.accept_incoming = accept_incoming;
26        self
27    }
28
29    pub fn accept_port(mut self, accept_port: u16) -> Self {
30        self.accept_port = Some(accept_port);
31        self
32    }
33
34    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
35        self.secret_key = *secret_key;
36        self
37    }
38
39    pub fn dot_ssh_integration(mut self, persist: bool, service: bool) -> Self {
40        if let Ok(secret_key) = dot_ssh(&SecretKey::from_bytes(&self.secret_key), persist, service) {
41            self.secret_key = secret_key.to_bytes();
42        }
43        self
44    }
45
46    pub async fn build(self: &mut Self) -> anyhow::Result<IrohSsh> {
47        // Iroh setup
48        let secret_key = SecretKey::from_bytes(&self.secret_key);
49        let endpoint = Endpoint::builder()
50            .secret_key(secret_key)
51            .discovery_n0()
52            .bind()
53            .await?;
54
55        wait_for_relay(&endpoint).await?;
56
57        let mut iroh_ssh = IrohSsh {
58            public_key: endpoint.node_id().as_bytes().clone(),
59            secret_key: self.secret_key,
60            inner: None,
61            ssh_port: self.accept_port.unwrap_or(22),
62        };
63
64        let router = if self.accept_incoming {
65            Router::builder(endpoint.clone()).accept(&IrohSsh::ALPN(), iroh_ssh.clone())
66        } else {
67            Router::builder(endpoint.clone())
68        }
69        .spawn();
70
71        iroh_ssh.add_inner(endpoint, router);
72
73        Ok(iroh_ssh)
74    }
75}
76
77impl IrohSsh {
78    pub fn new() -> Builder {
79        Builder::new()
80    }
81
82    #[allow(non_snake_case)]
83    pub fn ALPN() -> Vec<u8> {
84        format!("/iroh/ssh").into_bytes()
85    }
86
87    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
88        self.inner = Some(Inner { endpoint, router });
89    }
90
91    pub async fn connect(&self, ssh_user: &str, node_id: NodeId, identity_file: Option<PathBuf>) -> anyhow::Result<Child> {
92        let inner = self.inner.as_ref().expect("inner not set");
93        let conn = inner.endpoint.connect(node_id, &IrohSsh::ALPN()).await?;
94        let listener = TcpListener::bind("127.0.0.1:0").await?;
95        let port = listener.local_addr()?.port();
96
97        tokio::spawn(async move {
98            loop {
99                match listener.accept().await {
100                    Ok((mut stream, _)) => match conn.open_bi().await {
101                        Ok((mut iroh_send, mut iroh_recv)) => {
102                            tokio::spawn(async move {
103                                let (mut local_read, mut local_write) = stream.split();
104                                let a_to_b = async move {
105                                    tokio::io::copy(&mut local_read, &mut iroh_send).await
106                                };
107                                let b_to_a = async move {
108                                    tokio::io::copy(&mut iroh_recv, &mut local_write).await
109                                };
110
111                                tokio::select! {
112                                    result = a_to_b => {
113                                        let _ = result;
114                                    },
115                                    result = b_to_a => {
116                                        let _ = result;
117                                    },
118                                };
119                            });
120                        }
121                        Err(_) => break,
122                    },
123                    Err(_) => break,
124                }
125            }
126        });
127        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
128        let mut cmd = &mut Command::new("ssh");
129        cmd = cmd.arg("-tt") // Force pseudo-terminal allocation
130            .arg(format!("{}@127.0.0.1", ssh_user))
131            .arg("-p")
132            .arg(port.to_string())
133            .arg("-o")
134            .arg("StrictHostKeyChecking=no")
135            .arg("-o")
136            .arg("UserKnownHostsFile=/dev/null")
137            .arg("-o")
138            .arg("LogLevel=ERROR"); // Reduce SSH debug output
139
140        if let Some(identity_file) = identity_file {
141            cmd.arg("-i").arg(identity_file);
142        }
143
144        let ssh_process = cmd.stdin(Stdio::inherit())
145            .stdout(Stdio::inherit())
146            .stderr(Stdio::inherit())
147            .spawn()?;
148
149        Ok(ssh_process)
150    }
151
152    pub fn node_id(&self) -> NodeId {
153        self.inner
154            .as_ref()
155            .expect("inner not set")
156            .endpoint
157            .node_id()
158    }
159}
160
161impl ProtocolHandler for IrohSsh {
162    async fn accept(
163        &self,
164        connection: Connection,
165    ) -> Result<(), iroh::protocol::AcceptError> {
166
167        let alpn = connection.alpn().ok_or_else(|| iroh::protocol::AcceptError::NotAllowed {  })?;
168        if alpn != IrohSsh::ALPN() {
169            return Err(iroh::protocol::AcceptError::NotAllowed {})
170        }
171
172        let node_id = connection.remote_node_id()?;
173        println!("{}: {node_id} connected", String::from_utf8_lossy(&alpn));
174    
175        match connection.accept_bi().await {
176            Ok((mut iroh_send, mut iroh_recv)) => {
177                println!("Accepted bidirectional stream from {}", node_id);
178
179                match TcpStream::connect(format!("127.0.0.1:{}", self.ssh_port)).await {
180                    Ok(mut ssh_stream) => {
181                        println!("Connected to local SSH server on port {}", self.ssh_port);
182
183                        let (mut local_read, mut local_write) = ssh_stream.split();
184
185                        let a_to_b = async move {
186                            tokio::io::copy(&mut local_read, &mut iroh_send).await
187                        };
188                        let b_to_a = async move {
189                            tokio::io::copy(&mut iroh_recv, &mut local_write).await
190                        };
191
192                        tokio::select! {
193                            result = a_to_b => {
194                                println!("SSH->Iroh stream ended: {:?}", result);
195                            },
196                            result = b_to_a => {
197                                println!("Iroh->SSH stream ended: {:?}", result);
198                            },
199                        };
200                    }
201                    Err(e) => {
202                        println!("Failed to connect to SSH server: {}", e);
203                    }
204                }
205            }
206            Err(e) => {
207                println!("Failed to accept bidirectional stream: {}", e);
208            }
209        }
210
211        Ok(())
212        
213    }
214}
215
216
217
218pub fn dot_ssh(default_secret_key: &SecretKey, persist: bool, service: bool) -> anyhow::Result<SecretKey> {
219    let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
220    #[allow(unused_mut)]
221    let mut ssh_dir = distro_home.join(".ssh");
222
223    // For now linux services are installed as "sudo'er" so
224    // we need to use the root .ssh directory
225    #[cfg(target_os = "linux")]
226    if service {
227        ssh_dir = std::path::PathBuf::from("/root/.ssh");
228    }
229
230    // Weird windows System service profile location:
231    // "C:\WINDOWS\system32\config\systemprofile\.ssh"
232    #[cfg(target_os = "windows")]
233    if service {
234        ssh_dir = std::path::PathBuf::from(r#"C:\WINDOWS\system32\config\systemprofile\.ssh"#);
235    }
236
237    let pub_key = ssh_dir.join("irohssh_ed25519.pub");
238    let priv_key = ssh_dir.join("irohssh_ed25519");
239
240    match (ssh_dir.exists(), persist) {
241        (false, false) => {
242            bail!("no .ssh folder found in {}, use --persist flag to create it", distro_home.display())
243        }
244        (false, true) => {
245            std::fs::create_dir_all(&ssh_dir)?;
246            println!("[INFO] created .ssh folder: {}", ssh_dir.display());
247            dot_ssh(default_secret_key, persist, service)
248        }
249        (true, true) => {
250            // check pub and priv key already exists
251            if pub_key.exists() && priv_key.exists() {
252                // read secret key
253                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
254                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
255                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
256                    Ok(SecretKey::from_bytes(&sk_bytes))
257                } else {
258                    bail!("failed to read secret key from {}", priv_key.display())
259                }
260            } else {
261                let key = default_secret_key.clone();
262                let secret_key = key.secret();
263                let public_key = key.public();
264
265                std::fs::write(pub_key, z32::encode(public_key.as_bytes()))?;
266                std::fs::write(priv_key, z32::encode(secret_key.as_bytes()))?;
267                Ok(key)
268            }
269        }
270        (true, false) => {
271            // check pub and priv key already exists
272            if pub_key.exists() && priv_key.exists() {
273                // read secret key
274                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
275                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
276                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
277                    return Ok(SecretKey::from_bytes(&sk_bytes));
278                }
279            }
280            bail!("no iroh-ssh keys found in {}, use --persist flag to create it", ssh_dir.display())
281        }
282    }
283}
284
285async fn wait_for_relay(endpoint: &Endpoint) -> anyhow::Result<()> {
286    while endpoint.home_relay().initialized().await.is_err(){
287        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
288    }
289    Ok(())
290}