#[allow(unused_imports)]
use {
crate::error::{Error, Result, TrapBug},
log::{debug, error, info, log, trace, warn},
};
use core::{hash::Hash, mem::discriminant, task::Waker};
use crate::*;
use channel::{ChanData, ChanNum};
use channel::{CliSessionExit, CliSessionOpener};
use encrypt::KeyState;
use event::{CliEvent, CliEventId, Event, ServEvent, ServEventId};
use traffic::{TrafIn, TrafOut};
use conn::{CliServ, Conn, DispatchEvent, Dispatched};
pub(crate) type ServRunner<'a> = Runner<'a, server::Server>;
pub(crate) type CliRunner<'a> = Runner<'a, client::Client>;
pub struct Runner<'a, CS: conn::CliServ> {
conn: Conn<CS>,
traf_in: TrafIn<'a>,
traf_out: TrafOut<'a>,
keys: KeyState,
output_waker: Option<Waker>,
input_waker: Option<Waker>,
closed_input: bool,
resume_event: DispatchEvent,
extra_resume_event: DispatchEvent,
}
impl<CS: CliServ> core::fmt::Debug for Runner<'_, CS> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Runner")
.field("keys", &self.keys)
.field("output_waker", &self.output_waker)
.field("input_waker", &self.input_waker)
.finish_non_exhaustive()
}
}
impl<'a> Runner<'a, client::Client> {
pub fn new_client(
inbuf: &'a mut [u8],
outbuf: &'a mut [u8],
) -> Runner<'a, client::Client> {
Self::new(inbuf, outbuf)
}
pub fn term_break(&mut self, chan: &ChanHandle, length: u32) -> Result<()> {
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.channels.term_break(chan.0, length, &mut s)
}
pub(crate) fn fetch_cli_session_exit(&mut self) -> Result<CliSessionExit<'_>> {
let (payload, _seq) = self.traf_in.payload().trap()?;
self.conn.fetch_cli_session_exit(payload)
}
pub(crate) fn fetch_cli_banner(&mut self) -> Result<event::Banner<'_>> {
let (payload, _seq) = self.traf_in.payload().trap()?;
self.conn.fetch_cli_banner(payload)
}
pub(crate) fn cli_session_opener(
&mut self,
ch: ChanNum,
) -> Result<CliSessionOpener<'_, 'a>> {
let ch = self.conn.channels.get(ch)?;
let s = self.traf_out.sender(&mut self.keys);
Ok(CliSessionOpener { ch, s })
}
pub(crate) fn resume_cliusername(&mut self, username: &str) -> Result<()> {
self.resume(&DispatchEvent::CliEvent(CliEventId::Username));
let mut s = self.traf_out.sender(&mut self.keys);
let (cliauth, _) = self.conn.mut_cliauth()?;
cliauth.resume_username(&mut s, username)?;
self.traf_in.done_payload();
Ok(())
}
pub(crate) fn resume_clipassword(
&mut self,
password: Option<&str>,
) -> Result<()> {
self.resume(&DispatchEvent::CliEvent(CliEventId::Password));
self.traf_in.done_payload();
let mut s = self.traf_out.sender(&mut self.keys);
let (cliauth, ctx) = self.conn.mut_cliauth()?;
cliauth.resume_password(&mut s, password, ctx)?;
debug_assert!(password.is_some(), "no password");
Ok(())
}
pub(crate) fn resume_clipubkey(&mut self, key: Option<SignKey>) -> Result<()> {
self.resume(&DispatchEvent::CliEvent(CliEventId::Pubkey));
let mut s = self.traf_out.sender(&mut self.keys);
let (cliauth, ctx) = self.conn.mut_cliauth()?;
self.extra_resume_event = cliauth.resume_pubkey(&mut s, key, ctx)?;
if self.extra_resume_event.is_none() {
self.traf_in.done_payload();
}
Ok(())
}
pub(crate) fn fetch_agentsign_key(&self) -> Result<&SignKey> {
self.check_resume(&DispatchEvent::CliEvent(CliEventId::AgentSign));
let cliauth = self.conn.cliauth()?;
cliauth.fetch_agentsign_key()
}
pub(crate) fn fetch_agentsign_msg(&self) -> Result<AuthSigMsg<'_>> {
self.check_resume(&DispatchEvent::CliEvent(CliEventId::AgentSign));
self.conn.fetch_agentsign_msg()
}
pub(crate) fn resume_agentsign(&mut self, sig: Option<&OwnedSig>) -> Result<()> {
self.resume(&DispatchEvent::CliEvent(CliEventId::AgentSign));
let (cliauth, ctx) = self.conn.mut_cliauth()?;
let mut s = self.traf_out.sender(&mut self.keys);
self.extra_resume_event = cliauth.resume_agentsign(sig, ctx, &mut s)?;
if self.extra_resume_event.is_none() {
self.traf_in.done_payload();
}
Ok(())
}
pub(crate) fn resume_checkhostkey(&mut self, accept: bool) -> Result<()> {
self.resume(&DispatchEvent::CliEvent(CliEventId::Hostkey));
let (payload, _seq) = self.traf_in.payload().trap()?;
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.resume_checkhostkey(payload, &mut s, accept)?;
self.traf_in.done_payload();
Ok(())
}
pub(crate) fn fetch_checkhostkey(&self) -> Result<PubKey<'_>> {
self.check_resume(&DispatchEvent::CliEvent(CliEventId::Hostkey));
let (payload, _seq) = self.traf_in.payload().trap()?;
self.conn.fetch_checkhostkey(payload)
}
}
impl<'a> Runner<'a, server::Server> {
pub fn new_server(
inbuf: &'a mut [u8],
outbuf: &'a mut [u8],
) -> Runner<'a, server::Server> {
Self::new(inbuf, outbuf)
}
pub(crate) fn resume_servhostkeys(&mut self, keys: &[&SignKey]) -> Result<()> {
self.resume(&DispatchEvent::ServEvent(ServEventId::Hostkeys));
let (payload, _seq) = self.traf_in.payload().trap()?;
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.resume_servhostkeys(payload, &mut s, keys)?;
self.traf_in.done_payload();
Ok(())
}
pub(crate) fn fetch_servusername(&self) -> Result<TextString<'_>> {
let u = self.conn.server()?.auth.username.as_ref().trap()?;
Ok(TextString(u.as_slice()))
}
pub(crate) fn fetch_servpassword(&self) -> Result<TextString<'_>> {
self.check_resume(&DispatchEvent::ServEvent(ServEventId::PasswordAuth));
let (payload, _seq) = self.traf_in.payload().trap()?;
self.conn.fetch_servpassword(payload)
}
pub(crate) fn fetch_servpubkey(&self) -> Result<PubKey<'_>> {
self.check_resume(&DispatchEvent::ServEvent(ServEventId::PubkeyAuth {
real_sig: false,
}));
let (payload, _seq) = self.traf_in.payload().trap()?;
self.conn.fetch_servpubkey(payload)
}
pub(crate) fn resume_servauth(&mut self, allow: bool) -> Result<()> {
let prev_event = self.resume_event.take();
self.traf_in.zeroize_payload();
debug_assert!(
matches!(
prev_event,
DispatchEvent::ServEvent(ServEventId::PasswordAuth)
) || matches!(
prev_event,
DispatchEvent::ServEvent(ServEventId::PubkeyAuth { .. })
) || matches!(
prev_event,
DispatchEvent::ServEvent(ServEventId::FirstAuth)
)
);
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.resume_servauth(allow, &mut s)
}
pub(crate) fn resume_servauth_pkok(&mut self) -> Result<()> {
self.resume(&DispatchEvent::ServEvent(ServEventId::PubkeyAuth {
real_sig: false,
}));
let (payload, _seq) = self.traf_in.payload().trap()?;
let mut s = self.traf_out.sender(&mut self.keys);
let r = self.conn.resume_servauth_pkok(payload, &mut s);
self.traf_in.done_payload();
r
}
pub(crate) fn set_auth_methods(
&mut self,
password: bool,
pubkey: bool,
) -> Result<()> {
self.conn.set_auth_methods(password, pubkey)
}
pub(crate) fn get_auth_methods(&self) -> Result<(bool, bool)> {
let auth = &self.conn.server()?.auth;
Ok((auth.method_password, auth.method_pubkey))
}
}
impl<'a, CS: CliServ> Runner<'a, CS> {
pub fn new(inbuf: &'a mut [u8], outbuf: &'a mut [u8]) -> Runner<'a, CS> {
Runner {
conn: Conn::new(),
traf_in: TrafIn::new(inbuf),
traf_out: TrafOut::new(outbuf),
keys: KeyState::new_cleartext(),
output_waker: None,
input_waker: None,
closed_input: false,
resume_event: DispatchEvent::None,
extra_resume_event: DispatchEvent::None,
}
}
pub fn progress(&mut self) -> Result<Event<'_, 'a>> {
let prev = self.resume_event.take();
if prev.needs_resume() {
debug!("No response provided to {:?} event", prev);
return error::BadUsage.fail();
}
let ex = self.extra_resume_event.take();
if ex.is_some() {
self.resume_event = ex.clone();
return CS::dispatch_into_event(self, ex);
}
if prev.is_event() {
self.traf_in.done_payload();
}
let mut disp = Dispatched::default();
if let Some((payload, seq)) = self.traf_in.payload() {
let mut s = self.traf_out.sender(&mut self.keys);
disp = self.conn.handle_payload(payload, seq, &mut s)?;
match disp.event {
DispatchEvent::Data(data_in) => {
let (num, dt) = self.traf_in.set_read_channel_data(data_in)?;
self.channel_wake_read(num, dt);
disp.event = DispatchEvent::None
}
DispatchEvent::CliEvent(_) | DispatchEvent::ServEvent(_) => {
}
DispatchEvent::None => {
self.traf_in.done_payload()
}
DispatchEvent::KexDone => {
self.channel_wake_write();
self.traf_in.done_payload();
disp.event = DispatchEvent::None;
}
DispatchEvent::Progressed => return Err(Error::bug()),
}
} else if self.closed_input {
if CS::is_client() {
return Ok(Event::Cli(CliEvent::Defunct));
} else {
return Ok(Event::Serv(ServEvent::Defunct));
}
}
if disp.event.is_none() {
let mut s = self.traf_out.sender(&mut self.keys);
disp = self.conn.progress(&mut s)?;
trace!("prog disp {disp:?}");
match disp.event {
DispatchEvent::CliEvent(_)
| DispatchEvent::ServEvent(_)
| DispatchEvent::None
| DispatchEvent::Progressed => (),
DispatchEvent::Data(_) | DispatchEvent::KexDone => {
return Err(Error::bug())
}
}
}
self.wake();
self.resume_event = disp.event.clone();
CS::dispatch_into_event(self, disp.event)
}
pub(crate) fn packet(&self) -> Result<Option<packets::Packet<'_>>> {
if let Some((payload, _seq)) = self.traf_in.payload() {
self.conn.packet(payload).map(Some)
} else {
Ok(None)
}
}
pub fn input(&mut self, buf: &[u8]) -> Result<usize, Error> {
if self.closed_input {
return error::SessionEOF.fail();
}
if !self.is_input_ready() {
return Ok(0);
}
self.traf_in.input(&mut self.keys, &mut self.conn.remote_version, buf)
}
pub fn is_input_ready(&self) -> bool {
(self.conn.initial_sent() && self.traf_in.is_input_ready())
|| self.closed_input
}
pub fn set_input_waker(&mut self, waker: &Waker) {
set_waker(&mut self.input_waker, waker)
}
pub fn close_input(&mut self) {
trace!("close_input");
self.closed_input = true;
}
pub fn output(&mut self, buf: &mut [u8]) -> usize {
let out = self.output_buf();
let l = out.len().min(buf.len());
buf.copy_from_slice(&out[..l]);
self.consume_output(l);
l
}
pub fn output_buf(&mut self) -> &[u8] {
self.traf_out.output_buf()
}
pub fn consume_output(&mut self, l: usize) {
trace!("consume_output {l}");
self.traf_out.consume_output(l);
if !self.traf_out.is_output_pending() {
self.channel_wake_write();
self.wake();
}
}
pub fn is_output_pending(&self) -> bool {
self.traf_out.is_output_pending()
}
pub fn set_output_waker(&mut self, waker: &Waker) {
trace!("set_output_waker");
set_waker(&mut self.output_waker, waker);
}
pub fn close_output(&mut self) {
trace!("close_input");
self.traf_out.close();
self.wake();
}
pub fn open_client_session(&mut self) -> Result<ChanHandle> {
trace!("open_client_session");
let (chan, p) =
self.conn.channels.open(packets::ChannelOpenType::Session)?;
self.traf_out.send_packet(p, &mut self.keys)?;
self.wake();
Ok(ChanHandle(chan))
}
pub fn write_channel(
&mut self,
chan: &ChanHandle,
dt: ChanData,
buf: &[u8],
) -> Result<usize> {
if self.traf_out.closed() {
return error::ChannelEOF.fail();
}
if buf.is_empty() {
return Ok(0);
}
let len = self.write_channel_ready(chan, dt)?;
let len = match len {
Some(0) => return Ok(0),
Some(l) => l,
None => return Err(Error::ChannelEOF),
};
let len = len.min(buf.len());
let p = self.conn.channels.send_data(chan.0, dt, &buf[..len])?;
trace!("send_packet ch {:?} dt {:?} {}", chan.0, dt, len);
self.traf_out.send_packet(p, &mut self.keys)?;
self.wake();
Ok(len)
}
pub fn read_channel(
&mut self,
chan: &ChanHandle,
dt: ChanData,
buf: &mut [u8],
) -> Result<usize> {
if self.closed_input {
return error::ChannelEOF.fail();
}
dt.validate_receive(CS::is_client())?;
if self.is_channel_eof(chan) {
return error::ChannelEOF.fail();
}
let (len, complete) = self.traf_in.read_channel(chan.0, dt, buf);
if let Some(x) = complete {
self.finished_read_channel(chan, x)?;
}
Ok(len)
}
pub fn read_channel_either(
&mut self,
chan: &ChanHandle,
buf: &mut [u8],
) -> Result<(usize, ChanData)> {
let (len, complete, dt) = self.traf_in.read_channel_either(chan.0, buf);
if let Some(x) = complete {
self.finished_read_channel(chan, x)?;
}
Ok((len, dt))
}
pub fn discard_read_channel(&mut self, chan: &ChanHandle) -> Result<()> {
let x = self.traf_in.discard_read_channel(chan.0);
self.finished_read_channel(chan, x)?;
Ok(())
}
fn finished_read_channel(
&mut self,
chan: &ChanHandle,
len: usize,
) -> Result<()> {
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.channels.finished_read(chan.0, len, &mut s)?;
self.wake();
Ok(())
}
pub fn read_channel_ready(&self) -> Option<(ChanNum, ChanData, usize)> {
self.traf_in.read_channel_ready()
}
pub fn is_channel_eof(&self, chan: &ChanHandle) -> bool {
self.conn.channels.have_recv_eof(chan.0) || self.closed_input
}
pub fn is_channel_closed(&self, chan: &ChanHandle) -> bool {
self.conn.channels.is_closed(chan.0) || self.closed_input
}
pub fn write_channel_ready(
&self,
chan: &ChanHandle,
dt: ChanData,
) -> Result<Option<usize>> {
if self.traf_out.closed() {
return Ok(None);
}
dt.validate_send(CS::is_client())?;
if !self.conn.kex_is_idle() {
return Ok(Some(0));
}
let payload_space = self.traf_out.send_allowed(&self.keys);
let payload_space = payload_space.saturating_sub(dt.packet_offset());
let r = Ok(self
.conn
.channels
.send_allowed(chan.0)
.map(|s| s.min(payload_space)));
trace!("ready_channel_send {chan:?} -> {r:?}");
r
}
pub fn is_write_channel_valid(&self, chan: &ChanHandle, dt: ChanData) -> bool {
self.conn.channels.valid_send(chan.0, dt)
}
pub fn channel_done(&mut self, chan: ChanHandle) -> Result<()> {
self.conn.channels.done(chan.0)?;
self.traf_in.discard_read_channel(chan.0);
self.wake();
Ok(())
}
pub fn set_channel_read_waker(
&mut self,
ch: &ChanHandle,
dt: ChanData,
waker: &Waker,
) {
self.conn.channels.by_handle_mut(ch).set_read_waker(
dt,
CS::is_client(),
waker,
)
}
pub fn set_channel_write_waker(
&mut self,
ch: &ChanHandle,
dt: ChanData,
waker: &Waker,
) {
self.conn.channels.by_handle_mut(ch).set_write_waker(
dt,
CS::is_client(),
waker,
)
}
fn channel_wake_read(&mut self, num: ChanNum, dt: ChanData) {
self.conn.channels.wake_read(num, dt, CS::is_client())
}
fn channel_wake_write(&mut self) {
self.conn.channels.wake_write(CS::is_client())
}
pub fn term_window_change(
&mut self,
chan: &ChanHandle,
winch: &packets::WinChange,
) -> Result<()> {
if CS::is_client() {
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.channels.term_window_change(chan.0, winch, &mut s)
} else {
trace!("winch as server");
Err(error::BadUsage.build())
}
}
fn wake(&mut self) {
trace!("wake");
if self.is_input_ready() {
trace!("wake ready_input, waker {:?}", self.input_waker);
if let Some(w) = self.input_waker.take() {
trace!("wake input waker");
w.wake()
}
} else {
trace!("no input ready");
}
if self.is_output_pending() {
if let Some(w) = self.output_waker.take() {
trace!("wake output waker");
w.wake()
} else {
trace!("no waker");
}
} else {
trace!("no output pending")
}
}
fn check_resume_inner(&self, expect: &DispatchEvent, compare: &DispatchEvent) {
match (expect, compare) {
(DispatchEvent::CliEvent(e), DispatchEvent::CliEvent(c)) => {
debug_assert_eq!(
discriminant(c),
discriminant(e),
"Expected response to pending {expect:?} event"
)
}
(DispatchEvent::ServEvent(e), DispatchEvent::ServEvent(c)) => {
debug_assert_eq!(
discriminant(c),
discriminant(e),
"Expected response to pending {expect:?} event"
)
}
_ => debug_assert!(false),
}
}
fn resume(&mut self, expect: &DispatchEvent) {
let prev_event = self.resume_event.take();
self.check_resume_inner(expect, &prev_event)
}
fn check_resume(&self, expect: &DispatchEvent) {
self.check_resume_inner(expect, &self.resume_event)
}
pub(crate) fn resume_chanopen(
&mut self,
num: ChanNum,
failure: Option<ChanFail>,
) -> Result<()> {
self.resume(&DispatchEvent::ServEvent(ServEventId::OpenSession { num }));
self.traf_in.done_payload();
let mut s = self.traf_out.sender(&mut self.keys);
self.conn.channels.resume_open(num, failure, &mut s)
}
fn check_chanreq(prev_event: &DispatchEvent) {
debug_assert!(matches!(
prev_event,
DispatchEvent::ServEvent(ServEventId::SessionShell { .. })
| DispatchEvent::ServEvent(ServEventId::SessionExec { .. })
| DispatchEvent::ServEvent(ServEventId::SessionSubsystem { .. })
| DispatchEvent::ServEvent(ServEventId::SessionPty { .. })
| DispatchEvent::ServEvent(ServEventId::Environment { .. })
));
}
pub(crate) fn resume_chanreq(&mut self, success: bool) -> Result<()> {
let prev_event = self.resume_event.take();
trace!("resume chanreq {prev_event:?} {success}");
Self::check_chanreq(&prev_event);
let mut s = self.traf_out.sender(&mut self.keys);
let (payload, _seq) = self.traf_in.payload().trap()?;
let p = self.conn.packet(payload)?;
let r = self.conn.channels.resume_chanreq(&p, success, &mut s);
self.traf_in.done_payload();
r
}
pub(crate) fn fetch_servcommand(&self) -> Result<TextString<'_>> {
Self::check_chanreq(&self.resume_event);
let (payload, _seq) = self.traf_in.payload().trap()?;
let p = self.conn.packet(payload)?;
self.conn.channels.fetch_servcommand(&p)
}
pub(crate) fn fetch_env_name(&self) -> Result<TextString<'_>> {
Self::check_chanreq(&self.resume_event);
let (payload, _seq) = self.traf_in.payload().trap()?;
let p = self.conn.packet(payload)?;
self.conn.channels.fetch_env_name(&p)
}
pub(crate) fn fetch_env_value(&self) -> Result<TextString<'_>> {
Self::check_chanreq(&self.resume_event);
let (payload, _seq) = self.traf_in.payload().trap()?;
let p = self.conn.packet(payload)?;
self.conn.channels.fetch_env_value(&p)
}
}
pub(crate) fn set_waker(store_waker: &mut Option<Waker>, new_waker: &Waker) {
if let Some(w) = store_waker {
if w.will_wake(new_waker) {
return;
}
}
if let Some(w) = store_waker.take() {
w.wake()
}
*store_waker = Some(new_waker.clone())
}
#[derive(PartialEq, Eq, Hash, Ord, PartialOrd)]
pub struct ChanHandle(pub(crate) ChanNum);
impl ChanHandle {
pub fn num(&self) -> ChanNum {
self.0
}
}
impl core::fmt::Debug for ChanHandle {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "ChanHandle({})", self.num())
}
}
#[cfg(test)]
mod tests {
}