Skip to main content

blit_ssh/
lib.rs

1//! Embedded SSH client for blit.
2//!
3//! Provides connection pooling, ssh-agent authentication, `~/.ssh/config`
4//! parsing, and `direct-streamlocal` channel forwarding for connecting to
5//! remote blit-servers without shelling out to the system `ssh` binary.
6
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12use russh::client;
13#[cfg(unix)]
14use russh::keys::agent;
15use russh::keys::{self, PrivateKeyWithHashAlg};
16
17// ── Error ──────────────────────────────────────────────────────────────
18
19#[derive(Debug, thiserror::Error)]
20pub enum Error {
21    #[error("ssh: {0}")]
22    Russh(#[from] russh::Error),
23    #[error("ssh key: {0}")]
24    Keys(#[from] keys::Error),
25    #[error("ssh: {0}")]
26    Io(#[from] std::io::Error),
27    #[error("ssh: {0}")]
28    Other(String),
29}
30
31// ── Shell scripts run on the remote ────────────────────────────────────
32
33/// Resolve the remote blit socket path.
34///
35/// Wrapped in `sh -c` so the POSIX script runs correctly even when the
36/// remote user's login shell is fish or another non-POSIX shell.
37const SOCK_SEARCH: &str = r#"sh -c 'if [ -n "$BLIT_SOCK" ]; then S="$BLIT_SOCK"; elif [ -n "$TMPDIR" ] && [ -S "$TMPDIR/blit.sock" ]; then S="$TMPDIR/blit.sock"; elif [ -S "/tmp/blit-$(id -un).sock" ]; then S="/tmp/blit-$(id -un).sock"; elif [ -S "/run/blit/$(id -un).sock" ]; then S="/run/blit/$(id -un).sock"; elif [ -n "$XDG_RUNTIME_DIR" ] && [ -S "$XDG_RUNTIME_DIR/blit.sock" ]; then S="$XDG_RUNTIME_DIR/blit.sock"; else S=/tmp/blit.sock; fi; echo "$S"'"#;
38
39/// Escape a string for use inside double quotes in a POSIX shell.
40/// Handles `\`, `$`, `` ` ``, and `"`.
41fn dq_escape(s: &str) -> String {
42    let mut out = String::with_capacity(s.len());
43    for ch in s.chars() {
44        match ch {
45            '\\' | '$' | '`' | '"' => {
46                out.push('\\');
47                out.push(ch);
48            }
49            _ => out.push(ch),
50        }
51    }
52    out
53}
54
55/// Install blit on the remote if missing, then start blit-server and
56/// detach it from the session.
57///
58/// Wrapped in `sh -c` so the POSIX script runs correctly even when the
59/// remote user's login shell is fish or another non-POSIX shell.  The
60/// socket path is double-quote-escaped to avoid single-quote nesting
61/// issues inside the outer `sh -c '…'` wrapper.
62fn install_and_start_script(socket_path: &str) -> String {
63    let escaped = dq_escape(socket_path);
64    format!(
65        "sh -c 'export PATH=\"$HOME/.local/bin:$PATH\"; \
66         if ! command -v blit >/dev/null 2>&1 && ! command -v blit-server >/dev/null 2>&1; then \
67           if command -v curl >/dev/null 2>&1; then BLIT_PREFIX=\"$HOME/.local\" curl -sf https://install.blit.sh | sh >&2; \
68           elif command -v wget >/dev/null 2>&1; then BLIT_PREFIX=\"$HOME/.local\" wget -qO- https://install.blit.sh | sh >&2; fi; \
69         fi; \
70         S=\"{escaped}\"; \
71         if [ -S \"$S\" ]; then \
72           if command -v nc >/dev/null 2>&1; then nc -z -U \"$S\" 2>/dev/null || rm -f \"$S\"; \
73           elif command -v socat >/dev/null 2>&1; then socat /dev/null \"UNIX-CONNECT:$S\" 2>/dev/null || rm -f \"$S\"; fi; \
74         fi; \
75         if ! [ -S \"$S\" ]; then \
76           if command -v blit >/dev/null 2>&1; then nohup blit server </dev/null >/dev/null 2>&1 & \
77           elif command -v blit-server >/dev/null 2>&1; then nohup blit-server </dev/null >/dev/null 2>&1 & fi; \
78         fi; \
79         echo ok'"
80    )
81}
82
83// ── SSH config resolution ──────────────────────────────────────────────
84
85/// Resolved SSH settings for a host, from `~/.ssh/config`.
86#[derive(Default)]
87struct ResolvedConfig {
88    hostname: Option<String>,
89    user: Option<String>,
90    port: Option<u16>,
91    identity_files: Vec<PathBuf>,
92    proxy_jump: Option<String>,
93}
94
95/// Minimal `~/.ssh/config` parser. Supports Host (with `*`/`?` globs),
96/// Hostname, User, Port, IdentityFile, and ProxyJump.
97fn resolve_ssh_config(host: &str) -> ResolvedConfig {
98    let path = match home_dir() {
99        Some(h) => h.join(".ssh").join("config"),
100        None => return ResolvedConfig::default(),
101    };
102    let text = match std::fs::read_to_string(&path) {
103        Ok(t) => t,
104        Err(_) => return ResolvedConfig::default(),
105    };
106
107    let mut result = ResolvedConfig::default();
108    let mut in_matching_block = false;
109    let mut in_global = true; // before the first Host line
110
111    for line in text.lines() {
112        let line = line.trim();
113        if line.is_empty() || line.starts_with('#') {
114            continue;
115        }
116        let (key, value) = match line.split_once(|c: char| c.is_ascii_whitespace() || c == '=') {
117            Some((k, v)) => (k.trim(), v.trim().trim_start_matches('=')),
118            None => continue,
119        };
120        let value = value.trim();
121        if key.eq_ignore_ascii_case("Host") {
122            in_global = false;
123            in_matching_block = value
124                .split_whitespace()
125                .any(|pattern| host_matches(pattern, host));
126            continue;
127        }
128        if !in_matching_block && !in_global {
129            continue;
130        }
131        if key.eq_ignore_ascii_case("Hostname") && result.hostname.is_none() {
132            result.hostname = Some(value.to_string());
133        } else if key.eq_ignore_ascii_case("User") && result.user.is_none() {
134            result.user = Some(value.to_string());
135        } else if key.eq_ignore_ascii_case("Port") && result.port.is_none() {
136            result.port = value.parse().ok();
137        } else if key.eq_ignore_ascii_case("IdentityFile") {
138            let expanded = expand_tilde(value);
139            result.identity_files.push(PathBuf::from(expanded));
140        } else if key.eq_ignore_ascii_case("ProxyJump") && result.proxy_jump.is_none() {
141            result.proxy_jump = Some(value.to_string());
142        }
143    }
144    result
145}
146
147/// Simple glob match supporting `*` (any chars) and `?` (one char).
148fn host_matches(pattern: &str, host: &str) -> bool {
149    let mut p = pattern.chars().peekable();
150    let mut h = host.chars().peekable();
151    host_matches_inner(&mut p, &mut h)
152}
153
154fn host_matches_inner(
155    p: &mut std::iter::Peekable<std::str::Chars>,
156    h: &mut std::iter::Peekable<std::str::Chars>,
157) -> bool {
158    while let Some(&pc) = p.peek() {
159        match pc {
160            '*' => {
161                p.next();
162                if p.peek().is_none() {
163                    return true; // trailing * matches everything
164                }
165                // Try matching * against 0..N chars of h
166                loop {
167                    let mut p2 = p.clone();
168                    let mut h2 = h.clone();
169                    if host_matches_inner(&mut p2, &mut h2) {
170                        return true;
171                    }
172                    if h.next().is_none() {
173                        return false;
174                    }
175                }
176            }
177            '?' => {
178                p.next();
179                if h.next().is_none() {
180                    return false;
181                }
182            }
183            _ => {
184                p.next();
185                match h.next() {
186                    Some(hc) if hc == pc => {}
187                    _ => return false,
188                }
189            }
190        }
191    }
192    h.peek().is_none()
193}
194
195fn expand_tilde(path: &str) -> String {
196    if let Some(rest) = path.strip_prefix("~/")
197        && let Some(home) = home_dir()
198    {
199        return format!("{}/{rest}", home.display());
200    }
201    path.to_string()
202}
203
204// ── Handler ────────────────────────────────────────────────────────────
205
206struct SshHandler {
207    host: String,
208    port: u16,
209}
210
211impl client::Handler for SshHandler {
212    type Error = Error;
213
214    async fn check_server_key(
215        &mut self,
216        server_public_key: &keys::PublicKey,
217    ) -> Result<bool, Self::Error> {
218        let known_hosts_path = match home_dir() {
219            Some(h) => h.join(".ssh").join("known_hosts"),
220            None => return Ok(true), // No home dir — accept
221        };
222        if !known_hosts_path.exists() {
223            // No known_hosts file — accept-new behaviour: create file and
224            // record the key.
225            if let Some(parent) = known_hosts_path.parent() {
226                let _ = std::fs::create_dir_all(parent);
227            }
228            append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
229            return Ok(true);
230        }
231        match keys::check_known_hosts_path(
232            &self.host,
233            self.port,
234            server_public_key,
235            &known_hosts_path,
236        ) {
237            Ok(true) => Ok(true),
238            Ok(false) => {
239                // Key not in file — accept-new: append and accept.
240                append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
241                Ok(true)
242            }
243            Err(keys::Error::KeyChanged { .. }) => Err(Error::Other(format!(
244                "host key for {}:{} has changed! \
245                     This could indicate a man-in-the-middle attack. \
246                     Remove the old key from ~/.ssh/known_hosts to continue.",
247                self.host, self.port
248            ))),
249            Err(_) => {
250                // Other errors (parse failure, etc.) — accept-new.
251                append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
252                Ok(true)
253            }
254        }
255    }
256}
257
258fn append_known_host(path: &Path, host: &str, port: u16, key: &keys::PublicKey) {
259    use keys::PublicKeyBase64;
260    let host_entry = if port == 22 {
261        host.to_string()
262    } else {
263        format!("[{host}]:{port}")
264    };
265    let algo = key.algorithm().to_string();
266    let b64 = key.public_key_base64();
267    let line = format!("{host_entry} {algo} {b64}\n");
268    let _ = std::fs::OpenOptions::new()
269        .create(true)
270        .append(true)
271        .open(path)
272        .and_then(|mut f| {
273            use std::io::Write;
274            f.write_all(line.as_bytes())
275        });
276}
277
278// ── SSH Pool ───────────────────────────────────────────────────────────
279
280/// SSH connection pool. Maintains persistent SSH connections and opens
281/// channels on demand. Multiple channels share a single TCP+SSH connection
282/// per host. Thread-safe and cheaply cloneable via `Arc`.
283#[derive(Clone)]
284pub struct SshPool {
285    inner: Arc<PoolInner>,
286}
287
288struct PoolInner {
289    /// Cached connections keyed by `"user@host:port"`.
290    connections: Mutex<HashMap<String, CachedConnection>>,
291}
292
293struct CachedConnection {
294    handle: client::Handle<SshHandler>,
295    /// Resolved remote blit socket path (cached after first resolution).
296    remote_socket: Option<String>,
297}
298
299impl Default for SshPool {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305impl SshPool {
306    pub fn new() -> Self {
307        Self {
308            inner: Arc::new(PoolInner {
309                connections: Mutex::new(HashMap::new()),
310            }),
311        }
312    }
313
314    /// Open a `direct-streamlocal` channel to a remote blit-server.
315    ///
316    /// - Resolves `~/.ssh/config` for the target host.
317    /// - Reuses an existing SSH connection if available.
318    /// - Authenticates via ssh-agent, then falls back to key files.
319    /// - If `remote_socket` is `None`, discovers the socket path on the remote.
320    /// - Auto-starts blit-server on the remote if needed.
321    /// - Returns a bidirectional `DuplexStream` connected to the remote socket.
322    pub async fn connect(
323        &self,
324        host: &str,
325        user: Option<&str>,
326        remote_socket: Option<&str>,
327    ) -> Result<tokio::io::DuplexStream, Error> {
328        let config = resolve_ssh_config(host);
329        let effective_host = config.hostname.as_deref().unwrap_or(host);
330        let effective_user = user
331            .map(String::from)
332            .or(config.user.clone())
333            .unwrap_or_else(current_username);
334        let effective_port = config.port.unwrap_or(22);
335
336        let key = format!("{effective_user}@{effective_host}:{effective_port}");
337
338        // Phase 1: check if we need a new SSH connection.
339        // Drop the lock before doing any network I/O so that connections to
340        // *other* hosts can proceed concurrently.
341        let mut conns = self.inner.connections.lock().await;
342        let need_new = match conns.get(&key) {
343            Some(cached) => cached.handle.is_closed(),
344            None => true,
345        };
346
347        if need_new {
348            // Release the lock while establishing the TCP + SSH connection —
349            // this can take seconds (DNS, handshake, auth).
350            drop(conns);
351            let handle =
352                establish_connection(effective_host, effective_port, &effective_user, &config)
353                    .await?;
354            conns = self.inner.connections.lock().await;
355            // Another task may have raced us for the same key — prefer the
356            // existing live connection to avoid duplicates.
357            let still_need = match conns.get(&key) {
358                Some(cached) => cached.handle.is_closed(),
359                None => true,
360            };
361            if still_need {
362                conns.insert(
363                    key.clone(),
364                    CachedConnection {
365                        handle,
366                        remote_socket: None,
367                    },
368                );
369            }
370        }
371
372        let cached = conns.get_mut(&key).unwrap();
373
374        // Resolve remote socket path if not cached and not explicitly provided.
375        let socket_path = if let Some(explicit) = remote_socket {
376            explicit.to_string()
377        } else if let Some(ref cached_path) = cached.remote_socket {
378            cached_path.clone()
379        } else {
380            let path = exec_command(&cached.handle, SOCK_SEARCH).await?;
381            let path = path.trim().to_string();
382            if path.is_empty() {
383                return Err(Error::Other(
384                    "could not determine remote blit socket path".into(),
385                ));
386            }
387            cached.remote_socket = Some(path.clone());
388            path
389        };
390
391        // Try to open the channel. If it fails, install + start and retry.
392        let channel = match cached
393            .handle
394            .channel_open_direct_streamlocal(&socket_path)
395            .await
396        {
397            Ok(ch) => ch,
398            Err(_first_err) => {
399                // Install blit if missing and (re)start the server.
400                let _ = exec_command(&cached.handle, &install_and_start_script(&socket_path)).await;
401                // Retry with back-off: the server needs a moment to create
402                // the socket after starting.
403                let mut last_err = _first_err;
404                for attempt in 0..10 {
405                    tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt + 1))).await;
406                    match cached
407                        .handle
408                        .channel_open_direct_streamlocal(&socket_path)
409                        .await
410                    {
411                        Ok(ch) => return Ok(bridge_channel(ch)),
412                        Err(e) => last_err = e,
413                    }
414                }
415                return Err(Error::Other(format!(
416                    "failed to connect to {socket_path} after install: {last_err}"
417                )));
418            }
419        };
420
421        Ok(bridge_channel(channel))
422    }
423}
424
425/// Bridge an SSH channel to a `DuplexStream` so callers get a standard
426/// tokio type with no russh types leaking.
427fn bridge_channel(channel: russh::Channel<russh::client::Msg>) -> tokio::io::DuplexStream {
428    let stream = channel.into_stream();
429    let (client, server) = tokio::io::duplex(64 * 1024);
430    tokio::spawn(async move {
431        let (mut sr, mut sw) = tokio::io::split(server);
432        let (mut cr, mut cw) = tokio::io::split(stream);
433        tokio::select! {
434            _ = tokio::io::copy(&mut cr, &mut sw) => {}
435            _ = tokio::io::copy(&mut sr, &mut cw) => {}
436        }
437    });
438    client
439}
440
441// ── Connection + Authentication ────────────────────────────────────────
442
443async fn establish_connection(
444    host: &str,
445    port: u16,
446    user: &str,
447    config: &ResolvedConfig,
448) -> Result<client::Handle<SshHandler>, Error> {
449    let ssh_config = client::Config {
450        // Detect dead connections behind NATs/firewalls instead of hanging
451        // indefinitely.  The SSH transport will send a keepalive packet
452        // every 15 s and give up after 3 consecutive misses (~45 s).
453        keepalive_interval: Some(std::time::Duration::from_secs(15)),
454        keepalive_max: 3,
455        ..Default::default()
456    };
457
458    let handler = SshHandler {
459        host: host.to_string(),
460        port,
461    };
462
463    let mut handle = client::connect(Arc::new(ssh_config), (host, port), handler).await?;
464
465    // Try ssh-agent first.
466    if try_agent_auth(&mut handle, user).await {
467        return Ok(handle);
468    }
469
470    // Fall back to key files.
471    if try_key_file_auth(&mut handle, user, config).await? {
472        return Ok(handle);
473    }
474
475    Err(Error::Other(format!(
476        "authentication failed for {user}@{host}:{port} \
477         (tried ssh-agent and key files)"
478    )))
479}
480
481/// Try authenticating via ssh-agent. Returns true on success.
482#[cfg(unix)]
483async fn try_agent_auth(handle: &mut client::Handle<SshHandler>, user: &str) -> bool {
484    let agent_path = match std::env::var("SSH_AUTH_SOCK") {
485        Ok(p) if !p.is_empty() => p,
486        _ => return false,
487    };
488    let stream = match tokio::net::UnixStream::connect(&agent_path).await {
489        Ok(s) => s,
490        Err(e) => {
491            log::debug!("ssh-agent connect failed: {e}");
492            return false;
493        }
494    };
495    let mut agent = agent::client::AgentClient::connect(stream);
496    let identities = match agent.request_identities().await {
497        Ok(ids) => ids,
498        Err(e) => {
499            log::debug!("ssh-agent request_identities failed: {e}");
500            return false;
501        }
502    };
503    for identity in &identities {
504        let public_key = identity.public_key().into_owned();
505        match handle
506            .authenticate_publickey_with(user, public_key, None, &mut agent)
507            .await
508        {
509            Ok(russh::client::AuthResult::Success) => return true,
510            Ok(_) => continue,
511            Err(e) => {
512                log::debug!("ssh-agent auth attempt failed: {e}");
513                continue;
514            }
515        }
516    }
517    false
518}
519
520/// On non-Unix platforms, agent auth is not yet supported — fall back to key files.
521#[cfg(not(unix))]
522async fn try_agent_auth(_handle: &mut client::Handle<SshHandler>, _user: &str) -> bool {
523    false
524}
525
526/// Try authenticating with key files. Returns true on success.
527async fn try_key_file_auth(
528    handle: &mut client::Handle<SshHandler>,
529    user: &str,
530    config: &ResolvedConfig,
531) -> Result<bool, Error> {
532    let home = match home_dir() {
533        Some(h) => h,
534        None => return Ok(false),
535    };
536
537    // Collect candidate key paths: explicit from config + defaults.
538    let mut candidates: Vec<PathBuf> = config.identity_files.clone();
539    for default in &["id_ed25519", "id_ecdsa", "id_rsa"] {
540        let p = home.join(".ssh").join(default);
541        if !candidates.contains(&p) {
542            candidates.push(p);
543        }
544    }
545
546    for path in &candidates {
547        if !path.exists() {
548            continue;
549        }
550        let key = match keys::load_secret_key(path, None) {
551            Ok(k) => k,
552            Err(e) => {
553                log::debug!("could not load {}: {e}", path.display());
554                continue;
555            }
556        };
557
558        // Determine the best RSA hash algorithm if applicable.
559        let hash_alg = handle.best_supported_rsa_hash().await.ok().flatten();
560        let key_with_hash = PrivateKeyWithHashAlg::new(Arc::new(key), hash_alg.flatten());
561
562        match handle.authenticate_publickey(user, key_with_hash).await {
563            Ok(russh::client::AuthResult::Success) => return Ok(true),
564            Ok(_) => continue,
565            Err(e) => {
566                log::debug!("key auth failed for {}: {e}", path.display());
567                continue;
568            }
569        }
570    }
571    Ok(false)
572}
573
574// ── Remote command execution ───────────────────────────────────────────
575
576/// Execute a command on the remote and return its stdout.
577async fn exec_command(handle: &client::Handle<SshHandler>, cmd: &str) -> Result<String, Error> {
578    let mut channel = handle.channel_open_session().await?;
579    channel.exec(true, cmd.as_bytes()).await?;
580
581    let mut output = Vec::new();
582    while let Some(msg) = channel.wait().await {
583        match msg {
584            russh::ChannelMsg::Data { data } => output.extend_from_slice(&data),
585            russh::ChannelMsg::Eof | russh::ChannelMsg::Close => break,
586            _ => continue,
587        }
588    }
589    Ok(String::from_utf8_lossy(&output).into_owned())
590}
591
592// ── Helpers ────────────────────────────────────────────────────────────
593
594fn home_dir() -> Option<PathBuf> {
595    #[cfg(unix)]
596    {
597        std::env::var("HOME").ok().map(PathBuf::from)
598    }
599    #[cfg(windows)]
600    {
601        std::env::var("USERPROFILE").ok().map(PathBuf::from)
602    }
603}
604
605fn current_username() -> String {
606    #[cfg(unix)]
607    {
608        std::env::var("USER").unwrap_or_else(|_| "root".into())
609    }
610    #[cfg(windows)]
611    {
612        std::env::var("USERNAME").unwrap_or_else(|_| "user".into())
613    }
614}
615
616/// Parse an SSH URI: `[user@]host[:/socket]`.
617/// Returns `(user, host, socket)`.
618pub fn parse_ssh_uri(s: &str) -> (Option<String>, String, Option<String>) {
619    let colon_start = s.find('@').map(|a| a + 1).unwrap_or(0);
620    let (host_part, socket) = if let Some(rel) = s[colon_start..].find(':') {
621        let pos = colon_start + rel;
622        let path = &s[pos + 1..];
623        if path.is_empty() {
624            (s, None)
625        } else {
626            (&s[..pos], Some(path.to_string()))
627        }
628    } else {
629        (s, None)
630    };
631    let (user, host) = if let Some(at) = host_part.rfind('@') {
632        (
633            Some(host_part[..at].to_string()),
634            host_part[at + 1..].to_string(),
635        )
636    } else {
637        (None, host_part.to_string())
638    };
639    (user, host, socket)
640}