Skip to main content

iroh_ssh/
ssh.rs

1use crate::{Builder, Inner, IrohSsh, cli::SshOpts};
2use std::{ffi::OsString, process::Stdio};
3
4use anyhow::bail;
5use ed25519_dalek::SECRET_KEY_LENGTH;
6use homedir::my_home;
7use std::sync::Arc;
8
9use iroh::{
10    RelayConfig,
11    endpoint::{Connection, RelayMode}, protocol::{ProtocolHandler, Router}, Endpoint, EndpointId, RelayUrl, SecretKey
12};
13use tokio::{
14    net::TcpStream,
15    process::{Child, Command},
16};
17
18impl Builder {
19    pub fn new() -> Self {
20        Self {
21            secret_key: SecretKey::generate(&mut rand::rng()).to_bytes(),
22            accept_incoming: false,
23            accept_port: None,
24            relay_urls: Vec::new(),
25            extra_relay_urls: Vec::new(),
26        }
27    }
28
29    pub fn accept_incoming(mut self, accept_incoming: bool) -> Self {
30        self.accept_incoming = accept_incoming;
31        self
32    }
33
34    pub fn accept_port(mut self, accept_port: u16) -> Self {
35        self.accept_port = Some(accept_port);
36        self
37    }
38
39    pub fn secret_key(mut self, secret_key: &[u8; SECRET_KEY_LENGTH]) -> Self {
40        self.secret_key = *secret_key;
41        self
42    }
43
44    pub fn relay_urls(mut self, urls: Vec<RelayUrl>) -> Self {
45        self.relay_urls = urls;
46        self
47    }
48
49    pub fn extra_relay_urls(mut self, urls: Vec<RelayUrl>) -> Self {
50        self.extra_relay_urls = urls;
51        self
52    }
53
54    pub fn dot_ssh_integration(mut self, persist: bool, service: bool) -> Self {
55        tracing::info!(
56            "dot_ssh_integration: persist={}, service={}",
57            persist,
58            service
59        );
60
61        match dot_ssh(&SecretKey::from_bytes(&self.secret_key), persist, service) {
62            Ok(secret_key) => {
63                tracing::info!("dot_ssh_integration: Successfully loaded/created SSH keys");
64                self.secret_key = secret_key.to_bytes();
65            }
66            Err(e) => {
67                tracing::error!(
68                    "dot_ssh_integration: Failed to load/create SSH keys: {:#}",
69                    e
70                );
71                eprintln!(
72                    "Warning: Failed to load/create persistent SSH keys: {e:#}"
73                );
74                eprintln!("Continuing with ephemeral keys...");
75            }
76        }
77        self
78    }
79
80    pub async fn build(&mut self) -> anyhow::Result<IrohSsh> {
81        // Iroh setup
82        let secret_key = SecretKey::from_bytes(&self.secret_key);
83        let mut builder = Endpoint::builder().secret_key(secret_key);
84
85        if !self.relay_urls.is_empty() {
86            let relay_map = self.relay_urls.iter().cloned().collect();
87            builder = builder.relay_mode(RelayMode::Custom(relay_map));
88        } else if !self.extra_relay_urls.is_empty() {
89            let relay_map = RelayMode::Default.relay_map();
90            for url in &self.extra_relay_urls {
91                relay_map.insert(url.clone(), Arc::new(RelayConfig::from(url.clone())));
92            }
93            builder = builder.relay_mode(RelayMode::Custom(relay_map));
94        }
95
96        let endpoint = builder.bind().await?;
97
98        let mut iroh_ssh = IrohSsh {
99            public_key: *endpoint.id().as_bytes(),
100            secret_key: self.secret_key,
101            inner: None,
102            ssh_port: self.accept_port.unwrap_or(22),
103        };
104
105        let router = if self.accept_incoming {
106            Router::builder(endpoint.clone()).accept(IrohSsh::ALPN(), iroh_ssh.clone())
107        } else {
108            Router::builder(endpoint.clone())
109        }
110        .spawn();
111
112        iroh_ssh.add_inner(endpoint, router);
113
114        Ok(iroh_ssh)
115    }
116}
117
118impl Default for Builder {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl IrohSsh {
125    pub fn builder() -> Builder {
126        Builder::new()
127    }
128
129    #[allow(non_snake_case)]
130    pub fn ALPN() -> Vec<u8> {
131        b"/iroh/ssh".to_vec()
132    }
133
134    fn add_inner(&mut self, endpoint: Endpoint, router: Router) {
135        self.inner = Some(Inner { endpoint, router });
136    }
137
138    pub async fn start_ssh(
139        &self,
140        target: String,
141        ssh_opts: SshOpts,
142        remote_cmd: Vec<OsString>,
143        relay_urls: &[String],
144        extra_relay_urls: &[String],
145    ) -> anyhow::Result<Child> {
146        let c_exe = std::env::current_exe()?;
147        let cmd = &mut Command::new("ssh");
148
149        let mut proxy_cmd = format!("{} proxy", c_exe.display());
150        for url in relay_urls {
151            proxy_cmd.push_str(&format!(" --relay-url {url}"));
152        }
153        for url in extra_relay_urls {
154            proxy_cmd.push_str(&format!(" --extra-relay-url {url}"));
155        }
156        proxy_cmd.push_str(" %h");
157        cmd.arg("-o")
158            .arg(format!("ProxyCommand={proxy_cmd}"));
159
160        if let Some(p) = ssh_opts.port {
161            cmd.arg("-p").arg(p.to_string());
162        }
163        if let Some(id) = &ssh_opts.identity_file {
164            cmd.arg("-i").arg(id);
165        }
166        for l in &ssh_opts.local_forward {
167            cmd.arg("-L").arg(l);
168        }
169        for r in &ssh_opts.remote_forward {
170            cmd.arg("-R").arg(r);
171        }
172        for o in &ssh_opts.options {
173            cmd.arg("-o").arg(o);
174        }
175        if ssh_opts.agent {
176            cmd.arg("-A");
177        }
178        if ssh_opts.no_agent {
179            cmd.arg("-a");
180        }
181        if ssh_opts.x11_trusted {
182            cmd.arg("-Y");
183        } else if ssh_opts.x11 {
184            cmd.arg("-X");
185        }
186        if ssh_opts.no_cmd {
187            cmd.arg("-N");
188        }
189        if ssh_opts.force_tty {
190            cmd.arg("-t");
191        }
192        if ssh_opts.no_tty {
193            cmd.arg("-T");
194        }
195        for _ in 0..ssh_opts.verbose {
196            cmd.arg("-v");
197        }
198        if ssh_opts.quiet {
199            cmd.arg("-q");
200        }
201
202        cmd.arg(target);
203
204        if !remote_cmd.is_empty() {
205            cmd.args(remote_cmd.iter());
206        }
207
208        let ssh_process = cmd
209            .stdin(Stdio::inherit())
210            .stdout(Stdio::inherit())
211            .stderr(Stdio::inherit())
212            .spawn()?;
213
214        Ok(ssh_process)
215    }
216
217    pub async fn connect(&self, endpoint_id: EndpointId) -> anyhow::Result<()> {
218        let inner = self.inner.as_ref().expect("inner not set");
219        let conn = inner.endpoint.connect(endpoint_id, &IrohSsh::ALPN()).await?;
220        let (mut iroh_send, mut iroh_recv) = conn.open_bi().await?;
221        let (mut local_read, mut local_write) = (tokio::io::stdin(), tokio::io::stdout());
222        let a_to_b = async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
223        let b_to_a = async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
224
225        tokio::select! {
226            result = a_to_b => {
227                let _ = result;
228            },
229            result = b_to_a => {
230                let _ = result;
231            },
232        };
233        Ok(())
234    }
235
236    pub fn endpoint_id(&self) -> EndpointId {
237        self.inner
238            .as_ref()
239            .expect("inner not set")
240            .endpoint
241            .id()
242    }
243}
244
245impl ProtocolHandler for IrohSsh {
246    async fn accept(&self, connection: Connection) -> Result<(), iroh::protocol::AcceptError> {
247        let endpoint_id = connection.remote_id()?;
248
249        match connection.accept_bi().await {
250            Ok((mut iroh_send, mut iroh_recv)) => {
251                println!("Accepted bidirectional stream from {endpoint_id}");
252
253                match TcpStream::connect(format!("127.0.0.1:{}", self.ssh_port)).await {
254                    Ok(mut ssh_stream) => {
255                        println!("Connected to local SSH server on port {}", self.ssh_port);
256
257                        let (mut local_read, mut local_write) = ssh_stream.split();
258
259                        let a_to_b =
260                            async move { tokio::io::copy(&mut local_read, &mut iroh_send).await };
261                        let b_to_a =
262                            async move { tokio::io::copy(&mut iroh_recv, &mut local_write).await };
263
264                        tokio::select! {
265                            result = a_to_b => {
266                                println!("SSH->Iroh stream ended: {result:?}");
267                            },
268                            result = b_to_a => {
269                                println!("Iroh->SSH stream ended: {result:?}");
270                            },
271                        };
272                    }
273                    Err(e) => {
274                        println!("Failed to connect to SSH server: {e}");
275                    }
276                }
277            }
278            Err(e) => {
279                println!("Failed to accept bidirectional stream: {e}");
280            }
281        }
282
283        Ok(())
284    }
285}
286
287pub fn dot_ssh(
288    default_secret_key: &SecretKey,
289    persist: bool,
290    _service: bool,
291) -> anyhow::Result<SecretKey> {
292    tracing::info!(
293        "dot_ssh: Function called, persist={}, service={}",
294        persist,
295        _service
296    );
297
298    let distro_home = my_home()?.ok_or_else(|| anyhow::anyhow!("home directory not found"))?;
299    #[allow(unused_mut)]
300    let mut ssh_dir = distro_home.join(".ssh");
301
302    // For now linux services are installed as "sudo'er" so
303    // we need to use the root .ssh directory
304    #[cfg(target_os = "linux")]
305    if _service {
306        ssh_dir = std::path::PathBuf::from("/root/.ssh");
307    }
308
309    // Windows virtual service account profile location for NT SERVICE\iroh-ssh
310    #[cfg(target_os = "windows")]
311    if _service {
312        ssh_dir = std::path::PathBuf::from(crate::service::WindowsService::SERVICE_SSH_DIR);
313        tracing::info!("dot_ssh: Using service SSH dir: {}", ssh_dir.display());
314
315        // Ensure directory exists when running as service
316        if !ssh_dir.exists() {
317            tracing::info!("dot_ssh: Service SSH dir doesn't exist, creating it");
318            std::fs::create_dir_all(&ssh_dir)?;
319        }
320    }
321
322    let pub_key = ssh_dir.join("irohssh_ed25519.pub");
323    let priv_key = ssh_dir.join("irohssh_ed25519");
324
325    tracing::debug!("dot_ssh: ssh_dir exists = {}", ssh_dir.exists());
326    tracing::debug!("dot_ssh: pub_key path = {}", pub_key.display());
327    tracing::debug!("dot_ssh: priv_key path = {}", priv_key.display());
328
329    match (ssh_dir.exists(), persist) {
330        (false, false) => {
331            tracing::error!(
332                "dot_ssh: ssh_dir does not exist and persist=false: {}",
333                ssh_dir.display()
334            );
335            bail!(
336                "no .ssh folder found in {}, use --persist flag to create it",
337                distro_home.display()
338            )
339        }
340        (false, true) => {
341            tracing::info!("dot_ssh: Creating ssh_dir: {}", ssh_dir.display());
342            std::fs::create_dir_all(&ssh_dir)?;
343            println!("[INFO] created .ssh folder: {}", ssh_dir.display());
344            dot_ssh(default_secret_key, persist, _service)
345        }
346        (true, true) => {
347            tracing::info!("dot_ssh: Branch (true, true) - directory exists, persist enabled");
348            tracing::debug!("dot_ssh: pub_key.exists() = {}", pub_key.exists());
349            tracing::debug!("dot_ssh: priv_key.exists() = {}", priv_key.exists());
350
351            // check pub and priv key already exists
352            if pub_key.exists() && priv_key.exists() {
353                tracing::info!("dot_ssh: Keys exist, reading them");
354                // read secret key
355                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
356                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
357                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
358                    Ok(SecretKey::from_bytes(&sk_bytes))
359                } else {
360                    bail!("failed to read secret key from {}", priv_key.display())
361                }
362            } else {
363                tracing::info!("dot_ssh: Keys don't exist, creating new keys");
364                tracing::debug!("dot_ssh: Writing to pub_key: {}", pub_key.display());
365                tracing::debug!("dot_ssh: Writing to priv_key: {}", priv_key.display());
366
367                let secret_key = default_secret_key.clone();
368                let public_key = secret_key.public();
369
370                match std::fs::write(&pub_key, z32::encode(public_key.as_bytes())) {
371                    Ok(_) => {
372                        tracing::info!("dot_ssh: Successfully wrote pub_key");
373                    }
374                    Err(e) => {
375                        tracing::error!(
376                            "dot_ssh: Failed to write pub_key: {} (error kind: {:?})",
377                            e,
378                            e.kind()
379                        );
380                        return Err(e.into());
381                    }
382                }
383
384                match std::fs::write(&priv_key, z32::encode(&secret_key.to_bytes())) {
385                    Ok(_) => {
386                        tracing::info!("dot_ssh: Successfully wrote priv_key");
387                    }
388                    Err(e) => {
389                        tracing::error!(
390                            "dot_ssh: Failed to write priv_key: {} (error kind: {:?})",
391                            e,
392                            e.kind()
393                        );
394                        return Err(e.into());
395                    }
396                }
397
398                Ok(secret_key)
399            }
400        }
401        (true, false) => {
402            // check pub and priv key already exists
403            if pub_key.exists() && priv_key.exists() {
404                // read secret key
405                if let Ok(secret_key) = std::fs::read(priv_key.clone()) {
406                    let mut sk_bytes = [0u8; SECRET_KEY_LENGTH];
407                    sk_bytes.copy_from_slice(z32::decode(secret_key.as_slice())?.as_slice());
408                    return Ok(SecretKey::from_bytes(&sk_bytes));
409                }
410            }
411            bail!(
412                "no iroh-ssh keys found in {}, use --persist flag to create it",
413                ssh_dir.display()
414            )
415        }
416    }
417}