use russh::client::Msg;
use russh::Channel;
use russh::CryptoVec;
use std::io;
use std::net::SocketAddr;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
use super::connection::Client;
use super::ToSocketAddrsWithHostname;
const OUTPUT_EVENTS_CHANNEL_SIZE: usize = 100;
#[derive(Debug, Clone)]
pub enum CommandOutput {
StdOut(CryptoVec),
StdErr(CryptoVec),
}
pub(crate) struct CommandOutputBuffer {
pub(crate) sender: Sender<CommandOutput>,
pub(crate) receiver_task: JoinHandle<(Vec<u8>, Vec<u8>)>,
}
impl CommandOutputBuffer {
pub(crate) fn new() -> Self {
let (sender, mut receiver): (Sender<CommandOutput>, Receiver<CommandOutput>) =
channel(OUTPUT_EVENTS_CHANNEL_SIZE);
let receiver_task = tokio::task::spawn(async move {
let mut stdout = Vec::with_capacity(1024); let mut stderr = Vec::with_capacity(256);
while let Some(output) = receiver.recv().await {
match output {
CommandOutput::StdOut(buffer) => {
let required = stdout.len() + buffer.len();
if stdout.capacity() < required {
let new_capacity =
required.max(stdout.capacity() + stdout.capacity() / 2);
stdout.reserve(new_capacity - stdout.capacity());
}
stdout.extend_from_slice(&buffer);
}
CommandOutput::StdErr(buffer) => {
let required = stderr.len() + buffer.len();
if stderr.capacity() < required {
let new_capacity =
required.max(stderr.capacity() + stderr.capacity() / 2);
stderr.reserve(new_capacity - stderr.capacity());
}
stderr.extend_from_slice(&buffer);
}
}
}
(stdout, stderr)
});
Self {
sender,
receiver_task,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandExecutedResult {
pub stdout: String,
pub stderr: String,
pub exit_status: u32,
}
impl Client {
pub async fn get_channel(&self) -> Result<Channel<Msg>, super::Error> {
self.connection_handle
.channel_open_session()
.await
.map_err(super::Error::SshError)
}
pub async fn open_direct_tcpip_channel<
T: ToSocketAddrsWithHostname,
S: Into<Option<SocketAddr>>,
>(
&self,
target: T,
src: S,
) -> Result<Channel<Msg>, super::Error> {
let targets = target
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let src = src
.into()
.map(|src| (src.ip().to_string(), src.port().into()))
.unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
let mut connect_err = super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
));
for target in targets {
match self
.connection_handle
.channel_open_direct_tcpip(
target.ip().to_string(),
target.port().into(),
src.0.clone(),
src.1,
)
.await
{
Ok(channel) => return Ok(channel),
Err(err) => connect_err = super::Error::SshError(err),
}
}
Err(connect_err)
}
pub async fn execute_streaming(
&self,
command: &str,
sender: Sender<CommandOutput>,
) -> Result<u32, super::Error> {
let sanitized_command = crate::utils::sanitize_command(command)
.map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?;
let mut channel = self.connection_handle.channel_open_session().await?;
channel.exec(true, sanitized_command.as_str()).await?;
let mut result: Option<u32> = None;
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => {
match sender.try_send(CommandOutput::StdOut(data.clone())) {
Ok(_) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(output)) => {
tracing::trace!("Channel full, applying backpressure for stdout");
if sender.send(output).await.is_err() {
tracing::debug!("Receiver dropped, stopping stdout processing");
break;
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Channel closed, stopping stdout processing");
break;
}
}
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
match sender.try_send(CommandOutput::StdErr(data.clone())) {
Ok(_) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(output)) => {
tracing::trace!("Channel full, applying backpressure for stderr");
if sender.send(output).await.is_err() {
tracing::debug!("Receiver dropped, stopping stderr processing");
break;
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Channel closed, stopping stderr processing");
break;
}
}
}
}
russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
_ => {}
}
}
drop(sender);
if let Some(result) = result {
Ok(result)
} else {
Err(super::Error::CommandDidntExit)
}
}
pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, super::Error> {
let output_buffer = CommandOutputBuffer::new();
let sender = output_buffer.sender.clone();
let exit_status = self.execute_streaming(command, sender).await?;
drop(output_buffer.sender);
let (stdout_bytes, stderr_bytes) = output_buffer.receiver_task.await.map_err(|e| {
super::Error::JoinError(e)
})?;
Ok(CommandExecutedResult {
stdout: String::from_utf8_lossy(&stdout_bytes).to_string(),
stderr: String::from_utf8_lossy(&stderr_bytes).to_string(),
exit_status,
})
}
pub async fn request_interactive_shell(
&self,
_term_type: &str,
_width: u32,
_height: u32,
) -> Result<Channel<Msg>, super::Error> {
let channel = self.connection_handle.channel_open_session().await?;
Ok(channel)
}
pub async fn resize_pty(
&self,
channel: &mut Channel<Msg>,
width: u32,
height: u32,
) -> Result<(), super::Error> {
channel
.window_change(width, height, 0, 0)
.await
.map_err(super::Error::SshError)
}
}