use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use russh::client::{self, Handle};
use russh::ChannelMsg;
use tokio::time;
use crate::ssh::client::Host;
struct MetricsHandler {
host: String,
port: u16,
}
#[async_trait]
impl client::Handler for MetricsHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
server_public_key: &russh::keys::key::PublicKey,
) -> Result<bool, Self::Error> {
match russh::keys::check_known_hosts(&self.host, self.port, server_public_key) {
Ok(true) => Ok(true), Ok(false) => Ok(true), Err(russh::keys::Error::KeyChanged { .. }) => {
tracing::warn!(
host = %self.host,
port = self.port,
"Server key mismatch in known_hosts — possible MITM attack, refusing connection"
);
Ok(false)
}
Err(e) => {
tracing::warn!(error = %e, "known_hosts check failed; accepting key");
Ok(true)
}
}
}
}
#[derive(Clone)]
pub struct SshSession {
handle: Arc<Handle<MetricsHandler>>,
}
impl SshSession {
pub async fn connect(host: &Host) -> anyhow::Result<Self> {
let config = Arc::new(client::Config {
inactivity_timeout: Some(Duration::from_secs(30)),
keepalive_interval: Some(Duration::from_secs(15)),
keepalive_max: 3,
..Default::default()
});
let addr = format!("{}:{}", host.hostname, host.port);
let mut handle = time::timeout(
Duration::from_secs(10),
client::connect(
config,
addr,
MetricsHandler {
host: host.hostname.clone(),
port: host.port,
},
),
)
.await
.map_err(|_| anyhow!("SSH connection timed out (10 s)"))?
.context("SSH connection failed")?;
let authenticated = authenticate(&mut handle, host).await?;
if !authenticated {
return Err(anyhow!("SSH authentication failed for {}", host.name));
}
Ok(Self {
handle: Arc::new(handle),
})
}
pub async fn run_command(&self, cmd: &str) -> anyhow::Result<String> {
let mut channel = self
.handle
.channel_open_session()
.await
.context("open SSH channel")?;
channel.exec(true, cmd).await.context("exec SSH command")?;
let output = time::timeout(Duration::from_secs(30), collect_output(&mut channel))
.await
.map_err(|_| anyhow!("command timed out (30 s): {}", cmd))?
.context("read command output")?;
Ok(output)
}
pub async fn open_sftp_channel(
&self,
) -> anyhow::Result<russh::ChannelStream<russh::client::Msg>> {
let channel = self
.handle
.channel_open_session()
.await
.context("open SFTP session channel")?;
channel
.request_subsystem(true, "sftp")
.await
.context("request SFTP subsystem")?;
Ok(channel.into_stream())
}
pub async fn disconnect(self) {
let _ = self
.handle
.disconnect(russh::Disconnect::ByApplication, "", "en")
.await;
}
}
async fn authenticate(handle: &mut Handle<MetricsHandler>, host: &Host) -> anyhow::Result<bool> {
let user = host.user.clone();
#[cfg(unix)]
{
if try_agent_auth(handle, &user).await.unwrap_or(false) {
return Ok(true);
}
}
if let Some(key_path) = &host.identity_file {
let path = expand_tilde(key_path);
if try_key_auth(handle, &user, &path).await.unwrap_or(false) {
return Ok(true);
}
}
for key_path in default_key_paths() {
if key_path.exists() {
let path_str = key_path.to_string_lossy().into_owned();
if try_key_auth(handle, &user, &path_str)
.await
.unwrap_or(false)
{
return Ok(true);
}
}
}
if let Some(password) = &host.password {
if try_password_auth(handle, &user, password)
.await
.unwrap_or(false)
{
tracing::info!(
host = %host.name,
"Connected via password authentication — consider setting up SSH key"
);
return Ok(true);
}
}
Ok(false)
}
fn default_key_paths() -> Vec<std::path::PathBuf> {
let Some(home) = dirs::home_dir() else {
return vec![];
};
let ssh = home.join(".ssh");
[
"id_ed25519",
"id_rsa",
"id_ecdsa",
"id_ecdsa_sk",
"id_ed25519_sk",
"id_dsa",
]
.iter()
.map(|name| ssh.join(name))
.collect()
}
async fn try_key_auth(
handle: &mut Handle<MetricsHandler>,
user: &str,
key_path: &str,
) -> anyhow::Result<bool> {
let path = key_path.to_string();
let key_pair = tokio::task::spawn_blocking(move || {
russh::keys::load_secret_key(&path, None).with_context(|| format!("load key from {path}"))
})
.await
.context("spawn_blocking panicked")??;
let ok = handle
.authenticate_publickey(user, Arc::new(key_pair))
.await
.context("authenticate_publickey")?;
Ok(ok)
}
#[cfg(unix)]
async fn try_agent_auth(handle: &mut Handle<MetricsHandler>, user: &str) -> anyhow::Result<bool> {
use russh::keys::agent::client::AgentClient;
let mut agent = AgentClient::connect_env()
.await
.context("connect to SSH agent")?;
let identities = agent
.request_identities()
.await
.context("request agent identities")?;
for pubkey in identities {
let (agent_back, result) = handle.authenticate_future(user, pubkey, agent).await;
agent = agent_back;
match result {
Ok(true) => return Ok(true),
Ok(false) => continue,
Err(_) => continue,
}
}
Ok(false)
}
async fn try_password_auth(
handle: &mut Handle<MetricsHandler>,
user: &str,
password: &str,
) -> anyhow::Result<bool> {
let ok = handle
.authenticate_password(user, password)
.await
.context("authenticate with password")?;
Ok(ok)
}
async fn collect_output(
channel: &mut russh::Channel<russh::client::Msg>,
) -> anyhow::Result<String> {
let mut buf = Vec::new();
loop {
match channel.wait().await {
Some(ChannelMsg::Data { ref data }) => {
buf.extend_from_slice(data);
}
Some(ChannelMsg::ExtendedData { .. }) => {
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) => break,
Some(ChannelMsg::ExitStatus { .. }) => break,
None => break,
_ => {}
}
}
let raw = String::from_utf8_lossy(&buf);
let normalised: String = raw.lines().flat_map(|l| [l, "\n"]).collect();
Ok(normalised)
}
fn expand_tilde(path: &str) -> String {
if path.starts_with("~/") || path == "~" {
if let Some(home) = dirs::home_dir() {
return path.replacen('~', &home.to_string_lossy(), 1);
}
}
path.to_string()
}