#![cfg(feature = "std")]
use std::collections::{BTreeMap, VecDeque};
use std::io::{Read, Write};
use std::sync::{Arc, Condvar, Mutex};
use std::time::Duration;
use crate::channel::{ChannelEvent, ChannelOpen, ChannelRequest};
use crate::client::{io_err, Client};
use crate::error::{Error, Result};
use crate::sftp::{SftpClient, SftpError};
const WAIT_TIMEOUT: Duration = Duration::from_millis(500);
fn lock_or_poison<'a>(m: &'a Mutex<Inner>) -> Result<std::sync::MutexGuard<'a, Inner>> {
m.lock()
.map_err(|_| Error::Protocol("SharedClient mutex poisoned"))
}
fn lock_or_poison_io<'a>(m: &'a Mutex<Inner>) -> std::io::Result<std::sync::MutexGuard<'a, Inner>> {
m.lock().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"SharedClient mutex poisoned",
)
})
}
type SftpResult<T> = core::result::Result<T, SftpError>;
const MAX_OPEN_ITER: usize = 1_000_000;
#[derive(Default)]
struct ChannelQueue {
data: VecDeque<u8>,
stderr: VecDeque<u8>,
remote_eof: bool,
remote_close: bool,
}
struct Inner {
client: Client,
queues: BTreeMap<u32, ChannelQueue>,
notifiers: BTreeMap<u32, Arc<Condvar>>,
pumping: bool,
}
#[derive(Clone)]
pub struct SharedClient {
inner: Arc<Mutex<Inner>>,
}
impl From<Client> for SharedClient {
fn from(client: Client) -> Self {
Self {
inner: Arc::new(Mutex::new(Inner {
client,
queues: BTreeMap::new(),
notifiers: BTreeMap::new(),
pumping: false,
})),
}
}
}
impl SharedClient {
pub fn sftp(&self) -> Result<SftpSession> {
let local_id = {
let mut g = lock_or_poison(&self.inner)?;
let id = open_session_under_lock(&mut g, "sftp")?;
send_request_and_await(
&mut g,
id,
ChannelRequest::Subsystem {
name: "sftp".into(),
},
"sftp: subsystem",
)?;
id
};
let stream = OwnedChannelStream {
shared: self.clone(),
channel: local_id,
local_close_sent: false,
};
match SftpClient::new(stream) {
Ok(c) => Ok(SftpSession {
_shared: self.clone(),
inner: c,
}),
Err(e) => {
Err(Error::Protocol(match e {
SftpError::Protocol(s) => s,
_ => "sftp: handshake failed",
}))
}
}
}
pub fn exec_stream(&self, command: &str) -> Result<OwnedChannelStream> {
let local_id = {
let mut g = lock_or_poison(&self.inner)?;
let id = open_session_under_lock(&mut g, "exec")?;
send_request_and_await(
&mut g,
id,
ChannelRequest::Exec {
command: command.into(),
},
"exec: command",
)?;
id
};
Ok(OwnedChannelStream {
shared: self.clone(),
channel: local_id,
local_close_sent: false,
})
}
pub fn shell(&self, term: &str, cols: u32, rows: u32) -> Result<OwnedChannelStream> {
let local_id = {
let mut g = lock_or_poison(&self.inner)?;
let id = open_session_under_lock(&mut g, "shell")?;
send_request_and_await(
&mut g,
id,
ChannelRequest::PtyReq {
term: term.into(),
cols,
rows,
px_w: 0,
px_h: 0,
modes: Vec::new(),
},
"shell: pty-req",
)?;
send_request_and_await(&mut g, id, ChannelRequest::Shell, "shell: shell-req")?;
id
};
Ok(OwnedChannelStream {
shared: self.clone(),
channel: local_id,
local_close_sent: false,
})
}
pub fn open_direct_tcpip(
&self,
dest_host: &str,
dest_port: u16,
orig_host: &str,
orig_port: u16,
) -> Result<OwnedChannelStream> {
let local_id = {
let mut g = lock_or_poison(&self.inner)?;
open_direct_tcpip_under_lock(&mut g, dest_host, dest_port, orig_host, orig_port)?
};
Ok(OwnedChannelStream {
shared: self.clone(),
channel: local_id,
local_close_sent: false,
})
}
#[cfg_attr(not(feature = "ffi"), allow(dead_code))]
pub(crate) fn with_client<R>(&self, f: impl FnOnce(&mut Client) -> R) -> R {
let mut g = self
.inner
.lock()
.expect("SharedClient mutex poisoned (with_client)");
f(&mut g.client)
}
#[allow(dead_code)]
pub(crate) fn try_with_client<R>(&self, f: impl FnOnce(&mut Client) -> R) -> Result<R> {
let mut g = lock_or_poison(&self.inner)?;
Ok(f(&mut g.client))
}
}
fn open_session_under_lock(g: &mut Inner, what: &'static str) -> Result<u32> {
let (local_id, open_payload) = g.client.conn.open(ChannelOpen::Session)?;
g.client.write_payload(&open_payload)?;
let mut opened = false;
let mut iter_guard = 0usize;
while !opened {
iter_guard += 1;
if iter_guard > MAX_OPEN_ITER {
return Err(Error::Protocol(open_loop_msg(what)));
}
let payload = g.client.read_one_packet()?;
let ev = g.client.conn.on_packet(&payload)?;
match ev {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => {
opened = true;
}
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol(open_failed_msg(what)));
}
other => dispatch_event(&mut *g, other),
}
}
g.client.maybe_send_auth_agent_req(local_id)?;
g.client.maybe_send_x11_req(local_id)?;
g.queues.entry(local_id).or_default();
Ok(local_id)
}
fn open_direct_tcpip_under_lock(
g: &mut Inner,
dest_host: &str,
dest_port: u16,
orig_host: &str,
orig_port: u16,
) -> Result<u32> {
let (local_id, open_payload) = g.client.conn.open(ChannelOpen::DirectTcpip {
dest_host: dest_host.to_string(),
dest_port: dest_port as u32,
orig_host: orig_host.to_string(),
orig_port: orig_port as u32,
})?;
g.client.write_payload(&open_payload)?;
let mut iter_guard = 0usize;
loop {
iter_guard += 1;
if iter_guard > MAX_OPEN_ITER {
return Err(Error::Protocol(open_loop_msg("direct-tcpip")));
}
let payload = g.client.read_one_packet()?;
let ev = g.client.conn.on_packet(&payload)?;
match ev {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => break,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol(open_failed_msg("direct-tcpip")));
}
other => dispatch_event(&mut *g, other),
}
}
g.queues.entry(local_id).or_default();
Ok(local_id)
}
fn send_request_and_await(
g: &mut Inner,
local_id: u32,
req: ChannelRequest,
what: &'static str,
) -> Result<()> {
let payload = g.client.conn.send_request(local_id, req, true)?;
g.client.write_payload(&payload)?;
let mut iter_guard = 0usize;
loop {
iter_guard += 1;
if iter_guard > MAX_OPEN_ITER {
return Err(Error::Protocol(reply_loop_msg(what)));
}
let payload = g.client.read_one_packet()?;
let ev = g.client.conn.on_packet(&payload)?;
match ev {
ChannelEvent::Success { channel } if channel == local_id => return Ok(()),
ChannelEvent::Failure { channel } if channel == local_id => {
return Err(Error::Protocol(reply_failed_msg(what)));
}
other => dispatch_event(&mut *g, other),
}
}
}
fn open_loop_msg(what: &'static str) -> &'static str {
match what {
"sftp" => "sftp: open loop did not converge",
"exec" => "exec: open loop did not converge",
"shell" => "shell: open loop did not converge",
"direct-tcpip" => "direct-tcpip: open loop did not converge",
_ => "channel: open loop did not converge",
}
}
fn open_failed_msg(what: &'static str) -> &'static str {
match what {
"sftp" => "sftp: channel open failed",
"exec" => "exec: channel open failed",
"shell" => "shell: channel open failed",
"direct-tcpip" => "direct-tcpip: open failed",
_ => "channel: open failed",
}
}
fn reply_loop_msg(what: &'static str) -> &'static str {
match what {
"sftp: subsystem" => "sftp: subsystem-reply loop did not converge",
"exec: command" => "exec: command-reply loop did not converge",
"shell: pty-req" => "shell: pty-req-reply loop did not converge",
"shell: shell-req" => "shell: shell-req-reply loop did not converge",
_ => "channel: request-reply loop did not converge",
}
}
fn reply_failed_msg(what: &'static str) -> &'static str {
match what {
"sftp: subsystem" => "sftp: subsystem request denied",
"exec: command" => "exec: command request denied",
"shell: pty-req" => "shell: pty-req denied",
"shell: shell-req" => "shell: shell-req denied",
_ => "channel: request denied",
}
}
fn notifier_for(g: &mut Inner, channel: u32) -> Arc<Condvar> {
g.notifiers
.entry(channel)
.or_insert_with(|| Arc::new(Condvar::new()))
.clone()
}
fn dispatch_event(g: &mut Inner, ev: ChannelEvent) {
let target = match &ev {
ChannelEvent::Data { channel, .. }
| ChannelEvent::ExtendedData { channel, .. }
| ChannelEvent::Eof { channel }
| ChannelEvent::Close { channel } => Some(*channel),
_ => None,
};
stash_event(&mut g.queues, ev);
if let Some(ch) = target {
if let Some(cv) = g.notifiers.get(&ch) {
cv.notify_all();
}
}
}
fn stash_event(queues: &mut BTreeMap<u32, ChannelQueue>, ev: ChannelEvent) {
match ev {
ChannelEvent::Data { channel, data } => {
queues.entry(channel).or_default().data.extend(data);
}
ChannelEvent::ExtendedData { channel, data, .. } => {
queues.entry(channel).or_default().stderr.extend(data);
}
ChannelEvent::Eof { channel } => {
queues.entry(channel).or_default().remote_eof = true;
}
ChannelEvent::Close { channel } => {
let q = queues.entry(channel).or_default();
q.remote_eof = true;
q.remote_close = true;
}
_ => {}
}
}
fn replenish_under_lock(g: &mut Inner, channel: u32, n: u32) -> std::io::Result<()> {
if n == 0 {
return Ok(());
}
if let Some(adj) = g.client.conn.replenish_window(channel, n).map_err(io_err)? {
g.client.write_payload(&adj).map_err(io_err)?;
}
Ok(())
}
pub struct OwnedChannelStream {
shared: SharedClient,
channel: u32,
local_close_sent: bool,
}
#[derive(Clone, Copy)]
enum Stream {
Data,
Stderr,
}
impl OwnedChannelStream {
fn drain_into(queue: &mut ChannelQueue, stream: Stream, buf: &mut [u8]) -> usize {
let src = match stream {
Stream::Data => &mut queue.data,
Stream::Stderr => &mut queue.stderr,
};
let n = core::cmp::min(buf.len(), src.len());
for slot in buf.iter_mut().take(n) {
*slot = src.pop_front().unwrap();
}
n
}
fn read_stream(&mut self, stream: Stream, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let mut g = lock_or_poison_io(&self.shared.inner)?;
loop {
let queue = g.queues.entry(self.channel).or_default();
let avail = match stream {
Stream::Data => !queue.data.is_empty(),
Stream::Stderr => !queue.stderr.is_empty(),
};
if avail {
let n = Self::drain_into(queue, stream, buf);
replenish_under_lock(&mut g, self.channel, n as u32)?;
return Ok(n);
}
if queue.remote_eof {
return Ok(0);
}
if !g.pumping {
g.pumping = true;
let res = Self::pump_one_step(&mut g);
g.pumping = false;
for cv in g.notifiers.values() {
cv.notify_one();
}
res?;
} else {
let cv = notifier_for(&mut g, self.channel);
g = cv
.wait_timeout(g, WAIT_TIMEOUT)
.map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"SharedClient mutex poisoned",
)
})?
.0;
}
}
}
pub fn read_stderr(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_stream(Stream::Stderr, buf)
}
fn pump_one_step(g: &mut Inner) -> std::io::Result<()> {
let payload = g.client.read_one_packet().map_err(io_err)?;
let ev = g.client.conn.on_packet(&payload).map_err(io_err)?;
dispatch_event(g, ev);
Ok(())
}
}
impl Read for OwnedChannelStream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.read_stream(Stream::Data, buf)
}
}
impl Write for OwnedChannelStream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let mut g = lock_or_poison_io(&self.shared.inner)?;
loop {
let (payload, taken) = g.client.conn.send_data(self.channel, buf).map_err(io_err)?;
if taken > 0 {
g.client.write_payload(&payload).map_err(io_err)?;
return Ok(taken);
}
let queue = g.queues.entry(self.channel).or_default();
if queue.remote_close {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"channel closed by peer mid-write",
));
}
if !g.pumping {
g.pumping = true;
let res = Self::pump_one_step(&mut g);
g.pumping = false;
for cv in g.notifiers.values() {
cv.notify_one();
}
res?;
} else {
let cv = notifier_for(&mut g, self.channel);
g = cv
.wait_timeout(g, WAIT_TIMEOUT)
.map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"SharedClient mutex poisoned",
)
})?
.0;
}
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl Drop for OwnedChannelStream {
fn drop(&mut self) {
let Ok(mut g) = self.shared.inner.lock() else {
return; };
if !self.local_close_sent {
if let Ok(p) = g.client.conn.send_eof(self.channel) {
let _ = g.client.write_payload(&p);
}
if let Ok(p) = g.client.conn.send_close(self.channel) {
let _ = g.client.write_payload(&p);
}
self.local_close_sent = true;
}
const MAX_DRAIN: usize = 128;
for _ in 0..MAX_DRAIN {
let already_closed = g
.queues
.get(&self.channel)
.map(|q| q.remote_close)
.unwrap_or(false);
if already_closed {
break;
}
if Self::pump_one_step(&mut g).is_err() {
break;
}
}
g.queues.remove(&self.channel);
g.notifiers.remove(&self.channel);
}
}
pub struct SftpSession {
_shared: SharedClient,
inner: SftpClient<OwnedChannelStream>,
}
impl SftpSession {
pub fn server_version(&self) -> u32 {
self.inner.server_version()
}
pub fn extensions(&self) -> &[(Vec<u8>, Vec<u8>)] {
self.inner.extensions()
}
pub fn open(
&mut self,
path: &[u8],
pflags: u32,
attrs: crate::sftp::Attrs,
) -> SftpResult<Vec<u8>> {
self.inner.open(path, pflags, attrs)
}
pub fn close(&mut self, handle: &[u8]) -> SftpResult<()> {
self.inner.close(handle)
}
pub fn read(&mut self, handle: &[u8], offset: u64, len: u32) -> SftpResult<Vec<u8>> {
self.inner.read(handle, offset, len)
}
pub fn write(&mut self, handle: &[u8], offset: u64, data: &[u8]) -> SftpResult<()> {
self.inner.write(handle, offset, data)
}
pub fn stat(&mut self, path: &[u8]) -> SftpResult<crate::sftp::Attrs> {
self.inner.stat(path)
}
pub fn lstat(&mut self, path: &[u8]) -> SftpResult<crate::sftp::Attrs> {
self.inner.lstat(path)
}
pub fn fstat(&mut self, handle: &[u8]) -> SftpResult<crate::sftp::Attrs> {
self.inner.fstat(handle)
}
pub fn setstat(&mut self, path: &[u8], attrs: crate::sftp::Attrs) -> SftpResult<()> {
self.inner.setstat(path, attrs)
}
pub fn fsetstat(&mut self, handle: &[u8], attrs: crate::sftp::Attrs) -> SftpResult<()> {
self.inner.fsetstat(handle, attrs)
}
pub fn opendir(&mut self, path: &[u8]) -> SftpResult<Vec<u8>> {
self.inner.opendir(path)
}
pub fn readdir(&mut self, handle: &[u8]) -> SftpResult<Option<Vec<crate::sftp::NameEntry>>> {
self.inner.readdir(handle)
}
pub fn mkdir(&mut self, path: &[u8], attrs: crate::sftp::Attrs) -> SftpResult<()> {
self.inner.mkdir(path, attrs)
}
pub fn rmdir(&mut self, path: &[u8]) -> SftpResult<()> {
self.inner.rmdir(path)
}
pub fn remove(&mut self, path: &[u8]) -> SftpResult<()> {
self.inner.remove(path)
}
pub fn rename(&mut self, oldpath: &[u8], newpath: &[u8]) -> SftpResult<()> {
self.inner.rename(oldpath, newpath)
}
pub fn symlink(&mut self, target_path: &[u8], link_path: &[u8]) -> SftpResult<()> {
self.inner.symlink(target_path, link_path)
}
pub fn readlink(&mut self, path: &[u8]) -> SftpResult<Vec<u8>> {
self.inner.readlink(path)
}
pub fn realpath(&mut self, path: &[u8]) -> SftpResult<Vec<u8>> {
self.inner.realpath(path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn channel_queue_default_is_empty() {
let q = ChannelQueue::default();
assert!(q.data.is_empty());
assert!(q.stderr.is_empty());
assert!(!q.remote_eof);
assert!(!q.remote_close);
}
#[test]
fn drain_into_partial() {
let mut q = ChannelQueue::default();
q.data.extend(b"hello".iter().copied());
let mut buf = [0u8; 3];
let n = OwnedChannelStream::drain_into(&mut q, Stream::Data, &mut buf);
assert_eq!(n, 3);
assert_eq!(&buf, b"hel");
assert_eq!(q.data.iter().copied().collect::<Vec<_>>(), b"lo");
}
#[test]
fn drain_into_overflow() {
let mut q = ChannelQueue::default();
q.data.extend(b"hi".iter().copied());
let mut buf = [0u8; 8];
let n = OwnedChannelStream::drain_into(&mut q, Stream::Data, &mut buf);
assert_eq!(n, 2);
assert_eq!(&buf[..2], b"hi");
assert!(q.data.is_empty());
}
#[test]
fn drain_into_stderr() {
let mut q = ChannelQueue::default();
q.stderr.extend(b"err".iter().copied());
q.data.extend(b"std".iter().copied());
let mut buf = [0u8; 8];
let n = OwnedChannelStream::drain_into(&mut q, Stream::Stderr, &mut buf);
assert_eq!(n, 3);
assert_eq!(&buf[..3], b"err");
assert!(q.stderr.is_empty());
assert_eq!(q.data.iter().copied().collect::<Vec<_>>(), b"std");
}
#[test]
fn stash_event_data_appends_to_right_channel() {
let mut queues: BTreeMap<u32, ChannelQueue> = BTreeMap::new();
stash_event(
&mut queues,
ChannelEvent::Data {
channel: 7,
data: b"abc".to_vec(),
},
);
stash_event(
&mut queues,
ChannelEvent::Data {
channel: 7,
data: b"def".to_vec(),
},
);
stash_event(
&mut queues,
ChannelEvent::Data {
channel: 9,
data: b"x".to_vec(),
},
);
assert_eq!(
queues[&7].data.iter().copied().collect::<Vec<_>>(),
b"abcdef"
);
assert_eq!(queues[&9].data.iter().copied().collect::<Vec<_>>(), b"x");
}
#[test]
fn stash_event_eof_and_close_set_flags() {
let mut queues: BTreeMap<u32, ChannelQueue> = BTreeMap::new();
stash_event(&mut queues, ChannelEvent::Eof { channel: 3 });
assert!(queues[&3].remote_eof);
assert!(!queues[&3].remote_close);
stash_event(&mut queues, ChannelEvent::Close { channel: 3 });
assert!(queues[&3].remote_eof);
assert!(queues[&3].remote_close);
}
#[test]
fn notifier_map_round_trips_arc_identity() {
let mut notifiers: BTreeMap<u32, Arc<Condvar>> = BTreeMap::new();
let cv1 = notifiers
.entry(7)
.or_insert_with(|| Arc::new(Condvar::new()))
.clone();
let cv2 = notifiers
.entry(7)
.or_insert_with(|| Arc::new(Condvar::new()))
.clone();
assert!(Arc::ptr_eq(&cv1, &cv2));
let cv3 = notifiers
.entry(9)
.or_insert_with(|| Arc::new(Condvar::new()))
.clone();
assert!(!Arc::ptr_eq(&cv1, &cv3));
}
#[test]
fn stash_event_ignores_irrelevant() {
let mut queues: BTreeMap<u32, ChannelQueue> = BTreeMap::new();
stash_event(&mut queues, ChannelEvent::OpenConfirmed { channel: 1 });
stash_event(
&mut queues,
ChannelEvent::OpenFailed {
channel: 1,
reason: 0,
description: String::new(),
},
);
stash_event(
&mut queues,
ChannelEvent::WindowAdjust {
channel: 1,
added: 100,
},
);
assert!(queues.is_empty());
}
}