use russh::client::KeyboardInteractiveAuthResponse;
use russh::{
client::{Config, Handle, Handler, Msg},
Channel,
};
use russh_sftp::{client::SftpSession, protocol::OpenFlags};
use std::net::SocketAddr;
use std::sync::Arc;
use std::{fmt::Debug, path::Path};
use std::{io, path::PathBuf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::ToSocketAddrsWithHostname;
use crate::utils::buffer_pool::global;
const SSH_CMD_BUFFER_SIZE: usize = 8192;
#[allow(dead_code)]
const SFTP_BUFFER_SIZE: usize = 65536;
const SSH_RESPONSE_BUFFER_SIZE: usize = 1024;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum AuthMethod {
Password(String),
PrivateKey {
key_data: String,
key_pass: Option<String>,
},
PrivateKeyFile {
key_file_path: PathBuf,
key_pass: Option<String>,
},
#[cfg(not(target_os = "windows"))]
PublicKeyFile {
key_file_path: PathBuf,
},
#[cfg(not(target_os = "windows"))]
Agent,
KeyboardInteractive(AuthKeyboardInteractive),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct PromptResponse {
exact: bool,
prompt: String,
response: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
#[non_exhaustive]
pub struct AuthKeyboardInteractive {
submethods: Option<String>,
responses: Vec<PromptResponse>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum ServerCheckMethod {
NoCheck,
PublicKey(String),
PublicKeyFile(String),
DefaultKnownHostsFile,
KnownHostsFile(String),
}
impl AuthMethod {
pub fn with_password(password: &str) -> Self {
Self::Password(password.to_string())
}
pub fn with_key(key: &str, passphrase: Option<&str>) -> Self {
Self::PrivateKey {
key_data: key.to_string(),
key_pass: passphrase.map(str::to_string),
}
}
pub fn with_key_file<T: AsRef<Path>>(key_file_path: T, passphrase: Option<&str>) -> Self {
Self::PrivateKeyFile {
key_file_path: key_file_path.as_ref().to_path_buf(),
key_pass: passphrase.map(str::to_string),
}
}
#[cfg(not(target_os = "windows"))]
pub fn with_public_key_file<T: AsRef<Path>>(key_file_path: T) -> Self {
Self::PublicKeyFile {
key_file_path: key_file_path.as_ref().to_path_buf(),
}
}
#[cfg(not(target_os = "windows"))]
pub fn with_agent() -> Self {
Self::Agent
}
pub const fn with_keyboard_interactive(auth: AuthKeyboardInteractive) -> Self {
Self::KeyboardInteractive(auth)
}
}
impl AuthKeyboardInteractive {
pub fn new() -> Self {
Default::default()
}
pub fn with_submethods(mut self, submethods: impl Into<String>) -> Self {
self.submethods = Some(submethods.into());
self
}
pub fn with_response(mut self, prompt: impl Into<String>, response: impl Into<String>) -> Self {
self.responses.push(PromptResponse {
exact: false,
prompt: prompt.into(),
response: response.into(),
});
self
}
pub fn with_response_exact(
mut self,
prompt: impl Into<String>,
response: impl Into<String>,
) -> Self {
self.responses.push(PromptResponse {
exact: true,
prompt: prompt.into(),
response: response.into(),
});
self
}
}
impl PromptResponse {
fn matches(&self, received_prompt: &str) -> bool {
if self.exact {
self.prompt.eq(received_prompt)
} else {
received_prompt.contains(&self.prompt)
}
}
}
impl From<AuthKeyboardInteractive> for AuthMethod {
fn from(value: AuthKeyboardInteractive) -> Self {
Self::with_keyboard_interactive(value)
}
}
impl ServerCheckMethod {
pub fn with_public_key(key: &str) -> Self {
Self::PublicKey(key.to_string())
}
pub fn with_public_key_file(key_file_name: &str) -> Self {
Self::PublicKeyFile(key_file_name.to_string())
}
pub fn with_known_hosts_file(known_hosts_file: &str) -> Self {
Self::KnownHostsFile(known_hosts_file.to_string())
}
}
#[derive(Clone)]
pub struct Client {
connection_handle: Arc<Handle<ClientHandler>>,
username: String,
address: SocketAddr,
#[allow(private_interfaces)]
pub session: Arc<Handle<ClientHandler>>,
}
impl Client {
pub async fn connect(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
) -> Result<Self, super::Error> {
Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
}
pub async fn connect_with_config(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
) -> Result<Self, super::Error> {
let config = Arc::new(config);
let socket_addrs = addr
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let mut connect_res = Err(super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)));
for socket_addr in socket_addrs {
let handler = ClientHandler {
hostname: addr.hostname(),
host: socket_addr,
server_check: server_check.clone(),
};
match russh::client::connect(config.clone(), socket_addr, handler).await {
Ok(h) => {
connect_res = Ok((socket_addr, h));
break;
}
Err(e) => connect_res = Err(e),
}
}
let (address, mut handle) = connect_res?;
let username = username.to_string();
Self::authenticate(&mut handle, &username, auth).await?;
let connection_handle = Arc::new(handle);
Ok(Self {
connection_handle: connection_handle.clone(),
username,
address,
session: connection_handle,
})
}
pub fn from_handle_and_address(
handle: Arc<Handle<ClientHandler>>,
username: String,
address: SocketAddr,
) -> Self {
Self {
connection_handle: handle.clone(),
username,
address,
session: handle,
}
}
async fn authenticate(
handle: &mut Handle<ClientHandler>,
username: &String,
auth: AuthMethod,
) -> Result<(), super::Error> {
match auth {
AuthMethod::Password(password) => {
let is_authentificated = handle.authenticate_password(username, password).await?;
if !is_authentificated.success() {
return Err(super::Error::PasswordWrong);
}
}
AuthMethod::PrivateKey { key_data, key_pass } => {
let cprivk = russh::keys::decode_secret_key(key_data.as_str(), key_pass.as_deref())
.map_err(super::Error::KeyInvalid)?;
let is_authentificated = handle
.authenticate_publickey(
username,
russh::keys::PrivateKeyWithHashAlg::new(
Arc::new(cprivk),
handle.best_supported_rsa_hash().await?.flatten(),
),
)
.await?;
if !is_authentificated.success() {
return Err(super::Error::KeyAuthFailed);
}
}
AuthMethod::PrivateKeyFile {
key_file_path,
key_pass,
} => {
let cprivk = russh::keys::load_secret_key(key_file_path, key_pass.as_deref())
.map_err(super::Error::KeyInvalid)?;
let is_authentificated = handle
.authenticate_publickey(
username,
russh::keys::PrivateKeyWithHashAlg::new(
Arc::new(cprivk),
handle.best_supported_rsa_hash().await?.flatten(),
),
)
.await?;
if !is_authentificated.success() {
return Err(super::Error::KeyAuthFailed);
}
}
#[cfg(not(target_os = "windows"))]
AuthMethod::PublicKeyFile { key_file_path } => {
let cpubk = russh::keys::load_public_key(key_file_path)
.map_err(super::Error::KeyInvalid)?;
let mut agent = russh::keys::agent::client::AgentClient::connect_env()
.await
.unwrap();
let mut auth_identity: Option<russh::keys::PublicKey> = None;
for identity in agent
.request_identities()
.await
.map_err(super::Error::KeyInvalid)?
{
if identity == cpubk {
auth_identity = Some(identity.clone());
break;
}
}
if auth_identity.is_none() {
return Err(super::Error::KeyAuthFailed);
}
let is_authentificated = handle
.authenticate_publickey_with(
username,
cpubk,
handle.best_supported_rsa_hash().await?.flatten(),
&mut agent,
)
.await?;
if !is_authentificated.success() {
return Err(super::Error::KeyAuthFailed);
}
}
#[cfg(not(target_os = "windows"))]
AuthMethod::Agent => {
let mut agent = russh::keys::agent::client::AgentClient::connect_env()
.await
.map_err(|_| super::Error::AgentConnectionFailed)?;
let identities = agent
.request_identities()
.await
.map_err(|_| super::Error::AgentRequestIdentitiesFailed)?;
if identities.is_empty() {
return Err(super::Error::AgentNoIdentities);
}
let mut auth_success = false;
for identity in identities {
let result = handle
.authenticate_publickey_with(
username,
identity.clone(),
handle.best_supported_rsa_hash().await?.flatten(),
&mut agent,
)
.await;
if let Ok(auth_result) = result {
if auth_result.success() {
auth_success = true;
break;
}
}
}
if !auth_success {
return Err(super::Error::AgentAuthenticationFailed);
}
}
AuthMethod::KeyboardInteractive(mut kbd) => {
let mut res = handle
.authenticate_keyboard_interactive_start(username, kbd.submethods)
.await?;
loop {
let prompts = match res {
KeyboardInteractiveAuthResponse::Success => break,
KeyboardInteractiveAuthResponse::Failure { .. } => {
return Err(super::Error::KeyboardInteractiveAuthFailed);
}
KeyboardInteractiveAuthResponse::InfoRequest { prompts, .. } => prompts,
};
let mut responses = vec![];
for prompt in prompts {
let Some(pos) = kbd
.responses
.iter()
.position(|pr| pr.matches(&prompt.prompt))
else {
return Err(super::Error::KeyboardInteractiveNoResponseForPrompt(
prompt.prompt,
));
};
let pr = kbd.responses.remove(pos);
responses.push(pr.response);
}
res = handle
.authenticate_keyboard_interactive_respond(responses)
.await?;
}
}
};
Ok(())
}
pub async fn get_channel(&self) -> Result<Channel<Msg>, super::Error> {
self.connection_handle
.channel_open_session()
.await
.map_err(super::Error::SshError)
}
pub async fn open_direct_tcpip_channel<
T: ToSocketAddrsWithHostname,
S: Into<Option<SocketAddr>>,
>(
&self,
target: T,
src: S,
) -> Result<Channel<Msg>, super::Error> {
let targets = target
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let src = src
.into()
.map(|src| (src.ip().to_string(), src.port().into()))
.unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
let mut connect_err = super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
));
for target in targets {
match self
.connection_handle
.channel_open_direct_tcpip(
target.ip().to_string(),
target.port().into(),
src.0.clone(),
src.1,
)
.await
{
Ok(channel) => return Ok(channel),
Err(err) => connect_err = super::Error::SshError(err),
}
}
Err(connect_err)
}
pub async fn upload_file<T: AsRef<Path>, U: Into<String>>(
&self,
src_file_path: T,
dest_file_path: U,
) -> Result<(), super::Error> {
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
let file_contents = tokio::fs::read(src_file_path)
.await
.map_err(super::Error::IoError)?;
let mut file = sftp
.open_with_flags(
dest_file_path,
OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE | OpenFlags::READ,
)
.await?;
file.write_all(&file_contents)
.await
.map_err(super::Error::IoError)?;
file.flush().await.map_err(super::Error::IoError)?;
file.shutdown().await.map_err(super::Error::IoError)?;
Ok(())
}
pub async fn download_file<T: AsRef<Path>, U: Into<String>>(
&self,
remote_file_path: U,
local_file_path: T,
) -> Result<(), super::Error> {
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
let mut remote_file = sftp
.open_with_flags(remote_file_path, OpenFlags::READ)
.await?;
let mut pooled_buffer = global::get_large_buffer();
remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?;
let contents = pooled_buffer.as_vec().clone();
let mut local_file = tokio::fs::File::create(local_file_path.as_ref())
.await
.map_err(super::Error::IoError)?;
local_file
.write_all(&contents)
.await
.map_err(super::Error::IoError)?;
local_file.flush().await.map_err(super::Error::IoError)?;
Ok(())
}
pub async fn upload_dir<T: AsRef<Path>, U: Into<String>>(
&self,
local_dir_path: T,
remote_dir_path: U,
) -> Result<(), super::Error> {
let local_dir = local_dir_path.as_ref();
let remote_dir = remote_dir_path.into();
if !local_dir.is_dir() {
return Err(super::Error::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Local directory does not exist: {local_dir:?}"),
)));
}
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
let _ = sftp.create_dir(&remote_dir).await;
self.upload_dir_recursive(&sftp, local_dir, &remote_dir)
.await?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn upload_dir_recursive<'a>(
&'a self,
sftp: &'a SftpSession,
local_dir: &'a Path,
remote_dir: &'a str,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), super::Error>> + Send + 'a>>
{
Box::pin(async move {
let entries = tokio::fs::read_dir(local_dir)
.await
.map_err(super::Error::IoError)?;
let mut entries = entries;
while let Some(entry) = entries.next_entry().await.map_err(super::Error::IoError)? {
let path = entry.path();
let file_name = entry.file_name();
let file_name_str = file_name.to_string_lossy();
let remote_path = format!("{remote_dir}/{file_name_str}");
let metadata = entry.metadata().await.map_err(super::Error::IoError)?;
if metadata.is_dir() {
let _ = sftp.create_dir(&remote_path).await; self.upload_dir_recursive(sftp, &path, &remote_path).await?;
} else if metadata.is_file() {
let file_contents = tokio::fs::read(&path)
.await
.map_err(super::Error::IoError)?;
let mut remote_file = sftp
.open_with_flags(
&remote_path,
OpenFlags::CREATE | OpenFlags::TRUNCATE | OpenFlags::WRITE,
)
.await?;
remote_file
.write_all(&file_contents)
.await
.map_err(super::Error::IoError)?;
remote_file.flush().await.map_err(super::Error::IoError)?;
remote_file
.shutdown()
.await
.map_err(super::Error::IoError)?;
}
}
Ok(())
})
}
pub async fn download_dir<T: AsRef<Path>, U: Into<String>>(
&self,
remote_dir_path: U,
local_dir_path: T,
) -> Result<(), super::Error> {
let local_dir = local_dir_path.as_ref();
let remote_dir = remote_dir_path.into();
let channel = self.get_channel().await?;
channel.request_subsystem(true, "sftp").await?;
let sftp = SftpSession::new(channel.into_stream()).await?;
tokio::fs::create_dir_all(local_dir)
.await
.map_err(super::Error::IoError)?;
self.download_dir_recursive(&sftp, &remote_dir, local_dir)
.await?;
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn download_dir_recursive<'a>(
&'a self,
sftp: &'a SftpSession,
remote_dir: &'a str,
local_dir: &'a Path,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), super::Error>> + Send + 'a>>
{
Box::pin(async move {
let entries = sftp.read_dir(remote_dir).await?;
for entry in entries {
let name = entry.file_name();
let metadata = entry.metadata();
if name == "." || name == ".." {
continue;
}
let remote_path = format!("{remote_dir}/{name}");
let local_path = local_dir.join(&name);
if metadata.file_type().is_dir() {
tokio::fs::create_dir_all(&local_path)
.await
.map_err(super::Error::IoError)?;
self.download_dir_recursive(sftp, &remote_path, &local_path)
.await?;
} else if metadata.file_type().is_file() {
let mut remote_file =
sftp.open_with_flags(&remote_path, OpenFlags::READ).await?;
let mut pooled_buffer = global::get_large_buffer();
remote_file.read_to_end(pooled_buffer.as_mut_vec()).await?;
let contents = pooled_buffer.as_vec().clone();
tokio::fs::write(&local_path, contents)
.await
.map_err(super::Error::IoError)?;
}
}
Ok(())
})
}
pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, super::Error> {
let sanitized_command = crate::utils::sanitize_command(command)
.map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?;
let mut stdout_buffer = Vec::with_capacity(SSH_CMD_BUFFER_SIZE);
let mut stderr_buffer = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE);
let mut channel = self.connection_handle.channel_open_session().await?;
channel.exec(true, sanitized_command.as_str()).await?;
let mut result: Option<u32> = None;
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => {
stdout_buffer.write_all(data).await.unwrap()
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
stderr_buffer.write_all(data).await.unwrap()
}
}
russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
_ => {}
}
}
if let Some(result) = result {
Ok(CommandExecutedResult {
stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
exit_status: result,
})
} else {
Err(super::Error::CommandDidntExit)
}
}
pub async fn request_interactive_shell(
&self,
term_type: &str,
width: u32,
height: u32,
) -> Result<Channel<Msg>, super::Error> {
let channel = self.connection_handle.channel_open_session().await?;
channel
.request_pty(
false,
term_type,
width,
height,
0, 0, &[], )
.await?;
channel.request_shell(false).await?;
Ok(channel)
}
pub async fn resize_pty(
&self,
channel: &mut Channel<Msg>,
width: u32,
height: u32,
) -> Result<(), super::Error> {
channel
.window_change(width, height, 0, 0)
.await
.map_err(super::Error::SshError)
}
pub fn get_connection_username(&self) -> &String {
&self.username
}
pub fn get_connection_address(&self) -> &SocketAddr {
&self.address
}
pub async fn disconnect(&self) -> Result<(), super::Error> {
self.connection_handle
.disconnect(russh::Disconnect::ByApplication, "", "")
.await
.map_err(super::Error::SshError)
}
pub fn is_closed(&self) -> bool {
self.connection_handle.is_closed()
}
}
impl Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("username", &self.username)
.field("address", &self.address)
.field("connection_handle", &"Handle<ClientHandler>")
.finish()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandExecutedResult {
pub stdout: String,
pub stderr: String,
pub exit_status: u32,
}
#[derive(Debug, Clone)]
pub struct ClientHandler {
hostname: String,
host: SocketAddr,
server_check: ServerCheckMethod,
}
impl ClientHandler {
pub fn new(hostname: String, host: SocketAddr, server_check: ServerCheckMethod) -> Self {
Self {
hostname,
host,
server_check,
}
}
}
impl Handler for ClientHandler {
type Error = super::Error;
async fn check_server_key(
&mut self,
server_public_key: &russh::keys::PublicKey,
) -> Result<bool, Self::Error> {
match &self.server_check {
ServerCheckMethod::NoCheck => Ok(true),
ServerCheckMethod::PublicKey(key) => {
let pk = russh::keys::parse_public_key_base64(key)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(pk == *server_public_key)
}
ServerCheckMethod::PublicKeyFile(key_file_name) => {
let pk = russh::keys::load_public_key(key_file_name)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(pk == *server_public_key)
}
ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
let result = russh::keys::check_known_hosts_path(
&self.hostname,
self.host.port(),
server_public_key,
known_hosts_path,
)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(result)
}
ServerCheckMethod::DefaultKnownHostsFile => {
let result = russh::keys::check_known_hosts(
&self.hostname,
self.host.port(),
server_public_key,
)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(result)
}
}
}
}