1use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11
12use russh::client;
13use russh::keys::{self, PrivateKeyWithHashAlg, agent};
14
15#[derive(Debug, thiserror::Error)]
18pub enum Error {
19 #[error("ssh: {0}")]
20 Russh(#[from] russh::Error),
21 #[error("ssh key: {0}")]
22 Keys(#[from] keys::Error),
23 #[error("ssh: {0}")]
24 Io(#[from] std::io::Error),
25 #[error("ssh: {0}")]
26 Other(String),
27}
28
29const SOCK_SEARCH: &str = r#"sh -c 'if [ -n "$BLIT_SOCK" ]; then S="$BLIT_SOCK"; elif [ -n "$TMPDIR" ] && [ -S "$TMPDIR/blit.sock" ]; then S="$TMPDIR/blit.sock"; elif [ -S "/tmp/blit-$(id -un).sock" ]; then S="/tmp/blit-$(id -un).sock"; elif [ -S "/run/blit/$(id -un).sock" ]; then S="/run/blit/$(id -un).sock"; elif [ -n "$XDG_RUNTIME_DIR" ] && [ -S "$XDG_RUNTIME_DIR/blit.sock" ]; then S="$XDG_RUNTIME_DIR/blit.sock"; else S=/tmp/blit.sock; fi; echo "$S"'"#;
36
37fn dq_escape(s: &str) -> String {
40 let mut out = String::with_capacity(s.len());
41 for ch in s.chars() {
42 match ch {
43 '\\' | '$' | '`' | '"' => {
44 out.push('\\');
45 out.push(ch);
46 }
47 _ => out.push(ch),
48 }
49 }
50 out
51}
52
53fn install_and_start_script(socket_path: &str) -> String {
61 let escaped = dq_escape(socket_path);
62 format!(
63 "sh -c 'export PATH=\"$HOME/.local/bin:$PATH\"; \
64 if ! command -v blit >/dev/null 2>&1 && ! command -v blit-server >/dev/null 2>&1; then \
65 if command -v curl >/dev/null 2>&1; then BLIT_INSTALL_DIR=\"$HOME/.local/bin\" curl -sf https://install.blit.sh | sh >&2; \
66 elif command -v wget >/dev/null 2>&1; then BLIT_INSTALL_DIR=\"$HOME/.local/bin\" wget -qO- https://install.blit.sh | sh >&2; fi; \
67 fi; \
68 S=\"{escaped}\"; \
69 if [ -S \"$S\" ]; then \
70 if command -v nc >/dev/null 2>&1; then nc -z -U \"$S\" 2>/dev/null || rm -f \"$S\"; \
71 elif command -v socat >/dev/null 2>&1; then socat /dev/null \"UNIX-CONNECT:$S\" 2>/dev/null || rm -f \"$S\"; fi; \
72 fi; \
73 if ! [ -S \"$S\" ]; then \
74 if command -v blit >/dev/null 2>&1; then nohup blit server </dev/null >/dev/null 2>&1 & \
75 elif command -v blit-server >/dev/null 2>&1; then nohup blit-server </dev/null >/dev/null 2>&1 & fi; \
76 fi; \
77 echo ok'"
78 )
79}
80
81#[derive(Default)]
85struct ResolvedConfig {
86 hostname: Option<String>,
87 user: Option<String>,
88 port: Option<u16>,
89 identity_files: Vec<PathBuf>,
90 #[allow(dead_code)]
91 proxy_jump: Option<String>,
92}
93
94fn resolve_ssh_config(host: &str) -> ResolvedConfig {
97 let path = match home_dir() {
98 Some(h) => h.join(".ssh").join("config"),
99 None => return ResolvedConfig::default(),
100 };
101 let text = match std::fs::read_to_string(&path) {
102 Ok(t) => t,
103 Err(_) => return ResolvedConfig::default(),
104 };
105
106 let mut result = ResolvedConfig::default();
107 let mut in_matching_block = false;
108 let mut in_global = true; for line in text.lines() {
111 let line = line.trim();
112 if line.is_empty() || line.starts_with('#') {
113 continue;
114 }
115 let (key, value) = match line.split_once(|c: char| c.is_ascii_whitespace() || c == '=') {
116 Some((k, v)) => (k.trim(), v.trim().trim_start_matches('=')),
117 None => continue,
118 };
119 let value = value.trim();
120 if key.eq_ignore_ascii_case("Host") {
121 in_global = false;
122 in_matching_block = value
123 .split_whitespace()
124 .any(|pattern| host_matches(pattern, host));
125 continue;
126 }
127 if !in_matching_block && !in_global {
128 continue;
129 }
130 if key.eq_ignore_ascii_case("Hostname") && result.hostname.is_none() {
131 result.hostname = Some(value.to_string());
132 } else if key.eq_ignore_ascii_case("User") && result.user.is_none() {
133 result.user = Some(value.to_string());
134 } else if key.eq_ignore_ascii_case("Port") && result.port.is_none() {
135 result.port = value.parse().ok();
136 } else if key.eq_ignore_ascii_case("IdentityFile") {
137 let expanded = expand_tilde(value);
138 result.identity_files.push(PathBuf::from(expanded));
139 } else if key.eq_ignore_ascii_case("ProxyJump") && result.proxy_jump.is_none() {
140 result.proxy_jump = Some(value.to_string());
141 }
142 }
143 result
144}
145
146fn host_matches(pattern: &str, host: &str) -> bool {
148 let mut p = pattern.chars().peekable();
149 let mut h = host.chars().peekable();
150 host_matches_inner(&mut p, &mut h)
151}
152
153fn host_matches_inner(
154 p: &mut std::iter::Peekable<std::str::Chars>,
155 h: &mut std::iter::Peekable<std::str::Chars>,
156) -> bool {
157 while let Some(&pc) = p.peek() {
158 match pc {
159 '*' => {
160 p.next();
161 if p.peek().is_none() {
162 return true; }
164 loop {
166 let mut p2 = p.clone();
167 let mut h2 = h.clone();
168 if host_matches_inner(&mut p2, &mut h2) {
169 return true;
170 }
171 if h.next().is_none() {
172 return false;
173 }
174 }
175 }
176 '?' => {
177 p.next();
178 if h.next().is_none() {
179 return false;
180 }
181 }
182 _ => {
183 p.next();
184 match h.next() {
185 Some(hc) if hc == pc => {}
186 _ => return false,
187 }
188 }
189 }
190 }
191 h.peek().is_none()
192}
193
194fn expand_tilde(path: &str) -> String {
195 if let Some(rest) = path.strip_prefix("~/")
196 && let Some(home) = home_dir()
197 {
198 return format!("{}/{rest}", home.display());
199 }
200 path.to_string()
201}
202
203struct SshHandler {
206 host: String,
207 port: u16,
208}
209
210impl client::Handler for SshHandler {
211 type Error = Error;
212
213 async fn check_server_key(
214 &mut self,
215 server_public_key: &keys::PublicKey,
216 ) -> Result<bool, Self::Error> {
217 let known_hosts_path = match home_dir() {
218 Some(h) => h.join(".ssh").join("known_hosts"),
219 None => return Ok(true), };
221 if !known_hosts_path.exists() {
222 if let Some(parent) = known_hosts_path.parent() {
225 let _ = std::fs::create_dir_all(parent);
226 }
227 append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
228 return Ok(true);
229 }
230 match keys::check_known_hosts_path(
231 &self.host,
232 self.port,
233 server_public_key,
234 &known_hosts_path,
235 ) {
236 Ok(true) => Ok(true),
237 Ok(false) => {
238 append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
240 Ok(true)
241 }
242 Err(keys::Error::KeyChanged { .. }) => Err(Error::Other(format!(
243 "host key for {}:{} has changed! \
244 This could indicate a man-in-the-middle attack. \
245 Remove the old key from ~/.ssh/known_hosts to continue.",
246 self.host, self.port
247 ))),
248 Err(_) => {
249 append_known_host(&known_hosts_path, &self.host, self.port, server_public_key);
251 Ok(true)
252 }
253 }
254 }
255}
256
257fn append_known_host(path: &Path, host: &str, port: u16, key: &keys::PublicKey) {
258 use keys::PublicKeyBase64;
259 let host_entry = if port == 22 {
260 host.to_string()
261 } else {
262 format!("[{host}]:{port}")
263 };
264 let algo = key.algorithm().to_string();
265 let b64 = key.public_key_base64();
266 let line = format!("{host_entry} {algo} {b64}\n");
267 let _ = std::fs::OpenOptions::new()
268 .create(true)
269 .append(true)
270 .open(path)
271 .and_then(|mut f| {
272 use std::io::Write;
273 f.write_all(line.as_bytes())
274 });
275}
276
277#[derive(Clone)]
283pub struct SshPool {
284 inner: Arc<PoolInner>,
285}
286
287struct PoolInner {
288 connections: Mutex<HashMap<String, CachedConnection>>,
290}
291
292struct CachedConnection {
293 handle: client::Handle<SshHandler>,
294 remote_socket: Option<String>,
296}
297
298impl Default for SshPool {
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304impl SshPool {
305 pub fn new() -> Self {
306 Self {
307 inner: Arc::new(PoolInner {
308 connections: Mutex::new(HashMap::new()),
309 }),
310 }
311 }
312
313 pub async fn connect(
322 &self,
323 host: &str,
324 user: Option<&str>,
325 remote_socket: Option<&str>,
326 ) -> Result<tokio::io::DuplexStream, Error> {
327 let config = resolve_ssh_config(host);
328 let effective_host = config.hostname.as_deref().unwrap_or(host);
329 let effective_user = user
330 .map(String::from)
331 .or(config.user.clone())
332 .unwrap_or_else(current_username);
333 let effective_port = config.port.unwrap_or(22);
334
335 let key = format!("{effective_user}@{effective_host}:{effective_port}");
336
337 let mut conns = self.inner.connections.lock().await;
338
339 let need_new = match conns.get(&key) {
341 Some(cached) => cached.handle.is_closed(),
342 None => true,
343 };
344
345 if need_new {
346 let handle =
347 establish_connection(effective_host, effective_port, &effective_user, &config)
348 .await?;
349 conns.insert(
350 key.clone(),
351 CachedConnection {
352 handle,
353 remote_socket: None,
354 },
355 );
356 }
357
358 let cached = conns.get_mut(&key).unwrap();
359
360 let socket_path = if let Some(explicit) = remote_socket {
362 explicit.to_string()
363 } else if let Some(ref cached_path) = cached.remote_socket {
364 cached_path.clone()
365 } else {
366 let path = exec_command(&cached.handle, SOCK_SEARCH).await?;
367 let path = path.trim().to_string();
368 if path.is_empty() {
369 return Err(Error::Other(
370 "could not determine remote blit socket path".into(),
371 ));
372 }
373 cached.remote_socket = Some(path.clone());
374 path
375 };
376
377 let channel = match cached
379 .handle
380 .channel_open_direct_streamlocal(&socket_path)
381 .await
382 {
383 Ok(ch) => ch,
384 Err(_first_err) => {
385 let _ = exec_command(&cached.handle, &install_and_start_script(&socket_path)).await;
387 let mut last_err = _first_err;
390 for attempt in 0..10 {
391 tokio::time::sleep(std::time::Duration::from_millis(100 * (attempt + 1))).await;
392 match cached
393 .handle
394 .channel_open_direct_streamlocal(&socket_path)
395 .await
396 {
397 Ok(ch) => return Ok(bridge_channel(ch)),
398 Err(e) => last_err = e,
399 }
400 }
401 return Err(Error::Other(format!(
402 "failed to connect to {socket_path} after install: {last_err}"
403 )));
404 }
405 };
406
407 Ok(bridge_channel(channel))
408 }
409}
410
411fn bridge_channel(channel: russh::Channel<russh::client::Msg>) -> tokio::io::DuplexStream {
414 let stream = channel.into_stream();
415 let (client, server) = tokio::io::duplex(64 * 1024);
416 tokio::spawn(async move {
417 let (mut sr, mut sw) = tokio::io::split(server);
418 let (mut cr, mut cw) = tokio::io::split(stream);
419 tokio::select! {
420 _ = tokio::io::copy(&mut cr, &mut sw) => {}
421 _ = tokio::io::copy(&mut sr, &mut cw) => {}
422 }
423 });
424 client
425}
426
427async fn establish_connection(
430 host: &str,
431 port: u16,
432 user: &str,
433 config: &ResolvedConfig,
434) -> Result<client::Handle<SshHandler>, Error> {
435 let ssh_config = client::Config {
436 ..Default::default()
437 };
438
439 let handler = SshHandler {
440 host: host.to_string(),
441 port,
442 };
443
444 let mut handle = client::connect(Arc::new(ssh_config), (host, port), handler).await?;
445
446 if try_agent_auth(&mut handle, user).await {
448 return Ok(handle);
449 }
450
451 if try_key_file_auth(&mut handle, user, config).await? {
453 return Ok(handle);
454 }
455
456 Err(Error::Other(format!(
457 "authentication failed for {user}@{host}:{port} \
458 (tried ssh-agent and key files)"
459 )))
460}
461
462#[cfg(unix)]
464async fn try_agent_auth(handle: &mut client::Handle<SshHandler>, user: &str) -> bool {
465 let agent_path = match std::env::var("SSH_AUTH_SOCK") {
466 Ok(p) if !p.is_empty() => p,
467 _ => return false,
468 };
469 let stream = match tokio::net::UnixStream::connect(&agent_path).await {
470 Ok(s) => s,
471 Err(e) => {
472 log::debug!("ssh-agent connect failed: {e}");
473 return false;
474 }
475 };
476 let mut agent = agent::client::AgentClient::connect(stream);
477 let identities = match agent.request_identities().await {
478 Ok(ids) => ids,
479 Err(e) => {
480 log::debug!("ssh-agent request_identities failed: {e}");
481 return false;
482 }
483 };
484 for identity in &identities {
485 let public_key = identity.public_key().into_owned();
486 match handle
487 .authenticate_publickey_with(user, public_key, None, &mut agent)
488 .await
489 {
490 Ok(russh::client::AuthResult::Success) => return true,
491 Ok(_) => continue,
492 Err(e) => {
493 log::debug!("ssh-agent auth attempt failed: {e}");
494 continue;
495 }
496 }
497 }
498 false
499}
500
501#[cfg(not(unix))]
503async fn try_agent_auth(_handle: &mut client::Handle<SshHandler>, _user: &str) -> bool {
504 false
505}
506
507async fn try_key_file_auth(
509 handle: &mut client::Handle<SshHandler>,
510 user: &str,
511 config: &ResolvedConfig,
512) -> Result<bool, Error> {
513 let home = match home_dir() {
514 Some(h) => h,
515 None => return Ok(false),
516 };
517
518 let mut candidates: Vec<PathBuf> = config.identity_files.clone();
520 for default in &["id_ed25519", "id_ecdsa", "id_rsa"] {
521 let p = home.join(".ssh").join(default);
522 if !candidates.contains(&p) {
523 candidates.push(p);
524 }
525 }
526
527 for path in &candidates {
528 if !path.exists() {
529 continue;
530 }
531 let key = match keys::load_secret_key(path, None) {
532 Ok(k) => k,
533 Err(e) => {
534 log::debug!("could not load {}: {e}", path.display());
535 continue;
536 }
537 };
538
539 let hash_alg = handle.best_supported_rsa_hash().await.ok().flatten();
541 let key_with_hash = PrivateKeyWithHashAlg::new(Arc::new(key), hash_alg.flatten());
542
543 match handle.authenticate_publickey(user, key_with_hash).await {
544 Ok(russh::client::AuthResult::Success) => return Ok(true),
545 Ok(_) => continue,
546 Err(e) => {
547 log::debug!("key auth failed for {}: {e}", path.display());
548 continue;
549 }
550 }
551 }
552 Ok(false)
553}
554
555async fn exec_command(handle: &client::Handle<SshHandler>, cmd: &str) -> Result<String, Error> {
559 let mut channel = handle.channel_open_session().await?;
560 channel.exec(true, cmd.as_bytes()).await?;
561
562 let mut output = Vec::new();
563 while let Some(msg) = channel.wait().await {
564 match msg {
565 russh::ChannelMsg::Data { data } => output.extend_from_slice(&data),
566 russh::ChannelMsg::Eof | russh::ChannelMsg::Close => break,
567 _ => continue,
568 }
569 }
570 Ok(String::from_utf8_lossy(&output).into_owned())
571}
572
573fn home_dir() -> Option<PathBuf> {
576 #[cfg(unix)]
577 {
578 std::env::var("HOME").ok().map(PathBuf::from)
579 }
580 #[cfg(windows)]
581 {
582 std::env::var("USERPROFILE").ok().map(PathBuf::from)
583 }
584}
585
586fn current_username() -> String {
587 #[cfg(unix)]
588 {
589 std::env::var("USER").unwrap_or_else(|_| "root".into())
590 }
591 #[cfg(windows)]
592 {
593 std::env::var("USERNAME").unwrap_or_else(|_| "user".into())
594 }
595}
596
597pub fn parse_ssh_uri(s: &str) -> (Option<String>, String, Option<String>) {
600 let colon_start = s.find('@').map(|a| a + 1).unwrap_or(0);
601 let (host_part, socket) = if let Some(rel) = s[colon_start..].find(':') {
602 let pos = colon_start + rel;
603 let path = &s[pos + 1..];
604 if path.is_empty() {
605 (s, None)
606 } else {
607 (&s[..pos], Some(path.to_string()))
608 }
609 } else {
610 (s, None)
611 };
612 let (user, host) = if let Some(at) = host_part.rfind('@') {
613 (
614 Some(host_part[..at].to_string()),
615 host_part[at + 1..].to_string(),
616 )
617 } else {
618 (None, host_part.to_string())
619 };
620 (user, host, socket)
621}