use async_trait::async_trait;
use russh::client::AuthResult;
use russh::keys::{self, PrivateKeyWithHashAlg};
use russh::*;
use std::borrow::Cow;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::error::TransportError;
use crate::transport::Transport;
pub struct SshTransport {
channel: Arc<Mutex<ChannelStream>>,
handle: Arc<Mutex<client::Handle<SshHandler>>>,
}
#[derive(Clone)]
pub enum SshAuth {
Password(String),
KeyFile {
path: String,
passphrase: Option<String>,
},
Agent,
}
#[derive(Clone, Debug, Default)]
pub enum HostKeyVerification {
#[default]
AcceptAll,
Fingerprint(String),
RejectAll,
}
#[derive(Clone)]
pub struct SshConfig {
pub host: String,
pub port: u16,
pub username: String,
pub auth: SshAuth,
pub host_key_verification: HostKeyVerification,
}
struct SshHandler {
host_key_verification: HostKeyVerification,
}
impl client::Handler for SshHandler {
type Error = russh::Error;
fn check_server_key(
&mut self,
server_public_key: &keys::PublicKey,
) -> impl std::future::Future<Output = Result<bool, Self::Error>> + Send {
let result = match &self.host_key_verification {
HostKeyVerification::AcceptAll => {
tracing::warn!(
"accepting SSH host key without verification — \
set host_key_verification() for production use"
);
Ok(true)
}
HostKeyVerification::Fingerprint(expected) => {
let fingerprint = server_public_key.fingerprint(keys::HashAlg::Sha256);
let actual = fingerprint.to_string();
let matches = actual == *expected
|| actual
.strip_prefix("SHA256:")
.is_some_and(|stripped| stripped == expected);
if matches {
tracing::debug!("SSH host key fingerprint verified");
Ok(true)
} else {
tracing::error!(
expected = %expected,
actual = %actual,
"SSH host key fingerprint mismatch — possible MITM attack"
);
Ok(false)
}
}
HostKeyVerification::RejectAll => {
tracing::error!("SSH host key rejected (RejectAll policy)");
Ok(false)
}
};
std::future::ready(result)
}
}
struct ChannelStream {
channel: Channel<client::Msg>,
read_buffer: Vec<u8>,
}
impl SshTransport {
pub async fn connect(config: SshConfig) -> Result<Self, TransportError> {
let preferred = Preferred {
kex: Cow::Borrowed(&[
kex::CURVE25519,
kex::CURVE25519_PRE_RFC_8731,
kex::ECDH_SHA2_NISTP256,
kex::ECDH_SHA2_NISTP384,
kex::ECDH_SHA2_NISTP521,
kex::DH_G16_SHA512,
kex::DH_G14_SHA256,
kex::EXTENSION_SUPPORT_AS_CLIENT,
kex::EXTENSION_OPENSSH_STRICT_KEX_AS_CLIENT,
]),
..Preferred::default()
};
let russh_config = client::Config {
preferred,
..Default::default()
};
let handler = SshHandler {
host_key_verification: config.host_key_verification.clone(),
};
let mut handle = client::connect(Arc::new(russh_config), (&*config.host, config.port), handler)
.await
.map_err(|e| TransportError::Connect(format!("SSH connect to {}:{} failed: {e}", config.host, config.port)))?;
let auth_result = match &config.auth {
SshAuth::Password(password) => {
handle
.authenticate_password(&config.username, password)
.await
.map_err(|e| TransportError::Auth(format!("password auth failed: {e}")))?
}
SshAuth::KeyFile { path, passphrase } => {
let key_path = Path::new(path);
let key_pair = if let Some(pass) = passphrase {
keys::decode_secret_key(&std::fs::read_to_string(key_path)
.map_err(|e| {
tracing::debug!(path, %e, "failed to read key file");
TransportError::Auth("failed to read SSH key file".to_string())
})?, Some(pass))
.map_err(|e| {
tracing::debug!(%e, "failed to decode key");
TransportError::Auth("failed to decode SSH key".to_string())
})?
} else {
keys::decode_secret_key(&std::fs::read_to_string(key_path)
.map_err(|e| {
tracing::debug!(path, %e, "failed to read key file");
TransportError::Auth("failed to read SSH key file".to_string())
})?, None)
.map_err(|e| {
tracing::debug!(%e, "failed to decode key");
TransportError::Auth("failed to decode SSH key".to_string())
})?
};
let hash_alg = handle.best_supported_rsa_hash().await
.unwrap_or(None)
.flatten();
let key_with_hash = PrivateKeyWithHashAlg::new(
Arc::new(key_pair),
hash_alg,
);
handle
.authenticate_publickey(&config.username, key_with_hash)
.await
.map_err(|e| TransportError::Auth(format!("key auth failed: {e}")))?
}
SshAuth::Agent => {
let mut agent = keys::agent::client::AgentClient::connect_env()
.await
.map_err(|e| TransportError::Auth(format!("SSH agent connect failed: {e}")))?;
let identities = agent
.request_identities()
.await
.map_err(|e| TransportError::Auth(format!("SSH agent identities failed: {e}")))?;
let mut auth_success = false;
for public_key in identities {
match handle
.authenticate_publickey_with(&config.username, public_key, None, &mut agent)
.await
{
Ok(AuthResult::Success) => {
auth_success = true;
break;
}
_ => continue,
}
}
if auth_success {
AuthResult::Success
} else {
AuthResult::Failure {
remaining_methods: russh::MethodSet::empty(),
partial_success: false,
}
}
}
};
if !matches!(auth_result, AuthResult::Success) {
return Err(TransportError::Auth(format!(
"authentication failed for user '{}'",
config.username
)));
}
let mut channel = handle
.channel_open_session()
.await
.map_err(|e| TransportError::Channel(format!("failed to open SSH channel: {e}")))?;
channel
.request_subsystem(true, "netconf")
.await
.map_err(|e| TransportError::Channel(format!("failed to request netconf subsystem: {e}")))?;
loop {
match channel.wait().await {
Some(ChannelMsg::Success) => break,
Some(ChannelMsg::Failure) => {
return Err(TransportError::Channel(
"server rejected netconf subsystem request".to_string(),
));
}
Some(ChannelMsg::WindowAdjusted { .. }) => {
continue;
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
return Err(TransportError::ChannelClosed(
"channel closed before subsystem confirmation".to_string(),
));
}
Some(_other) => {
continue;
}
}
}
let channel_stream = ChannelStream {
channel,
read_buffer: Vec::new(),
};
Ok(Self {
channel: Arc::new(Mutex::new(channel_stream)),
handle: Arc::new(Mutex::new(handle)),
})
}
}
#[async_trait]
impl Transport for SshTransport {
async fn write_all(&mut self, data: &[u8]) -> Result<(), TransportError> {
let channel = self.channel.lock().await;
channel
.channel
.data(data)
.await
.map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
Ok(())
}
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TransportError> {
let mut channel = self.channel.lock().await;
if !channel.read_buffer.is_empty() {
let to_read = std::cmp::min(buf.len(), channel.read_buffer.len());
buf[..to_read].copy_from_slice(&channel.read_buffer[..to_read]);
channel.read_buffer.drain(..to_read);
return Ok(to_read);
}
loop {
match channel.channel.wait().await {
Some(ChannelMsg::Data { data: channel_data }) => {
let bytes = &channel_data[..];
let to_copy = std::cmp::min(buf.len(), bytes.len());
buf[..to_copy].copy_from_slice(&bytes[..to_copy]);
if bytes.len() > to_copy {
channel.read_buffer.extend_from_slice(&bytes[to_copy..]);
}
return Ok(to_copy);
}
Some(ChannelMsg::Eof) | Some(ChannelMsg::Close) | None => {
return Ok(0);
}
Some(_other) => {
continue;
}
}
}
}
async fn close(&mut self) -> Result<(), TransportError> {
let channel = self.channel.lock().await;
channel
.channel
.eof()
.await
.map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
let handle = self.handle.lock().await;
handle
.disconnect(Disconnect::ByApplication, "closing session", "en")
.await
.map_err(|e| TransportError::Io(std::io::Error::other(e.to_string())))?;
Ok(())
}
}