use std::sync::Arc;
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{Mutex, Notify};
use crate::{ChannelId, ChannelOpenFailure, Error, Pty, Sig};
pub mod io;
mod channel_ref;
pub use channel_ref::ChannelRef;
mod channel_stream;
pub use channel_stream::ChannelStream;
#[derive(Debug)]
#[non_exhaustive]
pub enum ChannelMsg {
Open {
id: ChannelId,
max_packet_size: u32,
window_size: u32,
},
Data {
data: Bytes,
},
ExtendedData {
data: Bytes,
ext: u32,
},
Eof,
Close,
RequestPty {
want_reply: bool,
term: String,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
terminal_modes: Vec<(Pty, u32)>,
},
RequestShell {
want_reply: bool,
},
Exec {
want_reply: bool,
command: Vec<u8>,
},
Signal {
signal: Sig,
},
RequestSubsystem {
want_reply: bool,
name: String,
},
RequestX11 {
want_reply: bool,
single_connection: bool,
x11_authentication_protocol: String,
x11_authentication_cookie: String,
x11_screen_number: u32,
},
SetEnv {
want_reply: bool,
variable_name: String,
variable_value: String,
},
WindowChange {
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
},
AgentForward {
want_reply: bool,
},
XonXoff {
client_can_do: bool,
},
ExitStatus {
exit_status: u32,
},
ExitSignal {
signal_name: Sig,
core_dumped: bool,
error_message: String,
lang_tag: String,
},
WindowAdjusted {
new_size: u32,
},
Success,
Failure,
OpenFailure(ChannelOpenFailure),
}
#[derive(Clone, Debug)]
pub(crate) struct WindowSizeRef {
value: Arc<Mutex<u32>>,
notifier: Arc<Notify>,
}
impl WindowSizeRef {
pub(crate) fn new(initial: u32) -> Self {
let notifier = Arc::new(Notify::new());
Self {
value: Arc::new(Mutex::new(initial)),
notifier,
}
}
pub(crate) async fn update(&self, value: u32) {
*self.value.lock().await = value;
self.notifier.notify_one();
}
pub(crate) fn subscribe(&self) -> Arc<Notify> {
Arc::clone(&self.notifier)
}
}
pub struct ChannelReadHalf {
pub(crate) receiver: Receiver<ChannelMsg>,
}
impl std::fmt::Debug for ChannelReadHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelReadHalf").finish()
}
}
impl ChannelReadHalf {
pub async fn wait(&mut self) -> Option<ChannelMsg> {
self.receiver.recv().await
}
pub fn make_reader(&mut self) -> impl AsyncRead + '_ {
self.make_reader_ext(None)
}
pub fn make_reader_ext(&mut self, ext: Option<u32>) -> impl AsyncRead + '_ {
io::ChannelRx::new(self, ext)
}
}
pub struct ChannelWriteHalf<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) id: ChannelId,
pub(crate) sender: Sender<Send>,
pub(crate) max_packet_size: u32,
pub(crate) window_size: WindowSizeRef,
}
impl<S: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for ChannelWriteHalf<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChannelWriteHalf")
.field("id", &self.id)
.finish()
}
}
impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> ChannelWriteHalf<S> {
pub async fn writable_packet_size(&self) -> usize {
self.max_packet_size
.min(*self.window_size.value.lock().await) as usize
}
pub fn id(&self) -> ChannelId {
self.id
}
#[allow(clippy::too_many_arguments)] pub async fn request_pty(
&self,
want_reply: bool,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
terminal_modes: &[(Pty, u32)],
) -> Result<(), Error> {
self.send_msg(ChannelMsg::RequestPty {
want_reply,
term: term.to_string(),
col_width,
row_height,
pix_width,
pix_height,
terminal_modes: terminal_modes.to_vec(),
})
.await
}
pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> {
self.send_msg(ChannelMsg::RequestShell { want_reply }).await
}
pub async fn exec<A: Into<Vec<u8>>>(&self, want_reply: bool, command: A) -> Result<(), Error> {
self.send_msg(ChannelMsg::Exec {
want_reply,
command: command.into(),
})
.await
}
pub async fn signal(&self, signal: Sig) -> Result<(), Error> {
self.send_msg(ChannelMsg::Signal { signal }).await
}
pub async fn request_subsystem<A: Into<String>>(
&self,
want_reply: bool,
name: A,
) -> Result<(), Error> {
self.send_msg(ChannelMsg::RequestSubsystem {
want_reply,
name: name.into(),
})
.await
}
pub async fn request_x11<A: Into<String>, B: Into<String>>(
&self,
want_reply: bool,
single_connection: bool,
x11_authentication_protocol: A,
x11_authentication_cookie: B,
x11_screen_number: u32,
) -> Result<(), Error> {
self.send_msg(ChannelMsg::RequestX11 {
want_reply,
single_connection,
x11_authentication_protocol: x11_authentication_protocol.into(),
x11_authentication_cookie: x11_authentication_cookie.into(),
x11_screen_number,
})
.await
}
pub async fn set_env<A: Into<String>, B: Into<String>>(
&self,
want_reply: bool,
variable_name: A,
variable_value: B,
) -> Result<(), Error> {
self.send_msg(ChannelMsg::SetEnv {
want_reply,
variable_name: variable_name.into(),
variable_value: variable_value.into(),
})
.await
}
pub async fn window_change(
&self,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
) -> Result<(), Error> {
self.send_msg(ChannelMsg::WindowChange {
col_width,
row_height,
pix_width,
pix_height,
})
.await
}
pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> {
self.send_msg(ChannelMsg::AgentForward { want_reply }).await
}
pub async fn data<R: tokio::io::AsyncRead + Unpin>(&self, data: R) -> Result<(), Error> {
self.send_data(None, data).await
}
pub async fn extended_data<R: tokio::io::AsyncRead + Unpin>(
&self,
ext: u32,
data: R,
) -> Result<(), Error> {
self.send_data(Some(ext), data).await
}
async fn send_data<R: tokio::io::AsyncRead + Unpin>(
&self,
ext: Option<u32>,
mut data: R,
) -> Result<(), Error> {
let mut tx = self.make_writer_ext(ext);
tokio::io::copy(&mut data, &mut tx).await?;
Ok(())
}
pub async fn eof(&self) -> Result<(), Error> {
self.send_msg(ChannelMsg::Eof).await
}
pub async fn exit_status(&self, exit_status: u32) -> Result<(), Error> {
self.send_msg(ChannelMsg::ExitStatus { exit_status }).await
}
pub async fn close(&self) -> Result<(), Error> {
self.send_msg(ChannelMsg::Close).await
}
async fn send_msg(&self, msg: ChannelMsg) -> Result<(), Error> {
self.sender
.send((self.id, msg).into())
.await
.map_err(|_| Error::SendError)
}
pub fn make_writer(&self) -> impl AsyncWrite + 'static {
self.make_writer_ext(None)
}
pub fn make_writer_ext(&self, ext: Option<u32>) -> impl AsyncWrite + 'static {
io::ChannelTx::new(
self.sender.clone(),
self.id,
self.window_size.value.clone(),
self.window_size.subscribe(),
self.max_packet_size,
ext,
)
}
}
pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) read_half: ChannelReadHalf,
pub(crate) write_half: ChannelWriteHalf<Send>,
}
impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Channel")
.field("id", &self.write_half.id)
.finish()
}
}
impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
pub(crate) fn new(
id: ChannelId,
sender: Sender<S>,
max_packet_size: u32,
window_size: u32,
channel_buffer_size: usize,
) -> (Self, ChannelRef) {
let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size);
let window_size = WindowSizeRef::new(window_size);
let read_half = ChannelReadHalf { receiver: rx };
let write_half = ChannelWriteHalf {
id,
sender,
max_packet_size,
window_size: window_size.clone(),
};
(
Self {
write_half,
read_half,
},
ChannelRef {
sender: tx,
window_size,
},
)
}
pub async fn writable_packet_size(&self) -> usize {
self.write_half.writable_packet_size().await
}
pub fn id(&self) -> ChannelId {
self.write_half.id()
}
pub fn split(self) -> (ChannelReadHalf, ChannelWriteHalf<S>) {
(self.read_half, self.write_half)
}
#[allow(clippy::too_many_arguments)] pub async fn request_pty(
&self,
want_reply: bool,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
terminal_modes: &[(Pty, u32)],
) -> Result<(), Error> {
self.write_half
.request_pty(
want_reply,
term,
col_width,
row_height,
pix_width,
pix_height,
terminal_modes,
)
.await
}
pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> {
self.write_half.request_shell(want_reply).await
}
pub async fn exec<A: Into<Vec<u8>>>(&self, want_reply: bool, command: A) -> Result<(), Error> {
self.write_half.exec(want_reply, command).await
}
pub async fn signal(&self, signal: Sig) -> Result<(), Error> {
self.write_half.signal(signal).await
}
pub async fn request_subsystem<A: Into<String>>(
&self,
want_reply: bool,
name: A,
) -> Result<(), Error> {
self.write_half.request_subsystem(want_reply, name).await
}
pub async fn request_x11<A: Into<String>, B: Into<String>>(
&self,
want_reply: bool,
single_connection: bool,
x11_authentication_protocol: A,
x11_authentication_cookie: B,
x11_screen_number: u32,
) -> Result<(), Error> {
self.write_half
.request_x11(
want_reply,
single_connection,
x11_authentication_protocol,
x11_authentication_cookie,
x11_screen_number,
)
.await
}
pub async fn set_env<A: Into<String>, B: Into<String>>(
&self,
want_reply: bool,
variable_name: A,
variable_value: B,
) -> Result<(), Error> {
self.write_half
.set_env(want_reply, variable_name, variable_value)
.await
}
pub async fn window_change(
&self,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
) -> Result<(), Error> {
self.write_half
.window_change(col_width, row_height, pix_width, pix_height)
.await
}
pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> {
self.write_half.agent_forward(want_reply).await
}
pub async fn data<R: tokio::io::AsyncRead + Unpin>(&self, data: R) -> Result<(), Error> {
self.write_half.data(data).await
}
pub async fn extended_data<R: tokio::io::AsyncRead + Unpin>(
&self,
ext: u32,
data: R,
) -> Result<(), Error> {
self.write_half.extended_data(ext, data).await
}
pub async fn eof(&self) -> Result<(), Error> {
self.write_half.eof().await
}
pub async fn exit_status(&self, exit_status: u32) -> Result<(), Error> {
self.write_half.exit_status(exit_status).await
}
pub async fn close(&self) -> Result<(), Error> {
self.write_half.close().await
}
pub async fn wait(&mut self) -> Option<ChannelMsg> {
self.read_half.wait().await
}
pub fn into_stream(self) -> ChannelStream<S> {
ChannelStream::new(
io::ChannelTx::new(
self.write_half.sender.clone(),
self.write_half.id,
self.write_half.window_size.value.clone(),
self.write_half.window_size.subscribe(),
self.write_half.max_packet_size,
None,
),
io::ChannelRx::new(io::ChannelCloseOnDrop(self), None),
)
}
pub fn make_reader(&mut self) -> impl AsyncRead + '_ {
self.read_half.make_reader()
}
pub fn make_reader_ext(&mut self, ext: Option<u32>) -> impl AsyncRead + '_ {
self.read_half.make_reader_ext(ext)
}
pub fn make_writer(&self) -> impl AsyncWrite + 'static {
self.write_half.make_writer()
}
pub fn make_writer_ext(&self, ext: Option<u32>) -> impl AsyncWrite + 'static {
self.write_half.make_writer_ext(ext)
}
}