1use std::{
7 convert::Infallible, error::Error as StdError, fmt::Display, future::Future, mem::size_of,
8 net::SocketAddr, num::ParseIntError, sync::Arc, time::Duration,
9};
10
11use hex::{FromHex, FromHexError};
12use http_body::Body;
13use http_body_util::{Empty, Full};
14use hyper::{
15 body::{Bytes, Incoming},
16 header::{CONTENT_TYPE, ORIGIN, SEC_WEBSOCKET_PROTOCOL},
17 http::HeaderValue,
18 service::service_fn,
19 HeaderMap, Method, Request, Response, StatusCode, Uri,
20};
21use tokio::net::TcpListener;
22use tokio_native_tls::TlsAcceptor;
23use tokio_tungstenite::MaybeTlsStream;
24use tracing::Instrument;
25use tracing_subscriber::{filter::LevelFilter, fmt::format::FmtSpan, EnvFilter};
26use tungstenite::handshake::server::create_response;
27
28#[macro_use]
29extern crate tracing;
30
31mod tls;
32pub use tls::*;
33
34pub type RoutingId = [u8; 3];
35pub type TunnelId = [u8; 16];
36
37pub static CABLE_PROTOCOL: HeaderValue = HeaderValue::from_static("fido.cable");
38pub const CABLE_ROUTING_ID_HEADER: &str = "X-caBLE-Routing-ID";
39
40pub const CABLE_NEW_PATH: &str = "/cable/new/";
41pub const CABLE_CONNECT_PATH: &str = "/cable/connect/";
42
43pub const MAX_URL_LENGTH: usize =
44 CABLE_CONNECT_PATH.len() + ((size_of::<RoutingId>() + size_of::<TunnelId>()) * 2) + 2;
45
46const FAVICON: &[u8] = include_bytes!("favicon.ico");
47const INDEX: &str = include_str!("index.html");
48
49pub fn parse_hex<T>(i: &str) -> Result<T, FromHexError>
53where
54 T: FromHex<Error = FromHexError>,
55{
56 FromHex::from_hex(i)
57}
58
59pub fn parse_duration_secs(i: &str) -> Result<Duration, ParseIntError> {
63 i.parse::<u64>().map(Duration::from_secs)
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct CablePath {
69 pub method: CableMethod,
71 pub routing_id: RoutingId,
73 pub tunnel_id: TunnelId,
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum CableMethod {
80 New,
83 Connect,
85}
86
87impl CablePath {
88 pub fn new(tunnel_id: TunnelId) -> Self {
89 Self {
90 method: CableMethod::New,
91 routing_id: [0; size_of::<RoutingId>()],
92 tunnel_id,
93 }
94 }
95
96 pub fn connect(routing_id: RoutingId, tunnel_id: TunnelId) -> Self {
97 Self {
98 method: CableMethod::Connect,
99 routing_id,
100 tunnel_id,
101 }
102 }
103
104 pub fn insert_routing_id_header(&self, headers: &mut HeaderMap) {
106 headers.insert(
107 CABLE_ROUTING_ID_HEADER,
108 HeaderValue::from_str(&hex::encode_upper(self.routing_id)).unwrap(),
109 );
110 }
111}
112
113impl Display for CablePath {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self.method {
116 CableMethod::New => {
117 write!(f, "{}{}", CABLE_NEW_PATH, hex::encode_upper(self.tunnel_id))
118 }
119 CableMethod::Connect => write!(
120 f,
121 "{}{}/{}",
122 CABLE_CONNECT_PATH,
123 hex::encode_upper(self.routing_id),
124 hex::encode_upper(self.tunnel_id)
125 ),
126 }
127 }
128}
129
130impl TryFrom<&str> for CablePath {
131 type Error = ();
132 fn try_from(path: &str) -> Result<Self, Self::Error> {
133 if path.len() > MAX_URL_LENGTH {
134 error!("path too long: {} > {MAX_URL_LENGTH} bytes", path.len());
135 return Err(());
136 } else if let Some(path) = path.strip_prefix(CABLE_NEW_PATH) {
137 let mut tunnel_id: TunnelId = [0; size_of::<TunnelId>()];
138 if hex::decode_to_slice(path, &mut tunnel_id).is_ok() {
139 return Ok(Self::new(tunnel_id));
140 }
141 error!("invalid new path: {path}");
142 } else if let Some(path) = path.strip_prefix(CABLE_CONNECT_PATH) {
143 let mut routing_id: RoutingId = [0; size_of::<RoutingId>()];
144 let mut tunnel_id: TunnelId = [0; size_of::<TunnelId>()];
145
146 let mut splitter = path.split('/');
147
148 if splitter
149 .next()
150 .and_then(|c| hex::decode_to_slice(c, &mut routing_id).ok())
151 .is_none()
152 {
153 error!("invalid routing_id in connect path: {path}");
154 return Err(());
155 }
156
157 if splitter
158 .next()
159 .and_then(|c| hex::decode_to_slice(c, &mut tunnel_id).ok())
160 .is_none()
161 {
162 error!("invalid tunnel_id in connect path: {path}");
163 return Err(());
164 }
165
166 if splitter.next().is_some() {
167 error!("unexpected extra token in connect path: {path}");
168 return Err(());
169 }
170
171 return Ok(Self::connect(routing_id, tunnel_id));
172 } else {
173 error!("unknown path: {path}")
174 }
175
176 Err(())
177 }
178}
179
180#[derive(Debug)]
182pub enum Router {
183 Websocket(Response<Full<Bytes>>, CablePath),
186 Static(Response<Full<Bytes>>),
189
190 Debug,
191}
192
193impl Router {
194 pub fn route(req: &Request<()>, origin: Option<&str>) -> Self {
196 if req.method() != Method::GET {
197 error!("method {} not allowed", req.method());
198 let response = Response::builder()
199 .status(StatusCode::METHOD_NOT_ALLOWED)
200 .header("Allow", "GET")
201 .body(Default::default())
202 .unwrap();
203 return Self::Static(response);
204 }
205
206 let path = match req.uri().path() {
207 "/" => {
208 return Self::Static(
209 Response::builder()
210 .status(StatusCode::OK)
211 .header(CONTENT_TYPE, "text/html")
212 .body(Bytes::from(INDEX).into())
213 .unwrap(),
214 )
215 }
216 "/favicon.ico" => {
217 return Self::Static(
218 Response::builder()
219 .status(StatusCode::OK)
220 .header(CONTENT_TYPE, "image/vnd.microsoft.icon")
221 .body(Bytes::from(FAVICON).into())
222 .unwrap(),
223 );
224 }
225 "/debug" => return Self::Debug,
226 path => match CablePath::try_from(path) {
227 Err(()) => {
228 return Self::Static(empty_response(StatusCode::NOT_FOUND));
229 }
230 Ok(p) => p,
231 },
232 };
233
234 let mut res = match create_response(req) {
235 Ok(r) => r,
236 Err(e) => {
237 error!("bad request for WebSocket: {e}");
238 return Self::Static(empty_response(StatusCode::BAD_REQUEST));
239 }
240 };
241
242 if !req
245 .headers()
246 .get(SEC_WEBSOCKET_PROTOCOL)
247 .map(|v| v == CABLE_PROTOCOL)
248 .unwrap_or_default()
249 {
250 error!("unsupported or missing WebSocket protocol");
251 return Self::Static(empty_response(StatusCode::BAD_REQUEST));
252 }
253
254 if let Some(origin) = origin {
256 if !req
257 .headers()
258 .get(ORIGIN)
259 .and_then(|v| v.to_str().ok())
260 .and_then(|v| v.parse::<Uri>().ok())
261 .map(|v| {
262 v.host()
263 .map(|o| o.eq_ignore_ascii_case(origin))
264 .unwrap_or_default()
265 })
266 .unwrap_or_default()
267 {
268 error!("incorrect or missing Origin header");
269 return Self::Static(empty_response(StatusCode::FORBIDDEN));
270 }
271 }
272
273 res.headers_mut()
275 .append(SEC_WEBSOCKET_PROTOCOL, CABLE_PROTOCOL.to_owned());
276 let res = res.map(|_| Default::default());
277
278 Router::Websocket(res, path)
279 }
280
281 #[cfg(test)]
282 pub(self) fn static_response(self) -> Option<Response<Full<Bytes>>> {
283 match self {
284 Self::Static(r) => Some(r),
285 _ => None,
286 }
287 }
288}
289
290pub fn empty_response<E: Into<StatusCode>, T: Default>(status: E) -> Response<T> {
292 Response::builder()
293 .status(status)
294 .body(Default::default())
295 .unwrap()
296}
297
298pub fn copy_request_empty_body<T>(r: &Request<T>) -> Request<Empty<Bytes>> {
300 let mut o = Request::builder().method(r.method()).uri(r.uri());
301 {
302 let headers = o.headers_mut().unwrap();
303 headers.extend(r.headers().to_owned());
304 }
305
306 o.body(Default::default()).unwrap()
307}
308
309pub fn copy_response_empty_body<T>(r: &Response<T>) -> Response<Empty<Bytes>> {
311 let mut o = Response::builder().status(r.status());
312 {
313 let headers = o.headers_mut().unwrap();
314 headers.extend(r.headers().to_owned());
315 }
316
317 o.body(Default::default()).unwrap()
318}
319
320pub async fn run_server<F, R, ResBody, T>(
322 bind_address: SocketAddr,
323 tls_acceptor: Option<TlsAcceptor>,
324 server_state: T,
325 mut request_handler: F,
326) -> Result<(), Box<dyn StdError>>
327where
328 F: FnMut(Arc<T>, SocketAddr, Request<Incoming>) -> R + Copy + Send + Sync + 'static,
329 R: Future<Output = Result<Response<ResBody>, Infallible>> + Send,
330 ResBody: Body + Send + 'static,
331 <ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
332 <ResBody as Body>::Data: Send,
333 T: Send + Sync + 'static,
334{
335 let server_state = Arc::new(server_state);
336 let tcp = TcpListener::bind(&bind_address).await?;
337 let tls_acceptor = tls_acceptor.map(Arc::new);
338
339 loop {
340 let (stream, remote_addr) = match tcp.accept().await {
341 Ok(o) => o,
342 Err(e) => {
343 error!("tcp.accept: {e}");
344 continue;
345 }
346 };
347 let server_state = server_state.clone();
348 let service =
349 service_fn(move |req| request_handler(server_state.clone(), remote_addr, req));
350 let tls_acceptor = tls_acceptor.clone();
351
352 let span = info_span!("handle_connection", addr = remote_addr.to_string());
353 tokio::task::spawn(
354 async move {
355 let stream = match tls_acceptor {
356 None => MaybeTlsStream::Plain(stream),
357 Some(tls_acceptor) => match tls_acceptor.accept(stream).await {
358 Ok(o) => MaybeTlsStream::NativeTls(o),
359 Err(e) => {
360 error!("tls_acceptor.accept: {e}");
361 return;
362 }
363 },
364 };
365
366 let conn =
367 hyper::server::conn::http1::Builder::new().serve_connection(stream, service);
368 let conn = conn.with_upgrades();
369
370 if let Err(e) = conn.await {
371 error!("connection error: {e}");
372 }
373 }
374 .instrument(span),
375 );
376 }
377}
378
379pub fn setup_logging() {
381 tracing_subscriber::fmt()
382 .with_env_filter(
383 EnvFilter::builder()
384 .with_default_directive(LevelFilter::INFO.into())
385 .from_env_lossy(),
386 )
387 .with_span_events(FmtSpan::CLOSE | FmtSpan::NEW)
388 .with_thread_ids(true)
389 .compact()
392 .init();
393}
394
395#[cfg(test)]
396mod tests {
397 use http_body_util::BodyExt;
398 use tungstenite::client::IntoClientRequest;
399
400 use super::*;
401
402 #[test]
403 fn parse_urls() {
404 assert_eq!(
406 CablePath::new(*b"hello, webauthn!"),
407 CablePath::try_from("/cable/new/68656C6C6F2C20776562617574686E21").unwrap()
408 );
409 assert_eq!(
410 CablePath::connect(*b"abc", *b"hello, webauthn!"),
411 CablePath::try_from("/cable/connect/616263/68656C6C6F2C20776562617574686E21").unwrap()
412 );
413
414 assert_eq!(
416 "/cable/new/68656C6C6F2C20776562617574686E21",
417 CablePath::new(*b"hello, webauthn!").to_string(),
418 );
419 assert_eq!(
420 "/cable/connect/616263/68656C6C6F2C20776562617574686E21",
421 CablePath::connect(*b"abc", *b"hello, webauthn!").to_string(),
422 );
423
424 assert_eq!(
426 CablePath::new(*b"hello, webauthn!"),
427 CablePath::try_from("/cable/new/68656c6c6f2c20776562617574686e21").unwrap()
428 );
429 assert_eq!(
430 CablePath::connect(*b"abc", *b"hello, webauthn!"),
431 CablePath::try_from("/cable/connect/616263/68656c6c6f2c20776562617574686e21").unwrap()
432 );
433
434 assert_eq!(
436 "/cable/new/68656C6C6F2C20776562617574686E21",
437 CablePath::try_from("/cable/new/68656c6c6f2c20776562617574686e21")
438 .unwrap()
439 .to_string()
440 );
441 assert_eq!(
442 "/cable/connect/616263/68656C6C6F2C20776562617574686E21",
443 CablePath::try_from("/cable/connect/616263/68656c6c6f2c20776562617574686e21")
444 .unwrap()
445 .to_string()
446 );
447
448 assert!(CablePath::try_from("/").is_err());
450
451 assert!(CablePath::try_from("/cable/new/").is_err());
452 assert!(CablePath::try_from("/cable/new/not_hex_digits_here_but_32_chars").is_err());
453 assert!(CablePath::try_from("/cable/new/C0FFEE").is_err());
454 assert!(CablePath::try_from("/cable/new/C0FFEEC0FFEEC0FFEEC0FFEEC0FFEEC0FFEE").is_err());
455 assert!(CablePath::try_from("/cable/new/68656C6C6F2C20776562617574686E21/").is_err());
456 assert!(CablePath::try_from("/cable/new/../new/68656C6C6F2C20776562617574686E21").is_err());
457
458 assert!(
459 CablePath::try_from("/cable/connect/C0FFEE/not_hex_digits_here_but_32_chars").is_err()
460 );
461 assert!(CablePath::try_from("/cable/connect/C0FFEE/COFFEE").is_err());
462 assert!(CablePath::try_from("/cable/connect/C0/FFEE").is_err());
463 assert!(CablePath::try_from("/cable/connect/C0/68656C6C6F2C20776562617574686E21").is_err());
464 assert!(
465 CablePath::try_from("/cable/connect/C0F/68656C6C6F2C20776562617574686E21").is_err()
466 );
467 assert!(
468 CablePath::try_from("/cable/connect/C0FFEECO/68656C6C6F2C20776562617574686E21")
469 .is_err()
470 );
471 assert!(
472 CablePath::try_from("/cable/connect/C0FFEE/68656C6C6F2C20776562617574686E21/1234")
473 .is_err()
474 );
475 assert!(
476 CablePath::try_from("/cable/connect/C0FFEE/68656C6C6F2C20776562617574686E21/").is_err()
477 );
478
479 assert!(CablePath::try_from("cable/new/68656C6C6F2C20776562617574686E21").is_err());
481 assert!(CablePath::try_from("../cable/new/68656C6C6F2C20776562617574686E21").is_err());
482 assert!(CablePath::try_from("/../cable/new/68656C6C6F2C20776562617574686E21").is_err());
483 assert!(CablePath::try_from("../../../etc/passwd").is_err());
484
485 assert!(CablePath::try_from(include_str!("lib.rs")).is_err());
487 }
488
489 async fn check_index_response(response: Response<Full<Bytes>>) {
490 assert_eq!(StatusCode::OK, response.status());
491 assert_eq!(
492 "text/html",
493 HeaderValue::to_str(response.headers().get(CONTENT_TYPE).unwrap()).unwrap()
494 );
495 assert!(response.body().size_hint().exact().unwrap() > 16);
496 assert_eq!(
497 INDEX,
498 response
499 .into_body()
500 .frame()
501 .await
502 .unwrap()
503 .unwrap()
504 .into_data()
505 .unwrap()
506 );
507 }
508
509 async fn check_favicon_response(response: Response<Full<Bytes>>) {
510 assert_eq!(StatusCode::OK, response.status());
511 assert_eq!(
512 "image/vnd.microsoft.icon",
513 HeaderValue::to_str(response.headers().get(CONTENT_TYPE).unwrap()).unwrap()
514 );
515 assert!(response.body().size_hint().exact().unwrap() > 16);
516 assert_eq!(
517 FAVICON,
518 response
519 .into_body()
520 .frame()
521 .await
522 .unwrap()
523 .unwrap()
524 .into_data()
525 .unwrap()
526 );
527 }
528
529 fn check_error_response(response: Response<Full<Bytes>>, expected_status: StatusCode) {
530 assert_eq!(expected_status, response.status());
531 assert_eq!(0, response.body().size_hint().exact().unwrap());
532 }
533
534 fn check_websocket_response(response: &Response<Full<Bytes>>) -> bool {
535 assert_eq!(StatusCode::SWITCHING_PROTOCOLS, response.status());
536 assert_eq!(0, response.body().size_hint().exact().unwrap());
537 assert_eq!(
538 "fido.cable",
539 HeaderValue::to_str(response.headers().get(SEC_WEBSOCKET_PROTOCOL).unwrap()).unwrap()
540 );
541 true
542 }
543
544 const TEST_ORIGINS: [Option<&str>; 3] =
545 [None, Some("cable.example.com"), Some("cable.example.net")];
546
547 #[tokio::test]
548 async fn static_router() -> Result<(), Box<dyn StdError>> {
549 let request = Request::get("https://cable.example.com/").body(())?;
551
552 for router_origin in TEST_ORIGINS {
553 check_index_response(
554 Router::route(&request, router_origin)
555 .static_response()
556 .ok_or("expected static response")?,
557 )
558 .await;
559 }
560
561 let request = Request::get("https://cable.example.com/")
563 .header(ORIGIN, "cable.example.com")
564 .body(())?;
565
566 for router_origin in TEST_ORIGINS {
567 check_index_response(
568 Router::route(&request, router_origin)
569 .static_response()
570 .ok_or("expected static response")?,
571 )
572 .await;
573 }
574
575 let request = Request::get("https://cable.example.com/favicon.ico").body(())?;
577
578 for router_origin in TEST_ORIGINS {
579 check_favicon_response(
580 Router::route(&request, router_origin)
581 .static_response()
582 .ok_or("expected static response")?,
583 )
584 .await;
585 }
586
587 let request = Request::get("https://cable.example.com/favicon.ico")
589 .header(ORIGIN, "cable.example.com")
590 .body(())?;
591
592 for router_origin in TEST_ORIGINS {
593 check_favicon_response(
594 Router::route(&request, router_origin)
595 .static_response()
596 .ok_or("expected static response")?,
597 )
598 .await;
599 }
600 Ok(())
601 }
602
603 #[test]
604 fn debug_router() -> Result<(), Box<dyn StdError>> {
605 let request = Request::get("https://cable.example.com/debug").body(())?;
607
608 for router_origin in TEST_ORIGINS {
609 assert!(matches!(
610 Router::route(&request, router_origin),
611 Router::Debug
612 ));
613 }
614
615 let request = Request::get("https://cable.example.com/debug")
617 .header(ORIGIN, "cable.example.com")
618 .body(())?;
619
620 for router_origin in TEST_ORIGINS {
621 assert!(matches!(
622 Router::route(&request, router_origin),
623 Router::Debug
624 ));
625 }
626
627 Ok(())
628 }
629
630 #[test]
631 fn websocket_router() -> Result<(), Box<dyn StdError>> {
632 let request = "wss://cable.example.com/cable/new/68656C6C6F2C20776562617574686E21"
634 .into_client_request()?;
635 check_error_response(
636 Router::route(&request, None)
637 .static_response()
638 .ok_or("expected static response")?,
639 StatusCode::BAD_REQUEST,
640 );
641
642 let mut request = "wss://cable.example.com/cable/new/68656C6C6F2C20776562617574686E21"
644 .into_client_request()?;
645 request
646 .headers_mut()
647 .insert(SEC_WEBSOCKET_PROTOCOL, CABLE_PROTOCOL.to_owned());
648 matches!(
649 Router::route(&request, None),
650 Router::Websocket(_, p) if p == CablePath::new(*b"hello, webauthn!")
651 );
652 check_error_response(
653 Router::route(&request, Some("cable.example.com"))
654 .static_response()
655 .ok_or("expected static response")?,
656 StatusCode::FORBIDDEN,
657 );
658
659 let mut request = "wss://cable.example.com/cable/new/68656C6C6F2C20776562617574686E21"
661 .into_client_request()?;
662 let headers = request.headers_mut();
663 headers.insert(SEC_WEBSOCKET_PROTOCOL, CABLE_PROTOCOL.to_owned());
664 headers.insert(ORIGIN, HeaderValue::from_static("cable.example.com"));
665
666 for router_origin in [None, Some("cable.example.com")] {
667 matches!(
668 Router::route(&request, router_origin),
669 Router::Websocket(r, p) if p == CablePath::new(*b"hello, webauthn!") && check_websocket_response(&r)
670 );
671 }
672 matches!(
673 Router::route(&request, Some("cable.example.net")),
674 Router::Websocket(r, p) if p == CablePath::new(*b"hello, webauthn!") && check_websocket_response(&r)
675 );
676
677 Ok(())
678 }
679}