use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
use russh::{
Channel, ChannelId, Pty, Sig,
server::{Auth, Handle, Msg, Session},
};
use tokio::{
sync::{mpsc, mpsc::UnboundedSender},
task::JoinSet,
};
use crate::{
Device,
ssh::{SshAccept, TailnetServer},
};
type Request = (ChannelId, ChannelEvent);
pub trait ChannelHandler: Sized {
type Error: Into<std::io::Error> + std::error::Error;
fn new(
handle: tokio::runtime::Handle,
channel_id: ChannelId,
session: Handle,
dev: Arc<Device>,
accept: &SshAccept,
) -> Result<Self, Self::Error>;
fn handle_event(
&mut self,
event: &ChannelEvent,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
}
pub struct ChannelServer<H> {
channel_state: HashMap<ChannelId, ChannelState>,
remote: SocketAddr,
dev: Arc<Device>,
accepted: Option<SshAccept>,
_handler: PhantomSend<H>,
}
struct PhantomSend<H>(PhantomData<fn() -> H>);
const MAX_CHANNELS_PER_CONN: usize = 16;
fn at_channel_cap(open_channels: usize) -> bool {
open_channels >= MAX_CHANNELS_PER_CONN
}
#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
#[error("no such channel")]
struct NoChannel;
struct ChannelState {
channel: ChannelId,
tx: UnboundedSender<Request>,
_joinset: JoinSet<()>,
}
impl ChannelState {
fn send(&self, event: ChannelEvent) {
if self.tx.send((self.channel, event)).is_err() {
tracing::error!(channel = %self.channel, "failed to send event");
}
}
}
impl<H> ChannelServer<H> {
fn get_channel(
&mut self,
id: ChannelId,
) -> Result<&mut ChannelState, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.channel_state.get_mut(&id).ok_or(Box::new(NoChannel))
}
}
impl<H> TailnetServer for ChannelServer<H> {
fn new_client(dev: Arc<Device>, addr: SocketAddr) -> Self {
Self {
channel_state: Default::default(),
dev,
remote: addr,
accepted: None,
_handler: PhantomSend(PhantomData),
}
}
}
#[derive(Debug, Clone)]
pub enum ChannelEvent {
Data(Vec<u8>),
Resize {
width: u16,
height: u16,
},
Signal(Sig),
Close,
Eof,
}
impl<H> russh::server::Handler for ChannelServer<H>
where
H: ChannelHandler + Send,
H::Error: Send,
{
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
#[tracing::instrument(skip_all, fields(user = %user, remote = ?self.remote))]
async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
match self.dev.authorize_ssh(self.remote, user).await {
Ok(crate::ssh::SshDecision::Accept(accept)) => {
tracing::debug!(
local_user = %accept.local_user,
"ssh: policy accepted connection"
);
self.accepted = Some(accept);
Ok(Auth::Accept)
}
Ok(crate::ssh::SshDecision::Deny(reason)) => {
tracing::warn!(?reason, "ssh: policy denied connection");
Ok(Auth::reject())
}
Err(e) => {
tracing::error!(error = %e, "ssh: authorization failed; rejecting");
Ok(Auth::reject())
}
}
}
async fn channel_open_session(
&mut self,
channel: Channel<Msg>,
session: &mut Session,
) -> Result<bool, Self::Error> {
tracing::debug!(channel = ?channel.id(), "new session");
let Some(accept) = self.accepted.clone() else {
tracing::error!(
channel = ?channel.id(),
"ssh: channel open with no accepted identity; refusing"
);
return Ok(false);
};
if at_channel_cap(self.channel_state.len()) {
tracing::warn!(
channel = ?channel.id(),
cap = MAX_CHANNELS_PER_CONN,
"ssh: per-connection channel cap reached; refusing new channel"
);
return Ok(false);
}
let (tx, mut rx) = mpsc::unbounded_channel::<Request>();
let mut joinset = JoinSet::new();
let (channel_id, session_handle) = (channel.id(), session.handle());
let dev = self.dev.clone();
joinset.spawn(async move {
let rt = tokio::runtime::Handle::current();
let mut handler = match H::new(rt, channel_id, session_handle.clone(), dev, &accept) {
Ok(handler) => handler,
Err(e) => {
let e = e.into();
tracing::error!(error = %e, %channel_id, "spawning channel handler");
if session_handle.close(channel_id).await.is_err() {
tracing::error!("failed closing channel after handler init error");
};
return;
}
};
while let Some((_channel, evt)) = rx.recv().await {
let result = handler.handle_event(&evt).await;
if let Err(e) = result {
let e = e.into();
tracing::error!(error = %e, %channel_id, ?evt, "handling event");
if session_handle.close(channel_id).await.is_err() {
tracing::error!("failed closing channel after event handler error");
};
break;
}
}
tracing::debug!(?channel_id, "closed");
});
self.channel_state.insert(
channel.id(),
ChannelState {
channel: channel.id(),
tx,
_joinset: joinset,
},
);
session.channel_success(channel.id())?;
Ok(true)
}
async fn channel_close(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
tracing::trace!(?channel, "session closed");
self.get_channel(channel)?.send(ChannelEvent::Close);
self.channel_state.remove(&channel);
session.channel_success(channel)?;
Ok(())
}
async fn signal(
&mut self,
channel: ChannelId,
signal: Sig,
session: &mut Session,
) -> Result<(), Self::Error> {
self.get_channel(channel)?
.send(ChannelEvent::Signal(signal));
session.channel_success(channel)?;
Ok(())
}
async fn data(
&mut self,
channel: ChannelId,
data: &[u8],
session: &mut Session,
) -> Result<(), Self::Error> {
self.get_channel(channel)?
.send(ChannelEvent::Data(data.into()));
session.channel_success(channel)?;
Ok(())
}
async fn channel_eof(
&mut self,
channel: ChannelId,
session: &mut Session,
) -> Result<(), Self::Error> {
self.get_channel(channel)?.send(ChannelEvent::Eof);
session.channel_success(channel)?;
Ok(())
}
async fn window_change_request(
&mut self,
channel: ChannelId,
col_width: u32,
row_height: u32,
_: u32,
_: u32,
session: &mut Session,
) -> Result<(), Self::Error> {
self.get_channel(channel)?.send(ChannelEvent::Resize {
width: col_width as _,
height: row_height as _,
});
session.channel_success(channel)?;
Ok(())
}
async fn pty_request(
&mut self,
channel: ChannelId,
_: &str,
col_width: u32,
row_height: u32,
_: u32,
_: u32,
_: &[(Pty, u32)],
session: &mut Session,
) -> Result<(), Self::Error> {
self.get_channel(channel)?.send(ChannelEvent::Resize {
width: col_width as _,
height: row_height as _,
});
session.channel_success(channel)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{MAX_CHANNELS_PER_CONN, at_channel_cap};
#[test]
fn channel_cap_boundary_is_inclusive() {
assert!(!at_channel_cap(MAX_CHANNELS_PER_CONN - 1));
assert!(!at_channel_cap(15));
assert!(at_channel_cap(MAX_CHANNELS_PER_CONN));
assert!(at_channel_cap(16));
assert!(at_channel_cap(17));
assert_eq!(MAX_CHANNELS_PER_CONN, 16);
}
}