use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use base64::Engine;
use super::handler::{SshHandler, SshOutput, SshTarget};
fn shell_escape(s: &str) -> String {
format!("'{}'", s.replace('\'', "'\\''"))
}
struct ClientHandler;
impl russh::client::Handler for ClientHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh::keys::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
}
pub struct RusshHandler {
timeout: Duration,
}
impl RusshHandler {
pub fn new(timeout: Duration) -> Self {
Self { timeout }
}
async fn connect(
&self,
target: &SshTarget,
) -> std::result::Result<russh::client::Handle<ClientHandler>, String> {
let config = russh::client::Config {
inactivity_timeout: Some(self.timeout),
..<_>::default()
};
let addr = (target.host.as_str(), target.port);
let mut session = russh::client::connect(Arc::new(config), addr, ClientHandler)
.await
.map_err(|e| format!("connection failed: {e}"))?;
if let Some(ref key_pem) = target.private_key {
let key_pair = russh::keys::PrivateKey::from_openssh(key_pem.as_bytes())
.map_err(|e| format!("invalid private key: {e}"))?;
let auth = session
.authenticate_publickey(
&target.user,
russh::keys::PrivateKeyWithHashAlg::new(
Arc::new(key_pair),
session
.best_supported_rsa_hash()
.await
.ok()
.flatten()
.flatten(),
),
)
.await
.map_err(|e| format!("publickey auth failed: {e}"))?;
if !auth.success() {
return Err("publickey authentication rejected".to_string());
}
} else if let Some(ref password) = target.password {
let auth = session
.authenticate_password(&target.user, password)
.await
.map_err(|e| format!("password auth failed: {e}"))?;
if !auth.success() {
return Err("password authentication rejected".to_string());
}
} else {
let auth = session
.authenticate_none(&target.user)
.await
.map_err(|e| format!("auth failed: {e}"))?;
if !auth.success() {
return Err("ssh: authentication failed (server requires credentials)".to_string());
}
}
Ok(session)
}
}
#[async_trait]
impl SshHandler for RusshHandler {
async fn exec(
&self,
target: &SshTarget,
command: &str,
) -> std::result::Result<SshOutput, String> {
let session = self.connect(target).await?;
let mut channel = session
.channel_open_session()
.await
.map_err(|e| format!("channel open failed: {e}"))?;
channel
.exec(true, command)
.await
.map_err(|e| format!("exec failed: {e}"))?;
let mut stdout = Vec::new();
let mut stderr = Vec::new();
let mut exit_code: Option<u32> = None;
loop {
let Some(msg) = channel.wait().await else {
break;
};
match msg {
russh::ChannelMsg::Data { ref data } => {
stdout.extend_from_slice(data);
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
stderr.extend_from_slice(data);
}
}
russh::ChannelMsg::ExitStatus { exit_status } => {
exit_code = Some(exit_status);
}
_ => {}
}
}
let _ = session
.disconnect(russh::Disconnect::ByApplication, "", "")
.await;
Ok(SshOutput {
stdout: String::from_utf8_lossy(&stdout).into_owned(),
stderr: String::from_utf8_lossy(&stderr).into_owned(),
exit_code: exit_code.unwrap_or(0) as i32,
})
}
async fn shell(&self, target: &SshTarget) -> std::result::Result<SshOutput, String> {
let session = self.connect(target).await?;
let mut channel = session
.channel_open_session()
.await
.map_err(|e| format!("channel open failed: {e}"))?;
channel
.request_pty(false, "xterm", 80, 24, 0, 0, &[])
.await
.map_err(|e| format!("pty request failed: {e}"))?;
channel
.request_shell(true)
.await
.map_err(|e| format!("shell request failed: {e}"))?;
let mut stdout = Vec::new();
let mut stderr = Vec::new();
let mut exit_code: Option<u32> = None;
loop {
let Some(msg) = channel.wait().await else {
break;
};
match msg {
russh::ChannelMsg::Data { ref data } => {
stdout.extend_from_slice(data);
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
stderr.extend_from_slice(data);
}
}
russh::ChannelMsg::ExitStatus { exit_status } => {
exit_code = Some(exit_status);
}
_ => {}
}
}
let _ = session
.disconnect(russh::Disconnect::ByApplication, "", "")
.await;
Ok(SshOutput {
stdout: String::from_utf8_lossy(&stdout).into_owned(),
stderr: String::from_utf8_lossy(&stderr).into_owned(),
exit_code: exit_code.unwrap_or(0) as i32,
})
}
async fn upload(
&self,
target: &SshTarget,
remote_path: &str,
content: &[u8],
mode: u32,
) -> std::result::Result<(), String> {
let b64 = base64::engine::general_purpose::STANDARD.encode(content);
let escaped_path = shell_escape(remote_path);
let cmd = format!(
"echo '{}' | base64 -d > {} && chmod {:o} {}",
b64, escaped_path, mode, escaped_path
);
let result = self.exec(target, &cmd).await?;
if result.exit_code != 0 {
return Err(format!(
"upload failed (exit {}): {}",
result.exit_code, result.stderr
));
}
Ok(())
}
async fn download(
&self,
target: &SshTarget,
remote_path: &str,
) -> std::result::Result<Vec<u8>, String> {
let cmd = format!("base64 < {}", shell_escape(remote_path));
let result = self.exec(target, &cmd).await?;
if result.exit_code != 0 {
return Err(format!(
"download failed (exit {}): {}",
result.exit_code, result.stderr
));
}
let decoded = base64::engine::general_purpose::STANDARD
.decode(result.stdout.trim())
.map_err(|e| format!("base64 decode failed: {e}"))?;
Ok(decoded)
}
}