1use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
2
3use russh::{
4 Channel, ChannelId, Pty, Sig,
5 server::{Auth, Handle, Msg, Session},
6};
7use tokio::{
8 sync::{mpsc, mpsc::UnboundedSender},
9 task::JoinSet,
10};
11
12use crate::{
13 Device,
14 ssh::{SshAccept, TailnetServer},
15};
16
17type Request = (ChannelId, ChannelEvent);
18
19pub trait ChannelHandler: Sized {
21 type Error: Into<std::io::Error> + std::error::Error;
23
24 fn new(
32 handle: tokio::runtime::Handle,
33 channel_id: ChannelId,
34 session: Handle,
35 dev: Arc<Device>,
36 accept: &SshAccept,
37 ) -> Result<Self, Self::Error>;
38
39 fn handle_event(
41 &mut self,
42 event: &ChannelEvent,
43 ) -> impl Future<Output = Result<(), Self::Error>> + Send;
44}
45
46pub struct ChannelServer<H> {
63 channel_state: HashMap<ChannelId, ChannelState>,
64 remote: SocketAddr,
65 dev: Arc<Device>,
66 accepted: Option<SshAccept>,
70 _handler: PhantomSend<H>,
71}
72
73struct PhantomSend<H>(PhantomData<fn() -> H>);
74
75const MAX_CHANNELS_PER_CONN: usize = 16;
80
81fn at_channel_cap(open_channels: usize) -> bool {
86 open_channels >= MAX_CHANNELS_PER_CONN
87}
88
89#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
90#[error("no such channel")]
91struct NoChannel;
92
93struct ChannelState {
95 channel: ChannelId,
96 tx: UnboundedSender<Request>,
97 _joinset: JoinSet<()>,
98}
99
100impl ChannelState {
101 fn send(&self, event: ChannelEvent) {
102 if self.tx.send((self.channel, event)).is_err() {
103 tracing::error!(channel = %self.channel, "failed to send event");
104 }
105 }
106}
107
108impl<H> ChannelServer<H> {
109 fn get_channel(
110 &mut self,
111 id: ChannelId,
112 ) -> Result<&mut ChannelState, Box<dyn std::error::Error + Send + Sync + 'static>> {
113 self.channel_state.get_mut(&id).ok_or(Box::new(NoChannel))
114 }
115}
116
117impl<H> TailnetServer for ChannelServer<H> {
118 fn new_client(dev: Arc<Device>, addr: SocketAddr) -> Self {
119 Self {
120 channel_state: Default::default(),
121 dev,
122 remote: addr,
123 accepted: None,
124 _handler: PhantomSend(PhantomData),
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub enum ChannelEvent {
132 Data(Vec<u8>),
134 Resize {
136 width: u16,
138 height: u16,
140 },
141 Signal(Sig),
143 Close,
145 Eof,
147}
148
149impl<H> russh::server::Handler for ChannelServer<H>
150where
151 H: ChannelHandler + Send,
152 H::Error: Send,
153{
154 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
155
156 #[tracing::instrument(skip_all, fields(user = %user, remote = ?self.remote))]
157 async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
158 match self.dev.authorize_ssh(self.remote, user).await {
161 Ok(crate::ssh::SshDecision::Accept(accept)) => {
162 tracing::debug!(
163 local_user = %accept.local_user,
164 "ssh: policy accepted connection"
165 );
166 self.accepted = Some(accept);
170 Ok(Auth::Accept)
171 }
172 Ok(crate::ssh::SshDecision::Deny(reason)) => {
173 tracing::warn!(?reason, "ssh: policy denied connection");
174 Ok(Auth::reject())
175 }
176 Err(e) => {
177 tracing::error!(error = %e, "ssh: authorization failed; rejecting");
178 Ok(Auth::reject())
179 }
180 }
181 }
182
183 async fn channel_open_session(
184 &mut self,
185 channel: Channel<Msg>,
186 session: &mut Session,
187 ) -> Result<bool, Self::Error> {
188 tracing::debug!(channel = ?channel.id(), "new session");
189
190 let Some(accept) = self.accepted.clone() else {
194 tracing::error!(
195 channel = ?channel.id(),
196 "ssh: channel open with no accepted identity; refusing"
197 );
198 return Ok(false);
199 };
200
201 if at_channel_cap(self.channel_state.len()) {
205 tracing::warn!(
206 channel = ?channel.id(),
207 cap = MAX_CHANNELS_PER_CONN,
208 "ssh: per-connection channel cap reached; refusing new channel"
209 );
210 return Ok(false);
211 }
212
213 let (tx, mut rx) = mpsc::unbounded_channel::<Request>();
214 let mut joinset = JoinSet::new();
215
216 let (channel_id, session_handle) = (channel.id(), session.handle());
217 let dev = self.dev.clone();
218
219 joinset.spawn(async move {
220 let rt = tokio::runtime::Handle::current();
221
222 let mut handler = match H::new(rt, channel_id, session_handle.clone(), dev, &accept) {
223 Ok(handler) => handler,
224 Err(e) => {
225 let e = e.into();
226 tracing::error!(error = %e, %channel_id, "spawning channel handler");
227
228 if session_handle.close(channel_id).await.is_err() {
229 tracing::error!("failed closing channel after handler init error");
230 };
231
232 return;
233 }
234 };
235
236 while let Some((_channel, evt)) = rx.recv().await {
237 let result = handler.handle_event(&evt).await;
238
239 if let Err(e) = result {
240 let e = e.into();
241 tracing::error!(error = %e, %channel_id, ?evt, "handling event");
242
243 if session_handle.close(channel_id).await.is_err() {
244 tracing::error!("failed closing channel after event handler error");
245 };
246
247 break;
248 }
249 }
250
251 tracing::debug!(?channel_id, "closed");
252 });
253
254 self.channel_state.insert(
255 channel.id(),
256 ChannelState {
257 channel: channel.id(),
258 tx,
259 _joinset: joinset,
260 },
261 );
262
263 session.channel_success(channel.id())?;
264
265 Ok(true)
266 }
267
268 async fn channel_close(
269 &mut self,
270 channel: ChannelId,
271 session: &mut Session,
272 ) -> Result<(), Self::Error> {
273 tracing::trace!(?channel, "session closed");
274
275 self.get_channel(channel)?.send(ChannelEvent::Close);
276 self.channel_state.remove(&channel);
277
278 session.channel_success(channel)?;
279
280 Ok(())
281 }
282
283 async fn signal(
284 &mut self,
285 channel: ChannelId,
286 signal: Sig,
287 session: &mut Session,
288 ) -> Result<(), Self::Error> {
289 self.get_channel(channel)?
290 .send(ChannelEvent::Signal(signal));
291 session.channel_success(channel)?;
292
293 Ok(())
294 }
295
296 async fn data(
297 &mut self,
298 channel: ChannelId,
299 data: &[u8],
300 session: &mut Session,
301 ) -> Result<(), Self::Error> {
302 self.get_channel(channel)?
303 .send(ChannelEvent::Data(data.into()));
304
305 session.channel_success(channel)?;
306
307 Ok(())
308 }
309
310 async fn channel_eof(
311 &mut self,
312 channel: ChannelId,
313 session: &mut Session,
314 ) -> Result<(), Self::Error> {
315 self.get_channel(channel)?.send(ChannelEvent::Eof);
316 session.channel_success(channel)?;
317
318 Ok(())
319 }
320
321 async fn window_change_request(
322 &mut self,
323 channel: ChannelId,
324 col_width: u32,
325 row_height: u32,
326 _: u32,
327 _: u32,
328 session: &mut Session,
329 ) -> Result<(), Self::Error> {
330 self.get_channel(channel)?.send(ChannelEvent::Resize {
331 width: col_width as _,
332 height: row_height as _,
333 });
334
335 session.channel_success(channel)?;
336
337 Ok(())
338 }
339
340 async fn pty_request(
341 &mut self,
342 channel: ChannelId,
343 _: &str,
344 col_width: u32,
345 row_height: u32,
346 _: u32,
347 _: u32,
348 _: &[(Pty, u32)],
349 session: &mut Session,
350 ) -> Result<(), Self::Error> {
351 self.get_channel(channel)?.send(ChannelEvent::Resize {
352 width: col_width as _,
353 height: row_height as _,
354 });
355
356 session.channel_success(channel)?;
357
358 Ok(())
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::{MAX_CHANNELS_PER_CONN, at_channel_cap};
365
366 #[test]
370 fn channel_cap_boundary_is_inclusive() {
371 assert!(!at_channel_cap(MAX_CHANNELS_PER_CONN - 1));
373 assert!(!at_channel_cap(15));
374 assert!(at_channel_cap(MAX_CHANNELS_PER_CONN));
376 assert!(at_channel_cap(16));
377 assert!(at_channel_cap(17));
379 assert_eq!(MAX_CHANNELS_PER_CONN, 16);
381 }
382}