use std::path::{Path, PathBuf};
use crate::error::Result;
#[derive(Clone)]
pub struct SshConfig {
pub host: String,
pub port: u16,
pub user: String,
pub auth: SshAuth,
pub host_key: HostKeyVerification,
}
impl SshConfig {
pub fn new(
host: impl Into<String>,
user: impl Into<String>,
auth: SshAuth,
host_key: HostKeyVerification,
) -> Self {
Self {
host: host.into(),
port: 22,
user: user.into(),
auth,
host_key,
}
}
}
#[derive(Clone)]
pub enum SshAuth {
Password(String),
Key {
path: PathBuf,
passphrase: Option<String>,
},
}
impl std::fmt::Debug for SshAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SshAuth::Password(_) => f.write_str("Password(***)"),
SshAuth::Key { path, .. } => write!(f, "Key {{ path: {path:?}, passphrase: *** }}"),
}
}
}
impl std::fmt::Debug for SshConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SshConfig")
.field("host", &self.host)
.field("port", &self.port)
.field("user", &self.user)
.field("auth", &self.auth)
.field("host_key", &self.host_key)
.finish()
}
}
#[derive(Debug, Clone)]
pub enum HostKeyVerification {
Pinned(String),
KnownHosts(PathBuf),
AcceptAny,
}
#[allow(dead_code)]
pub(crate) fn verify_fingerprint(
policy: &HostKeyVerification,
host: &str,
fingerprint: &str,
) -> Result<bool> {
match policy {
HostKeyVerification::AcceptAny => Ok(true),
HostKeyVerification::Pinned(expected) => Ok(expected == fingerprint),
HostKeyVerification::KnownHosts(path) => verify_known_hosts(path, host, fingerprint),
}
}
#[allow(dead_code)]
fn verify_known_hosts(path: &Path, host: &str, fingerprint: &str) -> Result<bool> {
let content = match std::fs::read(path) {
Ok(b) => String::from_utf8_lossy(&b).into_owned(),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => String::new(),
Err(e) => return Err(e.into()),
};
for line in content.lines() {
let mut it = line.split_whitespace();
if let (Some(h), Some(fp)) = (it.next(), it.next()) {
if h == host {
return Ok(fp == fingerprint);
}
}
}
use std::io::Write;
let mut f = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)?;
f.write_all(format!("{host} {fingerprint}\n").as_bytes())?;
Ok(true)
}
#[cfg(feature = "ssh")]
mod imp {
use std::sync::mpsc as std_mpsc;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
use russh::client;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc as tokio_mpsc;
use super::{verify_fingerprint, HostKeyVerification, SshAuth, SshConfig};
use crate::error::{Error, Result};
use crate::transport::Transport;
const CHANNEL_CAP: usize = 64;
pub struct SshTransport {
write_tx: Option<tokio_mpsc::Sender<Vec<u8>>>,
read_rx: Option<std_mpsc::Receiver<Vec<u8>>>,
thread: Option<JoinHandle<()>>,
}
impl SshTransport {
pub fn connect(cfg: SshConfig) -> Result<Self> {
let (write_tx, write_rx) = tokio_mpsc::channel::<Vec<u8>>(CHANNEL_CAP);
let (read_tx, read_rx) = std_mpsc::sync_channel::<Vec<u8>>(CHANNEL_CAP);
let (ready_tx, ready_rx) = std_mpsc::channel::<Result<()>>();
let thread = std::thread::spawn(move || {
let rt = match tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
let _ = ready_tx.send(Err(Error::Transport(format!("runtime: {e}"))));
return;
}
};
rt.block_on(io_loop(cfg, write_rx, read_tx, ready_tx));
});
match ready_rx.recv() {
Ok(Ok(())) => Ok(SshTransport {
write_tx: Some(write_tx),
read_rx: Some(read_rx),
thread: Some(thread),
}),
Ok(Err(e)) => {
let _ = thread.join();
Err(e)
}
Err(_) => Err(Error::Transport("ssh thread died during connect".into())),
}
}
}
impl Transport for SshTransport {
fn write_all(&mut self, bytes: &[u8]) -> Result<()> {
let tx = self
.write_tx
.as_ref()
.ok_or_else(|| Error::Transport("ssh session closed".into()))?;
tx.blocking_send(bytes.to_vec())
.map_err(|_| Error::Transport("ssh session closed".into()))
}
fn recv_timeout(&self, dur: Duration) -> Option<Vec<u8>> {
self.read_rx.as_ref()?.recv_timeout(dur).ok()
}
}
impl Drop for SshTransport {
fn drop(&mut self) {
self.write_tx = None;
self.read_rx = None;
if let Some(t) = self.thread.take() {
let _ = t.join();
}
}
}
struct Handler {
policy: HostKeyVerification,
host: String,
}
impl client::Handler for Handler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
server_public_key: &russh::keys::ssh_key::PublicKey,
) -> std::result::Result<bool, Self::Error> {
let fp = server_public_key
.fingerprint(russh::keys::ssh_key::HashAlg::Sha256)
.to_string();
Ok(verify_fingerprint(&self.policy, &self.host, &fp).unwrap_or(false))
}
}
async fn establish(
cfg: &SshConfig,
) -> Result<(client::Handle<Handler>, russh::Channel<client::Msg>)> {
let config = Arc::new(client::Config::default());
let handler = Handler {
policy: cfg.host_key.clone(),
host: cfg.host.clone(),
};
let mut handle = client::connect(config, (cfg.host.as_str(), cfg.port), handler)
.await
.map_err(|e| Error::Transport(format!("ssh connect: {e}")))?;
let result = match &cfg.auth {
SshAuth::Password(p) => handle
.authenticate_password(cfg.user.clone(), p.clone())
.await
.map_err(|e| Error::Transport(format!("ssh auth: {e}")))?,
SshAuth::Key { path, passphrase } => {
let key = russh::keys::load_secret_key(path, passphrase.as_deref())
.map_err(|e| Error::Transport(format!("load key: {e}")))?;
let hash = handle
.best_supported_rsa_hash()
.await
.ok()
.flatten()
.flatten();
let key = russh::keys::PrivateKeyWithHashAlg::new(Arc::new(key), hash);
handle
.authenticate_publickey(cfg.user.clone(), key)
.await
.map_err(|e| Error::Transport(format!("ssh auth: {e}")))?
}
};
if !result.success() {
return Err(Error::Transport("ssh authentication failed".into()));
}
let channel = handle
.channel_open_session()
.await
.map_err(|e| Error::Transport(format!("open channel: {e}")))?;
channel
.request_pty(false, "xterm-256color", 120, 40, 0, 0, &[])
.await
.map_err(|e| Error::Transport(format!("request pty: {e}")))?;
channel
.exec(false, "/bin/sh")
.await
.map_err(|e| Error::Transport(format!("start shell: {e}")))?;
Ok((handle, channel))
}
async fn io_loop(
cfg: SshConfig,
mut write_rx: tokio_mpsc::Receiver<Vec<u8>>,
read_tx: std_mpsc::SyncSender<Vec<u8>>,
ready_tx: std_mpsc::Sender<Result<()>>,
) {
let (handle, channel) = match establish(&cfg).await {
Ok(v) => v,
Err(e) => {
let _ = ready_tx.send(Err(e));
return;
}
};
let _ = ready_tx.send(Ok(()));
let _keep = handle;
let stream = channel.into_stream(); let (mut rd, mut wr) = tokio::io::split(stream);
let mut buf = [0u8; 8192];
loop {
tokio::select! {
r = rd.read(&mut buf) => match r {
Ok(0) | Err(_) => break,
Ok(n) => if read_tx.send(buf[..n].to_vec()).is_err() { break; },
},
w = write_rx.recv() => match w {
Some(bytes) => {
if wr.write_all(&bytes).await.is_err() { break; }
let _ = wr.flush().await;
}
None => break, },
}
}
}
}
#[cfg(feature = "ssh")]
pub use imp::SshTransport;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pinned_matches_only_exact() {
let p = HostKeyVerification::Pinned("SHA256:abc".into());
assert!(verify_fingerprint(&p, "h", "SHA256:abc").unwrap());
assert!(!verify_fingerprint(&p, "h", "SHA256:evil").unwrap());
}
#[test]
fn known_hosts_tofu_then_pins_and_detects_change() {
let dir = std::env::temp_dir();
let path = dir.join(format!("execkit_kh_test_{}", std::process::id()));
let _ = std::fs::remove_file(&path);
let p = HostKeyVerification::KnownHosts(path.clone());
assert!(verify_fingerprint(&p, "prod-1", "SHA256:good").unwrap());
assert!(verify_fingerprint(&p, "prod-1", "SHA256:good").unwrap());
assert!(!verify_fingerprint(&p, "prod-1", "SHA256:evil").unwrap());
assert!(verify_fingerprint(&p, "prod-2", "SHA256:other").unwrap());
let _ = std::fs::remove_file(&path);
}
#[test]
fn auth_debug_redacts_secrets() {
let a = SshAuth::Password("hunter2".into());
assert!(!format!("{a:?}").contains("hunter2"));
}
#[test]
fn known_hosts_corrupt_file_fails_closed() {
let dir = std::env::temp_dir();
let path = dir.join(format!("execkit_kh_corrupt_{}", std::process::id()));
let mut bytes = b"prod-1 SHA256:GOODKEY\n".to_vec();
bytes.extend_from_slice(b"\xff\xfe bad\n");
std::fs::write(&path, &bytes).unwrap();
let p = HostKeyVerification::KnownHosts(path.clone());
let result = verify_fingerprint(&p, "prod-1", "SHA256:ATTACKER");
let _ = std::fs::remove_file(&path);
if let Ok(true) = result {
panic!("SEC-2: corrupt known_hosts silently accepted attacker key (TOFU bypass)");
}
}
#[test]
fn known_hosts_absent_file_tofu_preserved() {
let dir = std::env::temp_dir();
let path = dir.join(format!("execkit_kh_absent_{}", std::process::id()));
let _ = std::fs::remove_file(&path);
let p = HostKeyVerification::KnownHosts(path.clone());
assert!(
verify_fingerprint(&p, "new-host", "SHA256:firstkey").unwrap(),
"TOFU must accept first-ever connection when known_hosts is absent"
);
let _ = std::fs::remove_file(&path);
}
}