iroh_ssh/
ssh.rs

1use crate::{Builder, Inner, IrohSsh};
2use std::process::Stdio;
3
4use anyhow::bail;
5use ed25519_dalek::SECRET_KEY_LENGTH;
6use homedir::my_home;
7use iroh::{
8    endpoint::Connection, protocol::{ProtocolHandler, Router}, Endpoint, NodeId, SecretKey, Watcher
9};
10use tokio::{
11    net::{TcpListener, TcpStream},
12    process::{Child, Command},
13};
14
15impl Builder {
16    pub fn new() -> Self {
17        Self {
18            secret_key: SecretKey::generate(rand::rngs::OsRng).to_bytes(),
19            accept_incoming: false,
20            accept_port: None,
21        }
22    }
23
24    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
25        self.accept_incoming = accept_incoming;
26        self
27    }
28
29    pub fn accept_port(mut self, accept_port: u16) -> Self {
30        self.accept_port = Some(accept_port);
31        self
32    }
33
34    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
35        self.secret_key = *secret_key;
36        self
37    }
38
39    pub fn dot_ssh_integration(mut self, persist: bool) -> Self {
40        if let Ok(secret_key) = dot_ssh(&SecretKey::from_bytes(&self.secret_key), persist) {
41            self.secret_key = secret_key.to_bytes();
42        }
43        self
44    }
45
46    pub async fn build(self: &mut Self) -> anyhow::Result<IrohSsh> {
47        // Iroh setup
48        let secret_key = SecretKey::from_bytes(&self.secret_key);
49        let endpoint = Endpoint::builder()
50            .secret_key(secret_key)
51            .discovery_n0()
52            .bind()
53            .await?;
54
55        wait_for_relay(&endpoint).await?;
56
57        let mut iroh_ssh = IrohSsh {
58            public_key: endpoint.node_id().as_bytes().clone(),
59            secret_key: self.secret_key,
60            inner: None,
61            ssh_port: self.accept_port.unwrap_or(22),
62        };
63
64        let router = if self.accept_incoming {
65            Router::builder(endpoint.clone()).accept(&IrohSsh::ALPN(), iroh_ssh.clone())
66        } else {
67            Router::builder(endpoint.clone())
68        }
69        .spawn();
70
71        iroh_ssh.add_inner(endpoint, router);
72
73        Ok(iroh_ssh)
74    }
75}
76
77impl IrohSsh {
78    pub fn new() -> Builder {
79        Builder::new()
80    }
81
82    #[allow(non_snake_case)]
83    pub fn ALPN() -> Vec<u8> {
84        format!("/iroh/ssh").into_bytes()
85    }
86
87    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
88        self.inner = Some(Inner { endpoint, router });
89    }
90
91    pub async fn connect(&self, ssh_user: &str, node_id: NodeId) -> anyhow::Result<Child> {
92        let inner = self.inner.as_ref().expect("inner not set");
93        let conn = inner.endpoint.connect(node_id, &IrohSsh::ALPN()).await?;
94        let listener = TcpListener::bind("127.0.0.1:0").await?;
95        let port = listener.local_addr()?.port();
96
97        tokio::spawn(async move {
98            loop {
99                match listener.accept().await {
100                    Ok((mut stream, _)) => match conn.open_bi().await {
101                        Ok((mut iroh_send, mut iroh_recv)) => {
102                            tokio::spawn(async move {
103                                let (mut local_read, mut local_write) = stream.split();
104                                let a_to_b = async move {
105                                    tokio::io::copy(&mut local_read, &mut iroh_send).await
106                                };
107                                let b_to_a = async move {
108                                    tokio::io::copy(&mut iroh_recv, &mut local_write).await
109                                };
110
111                                tokio::select! {
112                                    result = a_to_b => {
113                                        let _ = result;
114                                    },
115                                    result = b_to_a => {
116                                        let _ = result;
117                                    },
118                                };
119                            });
120                        }
121                        Err(_) => break,
122                    },
123                    Err(_) => break,
124                }
125            }
126        });
127        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
128        let ssh_process = Command::new("ssh")
129            .arg("-tt") // Force pseudo-terminal allocation
130            .arg(format!("{}@127.0.0.1", ssh_user))
131            .arg("-p")
132            .arg(port.to_string())
133            .arg("-o")
134            .arg("StrictHostKeyChecking=no")
135            .arg("-o")
136            .arg("UserKnownHostsFile=/dev/null")
137            .arg("-o")
138            .arg("LogLevel=ERROR") // Reduce SSH debug output
139            .stdin(Stdio::inherit())
140            .stdout(Stdio::inherit())
141            .stderr(Stdio::inherit())
142            .spawn()?;
143
144        Ok(ssh_process)
145    }
146
147    pub fn node_id(&self) -> NodeId {
148        self.inner
149            .as_ref()
150            .expect("inner not set")
151            .endpoint
152            .node_id()
153    }
154}
155
156impl ProtocolHandler for IrohSsh {
157    async fn accept(
158        &self,
159        connection: Connection,
160    ) -> Result<(), iroh::protocol::AcceptError> {
161
162        let alpn = connection.alpn().ok_or_else(|| iroh::protocol::AcceptError::NotAllowed {  })?;
163        if alpn != IrohSsh::ALPN() {
164            return Err(iroh::protocol::AcceptError::NotAllowed {})
165        }
166
167        let node_id = connection.remote_node_id()?;
168        println!("{}: {node_id} connected", String::from_utf8_lossy(&alpn));
169    
170        match connection.accept_bi().await {
171            Ok((mut iroh_send, mut iroh_recv)) => {
172                println!("Accepted bidirectional stream from {}", node_id);
173
174                match TcpStream::connect(format!("127.0.0.1:{}", self.ssh_port)).await {
175                    Ok(mut ssh_stream) => {
176                        println!("Connected to local SSH server on port {}", self.ssh_port);
177
178                        let (mut local_read, mut local_write) = ssh_stream.split();
179
180                        let a_to_b = async move {
181                            tokio::io::copy(&mut local_read, &mut iroh_send).await
182                        };
183                        let b_to_a = async move {
184                            tokio::io::copy(&mut iroh_recv, &mut local_write).await
185                        };
186
187                        tokio::select! {
188                            result = a_to_b => {
189                                println!("SSH->Iroh stream ended: {:?}", result);
190                            },
191                            result = b_to_a => {
192                                println!("Iroh->SSH stream ended: {:?}", result);
193                            },
194                        };
195                    }
196                    Err(e) => {
197                        println!("Failed to connect to SSH server: {}", e);
198                    }
199                }
200            }
201            Err(e) => {
202                println!("Failed to accept bidirectional stream: {}", e);
203            }
204        }
205
206        Ok(())
207        
208    }
209}
210
211pub fn dot_ssh(default_secret_key: &SecretKey, persist: bool) -> anyhow::Result<SecretKey> {
212    let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
213    #[allow(unused_mut)]
214    let mut ssh_dir = distro_home.join(".ssh");
215
216    // For now linux services are installed as "sudo'er" so
217    // we need to use the root .ssh directory
218    #[cfg(target_os = "linux")]
219    if !ssh_dir.join("irohssh_ed25519.pub").exists() {
220        ssh_dir = std::path::PathBuf::from("/root/.ssh");
221        println!("[INFO] using linux service ssh_dir: {}", ssh_dir.display());
222    }
223
224    // Weird windows System service profile location:
225    // "C:\WINDOWS\system32\config\systemprofile\.ssh"
226    #[cfg(target_os = "windows")]
227    if !ssh_dir.join("irohssh_ed25519.pub").exists() {
228        ssh_dir = std::path::PathBuf::from(r#"C:\WINDOWS\system32\config\systemprofile\.ssh"#);
229        println!("[INFO] using windows service ssh_dir: {}", ssh_dir.display());
230    }
231
232    let pub_key = ssh_dir.join("irohssh_ed25519.pub");
233    let priv_key = ssh_dir.join("irohssh_ed25519");
234
235    match (ssh_dir.exists(), persist) {
236        (false, false) => {
237            bail!("no .ssh folder found in {}, use --persist flag to create it", distro_home.display())
238        }
239        (false, true) => {
240            std::fs::create_dir_all(&ssh_dir)?;
241            println!("[INFO] created .ssh folder: {}", ssh_dir.display());
242            dot_ssh(default_secret_key, persist)
243        }
244        (true, true) => {
245            // check pub and priv key already exists
246            if pub_key.exists() && priv_key.exists() {
247                // read secret key
248                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
249                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
250                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
251                    Ok(SecretKey::from_bytes(&sk_bytes))
252                } else {
253                    bail!("failed to read secret key from {}", priv_key.display())
254                }
255            } else {
256                let key = default_secret_key.clone();
257                let secret_key = key.secret();
258                let public_key = key.public();
259
260                std::fs::write(pub_key, z32::encode(public_key.as_bytes()))?;
261                std::fs::write(priv_key, z32::encode(secret_key.as_bytes()))?;
262                Ok(key)
263            }
264        }
265        (true, false) => {
266            // check pub and priv key already exists
267            if pub_key.exists() && priv_key.exists() {
268                // read secret key
269                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
270                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
271                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
272                    return Ok(SecretKey::from_bytes(&sk_bytes));
273                }
274            }
275            bail!("no iroh-ssh keys found in {}, use --persist flag to create it", ssh_dir.display())
276        }
277    }
278}
279
280async fn wait_for_relay(endpoint: &Endpoint) -> anyhow::Result<()> {
281    while endpoint.home_relay().initialized().await.is_err(){
282        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
283    }
284    Ok(())
285}