#![cfg(feature = "std")]
use std::collections::BTreeMap;
use std::io::{ErrorKind, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{self, Receiver, Sender, TryRecvError};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use purecrypto::hash::{Digest, Sha256};
use purecrypto::rng::{OsRng, RngCore};
use crate::auth::{ClientAuth, ClientCredential, ClientStep};
use crate::channel::{
ChannelEvent, ChannelOpen, ChannelRequest, ConnectionState, SSH_EXTENDED_DATA_STDERR,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
};
use crate::error::{Error, Result};
use crate::hostkey::{host_key_verify_by_name, HostKey, HostKeyVerify};
use crate::known_hosts::{KnownHosts, LookupResult};
use crate::sftp::SftpClient;
pub use crate::stream::{ChannelEgress, ChannelStream};
use crate::transport::kex::{defaults, KexAlgorithms};
use crate::transport::rekey::{is_kex_msg, RekeyPolicy};
use crate::transport::{KexInit, KexRunner, PacketCodec, Role, VersionExchange};
const MAX_BANNER_LINE: usize = 1024;
const MAX_BANNER_LINES: usize = 256;
const MAX_INBOX_BYTES: usize = 8 * 1024 * 1024;
const MAX_EXEC_OUTPUT: usize = 64 * 1024 * 1024;
const MAX_KEX_STEPS: usize = 32;
const MAX_AUTH_STEPS: usize = 64;
const MAX_EXEC_ITER: usize = 1_000_000;
const SERVE_EGRESS_BACKLOG: usize = 32;
const MAX_SERVE_STEPS: usize = 100_000_000;
const SERVE_POLL_INTERVAL: Duration = Duration::from_millis(50);
const SSH_MSG_KEX_ECDH_REPLY: u8 = 31;
pub enum HostKeyPolicy {
AcceptAny,
AcceptFingerprint([u8; 32]),
KnownHosts(KnownHostsPolicy),
}
pub struct KnownHostsPolicy {
pub store: Arc<Mutex<KnownHosts>>,
pub save_path: Option<PathBuf>,
pub hash_new: bool,
pub on_unknown: TofuAction,
}
pub type TofuPromptFn = dyn Fn(&str, u16, &str, &[u8]) -> bool + Send + Sync;
pub enum TofuAction {
Reject,
Accept,
Prompt(Arc<TofuPromptFn>),
}
pub struct Config {
pub host_key_policy: HostKeyPolicy,
pub timeout: Option<Duration>,
}
impl Default for Config {
fn default() -> Self {
Self {
host_key_policy: HostKeyPolicy::AcceptAny,
timeout: None,
}
}
}
pub struct ExecOutput {
pub stdout: Vec<u8>,
pub stderr: Vec<u8>,
pub exit_status: Option<u32>,
pub exit_signal: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ForwardedTcpipOrigin {
pub bound_address: String,
pub bound_port: u16,
pub orig_address: String,
pub orig_port: u16,
}
pub type ForwardedTcpipCallback =
dyn Fn(ForwardedTcpipOrigin, ChannelStream) + Send + Sync + 'static;
pub type AuthAgentCallback = dyn Fn(ChannelStream) + Send + Sync + 'static;
pub type X11Callback = dyn Fn(ChannelStream) + Send + Sync + 'static;
pub struct ClientHandlers {
pub on_forwarded_tcpip: Option<Arc<ForwardedTcpipCallback>>,
pub on_auth_agent: Option<Arc<AuthAgentCallback>>,
pub on_x11: Option<Arc<X11Callback>>,
pub stop: Arc<AtomicBool>,
cmd_rx: Option<Receiver<ServeCommand>>,
}
impl Default for ClientHandlers {
fn default() -> Self {
Self::new()
}
}
impl ClientHandlers {
pub fn new() -> Self {
Self {
on_forwarded_tcpip: None,
on_auth_agent: None,
on_x11: None,
stop: Arc::new(AtomicBool::new(false)),
cmd_rx: None,
}
}
pub fn with_forwarded_tcpip(mut self, cb: Arc<ForwardedTcpipCallback>) -> Self {
self.on_forwarded_tcpip = Some(cb);
self
}
pub fn with_auth_agent(mut self, cb: Arc<AuthAgentCallback>) -> Self {
self.on_auth_agent = Some(cb);
self
}
pub fn with_x11(mut self, cb: Arc<X11Callback>) -> Self {
self.on_x11 = Some(cb);
self
}
pub fn with_serve_context(mut self) -> (Self, ServeContext) {
let (tx, rx) = mpsc::channel();
self.cmd_rx = Some(rx);
(self, ServeContext { cmd_tx: tx })
}
}
pub enum ServeCommand {
OpenDirectTcpip {
dest_host: String,
dest_port: u16,
orig_host: String,
orig_port: u16,
reply: mpsc::SyncSender<Result<ChannelStream>>,
},
}
#[derive(Clone)]
pub struct ServeContext {
cmd_tx: Sender<ServeCommand>,
}
impl ServeContext {
pub fn open_direct_tcpip(
&self,
dest_host: &str,
dest_port: u16,
orig_host: &str,
orig_port: u16,
) -> Result<ChannelStream> {
let (reply_tx, reply_rx) = mpsc::sync_channel::<Result<ChannelStream>>(1);
self.cmd_tx
.send(ServeCommand::OpenDirectTcpip {
dest_host: dest_host.to_string(),
dest_port,
orig_host: orig_host.to_string(),
orig_port,
reply: reply_tx,
})
.map_err(|_| Error::Protocol("serve loop terminated"))?;
reply_rx
.recv()
.map_err(|_| Error::Protocol("serve loop terminated"))?
}
}
struct PendingOutboundOpen {
stream: Option<ChannelStream>,
ingress_tx: Sender<Option<Vec<u8>>>,
egress_rx: Option<Receiver<ChannelEgress>>,
reply: mpsc::SyncSender<Result<ChannelStream>>,
}
struct ServeRuntime {
ingress_tx: Sender<Option<Vec<u8>>>,
egress_rx: Receiver<ChannelEgress>,
pending_data: Vec<u8>,
pending_eof: bool,
pending_close: bool,
eof_sent: bool,
close_sent: bool,
}
pub struct Client {
stream: TcpStream,
codec: PacketCodec,
pub(crate) conn: ConnectionState,
session_id: Vec<u8>,
inbox: Vec<u8>,
rng: OsRng,
runner: KexRunner,
v_c: Vec<u8>,
v_s: Vec<u8>,
host_key_policy: HostKeyPolicy,
last_kex: Instant,
rekey_policy: RekeyPolicy,
deferred: Vec<Vec<u8>>,
target_host: String,
target_port: u16,
request_auth_agent: bool,
request_x11: Option<X11ReqArgs>,
}
#[derive(Clone)]
struct X11ReqArgs {
single_connection: bool,
auth_protocol: String,
auth_cookie: String,
screen: u32,
}
impl Client {
pub fn connect<A: ToSocketAddrs>(addr: A, cfg: Config) -> Result<Self> {
let stream = TcpStream::connect(addr)?;
if let Some(t) = cfg.timeout {
stream.set_read_timeout(Some(t))?;
stream.set_write_timeout(Some(t))?;
}
stream.set_nodelay(true)?;
let mut rng = OsRng;
let placeholder_advert = build_default_kexinit(&mut rng);
let mut me = Self {
stream,
codec: PacketCodec::new(),
conn: ConnectionState::new(),
session_id: Vec::new(),
inbox: Vec::new(),
rng,
runner: KexRunner::new(Role::Client, placeholder_advert),
v_c: Vec::new(),
v_s: Vec::new(),
host_key_policy: HostKeyPolicy::AcceptAny,
last_kex: Instant::now(),
rekey_policy: RekeyPolicy::default(),
deferred: Vec::new(),
target_host: String::new(),
target_port: 0,
request_auth_agent: false,
request_x11: None,
};
me.host_key_policy = cfg.host_key_policy;
me.do_version_and_kex()?;
Ok(me)
}
pub fn connect_to_host(host: &str, port: u16, cfg: Config) -> Result<Self> {
let stream = TcpStream::connect((host, port))?;
if let Some(t) = cfg.timeout {
stream.set_read_timeout(Some(t))?;
stream.set_write_timeout(Some(t))?;
}
stream.set_nodelay(true)?;
let mut rng = OsRng;
let placeholder_advert = build_default_kexinit(&mut rng);
let mut me = Self {
stream,
codec: PacketCodec::new(),
conn: ConnectionState::new(),
session_id: Vec::new(),
inbox: Vec::new(),
rng,
runner: KexRunner::new(Role::Client, placeholder_advert),
v_c: Vec::new(),
v_s: Vec::new(),
host_key_policy: HostKeyPolicy::AcceptAny,
last_kex: Instant::now(),
rekey_policy: RekeyPolicy::default(),
deferred: Vec::new(),
target_host: host.to_string(),
target_port: port,
request_auth_agent: false,
request_x11: None,
};
me.host_key_policy = cfg.host_key_policy;
me.do_version_and_kex()?;
Ok(me)
}
pub fn authenticate(&mut self, user: &str, credentials: Vec<ClientCredential>) -> Result<()> {
let mut auth = ClientAuth::new(user, self.session_id.clone());
for c in credentials {
auth.add_credential(c);
}
let first = auth.start();
self.write_payload(&first)?;
for _ in 0..MAX_AUTH_STEPS {
let payload = self.read_one_packet()?;
match auth.on_packet(&payload)? {
ClientStep::Send(p) => self.write_payload(&p)?,
ClientStep::Success => {
self.codec.activate_compress();
return Ok(());
}
ClientStep::Failed { .. } => return Err(Error::AuthFailed),
ClientStep::Banner { .. } => {}
ClientStep::Idle => {}
}
}
Err(Error::Protocol("auth: too many steps without termination"))
}
pub fn authenticate_password(&mut self, user: &str, password: &str) -> Result<()> {
self.authenticate(user, vec![ClientCredential::Password(password.into())])
}
pub fn session_id(&self) -> &[u8] {
&self.session_id
}
pub fn authenticate_publickey(
&mut self,
user: &str,
key: Box<dyn HostKey + Send>,
) -> Result<()> {
self.authenticate(user, vec![ClientCredential::PublicKey(key)])
}
pub fn open_session_for_agent_forward(&mut self) -> Result<u32> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut iter_guard = 0usize;
loop {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("agent-forward: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => break,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("agent-forward: channel open failed"));
}
_ => {}
}
}
let p = self
.conn
.send_request(local_id, ChannelRequest::AuthAgentReq, false)?;
self.write_payload(&p)?;
Ok(local_id)
}
pub fn open_session_for_x11_forward(
&mut self,
single_connection: bool,
auth_protocol: &str,
auth_cookie: &str,
screen: u32,
) -> Result<u32> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut iter_guard = 0usize;
loop {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("x11-forward: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => break,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("x11-forward: channel open failed"));
}
_ => {}
}
}
let p = self.conn.send_request(
local_id,
ChannelRequest::X11Req {
single_connection,
auth_protocol: auth_protocol.to_string(),
auth_cookie: auth_cookie.to_string(),
screen,
},
false,
)?;
self.write_payload(&p)?;
Ok(local_id)
}
pub fn close_session(&mut self, channel: u32) -> Result<()> {
let payload = match self.conn.send_close(channel) {
Ok(p) => p,
Err(_) => return Ok(()),
};
self.write_payload(&payload)?;
Ok(())
}
pub fn set_request_auth_agent_forwarding(&mut self, on: bool) {
self.request_auth_agent = on;
}
pub(crate) fn maybe_send_auth_agent_req(&mut self, channel: u32) -> Result<()> {
if self.request_auth_agent {
let p = self
.conn
.send_request(channel, ChannelRequest::AuthAgentReq, false)?;
self.write_payload(&p)?;
}
Ok(())
}
pub fn set_request_x11_forwarding(&mut self, args: Option<(bool, String, String, u32)>) {
self.request_x11 =
args.map(
|(single_connection, auth_protocol, auth_cookie, screen)| X11ReqArgs {
single_connection,
auth_protocol,
auth_cookie,
screen,
},
);
}
pub(crate) fn maybe_send_x11_req(&mut self, channel: u32) -> Result<()> {
if let Some(args) = self.request_x11.clone() {
let p = self.conn.send_request(
channel,
ChannelRequest::X11Req {
single_connection: args.single_connection,
auth_protocol: args.auth_protocol,
auth_cookie: args.auth_cookie,
screen: args.screen,
},
false,
)?;
self.write_payload(&p)?;
}
Ok(())
}
pub fn exec(&mut self, command: &str) -> Result<ExecOutput> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut opened = false;
let mut iter_guard = 0usize;
while !opened {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("exec: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => {
opened = true;
}
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("channel open failed"));
}
_ => {}
}
}
self.maybe_send_auth_agent_req(local_id)?;
self.maybe_send_x11_req(local_id)?;
let exec_req = self.conn.send_request(
local_id,
ChannelRequest::Exec {
command: command.into(),
},
true,
)?;
self.write_payload(&exec_req)?;
let mut exec_accepted = false;
iter_guard = 0;
while !exec_accepted {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("exec: request loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::Success { channel } if channel == local_id => exec_accepted = true,
ChannelEvent::Failure { channel } if channel == local_id => {
return Err(Error::Protocol("exec request denied"));
}
_ => {}
}
}
let mut out = ExecOutput {
stdout: Vec::new(),
stderr: Vec::new(),
exit_status: None,
exit_signal: None,
};
let mut local_eof_sent = false;
let mut local_close_sent = false;
let mut remote_close_seen = false;
for _ in 0..MAX_EXEC_ITER {
if remote_close_seen && local_close_sent {
break;
}
let payload = self.read_one_packet()?;
let ev = self.conn.on_packet(&payload)?;
match ev {
ChannelEvent::Data { channel, data } if channel == local_id => {
if out.stdout.len() + out.stderr.len() + data.len() > MAX_EXEC_OUTPUT {
return Err(Error::Protocol("exec output too large"));
}
let n = data.len() as u32;
out.stdout.extend_from_slice(&data);
if let Some(adj) = self.conn.replenish_window(local_id, n)? {
self.write_payload(&adj)?;
}
}
ChannelEvent::ExtendedData {
channel,
code,
data,
} if channel == local_id => {
if out.stdout.len() + out.stderr.len() + data.len() > MAX_EXEC_OUTPUT {
return Err(Error::Protocol("exec output too large"));
}
let n = data.len() as u32;
if code == SSH_EXTENDED_DATA_STDERR {
out.stderr.extend_from_slice(&data);
} else {
out.stdout.extend_from_slice(&data);
}
if let Some(adj) = self.conn.replenish_window(local_id, n)? {
self.write_payload(&adj)?;
}
}
ChannelEvent::Request {
channel,
request,
want_reply,
} if channel == local_id => {
match request {
ChannelRequest::ExitStatus { code } => out.exit_status = Some(code),
ChannelRequest::ExitSignal { name, .. } => out.exit_signal = Some(name),
_ => {}
}
if want_reply {
let p = self.conn.send_request_failure(local_id)?;
self.write_payload(&p)?;
}
}
ChannelEvent::Eof { channel } if channel == local_id && !local_eof_sent => {
let p = self.conn.send_eof(local_id)?;
self.write_payload(&p)?;
local_eof_sent = true;
}
ChannelEvent::Close { channel } if channel == local_id => {
remote_close_seen = true;
if !local_close_sent {
let p = self.conn.send_close(local_id)?;
self.write_payload(&p)?;
local_close_sent = true;
}
}
ChannelEvent::WindowAdjust { .. } => {}
_ => {}
}
}
if !(remote_close_seen && local_close_sent) {
return Err(Error::Protocol("exec: drain loop exceeded iteration cap"));
}
Ok(out)
}
pub fn shell_with_stdin(
&mut self,
term: &str,
cols: u32,
rows: u32,
stdin: &[u8],
) -> Result<ExecOutput> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut opened = false;
let mut iter_guard = 0usize;
while !opened {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("shell: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => {
opened = true;
}
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("channel open failed"));
}
_ => {}
}
}
self.maybe_send_auth_agent_req(local_id)?;
self.maybe_send_x11_req(local_id)?;
let pty_req = self.conn.send_request(
local_id,
ChannelRequest::PtyReq {
term: term.into(),
cols,
rows,
px_w: 0,
px_h: 0,
modes: Vec::new(),
},
true,
)?;
self.write_payload(&pty_req)?;
self.await_request_reply(local_id, "pty-req")?;
let shell_req = self
.conn
.send_request(local_id, ChannelRequest::Shell, true)?;
self.write_payload(&shell_req)?;
self.await_request_reply(local_id, "shell")?;
if !stdin.is_empty() {
let mut off = 0usize;
iter_guard = 0;
while off < stdin.len() {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("shell: stdin drain loop did not converge"));
}
let (payload, taken) = self.conn.send_data(local_id, &stdin[off..])?;
if taken == 0 {
let pkt = self.read_one_packet()?;
match self.conn.on_packet(&pkt)? {
ChannelEvent::WindowAdjust { channel, .. } if channel == local_id => {}
ChannelEvent::Close { channel } if channel == local_id => {
return Err(Error::Protocol(
"shell: peer closed channel before stdin drain",
));
}
_ => {}
}
continue;
}
self.write_payload(&payload)?;
off += taken;
}
}
let eof = self.conn.send_eof(local_id)?;
self.write_payload(&eof)?;
let mut out = ExecOutput {
stdout: Vec::new(),
stderr: Vec::new(),
exit_status: None,
exit_signal: None,
};
let mut local_close_sent = false;
let mut remote_close_seen = false;
for _ in 0..MAX_EXEC_ITER {
if remote_close_seen && local_close_sent {
break;
}
let payload = self.read_one_packet()?;
let ev = self.conn.on_packet(&payload)?;
match ev {
ChannelEvent::Data { channel, data } if channel == local_id => {
if out.stdout.len() + out.stderr.len() + data.len() > MAX_EXEC_OUTPUT {
return Err(Error::Protocol("shell output too large"));
}
let n = data.len() as u32;
out.stdout.extend_from_slice(&data);
if let Some(adj) = self.conn.replenish_window(local_id, n)? {
self.write_payload(&adj)?;
}
}
ChannelEvent::ExtendedData {
channel,
code,
data,
} if channel == local_id => {
if out.stdout.len() + out.stderr.len() + data.len() > MAX_EXEC_OUTPUT {
return Err(Error::Protocol("shell output too large"));
}
let n = data.len() as u32;
if code == SSH_EXTENDED_DATA_STDERR {
out.stderr.extend_from_slice(&data);
} else {
out.stdout.extend_from_slice(&data);
}
if let Some(adj) = self.conn.replenish_window(local_id, n)? {
self.write_payload(&adj)?;
}
}
ChannelEvent::Request {
channel,
request,
want_reply,
} if channel == local_id => {
match request {
ChannelRequest::ExitStatus { code } => out.exit_status = Some(code),
ChannelRequest::ExitSignal { name, .. } => out.exit_signal = Some(name),
_ => {}
}
if want_reply {
let p = self.conn.send_request_failure(local_id)?;
self.write_payload(&p)?;
}
}
ChannelEvent::Eof { channel } if channel == local_id => {}
ChannelEvent::Close { channel } if channel == local_id => {
remote_close_seen = true;
if !local_close_sent {
let p = self.conn.send_close(local_id)?;
self.write_payload(&p)?;
local_close_sent = true;
}
}
ChannelEvent::WindowAdjust { .. } => {}
_ => {}
}
}
if !(remote_close_seen && local_close_sent) {
return Err(Error::Protocol("shell: drain loop exceeded iteration cap"));
}
Ok(out)
}
pub fn subsystem_once(&mut self, name: &str, stdin: &[u8]) -> Result<Vec<u8>> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut opened = false;
let mut iter_guard = 0usize;
while !opened {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("subsystem: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => opened = true,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("channel open failed"));
}
_ => {}
}
}
let sub_req = self.conn.send_request(
local_id,
ChannelRequest::Subsystem { name: name.into() },
true,
)?;
self.write_payload(&sub_req)?;
self.await_request_reply(local_id, "subsystem")?;
if !stdin.is_empty() {
let mut off = 0usize;
iter_guard = 0;
while off < stdin.len() {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol(
"subsystem: stdin drain loop did not converge",
));
}
let (payload, taken) = self.conn.send_data(local_id, &stdin[off..])?;
if taken == 0 {
let pkt = self.read_one_packet()?;
match self.conn.on_packet(&pkt)? {
ChannelEvent::WindowAdjust { channel, .. } if channel == local_id => {}
ChannelEvent::Close { channel } if channel == local_id => {
return Err(Error::Protocol(
"subsystem: peer closed channel before stdin drain",
));
}
_ => {}
}
continue;
}
self.write_payload(&payload)?;
off += taken;
}
}
let eof = self.conn.send_eof(local_id)?;
self.write_payload(&eof)?;
let mut out = Vec::<u8>::new();
let mut local_close_sent = false;
let mut remote_close_seen = false;
for _ in 0..MAX_EXEC_ITER {
if remote_close_seen && local_close_sent {
break;
}
let payload = self.read_one_packet()?;
let ev = self.conn.on_packet(&payload)?;
match ev {
ChannelEvent::Data { channel, data } if channel == local_id => {
if out.len() + data.len() > MAX_EXEC_OUTPUT {
return Err(Error::Protocol("subsystem output too large"));
}
let n = data.len() as u32;
out.extend_from_slice(&data);
if let Some(adj) = self.conn.replenish_window(local_id, n)? {
self.write_payload(&adj)?;
}
}
ChannelEvent::ExtendedData {
channel,
code: _,
data,
} if channel == local_id => {
let n = data.len() as u32;
if let Some(adj) = self.conn.replenish_window(local_id, n)? {
self.write_payload(&adj)?;
}
}
ChannelEvent::Eof { channel } if channel == local_id => {}
ChannelEvent::Close { channel } if channel == local_id => {
remote_close_seen = true;
if !local_close_sent {
let p = self.conn.send_close(local_id)?;
self.write_payload(&p)?;
local_close_sent = true;
}
}
ChannelEvent::WindowAdjust { .. } => {}
_ => {}
}
}
if !(remote_close_seen && local_close_sent) {
return Err(Error::Protocol(
"subsystem: drain loop exceeded iteration cap",
));
}
Ok(out)
}
#[cfg_attr(
feature = "multichannel",
deprecated(
since = "0.0.2",
note = "Use SharedClient::sftp instead; the borrow-based API \
prevents multiple concurrent channels on one connection."
)
)]
pub fn sftp(&mut self) -> Result<SftpClient<ClientChannelStream<'_>>> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut opened = false;
let mut iter_guard = 0usize;
while !opened {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("sftp: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => opened = true,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("channel open failed"));
}
_ => {}
}
}
self.maybe_send_auth_agent_req(local_id)?;
self.maybe_send_x11_req(local_id)?;
let sub_req = self.conn.send_request(
local_id,
ChannelRequest::Subsystem {
name: "sftp".into(),
},
true,
)?;
self.write_payload(&sub_req)?;
self.await_request_reply(local_id, "subsystem")?;
let stream = ClientChannelStream {
client: self,
channel: local_id,
read_buf: Vec::new(),
stderr_buf: Vec::new(),
remote_eof: false,
local_close_sent: false,
};
match SftpClient::new(stream) {
Ok(c) => Ok(c),
Err(e) => Err(Error::Protocol(match e {
crate::sftp::SftpError::Protocol(s) => s,
_ => "sftp: handshake failed",
})),
}
}
#[cfg_attr(
feature = "multichannel",
deprecated(
since = "0.0.2",
note = "Use SharedClient::exec_stream for multi-channel support."
)
)]
pub fn exec_stream(&mut self, command: &str) -> Result<ClientChannelStream<'_>> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::Session)?;
self.write_payload(&open_payload)?;
let mut opened = false;
let mut iter_guard = 0usize;
while !opened {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("exec_stream: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => opened = true,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("exec_stream: channel open failed"));
}
_ => {}
}
}
self.maybe_send_auth_agent_req(local_id)?;
self.maybe_send_x11_req(local_id)?;
let exec_req = self.conn.send_request(
local_id,
ChannelRequest::Exec {
command: command.into(),
},
true,
)?;
self.write_payload(&exec_req)?;
self.await_request_reply(local_id, "exec")?;
Ok(ClientChannelStream {
client: self,
channel: local_id,
read_buf: Vec::new(),
stderr_buf: Vec::new(),
remote_eof: false,
local_close_sent: false,
})
}
#[cfg_attr(
feature = "multichannel",
deprecated(
since = "0.0.2",
note = "Use SharedClient::open_direct_tcpip for multi-channel support."
)
)]
pub fn open_direct_tcpip(
&mut self,
dest_host: &str,
dest_port: u16,
orig_host: &str,
orig_port: u16,
) -> Result<ClientChannelStream<'_>> {
let (local_id, open_payload) = self.conn.open(ChannelOpen::DirectTcpip {
dest_host: dest_host.to_string(),
dest_port: dest_port as u32,
orig_host: orig_host.to_string(),
orig_port: orig_port as u32,
})?;
self.write_payload(&open_payload)?;
let mut iter_guard = 0usize;
loop {
iter_guard += 1;
if iter_guard > MAX_EXEC_ITER {
return Err(Error::Protocol("direct-tcpip: open loop did not converge"));
}
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::OpenConfirmed { channel } if channel == local_id => break,
ChannelEvent::OpenFailed { channel, .. } if channel == local_id => {
return Err(Error::Protocol("direct-tcpip: open failed"));
}
_ => {}
}
}
Ok(ClientChannelStream {
client: self,
channel: local_id,
read_buf: Vec::new(),
stderr_buf: Vec::new(),
remote_eof: false,
local_close_sent: false,
})
}
pub fn scp_send_to(
&mut self,
sources: &[&std::path::Path],
remote_dest: &str,
opts: crate::scp::ScpSendOptions,
) -> Result<()> {
let cmd = build_scp_to_cmd(remote_dest, &opts)?;
#[allow(deprecated)]
let mut stream = self.exec_stream(&cmd)?;
let result = (|| -> Result<()> {
let mut sender = crate::scp::Sender::new(&mut stream)
.map_err(|e| scp_proto(e, "scp_send_to: handshake"))?;
for src in sources {
sender
.send_path(src, &opts)
.map_err(|e| scp_proto(e, "scp_send_to: send_path"))?;
}
Ok(())
})();
let stderr = stream.take_stderr();
match result {
Ok(()) => Ok(()),
Err(e) => {
if !stderr.is_empty() {
let msg = String::from_utf8_lossy(&stderr).trim().to_string();
eprintln!("scp_send_to: remote stderr: {}", msg);
}
Err(e)
}
}
}
pub fn scp_recv_from(
&mut self,
remote_source: &str,
local_dest: &std::path::Path,
mut opts: crate::scp::ScpRecvOptions,
) -> Result<()> {
let cmd = build_scp_from_cmd(remote_source, &opts)?;
if !opts.target_is_file && !opts.recursive {
if let Ok(md) = std::fs::metadata(local_dest) {
if !md.is_dir() {
opts.target_is_file = true;
}
} else {
opts.target_is_file = true;
}
}
#[allow(deprecated)]
let mut stream = self.exec_stream(&cmd)?;
let result = (|| -> Result<()> {
let mut recv = crate::scp::Receiver::new(&mut stream, local_dest, opts)
.map_err(|e| scp_proto(e, "scp_recv_from: handshake"))?;
recv.run().map_err(|e| scp_proto(e, "scp_recv_from: run"))?;
Ok(())
})();
let stderr = stream.take_stderr();
match result {
Ok(()) => Ok(()),
Err(e) => {
if !stderr.is_empty() {
let msg = String::from_utf8_lossy(&stderr).trim().to_string();
eprintln!("scp_recv_from: remote stderr: {}", msg);
}
Err(e)
}
}
}
pub fn request_tcpip_forward(&mut self, bind_address: &str, bind_port: u16) -> Result<u16> {
use crate::channel::GlobalRequest;
let payload = self.conn.send_global_request(
GlobalRequest::TcpipForward {
bind_address: bind_address.to_string(),
bind_port: bind_port as u32,
},
true,
);
self.write_payload(&payload)?;
let data = self.await_global_reply("tcpip-forward")?;
if bind_port == 0 {
let mut r = crate::format::Reader::new(&data);
let p = r
.read_u32()
.map_err(|_| Error::Protocol("tcpip-forward: server omitted assigned-port tail"))?;
if p > u16::MAX as u32 {
return Err(Error::Protocol(
"tcpip-forward: server returned out-of-range port",
));
}
Ok(p as u16)
} else {
Ok(bind_port)
}
}
pub fn cancel_tcpip_forward(&mut self, bind_address: &str, bind_port: u16) -> Result<()> {
use crate::channel::GlobalRequest;
let payload = self.conn.send_global_request(
GlobalRequest::CancelTcpipForward {
bind_address: bind_address.to_string(),
bind_port: bind_port as u32,
},
true,
);
self.write_payload(&payload)?;
let _ = self.await_global_reply("cancel-tcpip-forward")?;
Ok(())
}
pub fn serve(&mut self, handlers: ClientHandlers) -> Result<()> {
let mut runtimes: BTreeMap<u32, ServeRuntime> = BTreeMap::new();
let mut pending_opens: BTreeMap<u32, PendingOutboundOpen> = BTreeMap::new();
let _ = self.stream.set_read_timeout(Some(SERVE_POLL_INTERVAL));
let mut steps = 0usize;
let result = loop {
steps += 1;
if steps > MAX_SERVE_STEPS {
break Err(Error::Protocol("serve: step cap exceeded"));
}
if !self.runner.is_kexing() && !self.deferred.is_empty() {
let payload = self.deferred.remove(0);
if let Err(e) = serve_dispatch_packet(
self,
&handlers,
&mut runtimes,
&mut pending_opens,
&payload,
) {
break Err(e);
}
continue;
}
if !self.runner.is_kexing() {
if let Some(rx) = handlers.cmd_rx.as_ref() {
if let Err(e) = serve_drain_commands(self, rx, &mut pending_opens) {
break Err(e);
}
}
}
if !self.runner.is_kexing() {
if let Err(e) = serve_drain_runtimes(self, &mut runtimes) {
break Err(e);
}
runtimes.retain(|_, rt| !rt.close_sent);
}
if handlers.stop.load(Ordering::SeqCst)
&& runtimes.is_empty()
&& pending_opens.is_empty()
{
break Ok(());
}
if !self.runner.is_kexing()
&& self
.rekey_policy
.should_rekey(&self.codec, self.last_kex, Instant::now())
{
if let Err(e) = self.initiate_rekey() {
break Err(e);
}
}
let payload = match self.read_one_packet_maybe_timeout() {
Ok(Some(p)) => p,
Ok(None) => continue, Err(e) => break Err(e),
};
if let Err(e) =
serve_dispatch_packet(self, &handlers, &mut runtimes, &mut pending_opens, &payload)
{
break Err(e);
}
};
let stale_opens = core::mem::take(&mut pending_opens);
for (_ch, po) in stale_opens {
let _ = po.reply.send(Err(Error::Protocol("serve loop terminated")));
}
runtimes.clear();
let _ = self.stream.set_read_timeout(None);
result
}
fn read_one_packet_maybe_timeout(&mut self) -> Result<Option<Vec<u8>>> {
match self.read_one_packet() {
Ok(p) => Ok(Some(p)),
Err(Error::Io(e))
if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut =>
{
Ok(None)
}
Err(e) => Err(e),
}
}
fn await_global_reply(&mut self, what: &'static str) -> Result<Vec<u8>> {
for _ in 0..MAX_EXEC_ITER {
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::GlobalSuccess { data } => return Ok(data),
ChannelEvent::GlobalFailure => {
let _ = what; return Err(Error::Protocol("global request denied"));
}
_ => {}
}
}
Err(Error::Protocol(
"global request: reply loop did not converge",
))
}
pub(crate) fn await_request_reply(&mut self, channel: u32, what: &'static str) -> Result<()> {
for _ in 0..MAX_EXEC_ITER {
let payload = self.read_one_packet()?;
match self.conn.on_packet(&payload)? {
ChannelEvent::Success { channel: c } if c == channel => return Ok(()),
ChannelEvent::Failure { channel: c } if c == channel => {
let _ = what; return Err(Error::Protocol("shell: channel request denied"));
}
_ => {}
}
}
Err(Error::Protocol(
"shell: request-reply loop did not converge",
))
}
fn do_version_and_kex(&mut self) -> Result<()> {
let v_c = crate::transport::version::LOCAL_VERSION.as_bytes().to_vec();
self.stream.write_all(&VersionExchange::outgoing_bytes())?;
let v_s = self.read_peer_version()?;
self.v_c = v_c;
self.v_s = v_s;
let advert = build_default_kexinit(&mut self.rng);
self.runner = KexRunner::new(Role::Client, advert);
let initial = self.runner.start(&mut self.rng)?;
for p in initial.outbound {
self.write_payload(&p)?;
}
self.drive_kex_to_completion()?;
self.session_id = self
.runner
.session_id()
.ok_or(Error::Protocol("kex: missing session id"))?
.to_vec();
self.last_kex = Instant::now();
Ok(())
}
fn drive_kex_to_completion(&mut self) -> Result<()> {
let mut steps = 0usize;
loop {
steps += 1;
if steps > MAX_KEX_STEPS {
return Err(Error::Protocol("kex: too many steps"));
}
let payload = self.read_one_raw_kex_packet()?;
let b = *payload.first().ok_or(Error::Format("empty payload"))?;
if is_kex_msg(b) {
self.dispatch_kex_packet(&payload)?;
if self.runner.is_completed() {
return Ok(());
}
} else {
self.deferred.push(payload);
}
}
}
fn dispatch_kex_packet(&mut self, payload: &[u8]) -> Result<()> {
let msg = *payload.first().ok_or(Error::Format("empty kex payload"))?;
let verifier_box;
let verifier: Option<&dyn HostKeyVerify> = if msg == SSH_MSG_KEX_ECDH_REPLY {
verifier_box = Some(build_verifier(
payload,
&self.host_key_policy,
&self.runner,
&self.target_host,
self.target_port,
)?);
verifier_box.as_deref()
} else {
None
};
let v_c = self.v_c.clone();
let v_s = self.v_s.clone();
let adv = self.runner.on_packet(
&mut self.rng,
&mut self.codec,
payload,
None,
verifier,
&v_c,
&v_s,
)?;
for p in adv.outbound {
self.write_payload(&p)?;
}
Ok(())
}
fn read_one_raw_kex_packet(&mut self) -> Result<Vec<u8>> {
loop {
let payload = self.read_one_raw_packet()?;
match payload.first().copied() {
Some(1) => return Err(Error::Protocol("peer sent SSH_MSG_DISCONNECT")),
Some(2) | Some(3) | Some(4) => continue,
_ => return Ok(payload),
}
}
}
fn read_peer_version(&mut self) -> Result<Vec<u8>> {
let mut buf = Vec::new();
for _ in 0..MAX_BANNER_LINES {
buf.clear();
read_line(&mut self.stream, &mut buf, MAX_BANNER_LINE)?;
if buf.starts_with(b"SSH-") {
let parsed = VersionExchange::parse_remote(&buf)?;
return Ok(parsed.into_bytes());
}
}
Err(Error::Protocol("peer banner too long"))
}
pub(crate) fn read_one_packet(&mut self) -> Result<Vec<u8>> {
loop {
if !self.runner.is_kexing() && !self.deferred.is_empty() {
return Ok(self.deferred.remove(0));
}
if !self.runner.is_kexing()
&& self
.rekey_policy
.should_rekey(&self.codec, self.last_kex, Instant::now())
{
self.initiate_rekey()?;
}
let payload = self.read_one_raw_packet()?;
match payload.first().copied() {
Some(1) => return Err(Error::Protocol("peer sent SSH_MSG_DISCONNECT")),
Some(2) | Some(3) | Some(4) => continue,
Some(b) if is_kex_msg(b) => {
if b == 20 && !self.runner.is_kexing() {
self.initiate_rekey()?;
}
self.dispatch_kex_packet(&payload)?;
if !self.runner.is_completed() {
self.drive_kex_to_completion()?;
}
self.last_kex = Instant::now();
continue;
}
_ => {
if self.runner.is_kexing() {
self.deferred.push(payload);
continue;
}
return Ok(payload);
}
}
}
}
fn initiate_rekey(&mut self) -> Result<()> {
let advert = build_default_kexinit(&mut self.rng);
let adv = self.runner.restart(&mut self.rng, advert)?;
for p in adv.outbound {
self.write_payload(&p)?;
}
Ok(())
}
fn read_one_raw_packet(&mut self) -> Result<Vec<u8>> {
loop {
if let Some((payload, consumed)) = self.codec.decode(&self.inbox)? {
self.inbox.drain(..consumed);
return Ok(payload);
}
let mut tmp = [0u8; 16 * 1024];
let n = self.stream.read(&mut tmp)?;
if n == 0 {
return Err(Error::Protocol("connection closed"));
}
self.inbox.extend_from_slice(&tmp[..n]);
if self.inbox.len() > MAX_INBOX_BYTES {
return Err(Error::Protocol("inbound buffer too large"));
}
}
}
pub(crate) fn write_payload(&mut self, payload: &[u8]) -> Result<()> {
let frame = self.codec.encode(payload, &mut self.rng)?;
self.stream.write_all(&frame)?;
Ok(())
}
}
pub struct ClientChannelStream<'a> {
client: &'a mut Client,
channel: u32,
read_buf: Vec<u8>,
stderr_buf: Vec<u8>,
remote_eof: bool,
local_close_sent: bool,
}
impl ClientChannelStream<'_> {
pub fn take_stderr(&mut self) -> Vec<u8> {
core::mem::take(&mut self.stderr_buf)
}
fn pump_one(&mut self) -> std::io::Result<()> {
let payload = self.client.read_one_packet().map_err(io_err)?;
let ev = self.client.conn.on_packet(&payload).map_err(io_err)?;
match ev {
ChannelEvent::Data { channel, data } if channel == self.channel => {
let n = data.len() as u32;
self.read_buf.extend_from_slice(&data);
if let Some(adj) = self
.client
.conn
.replenish_window(self.channel, n)
.map_err(io_err)?
{
self.client.write_payload(&adj).map_err(io_err)?;
}
}
ChannelEvent::ExtendedData {
channel,
code: _,
data,
} if channel == self.channel => {
let n = data.len() as u32;
self.stderr_buf.extend_from_slice(&data);
if let Some(adj) = self
.client
.conn
.replenish_window(self.channel, n)
.map_err(io_err)?
{
self.client.write_payload(&adj).map_err(io_err)?;
}
}
ChannelEvent::Eof { channel } if channel == self.channel => {
self.remote_eof = true;
}
ChannelEvent::Close { channel } if channel == self.channel => {
self.remote_eof = true;
if !self.local_close_sent {
let p = self.client.conn.send_close(self.channel).map_err(io_err)?;
self.client.write_payload(&p).map_err(io_err)?;
self.local_close_sent = true;
}
}
_ => {}
}
Ok(())
}
}
impl Read for ClientChannelStream<'_> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
while self.read_buf.is_empty() && !self.remote_eof {
self.pump_one()?;
}
if self.read_buf.is_empty() {
return Ok(0);
}
let n = core::cmp::min(buf.len(), self.read_buf.len());
buf[..n].copy_from_slice(&self.read_buf[..n]);
self.read_buf.drain(..n);
Ok(n)
}
}
impl Write for ClientChannelStream<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
loop {
let (payload, taken) = self
.client
.conn
.send_data(self.channel, buf)
.map_err(io_err)?;
if taken > 0 {
self.client.write_payload(&payload).map_err(io_err)?;
return Ok(taken);
}
if self.remote_eof {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"channel closed by peer mid-write",
));
}
self.pump_one()?;
}
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl Drop for ClientChannelStream<'_> {
fn drop(&mut self) {
if !self.local_close_sent {
if let Ok(p) = self.client.conn.send_eof(self.channel) {
let _ = self.client.write_payload(&p);
}
if let Ok(p) = self.client.conn.send_close(self.channel) {
let _ = self.client.write_payload(&p);
}
self.local_close_sent = true;
}
const MAX_DRAIN: usize = 128;
for _ in 0..MAX_DRAIN {
if self.remote_eof {
break;
}
if self.pump_one().is_err() {
break;
}
}
}
}
pub(crate) fn io_err(e: Error) -> std::io::Error {
match e {
Error::Io(io) => io,
other => std::io::Error::other(format!("{:?}", other)),
}
}
fn build_scp_to_cmd(remote_dest: &str, opts: &crate::scp::ScpSendOptions) -> Result<String> {
let quoted = single_quote_for_remote(remote_dest)?;
let mut s = String::from("scp -t");
if opts.recursive {
s.push_str(" -r");
}
if opts.preserve_times {
s.push_str(" -p");
}
s.push_str(" -- ");
s.push_str("ed);
Ok(s)
}
fn build_scp_from_cmd(remote_source: &str, opts: &crate::scp::ScpRecvOptions) -> Result<String> {
let quoted = single_quote_for_remote(remote_source)?;
let mut s = String::from("scp -f");
if opts.recursive {
s.push_str(" -r");
}
if opts.preserve_times {
s.push_str(" -p");
}
s.push_str(" -- ");
s.push_str("ed);
Ok(s)
}
fn single_quote_for_remote(p: &str) -> Result<String> {
if p.contains('\'') {
return Err(Error::Protocol("scp: remote path contains single quote"));
}
if p.contains('\n') {
return Err(Error::Protocol("scp: remote path contains newline"));
}
if p.contains('\0') {
return Err(Error::Protocol("scp: remote path contains NUL"));
}
if p.starts_with('-') {
return Err(Error::Protocol("scp: remote path starts with '-'"));
}
let mut q = String::with_capacity(p.len() + 2);
q.push('\'');
q.push_str(p);
q.push('\'');
Ok(q)
}
fn scp_proto(e: crate::scp::ScpError, _stage: &'static str) -> Error {
match e {
crate::scp::ScpError::Io(io) => Error::Io(io),
crate::scp::ScpError::Remote(_) => Error::Protocol("scp: remote fatal frame"),
crate::scp::ScpError::Warning(_) => Error::Protocol("scp: remote warning frame"),
crate::scp::ScpError::BadHeader(_) => Error::Protocol("scp: malformed header"),
crate::scp::ScpError::BadName(_) => Error::Protocol("scp: invalid name"),
crate::scp::ScpError::PathEscape => Error::Protocol("scp: path escapes base"),
crate::scp::ScpError::Unexpected(_) => Error::Protocol("scp: unexpected protocol state"),
}
}
fn build_default_kexinit<R: RngCore>(rng: &mut R) -> KexInit {
let algs = KexAlgorithms {
kex: defaults::KEX,
server_host_key: defaults::HOST_KEY,
ciphers_c2s: defaults::CIPHERS,
ciphers_s2c: defaults::CIPHERS,
macs_c2s: defaults::MACS,
macs_s2c: defaults::MACS,
comp_c2s: defaults::COMP,
comp_s2c: defaults::COMP,
lang_c2s: &[],
lang_s2c: &[],
};
let mut cookie = [0u8; 16];
rng.fill_bytes(&mut cookie);
KexInit::from_algorithms(&algs, cookie)
}
fn build_verifier(
reply_payload: &[u8],
policy: &HostKeyPolicy,
runner: &KexRunner,
target_host: &str,
target_port: u16,
) -> Result<Box<dyn HostKeyVerify>> {
if reply_payload.len() < 5 {
return Err(Error::Format("kex-ecdh-reply too short"));
}
let k_s_len = u32::from_be_bytes([
reply_payload[1],
reply_payload[2],
reply_payload[3],
reply_payload[4],
]) as usize;
if reply_payload.len() < 5 + k_s_len {
return Err(Error::Format("kex-ecdh-reply truncated"));
}
let k_s = &reply_payload[5..5 + k_s_len];
let neg = runner
.negotiated()
.ok_or(Error::Protocol("kex: no negotiated algorithms"))?;
match policy {
HostKeyPolicy::AcceptAny => {}
HostKeyPolicy::AcceptFingerprint(fp) => {
let digest = Sha256::digest(k_s);
if digest.as_ref() != fp {
return Err(Error::HostKeyRejected);
}
}
HostKeyPolicy::KnownHosts(kh) => {
if target_host.is_empty() || target_port == 0 {
} else {
let mut store = kh.store.lock().map_err(|_| Error::HostKeyRejected)?;
let lookup = store.lookup(target_host, target_port, &neg.host_key, k_s);
match lookup {
LookupResult::Match => {}
LookupResult::Mismatch { .. } => {
return Err(Error::HostKeyRejected);
}
LookupResult::Unknown => {
let accept = match &kh.on_unknown {
TofuAction::Reject => false,
TofuAction::Accept => true,
TofuAction::Prompt(cb) => {
drop(store);
let ok = cb(target_host, target_port, &neg.host_key, k_s);
store = kh.store.lock().map_err(|_| Error::HostKeyRejected)?;
ok
}
};
if !accept {
return Err(Error::HostKeyRejected);
}
store.add(target_host, target_port, &neg.host_key, k_s, kh.hash_new);
if let Some(path) = &kh.save_path {
store.save(path).map_err(Error::from)?;
}
}
}
}
}
}
host_key_verify_by_name(&neg.host_key, k_s)
}
fn serve_drain_commands(
client: &mut Client,
cmd_rx: &Receiver<ServeCommand>,
pending_opens: &mut BTreeMap<u32, PendingOutboundOpen>,
) -> Result<()> {
loop {
match cmd_rx.try_recv() {
Ok(ServeCommand::OpenDirectTcpip {
dest_host,
dest_port,
orig_host,
orig_port,
reply,
}) => {
let (local_id, open_payload) = client.conn.open(ChannelOpen::DirectTcpip {
dest_host,
dest_port: dest_port as u32,
orig_host,
orig_port: orig_port as u32,
})?;
client.write_payload(&open_payload)?;
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SERVE_EGRESS_BACKLOG);
let stream = ChannelStream::new(ingress_rx, egress_tx);
pending_opens.insert(
local_id,
PendingOutboundOpen {
stream: Some(stream),
ingress_tx,
egress_rx: Some(egress_rx),
reply,
},
);
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => break,
}
}
Ok(())
}
fn serve_drain_runtimes(
client: &mut Client,
runtimes: &mut BTreeMap<u32, ServeRuntime>,
) -> Result<()> {
let channels: Vec<u32> = runtimes.keys().copied().collect();
for ch in channels {
let Some(rt) = runtimes.get_mut(&ch) else {
continue;
};
if rt.close_sent {
continue;
}
if !rt.pending_data.is_empty() {
let leftover = core::mem::take(&mut rt.pending_data);
emit_serve_data(client, ch, &leftover, rt)?;
if !rt.pending_data.is_empty() {
continue;
}
}
loop {
if !rt.pending_data.is_empty() {
break;
}
match rt.egress_rx.try_recv() {
Ok(ChannelEgress::Data(bytes)) => {
emit_serve_data(client, ch, &bytes, rt)?;
}
Ok(ChannelEgress::Eof) => {
rt.pending_eof = true;
break;
}
Ok(ChannelEgress::Close) => {
rt.pending_close = true;
break;
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => {
rt.pending_close = true;
break;
}
}
}
if rt.pending_data.is_empty() {
if rt.pending_eof && !rt.eof_sent {
let p = client.conn.send_eof(ch)?;
client.write_payload(&p)?;
rt.eof_sent = true;
}
if rt.pending_close && !rt.close_sent {
if !rt.eof_sent {
let p = client.conn.send_eof(ch)?;
client.write_payload(&p)?;
rt.eof_sent = true;
}
let p = client.conn.send_close(ch)?;
client.write_payload(&p)?;
rt.close_sent = true;
}
}
}
Ok(())
}
fn emit_serve_data(
client: &mut Client,
channel: u32,
bytes: &[u8],
rt: &mut ServeRuntime,
) -> Result<()> {
let mut off = 0usize;
while off < bytes.len() {
let (payload, taken) = client.conn.send_data(channel, &bytes[off..])?;
if taken == 0 {
rt.pending_data.extend_from_slice(&bytes[off..]);
return Ok(());
}
client.write_payload(&payload)?;
off += taken;
}
Ok(())
}
fn serve_dispatch_packet(
client: &mut Client,
handlers: &ClientHandlers,
runtimes: &mut BTreeMap<u32, ServeRuntime>,
pending_opens: &mut BTreeMap<u32, PendingOutboundOpen>,
payload: &[u8],
) -> Result<()> {
let ev = client.conn.on_packet(payload)?;
match ev {
ChannelEvent::OpenConfirmed { channel } => {
if let Some(mut po) = pending_opens.remove(&channel) {
let stream = po
.stream
.take()
.ok_or(Error::Protocol("pending open: stream taken twice"))?;
let egress_rx = po
.egress_rx
.take()
.ok_or(Error::Protocol("pending open: egress taken twice"))?;
if po.reply.send(Ok(stream)).is_err() {
let p = client.conn.send_close(channel)?;
client.write_payload(&p)?;
return Ok(());
}
runtimes.insert(
channel,
ServeRuntime {
ingress_tx: po.ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
}
}
ChannelEvent::OpenFailed { channel, .. } => {
if let Some(po) = pending_opens.remove(&channel) {
let _ = po
.reply
.send(Err(Error::Protocol("direct-tcpip: open failed")));
}
}
ChannelEvent::OpenRequest { channel, kind } => match kind {
ChannelOpen::ForwardedTcpip {
dest_host,
dest_port,
orig_host,
orig_port,
} => {
if let Some(cb) = handlers.on_forwarded_tcpip.clone() {
let p = client.conn.accept_open(channel)?;
client.write_payload(&p)?;
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SERVE_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
let origin = ForwardedTcpipOrigin {
bound_address: dest_host,
bound_port: clamp_u16(dest_port),
orig_address: orig_host,
orig_port: clamp_u16(orig_port),
};
thread::spawn(move || {
cb(origin, cs);
});
runtimes.insert(
channel,
ServeRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
} else {
let p = client.conn.reject_open(
channel,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
"forwarded-tcpip not enabled",
"",
)?;
client.write_payload(&p)?;
}
}
ChannelOpen::AuthAgent => {
if let Some(cb) = handlers.on_auth_agent.clone() {
let p = client.conn.accept_open(channel)?;
client.write_payload(&p)?;
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SERVE_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
thread::spawn(move || {
cb(cs);
});
runtimes.insert(
channel,
ServeRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
} else {
let p = client.conn.reject_open(
channel,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
"auth-agent not enabled",
"",
)?;
client.write_payload(&p)?;
}
}
ChannelOpen::X11 {
orig_host: _,
orig_port: _,
} => {
if let Some(cb) = handlers.on_x11.clone() {
let p = client.conn.accept_open(channel)?;
client.write_payload(&p)?;
let (ingress_tx, ingress_rx) = mpsc::channel::<Option<Vec<u8>>>();
let (egress_tx, egress_rx) =
mpsc::sync_channel::<ChannelEgress>(SERVE_EGRESS_BACKLOG);
let cs = ChannelStream::new(ingress_rx, egress_tx);
thread::spawn(move || {
cb(cs);
});
runtimes.insert(
channel,
ServeRuntime {
ingress_tx,
egress_rx,
pending_data: Vec::new(),
pending_eof: false,
pending_close: false,
eof_sent: false,
close_sent: false,
},
);
} else {
let p = client.conn.reject_open(
channel,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
"x11 not enabled",
"",
)?;
client.write_payload(&p)?;
}
}
_ => {
let p = client.conn.reject_open(
channel,
SSH_OPEN_ADMINISTRATIVELY_PROHIBITED,
"channel type not supported",
"",
)?;
client.write_payload(&p)?;
}
},
ChannelEvent::Data { channel, data } => {
if let Some(rt) = runtimes.get_mut(&channel) {
let _ = rt.ingress_tx.send(Some(data.clone()));
}
if let Some(adj) = client.conn.replenish_window(channel, data.len() as u32)? {
client.write_payload(&adj)?;
}
}
ChannelEvent::ExtendedData { channel, data, .. } => {
if let Some(adj) = client.conn.replenish_window(channel, data.len() as u32)? {
client.write_payload(&adj)?;
}
}
ChannelEvent::Eof { channel } => {
if let Some(rt) = runtimes.get_mut(&channel) {
let _ = rt.ingress_tx.send(None);
}
}
ChannelEvent::Close { channel } => {
if let Some(ch) = client.conn.channel(channel) {
if !ch.local_closed {
let p = client.conn.send_close(channel)?;
client.write_payload(&p)?;
}
}
runtimes.remove(&channel);
}
_ => {}
}
Ok(())
}
fn clamp_u16(v: u32) -> u16 {
if v > u16::MAX as u32 {
u16::MAX
} else {
v as u16
}
}
fn read_line<S: Read>(stream: &mut S, buf: &mut Vec<u8>, max_len: usize) -> Result<()> {
let mut byte = [0u8; 1];
loop {
let n = stream.read(&mut byte)?;
if n == 0 {
return Err(Error::Protocol("connection closed before newline"));
}
buf.push(byte[0]);
if byte[0] == b'\n' {
return Ok(());
}
if buf.len() >= max_len {
return Err(Error::Protocol("banner line too long"));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hostkey::Ed25519HostKey;
use crate::transport::version::LOCAL_VERSION;
use std::io::{Cursor, Read, Write};
use std::net::TcpListener;
use std::thread;
#[test]
fn read_line_caps_length() {
let mut buf = Vec::new();
let mut src = Cursor::new(vec![b'A'; 4096]);
let err = read_line(&mut src, &mut buf, 1024);
assert!(matches!(err, Err(Error::Protocol(_))));
}
#[test]
fn read_line_returns_at_newline() {
let mut buf = Vec::new();
let mut src = Cursor::new(b"hello\r\n".to_vec());
read_line(&mut src, &mut buf, 1024).unwrap();
assert_eq!(buf, b"hello\r\n");
}
#[test]
fn config_default_is_accept_any() {
let cfg = Config::default();
assert!(matches!(cfg.host_key_policy, HostKeyPolicy::AcceptAny));
assert!(cfg.timeout.is_none());
}
#[test]
fn exec_output_constructible() {
let _ = ExecOutput {
stdout: Vec::new(),
stderr: Vec::new(),
exit_status: Some(0),
exit_signal: None,
};
}
fn run_server(
listener: TcpListener,
host_key_seed: [u8; 32],
) -> thread::JoinHandle<std::result::Result<Vec<u8>, String>> {
thread::spawn(move || -> std::result::Result<Vec<u8>, String> {
let (mut s, _) = listener.accept().map_err(|e| e.to_string())?;
let server_hk = Ed25519HostKey::from_seed(host_key_seed);
s.write_all(&VersionExchange::outgoing_bytes())
.map_err(|e| e.to_string())?;
let mut line = Vec::new();
let v_c: Vec<u8> = {
read_line(&mut s, &mut line, 1024).map_err(|e| format!("{e:?}"))?;
if !line.starts_with(b"SSH-") {
return Err("client did not send SSH banner".into());
}
let parsed = VersionExchange::parse_remote(&line).map_err(|e| format!("{e:?}"))?;
parsed.into_bytes()
};
let v_s = LOCAL_VERSION.as_bytes().to_vec();
let mut codec = PacketCodec::new();
let advert = build_default_kexinit(&mut OsRng);
let mut runner = KexRunner::new(Role::Server, advert);
let mut inbox: Vec<u8> = Vec::new();
let mut rng = OsRng;
let initial = runner.start(&mut rng).map_err(|e| format!("{e:?}"))?;
for p in initial.outbound {
let frame = codec.encode(&p, &mut rng).map_err(|e| format!("{e:?}"))?;
s.write_all(&frame).map_err(|e| e.to_string())?;
}
let mut steps = 0;
loop {
steps += 1;
if steps > MAX_KEX_STEPS {
return Err("server kex did not converge".into());
}
let payload = read_one_packet_local(&mut s, &mut codec, &mut inbox)
.map_err(|e| format!("{e:?}"))?;
let adv = runner
.on_packet(
&mut rng,
&mut codec,
&payload,
Some(&server_hk),
None,
&v_c,
&v_s,
)
.map_err(|e| format!("{e:?}"))?;
for p in adv.outbound {
let frame = codec.encode(&p, &mut rng).map_err(|e| format!("{e:?}"))?;
s.write_all(&frame).map_err(|e| e.to_string())?;
}
if adv.completed {
break;
}
}
let sid = runner.session_id().unwrap().to_vec();
Ok(sid)
})
}
fn read_one_packet_local(
s: &mut TcpStream,
codec: &mut PacketCodec,
inbox: &mut Vec<u8>,
) -> Result<Vec<u8>> {
loop {
if let Some((payload, consumed)) = codec.decode(inbox)? {
inbox.drain(..consumed);
return Ok(payload);
}
let mut tmp = [0u8; 4096];
let n = s.read(&mut tmp)?;
if n == 0 {
return Err(Error::Protocol("connection closed"));
}
inbox.extend_from_slice(&tmp[..n]);
if inbox.len() > MAX_INBOX_BYTES {
return Err(Error::Protocol("inbound buffer too large"));
}
}
}
#[test]
fn handshake_over_real_loopback_socket() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let mut seed = [0u8; 32];
OsRng.fill_bytes(&mut seed);
let server = run_server(listener, seed);
let client = Client::connect(addr, Config::default()).expect("client connect");
let server_sid = server.join().unwrap().expect("server handshake");
assert_eq!(client.session_id, server_sid);
assert!(!client.session_id.is_empty());
}
#[test]
fn fingerprint_mismatch_rejected() {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let mut seed = [0u8; 32];
OsRng.fill_bytes(&mut seed);
let server = run_server(listener, seed);
let cfg = Config {
host_key_policy: HostKeyPolicy::AcceptFingerprint([0xffu8; 32]),
timeout: None,
};
let err = Client::connect(addr, cfg).err().expect("must fail");
assert!(matches!(err, Error::HostKeyRejected));
let _ = server.join();
}
}