#[cfg(feature = "ssh")]
use std::collections::VecDeque;
#[cfg(feature = "ssh")]
use std::io;
#[cfg(feature = "ssh")]
use std::pin::Pin;
#[cfg(feature = "ssh")]
use std::task::{Context, Poll};
use std::time::Duration;
#[cfg(feature = "ssh")]
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
#[cfg(feature = "ssh")]
use crate::error::SshError;
use crate::types::Dimensions;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelType {
Session,
DirectTcpIp,
ForwardedTcpIp,
X11,
}
#[derive(Debug, Clone)]
pub enum ChannelRequest {
Pty {
term: String,
dimensions: Dimensions,
modes: Vec<u8>,
},
Shell,
Exec {
command: String,
},
Subsystem {
name: String,
},
WindowChange {
dimensions: Dimensions,
},
Signal {
signal: String,
},
Env {
name: String,
value: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelState {
Opening,
Open,
Eof,
Closed,
}
#[derive(Debug, Clone)]
pub struct ChannelConfig {
pub channel_type: ChannelType,
pub read_timeout: Duration,
pub write_timeout: Duration,
pub buffer_size: usize,
pub pty: bool,
pub term: String,
pub dimensions: Dimensions,
}
impl Default for ChannelConfig {
fn default() -> Self {
Self {
channel_type: ChannelType::Session,
read_timeout: Duration::from_secs(30),
write_timeout: Duration::from_secs(30),
buffer_size: 32768,
pty: true,
term: "xterm-256color".to_string(),
dimensions: Dimensions::default(),
}
}
}
impl ChannelConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn term(mut self, term: impl Into<String>) -> Self {
self.term = term.into();
self
}
#[must_use]
pub const fn dimensions(mut self, cols: u16, rows: u16) -> Self {
self.dimensions = Dimensions { cols, rows };
self
}
#[must_use]
pub const fn no_pty(mut self) -> Self {
self.pty = false;
self
}
#[must_use]
pub const fn buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
}
#[cfg(feature = "ssh")]
pub struct SshChannelStream {
channel: russh::Channel<russh::client::Msg>,
config: ChannelConfig,
state: ChannelState,
read_buffer: VecDeque<u8>,
exit_status: Option<u32>,
eof_received: bool,
}
#[cfg(feature = "ssh")]
impl std::fmt::Debug for SshChannelStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SshChannelStream")
.field("config", &self.config)
.field("state", &self.state)
.field("read_buffer_len", &self.read_buffer.len())
.field("exit_status", &self.exit_status)
.field("eof_received", &self.eof_received)
.finish()
}
}
#[cfg(feature = "ssh")]
impl SshChannelStream {
#[must_use]
pub fn new(channel: russh::Channel<russh::client::Msg>, config: ChannelConfig) -> Self {
Self {
channel,
config,
state: ChannelState::Open,
read_buffer: VecDeque::with_capacity(32768),
exit_status: None,
eof_received: false,
}
}
#[must_use]
pub const fn config(&self) -> &ChannelConfig {
&self.config
}
#[must_use]
pub const fn state(&self) -> ChannelState {
self.state
}
#[must_use]
pub fn is_open(&self) -> bool {
self.state == ChannelState::Open
}
#[must_use]
pub const fn is_eof(&self) -> bool {
self.eof_received
}
#[must_use]
pub const fn exit_status(&self) -> Option<u32> {
self.exit_status
}
pub async fn request_pty(&mut self) -> crate::error::Result<()> {
self.channel
.request_pty(
false, &self.config.term,
self.config.dimensions.cols.into(),
self.config.dimensions.rows.into(),
0, 0, &[], )
.await
.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("PTY request failed: {e}"),
})
})
}
pub async fn request_shell(&mut self) -> crate::error::Result<()> {
self.channel.request_shell(false).await.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("Shell request failed: {e}"),
})
})
}
pub async fn exec(&mut self, command: &str) -> crate::error::Result<()> {
self.channel.exec(false, command).await.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("Exec request failed: {e}"),
})
})
}
pub async fn window_change(&mut self, cols: u16, rows: u16) -> crate::error::Result<()> {
self.channel
.window_change(cols.into(), rows.into(), 0, 0)
.await
.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("Window change failed: {e}"),
})
})
}
pub async fn send_data(&mut self, data: &[u8]) -> crate::error::Result<()> {
self.channel.data(data).await.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("Data send failed: {e}"),
})
})
}
pub async fn send_eof(&mut self) -> crate::error::Result<()> {
self.channel.eof().await.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("EOF send failed: {e}"),
})
})
}
pub async fn close(&mut self) -> crate::error::Result<()> {
self.state = ChannelState::Closed;
self.channel.close().await.map_err(|e| {
crate::error::ExpectError::Ssh(SshError::Channel {
reason: format!("Channel close failed: {e}"),
})
})
}
pub async fn wait(&mut self) -> Option<russh::ChannelMsg> {
let msg = self.channel.wait().await?;
match &msg {
russh::ChannelMsg::Data { data } => {
self.read_buffer.extend(data.as_ref());
}
russh::ChannelMsg::ExtendedData { data, ext } => {
if *ext == 1 {
self.read_buffer.extend(data.as_ref());
}
}
russh::ChannelMsg::ExitStatus { exit_status } => {
self.exit_status = Some(*exit_status);
}
russh::ChannelMsg::Eof => {
self.eof_received = true;
self.state = ChannelState::Eof;
}
russh::ChannelMsg::Close => {
self.state = ChannelState::Closed;
}
_ => {}
}
Some(msg)
}
pub fn read_buffered(&mut self, buf: &mut [u8]) -> usize {
let len = std::cmp::min(buf.len(), self.read_buffer.len());
for (i, byte) in self.read_buffer.drain(..len).enumerate() {
buf[i] = byte;
}
len
}
#[must_use]
pub fn has_data(&self) -> bool {
!self.read_buffer.is_empty()
}
#[must_use]
pub fn available(&self) -> usize {
self.read_buffer.len()
}
}
#[cfg(feature = "ssh")]
impl AsyncRead for SshChannelStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if !self.read_buffer.is_empty() {
let len = std::cmp::min(buf.remaining(), self.read_buffer.len());
let data: Vec<u8> = self.read_buffer.drain(..len).collect();
buf.put_slice(&data);
return Poll::Ready(Ok(()));
}
if self.eof_received || self.state == ChannelState::Closed {
return Poll::Ready(Ok(()));
}
let this = self.get_mut();
let wait_future = this.channel.wait();
tokio::pin!(wait_future);
match wait_future.poll(cx) {
Poll::Ready(Some(msg)) => {
match msg {
russh::ChannelMsg::Data { data } => {
let len = std::cmp::min(buf.remaining(), data.len());
buf.put_slice(&data[..len]);
if len < data.len() {
this.read_buffer.extend(&data[len..]);
}
Poll::Ready(Ok(()))
}
russh::ChannelMsg::ExtendedData { data, ext } => {
if ext == 1 {
let len = std::cmp::min(buf.remaining(), data.len());
buf.put_slice(&data[..len]);
if len < data.len() {
this.read_buffer.extend(&data[len..]);
}
}
Poll::Ready(Ok(()))
}
russh::ChannelMsg::Eof => {
this.eof_received = true;
this.state = ChannelState::Eof;
Poll::Ready(Ok(()))
}
russh::ChannelMsg::Close => {
this.state = ChannelState::Closed;
Poll::Ready(Ok(()))
}
russh::ChannelMsg::ExitStatus { exit_status } => {
this.exit_status = Some(exit_status);
cx.waker().wake_by_ref();
Poll::Pending
}
_ => {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Poll::Ready(None) => {
this.state = ChannelState::Closed;
Poll::Ready(Ok(()))
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(feature = "ssh")]
impl AsyncWrite for SshChannelStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.state == ChannelState::Closed {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Channel is closed",
)));
}
let data_future = this.channel.data(buf);
tokio::pin!(data_future);
match data_future.poll(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(buf.len())),
Poll::Ready(Err(e)) => {
Poll::Ready(Err(io::Error::other(format!("SSH write error: {e}"))))
}
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let eof_future = this.channel.eof();
tokio::pin!(eof_future);
match eof_future.poll(cx) {
Poll::Ready(Ok(())) => {
this.state = ChannelState::Eof;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
Poll::Ready(Err(io::Error::other(format!("SSH shutdown error: {e}"))))
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(not(feature = "ssh"))]
#[derive(Debug)]
pub struct SshChannel {
config: ChannelConfig,
state: ChannelState,
exit_status: Option<i32>,
}
#[cfg(not(feature = "ssh"))]
impl SshChannel {
#[must_use]
pub const fn new(config: ChannelConfig) -> Self {
Self {
config,
state: ChannelState::Opening,
exit_status: None,
}
}
#[must_use]
pub const fn config(&self) -> &ChannelConfig {
&self.config
}
#[must_use]
pub const fn state(&self) -> ChannelState {
self.state
}
#[must_use]
pub fn is_open(&self) -> bool {
self.state == ChannelState::Open
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.state == ChannelState::Closed
}
#[must_use]
pub const fn exit_status(&self) -> Option<i32> {
self.exit_status
}
pub fn open(&mut self) -> crate::error::Result<()> {
self.state = ChannelState::Open;
Ok(())
}
pub fn close(&mut self) {
self.state = ChannelState::Closed;
}
pub fn set_exit_status(&mut self, status: i32) {
self.exit_status = Some(status);
}
}
#[cfg(feature = "ssh")]
pub type SshChannel = SshChannelStream;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn channel_config() {
let config = ChannelConfig::new()
.term("vt100")
.dimensions(120, 40)
.buffer_size(65536);
assert_eq!(config.term, "vt100");
assert_eq!(config.dimensions.cols, 120);
assert_eq!(config.buffer_size, 65536);
}
#[cfg(not(feature = "ssh"))]
#[test]
fn channel_state() {
let mut channel = SshChannel::new(ChannelConfig::default());
assert_eq!(channel.state(), ChannelState::Opening);
channel.open().unwrap();
assert!(channel.is_open());
channel.close();
assert!(channel.is_closed());
}
}