#![cfg(feature = "std")]
use std::collections::BTreeMap;
use std::io::{ErrorKind, Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::mpsc::{self, Receiver, Sender, TryRecvError};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use purecrypto::rng::{CryptoRng, OsRng, RngCore};
use crate::auth::{Authenticator, ServerAuth, ServerStep};
use crate::channel::{
ChannelEvent, ChannelOpen, ChannelRequest, ConnectionState, SSH_EXTENDED_DATA_STDERR,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
};
use crate::error::{Error, Result};
use crate::format::Writer;
use crate::hostkey::HostKey;
use crate::transport::kex::{defaults, KexAlgorithms};
use crate::transport::rekey::{is_kex_msg, RekeyPolicy};
use crate::transport::{KexInit, KexRunner, PacketCodec, Role, VersionExchange};
const MAX_BANNER_LINE: usize = 1024;
const MAX_BANNER_LINES: usize = 256;
const MAX_INBOX_BYTES: usize = 8 * 1024 * 1024;
const MAX_KEX_STEPS: usize = 32;
const MAX_AUTH_STEPS: usize = 64;
const MAX_CONNECTION_STEPS: usize = 10_000_000;
const MAX_DRAIN_STEPS: usize = 1_000_000;
const SUBSYSTEM_EGRESS_BACKLOG: usize = 32;
const SSH_DISCONNECT_BY_APPLICATION: u32 = 11;
const SSH_DISCONNECT_HOST_NOT_ALLOWED: u32 = 9;
#[derive(Debug, Clone)]
pub struct ExecResult {
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
pub exit_status: u32,
}
#[derive(Debug, Default, Clone)]
pub struct SessionEnv {
vars: BTreeMap<String, String>,
}
impl SessionEnv {
pub fn new() -> Self {
Self {
vars: BTreeMap::new(),
}
}
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> Option<String> {
self.vars.insert(key.into(), value.into())
}
pub fn get(&self, key: &str) -> Option<&str> {
self.vars.get(key).map(|s| s.as_str())
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.vars.iter().map(|(k, v)| (k.as_str(), v.as_str()))
}
pub fn len(&self) -> usize {
self.vars.len()
}
pub fn is_empty(&self) -> bool {
self.vars.is_empty()
}
}
pub trait CommandHandler: Send + Sync {
fn handle(&self, user: &str, env: &SessionEnv, command: &str) -> ExecResult;
}
struct ShellRuntime {
pending_pty: Option<PtySpec>,
session: Option<Box<dyn ShellSession>>,
exited: Option<ShellExitStatus>,
exit_sent: bool,
pending_stdout: Vec<u8>,
}
impl ShellRuntime {
fn new() -> Self {
Self {
pending_pty: None,
session: None,
exited: None,
exit_sent: false,
pending_stdout: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct PtySpec {
pub term: String,
pub cols: u32,
pub rows: u32,
pub px_w: u32,
pub px_h: u32,
pub modes: Vec<u8>,
}
#[derive(Debug, Clone)]
pub enum ShellExitStatus {
Exited(u32),
Signalled {
name: String,
core_dumped: bool,
message: String,
},
}
pub trait ShellHandler: Send + Sync {
fn spawn(
&self,
user: &str,
env: &SessionEnv,
pty: Option<PtySpec>,
) -> Result<Box<dyn ShellSession>>;
}
pub trait ShellSession: Send {
fn read(&mut self, buf: &mut [u8]) -> Result<usize>;
fn write(&mut self, buf: &[u8]) -> Result<usize>;
fn close_stdin(&mut self) -> Result<()>;
fn resize(&mut self, cols: u32, rows: u32, px_w: u32, px_h: u32) -> Result<()>;
fn try_exit(&mut self) -> Option<ShellExitStatus>;
}
pub use crate::stream::{ChannelEgress, ChannelStream};
struct SubsystemRuntime {
ingress_tx: Sender<Option<Vec<u8>>>,
egress_rx: Receiver<ChannelEgress>,
pending_data: Vec<u8>,
pending_eof: bool,
pending_close: bool,
eof_sent: bool,
close_sent: bool,
}
pub type SessionOpenCallback = Arc<dyn Fn(&str) -> Result<()> + Send + Sync>;
pub trait SubsystemHandler: Send + Sync {
fn handle(&self, user: &str, env: &SessionEnv, name: &str, stream: ChannelStream)
-> Result<()>;
}
pub trait ExecStreamHandler: Send + Sync {
fn claims(&self, command: &str) -> bool;
fn run(&self, user: &str, env: &SessionEnv, command: &str, stream: ChannelStream)
-> Result<()>;
}
#[derive(Debug, Clone, Copy)]
pub struct DirectTcpipRequest<'a> {
pub dest_host: &'a str,
pub dest_port: u32,
pub orig_host: &'a str,
pub orig_port: u32,
}
pub trait DirectTcpipHandler: Send + Sync {
fn handle(
&self,
user: &str,
request: DirectTcpipRequest<'_>,
stream: ChannelStream,
) -> Result<()>;
}
#[derive(Clone)]
pub struct ForwardContext {
req_tx: Sender<ForwardOpenRequest>,
}
pub(crate) struct ForwardOpenRequest {
bound_address: String,
bound_port: u32,
orig_address: String,
orig_port: u32,
reply: std::sync::mpsc::SyncSender<Result<ChannelStream>>,
}
impl ForwardContext {
pub(crate) fn new(req_tx: Sender<ForwardOpenRequest>) -> Self {
Self { req_tx }
}
#[doc(hidden)]
pub fn for_test_no_opens() -> Self {
let (tx, _rx) = mpsc::channel();
drop(_rx);
Self { req_tx: tx }
}
pub fn open_forwarded_tcpip(
&self,
bound_address: &str,
bound_port: u16,
orig_address: &str,
orig_port: u16,
) -> Result<ChannelStream> {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
self.req_tx
.send(ForwardOpenRequest {
bound_address: bound_address.to_string(),
bound_port: bound_port as u32,
orig_address: orig_address.to_string(),
orig_port: orig_port as u32,
reply: tx,
})
.map_err(|_| Error::Protocol("forwarded-tcpip: connection closed"))?;
rx.recv()
.map_err(|_| Error::Protocol("forwarded-tcpip: reply dropped"))?
}
}
pub trait TcpipForwardHandler: Send + Sync {
fn bind(
&self,
user: &str,
bind_address: &str,
bind_port: u16,
ctx: ForwardContext,
) -> Result<u16>;
fn unbind(&self, user: &str, bind_address: &str, bind_port: u16) -> Result<()>;
}
#[derive(Clone)]
pub struct AgentForwardContext {
req_tx: Sender<AgentOpenRequest>,
}
pub(crate) struct AgentOpenRequest {
reply: std::sync::mpsc::SyncSender<Result<ChannelStream>>,
}
impl AgentForwardContext {
pub(crate) fn new(req_tx: Sender<AgentOpenRequest>) -> Self {
Self { req_tx }
}
#[doc(hidden)]
pub fn for_test_no_opens() -> Self {
let (tx, _rx) = mpsc::channel();
drop(_rx);
Self { req_tx: tx }
}
pub fn open_auth_agent(&self) -> Result<ChannelStream> {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
self.req_tx
.send(AgentOpenRequest { reply: tx })
.map_err(|_| Error::Protocol("auth-agent: connection closed"))?;
rx.recv()
.map_err(|_| Error::Protocol("auth-agent: reply dropped"))?
}
}
pub struct AgentForwardHandle {
pub auth_sock_path: std::path::PathBuf,
pub stopper: Box<dyn core::any::Any + Send + Sync>,
}
pub trait AgentForwardHandler: Send + Sync {
fn setup(&self, user: &str, ctx: AgentForwardContext) -> Result<AgentForwardHandle>;
}
#[derive(Clone)]
pub struct X11ForwardContext {
req_tx: Sender<X11OpenRequest>,
}
pub(crate) struct X11OpenRequest {
pub orig_host: String,
pub orig_port: u32,
pub reply: std::sync::mpsc::SyncSender<Result<ChannelStream>>,
}
impl X11ForwardContext {
pub(crate) fn new(req_tx: Sender<X11OpenRequest>) -> Self {
Self { req_tx }
}
#[doc(hidden)]
pub fn for_test_no_opens() -> Self {
let (tx, _rx) = mpsc::channel();
drop(_rx);
Self { req_tx: tx }
}
pub fn open_x11(&self, orig_host: String, orig_port: u32) -> Result<ChannelStream> {
let (tx, rx) = std::sync::mpsc::sync_channel(1);
self.req_tx
.send(X11OpenRequest {
orig_host,
orig_port,
reply: tx,
})
.map_err(|_| Error::Protocol("x11: connection closed"))?;
rx.recv()
.map_err(|_| Error::Protocol("x11: reply dropped"))?
}
}
pub struct X11ForwardHandle {
pub display_env: String,
pub display_number: u16,
pub stopper: Box<dyn core::any::Any + Send + Sync>,
}
pub trait X11ForwardHandler: Send + Sync {
fn setup(
&self,
user: &str,
single_connection: bool,
auth_protocol: &str,
auth_cookie: &str,
screen: u32,
ctx: X11ForwardContext,
) -> Result<X11ForwardHandle>;
}
pub struct Config {
pub host_keys: Vec<Box<dyn HostKey + Send + Sync>>,
pub authenticator: Arc<dyn AuthenticatorFactory>,
pub allowed_auth_methods: Vec<&'static str>,
pub command_handler: Arc<dyn CommandHandler>,
pub exec_stream_handler: Option<Arc<dyn ExecStreamHandler>>,
pub shell_handler: Option<Arc<dyn ShellHandler>>,
pub subsystem_handler: Option<Arc<dyn SubsystemHandler>>,
pub direct_tcpip_handler: Option<Arc<dyn DirectTcpipHandler>>,
pub tcpip_forward_handler: Option<Arc<dyn TcpipForwardHandler>>,
pub agent_forward_handler: Option<Arc<dyn AgentForwardHandler>>,
pub x11_forward_handler: Option<Arc<dyn X11ForwardHandler>>,
pub on_session_open: Option<SessionOpenCallback>,
pub rekey_policy: RekeyPolicy,
}
impl Config {
pub fn new(
host_keys: Vec<Box<dyn HostKey + Send + Sync>>,
authenticator: Arc<dyn AuthenticatorFactory>,
allowed_auth_methods: Vec<&'static str>,
command_handler: Arc<dyn CommandHandler>,
) -> Self {
Self {
host_keys,
authenticator,
allowed_auth_methods,
command_handler,
exec_stream_handler: None,
shell_handler: None,
subsystem_handler: None,
direct_tcpip_handler: None,
tcpip_forward_handler: None,
agent_forward_handler: None,
x11_forward_handler: None,
on_session_open: None,
rekey_policy: RekeyPolicy::default(),
}
}
pub fn with_shell(mut self, handler: Arc<dyn ShellHandler>) -> Self {
self.shell_handler = Some(handler);
self
}
pub fn with_subsystem(mut self, handler: Arc<dyn SubsystemHandler>) -> Self {
self.subsystem_handler = Some(handler);
self
}
pub fn with_exec_stream_handler(mut self, handler: Arc<dyn ExecStreamHandler>) -> Self {
self.exec_stream_handler = Some(handler);
self
}
pub fn with_direct_tcpip(mut self, handler: Arc<dyn DirectTcpipHandler>) -> Self {
self.direct_tcpip_handler = Some(handler);
self
}
pub fn with_tcpip_forward(mut self, handler: Arc<dyn TcpipForwardHandler>) -> Self {
self.tcpip_forward_handler = Some(handler);
self
}
pub fn with_agent_forward(mut self, handler: Arc<dyn AgentForwardHandler>) -> Self {
self.agent_forward_handler = Some(handler);
self
}
pub fn with_x11_forward(mut self, handler: Arc<dyn X11ForwardHandler>) -> Self {
self.x11_forward_handler = Some(handler);
self
}
pub fn on_session_open<F>(mut self, f: F) -> Self
where
F: Fn(&str) -> Result<()> + Send + Sync + 'static,
{
self.on_session_open = Some(Arc::new(f));
self
}
}
pub trait AuthenticatorFactory: Send + Sync {
fn build(&self) -> Box<dyn Authenticator>;
}
impl<F> AuthenticatorFactory for F
where
F: Fn() -> Box<dyn Authenticator> + Send + Sync,
{
fn build(&self) -> Box<dyn Authenticator> {
(self)()
}
}
pub struct Server {
listener: TcpListener,
cfg: Arc<Config>,
}
impl Server {
pub fn bind<A: ToSocketAddrs>(addr: A, cfg: Config) -> Result<Self> {
if cfg.host_keys.is_empty() {
return Err(Error::Protocol("server: no host keys configured"));
}
let listener = TcpListener::bind(addr)?;
Ok(Self {
listener,
cfg: Arc::new(cfg),
})
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.listener.local_addr()?)
}
pub fn accept_one(&mut self) -> Result<()> {
let (stream, _peer) = self.listener.accept()?;
handle_session(stream, self.cfg.clone())
}
pub fn serve(&mut self) -> Result<()> {
loop {
let (stream, _peer) = self.listener.accept()?;
let cfg = self.cfg.clone();
thread::spawn(move || {
let _ = handle_session(stream, cfg);
});
}
}
}
pub fn handle_session(stream: TcpStream, cfg: Arc<Config>) -> Result<()> {
handle_connection_inner(stream, cfg)
}
fn handle_connection_inner(mut stream: TcpStream, cfg: Arc<Config>) -> Result<()> {
stream.set_nodelay(true)?;
let mut codec = PacketCodec::new();
let mut inbox: Vec<u8> = Vec::new();
let mut rng = OsRng;
let v_s = crate::transport::version::LOCAL_VERSION.as_bytes().to_vec();
stream.write_all(&VersionExchange::outgoing_bytes())?;
let v_c = read_peer_version(&mut stream)?;
let (mut runner, session_id) = do_server_kex(
&mut stream,
&mut codec,
&mut rng,
&mut inbox,
&cfg,
&v_c,
&v_s,
)?;
let mut last_kex = Instant::now();
let user = do_server_auth(
&mut stream,
&mut codec,
&mut rng,
&mut inbox,
&cfg,
session_id,
)?;
if let Some(hook) = cfg.on_session_open.clone() {
hook(&user)?;
}
codec.activate_compress();
let rekey_policy = cfg.rekey_policy;
let r = do_connection_phase(
&mut stream,
&mut codec,
&mut rng,
&mut inbox,
&cfg,
&user,
&mut runner,
&v_c,
&v_s,
&mut last_kex,
&rekey_policy,
);
let _ = send_disconnect(
&mut stream,
&mut codec,
&mut rng,
SSH_DISCONNECT_BY_APPLICATION,
"closing session",
);
r
}
fn do_server_kex<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
cfg: &Config,
v_c: &[u8],
v_s: &[u8],
) -> Result<(KexRunner, Vec<u8>)> {
let advert = build_server_kexinit(rng, &cfg.host_keys);
let mut runner = KexRunner::new(Role::Server, advert);
let initial = runner.start(rng)?;
for p in initial.outbound {
write_payload(stream, codec, rng, &p)?;
}
drive_server_kex(stream, codec, rng, inbox, &mut runner, cfg, v_c, v_s)?;
let sid = runner
.session_id()
.ok_or(Error::Protocol("kex: missing session id"))?
.to_vec();
Ok((runner, sid))
}
#[allow(clippy::too_many_arguments)]
fn drive_server_kex<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
runner: &mut KexRunner,
cfg: &Config,
v_c: &[u8],
v_s: &[u8],
) -> Result<()> {
let mut steps = 0usize;
let mut selected_host_key: Option<&(dyn HostKey + Send + Sync)> = None;
loop {
steps += 1;
if steps > MAX_KEX_STEPS {
return Err(Error::Protocol("kex: too many steps"));
}
let payload = read_one_packet(stream, codec, inbox)?;
if selected_host_key.is_none() {
if let Some(neg) = runner.negotiated() {
selected_host_key = pick_host_key(&cfg.host_keys, &neg.host_key);
if selected_host_key.is_none() {
return Err(Error::Protocol("kex: no host key for negotiated algorithm"));
}
}
}
let hk_ref: Option<&dyn HostKey> = selected_host_key.map(|k| k as &dyn HostKey);
let adv = runner.on_packet(rng, codec, &payload, hk_ref, None, v_c, v_s)?;
for p in adv.outbound {
write_payload(stream, codec, rng, &p)?;
}
if adv.completed {
return Ok(());
}
}
}
fn do_server_auth<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
cfg: &Config,
session_id: Vec<u8>,
) -> Result<String> {
let methods = cfg.allowed_auth_methods.clone();
let auth_impl = cfg.authenticator.build();
let mut server_auth = ServerAuth::new(session_id, methods, auth_impl);
for _ in 0..MAX_AUTH_STEPS {
let payload = read_one_packet(stream, codec, inbox)?;
match server_auth.on_packet(&payload)? {
ServerStep::Send(p) => write_payload(stream, codec, rng, &p)?,
ServerStep::Authenticated { payload, user } => {
write_payload(stream, codec, rng, &payload)?;
return Ok(user);
}
ServerStep::Disconnect(reason) => {
let _ =
send_disconnect(stream, codec, rng, SSH_DISCONNECT_HOST_NOT_ALLOWED, reason);
return Err(Error::AuthFailed);
}
}
}
Err(Error::Protocol("auth: too many steps"))
}
struct ForwardConn {
req_tx: Sender<ForwardOpenRequest>,
req_rx: Receiver<ForwardOpenRequest>,
pending_opens: BTreeMap<u32, std::sync::mpsc::SyncSender<Result<ChannelStream>>>,
owned_bindings: Vec<(String, u16)>,
}
impl ForwardConn {
fn new() -> Self {
let (req_tx, req_rx) = std::sync::mpsc::channel();
Self {
req_tx,
req_rx,
pending_opens: BTreeMap::new(),
owned_bindings: Vec::new(),
}
}
fn drain_pending<R: RngCore + CryptoRng>(
&mut self,
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
) -> Result<()> {
loop {
match self.req_rx.try_recv() {
Ok(req) => {
let kind = ChannelOpen::ForwardedTcpip {
dest_host: req.bound_address.clone(),
dest_port: req.bound_port,
orig_host: req.orig_address.clone(),
orig_port: req.orig_port,
};
let (local_id, payload) = conn.open(kind)?;
write_payload(stream, codec, rng, &payload)?;
self.pending_opens.insert(local_id, req.reply);
}
Err(TryRecvError::Empty) => return Ok(()),
Err(TryRecvError::Disconnected) => return Ok(()),
}
}
}
}
struct AgentForwardConn {
req_tx: Sender<AgentOpenRequest>,
req_rx: Receiver<AgentOpenRequest>,
pending_opens: BTreeMap<u32, std::sync::mpsc::SyncSender<Result<ChannelStream>>>,
active: BTreeMap<u32, AgentForwardHandle>,
}
impl AgentForwardConn {
fn new() -> Self {
let (req_tx, req_rx) = std::sync::mpsc::channel();
Self {
req_tx,
req_rx,
pending_opens: BTreeMap::new(),
active: BTreeMap::new(),
}
}
fn drain_pending<R: RngCore + CryptoRng>(
&mut self,
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
) -> Result<()> {
loop {
match self.req_rx.try_recv() {
Ok(req) => {
let (local_id, payload) = conn.open(ChannelOpen::AuthAgent)?;
write_payload(stream, codec, rng, &payload)?;
self.pending_opens.insert(local_id, req.reply);
}
Err(TryRecvError::Empty) => return Ok(()),
Err(TryRecvError::Disconnected) => return Ok(()),
}
}
}
}
struct X11ForwardConn {
req_tx: Sender<X11OpenRequest>,
req_rx: Receiver<X11OpenRequest>,
pending_opens: BTreeMap<u32, std::sync::mpsc::SyncSender<Result<ChannelStream>>>,
active: BTreeMap<u32, X11ForwardHandle>,
}
impl X11ForwardConn {
fn new() -> Self {
let (req_tx, req_rx) = std::sync::mpsc::channel();
Self {
req_tx,
req_rx,
pending_opens: BTreeMap::new(),
active: BTreeMap::new(),
}
}
fn drain_pending<R: RngCore + CryptoRng>(
&mut self,
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
) -> Result<()> {
loop {
match self.req_rx.try_recv() {
Ok(req) => {
let kind = ChannelOpen::X11 {
orig_host: req.orig_host.clone(),
orig_port: req.orig_port,
};
let (local_id, payload) = conn.open(kind)?;
write_payload(stream, codec, rng, &payload)?;
self.pending_opens.insert(local_id, req.reply);
}
Err(TryRecvError::Empty) => return Ok(()),
Err(TryRecvError::Disconnected) => return Ok(()),
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn do_connection_phase<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
cfg: &Config,
user: &str,
runner: &mut KexRunner,
v_c: &[u8],
v_s: &[u8],
last_kex: &mut Instant,
rekey_policy: &RekeyPolicy,
) -> Result<()> {
let mut conn = ConnectionState::new();
let mut any_channel_opened = false;
let mut steps = 0usize;
let mut deferred: Vec<Vec<u8>> = Vec::new();
let mut shells: BTreeMap<u32, ShellRuntime> = BTreeMap::new();
let mut subsystems: BTreeMap<u32, SubsystemRuntime> = BTreeMap::new();
let mut envs: BTreeMap<u32, SessionEnv> = BTreeMap::new();
let mut forward = ForwardConn::new();
let mut agent_forward = AgentForwardConn::new();
let mut x11_forward = X11ForwardConn::new();
let mut polling_active = false;
let result = do_connection_loop(
stream,
codec,
rng,
inbox,
cfg,
user,
runner,
v_c,
v_s,
last_kex,
rekey_policy,
&mut conn,
&mut any_channel_opened,
&mut steps,
&mut deferred,
&mut shells,
&mut subsystems,
&mut envs,
&mut forward,
&mut agent_forward,
&mut x11_forward,
&mut polling_active,
);
if let Some(handler) = cfg.tcpip_forward_handler.clone() {
for (addr, port) in forward.owned_bindings.drain(..) {
let _ = handler.unbind(user, &addr, port);
}
}
for (_id, reply) in forward.pending_opens.drain_filter_compat() {
let _ = reply.send(Err(Error::Protocol(
"forwarded-tcpip: connection torn down",
)));
}
agent_forward.active.clear();
for (_id, reply) in agent_forward.pending_opens.drain_filter_compat() {
let _ = reply.send(Err(Error::Protocol("auth-agent: connection torn down")));
}
x11_forward.active.clear();
for (_id, reply) in x11_forward.pending_opens.drain_filter_compat() {
let _ = reply.send(Err(Error::Protocol("x11: connection torn down")));
}
result
}
trait DrainFilterCompat<K, V> {
fn drain_filter_compat(&mut self) -> alloc::vec::IntoIter<(K, V)>;
}
impl<K: Ord + Clone, V> DrainFilterCompat<K, V> for BTreeMap<K, V> {
fn drain_filter_compat(&mut self) -> alloc::vec::IntoIter<(K, V)> {
let keys: Vec<K> = self.keys().cloned().collect();
let mut out = Vec::with_capacity(keys.len());
for k in keys {
if let Some(v) = self.remove(&k) {
out.push((k, v));
}
}
out.into_iter()
}
}
#[allow(clippy::too_many_arguments)]
fn do_connection_loop<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
cfg: &Config,
user: &str,
runner: &mut KexRunner,
v_c: &[u8],
v_s: &[u8],
last_kex: &mut Instant,
rekey_policy: &RekeyPolicy,
conn: &mut ConnectionState,
any_channel_opened: &mut bool,
steps: &mut usize,
deferred: &mut Vec<Vec<u8>>,
shells: &mut BTreeMap<u32, ShellRuntime>,
subsystems: &mut BTreeMap<u32, SubsystemRuntime>,
envs: &mut BTreeMap<u32, SessionEnv>,
forward: &mut ForwardConn,
agent_forward: &mut AgentForwardConn,
x11_forward: &mut X11ForwardConn,
polling_active: &mut bool,
) -> Result<()> {
loop {
*steps += 1;
if *steps > MAX_CONNECTION_STEPS {
return Err(Error::Protocol("connection: step cap exceeded"));
}
if !runner.is_kexing() && !deferred.is_empty() {
let payload = deferred.remove(0);
dispatch_app_packet(
stream,
codec,
rng,
inbox,
conn,
cfg,
user,
&payload,
any_channel_opened,
shells,
subsystems,
envs,
forward,
agent_forward,
x11_forward,
)?;
continue;
}
let any_shell_alive = shells.values().any(|rt| rt.session.is_some());
let any_subsystem_alive = !subsystems.is_empty();
let any_forward_alive = !forward.owned_bindings.is_empty();
let any_agent_fwd_alive = !agent_forward.active.is_empty();
let any_x11_fwd_alive = !x11_forward.active.is_empty();
let want_polling = any_shell_alive
|| any_subsystem_alive
|| any_forward_alive
|| any_agent_fwd_alive
|| any_x11_fwd_alive;
if want_polling && !*polling_active {
let _ = stream.set_read_timeout(Some(Duration::from_millis(50)));
*polling_active = true;
} else if !want_polling && *polling_active {
let _ = stream.set_read_timeout(None);
*polling_active = false;
}
if *polling_active && !runner.is_kexing() {
drain_shells(stream, codec, rng, conn, shells)?;
finalize_exited_shells(stream, codec, rng, conn, shells)?;
drain_subsystems(stream, codec, rng, conn, subsystems)?;
forward.drain_pending(stream, codec, rng, conn)?;
agent_forward.drain_pending(stream, codec, rng, conn)?;
x11_forward.drain_pending(stream, codec, rng, conn)?;
}
if *any_channel_opened
&& !conn.channels().any(|c| !c.is_fully_closed())
&& deferred.is_empty()
&& !any_forward_alive
&& !any_agent_fwd_alive
&& !any_x11_fwd_alive
{
return Ok(());
}
if !runner.is_kexing() && rekey_policy.should_rekey(codec, *last_kex, Instant::now()) {
let advert = build_server_kexinit(rng, &cfg.host_keys);
let adv = runner.restart(rng, advert)?;
for p in adv.outbound {
write_payload(stream, codec, rng, &p)?;
}
}
let payload = if *polling_active {
match read_one_packet_maybe_timeout(stream, codec, inbox)? {
Some(p) => p,
None => continue, }
} else {
read_one_packet(stream, codec, inbox)?
};
let msg = payload.first().copied().unwrap_or(0);
if is_kex_msg(msg) {
if msg == 20 && !runner.is_kexing() {
let advert = build_server_kexinit(rng, &cfg.host_keys);
let adv = runner.restart(rng, advert)?;
for p in adv.outbound {
write_payload(stream, codec, rng, &p)?;
}
}
let hk_ref: Option<&dyn HostKey> = match runner.negotiated() {
Some(neg) => {
pick_host_key(&cfg.host_keys, &neg.host_key).map(|k| k as &dyn HostKey)
}
None => None,
};
let adv = runner.on_packet(rng, codec, &payload, hk_ref, None, v_c, v_s)?;
for p in adv.outbound {
write_payload(stream, codec, rng, &p)?;
}
if adv.completed {
*last_kex = Instant::now();
}
continue;
}
if runner.is_kexing() {
deferred.push(payload);
continue;
}
dispatch_app_packet(
stream,
codec,
rng,
inbox,
conn,
cfg,
user,
&payload,
any_channel_opened,
shells,
subsystems,
envs,
forward,
agent_forward,
x11_forward,
)?;
}
}
fn drain_shells<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
shells: &mut BTreeMap<u32, ShellRuntime>,
) -> Result<()> {
let mut buf = [0u8; 8 * 1024];
let channels: Vec<u32> = shells.keys().copied().collect();
for ch in channels {
let Some(rt) = shells.get_mut(&ch) else {
continue;
};
if rt.session.is_none() {
continue;
}
if !rt.pending_stdout.is_empty() {
let leftover = core::mem::take(&mut rt.pending_stdout);
emit_channel_data(stream, codec, rng, conn, ch, &leftover, rt)?;
}
let mut pulled = 0usize;
while pulled < 64 * 1024 {
if let Some(sess) = rt.session.as_mut() {
let n = sess.read(&mut buf)?;
if n == 0 {
break;
}
pulled += n;
let bytes = buf[..n].to_vec();
emit_channel_data(stream, codec, rng, conn, ch, &bytes, rt)?;
} else {
break;
}
}
if rt.exited.is_none() {
if let Some(sess) = rt.session.as_mut() {
if let Some(status) = sess.try_exit() {
rt.exited = Some(status);
}
}
}
}
Ok(())
}
fn emit_channel_data<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
channel: u32,
bytes: &[u8],
rt: &mut ShellRuntime,
) -> Result<()> {
let mut off = 0usize;
while off < bytes.len() {
let (payload, taken) = conn.send_data(channel, &bytes[off..])?;
if taken == 0 {
rt.pending_stdout.extend_from_slice(&bytes[off..]);
return Ok(());
}
write_payload(stream, codec, rng, &payload)?;
off += taken;
}
Ok(())
}
fn finalize_exited_shells<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
shells: &mut BTreeMap<u32, ShellRuntime>,
) -> Result<()> {
let channels: Vec<u32> = shells.keys().copied().collect();
for ch in channels {
let Some(rt) = shells.get_mut(&ch) else {
continue;
};
if rt.exit_sent {
continue;
}
if !rt.pending_stdout.is_empty() {
continue;
}
let Some(status) = rt.exited.take() else {
continue;
};
let req = match status {
ShellExitStatus::Exited(code) => ChannelRequest::ExitStatus { code },
ShellExitStatus::Signalled {
name,
core_dumped,
message,
} => ChannelRequest::ExitSignal {
name,
core_dumped,
message,
language: String::new(),
},
};
let p = conn.send_request(ch, req, false)?;
write_payload(stream, codec, rng, &p)?;
let p = conn.send_eof(ch)?;
write_payload(stream, codec, rng, &p)?;
let p = conn.send_close(ch)?;
write_payload(stream, codec, rng, &p)?;
rt.exit_sent = true;
rt.session = None;
}
Ok(())
}
fn drain_subsystems<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
subsystems: &mut BTreeMap<u32, SubsystemRuntime>,
) -> Result<()> {
let channels: Vec<u32> = subsystems.keys().copied().collect();
for ch in channels {
let Some(rt) = subsystems.get_mut(&ch) else {
continue;
};
if rt.close_sent {
continue;
}
if !rt.pending_data.is_empty() {
let leftover = core::mem::take(&mut rt.pending_data);
emit_subsystem_data(stream, codec, rng, conn, ch, &leftover, rt)?;
if !rt.pending_data.is_empty() {
continue;
}
}
loop {
if !rt.pending_data.is_empty() {
break;
}
match rt.egress_rx.try_recv() {
Ok(ChannelEgress::Data(bytes)) => {
emit_subsystem_data(stream, codec, rng, conn, ch, &bytes, rt)?;
}
Ok(ChannelEgress::Eof) => {
rt.pending_eof = true;
break;
}
Ok(ChannelEgress::Close) => {
rt.pending_close = true;
break;
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
rt.pending_close = true;
break;
}
}
}
if rt.pending_data.is_empty() {
if rt.pending_eof && !rt.eof_sent {
let p = conn.send_eof(ch)?;
write_payload(stream, codec, rng, &p)?;
rt.eof_sent = true;
}
if rt.pending_close && !rt.close_sent {
if !rt.eof_sent {
let p = conn.send_eof(ch)?;
write_payload(stream, codec, rng, &p)?;
rt.eof_sent = true;
}
let p = conn.send_close(ch)?;
write_payload(stream, codec, rng, &p)?;
rt.close_sent = true;
}
}
}
Ok(())
}
fn emit_subsystem_data<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
channel: u32,
bytes: &[u8],
rt: &mut SubsystemRuntime,
) -> Result<()> {
let mut off = 0usize;
while off < bytes.len() {
let (payload, taken) = conn.send_data(channel, &bytes[off..])?;
if taken == 0 {
rt.pending_data.extend_from_slice(&bytes[off..]);
return Ok(());
}
write_payload(stream, codec, rng, &payload)?;
off += taken;
}
Ok(())
}
fn read_one_packet_maybe_timeout(
stream: &mut TcpStream,
codec: &mut PacketCodec,
inbox: &mut Vec<u8>,
) -> Result<Option<Vec<u8>>> {
match read_one_packet(stream, codec, inbox) {
Ok(p) => Ok(Some(p)),
Err(Error::Io(e))
if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut =>
{
Ok(None)
}
Err(e) => Err(e),
}
}
#[allow(clippy::too_many_arguments)]
fn dispatch_app_packet<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
conn: &mut ConnectionState,
cfg: &Config,
user: &str,
payload: &[u8],
any_channel_opened: &mut bool,
shells: &mut BTreeMap<u32, ShellRuntime>,
subsystems: &mut BTreeMap<u32, SubsystemRuntime>,
envs: &mut BTreeMap<u32, SessionEnv>,
forward: &mut ForwardConn,
agent_forward: &mut AgentForwardConn,
x11_forward: &mut X11ForwardConn,
) -> Result<()> {
let ev = conn.on_packet(payload)?;
match ev {
ChannelEvent::OpenConfirmed { channel } => {
if let Some(reply) = forward.pending_opens.remove(&channel) {
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SUBSYSTEM_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
subsystems.insert(
channel,
SubsystemRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
let _ = reply.send(Ok(cs));
} else if let Some(reply) = agent_forward.pending_opens.remove(&channel) {
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SUBSYSTEM_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
subsystems.insert(
channel,
SubsystemRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
let _ = reply.send(Ok(cs));
} else if let Some(reply) = x11_forward.pending_opens.remove(&channel) {
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SUBSYSTEM_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
subsystems.insert(
channel,
SubsystemRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
let _ = reply.send(Ok(cs));
}
}
ChannelEvent::OpenFailed {
channel,
reason: _reason,
description: _description,
} => {
if let Some(reply) = forward.pending_opens.remove(&channel) {
let _ = reply.send(Err(Error::Protocol(
"forwarded-tcpip: open rejected by peer",
)));
} else if let Some(reply) = agent_forward.pending_opens.remove(&channel) {
let _ = reply.send(Err(Error::Protocol("auth-agent: open rejected by peer")));
} else if let Some(reply) = x11_forward.pending_opens.remove(&channel) {
let _ = reply.send(Err(Error::Protocol("x11: open rejected by peer")));
}
}
ChannelEvent::OpenRequest { channel, kind } => match kind {
ChannelOpen::Session => {
*any_channel_opened = true;
let p = conn.accept_open(channel)?;
write_payload(stream, codec, rng, &p)?;
envs.insert(channel, SessionEnv::new());
}
ChannelOpen::DirectTcpip {
dest_host,
dest_port,
orig_host,
orig_port,
} => {
if let Some(handler) = cfg.direct_tcpip_handler.clone() {
let p = conn.accept_open(channel)?;
write_payload(stream, codec, rng, &p)?;
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SUBSYSTEM_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
let user_owned = user.to_string();
thread::spawn(move || {
let req = DirectTcpipRequest {
dest_host: &dest_host,
dest_port,
orig_host: &orig_host,
orig_port,
};
let _ = handler.handle(&user_owned, req, cs);
});
subsystems.insert(
channel,
SubsystemRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
} else {
let p = conn.reject_open(
channel,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
"direct-tcpip not enabled",
"",
)?;
write_payload(stream, codec, rng, &p)?;
}
}
_ => {
let p = conn.reject_open(
channel,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
"channel type not supported",
"",
)?;
write_payload(stream, codec, rng, &p)?;
}
},
ChannelEvent::Request {
channel,
request,
want_reply,
} => {
handle_channel_request(
stream,
codec,
rng,
inbox,
conn,
cfg,
user,
channel,
request,
want_reply,
shells,
subsystems,
envs,
agent_forward,
x11_forward,
)?;
}
ChannelEvent::Data { channel, data } => {
if let Some(rt) = shells.get_mut(&channel) {
if let Some(sess) = rt.session.as_mut() {
let mut off = 0usize;
let mut retries = 0u32;
while off < data.len() {
let n = sess.write(&data[off..])?;
if n == 0 {
retries += 1;
if retries > 4 {
break;
}
continue;
}
off += n;
}
}
}
if let Some(rt) = subsystems.get_mut(&channel) {
let _ = rt.ingress_tx.send(Some(data.clone()));
}
if let Some(adj) = conn.replenish_window(channel, data.len() as u32)? {
write_payload(stream, codec, rng, &adj)?;
}
}
ChannelEvent::ExtendedData { channel, data, .. } => {
if let Some(adj) = conn.replenish_window(channel, data.len() as u32)? {
write_payload(stream, codec, rng, &adj)?;
}
}
ChannelEvent::Eof { channel } => {
if let Some(rt) = shells.get_mut(&channel) {
if let Some(sess) = rt.session.as_mut() {
let _ = sess.close_stdin();
}
}
if let Some(rt) = subsystems.get_mut(&channel) {
let _ = rt.ingress_tx.send(None);
}
}
ChannelEvent::Close { channel } => {
if let Some(ch) = conn.channel(channel) {
if !ch.local_closed {
let p = conn.send_close(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
shells.remove(&channel);
subsystems.remove(&channel);
envs.remove(&channel);
agent_forward.active.remove(&channel);
x11_forward.active.remove(&channel);
}
ChannelEvent::WindowAdjust { .. } => {}
ChannelEvent::GlobalRequest {
request,
want_reply,
} => {
handle_global_request(
stream, codec, rng, conn, cfg, user, request, want_reply, forward,
)?;
}
_ => {}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn handle_global_request<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
conn: &mut ConnectionState,
cfg: &Config,
user: &str,
request: crate::channel::GlobalRequest,
want_reply: bool,
forward: &mut ForwardConn,
) -> Result<()> {
use crate::channel::GlobalRequest;
use crate::format::Writer;
match request {
GlobalRequest::TcpipForward {
bind_address,
bind_port,
} => {
let bound = if bind_port > u16::MAX as u32 {
None
} else if let Some(handler) = cfg.tcpip_forward_handler.clone() {
let ctx = ForwardContext::new(forward.req_tx.clone());
handler
.bind(user, &bind_address, bind_port as u16, ctx)
.ok()
} else {
None
};
if !want_reply {
if let Some(port) = bound {
forward.owned_bindings.push((bind_address, port));
}
return Ok(());
}
match bound {
Some(port) => {
forward.owned_bindings.push((bind_address, port));
let tail = if bind_port == 0 {
let mut w = Writer::new();
w.write_u32(port as u32);
w.into_vec()
} else {
Vec::new()
};
let p = conn.send_global_success(&tail);
write_payload(stream, codec, rng, &p)?;
}
None => {
let p = conn.send_global_failure();
write_payload(stream, codec, rng, &p)?;
}
}
}
GlobalRequest::CancelTcpipForward {
bind_address,
bind_port,
} => {
let ok = if bind_port > u16::MAX as u32 {
false
} else if let Some(handler) = cfg.tcpip_forward_handler.clone() {
let r = handler
.unbind(user, &bind_address, bind_port as u16)
.is_ok();
if r {
forward
.owned_bindings
.retain(|(a, p)| !(a == &bind_address && *p == bind_port as u16));
}
r
} else {
false
};
if !want_reply {
return Ok(());
}
let p = if ok {
conn.send_global_success(&[])
} else {
conn.send_global_failure()
};
write_payload(stream, codec, rng, &p)?;
}
GlobalRequest::Keepalive | GlobalRequest::Other { .. } => {
if want_reply {
let p = conn.send_global_failure();
write_payload(stream, codec, rng, &p)?;
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn handle_channel_request<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
conn: &mut ConnectionState,
cfg: &Config,
user: &str,
channel: u32,
request: ChannelRequest,
want_reply: bool,
shells: &mut BTreeMap<u32, ShellRuntime>,
subsystems: &mut BTreeMap<u32, SubsystemRuntime>,
envs: &mut BTreeMap<u32, SessionEnv>,
agent_forward: &mut AgentForwardConn,
x11_forward: &mut X11ForwardConn,
) -> Result<()> {
let empty_env = SessionEnv::new();
match request {
ChannelRequest::Exec { command } => {
if let Some(handler) = cfg.exec_stream_handler.clone() {
if handler.claims(&command) {
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SUBSYSTEM_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
let user_owned = user.to_string();
let command_owned = command.clone();
let env_snapshot = envs.get(&channel).cloned().unwrap_or_default();
let handler_for_thread = handler;
thread::spawn(move || {
let _ =
handler_for_thread.run(&user_owned, &env_snapshot, &command_owned, cs);
});
subsystems.insert(
channel,
SubsystemRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
return Ok(());
}
}
let env_ref = envs.get(&channel).unwrap_or(&empty_env);
let result = cfg.command_handler.handle(user, env_ref, &command);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
drain_send(
stream,
codec,
rng,
inbox,
conn,
channel,
&result.stdout,
None,
)?;
drain_send(
stream,
codec,
rng,
inbox,
conn,
channel,
&result.stderr,
Some(SSH_EXTENDED_DATA_STDERR),
)?;
let p = conn.send_request(
channel,
ChannelRequest::ExitStatus {
code: result.exit_status,
},
false,
)?;
write_payload(stream, codec, rng, &p)?;
let p = conn.send_eof(channel)?;
write_payload(stream, codec, rng, &p)?;
let p = conn.send_close(channel)?;
write_payload(stream, codec, rng, &p)?;
}
ChannelRequest::PtyReq {
term,
cols,
rows,
px_w,
px_h,
modes,
} => {
if cfg.shell_handler.is_some() {
let rt = shells.entry(channel).or_insert_with(ShellRuntime::new);
rt.pending_pty = Some(PtySpec {
term,
cols,
rows,
px_w,
px_h,
modes,
});
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
} else if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
ChannelRequest::Shell => {
if let Some(handler) = cfg.shell_handler.clone() {
let rt = shells.entry(channel).or_insert_with(ShellRuntime::new);
let pty = rt.pending_pty.take();
let env_ref = envs.get(&channel).unwrap_or(&empty_env);
match handler.spawn(user, env_ref, pty) {
Ok(sess) => {
rt.session = Some(sess);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
Err(_) => {
shells.remove(&channel);
if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
}
} else if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
ChannelRequest::WindowChange {
cols,
rows,
px_w,
px_h,
} => {
if let Some(rt) = shells.get_mut(&channel) {
if let Some(sess) = rt.session.as_mut() {
let _ = sess.resize(cols, rows, px_w, px_h);
}
}
}
ChannelRequest::Env { name, value } => {
envs.entry(channel).or_default().insert(name, value);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
ChannelRequest::Subsystem { name } => {
if let Some(handler) = cfg.subsystem_handler.clone() {
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SUBSYSTEM_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
let user_owned = user.to_string();
let name_owned = name.clone();
let env_snapshot = envs.get(&channel).cloned().unwrap_or_default();
thread::spawn(move || {
let _ = handler.handle(&user_owned, &env_snapshot, &name_owned, cs);
});
subsystems.insert(
channel,
SubsystemRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
} else if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
ChannelRequest::AuthAgentReq => {
if let Some(handler) = cfg.agent_forward_handler.clone() {
let ctx = AgentForwardContext::new(agent_forward.req_tx.clone());
match handler.setup(user, ctx) {
Ok(handle) => {
let path_str = handle.auth_sock_path.to_string_lossy().into_owned();
envs.entry(channel)
.or_default()
.insert("SSH_AUTH_SOCK".to_string(), path_str);
agent_forward.active.insert(channel, handle);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
Err(_) => {
if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
}
} else if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
ChannelRequest::X11Req {
single_connection,
auth_protocol,
auth_cookie,
screen,
} => {
if let Some(handler) = cfg.x11_forward_handler.clone() {
let ctx = X11ForwardContext::new(x11_forward.req_tx.clone());
match handler.setup(
user,
single_connection,
&auth_protocol,
&auth_cookie,
screen,
ctx,
) {
Ok(handle) => {
envs.entry(channel)
.or_default()
.insert("DISPLAY".to_string(), handle.display_env.clone());
x11_forward.active.insert(channel, handle);
if want_reply {
let p = conn.send_request_success(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
Err(_) => {
if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
}
} else if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
_ => {
if want_reply {
let p = conn.send_request_failure(channel)?;
write_payload(stream, codec, rng, &p)?;
}
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn drain_send<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
inbox: &mut Vec<u8>,
conn: &mut ConnectionState,
channel: u32,
mut data: &[u8],
extended: Option<u32>,
) -> Result<()> {
let mut iter = 0usize;
while !data.is_empty() {
iter += 1;
if iter > MAX_DRAIN_STEPS {
return Err(Error::Protocol("drain_send did not converge"));
}
let (payload, taken) = if let Some(code) = extended {
conn.send_extended_data(channel, code, data)?
} else {
conn.send_data(channel, data)?
};
if taken > 0 {
write_payload(stream, codec, rng, &payload)?;
data = &data[taken..];
continue;
}
let pkt = read_one_packet(stream, codec, inbox)?;
let ev = conn.on_packet(&pkt)?;
match ev {
ChannelEvent::WindowAdjust { channel: c, .. } if c == channel => continue,
ChannelEvent::Close { channel: c } if c == channel => {
return Err(Error::BadChannelState);
}
_ => continue,
}
}
Ok(())
}
fn pick_host_key<'a>(
keys: &'a [Box<dyn HostKey + Send + Sync>],
name: &str,
) -> Option<&'a (dyn HostKey + Send + Sync)> {
for k in keys {
if k.algorithm() == name {
return Some(k.as_ref());
}
}
for k in keys {
let a = k.algorithm();
if (a == "ssh-rsa" || a == "rsa-sha2-256" || a == "rsa-sha2-512")
&& (name == "ssh-rsa" || name == "rsa-sha2-256" || name == "rsa-sha2-512")
{
return Some(k.as_ref());
}
}
None
}
fn build_server_kexinit<R: RngCore>(
rng: &mut R,
host_keys: &[Box<dyn HostKey + Send + Sync>],
) -> KexInit {
let mut have: Vec<&'static str> = Vec::new();
for n in defaults::HOST_KEY {
if host_keys.iter().any(|k| k.algorithm() == *n) {
have.push(*n);
continue;
}
if (*n == "rsa-sha2-256" || *n == "rsa-sha2-512")
&& host_keys.iter().any(|k| {
let a = k.algorithm();
a == "ssh-rsa" || a == "rsa-sha2-256" || a == "rsa-sha2-512"
})
{
have.push(*n);
}
}
if have.is_empty() {
have.push("ssh-ed25519");
}
let algs = KexAlgorithms {
kex: defaults::KEX,
server_host_key: &have,
ciphers_c2s: defaults::CIPHERS,
ciphers_s2c: defaults::CIPHERS,
macs_c2s: defaults::MACS,
macs_s2c: defaults::MACS,
comp_c2s: defaults::COMP,
comp_s2c: defaults::COMP,
lang_c2s: &[],
lang_s2c: &[],
};
let mut cookie = [0u8; 16];
rng.fill_bytes(&mut cookie);
KexInit::from_algorithms(&algs, cookie)
}
fn read_peer_version(stream: &mut TcpStream) -> Result<Vec<u8>> {
let mut buf = Vec::new();
for _ in 0..MAX_BANNER_LINES {
buf.clear();
read_line(stream, &mut buf, MAX_BANNER_LINE)?;
if buf.starts_with(b"SSH-") {
let parsed = VersionExchange::parse_remote(&buf)?;
return Ok(parsed.into_bytes());
}
}
Err(Error::Protocol("peer banner too long"))
}
fn read_line<S: Read>(stream: &mut S, buf: &mut Vec<u8>, max_len: usize) -> Result<()> {
let mut byte = [0u8; 1];
loop {
let n = stream.read(&mut byte)?;
if n == 0 {
return Err(Error::Protocol("connection closed before newline"));
}
buf.push(byte[0]);
if byte[0] == b'\n' {
return Ok(());
}
if buf.len() >= max_len {
return Err(Error::Protocol("banner line too long"));
}
}
}
fn read_one_packet(
stream: &mut TcpStream,
codec: &mut PacketCodec,
inbox: &mut Vec<u8>,
) -> Result<Vec<u8>> {
loop {
let payload = read_one_raw_packet(stream, codec, inbox)?;
match payload.first().copied() {
Some(1) => return Err(Error::Protocol("peer sent SSH_MSG_DISCONNECT")),
Some(2) | Some(3) | Some(4) => continue,
_ => return Ok(payload),
}
}
}
fn read_one_raw_packet(
stream: &mut TcpStream,
codec: &mut PacketCodec,
inbox: &mut Vec<u8>,
) -> Result<Vec<u8>> {
loop {
if let Some((payload, consumed)) = codec.decode(inbox)? {
inbox.drain(..consumed);
return Ok(payload);
}
let mut tmp = [0u8; 16 * 1024];
let n = stream.read(&mut tmp)?;
if n == 0 {
return Err(Error::Protocol("connection closed"));
}
inbox.extend_from_slice(&tmp[..n]);
if inbox.len() > MAX_INBOX_BYTES {
return Err(Error::Protocol("inbound buffer too large"));
}
}
}
fn write_payload<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
payload: &[u8],
) -> Result<()> {
let frame = codec.encode(payload, rng)?;
stream.write_all(&frame)?;
Ok(())
}
fn send_disconnect<R: RngCore + CryptoRng>(
stream: &mut TcpStream,
codec: &mut PacketCodec,
rng: &mut R,
reason: u32,
description: &str,
) -> Result<()> {
let mut w = Writer::new();
w.write_u8(1);
w.write_u32(reason);
w.write_string(description.as_bytes());
w.write_string(b"");
write_payload(stream, codec, rng, &w.into_vec())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::{AuthAttempt, AuthDecision, Authenticator};
use crate::client::{Client, Config as ClientConfig, HostKeyPolicy};
use crate::hostkey::Ed25519HostKey;
use std::sync::Mutex;
use std::time::Duration;
struct OneKeyAuth {
allowed_user: String,
allowed_blob: Vec<u8>,
}
impl Authenticator for OneKeyAuth {
fn evaluate(&mut self, attempt: AuthAttempt) -> AuthDecision {
match attempt {
AuthAttempt::PublicKey {
user,
public_blob,
probe_only,
verified,
..
} => {
if user != self.allowed_user {
return AuthDecision::Reject;
}
if public_blob != self.allowed_blob {
return AuthDecision::Reject;
}
if probe_only {
return AuthDecision::Accept;
}
if !verified {
return AuthDecision::Reject;
}
AuthDecision::Accept
}
_ => AuthDecision::Reject,
}
}
}
struct StaticHandler {
out: Vec<u8>,
}
impl CommandHandler for StaticHandler {
fn handle(&self, _user: &str, _env: &SessionEnv, _command: &str) -> ExecResult {
ExecResult {
stdout: self.out.clone(),
stderr: Vec::new(),
exit_status: 0,
}
}
}
fn fresh_seed() -> [u8; 32] {
let mut s = [0u8; 32];
OsRng.fill_bytes(&mut s);
s
}
struct MemoryShellState {
stdout: Vec<u8>,
stdin: Vec<u8>,
closed_stdin: bool,
pty: Option<PtySpec>,
resizes: Vec<(u32, u32, u32, u32)>,
exit_on_stdin_close: Option<ShellExitStatus>,
exit_now: Option<ShellExitStatus>,
user: String,
}
#[derive(Clone)]
struct MemoryShell {
inner: Arc<Mutex<MemoryShellState>>,
}
impl MemoryShell {
fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(MemoryShellState {
stdout: Vec::new(),
stdin: Vec::new(),
closed_stdin: false,
pty: None,
resizes: Vec::new(),
exit_on_stdin_close: None,
exit_now: None,
user: String::new(),
})),
}
}
fn push_stdout(&self, bytes: &[u8]) {
self.inner.lock().unwrap().stdout.extend_from_slice(bytes);
}
fn arm_exit_on_stdin_close(&self, status: ShellExitStatus) {
self.inner.lock().unwrap().exit_on_stdin_close = Some(status);
}
}
struct MemoryShellHandler {
shell: MemoryShell,
}
impl ShellHandler for MemoryShellHandler {
fn spawn(
&self,
user: &str,
_env: &SessionEnv,
pty: Option<PtySpec>,
) -> Result<Box<dyn ShellSession>> {
{
let mut st = self.shell.inner.lock().unwrap();
st.pty = pty;
st.user = user.to_string();
}
Ok(Box::new(MemoryShellSession {
inner: self.shell.inner.clone(),
}))
}
}
struct MemoryShellSession {
inner: Arc<Mutex<MemoryShellState>>,
}
impl ShellSession for MemoryShellSession {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut st = self.inner.lock().unwrap();
if st.stdout.is_empty() {
return Ok(0);
}
let n = core::cmp::min(buf.len(), st.stdout.len());
buf[..n].copy_from_slice(&st.stdout[..n]);
st.stdout.drain(..n);
Ok(n)
}
fn write(&mut self, data: &[u8]) -> Result<usize> {
self.inner.lock().unwrap().stdin.extend_from_slice(data);
Ok(data.len())
}
fn close_stdin(&mut self) -> Result<()> {
self.inner.lock().unwrap().closed_stdin = true;
Ok(())
}
fn resize(&mut self, cols: u32, rows: u32, px_w: u32, px_h: u32) -> Result<()> {
self.inner
.lock()
.unwrap()
.resizes
.push((cols, rows, px_w, px_h));
Ok(())
}
fn try_exit(&mut self) -> Option<ShellExitStatus> {
let mut st = self.inner.lock().unwrap();
if let Some(s) = st.exit_now.take() {
return Some(s);
}
if st.closed_stdin && st.stdout.is_empty() {
if let Some(s) = st.exit_on_stdin_close.take() {
return Some(s);
}
}
None
}
}
#[test]
fn loopback_shell_with_pty_and_stdin() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "shell-test-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let memshell = MemoryShell::new();
memshell.push_stdout(b"hello from memshell\n");
memshell.arm_exit_on_stdin_close(ShellExitStatus::Exited(0));
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused-exec\n".to_vec(),
}),
)
.with_shell(Arc::new(MemoryShellHandler {
shell: memshell.clone(),
}));
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("local_addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let out = client
.shell_with_stdin("xterm-256color", 132, 43, b"echo back\n")
.expect("shell_with_stdin");
assert_eq!(out.stdout, b"hello from memshell\n");
assert_eq!(out.exit_status, Some(0));
assert_eq!(out.exit_signal, None);
let st = memshell.inner.lock().unwrap();
let pty = st.pty.as_ref().expect("pty-req captured");
assert_eq!(pty.term, "xterm-256color");
assert_eq!(pty.cols, 132);
assert_eq!(pty.rows, 43);
assert_eq!(st.stdin, b"echo back\n");
assert!(st.closed_stdin, "EOF should reach the backend");
assert_eq!(st.user, user);
drop(st);
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_exec_roundtrip() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "ssh-test-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"loopback-test\n".to_vec(),
}),
);
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("local_addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let out = client.exec("ignored").expect("exec");
assert_eq!(out.stdout, b"loopback-test\n");
assert_eq!(out.exit_status, Some(0));
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_forces_rekeys_with_tiny_policy() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "ssh-test-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let payload: Vec<u8> = (0..16_384).map(|i| (i & 0xff) as u8).collect();
let mut cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: payload.clone(),
}),
);
cfg.rekey_policy = RekeyPolicy {
max_bytes: 1024,
max_duration: Duration::from_secs(60 * 60),
max_seq: 1u32 << 31,
};
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("local_addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let session_id_before = client.session_id().to_vec();
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let out = client.exec("ignored").expect("exec");
assert_eq!(out.stdout, payload);
assert_eq!(out.exit_status, Some(0));
assert_eq!(client.session_id(), session_id_before.as_slice());
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn server_kexinit_negotiation_uses_role_server() {
let mut rng = OsRng;
let host_keys: Vec<Box<dyn HostKey + Send + Sync>> =
vec![Box::new(Ed25519HostKey::from_seed(fresh_seed()))];
let advert = build_server_kexinit(&mut rng, &host_keys);
let mut runner = KexRunner::new(Role::Server, advert.clone());
let mut cookie = [0u8; 16];
rng.fill_bytes(&mut cookie);
let client_init = {
let algs = KexAlgorithms {
kex: &["curve25519-sha256"],
server_host_key: &["ssh-ed25519"],
ciphers_c2s: &["chacha20-poly1305@openssh.com"],
ciphers_s2c: &["chacha20-poly1305@openssh.com"],
macs_c2s: &["hmac-sha2-256"],
macs_s2c: &["hmac-sha2-256"],
comp_c2s: &["none"],
comp_s2c: &["none"],
lang_c2s: &[],
lang_s2c: &[],
};
KexInit::from_algorithms(&algs, cookie)
};
let _ = runner.start(&mut rng).expect("server start");
let mut codec = PacketCodec::new();
let adv = runner
.on_packet(
&mut rng,
&mut codec,
&client_init.encode(),
None,
None,
b"SSH-2.0-test-client",
b"SSH-2.0-test-server",
)
.expect("server processes client kexinit");
assert!(!adv.completed);
let neg = runner.negotiated().expect("negotiated");
assert_eq!(neg.kex, "curve25519-sha256");
assert_eq!(neg.host_key, "ssh-ed25519");
}
struct EchoUpperSubsystem {
captured_name: Arc<Mutex<Option<String>>>,
captured_user: Arc<Mutex<Option<String>>>,
}
impl SubsystemHandler for EchoUpperSubsystem {
fn handle(
&self,
user: &str,
_env: &SessionEnv,
name: &str,
mut stream: ChannelStream,
) -> Result<()> {
*self.captured_name.lock().unwrap() = Some(name.to_string());
*self.captured_user.lock().unwrap() = Some(user.to_string());
let mut acc = Vec::new();
let mut tmp = [0u8; 256];
loop {
match std::io::Read::read(&mut stream, &mut tmp) {
Ok(0) => break, Ok(n) => acc.extend_from_slice(&tmp[..n]),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(Duration::from_millis(5));
continue;
}
Err(_) => break,
}
}
for b in acc.iter_mut() {
b.make_ascii_uppercase();
}
std::io::Write::write_all(&mut stream, &acc).ok();
Ok(())
}
}
#[test]
fn loopback_subsystem_roundtrip() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "subsys-test-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let captured_name = Arc::new(Mutex::new(None));
let captured_user = Arc::new(Mutex::new(None));
let sub = EchoUpperSubsystem {
captured_name: captured_name.clone(),
captured_user: captured_user.clone(),
};
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused-exec\n".to_vec(),
}),
)
.with_subsystem(Arc::new(sub));
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("local_addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let body = b"hello, subsystem world".to_vec();
let resp = client
.subsystem_once("echo", &body)
.expect("subsystem_once");
assert_eq!(resp, b"HELLO, SUBSYSTEM WORLD".to_vec());
assert_eq!(
captured_name.lock().unwrap().as_deref(),
Some("echo"),
"subsystem name reached the handler",
);
assert_eq!(
captured_user.lock().unwrap().as_deref(),
Some(user.as_str()),
"authenticated user reached the handler",
);
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_subsystem_unconfigured_refused() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "subsys-reject-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused-exec\n".to_vec(),
}),
);
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("local_addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let err = client
.subsystem_once("sftp", b"")
.expect_err("expected rejection");
match err {
Error::Protocol(_) => {}
other => panic!("expected Error::Protocol, got {:?}", other),
}
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
struct SftpSubsystem {
cwd: std::path::PathBuf,
root: std::path::PathBuf,
}
impl SubsystemHandler for SftpSubsystem {
fn handle(
&self,
_user: &str,
_env: &SessionEnv,
name: &str,
stream: ChannelStream,
) -> Result<()> {
if name != "sftp" {
return Ok(());
}
let opts =
crate::sftp::SftpServerOptions::new(self.cwd.clone()).with_root(self.root.clone());
let mut sess = crate::sftp::SftpServerSession::new(opts);
let _ = sess.run(stream);
Ok(())
}
}
struct SftpTempDir(std::path::PathBuf);
impl SftpTempDir {
fn new(tag: &str) -> Self {
let dir = std::env::temp_dir().join(format!(
"puressh-server-sftp-{}-{}-{}",
tag,
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos(),
));
std::fs::create_dir_all(&dir).unwrap();
Self(dir)
}
fn path(&self) -> &std::path::Path {
&self.0
}
}
impl Drop for SftpTempDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.0);
}
}
#[test]
fn loopback_sftp_client_roundtrip() {
let tmp = SftpTempDir::new("roundtrip");
let root = tmp.path().to_path_buf();
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "sftp-test-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let sub = SftpSubsystem {
cwd: root.clone(),
root: root.clone(),
};
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused-exec\n".to_vec(),
}),
)
.with_subsystem(Arc::new(sub));
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("local_addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
{
#[allow(deprecated)]
let mut sftp = client.sftp().expect("sftp handshake");
assert!(sftp.server_version() >= 3);
let cwd = sftp.realpath(b".").expect("realpath .");
assert_eq!(cwd.as_slice(), root.as_os_str().as_encoded_bytes());
let target = root.join("hello.txt");
let body = b"hello from sftp\n".to_vec();
let handle = sftp
.open(
target.as_os_str().as_encoded_bytes(),
crate::sftp::FXF_WRITE | crate::sftp::FXF_CREAT | crate::sftp::FXF_TRUNC,
crate::sftp::Attrs::default(),
)
.expect("open for write");
sftp.write(&handle, 0, &body).expect("write");
sftp.close(&handle).expect("close write handle");
let handle = sftp
.open(
target.as_os_str().as_encoded_bytes(),
crate::sftp::FXF_READ,
crate::sftp::Attrs::default(),
)
.expect("open for read");
let got = sftp.read(&handle, 0, 1024).expect("read");
assert_eq!(got, body);
sftp.close(&handle).expect("close read handle");
let dh = sftp
.opendir(root.as_os_str().as_encoded_bytes())
.expect("opendir");
let mut all_names = Vec::<Vec<u8>>::new();
while let Some(batch) = sftp.readdir(&dh).expect("readdir") {
for e in batch {
all_names.push(e.filename);
}
}
sftp.close(&dh).expect("close dir");
assert!(
all_names.iter().any(|n| n == b"hello.txt"),
"readdir saw the new file: {:?}",
all_names
.iter()
.map(|n| String::from_utf8_lossy(n).into_owned())
.collect::<Vec<_>>(),
);
sftp.remove(target.as_os_str().as_encoded_bytes())
.expect("remove");
let err = sftp
.stat(target.as_os_str().as_encoded_bytes())
.expect_err("stat after remove");
match err {
crate::sftp::SftpError::Status {
code: crate::sftp::FxpStatus::NoSuchFile,
..
} => {}
other => panic!("expected NoSuchFile, got {:?}", other),
}
}
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_direct_tcpip_round_trip() {
use std::io::{Read as _, Write as _};
use std::net::TcpListener;
let echo_listener = TcpListener::bind("127.0.0.1:0").expect("bind echo");
let echo_addr = echo_listener.local_addr().expect("echo addr");
let echo_thread = thread::spawn(move || {
if let Ok((mut s, _)) = echo_listener.accept() {
let mut buf = [0u8; 1024];
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
if s.write_all(&buf[..n]).is_err() {
break;
}
}
}
}
}
});
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "direct-tcpip-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused-exec\n".to_vec(),
}),
)
.with_direct_tcpip(Arc::new(
crate::forwarding::direct::DefaultDirectTcpipHandler::new(),
));
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind ssh");
let ssh_addr = server.local_addr().expect("ssh addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
ssh_addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
{
#[allow(deprecated)]
let mut s = client
.open_direct_tcpip(
&echo_addr.ip().to_string(),
echo_addr.port(),
"127.0.0.1",
0,
)
.expect("open direct-tcpip");
s.write_all(b"ping").expect("write");
let mut got = [0u8; 4];
s.read_exact(&mut got).expect("read echo");
assert_eq!(&got, b"ping");
}
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
let _ = echo_thread.join();
}
#[test]
#[allow(deprecated)]
fn loopback_direct_tcpip_unconfigured_refused() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "direct-tcpip-reject-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused\n".to_vec(),
}),
);
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind");
let addr = server.local_addr().expect("addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
match client.open_direct_tcpip("127.0.0.1", 1, "127.0.0.1", 0) {
Ok(_) => panic!("expected direct-tcpip open to be refused"),
Err(Error::Protocol(_)) => {}
Err(other) => panic!("expected Error::Protocol, got {:?}", other),
}
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_tcpip_forward_round_trip() {
use crate::client::{ClientHandlers, ForwardedTcpipOrigin};
use std::io::{Read as _, Write as _};
use std::net::TcpStream;
use std::sync::atomic::Ordering;
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "tcpip-forward-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused-exec\n".to_vec(),
}),
)
.with_tcpip_forward(Arc::new(
crate::forwarding::reverse::DefaultTcpipForwardHandler::new(),
));
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind ssh");
let ssh_addr = server.local_addr().expect("ssh addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
ssh_addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let bound_port = client
.request_tcpip_forward("127.0.0.1", 0)
.expect("request_tcpip_forward");
assert!(bound_port > 0);
let origin_seen: Arc<Mutex<Option<ForwardedTcpipOrigin>>> = Arc::new(Mutex::new(None));
let origin_clone = origin_seen.clone();
let cb: Arc<crate::client::ForwardedTcpipCallback> =
Arc::new(move |origin: ForwardedTcpipOrigin, mut s: ChannelStream| {
*origin_clone.lock().unwrap() = Some(origin);
let mut acc = Vec::new();
let mut tmp = [0u8; 256];
loop {
match Read::read(&mut s, &mut tmp) {
Ok(0) => break,
Ok(n) => acc.extend_from_slice(&tmp[..n]),
Err(_) => break,
}
}
for b in acc.iter_mut() {
b.make_ascii_uppercase();
}
let _ = Write::write_all(&mut s, &acc);
});
let handlers = ClientHandlers::new().with_forwarded_tcpip(cb);
let stop = handlers.stop.clone();
let serve_thread = thread::spawn(move || -> std::result::Result<Client, Error> {
client.serve(handlers)?;
Ok(client)
});
let mut s = TcpStream::connect(("127.0.0.1", bound_port)).expect("dial forwarded port");
s.write_all(b"hello").expect("write");
s.shutdown(std::net::Shutdown::Write)
.expect("shutdown write");
let mut got = Vec::new();
s.read_to_end(&mut got).expect("read echo");
assert_eq!(got, b"HELLO");
drop(s);
thread::sleep(Duration::from_millis(100));
stop.store(true, Ordering::SeqCst);
let start = std::time::Instant::now();
while !serve_thread.is_finished() {
if start.elapsed() > Duration::from_secs(10) {
panic!("serve loop did not stop in time");
}
thread::sleep(Duration::from_millis(20));
}
let client_back = serve_thread
.join()
.expect("serve join")
.expect("serve result");
let captured = origin_seen.lock().unwrap().clone().expect("origin latched");
assert_eq!(captured.bound_address, "127.0.0.1");
assert_eq!(captured.bound_port, bound_port);
assert!(captured.orig_port > 0);
drop(client_back);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_tcpip_forward_unconfigured_refused() {
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "tcpip-forward-reject-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused\n".to_vec(),
}),
);
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind ssh");
let ssh_addr = server.local_addr().expect("ssh addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
ssh_addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
match client.request_tcpip_forward("127.0.0.1", 0) {
Ok(_) => panic!("expected tcpip-forward to be refused"),
Err(Error::Protocol(_)) => {}
Err(other) => panic!("expected Error::Protocol, got {:?}", other),
}
drop(client);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
}
#[test]
fn loopback_serve_context_direct_tcpip_round_trip() {
use crate::client::ClientHandlers;
use std::io::{Read as _, Write as _};
use std::net::TcpListener;
use std::sync::atomic::Ordering;
let echo_listener = TcpListener::bind("127.0.0.1:0").expect("bind echo");
let echo_addr = echo_listener.local_addr().expect("echo addr");
let echo_thread = thread::spawn(move || {
if let Ok((mut s, _)) = echo_listener.accept() {
let mut buf = [0u8; 1024];
loop {
match s.read(&mut buf) {
Ok(0) | Err(_) => break,
Ok(n) => {
if s.write_all(&buf[..n]).is_err() {
break;
}
}
}
}
}
});
let host_seed = fresh_seed();
let client_seed = fresh_seed();
let host_key: Box<dyn HostKey + Send + Sync> =
Box::new(Ed25519HostKey::from_seed(host_seed));
let client_hk_for_auth = Ed25519HostKey::from_seed(client_seed);
let allowed_blob = client_hk_for_auth.public_blob();
let user = "serve-ctx-user".to_string();
let allowed_user_for_factory = user.clone();
let allowed_blob_clone = allowed_blob.clone();
let factory: Arc<dyn AuthenticatorFactory> = Arc::new(move || -> Box<dyn Authenticator> {
Box::new(OneKeyAuth {
allowed_user: allowed_user_for_factory.clone(),
allowed_blob: allowed_blob_clone.clone(),
})
});
let cfg = Config::new(
vec![host_key],
factory,
vec!["publickey"],
Arc::new(StaticHandler {
out: b"unused\n".to_vec(),
}),
)
.with_direct_tcpip(Arc::new(
crate::forwarding::direct::DefaultDirectTcpipHandler::new(),
));
let mut server = Server::bind("127.0.0.1:0", cfg).expect("bind ssh");
let ssh_addr = server.local_addr().expect("ssh addr");
let server_done = Arc::new(Mutex::new(false));
let sd = server_done.clone();
let server_thread = thread::spawn(move || {
let r = server.accept_one();
*sd.lock().unwrap() = true;
r
});
let mut client = Client::connect(
ssh_addr,
ClientConfig {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: Some(Duration::from_secs(10)),
},
)
.expect("client connect");
let client_hk: Box<dyn HostKey + Send> = Box::new(Ed25519HostKey::from_seed(client_seed));
client
.authenticate_publickey(&user, client_hk)
.expect("authenticate");
let (handlers, ctx) = ClientHandlers::new().with_serve_context();
let stop = handlers.stop.clone();
let serve_thread = thread::spawn(move || -> std::result::Result<Client, Error> {
client.serve(handlers)?;
Ok(client)
});
let mut s = ctx
.open_direct_tcpip(
&echo_addr.ip().to_string(),
echo_addr.port(),
"127.0.0.1",
0,
)
.expect("open_direct_tcpip via ServeContext");
s.write_all(b"ping").expect("write ping");
let mut got = [0u8; 4];
s.read_exact(&mut got).expect("read echo");
assert_eq!(&got, b"ping");
drop(s);
thread::sleep(Duration::from_millis(100));
drop(ctx);
stop.store(true, Ordering::SeqCst);
let start = std::time::Instant::now();
while !serve_thread.is_finished() {
if start.elapsed() > Duration::from_secs(10) {
panic!("serve loop did not stop in time");
}
thread::sleep(Duration::from_millis(20));
}
let client_back = serve_thread
.join()
.expect("serve join")
.expect("serve result");
drop(client_back);
let start = std::time::Instant::now();
while !*server_done.lock().unwrap() {
if start.elapsed() > Duration::from_secs(10) {
panic!("server thread did not finish in time");
}
thread::sleep(Duration::from_millis(20));
}
let _ = server_thread.join();
let _ = echo_thread.join();
}
}