use async_trait::async_trait;
use russh::client::{Config, Handle, Handler};
use std::io::{self, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum AuthMethod {
Password(String),
PrivateKey {
key_data: String,
key_pass: Option<String>,
},
PrivateKeyFile {
key_file_name: String,
key_pass: Option<String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum ServerCheckMethod {
NoCheck,
PublicKey(String),
PublicKeyFile(String),
DefaultKnownHostsFile,
KnownHostsFile(String),
}
impl AuthMethod {
pub fn with_password(password: &str) -> Self {
Self::Password(password.to_string())
}
pub fn with_key(key: &str, passphrase: Option<&str>) -> Self {
Self::PrivateKey {
key_data: key.to_string(),
key_pass: passphrase.map(str::to_string),
}
}
pub fn with_key_file(key_file_name: &str, passphrase: Option<&str>) -> Self {
Self::PrivateKeyFile {
key_file_name: key_file_name.to_string(),
key_pass: passphrase.map(str::to_string),
}
}
}
impl ServerCheckMethod {
pub fn with_public_key(key: &str) -> Self {
Self::PublicKey(key.to_string())
}
pub fn with_public_key_file(key_file_name: &str) -> Self {
Self::PublicKeyFile(key_file_name.to_string())
}
pub fn with_known_hosts_file(known_hosts_file: &str) -> Self {
Self::KnownHostsFile(known_hosts_file.to_string())
}
}
pub struct Client {
connection_handle: Handle<ClientHandler>,
username: String,
address: SocketAddr,
}
impl Client {
pub async fn connect(
addr: impl ToSocketAddrs,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
) -> Result<Self, crate::Error> {
Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
}
pub async fn connect_with_config(
addr: impl ToSocketAddrs,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
) -> Result<Self, crate::Error> {
let config = Arc::new(config);
let addrs = match addr.to_socket_addrs() {
Ok(addrs) => addrs,
Err(e) => return Err(crate::Error::AddressInvalid(e)),
};
let mut connect_res = Err(crate::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)));
for addr in addrs {
let handler = ClientHandler {
host: addr,
server_check: server_check.clone(),
};
match russh::client::connect(config.clone(), addr, handler).await {
Ok(h) => {
connect_res = Ok((addr, h));
break;
}
Err(e) => connect_res = Err(e),
}
}
let (address, mut handle) = connect_res?;
let username = username.to_string();
Self::authenticate(&mut handle, &username, auth).await?;
Ok(Self {
connection_handle: handle,
username,
address,
})
}
async fn authenticate(
handle: &mut Handle<ClientHandler>,
username: &String,
auth: AuthMethod,
) -> Result<(), crate::Error> {
match auth {
AuthMethod::Password(password) => {
let is_authentificated = handle.authenticate_password(username, password).await?;
if is_authentificated {
Ok(())
} else {
Err(crate::Error::PasswordWrong)
}
}
AuthMethod::PrivateKey { key_data, key_pass } => {
let cprivk =
match russh_keys::decode_secret_key(key_data.as_str(), key_pass.as_deref()) {
Ok(kp) => kp,
Err(e) => return Err(crate::Error::KeyInvalid(e)),
};
let is_authentificated = handle
.authenticate_publickey(username, Arc::new(cprivk))
.await?;
if is_authentificated {
Ok(())
} else {
Err(crate::Error::KeyAuthFailed)
}
}
AuthMethod::PrivateKeyFile {
key_file_name,
key_pass,
} => {
let cprivk = match russh_keys::load_secret_key(key_file_name, key_pass.as_deref()) {
Ok(kp) => kp,
Err(e) => return Err(crate::Error::KeyInvalid(e)),
};
let is_authentificated = handle
.authenticate_publickey(username, Arc::new(cprivk))
.await?;
if is_authentificated {
Ok(())
} else {
Err(crate::Error::KeyAuthFailed)
}
}
}
}
pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, crate::Error> {
let mut receive_buffer = vec![];
let mut channel = self.connection_handle.channel_open_session().await?;
channel.exec(true, command).await?;
let mut result: Option<u32> = None;
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => receive_buffer.write_all(data).unwrap(),
russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
_ => {}
}
}
if result.is_some() {
Ok(CommandExecutedResult {
output: String::from_utf8_lossy(&receive_buffer).to_string(),
exit_status: result.unwrap(),
})
} else {
Err(crate::Error::CommandDidntExit)
}
}
pub fn get_connection_username(&self) -> &String {
&self.username
}
pub fn get_connection_address(&self) -> &SocketAddr {
&self.address
}
pub async fn disconnect(&self) -> Result<(), russh::Error> {
match self
.connection_handle
.disconnect(russh::Disconnect::ByApplication, "", "")
.await
{
Ok(()) => Ok(()),
Err(e) => Err(e),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandExecutedResult {
pub output: String,
pub exit_status: u32,
}
#[derive(Clone)]
struct ClientHandler {
host: SocketAddr,
server_check: ServerCheckMethod,
}
#[async_trait]
impl Handler for ClientHandler {
type Error = crate::Error;
async fn check_server_key(
self,
server_public_key: &russh_keys::key::PublicKey,
) -> Result<(Self, bool), Self::Error> {
match &self.server_check {
ServerCheckMethod::NoCheck => Ok((self, true)),
ServerCheckMethod::PublicKey(key) => {
let pk = russh_keys::parse_public_key_base64(key)
.map_err(|_| crate::Error::ServerCheckFailed)?;
Ok((self, pk == *server_public_key))
}
ServerCheckMethod::PublicKeyFile(key_file_name) => {
let pk = russh_keys::load_public_key(key_file_name)
.map_err(|_| crate::Error::ServerCheckFailed)?;
Ok((self, pk == *server_public_key))
}
ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
let result = russh_keys::check_known_hosts_path(
&self.host.ip().to_string(),
self.host.port(),
server_public_key,
known_hosts_path,
)
.map_err(|_| crate::Error::ServerCheckFailed)?;
Ok((self, result))
}
ServerCheckMethod::DefaultKnownHostsFile => {
let result = russh_keys::check_known_hosts(
&self.host.ip().to_string(),
self.host.port(),
server_public_key,
)
.map_err(|_| crate::Error::ServerCheckFailed)?;
Ok((self, result))
}
}
}
}
#[cfg(test)]
mod tests {
use core::time;
use crate::client::*;
fn env(name: &str) -> String {
std::env::var(name).expect(
"Failed to get env var needed for test, make sure to set the following env vars:
ASYNC_SSH2_TEST_HOST_USER
ASYNC_SSH2_TEST_HOST_PW
ASYNC_SSH2_TEST_HOST_IP
ASYNC_SSH2_TEST_HOST_PORT
ASYNC_SSH2_TEST_CLIENT_PROT_PRIV
ASYNC_SSH2_TEST_CLIENT_PRIV
ASYNC_SSH2_TEST_CLIENT_PROT_PASS
ASYNC_SSH2_TEST_SERVER_PUB
",
)
}
fn test_address() -> SocketAddr {
format!(
"{}:{}",
env("ASYNC_SSH2_TEST_HOST_IP"),
env("ASYNC_SSH2_TEST_HOST_PORT")
)
.parse()
.unwrap()
}
async fn establish_test_host_connection() -> Client {
Client::connect(
(
env("ASYNC_SSH2_TEST_HOST_IP"),
env("ASYNC_SSH2_TEST_HOST_PORT").parse().unwrap(),
),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
ServerCheckMethod::NoCheck,
)
.await
.expect("Connection/Authentification failed")
}
#[tokio::test]
async fn connect_with_password() {
let client = establish_test_host_connection().await;
assert_eq!(
&env("ASYNC_SSH2_TEST_HOST_USER"),
client.get_connection_username(),
);
assert_eq!(test_address(), *client.get_connection_address(),);
}
#[tokio::test]
async fn execute_command_result() {
let client = establish_test_host_connection().await;
let output = client.execute("echo test!!!").await.unwrap();
assert_eq!("test!!!\n", output.output);
assert_eq!(0, output.exit_status);
}
#[tokio::test]
async fn unicode_output() {
let client = establish_test_host_connection().await;
let output = client.execute("echo To thḙ moon! 🚀").await.unwrap();
assert_eq!("To thḙ moon! 🚀\n", output.output);
assert_eq!(0, output.exit_status);
}
#[tokio::test]
async fn execute_command_status() {
let client = establish_test_host_connection().await;
let output = client.execute("exit 42").await.unwrap();
assert_eq!(42, output.exit_status);
}
#[tokio::test]
async fn execute_multiple_commands() {
let client = establish_test_host_connection().await;
let output = client.execute("echo test!!!").await.unwrap().output;
assert_eq!("test!!!\n", output);
let output = client.execute("echo Hello World").await.unwrap().output;
assert_eq!("Hello World\n", output);
}
#[tokio::test]
async fn stderr_redirection() {
let client = establish_test_host_connection().await;
let output = client.execute("echo foo >/dev/null").await.unwrap();
assert_eq!("", output.output);
let output = client.execute("echo foo >>/dev/stderr").await.unwrap();
assert_eq!("", output.output);
let output = client.execute("2>&1 echo foo >>/dev/stderr").await.unwrap();
assert_eq!("foo\n", output.output);
}
#[tokio::test]
async fn sequential_commands() {
let client = establish_test_host_connection().await;
for i in 0..1000 {
std::thread::sleep(time::Duration::from_millis(200));
let res = client
.execute(&format!("echo {i}"))
.await
.expect(&format!("Execution failed in iteration {i}"));
assert_eq!(format!("{i}\n"), res.output);
}
}
#[tokio::test]
async fn execute_multiple_context() {
let client = establish_test_host_connection().await;
let output = client
.execute("export VARIABLE=42; echo $VARIABLE")
.await
.unwrap()
.output;
assert_eq!("42\n", output);
let output = client.execute("echo $VARIABLE").await.unwrap().output;
assert_eq!("\n", output);
}
#[tokio::test]
async fn connect_second_address() {
let client = Client::connect(
&[SocketAddr::from(([127, 0, 0, 1], 23)), test_address()][..],
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
ServerCheckMethod::NoCheck,
)
.await
.expect("Resolution to second address failed");
assert_eq!(test_address(), *client.get_connection_address(),);
}
#[tokio::test]
async fn connect_with_wrong_password() {
let error = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password("hopefully the wrong password"),
ServerCheckMethod::NoCheck,
)
.await
.err()
.expect("Client connected with wrong password");
match error {
crate::Error::PasswordWrong => {}
_ => panic!("Wrong error type"),
}
}
#[tokio::test]
async fn invalid_address() {
let no_client = Client::connect(
"this is definitely not an address",
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password("hopefully the wrong password"),
ServerCheckMethod::NoCheck,
)
.await;
assert!(no_client.is_err());
}
#[tokio::test]
async fn connect_to_wrong_port() {
let no_client = Client::connect(
(env("ASYNC_SSH2_TEST_HOST_IP"), 23),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
ServerCheckMethod::NoCheck,
)
.await;
assert!(no_client.is_err());
}
#[tokio::test]
#[ignore = "This times out only after 20 seconds"]
async fn connect_to_wrong_host() {
let no_client = Client::connect(
"172.16.0.6:22",
"xxx",
AuthMethod::with_password("xxx"),
ServerCheckMethod::NoCheck,
)
.await;
assert!(no_client.is_err());
}
#[tokio::test]
async fn auth_key_file() {
let client = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_key_file(&env("ASYNC_SSH2_TEST_CLIENT_PRIV"), None),
ServerCheckMethod::NoCheck,
)
.await;
assert!(client.is_ok());
}
#[tokio::test]
async fn auth_key_file_with_passphrase() {
let client = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_key_file(
&env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV"),
Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS")),
),
ServerCheckMethod::NoCheck,
)
.await;
if client.is_err() {
println!("{:?}", client.err());
panic!();
}
assert!(client.is_ok());
}
#[tokio::test]
async fn auth_key_str() {
let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PRIV")).unwrap();
let client = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_key(key.as_str(), None),
ServerCheckMethod::NoCheck,
)
.await;
assert!(client.is_ok());
}
#[tokio::test]
async fn auth_key_str_with_passphrase() {
let key = std::fs::read_to_string(env("ASYNC_SSH2_TEST_CLIENT_PROT_PRIV")).unwrap();
let client = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_key(key.as_str(), Some(&env("ASYNC_SSH2_TEST_CLIENT_PROT_PASS"))),
ServerCheckMethod::NoCheck,
)
.await;
assert!(client.is_ok());
}
#[tokio::test]
async fn server_check_file() {
let client = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
ServerCheckMethod::with_public_key_file(&env("ASYNC_SSH2_TEST_SERVER_PUB")),
)
.await;
assert!(client.is_ok());
}
#[tokio::test]
async fn server_check_str() {
let line = std::fs::read_to_string(env("ASYNC_SSH2_TEST_SERVER_PUB")).unwrap();
let mut split = line.split_whitespace();
let key = match (split.next(), split.next()) {
(Some(_), Some(k)) => k,
(Some(k), None) => k,
_ => panic!("Failed to parse pub key file"),
};
let client = Client::connect(
test_address(),
&env("ASYNC_SSH2_TEST_HOST_USER"),
AuthMethod::with_password(&env("ASYNC_SSH2_TEST_HOST_PW")),
ServerCheckMethod::with_public_key(key),
)
.await;
assert!(client.is_ok());
}
}