use crate::{CommandOutput, Error, Result};
use ssh2::Session;
use std::io::Read;
use std::net::TcpStream;
use std::path::Path;
#[derive(Debug, Clone)]
pub enum AuthMethod {
Password {
username: String,
password: String,
},
PublicKey {
username: String,
private_key_path: String,
passphrase: Option<String>,
},
}
pub struct SshClient {
host: String,
port: u16,
auth: Option<AuthMethod>,
session: Option<Session>,
}
impl std::fmt::Debug for SshClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SshClient")
.field("host", &self.host)
.field("port", &self.port)
.field("auth", &self.auth)
.field("connected", &self.is_connected())
.finish()
}
}
impl SshClient {
pub fn new(host: impl Into<String>, port: u16) -> Result<Self> {
let host = host.into();
if host.is_empty() {
return Err(Error::InvalidConfig("Host cannot be empty".to_string()));
}
if port == 0 {
return Err(Error::InvalidConfig("Port cannot be 0".to_string()));
}
Ok(Self {
host,
port,
auth: None,
session: None,
})
}
pub fn with_auth(mut self, auth: AuthMethod) -> Self {
self.auth = Some(auth);
self
}
pub fn connect(mut self) -> Result<Self> {
let auth = self
.auth
.as_ref()
.ok_or_else(|| Error::InvalidConfig("No authentication method set".to_string()))?;
let tcp = TcpStream::connect(format!("{}:{}", self.host, self.port)).map_err(|e| {
Error::ConnectionFailed {
host: self.host.clone(),
port: self.port,
source: e,
}
})?;
let mut session = Session::new()?;
session.set_tcp_stream(tcp);
session.handshake()?;
match auth {
AuthMethod::Password { username, password } => {
session.userauth_password(username, password).map_err(|e| {
Error::AuthenticationFailed {
username: username.clone(),
reason: e.to_string(),
}
})?;
}
AuthMethod::PublicKey {
username,
private_key_path,
passphrase,
} => {
let key_path = Path::new(private_key_path);
if !key_path.exists() {
return Err(Error::PrivateKeyNotFound {
path: private_key_path.clone(),
});
}
session
.userauth_pubkey_file(username, None, key_path, passphrase.as_deref())
.map_err(|e| Error::AuthenticationFailed {
username: username.clone(),
reason: e.to_string(),
})?;
}
}
if !session.authenticated() {
let username = match auth {
AuthMethod::Password { username, .. } => username.clone(),
AuthMethod::PublicKey { username, .. } => username.clone(),
};
return Err(Error::AuthenticationFailed {
username,
reason: "Authentication completed but session is not authenticated".to_string(),
});
}
self.session = Some(session);
Ok(self)
}
pub fn execute(&mut self, command: &str) -> Result<CommandOutput> {
let session = self.session.as_ref().ok_or(Error::NotConnected)?;
let mut channel = session
.channel_session()
.map_err(|e| Error::ChannelFailed(e.to_string()))?;
channel
.exec(command)
.map_err(|e| Error::ExecutionFailed(e.to_string()))?;
let mut stdout = String::new();
channel
.read_to_string(&mut stdout)
.map_err(|e| Error::ExecutionFailed(format!("Failed to read stdout: {}", e)))?;
let mut stderr = String::new();
channel
.stderr()
.read_to_string(&mut stderr)
.map_err(|e| Error::ExecutionFailed(format!("Failed to read stderr: {}", e)))?;
channel
.wait_close()
.map_err(|e| Error::ExecutionFailed(format!("Failed to close channel: {}", e)))?;
let exit_status = channel.exit_status()?;
Ok(CommandOutput::new(stdout, stderr, exit_status))
}
pub fn execute_batch(&mut self, commands: &[&str]) -> Result<Vec<CommandOutput>> {
commands.iter().map(|cmd| self.execute(cmd)).collect()
}
pub fn host(&self) -> &str {
&self.host
}
pub fn port(&self) -> u16 {
self.port
}
pub fn is_connected(&self) -> bool {
self.session.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_client() {
let client = SshClient::new("example.com", 22).unwrap();
assert_eq!(client.host(), "example.com");
assert_eq!(client.port(), 22);
assert!(!client.is_connected());
}
#[test]
fn test_empty_host() {
let result = SshClient::new("", 22);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::InvalidConfig(_)));
}
#[test]
fn test_zero_port() {
let result = SshClient::new("example.com", 0);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::InvalidConfig(_)));
}
#[test]
fn test_with_auth() {
let client = SshClient::new("example.com", 22)
.unwrap()
.with_auth(AuthMethod::Password {
username: "user".to_string(),
password: "pass".to_string(),
});
assert!(client.auth.is_some());
}
#[test]
fn test_connect_without_auth() {
let client = SshClient::new("example.com", 22).unwrap();
let result = client.connect();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), Error::InvalidConfig(_)));
}
}