use russh::client::{Config, Handle, Handler};
use std::net::SocketAddr;
use std::sync::Arc;
use std::{fmt::Debug, io};
use super::authentication::{AuthMethod, ServerCheckMethod};
use super::ToSocketAddrsWithHostname;
#[derive(Clone)]
pub struct Client {
pub(super) connection_handle: Arc<Handle<ClientHandler>>,
pub(super) username: String,
pub(super) address: SocketAddr,
#[allow(private_interfaces)]
pub session: Arc<Handle<ClientHandler>>,
}
impl Client {
pub async fn connect(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
) -> Result<Self, super::Error> {
Self::connect_with_config(addr, username, auth, server_check, Config::default()).await
}
pub async fn connect_with_config(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
) -> Result<Self, super::Error> {
let config = Arc::new(config);
let socket_addrs = addr
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let mut connect_res = Err(super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)));
for socket_addr in socket_addrs {
let handler = ClientHandler {
hostname: addr.hostname(),
host: socket_addr,
server_check: server_check.clone(),
};
match russh::client::connect(config.clone(), socket_addr, handler).await {
Ok(h) => {
connect_res = Ok((socket_addr, h));
break;
}
Err(e) => connect_res = Err(e),
}
}
let (address, mut handle) = connect_res?;
let username = username.to_string();
super::authentication::authenticate(&mut handle, &username, auth).await?;
let connection_handle = Arc::new(handle);
Ok(Self {
connection_handle: connection_handle.clone(),
username,
address,
session: connection_handle,
})
}
pub fn from_handle_and_address(
handle: Arc<Handle<ClientHandler>>,
username: String,
address: SocketAddr,
) -> Self {
Self {
connection_handle: handle.clone(),
username,
address,
session: handle,
}
}
pub fn get_connection_username(&self) -> &String {
&self.username
}
pub fn get_connection_address(&self) -> &SocketAddr {
&self.address
}
pub async fn disconnect(&self) -> Result<(), super::Error> {
self.connection_handle
.disconnect(russh::Disconnect::ByApplication, "", "")
.await
.map_err(super::Error::SshError)
}
pub fn is_closed(&self) -> bool {
self.connection_handle.is_closed()
}
pub async fn request_port_forward(
&self,
_bind_address: String,
_bind_port: u32,
) -> Result<u32, super::Error> {
tracing::warn!("Remote port forwarding request not yet implemented - TODO");
Err(super::Error::PortForwardingNotSupported)
}
pub async fn cancel_port_forward(
&self,
_bind_address: String,
_bind_port: u32,
) -> Result<(), super::Error> {
tracing::warn!("Cancel remote port forwarding not yet implemented - TODO");
Err(super::Error::PortForwardingNotSupported)
}
}
impl Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("username", &self.username)
.field("address", &self.address)
.field("connection_handle", &"Handle<ClientHandler>")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ClientHandler {
hostname: String,
host: SocketAddr,
server_check: ServerCheckMethod,
}
impl ClientHandler {
pub fn new(hostname: String, host: SocketAddr, server_check: ServerCheckMethod) -> Self {
Self {
hostname,
host,
server_check,
}
}
}
impl Handler for ClientHandler {
type Error = super::Error;
async fn check_server_key(
&mut self,
server_public_key: &russh::keys::PublicKey,
) -> Result<bool, Self::Error> {
match &self.server_check {
ServerCheckMethod::NoCheck => Ok(true),
ServerCheckMethod::PublicKey(key) => {
let pk = russh::keys::parse_public_key_base64(key)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(pk == *server_public_key)
}
ServerCheckMethod::PublicKeyFile(key_file_name) => {
let pk = russh::keys::load_public_key(key_file_name)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(pk == *server_public_key)
}
ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
let result = russh::keys::check_known_hosts_path(
&self.hostname,
self.host.port(),
server_public_key,
known_hosts_path,
)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(result)
}
ServerCheckMethod::DefaultKnownHostsFile => {
let result = russh::keys::check_known_hosts(
&self.hostname,
self.host.port(),
server_public_key,
)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(result)
}
}
}
}