cable_tunnel_server_common/
lib.rs

1//! Common components for webauthn-rs' caBLE tunnel server.
2//!
3//! **Important**: this library is an internal implementation detail of
4//! webauthn-rs' caBLE tunnel server, and has no guarantees of API stability
5//! whatsoever. It is **not** intended for use outside of that context.
6use std::{
7    convert::Infallible, error::Error as StdError, fmt::Display, future::Future, mem::size_of,
8    net::SocketAddr, num::ParseIntError, sync::Arc, time::Duration,
9};
10
11use hex::{FromHex, FromHexError};
12use http_body::Body;
13use http_body_util::{Empty, Full};
14use hyper::{
15    body::{Bytes, Incoming},
16    header::{CONTENT_TYPE, ORIGIN, SEC_WEBSOCKET_PROTOCOL},
17    http::HeaderValue,
18    service::service_fn,
19    HeaderMap, Method, Request, Response, StatusCode, Uri,
20};
21use tokio::net::TcpListener;
22use tokio_native_tls::TlsAcceptor;
23use tokio_tungstenite::MaybeTlsStream;
24use tracing::Instrument;
25use tracing_subscriber::{filter::LevelFilter, fmt::format::FmtSpan, EnvFilter};
26use tungstenite::handshake::server::create_response;
27
28#[macro_use]
29extern crate tracing;
30
31mod tls;
32pub use tls::*;
33
34pub type RoutingId = [u8; 3];
35pub type TunnelId = [u8; 16];
36
37pub static CABLE_PROTOCOL: HeaderValue = HeaderValue::from_static("fido.cable");
38pub const CABLE_ROUTING_ID_HEADER: &str = "X-caBLE-Routing-ID";
39
40pub const CABLE_NEW_PATH: &str = "/cable/new/";
41pub const CABLE_CONNECT_PATH: &str = "/cable/connect/";
42
43pub const MAX_URL_LENGTH: usize =
44    CABLE_CONNECT_PATH.len() + ((size_of::<RoutingId>() + size_of::<TunnelId>()) * 2) + 2;
45
46const FAVICON: &[u8] = include_bytes!("favicon.ico");
47const INDEX: &str = include_str!("index.html");
48
49/// Parses a Base-16 encoded string.
50///
51/// This function is intended for use as a `clap` `value_parser`.
52pub fn parse_hex<T>(i: &str) -> Result<T, FromHexError>
53where
54    T: FromHex<Error = FromHexError>,
55{
56    FromHex::from_hex(i)
57}
58
59/// Parses a duration as a number of seconds from a string.
60///
61/// This function is intended for use as a `clap` `value_parser`.
62pub fn parse_duration_secs(i: &str) -> Result<Duration, ParseIntError> {
63    i.parse::<u64>().map(Duration::from_secs)
64}
65
66/// Path for caBLE WebSocket tunnel protocol.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct CablePath {
69    /// The method for the tunnel.
70    pub method: CableMethod,
71    /// The routing ID of the tunnel.
72    pub routing_id: RoutingId,
73    /// The tunnel ID of the tunnel.
74    pub tunnel_id: TunnelId,
75}
76
77/// Method for caBLE WebSocket tunnel protocol.
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum CableMethod {
80    /// Request from the authenticator to establish a new tunnel. This needs to
81    /// be allocated a routing ID.
82    New,
83    /// Request from the initiator to connect to an existing tunnel.
84    Connect,
85}
86
87impl CablePath {
88    pub fn new(tunnel_id: TunnelId) -> Self {
89        Self {
90            method: CableMethod::New,
91            routing_id: [0; size_of::<RoutingId>()],
92            tunnel_id,
93        }
94    }
95
96    pub fn connect(routing_id: RoutingId, tunnel_id: TunnelId) -> Self {
97        Self {
98            method: CableMethod::Connect,
99            routing_id,
100            tunnel_id,
101        }
102    }
103
104    /// Inserts the caBLE routing ID header into a HTTP response.
105    pub fn insert_routing_id_header(&self, headers: &mut HeaderMap) {
106        headers.insert(
107            CABLE_ROUTING_ID_HEADER,
108            HeaderValue::from_str(&hex::encode_upper(self.routing_id)).unwrap(),
109        );
110    }
111}
112
113impl Display for CablePath {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self.method {
116            CableMethod::New => {
117                write!(f, "{}{}", CABLE_NEW_PATH, hex::encode_upper(self.tunnel_id))
118            }
119            CableMethod::Connect => write!(
120                f,
121                "{}{}/{}",
122                CABLE_CONNECT_PATH,
123                hex::encode_upper(self.routing_id),
124                hex::encode_upper(self.tunnel_id)
125            ),
126        }
127    }
128}
129
130impl TryFrom<&str> for CablePath {
131    type Error = ();
132    fn try_from(path: &str) -> Result<Self, Self::Error> {
133        if path.len() > MAX_URL_LENGTH {
134            error!("path too long: {} > {MAX_URL_LENGTH} bytes", path.len());
135            return Err(());
136        } else if let Some(path) = path.strip_prefix(CABLE_NEW_PATH) {
137            let mut tunnel_id: TunnelId = [0; size_of::<TunnelId>()];
138            if hex::decode_to_slice(path, &mut tunnel_id).is_ok() {
139                return Ok(Self::new(tunnel_id));
140            }
141            error!("invalid new path: {path}");
142        } else if let Some(path) = path.strip_prefix(CABLE_CONNECT_PATH) {
143            let mut routing_id: RoutingId = [0; size_of::<RoutingId>()];
144            let mut tunnel_id: TunnelId = [0; size_of::<TunnelId>()];
145
146            let mut splitter = path.split('/');
147
148            if splitter
149                .next()
150                .and_then(|c| hex::decode_to_slice(c, &mut routing_id).ok())
151                .is_none()
152            {
153                error!("invalid routing_id in connect path: {path}");
154                return Err(());
155            }
156
157            if splitter
158                .next()
159                .and_then(|c| hex::decode_to_slice(c, &mut tunnel_id).ok())
160                .is_none()
161            {
162                error!("invalid tunnel_id in connect path: {path}");
163                return Err(());
164            }
165
166            if splitter.next().is_some() {
167                error!("unexpected extra token in connect path: {path}");
168                return Err(());
169            }
170
171            return Ok(Self::connect(routing_id, tunnel_id));
172        } else {
173            error!("unknown path: {path}")
174        }
175
176        Err(())
177    }
178}
179
180/// HTTP router for a caBLE WebSocket tunnel server.
181#[derive(Debug)]
182pub enum Router {
183    /// The web server should handle the request as a caBLE WebSocket
184    /// connection.
185    Websocket(Response<Full<Bytes>>, CablePath),
186    /// The web server should return a static response. This may be an error
187    /// message.
188    Static(Response<Full<Bytes>>),
189
190    Debug,
191}
192
193impl Router {
194    /// Routes an incoming HTTP request.
195    pub fn route(req: &Request<()>, origin: Option<&str>) -> Self {
196        if req.method() != Method::GET {
197            error!("method {} not allowed", req.method());
198            let response = Response::builder()
199                .status(StatusCode::METHOD_NOT_ALLOWED)
200                .header("Allow", "GET")
201                .body(Default::default())
202                .unwrap();
203            return Self::Static(response);
204        }
205
206        let path = match req.uri().path() {
207            "/" => {
208                return Self::Static(
209                    Response::builder()
210                        .status(StatusCode::OK)
211                        .header(CONTENT_TYPE, "text/html")
212                        .body(Bytes::from(INDEX).into())
213                        .unwrap(),
214                )
215            }
216            "/favicon.ico" => {
217                return Self::Static(
218                    Response::builder()
219                        .status(StatusCode::OK)
220                        .header(CONTENT_TYPE, "image/vnd.microsoft.icon")
221                        .body(Bytes::from(FAVICON).into())
222                        .unwrap(),
223                );
224            }
225            "/debug" => return Self::Debug,
226            path => match CablePath::try_from(path) {
227                Err(()) => {
228                    return Self::Static(empty_response(StatusCode::NOT_FOUND));
229                }
230                Ok(p) => p,
231            },
232        };
233
234        let mut res = match create_response(req) {
235            Ok(r) => r,
236            Err(e) => {
237                error!("bad request for WebSocket: {e}");
238                return Self::Static(empty_response(StatusCode::BAD_REQUEST));
239            }
240        };
241
242        // At this point, we have something that looks like a WebSocket on the
243        // other side. We should check the parameters selected etc.
244        if !req
245            .headers()
246            .get(SEC_WEBSOCKET_PROTOCOL)
247            .map(|v| v == CABLE_PROTOCOL)
248            .unwrap_or_default()
249        {
250            error!("unsupported or missing WebSocket protocol");
251            return Self::Static(empty_response(StatusCode::BAD_REQUEST));
252        }
253
254        // Check the origin header
255        if let Some(origin) = origin {
256            if !req
257                .headers()
258                .get(ORIGIN)
259                .and_then(|v| v.to_str().ok())
260                .and_then(|v| v.parse::<Uri>().ok())
261                .map(|v| {
262                    v.host()
263                        .map(|o| o.eq_ignore_ascii_case(origin))
264                        .unwrap_or_default()
265                })
266                .unwrap_or_default()
267            {
268                error!("incorrect or missing Origin header");
269                return Self::Static(empty_response(StatusCode::FORBIDDEN));
270            }
271        }
272
273        // We have the correct protocol, include in the response
274        res.headers_mut()
275            .append(SEC_WEBSOCKET_PROTOCOL, CABLE_PROTOCOL.to_owned());
276        let res = res.map(|_| Default::default());
277
278        Router::Websocket(res, path)
279    }
280
281    #[cfg(test)]
282    pub(self) fn static_response(self) -> Option<Response<Full<Bytes>>> {
283        match self {
284            Self::Static(r) => Some(r),
285            _ => None,
286        }
287    }
288}
289
290/// Make a [Response] with a given [StatusCode] and empty body.
291pub fn empty_response<E: Into<StatusCode>, T: Default>(status: E) -> Response<T> {
292    Response::builder()
293        .status(status)
294        .body(Default::default())
295        .unwrap()
296}
297
298/// Create a copy of an existing HTTP [Request], discarding the body.
299pub fn copy_request_empty_body<T>(r: &Request<T>) -> Request<Empty<Bytes>> {
300    let mut o = Request::builder().method(r.method()).uri(r.uri());
301    {
302        let headers = o.headers_mut().unwrap();
303        headers.extend(r.headers().to_owned());
304    }
305
306    o.body(Default::default()).unwrap()
307}
308
309/// Create a copy of an existing HTTP [Response], discarding the body.
310pub fn copy_response_empty_body<T>(r: &Response<T>) -> Response<Empty<Bytes>> {
311    let mut o = Response::builder().status(r.status());
312    {
313        let headers = o.headers_mut().unwrap();
314        headers.extend(r.headers().to_owned());
315    }
316
317    o.body(Default::default()).unwrap()
318}
319
320/// Run a HTTP server for the caBLE WebSocket tunnel.
321pub async fn run_server<F, R, ResBody, T>(
322    bind_address: SocketAddr,
323    tls_acceptor: Option<TlsAcceptor>,
324    server_state: T,
325    mut request_handler: F,
326) -> Result<(), Box<dyn StdError>>
327where
328    F: FnMut(Arc<T>, SocketAddr, Request<Incoming>) -> R + Copy + Send + Sync + 'static,
329    R: Future<Output = Result<Response<ResBody>, Infallible>> + Send,
330    ResBody: Body + Send + 'static,
331    <ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
332    <ResBody as Body>::Data: Send,
333    T: Send + Sync + 'static,
334{
335    let server_state = Arc::new(server_state);
336    let tcp = TcpListener::bind(&bind_address).await?;
337    let tls_acceptor = tls_acceptor.map(Arc::new);
338
339    loop {
340        let (stream, remote_addr) = match tcp.accept().await {
341            Ok(o) => o,
342            Err(e) => {
343                error!("tcp.accept: {e}");
344                continue;
345            }
346        };
347        let server_state = server_state.clone();
348        let service =
349            service_fn(move |req| request_handler(server_state.clone(), remote_addr, req));
350        let tls_acceptor = tls_acceptor.clone();
351
352        let span = info_span!("handle_connection", addr = remote_addr.to_string());
353        tokio::task::spawn(
354            async move {
355                let stream = match tls_acceptor {
356                    None => MaybeTlsStream::Plain(stream),
357                    Some(tls_acceptor) => match tls_acceptor.accept(stream).await {
358                        Ok(o) => MaybeTlsStream::NativeTls(o),
359                        Err(e) => {
360                            error!("tls_acceptor.accept: {e}");
361                            return;
362                        }
363                    },
364                };
365
366                let conn =
367                    hyper::server::conn::http1::Builder::new().serve_connection(stream, service);
368                let conn = conn.with_upgrades();
369
370                if let Err(e) = conn.await {
371                    error!("connection error: {e}");
372                }
373            }
374            .instrument(span),
375        );
376    }
377}
378
379/// Sets up logging for cable-tunnel-server binaries.
380pub fn setup_logging() {
381    tracing_subscriber::fmt()
382        .with_env_filter(
383            EnvFilter::builder()
384                .with_default_directive(LevelFilter::INFO.into())
385                .from_env_lossy(),
386        )
387        .with_span_events(FmtSpan::CLOSE | FmtSpan::NEW)
388        .with_thread_ids(true)
389        // .with_file(true)
390        // .with_line_number(true)
391        .compact()
392        .init();
393}
394
395#[cfg(test)]
396mod tests {
397    use http_body_util::BodyExt;
398    use tungstenite::client::IntoClientRequest;
399
400    use super::*;
401
402    #[test]
403    fn parse_urls() {
404        // Parse valid paths in upper case
405        assert_eq!(
406            CablePath::new(*b"hello, webauthn!"),
407            CablePath::try_from("/cable/new/68656C6C6F2C20776562617574686E21").unwrap()
408        );
409        assert_eq!(
410            CablePath::connect(*b"abc", *b"hello, webauthn!"),
411            CablePath::try_from("/cable/connect/616263/68656C6C6F2C20776562617574686E21").unwrap()
412        );
413
414        // Converting to string should always return upper-case paths
415        assert_eq!(
416            "/cable/new/68656C6C6F2C20776562617574686E21",
417            CablePath::new(*b"hello, webauthn!").to_string(),
418        );
419        assert_eq!(
420            "/cable/connect/616263/68656C6C6F2C20776562617574686E21",
421            CablePath::connect(*b"abc", *b"hello, webauthn!").to_string(),
422        );
423
424        // Parse valid paths in lower case
425        assert_eq!(
426            CablePath::new(*b"hello, webauthn!"),
427            CablePath::try_from("/cable/new/68656c6c6f2c20776562617574686e21").unwrap()
428        );
429        assert_eq!(
430            CablePath::connect(*b"abc", *b"hello, webauthn!"),
431            CablePath::try_from("/cable/connect/616263/68656c6c6f2c20776562617574686e21").unwrap()
432        );
433
434        // Parsing lower-case paths should return strings in upper case.
435        assert_eq!(
436            "/cable/new/68656C6C6F2C20776562617574686E21",
437            CablePath::try_from("/cable/new/68656c6c6f2c20776562617574686e21")
438                .unwrap()
439                .to_string()
440        );
441        assert_eq!(
442            "/cable/connect/616263/68656C6C6F2C20776562617574686E21",
443            CablePath::try_from("/cable/connect/616263/68656c6c6f2c20776562617574686e21")
444                .unwrap()
445                .to_string()
446        );
447
448        // Invalid paths
449        assert!(CablePath::try_from("/").is_err());
450
451        assert!(CablePath::try_from("/cable/new/").is_err());
452        assert!(CablePath::try_from("/cable/new/not_hex_digits_here_but_32_chars").is_err());
453        assert!(CablePath::try_from("/cable/new/C0FFEE").is_err());
454        assert!(CablePath::try_from("/cable/new/C0FFEEC0FFEEC0FFEEC0FFEEC0FFEEC0FFEE").is_err());
455        assert!(CablePath::try_from("/cable/new/68656C6C6F2C20776562617574686E21/").is_err());
456        assert!(CablePath::try_from("/cable/new/../new/68656C6C6F2C20776562617574686E21").is_err());
457
458        assert!(
459            CablePath::try_from("/cable/connect/C0FFEE/not_hex_digits_here_but_32_chars").is_err()
460        );
461        assert!(CablePath::try_from("/cable/connect/C0FFEE/COFFEE").is_err());
462        assert!(CablePath::try_from("/cable/connect/C0/FFEE").is_err());
463        assert!(CablePath::try_from("/cable/connect/C0/68656C6C6F2C20776562617574686E21").is_err());
464        assert!(
465            CablePath::try_from("/cable/connect/C0F/68656C6C6F2C20776562617574686E21").is_err()
466        );
467        assert!(
468            CablePath::try_from("/cable/connect/C0FFEECO/68656C6C6F2C20776562617574686E21")
469                .is_err()
470        );
471        assert!(
472            CablePath::try_from("/cable/connect/C0FFEE/68656C6C6F2C20776562617574686E21/1234")
473                .is_err()
474        );
475        assert!(
476            CablePath::try_from("/cable/connect/C0FFEE/68656C6C6F2C20776562617574686E21/").is_err()
477        );
478
479        // other nonsense
480        assert!(CablePath::try_from("cable/new/68656C6C6F2C20776562617574686E21").is_err());
481        assert!(CablePath::try_from("../cable/new/68656C6C6F2C20776562617574686E21").is_err());
482        assert!(CablePath::try_from("/../cable/new/68656C6C6F2C20776562617574686E21").is_err());
483        assert!(CablePath::try_from("../../../etc/passwd").is_err());
484
485        // Should be rejected by length limits
486        assert!(CablePath::try_from(include_str!("lib.rs")).is_err());
487    }
488
489    async fn check_index_response(response: Response<Full<Bytes>>) {
490        assert_eq!(StatusCode::OK, response.status());
491        assert_eq!(
492            "text/html",
493            HeaderValue::to_str(response.headers().get(CONTENT_TYPE).unwrap()).unwrap()
494        );
495        assert!(response.body().size_hint().exact().unwrap() > 16);
496        assert_eq!(
497            INDEX,
498            response
499                .into_body()
500                .frame()
501                .await
502                .unwrap()
503                .unwrap()
504                .into_data()
505                .unwrap()
506        );
507    }
508
509    async fn check_favicon_response(response: Response<Full<Bytes>>) {
510        assert_eq!(StatusCode::OK, response.status());
511        assert_eq!(
512            "image/vnd.microsoft.icon",
513            HeaderValue::to_str(response.headers().get(CONTENT_TYPE).unwrap()).unwrap()
514        );
515        assert!(response.body().size_hint().exact().unwrap() > 16);
516        assert_eq!(
517            FAVICON,
518            response
519                .into_body()
520                .frame()
521                .await
522                .unwrap()
523                .unwrap()
524                .into_data()
525                .unwrap()
526        );
527    }
528
529    fn check_error_response(response: Response<Full<Bytes>>, expected_status: StatusCode) {
530        assert_eq!(expected_status, response.status());
531        assert_eq!(0, response.body().size_hint().exact().unwrap());
532    }
533
534    fn check_websocket_response(response: &Response<Full<Bytes>>) -> bool {
535        assert_eq!(StatusCode::SWITCHING_PROTOCOLS, response.status());
536        assert_eq!(0, response.body().size_hint().exact().unwrap());
537        assert_eq!(
538            "fido.cable",
539            HeaderValue::to_str(response.headers().get(SEC_WEBSOCKET_PROTOCOL).unwrap()).unwrap()
540        );
541        true
542    }
543
544    const TEST_ORIGINS: [Option<&str>; 3] =
545        [None, Some("cable.example.com"), Some("cable.example.net")];
546
547    #[tokio::test]
548    async fn static_router() -> Result<(), Box<dyn StdError>> {
549        // Index handler, without origin
550        let request = Request::get("https://cable.example.com/").body(())?;
551
552        for router_origin in TEST_ORIGINS {
553            check_index_response(
554                Router::route(&request, router_origin)
555                    .static_response()
556                    .ok_or("expected static response")?,
557            )
558            .await;
559        }
560
561        // Index handler, with origin
562        let request = Request::get("https://cable.example.com/")
563            .header(ORIGIN, "cable.example.com")
564            .body(())?;
565
566        for router_origin in TEST_ORIGINS {
567            check_index_response(
568                Router::route(&request, router_origin)
569                    .static_response()
570                    .ok_or("expected static response")?,
571            )
572            .await;
573        }
574
575        // Favicon handler, without origin
576        let request = Request::get("https://cable.example.com/favicon.ico").body(())?;
577
578        for router_origin in TEST_ORIGINS {
579            check_favicon_response(
580                Router::route(&request, router_origin)
581                    .static_response()
582                    .ok_or("expected static response")?,
583            )
584            .await;
585        }
586
587        // Favicon handler, with origin
588        let request = Request::get("https://cable.example.com/favicon.ico")
589            .header(ORIGIN, "cable.example.com")
590            .body(())?;
591
592        for router_origin in TEST_ORIGINS {
593            check_favicon_response(
594                Router::route(&request, router_origin)
595                    .static_response()
596                    .ok_or("expected static response")?,
597            )
598            .await;
599        }
600        Ok(())
601    }
602
603    #[test]
604    fn debug_router() -> Result<(), Box<dyn StdError>> {
605        // Debug handler, without origin
606        let request = Request::get("https://cable.example.com/debug").body(())?;
607
608        for router_origin in TEST_ORIGINS {
609            assert!(matches!(
610                Router::route(&request, router_origin),
611                Router::Debug
612            ));
613        }
614
615        // Debug handler, with origin
616        let request = Request::get("https://cable.example.com/debug")
617            .header(ORIGIN, "cable.example.com")
618            .body(())?;
619
620        for router_origin in TEST_ORIGINS {
621            assert!(matches!(
622                Router::route(&request, router_origin),
623                Router::Debug
624            ));
625        }
626
627        Ok(())
628    }
629
630    #[test]
631    fn websocket_router() -> Result<(), Box<dyn StdError>> {
632        // Make WebSocket request missing caBLE headers
633        let request = "wss://cable.example.com/cable/new/68656C6C6F2C20776562617574686E21"
634            .into_client_request()?;
635        check_error_response(
636            Router::route(&request, None)
637                .static_response()
638                .ok_or("expected static response")?,
639            StatusCode::BAD_REQUEST,
640        );
641
642        // With caBLE headers, but no Origin
643        let mut request = "wss://cable.example.com/cable/new/68656C6C6F2C20776562617574686E21"
644            .into_client_request()?;
645        request
646            .headers_mut()
647            .insert(SEC_WEBSOCKET_PROTOCOL, CABLE_PROTOCOL.to_owned());
648        matches!(
649            Router::route(&request, None),
650            Router::Websocket(_, p) if p == CablePath::new(*b"hello, webauthn!")
651        );
652        check_error_response(
653            Router::route(&request, Some("cable.example.com"))
654                .static_response()
655                .ok_or("expected static response")?,
656            StatusCode::FORBIDDEN,
657        );
658
659        // With caBLE headers and Origin
660        let mut request = "wss://cable.example.com/cable/new/68656C6C6F2C20776562617574686E21"
661            .into_client_request()?;
662        let headers = request.headers_mut();
663        headers.insert(SEC_WEBSOCKET_PROTOCOL, CABLE_PROTOCOL.to_owned());
664        headers.insert(ORIGIN, HeaderValue::from_static("cable.example.com"));
665
666        for router_origin in [None, Some("cable.example.com")] {
667            matches!(
668                Router::route(&request, router_origin),
669                Router::Websocket(r, p) if p == CablePath::new(*b"hello, webauthn!") && check_websocket_response(&r)
670            );
671        }
672        matches!(
673            Router::route(&request, Some("cable.example.net")),
674            Router::Websocket(r, p) if p == CablePath::new(*b"hello, webauthn!") && check_websocket_response(&r)
675        );
676
677        Ok(())
678    }
679}