Skip to main content

blit_ssh/
lib.rs

1//! Embedded SSH client for blit.
2//!
3//! Provides connection pooling, ssh-agent authentication, `~/.ssh/config`
4//! parsing, and `direct-streamlocal` channel forwarding for connecting to
5//! remote blit-servers without shelling out to the system `ssh` binary.
6
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12use russh::client;
13use russh::keys::{self, PrivateKeyWithHashAlg, agent};
14
15// ── Error ──────────────────────────────────────────────────────────────
16
17#[derive(Debug, thiserror::Error)]
18pub enum Error {
19    #[error("ssh: {0}")]
20    Russh(#[from] russh::Error),
21    #[error("ssh key: {0}")]
22    Keys(#[from] keys::Error),
23    #[error("ssh: {0}")]
24    Io(#[from] std::io::Error),
25    #[error("ssh: {0}")]
26    Other(String),
27}
28
29// ── Shell scripts run on the remote ────────────────────────────────────
30
31/// Resolve the remote blit socket path.
32///
33/// Wrapped in `sh -c` so the POSIX script runs correctly even when the
34/// remote user's login shell is fish or another non-POSIX shell.
35const SOCK_SEARCH: &str = r#"sh -c 'if [ -n "$BLIT_SOCK" ]; then S="$BLIT_SOCK"; elif [ -n "$TMPDIR" ] && [ -S "$TMPDIR/blit.sock" ]; then S="$TMPDIR/blit.sock"; elif [ -S "/tmp/blit-$(id -un).sock" ]; then S="/tmp/blit-$(id -un).sock"; elif [ -S "/run/blit/$(id -un).sock" ]; then S="/run/blit/$(id -un).sock"; elif [ -n "$XDG_RUNTIME_DIR" ] && [ -S "$XDG_RUNTIME_DIR/blit.sock" ]; then S="$XDG_RUNTIME_DIR/blit.sock"; else S=/tmp/blit.sock; fi; echo "$S"'"#;
36
37/// Escape a string for use inside double quotes in a POSIX shell.
38/// Handles `\`, `$`, `` ` ``, and `"`.
39fn dq_escape(s: &str) -> String {
40    let mut out = String::with_capacity(s.len());
41    for ch in s.chars() {
42        match ch {
43            '\\' | '$' | '`' | '"' => {
44                out.push('\\');
45                out.push(ch);
46            }
47            _ => out.push(ch),
48        }
49    }
50    out
51}
52
53/// Install blit on the remote if missing, then start blit-server and
54/// detach it from the session.
55///
56/// Wrapped in `sh -c` so the POSIX script runs correctly even when the
57/// remote user's login shell is fish or another non-POSIX shell.  The
58/// socket path is double-quote-escaped to avoid single-quote nesting
59/// issues inside the outer `sh -c '…'` wrapper.
60fn install_and_start_script(socket_path: &str) -> String {
61    let escaped = dq_escape(socket_path);
62    format!(
63        "sh -c 'export PATH=\"$HOME/.local/bin:$PATH\"; \
64         if ! command -v blit >/dev/null 2>&1 && ! command -v blit-server >/dev/null 2>&1; then \
65           if command -v curl >/dev/null 2>&1; then BLIT_INSTALL_DIR=\"$HOME/.local/bin\" curl -sf https://install.blit.sh | sh >&2; \
66           elif command -v wget >/dev/null 2>&1; then BLIT_INSTALL_DIR=\"$HOME/.local/bin\" wget -qO- https://install.blit.sh | sh >&2; fi; \
67         fi; \
68         S=\"{escaped}\"; \
69         if [ -S \"$S\" ]; then \
70           if command -v nc >/dev/null 2>&1; then nc -z -U \"$S\" 2>/dev/null || rm -f \"$S\"; \
71           elif command -v socat >/dev/null 2>&1; then socat /dev/null \"UNIX-CONNECT:$S\" 2>/dev/null || rm -f \"$S\"; fi; \
72         fi; \
73         if ! [ -S \"$S\" ]; then \
74           if command -v blit >/dev/null 2>&1; then nohup blit server </dev/null >/dev/null 2>&1 & \
75           elif command -v blit-server >/dev/null 2>&1; then nohup blit-server </dev/null >/dev/null 2>&1 & fi; \
76         fi; \
77         echo ok'"
78    )
79}
80
81// ── SSH config resolution ──────────────────────────────────────────────
82
83/// Resolved SSH settings for a host, from `~/.ssh/config`.
84#[derive(Default)]
85struct ResolvedConfig {
86    hostname: Option<String>,
87    user: Option<String>,
88    port: Option<u16>,
89    identity_files: Vec<PathBuf>,
90    #[allow(dead_code)]
91    proxy_jump: Option<String>,
92}
93
94/// Minimal `~/.ssh/config` parser. Supports Host (with `*`/`?` globs),
95/// Hostname, User, Port, IdentityFile, and ProxyJump.
96fn resolve_ssh_config(host: &str) -> ResolvedConfig {
97    let path = match home_dir() {
98        Some(h) => h.join(".ssh").join("config"),
99        None => return ResolvedConfig::default(),
100    };
101    let text = match std::fs::read_to_string(&path) {
102        Ok(t) => t,
103        Err(_) => return ResolvedConfig::default(),
104    };
105
106    let mut result = ResolvedConfig::default();
107    let mut in_matching_block = false;
108    let mut in_global = true; // before the first Host line
109
110    for line in text.lines() {
111        let line = line.trim();
112        if line.is_empty() || line.starts_with('#') {
113            continue;
114        }
115        let (key, value) = match line.split_once(|c: char| c.is_ascii_whitespace() || c == '=') {
116            Some((k, v)) => (k.trim(), v.trim().trim_start_matches('=')),
117            None => continue,
118        };
119        let value = value.trim();
120        if key.eq_ignore_ascii_case("Host") {
121            in_global = false;
122            in_matching_block = value
123                .split_whitespace()
124                .any(|pattern| host_matches(pattern, host));
125            continue;
126        }
127        if !in_matching_block && !in_global {
128            continue;
129        }
130        if key.eq_ignore_ascii_case("Hostname") && result.hostname.is_none() {
131            result.hostname = Some(value.to_string());
132        } else if key.eq_ignore_ascii_case("User") && result.user.is_none() {
133            result.user = Some(value.to_string());
134        } else if key.eq_ignore_ascii_case("Port") && result.port.is_none() {
135            result.port = value.parse().ok();
136        } else if key.eq_ignore_ascii_case("IdentityFile") {
137            let expanded = expand_tilde(value);
138            result.identity_files.push(PathBuf::from(expanded));
139        } else if key.eq_ignore_ascii_case("ProxyJump") && result.proxy_jump.is_none() {
140            result.proxy_jump = Some(value.to_string());
141        }
142    }
143    result
144}
145
146/// Simple glob match supporting `*` (any chars) and `?` (one char).
147fn host_matches(pattern: &str, host: &str) -> bool {
148    let mut p = pattern.chars().peekable();
149    let mut h = host.chars().peekable();
150    host_matches_inner(&mut p, &mut h)
151}
152
153fn host_matches_inner(
154    p: &mut std::iter::Peekable<std::str::Chars>,
155    h: &mut std::iter::Peekable<std::str::Chars>,
156) -> bool {
157    while let Some(&pc) = p.peek() {
158        match pc {
159            '*' => {
160                p.next();
161                if p.peek().is_none() {
162                    return true; // trailing * matches everything
163                }
164                // Try matching * against 0..N chars of h
165                loop {
166                    let mut p2 = p.clone();
167                    let mut h2 = h.clone();
168                    if host_matches_inner(&mut p2, &mut h2) {
169                        return true;
170                    }
171                    if h.next().is_none() {
172                        return false;
173                    }
174                }
175            }
176            '?' => {
177                p.next();
178                if h.next().is_none() {
179                    return false;
180                }
181            }
182            _ => {
183                p.next();
184                match h.next() {
185                    Some(hc) if hc == pc => {}
186                    _ => return false,
187                }
188            }
189        }
190    }
191    h.peek().is_none()
192}
193
194fn expand_tilde(path: &str) -> String {
195    if let Some(rest) = path.strip_prefix("~/")
196        && let Some(home) = home_dir()
197    {
198        return format!("{}/{rest}", home.display());
199    }
200    path.to_string()
201}
202
203// ── Handler ────────────────────────────────────────────────────────────
204
205struct SshHandler {
206    host: String,
207    port: u16,
208}
209
210impl client::Handler for SshHandler {
211    type Error = Error;
212
213    async fn check_server_key(
214        &mut self,
215        server_public_key: &keys::PublicKey,
216    ) -> Result<bool, Self::Error> {
217        let known_hosts_path = match home_dir() {
218            Some(h) => h.join(".ssh").join("known_hosts"),
219            None => return Ok(true), // No home dir — accept
220        };
221        if !known_hosts_path.exists() {
222            // No known_hosts file — accept-new behaviour: create file and
223            // record the key.
224            if let Some(parent) = known_hosts_path.parent() {
225                let _ = std::fs::create_dir_all(parent);
226            }
227            append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
228            return Ok(true);
229        }
230        match keys::check_known_hosts_path(
231            &self.host,
232            self.port,
233            server_public_key,
234            &known_hosts_path,
235        ) {
236            Ok(true) => Ok(true),
237            Ok(false) => {
238                // Key not in file — accept-new: append and accept.
239                append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
240                Ok(true)
241            }
242            Err(keys::Error::KeyChanged { .. }) => Err(Error::Other(format!(
243                "host key for {}:{} has changed! \
244                     This could indicate a man-in-the-middle attack. \
245                     Remove the old key from ~/.ssh/known_hosts to continue.",
246                self.host, self.port
247            ))),
248            Err(_) => {
249                // Other errors (parse failure, etc.) — accept-new.
250                append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
251                Ok(true)
252            }
253        }
254    }
255}
256
257fn append_known_host(path: &Path, host: &str, port: u16, key: &keys::PublicKey) {
258    use keys::PublicKeyBase64;
259    let host_entry = if port == 22 {
260        host.to_string()
261    } else {
262        format!("[{host}]:{port}")
263    };
264    let algo = key.algorithm().to_string();
265    let b64 = key.public_key_base64();
266    let line = format!("{host_entry} {algo} {b64}\n");
267    let _ = std::fs::OpenOptions::new()
268        .create(true)
269        .append(true)
270        .open(path)
271        .and_then(|mut f| {
272            use std::io::Write;
273            f.write_all(line.as_bytes())
274        });
275}
276
277// ── SSH Pool ───────────────────────────────────────────────────────────
278
279/// SSH connection pool. Maintains persistent SSH connections and opens
280/// channels on demand. Multiple channels share a single TCP+SSH connection
281/// per host. Thread-safe and cheaply cloneable via `Arc`.
282#[derive(Clone)]
283pub struct SshPool {
284    inner: Arc<PoolInner>,
285}
286
287struct PoolInner {
288    /// Cached connections keyed by `"user@host:port"`.
289    connections: Mutex<HashMap<String, CachedConnection>>,
290}
291
292struct CachedConnection {
293    handle: client::Handle<SshHandler>,
294    /// Resolved remote blit socket path (cached after first resolution).
295    remote_socket: Option<String>,
296}
297
298impl Default for SshPool {
299    fn default() -> Self {
300        Self::new()
301    }
302}
303
304impl SshPool {
305    pub fn new() -> Self {
306        Self {
307            inner: Arc::new(PoolInner {
308                connections: Mutex::new(HashMap::new()),
309            }),
310        }
311    }
312
313    /// Open a `direct-streamlocal` channel to a remote blit-server.
314    ///
315    /// - Resolves `~/.ssh/config` for the target host.
316    /// - Reuses an existing SSH connection if available.
317    /// - Authenticates via ssh-agent, then falls back to key files.
318    /// - If `remote_socket` is `None`, discovers the socket path on the remote.
319    /// - Auto-starts blit-server on the remote if needed.
320    /// - Returns a bidirectional `DuplexStream` connected to the remote socket.
321    pub async fn connect(
322        &self,
323        host: &str,
324        user: Option<&str>,
325        remote_socket: Option<&str>,
326    ) -> Result<tokio::io::DuplexStream, Error> {
327        let config = resolve_ssh_config(host);
328        let effective_host = config.hostname.as_deref().unwrap_or(host);
329        let effective_user = user
330            .map(String::from)
331            .or(config.user.clone())
332            .unwrap_or_else(current_username);
333        let effective_port = config.port.unwrap_or(22);
334
335        let key = format!("{effective_user}@{effective_host}:{effective_port}");
336
337        let mut conns = self.inner.connections.lock().await;
338
339        // Try reusing an existing connection.
340        let need_new = match conns.get(&key) {
341            Some(cached) => cached.handle.is_closed(),
342            None => true,
343        };
344
345        if need_new {
346            let handle =
347                establish_connection(effective_host, effective_port, &effective_user, &config)
348                    .await?;
349            conns.insert(
350                key.clone(),
351                CachedConnection {
352                    handle,
353                    remote_socket: None,
354                },
355            );
356        }
357
358        let cached = conns.get_mut(&key).unwrap();
359
360        // Resolve remote socket path if not cached and not explicitly provided.
361        let socket_path = if let Some(explicit) = remote_socket {
362            explicit.to_string()
363        } else if let Some(ref cached_path) = cached.remote_socket {
364            cached_path.clone()
365        } else {
366            let path = exec_command(&cached.handle, SOCK_SEARCH).await?;
367            let path = path.trim().to_string();
368            if path.is_empty() {
369                return Err(Error::Other(
370                    "could not determine remote blit socket path".into(),
371                ));
372            }
373            cached.remote_socket = Some(path.clone());
374            path
375        };
376
377        // Try to open the channel. If it fails, install + start and retry.
378        let channel = match cached
379            .handle
380            .channel_open_direct_streamlocal(&socket_path)
381            .await
382        {
383            Ok(ch) => ch,
384            Err(_first_err) => {
385                // Install blit if missing and (re)start the server.
386                let _ = exec_command(&cached.handle, &install_and_start_script(&socket_path)).await;
387                // Retry with back-off: the server needs a moment to create
388                // the socket after starting.
389                let mut last_err = _first_err;
390                for attempt in 0..10 {
391                    tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt + 1))).await;
392                    match cached
393                        .handle
394                        .channel_open_direct_streamlocal(&socket_path)
395                        .await
396                    {
397                        Ok(ch) => return Ok(bridge_channel(ch)),
398                        Err(e) => last_err = e,
399                    }
400                }
401                return Err(Error::Other(format!(
402                    "failed to connect to {socket_path} after install: {last_err}"
403                )));
404            }
405        };
406
407        Ok(bridge_channel(channel))
408    }
409}
410
411/// Bridge an SSH channel to a `DuplexStream` so callers get a standard
412/// tokio type with no russh types leaking.
413fn bridge_channel(channel: russh::Channel<russh::client::Msg>) -> tokio::io::DuplexStream {
414    let stream = channel.into_stream();
415    let (client, server) = tokio::io::duplex(64 * 1024);
416    tokio::spawn(async move {
417        let (mut sr, mut sw) = tokio::io::split(server);
418        let (mut cr, mut cw) = tokio::io::split(stream);
419        tokio::select! {
420            _ = tokio::io::copy(&mut cr, &mut sw) => {}
421            _ = tokio::io::copy(&mut sr, &mut cw) => {}
422        }
423    });
424    client
425}
426
427// ── Connection + Authentication ────────────────────────────────────────
428
429async fn establish_connection(
430    host: &str,
431    port: u16,
432    user: &str,
433    config: &ResolvedConfig,
434) -> Result<client::Handle<SshHandler>, Error> {
435    let ssh_config = client::Config {
436        ..Default::default()
437    };
438
439    let handler = SshHandler {
440        host: host.to_string(),
441        port,
442    };
443
444    let mut handle = client::connect(Arc::new(ssh_config), (host, port), handler).await?;
445
446    // Try ssh-agent first.
447    if try_agent_auth(&mut handle, user).await {
448        return Ok(handle);
449    }
450
451    // Fall back to key files.
452    if try_key_file_auth(&mut handle, user, config).await? {
453        return Ok(handle);
454    }
455
456    Err(Error::Other(format!(
457        "authentication failed for {user}@{host}:{port} \
458         (tried ssh-agent and key files)"
459    )))
460}
461
462/// Try authenticating via ssh-agent. Returns true on success.
463#[cfg(unix)]
464async fn try_agent_auth(handle: &mut client::Handle<SshHandler>, user: &str) -> bool {
465    let agent_path = match std::env::var("SSH_AUTH_SOCK") {
466        Ok(p) if !p.is_empty() => p,
467        _ => return false,
468    };
469    let stream = match tokio::net::UnixStream::connect(&agent_path).await {
470        Ok(s) => s,
471        Err(e) => {
472            log::debug!("ssh-agent connect failed: {e}");
473            return false;
474        }
475    };
476    let mut agent = agent::client::AgentClient::connect(stream);
477    let identities = match agent.request_identities().await {
478        Ok(ids) => ids,
479        Err(e) => {
480            log::debug!("ssh-agent request_identities failed: {e}");
481            return false;
482        }
483    };
484    for identity in &identities {
485        let public_key = identity.public_key().into_owned();
486        match handle
487            .authenticate_publickey_with(user, public_key, None, &mut agent)
488            .await
489        {
490            Ok(russh::client::AuthResult::Success) => return true,
491            Ok(_) => continue,
492            Err(e) => {
493                log::debug!("ssh-agent auth attempt failed: {e}");
494                continue;
495            }
496        }
497    }
498    false
499}
500
501/// On non-Unix platforms, agent auth is not yet supported — fall back to key files.
502#[cfg(not(unix))]
503async fn try_agent_auth(_handle: &mut client::Handle<SshHandler>, _user: &str) -> bool {
504    false
505}
506
507/// Try authenticating with key files. Returns true on success.
508async fn try_key_file_auth(
509    handle: &mut client::Handle<SshHandler>,
510    user: &str,
511    config: &ResolvedConfig,
512) -> Result<bool, Error> {
513    let home = match home_dir() {
514        Some(h) => h,
515        None => return Ok(false),
516    };
517
518    // Collect candidate key paths: explicit from config + defaults.
519    let mut candidates: Vec<PathBuf> = config.identity_files.clone();
520    for default in &["id_ed25519", "id_ecdsa", "id_rsa"] {
521        let p = home.join(".ssh").join(default);
522        if !candidates.contains(&p) {
523            candidates.push(p);
524        }
525    }
526
527    for path in &candidates {
528        if !path.exists() {
529            continue;
530        }
531        let key = match keys::load_secret_key(path, None) {
532            Ok(k) => k,
533            Err(e) => {
534                log::debug!("could not load {}: {e}", path.display());
535                continue;
536            }
537        };
538
539        // Determine the best RSA hash algorithm if applicable.
540        let hash_alg = handle.best_supported_rsa_hash().await.ok().flatten();
541        let key_with_hash = PrivateKeyWithHashAlg::new(Arc::new(key), hash_alg.flatten());
542
543        match handle.authenticate_publickey(user, key_with_hash).await {
544            Ok(russh::client::AuthResult::Success) => return Ok(true),
545            Ok(_) => continue,
546            Err(e) => {
547                log::debug!("key auth failed for {}: {e}", path.display());
548                continue;
549            }
550        }
551    }
552    Ok(false)
553}
554
555// ── Remote command execution ───────────────────────────────────────────
556
557/// Execute a command on the remote and return its stdout.
558async fn exec_command(handle: &client::Handle<SshHandler>, cmd: &str) -> Result<String, Error> {
559    let mut channel = handle.channel_open_session().await?;
560    channel.exec(true, cmd.as_bytes()).await?;
561
562    let mut output = Vec::new();
563    while let Some(msg) = channel.wait().await {
564        match msg {
565            russh::ChannelMsg::Data { data } => output.extend_from_slice(&data),
566            russh::ChannelMsg::Eof | russh::ChannelMsg::Close => break,
567            _ => continue,
568        }
569    }
570    Ok(String::from_utf8_lossy(&output).into_owned())
571}
572
573// ── Helpers ────────────────────────────────────────────────────────────
574
575fn home_dir() -> Option<PathBuf> {
576    #[cfg(unix)]
577    {
578        std::env::var("HOME").ok().map(PathBuf::from)
579    }
580    #[cfg(windows)]
581    {
582        std::env::var("USERPROFILE").ok().map(PathBuf::from)
583    }
584}
585
586fn current_username() -> String {
587    #[cfg(unix)]
588    {
589        std::env::var("USER").unwrap_or_else(|_| "root".into())
590    }
591    #[cfg(windows)]
592    {
593        std::env::var("USERNAME").unwrap_or_else(|_| "user".into())
594    }
595}
596
597/// Parse an SSH URI: `[user@]host[:/socket]`.
598/// Returns `(user, host, socket)`.
599pub fn parse_ssh_uri(s: &str) -> (Option<String>, String, Option<String>) {
600    let colon_start = s.find('@').map(|a| a + 1).unwrap_or(0);
601    let (host_part, socket) = if let Some(rel) = s[colon_start..].find(':') {
602        let pos = colon_start + rel;
603        let path = &s[pos + 1..];
604        if path.is_empty() {
605            (s, None)
606        } else {
607            (&s[..pos], Some(path.to_string()))
608        }
609    } else {
610        (s, None)
611    };
612    let (user, host) = if let Some(at) = host_part.rfind('@') {
613        (
614            Some(host_part[..at].to_string()),
615            host_part[at + 1..].to_string(),
616        )
617    } else {
618        (None, host_part.to_string())
619    };
620    (user, host, socket)
621}