use russh::client::{Config, Handle, Handler};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt::Debug, io};
use super::authentication::{AuthMethod, ServerCheckMethod};
pub const DEFAULT_KEEPALIVE_INTERVAL: u64 = 60;
pub const DEFAULT_KEEPALIVE_MAX: usize = 3;
#[derive(Debug, Clone)]
pub struct SshConnectionConfig {
pub keepalive_interval: Option<u64>,
pub keepalive_max: usize,
}
impl Default for SshConnectionConfig {
fn default() -> Self {
Self {
keepalive_interval: Some(DEFAULT_KEEPALIVE_INTERVAL),
keepalive_max: DEFAULT_KEEPALIVE_MAX,
}
}
}
impl SshConnectionConfig {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_keepalive_interval(mut self, interval: Option<u64>) -> Self {
self.keepalive_interval = interval;
self
}
#[must_use]
pub fn with_keepalive_max(mut self, max: usize) -> Self {
self.keepalive_max = max;
self
}
pub fn to_russh_config(&self) -> Config {
let inactivity_timeout = if self.keepalive_interval.is_some() {
None
} else {
Some(Duration::from_secs(3600))
};
Config {
keepalive_interval: self.keepalive_interval.map(Duration::from_secs),
keepalive_max: self.keepalive_max,
inactivity_timeout,
..Default::default()
}
}
pub fn to_tcp_keepalive(&self) -> Option<socket2::TcpKeepalive> {
let interval = self.keepalive_interval?;
let probe_interval = (interval / 2).max(1);
let ka = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(interval))
.with_interval(Duration::from_secs(probe_interval));
#[cfg(any(
target_os = "linux",
target_os = "macos",
target_os = "freebsd",
target_os = "netbsd",
target_os = "tvos",
target_os = "watchos",
target_os = "ios",
))]
let ka = ka.with_retries(self.keepalive_max.max(1) as u32);
Some(ka)
}
}
use super::ToSocketAddrsWithHostname;
#[derive(Clone)]
pub struct Client {
pub(super) connection_handle: Arc<Handle<ClientHandler>>,
pub(super) username: String,
pub(super) address: SocketAddr,
#[allow(private_interfaces)]
pub session: Arc<Handle<ClientHandler>>,
}
impl Client {
pub async fn connect(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
) -> Result<Self, super::Error> {
Self::connect_with_ssh_config(
addr,
username,
auth,
server_check,
&SshConnectionConfig::default(),
)
.await
}
pub async fn connect_with_ssh_config(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
ssh_config: &SshConnectionConfig,
) -> Result<Self, super::Error> {
let config = ssh_config.to_russh_config();
let tcp_keepalive = ssh_config.to_tcp_keepalive();
Self::connect_with_config_inner(
addr,
username,
auth,
server_check,
config,
tcp_keepalive.as_ref(),
)
.await
}
pub async fn connect_with_config(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
) -> Result<Self, super::Error> {
Self::connect_with_config_inner(addr, username, auth, server_check, config, None).await
}
async fn connect_with_config_inner(
addr: impl ToSocketAddrsWithHostname,
username: &str,
auth: AuthMethod,
server_check: ServerCheckMethod,
config: Config,
tcp_keepalive: Option<&socket2::TcpKeepalive>,
) -> Result<Self, super::Error> {
let config = Arc::new(config);
let socket_addrs = addr
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let mut connect_res: Result<
(SocketAddr, russh::client::Handle<ClientHandler>),
super::Error,
> = Err(super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
)));
for socket_addr in socket_addrs {
let handler = ClientHandler {
hostname: addr.hostname(),
host: socket_addr,
server_check: server_check.clone(),
};
let stream = match tokio::net::TcpStream::connect(socket_addr).await {
Ok(s) => s,
Err(e) => {
connect_res = Err(super::Error::IoError(e));
continue;
}
};
if let Some(ka) = tcp_keepalive {
let sock_ref = socket2::SockRef::from(&stream);
if let Err(e) = sock_ref.set_tcp_keepalive(ka) {
tracing::debug!(
"Failed to set TCP keepalive on socket to {}: {}",
socket_addr,
e
);
}
}
match russh::client::connect_stream(config.clone(), stream, handler).await {
Ok(h) => {
connect_res = Ok((socket_addr, h));
break;
}
Err(e) => connect_res = Err(e),
}
}
let (address, mut handle) = connect_res?;
let username = username.to_string();
super::authentication::authenticate(&mut handle, &username, auth).await?;
let connection_handle = Arc::new(handle);
Ok(Self {
connection_handle: connection_handle.clone(),
username,
address,
session: connection_handle,
})
}
pub fn from_handle_and_address(
handle: Arc<Handle<ClientHandler>>,
username: String,
address: SocketAddr,
) -> Self {
Self {
connection_handle: handle.clone(),
username,
address,
session: handle,
}
}
pub fn get_connection_username(&self) -> &String {
&self.username
}
pub fn get_connection_address(&self) -> &SocketAddr {
&self.address
}
pub async fn disconnect(&self) -> Result<(), super::Error> {
self.connection_handle
.disconnect(russh::Disconnect::ByApplication, "", "")
.await
.map_err(super::Error::SshError)
}
pub fn is_closed(&self) -> bool {
self.connection_handle.is_closed()
}
pub async fn request_port_forward(
&self,
_bind_address: String,
_bind_port: u32,
) -> Result<u32, super::Error> {
tracing::warn!("Remote port forwarding request not yet implemented - TODO");
Err(super::Error::PortForwardingNotSupported)
}
pub async fn cancel_port_forward(
&self,
_bind_address: String,
_bind_port: u32,
) -> Result<(), super::Error> {
tracing::warn!("Cancel remote port forwarding not yet implemented - TODO");
Err(super::Error::PortForwardingNotSupported)
}
}
impl Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("username", &self.username)
.field("address", &self.address)
.field("connection_handle", &"Handle<ClientHandler>")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct ClientHandler {
hostname: String,
host: SocketAddr,
server_check: ServerCheckMethod,
}
impl ClientHandler {
pub fn new(hostname: String, host: SocketAddr, server_check: ServerCheckMethod) -> Self {
Self {
hostname,
host,
server_check,
}
}
}
impl Handler for ClientHandler {
type Error = super::Error;
async fn check_server_key(
&mut self,
server_public_key: &russh::keys::PublicKey,
) -> Result<bool, Self::Error> {
match &self.server_check {
ServerCheckMethod::NoCheck => Ok(true),
ServerCheckMethod::PublicKey(key) => {
let pk = russh::keys::parse_public_key_base64(key)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(pk == *server_public_key)
}
ServerCheckMethod::PublicKeyFile(key_file_name) => {
let pk = russh::keys::load_public_key(key_file_name)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(pk == *server_public_key)
}
ServerCheckMethod::KnownHostsFile(known_hosts_path) => {
let result = russh::keys::check_known_hosts_path(
&self.hostname,
self.host.port(),
server_public_key,
known_hosts_path,
)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(result)
}
ServerCheckMethod::DefaultKnownHostsFile => {
let result = russh::keys::check_known_hosts(
&self.hostname,
self.host.port(),
server_public_key,
)
.map_err(|_| super::Error::ServerCheckFailed)?;
Ok(result)
}
}
}
}