Skip to main content

iroh_ssh/
ssh.rs

1use crate::{Builder, Inner, IrohSsh, cli::SshOpts};
2use std::{ffi::OsString, io, path::Path, process::Stdio, time::Duration};
3
4use anyhow::bail;
5use ed25519_dalek::SECRET_KEY_LENGTH;
6use homedir::my_home;
7use regex::Regex;
8use std::sync::Arc;
9
10use iroh::{
11    Endpoint, EndpointId, RelayConfig, RelayUrl, SecretKey,
12    endpoint::{Connection, RelayMode},
13    protocol::{ProtocolHandler, Router},
14};
15use tokio::{
16    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
17    net::TcpStream,
18    process::{Child, Command},
19};
20
21impl Builder {
22    pub fn new() -> Self {
23        Self {
24            secret_key: SecretKey::generate(&mut rand::rng()).to_bytes(),
25            accept_incoming: false,
26            accept_port: None,
27            key_dir: None,
28            relay_urls: Vec::new(),
29            extra_relay_urls: Vec::new(),
30        }
31    }
32
33    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
34        self.accept_incoming = accept_incoming;
35        self
36    }
37
38    pub fn accept_port(mut self, accept_port: u16) -> Self {
39        self.accept_port = Some(accept_port);
40        self
41    }
42
43    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
44        self.secret_key = *secret_key;
45        self
46    }
47
48    pub fn relay_urls(mut self, urls: Vec<RelayUrl>) -> Self {
49        self.relay_urls = urls;
50        self
51    }
52
53    pub fn extra_relay_urls(mut self, urls: Vec<RelayUrl>) -> Self {
54        self.extra_relay_urls = urls;
55        self
56    }
57
58    pub fn key_dir(mut self, key_dir: Option<std::path::PathBuf>) -> Self {
59        self.key_dir = key_dir;
60        self
61    }
62
63    pub fn dot_ssh_integration(mut self, persist: bool, service: bool) -> Self {
64        tracing::info!(
65            "dot_ssh_integration: persist={}, service={}",
66            persist,
67            service
68        );
69
70        match dot_ssh(
71            &SecretKey::from_bytes(&self.secret_key),
72            persist,
73            service,
74            self.key_dir.as_deref(),
75        ) {
76            Ok(secret_key) => {
77                tracing::info!("dot_ssh_integration: Successfully loaded/created SSH keys");
78                self.secret_key = secret_key.to_bytes();
79            }
80            Err(e) => {
81                tracing::error!(
82                    "dot_ssh_integration: Failed to load/create SSH keys: {:#}",
83                    e
84                );
85                eprintln!("Warning: Failed to load/create persistent SSH keys: {e:#}");
86                eprintln!("Continuing with ephemeral keys...");
87            }
88        }
89        self
90    }
91
92    pub async fn build(&mut self) -> anyhow::Result<IrohSsh> {
93        // Iroh setup
94        let secret_key = SecretKey::from_bytes(&self.secret_key);
95        let mut builder = Endpoint::builder().secret_key(secret_key);
96
97        if !self.relay_urls.is_empty() {
98            let relay_map = self.relay_urls.iter().cloned().collect();
99            builder = builder.relay_mode(RelayMode::Custom(relay_map));
100        } else if !self.extra_relay_urls.is_empty() {
101            let relay_map = RelayMode::Default.relay_map();
102            for url in &self.extra_relay_urls {
103                relay_map.insert(url.clone(), Arc::new(RelayConfig::from(url.clone())));
104            }
105            builder = builder.relay_mode(RelayMode::Custom(relay_map));
106        }
107
108        let endpoint = builder.bind().await?;
109
110        let mut iroh_ssh = IrohSsh {
111            public_key: *endpoint.id().as_bytes(),
112            secret_key: self.secret_key,
113            inner: None,
114            ssh_port: self.accept_port.unwrap_or(22),
115        };
116
117        let router = if self.accept_incoming {
118            if is_ssh_server_available(iroh_ssh.ssh_port, Duration::from_secs(10)).await.is_err() {
119                eprintln!("SSH server not available on port {}, incoming connections will fail. Please ensure you have an SSH server installed and running on port {}.", iroh_ssh.ssh_port, iroh_ssh.ssh_port);
120                bail!("no ssh server available on specified port")
121            }
122            Router::builder(endpoint.clone()).accept(IrohSsh::ALPN(), iroh_ssh.clone())
123        } else {
124            Router::builder(endpoint.clone())
125        }
126        .spawn();
127
128        iroh_ssh.add_inner(endpoint, router);
129
130        Ok(iroh_ssh)
131    }
132}
133
134async fn is_ssh_server_available(port: u16, timeout: Duration) -> anyhow::Result<()> {
135    tokio::time::timeout(timeout, async {
136        let stream = TcpStream::connect(format!("localhost:{port}")).await?;
137        let mut reader = BufReader::new(stream);
138        let mut line_buf = String::new();
139        let regex = Regex::new(r"^SSH-\d+\.\d+-").expect("valid regex");
140
141        loop {
142            line_buf.clear();
143            if reader.read_line(&mut line_buf).await? == 0 {
144                bail!("SSH server closed the connection");
145            }
146            if regex.is_match(&line_buf) {
147                reader.shutdown().await.ok();
148                return Ok(());
149            }
150        }
151    })
152    .await?
153}
154
155impl Default for Builder {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl IrohSsh {
162    pub fn builder() -> Builder {
163        Builder::new()
164    }
165
166    #[allow(non_snake_case)]
167    pub fn ALPN() -> Vec<u8> {
168        b"/iroh/ssh".to_vec()
169    }
170
171    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
172        self.inner = Some(Inner { endpoint, router });
173    }
174
175    pub async fn start_ssh(
176        &self,
177        target: String,
178        ssh_opts: SshOpts,
179        remote_cmd: Vec<OsString>,
180        relay_urls: &[String],
181        extra_relay_urls: &[String],
182    ) -> io::Result<Child> {
183        let c_exe = std::env::current_exe()?;
184        let mut cmd = build_ssh_command(
185            &c_exe,
186            target,
187            ssh_opts,
188            remote_cmd,
189            relay_urls,
190            extra_relay_urls,
191        );
192
193        let ssh_process = cmd
194            .stdin(Stdio::inherit())
195            .stdout(Stdio::inherit())
196            .stderr(Stdio::inherit())
197            .spawn()?;
198
199        Ok(ssh_process)
200    }
201
202    pub async fn connect_pubkey(&self, endpoint_id: EndpointId) -> anyhow::Result<()> {
203        let inner = self.inner.as_ref().expect("inner not set");
204        let conn = inner
205            .endpoint
206            .connect(endpoint_id, &IrohSsh::ALPN())
207            .await?;
208        let (mut iroh_send, mut iroh_recv) = conn.open_bi().await?;
209        let (mut local_read, mut local_write) = (tokio::io::stdin(), tokio::io::stdout());
210        let a_to_b = async move {
211            let res = tokio::io::copy(&mut local_read, &mut iroh_send).await;
212            iroh_send.finish().ok();
213            res
214        };
215        let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
216
217        let (_, _) = tokio::join!(a_to_b, b_to_a);
218        Ok(())
219    }
220
221    pub async fn connect_tcpip(&self, host_addr: &str) -> anyhow::Result<()> {
222        let conn = tokio::net::TcpStream::connect(host_addr).await?;
223        let (mut tcp_read, mut tcp_write) = conn.into_split();
224        let (mut local_read, mut local_write) = (tokio::io::stdin(), tokio::io::stdout());
225        let a_to_b = async move {
226            let res = tokio::io::copy(&mut local_read, &mut tcp_write).await;
227            tcp_write.shutdown().await.ok();
228            res
229        };
230        let b_to_a = async move { tokio::io::copy(&mut tcp_read, &mut local_write).await };
231
232        let (_, _) = tokio::join!(a_to_b, b_to_a);
233        Ok(())
234    }
235
236    pub fn endpoint_id(&self) -> EndpointId {
237        self.inner.as_ref().expect("inner not set").endpoint.id()
238    }
239}
240
241fn build_ssh_command(
242    iroh_ssh_exe: &Path,
243    target: String,
244    ssh_opts: SshOpts,
245    remote_cmd: Vec<OsString>,
246    relay_urls: &[String],
247    extra_relay_urls: &[String],
248) -> Command {
249    let mut cmd = Command::new("ssh");
250
251    let mut proxy_cmd = format!("{} proxy", iroh_ssh_exe.display());
252    for url in relay_urls {
253        proxy_cmd.push_str(&format!(" --relay-url {url}"));
254    }
255    for url in extra_relay_urls {
256        proxy_cmd.push_str(&format!(" --extra-relay-url {url}"));
257    }
258    proxy_cmd.push_str(" %h:%p");
259    cmd.arg("-o").arg(format!("ProxyCommand={proxy_cmd}"));
260
261    if let Some(p) = ssh_opts.port {
262        cmd.arg("-p").arg(p.to_string());
263    }
264    if let Some(u) = &ssh_opts.login_user {
265        cmd.arg("-l").arg(u);
266    }
267    if let Some(id) = &ssh_opts.identity_file {
268        cmd.arg("-i").arg(id);
269    }
270    for l in &ssh_opts.local_forward {
271        cmd.arg("-L").arg(l);
272    }
273    for r in &ssh_opts.remote_forward {
274        cmd.arg("-R").arg(r);
275    }
276    for o in &ssh_opts.options {
277        cmd.arg("-o").arg(o);
278    }
279    if ssh_opts.agent {
280        cmd.arg("-A");
281    }
282    if ssh_opts.no_agent {
283        cmd.arg("-a");
284    }
285    if ssh_opts.x11_trusted {
286        cmd.arg("-Y");
287    } else if ssh_opts.x11 {
288        cmd.arg("-X");
289    }
290    if ssh_opts.no_cmd {
291        cmd.arg("-N");
292    }
293    if ssh_opts.force_tty {
294        cmd.arg("-t");
295    }
296    if ssh_opts.no_tty {
297        cmd.arg("-T");
298    }
299    for _ in 0..ssh_opts.verbose {
300        cmd.arg("-v");
301    }
302    if ssh_opts.quiet {
303        cmd.arg("-q");
304    }
305
306    cmd.arg(target);
307
308    if !remote_cmd.is_empty() {
309        cmd.args(remote_cmd.iter());
310    }
311
312    cmd
313}
314
315impl ProtocolHandler for IrohSsh {
316    async fn accept(&self, connection: Connection) -> Result<(), iroh::protocol::AcceptError> {
317        let endpoint_id = connection.remote_id()?;
318
319        match connection.accept_bi().await {
320            Ok((mut iroh_send, mut iroh_recv)) => {
321                println!("Accepted bidirectional stream from {endpoint_id}");
322
323                match TcpStream::connect(format!("127.0.0.1:{}", self.ssh_port)).await {
324                    Ok(mut ssh_stream) => {
325                        println!("Connected to local SSH server on port {}", self.ssh_port);
326
327                        let (mut local_read, mut local_write) = ssh_stream.split();
328
329                        let a_to_b = async move {
330                            let res = tokio::io::copy(&mut local_read, &mut iroh_send).await;
331                            iroh_send.finish().ok();
332                            res
333                        };
334                        let b_to_a =
335                            async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
336
337                        let (_, _) = tokio::join!(a_to_b, b_to_a);
338                    }
339                    Err(e) => {
340                        println!("Failed to connect to SSH server: {e}");
341                    }
342                }
343            }
344            Err(e) => {
345                println!("Failed to accept bidirectional stream: {e}");
346            }
347        }
348
349        Ok(())
350    }
351}
352
353pub fn dot_ssh(
354    default_secret_key: &SecretKey,
355    persist: bool,
356    _service: bool,
357    key_dir: Option<&Path>,
358) -> anyhow::Result<SecretKey> {
359    tracing::info!(
360        "dot_ssh: Function called, persist={}, service={}, key_dir={:?}",
361        persist,
362        _service,
363        key_dir,
364    );
365
366    #[allow(unused_mut)]
367    let mut ssh_dir = if let Some(dir) = key_dir {
368        dir.to_path_buf()
369    } else {
370        let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
371        distro_home.join(".ssh")
372    };
373
374    // Only apply service-specific overrides when no explicit key_dir is set
375    if key_dir.is_none() {
376        // For now linux services are installed as "sudo'er" so
377        // we need to use the root .ssh directory
378        #[cfg(target_os = "linux")]
379        if _service {
380            ssh_dir = std::path::PathBuf::from("/root/.ssh");
381        }
382
383        // Windows virtual service account profile location for NT SERVICE\iroh-ssh
384        #[cfg(target_os = "windows")]
385        if _service {
386            ssh_dir = std::path::PathBuf::from(crate::service::WindowsService::SERVICE_SSH_DIR);
387            tracing::info!("dot_ssh: Using service SSH dir: {}", ssh_dir.display());
388
389            // Ensure directory exists when running as service
390            if !ssh_dir.exists() {
391                tracing::info!("dot_ssh: Service SSH dir doesn't exist, creating it");
392                std::fs::create_dir_all(&ssh_dir)?;
393            }
394        }
395    }
396
397    let pub_key = ssh_dir.join("irohssh_ed25519.pub");
398    let priv_key = ssh_dir.join("irohssh_ed25519");
399
400    tracing::debug!("dot_ssh: ssh_dir exists = {}", ssh_dir.exists());
401    tracing::debug!("dot_ssh: pub_key path = {}", pub_key.display());
402    tracing::debug!("dot_ssh: priv_key path = {}", priv_key.display());
403
404    match (ssh_dir.exists(), persist) {
405        (false, false) => {
406            tracing::error!(
407                "dot_ssh: ssh_dir does not exist and persist=false: {}",
408                ssh_dir.display()
409            );
410            bail!(
411                "key directory {} does not exist, use --persist flag to create it",
412                ssh_dir.display()
413            )
414        }
415        (false, true) => {
416            tracing::info!("dot_ssh: Creating ssh_dir: {}", ssh_dir.display());
417            std::fs::create_dir_all(&ssh_dir)?;
418            println!("[INFO] created .ssh folder: {}", ssh_dir.display());
419            dot_ssh(default_secret_key, persist, _service, key_dir)
420        }
421        (true, true) => {
422            tracing::info!("dot_ssh: Branch (true, true) - directory exists, persist enabled");
423            tracing::debug!("dot_ssh: pub_key.exists() = {}", pub_key.exists());
424            tracing::debug!("dot_ssh: priv_key.exists() = {}", priv_key.exists());
425
426            // check pub and priv key already exists
427            if pub_key.exists() && priv_key.exists() {
428                tracing::info!("dot_ssh: Keys exist, reading them");
429                // read secret key
430                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
431                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
432                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
433                    Ok(SecretKey::from_bytes(&sk_bytes))
434                } else {
435                    bail!("failed to read secret key from {}", priv_key.display())
436                }
437            } else {
438                tracing::info!("dot_ssh: Keys don't exist, creating new keys");
439                tracing::debug!("dot_ssh: Writing to pub_key: {}", pub_key.display());
440                tracing::debug!("dot_ssh: Writing to priv_key: {}", priv_key.display());
441
442                let secret_key = default_secret_key.clone();
443                let public_key = secret_key.public();
444
445                match std::fs::write(&pub_key, z32::encode(public_key.as_bytes())) {
446                    Ok(_) => {
447                        tracing::info!("dot_ssh: Successfully wrote pub_key");
448                    }
449                    Err(e) => {
450                        tracing::error!(
451                            "dot_ssh: Failed to write pub_key: {} (error kind: {:?})",
452                            e,
453                            e.kind()
454                        );
455                        return Err(e.into());
456                    }
457                }
458
459                match std::fs::write(&priv_key, z32::encode(&secret_key.to_bytes())) {
460                    Ok(_) => {
461                        tracing::info!("dot_ssh: Successfully wrote priv_key");
462                    }
463                    Err(e) => {
464                        tracing::error!(
465                            "dot_ssh: Failed to write priv_key: {} (error kind: {:?})",
466                            e,
467                            e.kind()
468                        );
469                        return Err(e.into());
470                    }
471                }
472
473                Ok(secret_key)
474            }
475        }
476        (true, false) => {
477            // check pub and priv key already exists
478            if pub_key.exists() && priv_key.exists() {
479                // read secret key
480                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
481                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
482                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
483                    return Ok(SecretKey::from_bytes(&sk_bytes));
484                }
485            }
486            bail!(
487                "no iroh-ssh keys found in {}, use --persist flag to create it",
488                ssh_dir.display()
489            )
490        }
491    }
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use clap::Parser;
498
499    fn args_of(cmd: &Command) -> Vec<String> {
500        cmd.as_std()
501            .get_args()
502            .map(|a| a.to_string_lossy().into_owned())
503            .collect()
504    }
505
506    #[test]
507    fn login_user_flag_is_passed_to_ssh() {
508        let opts = SshOpts {
509            login_user: Some("alice".to_string()),
510            ..Default::default()
511        };
512
513        let cmd = build_ssh_command(
514            Path::new("/usr/bin/iroh-ssh"),
515            "endpoint123".to_string(),
516            opts,
517            Vec::new(),
518            &[],
519            &[],
520        );
521
522        let args = args_of(&cmd);
523        let l_pos = args.iter().position(|a| a == "-l").expect("-l not found");
524        assert_eq!(args[l_pos + 1], "alice");
525    }
526
527    #[test]
528    fn rsync_invocation_parses_and_builds() {
529        // Mirrors `rsync -e iroh-ssh /local user@<id>:/remote`, which invokes:
530        //   iroh-ssh -l alice <endpoint_id> rsync --server -e.LsfxCIvu . /tmp/dest
531        let cli = crate::cli::Cli::try_parse_from([
532            "iroh-ssh",
533            "-l",
534            "alice",
535            "endpoint123",
536            "rsync",
537            "--server",
538            "-e.LsfxCIvu",
539            ".",
540            "/tmp/dest",
541        ])
542        .expect("CLI should accept rsync's invocation pattern");
543
544        assert_eq!(cli.target.as_deref(), Some("endpoint123"));
545        assert_eq!(cli.ssh.login_user.as_deref(), Some("alice"));
546        let remote_cmd_raw = cli.remote_cmd.unwrap_or_default();
547        let remote_cmd: Vec<String> = remote_cmd_raw
548            .iter()
549            .map(|o| o.to_string_lossy().into_owned())
550            .collect();
551        assert_eq!(
552            remote_cmd,
553            vec!["rsync", "--server", "-e.LsfxCIvu", ".", "/tmp/dest"]
554        );
555
556        let cmd = build_ssh_command(
557            Path::new("/usr/bin/iroh-ssh"),
558            cli.target.unwrap(),
559            cli.ssh,
560            remote_cmd_raw,
561            &[],
562            &[],
563        );
564        let args = args_of(&cmd);
565
566        // -l alice should appear before the host arg
567        let host_pos = args
568            .iter()
569            .position(|a| a == "endpoint123")
570            .expect("host arg not found");
571        let l_pos = args.iter().position(|a| a == "-l").expect("-l not found");
572        assert!(l_pos < host_pos);
573        assert_eq!(args[l_pos + 1], "alice");
574
575        // remote command should follow the host
576        assert_eq!(args[host_pos + 1], "rsync");
577        assert_eq!(args[host_pos + 2], "--server");
578        assert_eq!(args[host_pos + 3], "-e.LsfxCIvu");
579    }
580}