use std::future::Future;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use russh::client::{self, Msg};
use russh::{ChannelMsg, ChannelStream};
use tokio::sync::Mutex;
use super::config;
use crate::error::SshError;
pub struct CommandOutput {
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
pub success: bool,
}
pub struct IncomingForward {
pub remote_port: u16,
pub channel: russh::Channel<russh::client::Msg>,
}
struct ClientHandler {
forwarded_tx: Option<tokio::sync::mpsc::UnboundedSender<IncomingForward>>,
}
impl client::Handler for ClientHandler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh::keys::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
async fn server_channel_open_forwarded_tcpip(
&mut self,
channel: russh::Channel<russh::client::Msg>,
_connected_address: &str,
connected_port: u32,
_originator_address: &str,
_originator_port: u32,
_session: &mut russh::client::Session,
) -> Result<(), Self::Error> {
if let Some(tx) = &self.forwarded_tx {
let _ = tx.send(IncomingForward {
remote_port: match u16::try_from(connected_port) {
Ok(p) => p,
Err(_) => return Ok(()),
},
channel,
});
}
Ok(())
}
}
#[derive(Clone)]
pub struct Session {
handle: Arc<Mutex<client::Handle<ClientHandler>>>,
_jump_session: Option<Box<Session>>,
}
impl Session {
pub fn connect(
destination: &str,
forwarded_tx: Option<tokio::sync::mpsc::UnboundedSender<IncomingForward>>,
) -> Pin<Box<dyn Future<Output = Result<Self, SshError>> + Send + '_>> {
Box::pin(async move {
let (explicit_user, host) = config::parse_destination(destination);
let cfg = config::resolve_host_config(&host);
let user = explicit_user
.or(cfg.user)
.unwrap_or_else(|| std::env::var("USER").unwrap_or_else(|_| "root".into()));
let resolved_host = cfg.hostname.unwrap_or_else(|| host.to_string());
let resolved_port = cfg.port.unwrap_or(22);
let (mut handle, jump_session) = if let Some(ref jump_dest) = cfg.proxy_jump {
let jump = Session::connect(jump_dest, None).await?;
let target_host = resolved_host.clone();
let target_port = resolved_port as u32;
let channel = jump
.handle
.lock()
.await
.channel_open_direct_tcpip(target_host, target_port, "127.0.0.1", 0)
.await
.map_err(|e| SshError::Connection {
destination: destination.to_string(),
source: e,
})?;
let tunnel = channel.into_stream();
let config = Arc::new(client::Config::default());
let handle = client::connect_stream(
config,
tunnel,
ClientHandler {
forwarded_tx: forwarded_tx.clone(),
},
)
.await
.map_err(|e| SshError::Connection {
destination: destination.to_string(),
source: e,
})?;
(handle, Some(Box::new(jump)))
} else {
let addr = format!("{resolved_host}:{resolved_port}");
let stream = tokio::net::TcpStream::connect(&addr)
.await
.map_err(|e| SshError::Config(format!("failed to connect to {addr}: {e}")))?;
let config = Arc::new(client::Config::default());
let handle = client::connect_stream(
config,
stream,
ClientHandler {
forwarded_tx: forwarded_tx.clone(),
},
)
.await
.map_err(|e| SshError::Connection {
destination: destination.to_string(),
source: e,
})?;
(handle, None)
};
if !authenticate(&mut handle, &user, &cfg.identity_files).await? {
return Err(SshError::Auth {
destination: destination.to_string(),
message: format!(
"all authentication methods failed for {user}@{resolved_host}"
),
});
}
Ok(Self {
handle: Arc::new(Mutex::new(handle)),
_jump_session: jump_session,
})
})
}
pub async fn open_direct_tcpip(
&self,
host: &str,
port: u16,
) -> Result<ChannelStream<Msg>, SshError> {
let channel = self
.handle
.lock()
.await
.channel_open_direct_tcpip(host.to_string(), port as u32, "127.0.0.1", 0)
.await
.map_err(SshError::Remote)?;
Ok(channel.into_stream())
}
pub async fn exec(&self, command: &str) -> Result<CommandOutput, SshError> {
let mut channel = self
.handle
.lock()
.await
.channel_open_session()
.await
.map_err(SshError::Remote)?;
channel
.exec(true, command)
.await
.map_err(SshError::Remote)?;
collect_channel_output(&mut channel).await
}
pub async fn exec_streaming(&self, command: &str) -> Result<ChannelStream<Msg>, SshError> {
let channel = self
.handle
.lock()
.await
.channel_open_session()
.await
.map_err(SshError::Remote)?;
channel
.exec(true, command)
.await
.map_err(SshError::Remote)?;
Ok(channel.into_stream())
}
pub async fn exec_with_stdin(
&self,
command: &str,
data: &[u8],
) -> Result<CommandOutput, SshError> {
let mut channel = self
.handle
.lock()
.await
.channel_open_session()
.await
.map_err(SshError::Remote)?;
channel
.exec(true, command)
.await
.map_err(SshError::Remote)?;
channel.data(data).await.map_err(SshError::Remote)?;
channel.eof().await.map_err(SshError::Remote)?;
collect_channel_output(&mut channel).await
}
pub async fn tcpip_forward(&self, port: u16) -> Result<u16, SshError> {
self.handle
.lock()
.await
.tcpip_forward("127.0.0.1", port as u32)
.await
.map(|p| p as u16)
.map_err(SshError::Remote)
}
pub async fn cancel_tcpip_forward(&self, port: u16) -> Result<(), SshError> {
self.handle
.lock()
.await
.cancel_tcpip_forward("127.0.0.1", port as u32)
.await
.map_err(SshError::Remote)
}
}
async fn collect_channel_output(
channel: &mut russh::Channel<Msg>,
) -> Result<CommandOutput, SshError> {
let mut stdout = Vec::new();
let mut stderr = Vec::new();
let mut exit_status = 0u32;
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => stdout.extend_from_slice(&data),
Some(ChannelMsg::ExtendedData { data, ext: 1 }) => stderr.extend_from_slice(&data),
Some(ChannelMsg::ExitStatus { exit_status: code }) => exit_status = code,
None => break,
_ => {}
}
}
Ok(CommandOutput {
stdout,
stderr,
success: exit_status == 0,
})
}
async fn authenticate(
handle: &mut client::Handle<ClientHandler>,
user: &str,
identity_files: &[PathBuf],
) -> Result<bool, SshError> {
let rsa_hash = handle
.best_supported_rsa_hash()
.await
.ok()
.flatten()
.flatten();
if let Ok(mut agent) = russh::keys::agent::client::AgentClient::connect_env().await {
if let Ok(identities) = agent.request_identities().await {
for key in identities {
match handle
.authenticate_publickey_with(
user,
key.public_key().into_owned(),
rsa_hash,
&mut agent,
)
.await
{
Ok(res) if res.success() => return Ok(true),
_ => continue,
}
}
}
}
for path in identity_files {
if try_key_file(handle, user, rsa_hash, path).await? {
return Ok(true);
}
}
let home = std::env::var("HOME").unwrap_or_else(|_| ".".into());
let default_keys = [
PathBuf::from(format!("{home}/.ssh/id_ed25519")),
PathBuf::from(format!("{home}/.ssh/id_rsa")),
PathBuf::from(format!("{home}/.ssh/id_ecdsa")),
];
for path in &default_keys {
if try_key_file(handle, user, rsa_hash, path).await? {
return Ok(true);
}
}
Ok(false)
}
async fn try_key_file(
handle: &mut client::Handle<ClientHandler>,
user: &str,
rsa_hash: Option<russh::keys::HashAlg>,
path: &Path,
) -> Result<bool, SshError> {
if !path.exists() {
return Ok(false);
}
let key = match russh::keys::load_secret_key(path, None) {
Ok(k) => k,
Err(_) => return Ok(false),
};
let key = russh::keys::PrivateKeyWithHashAlg::new(Arc::new(key), rsa_hash);
match handle.authenticate_publickey(user, key).await {
Ok(res) if res.success() => Ok(true),
_ => Ok(false),
}
}