use crate::{opt_cstring_to_cstr, opt_str_to_cstring, Error, SessionHolder, SshResult};
use libssh_rs_sys as sys;
use std::convert::TryInto;
use std::ffi::{CStr, CString};
use std::os::raw::c_int;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::Duration;
pub struct Channel {
pub(crate) sess: Arc<Mutex<SessionHolder>>,
pub(crate) chan_inner: sys::ssh_channel,
_callbacks: Box<sys::ssh_channel_callbacks_struct>,
callback_state: Box<CallbackState>,
}
unsafe impl Send for Channel {}
impl Drop for Channel {
fn drop(&mut self) {
unsafe {
sys::ssh_remove_channel_callbacks(self.chan_inner, self._callbacks.as_mut());
}
let (_sess, chan) = self.lock_session();
unsafe {
sys::ssh_channel_free(chan);
}
}
}
struct CallbackState {
signal_state: Mutex<Option<SignalState>>,
}
#[derive(Clone, Debug)]
pub struct SignalState {
pub signal_name: Option<String>,
pub core_dumped: bool,
pub error_message: Option<String>,
pub language: Option<String>,
}
fn cstr_to_opt_string(cstr: *const ::std::os::raw::c_char) -> Option<String> {
if cstr.is_null() {
return None;
}
Some(
unsafe { CStr::from_ptr(cstr) }
.to_string_lossy()
.to_string(),
)
}
unsafe extern "C" fn handle_exit_signal(
_session: sys::ssh_session,
_channel: sys::ssh_channel,
signal: *const ::std::os::raw::c_char,
core_dumped: ::std::os::raw::c_int,
errmsg: *const ::std::os::raw::c_char,
lang: *const ::std::os::raw::c_char,
userdata: *mut ::std::os::raw::c_void,
) {
let callback_state: &CallbackState = &*(userdata as *const CallbackState);
let signal_name = cstr_to_opt_string(signal);
let error_message = cstr_to_opt_string(errmsg);
let language = cstr_to_opt_string(lang);
callback_state
.signal_state
.lock()
.unwrap()
.replace(SignalState {
signal_name,
core_dumped: if core_dumped == 0 { false } else { true },
error_message,
language,
});
}
impl Channel {
pub fn accept_x11(&self, timeout: std::time::Duration) -> Option<Self> {
let (_sess, chan) = self.lock_session();
let timeout = timeout.as_millis();
let chan = unsafe { sys::ssh_channel_accept_x11(chan, timeout.try_into().unwrap()) };
if chan.is_null() {
None
} else {
Some(Self::new(&self.sess, chan))
}
}
pub(crate) fn new(sess: &Arc<Mutex<SessionHolder>>, chan: sys::ssh_channel) -> Self {
let callback_state = Box::new(CallbackState {
signal_state: Mutex::new(None),
});
let callbacks = Box::new(sys::ssh_channel_callbacks_struct {
size: std::mem::size_of::<sys::ssh_channel_callbacks_struct>(),
userdata: callback_state.as_ref() as *const CallbackState as *mut _,
channel_data_function: None,
channel_eof_function: None,
channel_close_function: None,
channel_signal_function: None,
channel_exit_status_function: None,
channel_exit_signal_function: Some(handle_exit_signal),
channel_pty_request_function: None,
channel_shell_request_function: None,
channel_auth_agent_req_function: None,
channel_x11_req_function: None,
channel_pty_window_change_function: None,
channel_exec_request_function: None,
channel_env_request_function: None,
channel_subsystem_request_function: None,
channel_write_wontblock_function: None,
channel_open_response_function: None,
channel_request_response_function: None,
});
unsafe { sys::ssh_set_channel_callbacks(chan, callbacks.as_ref() as *const _ as *mut _) };
Self {
sess: Arc::clone(&sess),
chan_inner: chan,
callback_state,
_callbacks: callbacks,
}
}
fn lock_session(&self) -> (MutexGuard<SessionHolder>, sys::ssh_channel) {
(self.sess.lock().unwrap(), self.chan_inner)
}
pub fn close(&self) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_close(chan) };
sess.basic_status(res, "error closing channel")
}
pub fn get_exit_status(&self) -> Option<c_int> {
let (_sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_get_exit_status(chan) };
if res == -1 {
None
} else {
Some(res)
}
}
pub fn get_exit_signal(&self) -> Option<SignalState> {
self.callback_state.signal_state.lock().unwrap().clone()
}
pub fn is_closed(&self) -> bool {
let (_sess, chan) = self.lock_session();
unsafe { sys::ssh_channel_is_closed(chan) != 0 }
}
pub fn is_eof(&self) -> bool {
let (_sess, chan) = self.lock_session();
unsafe { sys::ssh_channel_is_eof(chan) != 0 }
}
pub fn send_eof(&self) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_send_eof(chan) };
sess.basic_status(res, "ssh_channel_send_eof failed")
}
pub fn is_open(&self) -> bool {
let (_sess, chan) = self.lock_session();
unsafe { sys::ssh_channel_is_open(chan) != 0 }
}
pub fn open_auth_agent(&self) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_open_auth_agent(chan) };
sess.basic_status(res, "ssh_channel_open_auth_agent failed")
}
pub fn request_auth_agent(&self) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_request_auth_agent(chan) };
sess.basic_status(res, "ssh_channel_request_auth_agent failed")
}
pub fn request_env(&self, name: &str, value: &str) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let name = CString::new(name)?;
let value = CString::new(value)?;
let res = unsafe { sys::ssh_channel_request_env(chan, name.as_ptr(), value.as_ptr()) };
sess.basic_status(res, "ssh_channel_request_env failed")
}
pub fn request_shell(&self) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_request_shell(chan) };
sess.basic_status(res, "ssh_channel_request_shell failed")
}
pub fn request_exec(&self, command: &str) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let command = CString::new(command)?;
let res = unsafe { sys::ssh_channel_request_exec(chan, command.as_ptr()) };
sess.basic_status(res, "ssh_channel_request_exec failed")
}
pub fn request_subsystem(&self, subsys: &str) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let subsys = CString::new(subsys)?;
let res = unsafe { sys::ssh_channel_request_subsystem(chan, subsys.as_ptr()) };
sess.basic_status(res, "ssh_channel_request_subsystem failed")
}
pub fn request_pty(&self, term: &str, columns: u32, rows: u32) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let term = CString::new(term)?;
let res = unsafe {
sys::ssh_channel_request_pty_size(
chan,
term.as_ptr(),
columns.try_into().unwrap(),
rows.try_into().unwrap(),
)
};
sess.basic_status(res, "ssh_channel_request_pty_size failed")
}
pub fn change_pty_size(&self, columns: u32, rows: u32) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe {
sys::ssh_channel_change_pty_size(
chan,
columns.try_into().unwrap(),
rows.try_into().unwrap(),
)
};
sess.basic_status(res, "ssh_channel_change_pty_size failed")
}
pub fn request_send_break(&self, length: Duration) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_request_send_break(chan, length.as_millis() as _) };
sess.basic_status(res, "ssh_channel_request_send_break failed")
}
pub fn request_send_signal(&self, signal: &str) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let signal = CString::new(signal)?;
let res = unsafe { sys::ssh_channel_request_send_signal(chan, signal.as_ptr()) };
sess.basic_status(res, "ssh_channel_request_send_signal failed")
}
pub fn open_forward(
&self,
remote_host: &str,
remote_port: u16,
source_host: &str,
source_port: u16,
) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let remote_host = CString::new(remote_host)?;
let source_host = CString::new(source_host)?;
let res = unsafe {
sys::ssh_channel_open_forward(
chan,
remote_host.as_ptr(),
remote_port as i32,
source_host.as_ptr(),
source_port as i32,
)
};
sess.basic_status(res, "ssh_channel_open_forward failed")
}
pub fn open_forward_unix(
&self,
remote_path: &str,
source_host: &str,
source_port: u16,
) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let remote_path = CString::new(remote_path)?;
let source_host = CString::new(source_host)?;
let res = unsafe {
sys::ssh_channel_open_forward_unix(
chan,
remote_path.as_ptr(),
source_host.as_ptr(),
source_port as i32,
)
};
sess.basic_status(res, "ssh_channel_open_forward_unix failed")
}
pub fn request_x11(
&self,
single_connection: bool,
protocol: Option<&str>,
cookie: Option<&str>,
screen_number: c_int,
) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let protocol = opt_str_to_cstring(protocol);
let cookie = opt_str_to_cstring(cookie);
let res = unsafe {
sys::ssh_channel_request_x11(
chan,
if single_connection { 1 } else { 0 },
opt_cstring_to_cstr(&protocol),
opt_cstring_to_cstr(&cookie),
screen_number,
)
};
sess.basic_status(res, "ssh_channel_open_forward failed")
}
pub fn open_session(&self) -> SshResult<()> {
let (sess, chan) = self.lock_session();
let res = unsafe { sys::ssh_channel_open_session(chan) };
sess.basic_status(res, "ssh_channel_open_session failed")
}
pub fn poll_timeout(
&self,
is_stderr: bool,
timeout: Option<Duration>,
) -> SshResult<PollStatus> {
let (sess, chan) = self.lock_session();
let timeout = match timeout {
Some(t) => t.as_millis() as c_int,
None => -1,
};
let res =
unsafe { sys::ssh_channel_poll_timeout(chan, if is_stderr { 1 } else { 0 }, timeout) };
match res {
sys::SSH_ERROR => {
if let Some(err) = sess.last_error() {
Err(err)
} else {
Err(Error::fatal("ssh_channel_poll failed"))
}
}
sys::SSH_EOF => Ok(PollStatus::EndOfFile),
n if n >= 0 => Ok(PollStatus::AvailableBytes(n as u32)),
n => Err(Error::Fatal(format!(
"ssh_channel_poll returned unexpected {} value",
n
))),
}
}
pub fn read_timeout(
&self,
buf: &mut [u8],
is_stderr: bool,
timeout: Option<Duration>,
) -> SshResult<usize> {
let (sess, chan) = self.lock_session();
let timeout = match timeout {
Some(t) => t.as_millis() as c_int,
None => -1,
};
let res = unsafe {
sys::ssh_channel_read_timeout(
chan,
buf.as_mut_ptr() as _,
buf.len() as u32,
if is_stderr { 1 } else { 0 },
timeout,
)
};
match res {
sys::SSH_ERROR => {
if let Some(err) = sess.last_error() {
Err(err)
} else {
Err(Error::fatal("ssh_channel_read_timeout failed"))
}
}
sys::SSH_AGAIN => Err(Error::TryAgain),
n if n < 0 => Err(Error::Fatal(format!(
"ssh_channel_read_timeout returned unexpected {} value",
n
))),
0 if !sess.is_blocking() => Err(Error::TryAgain),
n => Ok(n as usize),
}
}
pub fn read_nonblocking(&self, buf: &mut [u8], is_stderr: bool) -> SshResult<usize> {
let (sess, chan) = self.lock_session();
let res = unsafe {
sys::ssh_channel_read_nonblocking(
chan,
buf.as_mut_ptr() as _,
buf.len() as u32,
if is_stderr { 1 } else { 0 },
)
};
match res {
sys::SSH_ERROR => {
if let Some(err) = sess.last_error() {
Err(err)
} else {
Err(Error::fatal("ssh_channel_read_timeout failed"))
}
}
sys::SSH_EOF => Ok(0 as usize),
n if n < 0 => Err(Error::Fatal(format!(
"ssh_channel_read_timeout returned unexpected value: {n}"
))),
n => Ok(n as usize),
}
}
pub fn window_size(&self) -> usize {
let (_sess, chan) = self.lock_session();
unsafe { sys::ssh_channel_window_size(chan).try_into().unwrap() }
}
fn read_impl(&self, buf: &mut [u8], is_stderr: bool) -> std::io::Result<usize> {
Ok(self.read_timeout(buf, is_stderr, None)?)
}
fn write_impl(&self, buf: &[u8], is_stderr: bool) -> SshResult<usize> {
let (sess, chan) = self.lock_session();
let res = unsafe {
(if is_stderr {
sys::ssh_channel_write_stderr
} else {
sys::ssh_channel_write
})(chan, buf.as_ptr() as _, buf.len() as _)
};
match res {
sys::SSH_ERROR => {
if let Some(err) = sess.last_error() {
Err(err)
} else {
Err(Error::fatal("ssh_channel_read_timeout failed"))
}
}
sys::SSH_AGAIN => Err(Error::TryAgain),
n if n < 0 => Err(Error::Fatal(format!(
"ssh_channel_read_timeout returned unexpected {} value",
n
))),
n => Ok(n as usize),
}
}
pub fn stdout(&self) -> impl std::io::Read + '_ {
ChannelStdout { chan: self }
}
pub fn stderr(&self) -> impl std::io::Read + '_ {
ChannelStderr { chan: self }
}
pub fn stdin(&self) -> impl std::io::Write + '_ {
ChannelStdin { chan: self }
}
}
struct ChannelStdin<'a> {
chan: &'a Channel,
}
impl<'a> std::io::Write for ChannelStdin<'a> {
fn flush(&mut self) -> std::io::Result<()> {
Ok(self.chan.sess.lock().unwrap().blocking_flush(None)?)
}
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
Ok(self.chan.write_impl(buf, false)?)
}
}
struct ChannelStdout<'a> {
chan: &'a Channel,
}
impl<'a> std::io::Read for ChannelStdout<'a> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.chan.read_impl(buf, false)
}
}
struct ChannelStderr<'a> {
chan: &'a Channel,
}
impl<'a> std::io::Read for ChannelStderr<'a> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.chan.read_impl(buf, true)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PollStatus {
AvailableBytes(u32),
EndOfFile,
}