iroh_ssh/
ssh.rs

1use crate::{Builder, Inner, IrohSsh, api::ClientOptions};
2use std::process::Stdio;
3
4use anyhow::bail;
5use ed25519_dalek::SECRET_KEY_LENGTH;
6use homedir::my_home;
7use iroh::{
8    Endpoint, NodeId, SecretKey, Watcher,
9    endpoint::Connection,
10    protocol::{ProtocolHandler, Router},
11};
12use tokio::{
13    net::{TcpListener, TcpStream},
14    process::{Child, Command},
15};
16
17impl Builder {
18    pub fn new() -> Self {
19        Self {
20            secret_key: SecretKey::generate(rand::rngs::OsRng).to_bytes(),
21            accept_incoming: false,
22            accept_port: None,
23        }
24    }
25
26    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
27        self.accept_incoming = accept_incoming;
28        self
29    }
30
31    pub fn accept_port(mut self, accept_port: u16) -> Self {
32        self.accept_port = Some(accept_port);
33        self
34    }
35
36    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
37        self.secret_key = *secret_key;
38        self
39    }
40
41    pub fn dot_ssh_integration(mut self, persist: bool, service: bool) -> Self {
42        if let Ok(secret_key) = dot_ssh(&SecretKey::from_bytes(&self.secret_key), persist, service)
43        {
44            self.secret_key = secret_key.to_bytes();
45        }
46        self
47    }
48
49    pub async fn build(&mut self) -> anyhow::Result<IrohSsh> {
50        // Iroh setup
51        let secret_key = SecretKey::from_bytes(&self.secret_key);
52        let endpoint = Endpoint::builder()
53            .secret_key(secret_key)
54            .discovery_n0()
55            .bind()
56            .await?;
57
58        let _ = endpoint.home_relay().initialized().await?;
59
60        let mut iroh_ssh = IrohSsh {
61            public_key: *endpoint.node_id().as_bytes(),
62            secret_key: self.secret_key,
63            inner: None,
64            ssh_port: self.accept_port.unwrap_or(22),
65        };
66
67        let router = if self.accept_incoming {
68            Router::builder(endpoint.clone()).accept(IrohSsh::ALPN(), iroh_ssh.clone())
69        } else {
70            Router::builder(endpoint.clone())
71        }
72        .spawn();
73
74        iroh_ssh.add_inner(endpoint, router);
75
76        Ok(iroh_ssh)
77    }
78}
79
80impl Default for Builder {
81    fn default() -> Self {
82        Self::new()
83    }
84}
85
86impl IrohSsh {
87    pub fn builder() -> Builder {
88        Builder::new()
89    }
90
91    #[allow(non_snake_case)]
92    pub fn ALPN() -> Vec<u8> {
93        b"/iroh/ssh".to_vec()
94    }
95
96    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
97        self.inner = Some(Inner { endpoint, router });
98    }
99
100    pub async fn connect(
101        &self,
102        ssh_user: &str,
103        node_id: NodeId,
104        client_options: ClientOptions,
105        execute_command: Vec<String>,
106    ) -> anyhow::Result<Child> {
107        let inner = self.inner.as_ref().expect("inner not set");
108        let conn = inner.endpoint.connect(node_id, &IrohSsh::ALPN()).await?;
109        let listener = TcpListener::bind("127.0.0.1:0").await?;
110        let port = listener.local_addr()?.port();
111
112        tokio::spawn(async move {
113            loop {
114                let _ = handle_next(&listener, &conn).await;
115            }
116
117            async fn handle_next(
118                listener: &TcpListener,
119                conn: &Connection,
120            ) -> Result<(), std::io::Error> {
121                let (mut stream, _) = listener.accept().await?;
122                let (mut iroh_send, mut iroh_recv) = conn.open_bi().await?;
123                tokio::spawn(async move {
124                    let (mut local_read, mut local_write) = stream.split();
125                    let a_to_b =
126                        async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
127                    let b_to_a =
128                        async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
129
130                    tokio::select! {
131                        result = a_to_b => {
132                            let _ = result;
133                        },
134                        result = b_to_a => {
135                            let _ = result;
136                        },
137                    };
138                });
139                Ok(())
140            }
141        });
142        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
143        let mut cmd = &mut Command::new("ssh");
144        cmd = cmd
145            .arg("-tt") // Force pseudo-terminal allocation
146            .arg(format!("{ssh_user}@127.0.0.1"))
147            .arg("-p")
148            .arg(port.to_string())
149            .arg("-o")
150            .arg("StrictHostKeyChecking=no")
151            .arg("-o")
152            .arg("UserKnownHostsFile=/dev/null")
153            .arg("-o")
154            .arg("LogLevel=ERROR"); // Reduce SSH debug output
155
156        if let Some(identity_file) = client_options.identity_file {
157            cmd.arg("-i").arg(identity_file);
158        }
159        if let Some(local_forward) = client_options.local_forward {
160            cmd.arg("-L").arg(local_forward);
161        }
162        if let Some(remote_forward) = client_options.remote_forward {
163            cmd.arg("-R").arg(remote_forward);
164        }
165        cmd.arg(execute_command.join(" "));
166
167        let ssh_process = cmd
168            .stdin(Stdio::inherit())
169            .stdout(Stdio::inherit())
170            .stderr(Stdio::inherit())
171            .spawn()?;
172
173        Ok(ssh_process)
174    }
175
176    pub fn node_id(&self) -> NodeId {
177        self.inner
178            .as_ref()
179            .expect("inner not set")
180            .endpoint
181            .node_id()
182    }
183}
184
185impl ProtocolHandler for IrohSsh {
186    async fn accept(&self, connection: Connection) -> Result<(), iroh::protocol::AcceptError> {
187        let alpn = connection
188            .alpn()
189            .ok_or_else(|| iroh::protocol::AcceptError::NotAllowed {})?;
190        if alpn != IrohSsh::ALPN() {
191            return Err(iroh::protocol::AcceptError::NotAllowed {});
192        }
193
194        let node_id = connection.remote_node_id()?;
195        println!("{}: {node_id} connected", String::from_utf8_lossy(&alpn));
196
197        match connection.accept_bi().await {
198            Ok((mut iroh_send, mut iroh_recv)) => {
199                println!("Accepted bidirectional stream from {node_id}");
200
201                match TcpStream::connect(format!("127.0.0.1:{}", self.ssh_port)).await {
202                    Ok(mut ssh_stream) => {
203                        println!("Connected to local SSH server on port {}", self.ssh_port);
204
205                        let (mut local_read, mut local_write) = ssh_stream.split();
206
207                        let a_to_b =
208                            async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
209                        let b_to_a =
210                            async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
211
212                        tokio::select! {
213                            result = a_to_b => {
214                                println!("SSH->Iroh stream ended: {result:?}");
215                            },
216                            result = b_to_a => {
217                                println!("Iroh->SSH stream ended: {result:?}");
218                            },
219                        };
220                    }
221                    Err(e) => {
222                        println!("Failed to connect to SSH server: {e}");
223                    }
224                }
225            }
226            Err(e) => {
227                println!("Failed to accept bidirectional stream: {e}");
228            }
229        }
230
231        Ok(())
232    }
233}
234
235pub fn dot_ssh(
236    default_secret_key: &SecretKey,
237    persist: bool,
238    _service: bool,
239) -> anyhow::Result<SecretKey> {
240    let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
241    #[allow(unused_mut)]
242    let mut ssh_dir = distro_home.join(".ssh");
243
244    // For now linux services are installed as "sudo'er" so
245    // we need to use the root .ssh directory
246    #[cfg(target_os = "linux")]
247    if _service {
248        ssh_dir = std::path::PathBuf::from("/root/.ssh");
249    }
250
251    // Weird windows System service profile location:
252    // "C:\WINDOWS\system32\config\systemprofile\.ssh"
253    #[cfg(target_os = "windows")]
254    if _service {
255        ssh_dir = std::path::PathBuf::from(r#"C:\WINDOWS\system32\config\systemprofile\.ssh"#);
256    }
257
258    let pub_key = ssh_dir.join("irohssh_ed25519.pub");
259    let priv_key = ssh_dir.join("irohssh_ed25519");
260
261    match (ssh_dir.exists(), persist) {
262        (false, false) => {
263            bail!(
264                "no .ssh folder found in {}, use --persist flag to create it",
265                distro_home.display()
266            )
267        }
268        (false, true) => {
269            std::fs::create_dir_all(&ssh_dir)?;
270            println!("[INFO] created .ssh folder: {}", ssh_dir.display());
271            dot_ssh(default_secret_key, persist, _service)
272        }
273        (true, true) => {
274            // check pub and priv key already exists
275            if pub_key.exists() && priv_key.exists() {
276                // read secret key
277                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
278                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
279                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
280                    Ok(SecretKey::from_bytes(&sk_bytes))
281                } else {
282                    bail!("failed to read secret key from {}", priv_key.display())
283                }
284            } else {
285                let key = default_secret_key.clone();
286                let secret_key = key.secret();
287                let public_key = key.public();
288
289                std::fs::write(pub_key, z32::encode(public_key.as_bytes()))?;
290                std::fs::write(priv_key, z32::encode(secret_key.as_bytes()))?;
291                Ok(key)
292            }
293        }
294        (true, false) => {
295            // check pub and priv key already exists
296            if pub_key.exists() && priv_key.exists() {
297                // read secret key
298                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
299                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
300                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
301                    return Ok(SecretKey::from_bytes(&sk_bytes));
302                }
303            }
304            bail!(
305                "no iroh-ssh keys found in {}, use --persist flag to create it",
306                ssh_dir.display()
307            )
308        }
309    }
310}