rust-oxide-realtime 0.2.0

Reusable realtime transport primitives for Axum servers and Rust websocket clients
Documentation
use std::sync::Arc;

use axum::{
    Router,
    extract::{
        Query, State,
        ws::{WebSocketUpgrade, rejection::WebSocketUpgradeRejection},
    },
    http::{HeaderMap, StatusCode, header},
    response::{IntoResponse, Response},
    routing::get,
};
use serde::Deserialize;

use super::{RealtimeError, SocketAppState};

#[derive(Debug, Clone)]
pub struct RealtimeRouteOptions {
    pub path: &'static str,
    pub allow_query_token: bool,
    pub strict_header_precedence: bool,
}

impl Default for RealtimeRouteOptions {
    fn default() -> Self {
        Self {
            path: "/realtime/socket",
            allow_query_token: true,
            strict_header_precedence: true,
        }
    }
}

struct SocketRouteState {
    socket_server_handle: Arc<SocketAppState>,
    options: RealtimeRouteOptions,
}

impl Clone for SocketRouteState {
    fn clone(&self) -> Self {
        Self {
            socket_server_handle: Arc::clone(&self.socket_server_handle),
            options: self.options.clone(),
        }
    }
}

#[derive(Debug, Deserialize, Default)]
struct SocketQuery {
    token: Option<String>,
}

#[derive(Debug)]
enum RealtimeHttpError {
    MissingToken,
    InvalidToken,
    UpgradeRequired,
    RealtimeDisabled,
    VerifyFailed(RealtimeError),
}

impl RealtimeHttpError {
    fn status(&self) -> StatusCode {
        match self {
            Self::MissingToken | Self::InvalidToken => StatusCode::UNAUTHORIZED,
            Self::UpgradeRequired => StatusCode::BAD_REQUEST,
            Self::RealtimeDisabled => StatusCode::NOT_FOUND,
            Self::VerifyFailed(err) => match err {
                RealtimeError::BadRequest(_) => StatusCode::BAD_REQUEST,
                RealtimeError::Unauthorized(_) => StatusCode::UNAUTHORIZED,
                RealtimeError::Forbidden(_) => StatusCode::FORBIDDEN,
                RealtimeError::NotFound(_) => StatusCode::NOT_FOUND,
                RealtimeError::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
            },
        }
    }

    fn message(&self) -> String {
        match self {
            Self::MissingToken => {
                "Missing access token (use Authorization Bearer or token query param)".to_string()
            }
            Self::InvalidToken => "Missing/invalid Authorization header".to_string(),
            Self::UpgradeRequired => "WebSocket upgrade required".to_string(),
            Self::RealtimeDisabled => "Realtime is disabled".to_string(),
            Self::VerifyFailed(err) => err.message().to_string(),
        }
    }
}

impl IntoResponse for RealtimeHttpError {
    fn into_response(self) -> Response {
        (self.status(), self.message()).into_response()
    }
}

pub fn router(socket_server_handle: Arc<SocketAppState>) -> Router {
    router_with_options(socket_server_handle, RealtimeRouteOptions::default())
}

pub fn router_with_options(
    socket_server_handle: Arc<SocketAppState>,
    options: RealtimeRouteOptions,
) -> Router {
    let path = options.path;
    Router::new()
        .route(path, get(socket_handler))
        .with_state(SocketRouteState {
            socket_server_handle,
            options,
        })
}

async fn socket_handler(
    State(handler_state): State<SocketRouteState>,
    upgrade: Result<WebSocketUpgrade, WebSocketUpgradeRejection>,
    headers: HeaderMap,
    Query(query): Query<SocketQuery>,
) -> Response {
    let realtime = handler_state.socket_server_handle.handle.clone();

    if !realtime.is_enabled() {
        return RealtimeHttpError::RealtimeDisabled.into_response();
    }

    let upgrade = match upgrade {
        Ok(upgrade) => upgrade,
        Err(_) => return RealtimeHttpError::UpgradeRequired.into_response(),
    };

    let token = match extract_access_token(&headers, &query, &handler_state.options) {
        Ok(token) => token,
        Err(err) => return err.into_response(),
    };

    let auth = match handler_state
        .socket_server_handle
        .verifier
        .verify_token(&token)
        .await
    {
        Ok(auth) => auth,
        Err(err) => return RealtimeHttpError::VerifyFailed(err).into_response(),
    };

    upgrade
        .max_message_size(realtime.max_message_bytes())
        .max_frame_size(realtime.max_message_bytes())
        .on_upgrade(move |socket| async move {
            realtime.serve_socket(socket, auth).await;
        })
        .into_response()
}

fn extract_access_token(
    headers: &HeaderMap,
    query: &SocketQuery,
    options: &RealtimeRouteOptions,
) -> Result<String, RealtimeHttpError> {
    let auth_header = headers
        .get(header::AUTHORIZATION)
        .and_then(|value| value.to_str().ok());

    if let Some(auth_header) = auth_header {
        let header_token = auth_header
            .strip_prefix("Bearer ")
            .map(str::trim)
            .filter(|value| !value.is_empty());

        if let Some(token) = header_token {
            return Ok(token.to_string());
        }

        if options.strict_header_precedence {
            return Err(RealtimeHttpError::InvalidToken);
        }
    }

    if options.allow_query_token
        && let Some(token) = query
            .token
            .as_deref()
            .map(str::trim)
            .filter(|value| !value.is_empty())
    {
        return Ok(token.to_string());
    }

    Err(RealtimeHttpError::MissingToken)
}

#[cfg(test)]
mod tests {
    use axum::http::header;

    use super::*;

    #[test]
    fn extract_access_token_prefers_authorization_header() {
        let mut headers = HeaderMap::new();
        headers.insert(
            header::AUTHORIZATION,
            "Bearer header-token".parse().expect("valid header"),
        );
        let query = SocketQuery {
            token: Some("query-token".to_string()),
        };

        let token = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
            .expect("token should parse");
        assert_eq!(token, "header-token");
    }

    #[test]
    fn extract_access_token_falls_back_to_query_token() {
        let headers = HeaderMap::new();
        let query = SocketQuery {
            token: Some("query-token".to_string()),
        };

        let token = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
            .expect("token should parse");
        assert_eq!(token, "query-token");
    }

    #[test]
    fn extract_access_token_rejects_invalid_header_when_strict() {
        let mut headers = HeaderMap::new();
        headers.insert(
            header::AUTHORIZATION,
            "Token abc".parse().expect("valid header"),
        );
        let query = SocketQuery {
            token: Some("query-token".to_string()),
        };

        let err = extract_access_token(&headers, &query, &RealtimeRouteOptions::default())
            .expect_err("invalid header should fail");
        assert!(matches!(err, RealtimeHttpError::InvalidToken));
    }
}