sbd_server/
lib.rs

1//! Sbd server library.
2#![deny(missing_docs)]
3
4/// defined by the sbd spec
5const MAX_MSG_BYTES: i32 = 20_000;
6
7use std::collections::HashMap;
8use std::io::{Error, Result};
9use std::net::{IpAddr, Ipv6Addr};
10use std::sync::{Arc, Mutex};
11
12mod config;
13pub use config::*;
14
15mod maybe_tls;
16pub use maybe_tls::*;
17
18mod ip_deny;
19mod ip_rate;
20pub use ip_rate::*;
21
22mod cslot;
23pub use cslot::*;
24
25mod cmd;
26
27/// Websocket backend abstraction.
28pub mod ws {
29    /// Payload.
30    pub enum Payload {
31        /// Vec.
32        Vec(Vec<u8>),
33
34        /// BytesMut.
35        BytesMut(bytes::BytesMut),
36    }
37
38    impl std::ops::Deref for Payload {
39        type Target = [u8];
40
41        #[inline(always)]
42        fn deref(&self) -> &Self::Target {
43            match self {
44                Payload::Vec(v) => v.as_slice(),
45                Payload::BytesMut(b) => b.as_ref(),
46            }
47        }
48    }
49
50    impl Payload {
51        /// Mutable payload.
52        #[inline(always)]
53        pub fn to_mut(&mut self) -> &mut [u8] {
54            match self {
55                Payload::Vec(ref mut owned) => owned,
56                Payload::BytesMut(b) => b.as_mut(),
57            }
58        }
59    }
60
61    use futures::future::BoxFuture;
62
63    /// Websocket trait.
64    pub trait SbdWebsocket: Send + Sync + 'static {
65        /// Receive from the websocket.
66        fn recv(&self) -> BoxFuture<'static, std::io::Result<Payload>>;
67
68        /// Send to the websocket.
69        fn send(
70            &self,
71            payload: Payload,
72        ) -> BoxFuture<'static, std::io::Result<()>>;
73
74        /// Close the websocket.
75        fn close(&self) -> BoxFuture<'static, ()>;
76    }
77}
78
79pub use ws::{Payload, SbdWebsocket};
80
81/// Public key.
82#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83pub struct PubKey(pub Arc<[u8; 32]>);
84
85impl PubKey {
86    /// Verify a signature with this pub key.
87    pub fn verify(&self, sig: &[u8; 64], data: &[u8]) -> bool {
88        use ed25519_dalek::Verifier;
89        if let Ok(k) = ed25519_dalek::VerifyingKey::from_bytes(&self.0) {
90            k.verify(data, &ed25519_dalek::Signature::from_bytes(sig))
91                .is_ok()
92        } else {
93            false
94        }
95    }
96}
97
98/// SbdServer.
99pub struct SbdServer {
100    task_list: Vec<tokio::task::JoinHandle<()>>,
101    bind_addrs: Vec<std::net::SocketAddr>,
102
103    // this should be the only non-weak cslot so the others are dropped
104    // if this top-level server instance is ever dropped.
105    _cslot: CSlot,
106}
107
108impl Drop for SbdServer {
109    fn drop(&mut self) {
110        for task in self.task_list.iter() {
111            task.abort();
112        }
113    }
114}
115
116/// Convert an IP address to an IPv6 address.
117pub fn to_canonical_ip(ip: IpAddr) -> Arc<Ipv6Addr> {
118    Arc::new(match ip {
119        IpAddr::V4(ip) => ip.to_ipv6_mapped(),
120        IpAddr::V6(ip) => ip,
121    })
122}
123
124/// If the check passes, the canonical IP is returned, otherwise None and the connection should be
125/// dropped.
126pub async fn preflight_ip_check(
127    config: &Config,
128    ip_rate: &IpRate,
129    addr: std::net::SocketAddr,
130) -> Option<Arc<Ipv6Addr>> {
131    let raw_ip = to_canonical_ip(addr.ip());
132
133    let use_trusted_ip = config.trusted_ip_header.is_some();
134
135    if !use_trusted_ip {
136        // Do this check BEFORE handshake to avoid extra
137        // server process when capable.
138        // If we *are* behind a reverse proxy, we assume
139        // some amount of DDoS mitigation is happening there
140        // and thus we can accept a little more process overhead
141        if ip_rate.is_blocked(&raw_ip).await {
142            return None;
143        }
144
145        // Also precheck our rate limit, using up one byte
146        if !ip_rate.is_ok(&raw_ip, 1).await {
147            return None;
148        }
149    }
150
151    Some(raw_ip)
152}
153
154/// Handle an upgraded websocket connection.
155pub async fn handle_upgraded(
156    config: Arc<Config>,
157    ip_rate: Arc<IpRate>,
158    weak_cslot: WeakCSlot,
159    ws: Arc<impl SbdWebsocket>,
160    pub_key: PubKey,
161    calc_ip: Arc<Ipv6Addr>,
162    maybe_auth: Option<(Option<Arc<str>>, AuthTokenTracker)>,
163) {
164    let use_trusted_ip = config.trusted_ip_header.is_some();
165
166    // illegal pub key
167    if &pub_key.0[..28] == cmd::CMD_PREFIX {
168        return;
169    }
170
171    if use_trusted_ip {
172        // if using a trusted ip, check block here.
173        // see note above before the handshakes.
174        if ip_rate.is_blocked(&calc_ip).await {
175            return;
176        }
177
178        // Also precheck our rate limit, using up one byte
179        if !ip_rate.is_ok(&calc_ip, 1).await {
180            return;
181        }
182    }
183
184    if let Some(cslot) = weak_cslot.upgrade() {
185        cslot
186            .insert(&config, calc_ip, pub_key, ws, maybe_auth)
187            .await;
188    }
189}
190
191/// Handle the /authenticate request for access token
192async fn handle_auth(
193    axum::extract::State(app_state): axum::extract::State<AppState>,
194    body: bytes::Bytes,
195) -> axum::response::Response {
196    use AuthenticateTokenError::*;
197
198    // process the actual authentication
199    match process_authenticate_token(
200        &app_state.config,
201        &app_state.token_tracker,
202        body,
203    )
204    .await
205    {
206        Ok(token) => axum::response::IntoResponse::into_response(axum::Json(
207            serde_json::json!({
208                "authToken": *token,
209            }),
210        )),
211        Err(Unauthorized) => {
212            tracing::debug!("/authenticate: UNAUTHORIZED");
213            axum::response::IntoResponse::into_response((
214                axum::http::StatusCode::UNAUTHORIZED,
215                "Unauthorized",
216            ))
217        }
218        Err(HookServerError(err)) => {
219            tracing::debug!(?err, "/authenticate: BAD_GATEWAY");
220            axum::response::IntoResponse::into_response((
221                axum::http::StatusCode::BAD_GATEWAY,
222                format!("BAD_GATEWAY: {err:?}"),
223            ))
224        }
225        Err(OtherError(err)) => {
226            tracing::warn!(?err, "/authenticate: INTERNAL_SERVER_ERROR");
227            axum::response::IntoResponse::into_response((
228                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
229                format!("INTERNAL_SERVER_ERROR: {err:?}"),
230            ))
231        }
232    }
233}
234
235/// Authenticate token error type.
236pub enum AuthenticateTokenError {
237    /// The token is invalid.
238    Unauthorized,
239    /// We had an error talking to the hook server.
240    HookServerError(Error),
241    /// We had an internal error.
242    OtherError(Error),
243}
244
245/// Handle receiving a PUT "/authenticate" rest api request.
246pub async fn process_authenticate_token(
247    config: &Config,
248    token_tracker: &AuthTokenTracker,
249    auth_material: bytes::Bytes,
250) -> std::result::Result<Arc<str>, AuthenticateTokenError> {
251    use AuthenticateTokenError::*;
252
253    let token: Arc<str> = if let Some(url) = &config.authentication_hook_server
254    {
255        // if a hook server is configured, forward the call to it
256
257        let url = url.clone();
258        let token = tokio::task::spawn_blocking(move || {
259            ureq::put(&url)
260                .header("Content-Type", "application/octet-stream")
261                .send(&auth_material[..])
262                .map_err(|err| match err {
263                    ureq::Error::StatusCode(401) => Unauthorized,
264                    oth => HookServerError(Error::other(oth)),
265                })?
266                .into_body()
267                .read_to_string()
268                .map_err(Error::other)
269                // this is a HookServerError, not an OtherError, because
270                // it is the hook server that either failed to send a full
271                // response, or sent back non-utf8 bytes, etc...
272                .map_err(HookServerError)
273        })
274        .await
275        .map_err(|_| OtherError(Error::other("tokio task died")))??;
276
277        #[derive(serde::Deserialize)]
278        #[serde(rename_all = "camelCase")]
279        struct Token {
280            auth_token: String,
281        }
282
283        let token: Token = serde_json::from_str(&token)
284            .map_err(|err| OtherError(Error::other(err)))?;
285
286        token.auth_token
287    } else {
288        // If no hook server is configured, fallback to gen random token
289
290        use base64::prelude::*;
291        use rand::Rng;
292
293        let mut bytes = [0; 32];
294        rand::thread_rng().fill(&mut bytes);
295        BASE64_URL_SAFE_NO_PAD.encode(&bytes[..])
296    }
297    .into();
298
299    // register the token with our authentication token tracker
300    token_tracker.register_token(token.clone());
301
302    Ok(token)
303}
304
305/// Implement the ability to use axum websockets as our websocket backend.
306#[derive(Clone)]
307struct WebsocketImpl {
308    write: Arc<
309        tokio::sync::Mutex<
310            futures::stream::SplitSink<
311                axum::extract::ws::WebSocket,
312                axum::extract::ws::Message,
313            >,
314        >,
315    >,
316    read: Arc<
317        tokio::sync::Mutex<
318            futures::stream::SplitStream<axum::extract::ws::WebSocket>,
319        >,
320    >,
321}
322
323impl SbdWebsocket for WebsocketImpl {
324    fn recv(&self) -> futures::future::BoxFuture<'static, Result<Payload>> {
325        let this = self.clone();
326        Box::pin(async move {
327            let mut read = this.read.lock().await;
328            use futures::stream::StreamExt;
329            loop {
330                match read.next().await {
331                    None => return Err(Error::other("closed")),
332                    Some(r) => {
333                        let msg = r.map_err(Error::other)?;
334                        match msg {
335                            axum::extract::ws::Message::Text(s) => {
336                                return Ok(Payload::Vec(s.as_bytes().to_vec()))
337                            }
338                            axum::extract::ws::Message::Binary(v) => {
339                                return Ok(Payload::Vec(v[..].to_vec()))
340                            }
341                            axum::extract::ws::Message::Ping(_)
342                            | axum::extract::ws::Message::Pong(_) => (),
343                            axum::extract::ws::Message::Close(_) => {
344                                return Err(Error::other("closed"))
345                            }
346                        }
347                    }
348                }
349            }
350        })
351    }
352
353    fn send(
354        &self,
355        payload: Payload,
356    ) -> futures::future::BoxFuture<'static, Result<()>> {
357        use futures::SinkExt;
358        let this = self.clone();
359        Box::pin(async move {
360            let mut write = this.write.lock().await;
361            let v = match payload {
362                Payload::Vec(v) => v,
363                Payload::BytesMut(b) => b.to_vec(),
364            };
365            write
366                .send(axum::extract::ws::Message::Binary(
367                    bytes::Bytes::copy_from_slice(&v),
368                ))
369                .await
370                .map_err(Error::other)?;
371            write.flush().await.map_err(Error::other)?;
372            Ok(())
373        })
374    }
375
376    fn close(&self) -> futures::future::BoxFuture<'static, ()> {
377        use futures::SinkExt;
378        let this = self.clone();
379        Box::pin(async move {
380            let _ = this.write.lock().await.close().await;
381        })
382    }
383}
384
385impl WebsocketImpl {
386    fn new(ws: axum::extract::ws::WebSocket) -> Self {
387        use futures::StreamExt;
388        let (tx, rx) = ws.split();
389        Self {
390            write: Arc::new(tokio::sync::Mutex::new(tx)),
391            read: Arc::new(tokio::sync::Mutex::new(rx)),
392        }
393    }
394}
395
396/// Handle the http upgrade request for a websocket connection.
397async fn handle_ws(
398    axum::extract::Path(pub_key): axum::extract::Path<String>,
399    headers: axum::http::HeaderMap,
400    ws: axum::extract::WebSocketUpgrade,
401    axum::extract::ConnectInfo(addr): axum::extract::ConnectInfo<
402        std::net::SocketAddr,
403    >,
404    axum::extract::State(app_state): axum::extract::State<AppState>,
405) -> impl axum::response::IntoResponse {
406    use axum::response::IntoResponse;
407    use base64::Engine;
408
409    // first check for auth tokens
410    let token: Option<Arc<str>> = headers
411        .get("Authorization")
412        .and_then(|t| t.to_str().ok().map(<Arc<str>>::from));
413
414    let maybe_auth = Some((token.clone(), app_state.token_tracker.clone()));
415
416    // compare any passed tokens with our token authentication mechanism
417    if !app_state
418        .token_tracker
419        .check_is_token_valid(&app_state.config, token)
420    {
421        return axum::response::IntoResponse::into_response((
422            axum::http::StatusCode::UNAUTHORIZED,
423            "Unauthorized",
424        ));
425    }
426
427    // get the primary key this user is claiming
428    let pk = match base64::prelude::BASE64_URL_SAFE_NO_PAD.decode(pub_key) {
429        Ok(pk) if pk.len() == 32 => {
430            let mut sized_pk = [0; 32];
431            sized_pk.copy_from_slice(&pk);
432            PubKey(Arc::new(sized_pk))
433        }
434        _ => return axum::http::StatusCode::BAD_REQUEST.into_response(),
435    };
436
437    let mut calc_ip = to_canonical_ip(addr.ip());
438
439    // if we're using a trusted ip, parse that out of the header
440    if let Some(trusted_ip_header) = &app_state.config.trusted_ip_header {
441        if let Some(header) =
442            headers.get(trusted_ip_header).and_then(|h| h.to_str().ok())
443        {
444            if let Ok(ip) = header.parse::<IpAddr>() {
445                calc_ip = to_canonical_ip(ip);
446            }
447        }
448    }
449
450    // do the actual websocket upgrade
451    ws.max_message_size(MAX_MSG_BYTES as usize).on_upgrade(
452        move |socket| async move {
453            handle_upgraded(
454                app_state.config.clone(),
455                app_state.ip_rate.clone(),
456                app_state.cslot.clone(),
457                Arc::new(WebsocketImpl::new(socket)),
458                pk,
459                calc_ip,
460                maybe_auth,
461            )
462            .await;
463        },
464    )
465}
466
467/// Utility for managing auth tokens.
468#[derive(Clone, Default)]
469pub struct AuthTokenTracker {
470    token_map: Arc<Mutex<HashMap<Arc<str>, std::time::Instant>>>,
471}
472
473impl AuthTokenTracker {
474    /// Register a token as valid.
475    pub fn register_token(&self, token: Arc<str>) {
476        self.token_map
477            .lock()
478            .unwrap()
479            .insert(token, std::time::Instant::now());
480    }
481
482    /// Check that a token is valid.
483    /// If so, mark it as recently used so it doesn't time out.
484    /// The "token" parameter should be direct from the http header
485    /// i.e. with the "Barer" include, like "Bearer base64".
486    /// This should be called with None as the token if no Authenticate
487    /// header was specified.
488    pub fn check_is_token_valid(
489        &self,
490        config: &Config,
491        token: Option<Arc<str>>,
492    ) -> bool {
493        let token: Arc<str> = if let Some(token) = token {
494            // If the client supplied a token, always validate it,
495            // even if no hook server was specified in the config.
496            if !token.starts_with("Bearer ") {
497                return false;
498            }
499            token.trim_start_matches("Bearer ").into()
500        } else if config.authentication_hook_server.is_none() {
501            // If the client did not supply a token, and we have no
502            // hook server configured, allow the request.
503            return true;
504        } else {
505            // We have no token, but one is required. Unauthorized.
506            return false;
507        };
508
509        let mut lock = self.token_map.lock().unwrap();
510
511        let idle_dur = config.idle_dur();
512
513        lock.retain(|_t, e| e.elapsed() < idle_dur);
514
515        if let std::collections::hash_map::Entry::Occupied(mut e) =
516            lock.entry(token)
517        {
518            e.insert(std::time::Instant::now());
519            true
520        } else {
521            false
522        }
523    }
524}
525
526#[derive(Clone)]
527struct AppState {
528    config: Arc<Config>,
529    token_tracker: AuthTokenTracker,
530    ip_rate: Arc<IpRate>,
531    cslot: WeakCSlot,
532}
533
534impl AppState {
535    pub fn new(
536        config: Arc<Config>,
537        ip_rate: Arc<IpRate>,
538        cslot: WeakCSlot,
539    ) -> Self {
540        Self {
541            config,
542            token_tracker: AuthTokenTracker::default(),
543            ip_rate,
544            cslot,
545        }
546    }
547}
548
549impl SbdServer {
550    /// Construct a new running sbd server with the provided config.
551    pub async fn new(config: Arc<Config>) -> Result<Self> {
552        let tls_config = if let (Some(cert), Some(pk)) =
553            (&config.cert_pem_file, &config.priv_key_pem_file)
554        {
555            Some(Arc::new(TlsConfig::new(cert, pk).await?))
556        } else {
557            None
558        };
559
560        let mut task_list = Vec::new();
561        let mut bind_addrs = Vec::new();
562
563        let ip_rate = Arc::new(IpRate::new(config.clone()));
564        task_list.push(spawn_prune_task(ip_rate.clone()));
565
566        let cslot = CSlot::new(config.clone(), ip_rate.clone());
567        let weak_cslot = cslot.weak();
568
569        // setup the axum router
570        let app: axum::Router<()> = axum::Router::new()
571            .route("/authenticate", axum::routing::put(handle_auth))
572            .route("/{pub_key}", axum::routing::any(handle_ws))
573            .layer(axum::extract::DefaultBodyLimit::max(1024))
574            .with_state(AppState::new(
575                config.clone(),
576                ip_rate.clone(),
577                weak_cslot.clone(),
578            ));
579
580        let app =
581            app.into_make_service_with_connect_info::<std::net::SocketAddr>();
582
583        let mut found_port_zero: Option<u16> = None;
584
585        // bind to configured bindings
586        for bind in config.bind.iter() {
587            let mut a: std::net::SocketAddr =
588                bind.parse().map_err(Error::other)?;
589            if let Some(found_port_zero) = &found_port_zero {
590                if a.port() == 0 {
591                    a.set_port(*found_port_zero);
592                }
593            }
594
595            let h = axum_server::Handle::new();
596
597            if let Some(tls_config) = &tls_config {
598                let tls_config =
599                    axum_server::tls_rustls::RustlsConfig::from_config(
600                        tls_config.config(),
601                    );
602                let server = axum_server::bind_rustls(a, tls_config)
603                    .handle(h.clone())
604                    .serve(app.clone());
605                task_list.push(tokio::task::spawn(async move {
606                    if let Err(err) = server.await {
607                        tracing::error!(?err);
608                    }
609                }));
610            } else {
611                let server =
612                    axum_server::bind(a).handle(h.clone()).serve(app.clone());
613                task_list.push(tokio::task::spawn(async move {
614                    if let Err(err) = server.await {
615                        tracing::error!(?err);
616                    }
617                }));
618            }
619
620            if let Some(addr) = h.listening().await {
621                if found_port_zero.is_none() && a.port() == 0 {
622                    found_port_zero = Some(addr.port());
623                }
624                bind_addrs.push(addr);
625            }
626        }
627
628        Ok(Self {
629            task_list,
630            bind_addrs,
631            _cslot: cslot,
632        })
633    }
634
635    /// Get the list of addresses bound locally.
636    pub fn bind_addrs(&self) -> &[std::net::SocketAddr] {
637        self.bind_addrs.as_slice()
638    }
639}
640
641#[cfg(test)]
642mod test;