use anyhow::Result;
use serde::Deserialize;
use ssh2::Session;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream, ToSocketAddrs};
use std::path::{Path, PathBuf};
use std::time::Duration;
#[derive(Deserialize)]
pub enum SshAuth {
#[serde(rename = "agent")]
Agent,
#[serde(rename = "password")]
Password(String),
#[serde(rename = "pubkey")]
Pubkey(PathBuf),
}
pub struct SshOutput {
pub exit_status: i32,
pub stderr: Vec<u8>,
pub stdout: Vec<u8>,
}
pub struct SshClient {
addr: SocketAddr,
auth: SshAuth,
session: Option<Session>,
timeout: u64,
user: String,
}
impl SshClient {
pub fn from(user: impl Into<String>, addr: impl Into<SocketAddr>) -> Self {
Self {
addr: addr.into(),
auth: SshAuth::Agent,
session: None,
timeout: 0,
user: user.into(),
}
}
pub fn try_from(user: impl Into<String>, addr: impl ToSocketAddrs) -> Result<Self> {
if let Some(addr) = addr.to_socket_addrs()?.next() {
Ok(Self {
addr,
auth: SshAuth::Agent,
session: None,
timeout: 0,
user: user.into(),
})
} else {
Err(anyhow::anyhow!("Socket address conversion failed"))
}
}
pub fn set_auth_agent(&mut self) -> &mut Self {
self.auth = SshAuth::Agent;
self
}
pub fn set_auth_password(&mut self, password: impl Into<String>) -> &mut Self {
self.auth = SshAuth::Password(password.into());
self
}
pub fn set_auth_pubkey(&mut self, path: impl Into<PathBuf>) -> &mut Self {
self.auth = SshAuth::Pubkey(path.into());
self
}
pub fn set_timeout(&mut self, timeout_ms: u64) -> &mut Self {
self.timeout = timeout_ms;
self
}
pub fn get_addr(&self) -> SocketAddr {
self.addr
}
pub fn get_auth(&self) -> &SshAuth {
&self.auth
}
pub fn get_timeout(&self) -> u64 {
self.timeout
}
pub fn get_user(&self) -> &str {
&self.user
}
pub fn is_connected(&self) -> bool {
self.session.is_some()
}
pub fn execute(&mut self, command: &str) -> Result<SshOutput> {
if self.session.is_none() {
self.connect()?;
}
let session = self.session.as_ref().unwrap();
let mut channel = session.channel_session()?;
let mut stderr_stream = channel.stderr();
channel.exec(command)?;
let mut stdout = Vec::new();
channel.read_to_end(&mut stdout)?;
let mut stderr = Vec::new();
stderr_stream.read_to_end(&mut stderr)?;
channel.wait_close()?;
let exit_status = channel.exit_status()?;
Ok(SshOutput {
exit_status,
stdout,
stderr,
})
}
pub fn scp_download<P: AsRef<Path>>(&mut self, remote_path: P, local_path: P) -> Result<()> {
if self.session.is_none() {
self.connect()?;
}
let session = self.session.as_ref().unwrap();
let (mut channel, _) = session.scp_recv(remote_path.as_ref())?;
let mut buffer = Vec::new();
channel.read_to_end(&mut buffer)?;
std::fs::write(local_path, &buffer)?;
channel.send_eof()?;
channel.wait_eof()?;
channel.close()?;
channel.wait_close()?;
Ok(())
}
pub fn scp_upload<P: AsRef<Path>>(&mut self, local_path: P, remote_path: P) -> Result<()> {
if self.session.is_none() {
self.connect()?;
}
let session = self.session.as_ref().unwrap();
let buffer = std::fs::read(local_path)?;
let size = buffer.len() as u64;
let mut channel = session.scp_send(remote_path.as_ref(), 0o644, size, None)?;
channel.write_all(&buffer)?;
channel.send_eof()?;
channel.wait_eof()?;
channel.close()?;
channel.wait_close()?;
Ok(())
}
pub fn connect(&mut self) -> Result<&mut Self> {
let mut session = Session::new()?;
let tcp_stream = if self.timeout == 0 {
TcpStream::connect(&self.addr)?
} else {
session.set_timeout(self.timeout as u32);
TcpStream::connect_timeout(&self.addr, Duration::from_millis(self.timeout))?
};
session.set_tcp_stream(tcp_stream);
session.handshake()?;
match &self.auth {
SshAuth::Agent => session.userauth_agent(&self.user)?,
SshAuth::Password(password) => session.userauth_password(&self.user, password)?,
SshAuth::Pubkey(path) => session.userauth_pubkey_file(&self.user, None, path, None)?,
}
if !session.authenticated() {
return Err(anyhow::anyhow!("Authentication failed"));
}
self.session = Some(session);
Ok(self)
}
pub fn disconnect(&mut self) -> &mut Self {
self.session = None;
self
}
}