use russh::client::Msg;
use russh::Channel;
use std::io;
use std::net::SocketAddr;
use tokio::io::AsyncWriteExt;
use super::connection::Client;
use super::ToSocketAddrsWithHostname;
const SSH_CMD_BUFFER_SIZE: usize = 8192;
const SSH_RESPONSE_BUFFER_SIZE: usize = 1024;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandExecutedResult {
pub stdout: String,
pub stderr: String,
pub exit_status: u32,
}
impl Client {
pub async fn get_channel(&self) -> Result<Channel<Msg>, super::Error> {
self.connection_handle
.channel_open_session()
.await
.map_err(super::Error::SshError)
}
pub async fn open_direct_tcpip_channel<
T: ToSocketAddrsWithHostname,
S: Into<Option<SocketAddr>>,
>(
&self,
target: T,
src: S,
) -> Result<Channel<Msg>, super::Error> {
let targets = target
.to_socket_addrs()
.map_err(super::Error::AddressInvalid)?;
let src = src
.into()
.map(|src| (src.ip().to_string(), src.port().into()))
.unwrap_or_else(|| ("127.0.0.1".to_string(), 22));
let mut connect_err = super::Error::AddressInvalid(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any addresses",
));
for target in targets {
match self
.connection_handle
.channel_open_direct_tcpip(
target.ip().to_string(),
target.port().into(),
src.0.clone(),
src.1,
)
.await
{
Ok(channel) => return Ok(channel),
Err(err) => connect_err = super::Error::SshError(err),
}
}
Err(connect_err)
}
pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, super::Error> {
let sanitized_command = crate::utils::sanitize_command(command)
.map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?;
let mut stdout_buffer = Vec::with_capacity(SSH_CMD_BUFFER_SIZE);
let mut stderr_buffer = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE);
let mut channel = self.connection_handle.channel_open_session().await?;
channel.exec(true, sanitized_command.as_str()).await?;
let mut result: Option<u32> = None;
while let Some(msg) = channel.wait().await {
match msg {
russh::ChannelMsg::Data { ref data } => {
stdout_buffer.write_all(data).await.unwrap()
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
if ext == 1 {
stderr_buffer.write_all(data).await.unwrap()
}
}
russh::ChannelMsg::ExitStatus { exit_status } => result = Some(exit_status),
_ => {}
}
}
if let Some(result) = result {
Ok(CommandExecutedResult {
stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
exit_status: result,
})
} else {
Err(super::Error::CommandDidntExit)
}
}
pub async fn request_interactive_shell(
&self,
_term_type: &str,
_width: u32,
_height: u32,
) -> Result<Channel<Msg>, super::Error> {
let channel = self.connection_handle.channel_open_session().await?;
Ok(channel)
}
pub async fn resize_pty(
&self,
channel: &mut Channel<Msg>,
width: u32,
height: u32,
) -> Result<(), super::Error> {
channel
.window_change(width, height, 0, 0)
.await
.map_err(super::Error::SshError)
}
}