Skip to main content

tailscale/ssh/
channel_server.rs

1use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
2
3use russh::{
4    Channel, ChannelId, Pty, Sig,
5    server::{Auth, Handle, Msg, Session},
6};
7use tokio::{
8    sync::{mpsc, mpsc::UnboundedSender},
9    task::JoinSet,
10};
11
12use crate::{
13    Device,
14    ssh::{SshAccept, TailnetServer},
15};
16
17type Request = (ChannelId, ChannelEvent);
18
19/// Handler for a channel session.
20pub trait ChannelHandler: Sized {
21    /// Error this handler produces.
22    type Error: Into<std::io::Error> + std::error::Error;
23
24    /// Construct a new per-channel handler.
25    ///
26    /// `accept` is the [`SshAccept`] produced by the single fail-closed authorization decision in
27    /// [`auth_none`][russh::server::Handler::auth_none]; in particular its
28    /// [`local_user`][SshAccept::local_user] is the policy-mapped identity the session must run as.
29    /// Handlers MUST NOT re-evaluate policy or substitute a different user — the accepted identity
30    /// is the sole authorization source.
31    fn new(
32        handle: tokio::runtime::Handle,
33        channel_id: ChannelId,
34        session: Handle,
35        dev: Arc<Device>,
36        accept: &SshAccept,
37    ) -> Result<Self, Self::Error>;
38
39    /// Handle an event from the channel.
40    fn handle_event(
41        &mut self,
42        event: &ChannelEvent,
43    ) -> impl Future<Output = Result<(), Self::Error>> + Send;
44}
45
46/// Implementation of [`russh::server::Handler`] which provides per-channel session
47/// handlers using a parametric [`ChannelHandler`].
48///
49/// Primary motivation is to support custom console or TUI sessions over tailnet SSH
50/// connections.
51///
52/// # Authentication and authorization
53///
54/// Incoming connections are gated by the control-pushed Tailscale SSH policy: [`auth_none`]
55/// resolves the source IP to a known tailnet peer and evaluates the policy via
56/// [`Device::authorize_ssh`][crate::Device::authorize_ssh] (fail-closed — an unknown peer, an
57/// absent policy, or a non-matching policy all reject). The `ssh` policy block's accept/reject
58/// rules, principal matching, and SSH-user mapping are honored; advanced features (recording,
59/// `holdAndDelegate`, per-session capability enforcement) are not yet implemented.
60///
61/// [`auth_none`]: russh::server::Handler::auth_none
62pub struct ChannelServer<H> {
63    channel_state: HashMap<ChannelId, ChannelState>,
64    remote: SocketAddr,
65    dev: Arc<Device>,
66    /// The accepted identity from the single [`auth_none`][russh::server::Handler::auth_none]
67    /// authorization decision, stashed so per-channel handlers run as the policy-mapped user.
68    /// `None` until a successful `auth_none`; a channel open with `None` here fails closed.
69    accepted: Option<SshAccept>,
70    _handler: PhantomSend<H>,
71}
72
73struct PhantomSend<H>(PhantomData<fn() -> H>);
74
75/// Maximum number of concurrent channels a single SSH connection may open. Each channel spawns a
76/// session handler (e.g. a login shell), so this caps the per-connection resource/process fan-out
77/// an authorized-but-hostile peer can induce. SSH clients realistically open one (or a few)
78/// sessions per connection, so this is generous for legitimate use.
79const MAX_CHANNELS_PER_CONN: usize = 16;
80
81/// Whether a connection at `open_channels` currently-open channels has reached the per-connection
82/// channel cap and must refuse the next channel open. Pure boundary predicate extracted from
83/// [`ChannelServer::channel_open_session`] so the fork-bomb guard's edge can be unit-tested without
84/// a live russh [`Session`].
85fn at_channel_cap(open_channels: usize) -> bool {
86    open_channels >= MAX_CHANNELS_PER_CONN
87}
88
89#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)]
90#[error("no such channel")]
91struct NoChannel;
92
93/// State of a channel in [`ChannelServer`].
94struct ChannelState {
95    channel: ChannelId,
96    tx: UnboundedSender<Request>,
97    _joinset: JoinSet<()>,
98}
99
100impl ChannelState {
101    fn send(&self, event: ChannelEvent) {
102        if self.tx.send((self.channel, event)).is_err() {
103            tracing::error!(channel = %self.channel, "failed to send event");
104        }
105    }
106}
107
108impl<H> ChannelServer<H> {
109    fn get_channel(
110        &mut self,
111        id: ChannelId,
112    ) -> Result<&mut ChannelState, Box<dyn std::error::Error + Send + Sync + 'static>> {
113        self.channel_state.get_mut(&id).ok_or(Box::new(NoChannel))
114    }
115}
116
117impl<H> TailnetServer for ChannelServer<H> {
118    fn new_client(dev: Arc<Device>, addr: SocketAddr) -> Self {
119        Self {
120            channel_state: Default::default(),
121            dev,
122            remote: addr,
123            accepted: None,
124            _handler: PhantomSend(PhantomData),
125        }
126    }
127}
128
129/// An event that may be generated by a channel connected to a [`ChannelServer`].
130#[derive(Debug, Clone)]
131pub enum ChannelEvent {
132    /// Data was received over the channel.
133    Data(Vec<u8>),
134    /// A resize event occurred.
135    Resize {
136        /// The new width of the tty.
137        width: u16,
138        /// The new height of the tty.
139        height: u16,
140    },
141    /// A signal was sent over the channel.
142    Signal(Sig),
143    /// The channel was closed.
144    Close,
145    /// The channel received EOF.
146    Eof,
147}
148
149impl<H> russh::server::Handler for ChannelServer<H>
150where
151    H: ChannelHandler + Send,
152    H::Error: Send,
153{
154    type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
155
156    #[tracing::instrument(skip_all, fields(user = %user, remote = ?self.remote))]
157    async fn auth_none(&mut self, user: &str) -> Result<Auth, Self::Error> {
158        // Enforce the control-pushed Tailscale SSH policy. Fail-closed: an unknown source, an
159        // absent policy, a non-matching policy, or any lookup error all reject the connection.
160        match self.dev.authorize_ssh(self.remote, user).await {
161            Ok(crate::ssh::SshDecision::Accept(accept)) => {
162                tracing::debug!(
163                    local_user = %accept.local_user,
164                    "ssh: policy accepted connection"
165                );
166                // Stash the accepted identity so the per-channel handler runs as the
167                // policy-mapped local user. This is the single fail-closed authorization point;
168                // the handler never re-evaluates policy.
169                self.accepted = Some(accept);
170                Ok(Auth::Accept)
171            }
172            Ok(crate::ssh::SshDecision::Deny(reason)) => {
173                tracing::warn!(?reason, "ssh: policy denied connection");
174                Ok(Auth::reject())
175            }
176            Err(e) => {
177                tracing::error!(error = %e, "ssh: authorization failed; rejecting");
178                Ok(Auth::reject())
179            }
180        }
181    }
182
183    async fn channel_open_session(
184        &mut self,
185        channel: Channel<Msg>,
186        session: &mut Session,
187    ) -> Result<bool, Self::Error> {
188        tracing::debug!(channel = ?channel.id(), "new session");
189
190        // Fail closed: a channel open must be preceded by a successful `auth_none` that stashed
191        // the accepted identity. If it is somehow absent, refuse to open the channel rather than
192        // run a handler with no authorized user.
193        let Some(accept) = self.accepted.clone() else {
194            tracing::error!(
195                channel = ?channel.id(),
196                "ssh: channel open with no accepted identity; refusing"
197            );
198            return Ok(false);
199        };
200
201        // Bound the number of concurrent channels (each opens a session/handler — e.g. a login
202        // shell). Without this an authorized-but-hostile peer could open unbounded channels on one
203        // connection and fork-bomb the host with session handlers. Past the cap, refuse new channels.
204        if at_channel_cap(self.channel_state.len()) {
205            tracing::warn!(
206                channel = ?channel.id(),
207                cap = MAX_CHANNELS_PER_CONN,
208                "ssh: per-connection channel cap reached; refusing new channel"
209            );
210            return Ok(false);
211        }
212
213        let (tx, mut rx) = mpsc::unbounded_channel::<Request>();
214        let mut joinset = JoinSet::new();
215
216        let (channel_id, session_handle) = (channel.id(), session.handle());
217        let dev = self.dev.clone();
218
219        joinset.spawn(async move {
220            let rt = tokio::runtime::Handle::current();
221
222            let mut handler = match H::new(rt, channel_id, session_handle.clone(), dev, &accept) {
223                Ok(handler) => handler,
224                Err(e) => {
225                    let e = e.into();
226                    tracing::error!(error = %e, %channel_id, "spawning channel handler");
227
228                    if session_handle.close(channel_id).await.is_err() {
229                        tracing::error!("failed closing channel after handler init error");
230                    };
231
232                    return;
233                }
234            };
235
236            while let Some((_channel, evt)) = rx.recv().await {
237                let result = handler.handle_event(&evt).await;
238
239                if let Err(e) = result {
240                    let e = e.into();
241                    tracing::error!(error = %e, %channel_id, ?evt, "handling event");
242
243                    if session_handle.close(channel_id).await.is_err() {
244                        tracing::error!("failed closing channel after event handler error");
245                    };
246
247                    break;
248                }
249            }
250
251            tracing::debug!(?channel_id, "closed");
252        });
253
254        self.channel_state.insert(
255            channel.id(),
256            ChannelState {
257                channel: channel.id(),
258                tx,
259                _joinset: joinset,
260            },
261        );
262
263        session.channel_success(channel.id())?;
264
265        Ok(true)
266    }
267
268    async fn channel_close(
269        &mut self,
270        channel: ChannelId,
271        session: &mut Session,
272    ) -> Result<(), Self::Error> {
273        tracing::trace!(?channel, "session closed");
274
275        self.get_channel(channel)?.send(ChannelEvent::Close);
276        self.channel_state.remove(&channel);
277
278        session.channel_success(channel)?;
279
280        Ok(())
281    }
282
283    async fn signal(
284        &mut self,
285        channel: ChannelId,
286        signal: Sig,
287        session: &mut Session,
288    ) -> Result<(), Self::Error> {
289        self.get_channel(channel)?
290            .send(ChannelEvent::Signal(signal));
291        session.channel_success(channel)?;
292
293        Ok(())
294    }
295
296    async fn data(
297        &mut self,
298        channel: ChannelId,
299        data: &[u8],
300        session: &mut Session,
301    ) -> Result<(), Self::Error> {
302        self.get_channel(channel)?
303            .send(ChannelEvent::Data(data.into()));
304
305        session.channel_success(channel)?;
306
307        Ok(())
308    }
309
310    async fn channel_eof(
311        &mut self,
312        channel: ChannelId,
313        session: &mut Session,
314    ) -> Result<(), Self::Error> {
315        self.get_channel(channel)?.send(ChannelEvent::Eof);
316        session.channel_success(channel)?;
317
318        Ok(())
319    }
320
321    async fn window_change_request(
322        &mut self,
323        channel: ChannelId,
324        col_width: u32,
325        row_height: u32,
326        _: u32,
327        _: u32,
328        session: &mut Session,
329    ) -> Result<(), Self::Error> {
330        self.get_channel(channel)?.send(ChannelEvent::Resize {
331            width: col_width as _,
332            height: row_height as _,
333        });
334
335        session.channel_success(channel)?;
336
337        Ok(())
338    }
339
340    async fn pty_request(
341        &mut self,
342        channel: ChannelId,
343        _: &str,
344        col_width: u32,
345        row_height: u32,
346        _: u32,
347        _: u32,
348        _: &[(Pty, u32)],
349        session: &mut Session,
350    ) -> Result<(), Self::Error> {
351        self.get_channel(channel)?.send(ChannelEvent::Resize {
352            width: col_width as _,
353            height: row_height as _,
354        });
355
356        session.channel_success(channel)?;
357
358        Ok(())
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::{MAX_CHANNELS_PER_CONN, at_channel_cap};
365
366    /// The per-connection channel cap (fork-bomb guard) refuses at and beyond `MAX_CHANNELS_PER_CONN`
367    /// and allows below it. Pins the exact boundary: a `>=`→`>` flip would let `MAX_CHANNELS_PER_CONN`
368    /// open channels become `MAX_CHANNELS_PER_CONN + 1`, failing the `== cap` assertion below.
369    #[test]
370    fn channel_cap_boundary_is_inclusive() {
371        // Below the cap: still allowed.
372        assert!(!at_channel_cap(MAX_CHANNELS_PER_CONN - 1));
373        assert!(!at_channel_cap(15));
374        // At the cap: refuse the next open (the channel that would make it 17).
375        assert!(at_channel_cap(MAX_CHANNELS_PER_CONN));
376        assert!(at_channel_cap(16));
377        // Above the cap (defensive): still refused.
378        assert!(at_channel_cap(17));
379        // The const itself is the documented value.
380        assert_eq!(MAX_CHANNELS_PER_CONN, 16);
381    }
382}