use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use russh::ChannelMsg;
use russh::keys::key::PrivateKeyWithHashAlg;
use tokio::sync::Mutex;
use tracing::{debug, warn};
use crate::error::{PsrpError, Result};
use crate::transport::PsrpTransport;
#[derive(Debug, Clone)]
pub struct SshConfig {
pub host: String,
pub port: u16,
pub username: String,
pub auth: SshAuth,
pub subsystem: String,
pub connect_timeout: Duration,
}
impl Default for SshConfig {
fn default() -> Self {
Self {
host: String::new(),
port: 22,
username: String::new(),
auth: SshAuth::Agent,
subsystem: "powershell".into(),
connect_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone)]
pub enum SshAuth {
Password(String),
PrivateKey {
path: PathBuf,
passphrase: Option<String>,
},
Agent,
}
struct ClientHandler;
#[async_trait]
impl russh::client::Handler for ClientHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &ssh_key::PublicKey,
) -> std::result::Result<bool, Self::Error> {
Ok(true)
}
}
pub struct SshPsrpTransport {
channel: Arc<Mutex<russh::Channel<russh::client::Msg>>>,
handle: Arc<Mutex<russh::client::Handle<ClientHandler>>>,
closed: bool,
}
impl std::fmt::Debug for SshPsrpTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SshPsrpTransport")
.field("closed", &self.closed)
.finish()
}
}
impl SshPsrpTransport {
pub async fn connect(config: SshConfig) -> Result<Self> {
let ssh_config = russh::client::Config::default();
let addr = format!("{}:{}", config.host, config.port);
debug!(%addr, "SSH: connecting");
let mut handle = tokio::time::timeout(
config.connect_timeout,
russh::client::connect(Arc::new(ssh_config), &addr, ClientHandler),
)
.await
.map_err(|_| PsrpError::protocol(format!("SSH connect timeout to {addr}")))?
.map_err(|e| PsrpError::protocol(format!("SSH connect: {e}")))?;
let authenticated = match &config.auth {
SshAuth::Password(pw) => handle
.authenticate_password(&config.username, pw)
.await
.map_err(|e| PsrpError::protocol(format!("SSH password auth: {e}")))?,
SshAuth::PrivateKey { path, passphrase } => {
let private_key = russh_keys::load_secret_key(path, passphrase.as_deref())
.map_err(|e| PsrpError::protocol(format!("SSH key load: {e}")))?;
let key = PrivateKeyWithHashAlg::new(Arc::new(private_key), None)
.map_err(|e| PsrpError::protocol(format!("SSH key prep: {e}")))?;
handle
.authenticate_publickey(&config.username, key)
.await
.map_err(|e| PsrpError::protocol(format!("SSH pubkey auth: {e}")))?
}
SshAuth::Agent => {
return Err(PsrpError::protocol(
"SSH agent auth not yet implemented — use Password or PrivateKey",
));
}
};
if !authenticated {
return Err(PsrpError::protocol("SSH authentication failed"));
}
debug!("SSH: authenticated");
let channel = handle
.channel_open_session()
.await
.map_err(|e| PsrpError::protocol(format!("SSH channel open: {e}")))?;
channel
.request_subsystem(true, &config.subsystem)
.await
.map_err(|e| PsrpError::protocol(format!("SSH subsystem request: {e}")))?;
debug!(subsystem = %config.subsystem, "SSH: subsystem opened");
Ok(Self {
channel: Arc::new(Mutex::new(channel)),
handle: Arc::new(Mutex::new(handle)),
closed: false,
})
}
}
#[async_trait]
impl PsrpTransport for SshPsrpTransport {
async fn send_fragment(&self, bytes: &[u8]) -> Result<()> {
let channel = self.channel.lock().await;
channel
.data(bytes)
.await
.map_err(|e| PsrpError::protocol(format!("SSH send: {e}")))?;
Ok(())
}
async fn recv_chunk(&mut self) -> Result<Vec<u8>> {
let mut channel = self.channel.lock().await;
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
let bytes = data.to_vec();
if bytes.is_empty() {
continue;
}
return Ok(bytes);
}
Some(ChannelMsg::ExtendedData { data, ext: 1 }) => {
let text = String::from_utf8_lossy(&data);
debug!(stderr = %text, "SSH stderr");
continue;
}
Some(ChannelMsg::Eof) | None => {
return Ok(Vec::new());
}
Some(_other) => {
continue;
}
}
}
}
async fn signal_stop(&self) -> Result<()> {
warn!("signal_stop on SSH transport is a no-op; close the channel to abort");
Ok(())
}
async fn close_shell(&mut self) -> Result<()> {
if self.closed {
return Ok(());
}
self.closed = true;
{
let channel = self.channel.lock().await;
let _ = channel.eof().await;
let _ = channel.close().await;
}
let handle = self.handle.lock().await;
handle
.disconnect(russh::Disconnect::ByApplication, "psrp-rs close", "en")
.await
.map_err(|e| PsrpError::protocol(format!("SSH disconnect: {e}")))?;
debug!("SSH: session closed");
Ok(())
}
}
impl Drop for SshPsrpTransport {
fn drop(&mut self) {
if !self.closed {
warn!("SshPsrpTransport dropped without close — SSH session may leak");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ssh_config_defaults() {
let cfg = SshConfig::default();
assert_eq!(cfg.port, 22);
assert_eq!(cfg.subsystem, "powershell");
assert_eq!(cfg.connect_timeout, Duration::from_secs(30));
}
#[test]
fn ssh_auth_variants() {
let _pw = SshAuth::Password("secret".into());
let _key = SshAuth::PrivateKey {
path: PathBuf::from("/home/user/.ssh/id_rsa"),
passphrase: None,
};
let _agent = SshAuth::Agent;
}
#[test]
fn debug_format() {
let s = format!(
"{:?}",
SshConfig {
host: "test".into(),
..SshConfig::default()
}
);
assert!(s.contains("test"));
}
}