use bytes::Bytes;
use russh::client::Msg;
use russh::Channel;
use std::io;
use std::net::SocketAddr;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
use super::connection::Client;
use super::ToSocketAddrsWithHostname;
use crate::security::{contains_sudo_failure, contains_sudo_prompt, SudoPassword};
const OUTPUT_EVENTS_CHANNEL_SIZE: usize = 100;
const MAX_SUDO_PROMPT_BUFFER_SIZE: usize = 64 * 1024;
const MAX_SUDO_PASSWORD_SENDS: u32 = 10;
#[derive(Debug, Clone)]
pub enum CommandOutput {
StdOut(Bytes),
StdErr(Bytes),
ExitCode(u32),
}
pub(crate) struct CommandOutputBuffer {
pub(crate) sender: Sender<CommandOutput>,
pub(crate) receiver_task: JoinHandle<(Vec<u8>, Vec<u8>)>,
}
impl CommandOutputBuffer {
pub(crate) fn new() -> Self {
let (sender, mut receiver): (Sender<CommandOutput>, Receiver<CommandOutput>) =
channel(OUTPUT_EVENTS_CHANNEL_SIZE);
let receiver_task = tokio::task::spawn(async move {
let mut stdout = Vec::with_capacity(1024); let mut stderr = Vec::with_capacity(256);
while let Some(output) = receiver.recv().await {
match output {
CommandOutput::StdOut(buffer) => {
let required = stdout.len() + buffer.len();
if stdout.capacity() < required {
let new_capacity =
required.max(stdout.capacity() + stdout.capacity() / 2);
stdout.reserve(new_capacity - stdout.capacity());
}
stdout.extend_from_slice(&buffer);
}
CommandOutput::StdErr(buffer) => {
let required = stderr.len() + buffer.len();
if stderr.capacity() < required {
let new_capacity =
required.max(stderr.capacity() + stderr.capacity() / 2);
stderr.reserve(new_capacity - stderr.capacity());
}
stderr.extend_from_slice(&buffer);
}
CommandOutput::ExitCode(_) => {
}
}
}
(stdout, stderr)
});
Self {
sender,
receiver_task,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandExecutedResult {
pub stdout: String,
pub stderr: String,
pub exit_status: u32,
}
impl Client {
pub async fn get_channel(&self) -> Result<Channel<Msg>, super::Error> {
self.connection_handle
.channel_open_session()
.await
.map_err(super::Error::SshError)
}
pub async fn open_direct_tcpip_channel<
T: ToSocketAddrsWithHostname,
S: Into<Option<SocketAddr>>,
>(
&self,
target: T,
src: S,
) -> Result<Channel<Msg>, super::Error> {
let targets = target
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let src = src
.into()
.map(|src| (src.ip().to_string(), src.port().into()))
.unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
let mut connect_err = super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
));
for target in targets {
match self
.connection_handle
.channel_open_direct_tcpip(
target.ip().to_string(),
target.port().into(),
src.0.clone(),
src.1,
)
.await
{
Ok(channel) => return Ok(channel),
Err(err) => connect_err = super::Error::SshError(err),
}
}
Err(connect_err)
}
pub async fn execute_streaming(
&self,
command: &str,
sender: Sender<CommandOutput>,
) -> Result<u32, super::Error> {
let sanitized_command = crate::utils::sanitize_command(command)
.map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?;
let mut channel = self.connection_handle.channel_open_session().await?;
channel.exec(true, sanitized_command.as_str()).await?;
let mut result: Option<u32> = None;
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => {
match sender.try_send(CommandOutput::StdOut(data.clone())) {
Ok(_) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(output)) => {
tracing::trace!("Channel full, applying backpressure for stdout");
if sender.send(output).await.is_err() {
tracing::debug!("Receiver dropped, stopping stdout processing");
break;
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Channel closed, stopping stdout processing");
break;
}
}
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
match sender.try_send(CommandOutput::StdErr(data.clone())) {
Ok(_) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(output)) => {
tracing::trace!("Channel full, applying backpressure for stderr");
if sender.send(output).await.is_err() {
tracing::debug!("Receiver dropped, stopping stderr processing");
break;
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Channel closed, stopping stderr processing");
break;
}
}
}
}
russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
_ => {}
}
}
drop(sender);
if let Some(result) = result {
Ok(result)
} else {
Err(super::Error::CommandDidntExit)
}
}
pub async fn execute_with_sudo(
&self,
command: &str,
sender: Sender<CommandOutput>,
sudo_password: &SudoPassword,
) -> Result<u32, super::Error> {
let sanitized_command = crate::utils::sanitize_command(command)
.map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?;
let mut channel = self.connection_handle.channel_open_session().await?;
channel
.request_pty(
true, "xterm", 80, 24, 0, 0, &[], )
.await?;
channel.exec(true, sanitized_command.as_str()).await?;
let mut result: Option<u32> = None;
let mut password_send_count: u32 = 0;
let mut accumulated_output = String::new();
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => {
let text = String::from_utf8_lossy(data);
accumulated_output.push_str(&text);
if accumulated_output.len() > MAX_SUDO_PROMPT_BUFFER_SIZE {
let truncate_at = accumulated_output.len() - MAX_SUDO_PROMPT_BUFFER_SIZE;
accumulated_output = accumulated_output[truncate_at..].to_string();
tracing::debug!(
"Sudo prompt buffer exceeded limit, truncated to {} bytes",
MAX_SUDO_PROMPT_BUFFER_SIZE
);
}
match sender.try_send(CommandOutput::StdOut(data.clone())) {
Ok(_) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(output)) => {
tracing::trace!("Channel full, applying backpressure for stdout");
if sender.send(output).await.is_err() {
tracing::debug!("Receiver dropped, stopping stdout processing");
break;
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Channel closed, stopping stdout processing");
break;
}
}
if password_send_count < MAX_SUDO_PASSWORD_SENDS
&& contains_sudo_prompt(&accumulated_output)
{
password_send_count += 1;
tracing::debug!(
"Sudo prompt detected, sending password (attempt {}/{})",
password_send_count,
MAX_SUDO_PASSWORD_SENDS
);
let password_data = sudo_password.with_newline();
if let Err(e) = channel.data(&password_data[..]).await {
tracing::error!("Failed to send sudo password: {}", e);
return Err(super::Error::SshError(e));
}
accumulated_output.clear();
}
if password_send_count > 0 && contains_sudo_failure(&accumulated_output) {
tracing::debug!(
"Sudo authentication failed after {} attempt(s), closing channel",
password_send_count
);
let error_msg = format!(
"\n[bssh] Sudo authentication failed after {} attempt(s). \
Please verify your sudo password is correct.\n",
password_send_count
);
let _ = sender
.send(CommandOutput::StdErr(Bytes::from(error_msg.into_bytes())))
.await;
let _ = sender.send(CommandOutput::ExitCode(1)).await;
let _ = channel.eof().await;
let _ = channel.close().await;
drop(sender);
return Ok(1);
}
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
let text = String::from_utf8_lossy(data);
accumulated_output.push_str(&text);
if accumulated_output.len() > MAX_SUDO_PROMPT_BUFFER_SIZE {
let truncate_at =
accumulated_output.len() - MAX_SUDO_PROMPT_BUFFER_SIZE;
accumulated_output = accumulated_output[truncate_at..].to_string();
tracing::debug!(
"Sudo prompt buffer exceeded limit (stderr), truncated to {} bytes",
MAX_SUDO_PROMPT_BUFFER_SIZE
);
}
match sender.try_send(CommandOutput::StdErr(data.clone())) {
Ok(_) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(output)) => {
tracing::trace!("Channel full, applying backpressure for stderr");
if sender.send(output).await.is_err() {
tracing::debug!("Receiver dropped, stopping stderr processing");
break;
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
tracing::debug!("Channel closed, stopping stderr processing");
break;
}
}
if password_send_count < MAX_SUDO_PASSWORD_SENDS
&& contains_sudo_prompt(&accumulated_output)
{
password_send_count += 1;
tracing::debug!(
"Sudo prompt detected on stderr, sending password (attempt {}/{})",
password_send_count,
MAX_SUDO_PASSWORD_SENDS
);
let password_data = sudo_password.with_newline();
if let Err(e) = channel.data(&password_data[..]).await {
tracing::error!("Failed to send sudo password: {}", e);
return Err(super::Error::SshError(e));
}
accumulated_output.clear();
}
if password_send_count > 0 && contains_sudo_failure(&accumulated_output) {
tracing::debug!(
"Sudo authentication failed on stderr after {} attempt(s), closing channel",
password_send_count
);
let error_msg = format!(
"\n[bssh] Sudo authentication failed after {} attempt(s). \
Please verify your sudo password is correct.\n",
password_send_count
);
let _ = sender
.send(CommandOutput::StdErr(Bytes::from(error_msg.into_bytes())))
.await;
let _ = sender.send(CommandOutput::ExitCode(1)).await;
let _ = channel.eof().await;
let _ = channel.close().await;
drop(sender);
return Ok(1);
}
}
}
russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
_ => {}
}
}
drop(sender);
if let Some(result) = result {
Ok(result)
} else {
Err(super::Error::CommandDidntExit)
}
}
pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, super::Error> {
let output_buffer = CommandOutputBuffer::new();
let sender = output_buffer.sender.clone();
let exit_status = self.execute_streaming(command, sender).await?;
drop(output_buffer.sender);
let (stdout_bytes, stderr_bytes) = output_buffer.receiver_task.await.map_err(|e| {
super::Error::JoinError(e)
})?;
Ok(CommandExecutedResult {
stdout: String::from_utf8_lossy(&stdout_bytes).to_string(),
stderr: String::from_utf8_lossy(&stderr_bytes).to_string(),
exit_status,
})
}
pub async fn request_interactive_shell(
&self,
_term_type: &str,
_width: u32,
_height: u32,
) -> Result<Channel<Msg>, super::Error> {
let channel = self.connection_handle.channel_open_session().await?;
Ok(channel)
}
pub async fn resize_pty(
&self,
channel: &mut Channel<Msg>,
width: u32,
height: u32,
) -> Result<(), super::Error> {
channel
.window_change(width, height, 0, 0)
.await
.map_err(super::Error::SshError)
}
}