scalesocket 0.2.4

A collaborative websocket server and autoscaler
use {
    prometheus_client::encoding::text::encode,
    prometheus_client::registry::Registry,
    serde_json::json,
    std::path::PathBuf,
    warp::ws::Ws,
    warp::{self, Filter, Rejection, Reply, http::Response},
};

use crate::{
    cli::Config,
    envvars::Env,
    metrics::Metrics,
    types::{Event, EventTx, RoomID, ShutdownRx},
    utils::warpext::{self, handle_rejection},
};

const RESERVED_ROOMS: &[&str] = &[
    "api",
    "metrics",
    "health",
    "static",
    "upload",
    "robots.txt",
    "favicon.ico",
];

pub fn handle(
    tx: EventTx,
    config: Config,
    shutdown_rx: ShutdownRx,
    metrics: Metrics,
    registry: Option<Registry>,
) -> impl futures::Future<Output = ()> {
    let shutdown_rx = async {
        shutdown_rx.await.ok();
    };

    warp::serve(
        socket(tx, config.rooms)
            .or(health())
            .or(openmetrics(registry, config.metrics))
            .or(rooms_api(metrics.clone(), config.api))
            .or(metadata_api(metrics, config.api))
            .or(files(config.static_dir))
            .recover(handle_rejection),
    )
    .bind_with_graceful_shutdown(config.addr, shutdown_rx)
    .1
}

pub fn socket(
    tx: EventTx,
    allowed_rooms: Option<Vec<String>>,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
    warpext::path::param_matches::<RoomID>(Some(RESERVED_ROOMS), allowed_rooms)
        .and(warp::path::end())
        .and(warp::ws())
        .and(warpext::env())
        .map(move |room: RoomID, websocket: Ws, mut env: Env| {
            let tx = tx.clone();
            env.set_room(&room);
            websocket.on_upgrade(move |ws| {
                let ws = Box::new(ws);
                let event = Event::Connect { env, room, ws };
                tx.send(event).expect("Failed to send Connect event");
                futures::future::ready(())
            })
        })
}

pub fn health() -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
    warp::path("health")
        .and(warp::path::end())
        .and(warp::get())
        .map(|| warp::reply::json(&json!({"status" : "ok"})))
}

pub fn openmetrics(
    registry: Option<Registry>,
    enabled: bool,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
    let registry = std::sync::Arc::new(registry);

    warpext::enable_if(enabled)
        .and(warp::path("metrics"))
        .and(warp::path::end())
        .and(warp::get())
        .map(move || {
            // Encode metrics
            let mut buffer = String::new();
            let res = match *registry {
                Some(ref registry) => encode(&mut buffer, registry),
                None => unreachable!("Registry should always be Some"),
            };

            let encoded = match res.is_ok() {
                true => Some(buffer),
                false => None,
            };

            let builder = Response::builder().header(
                "content-type",
                "application/openmetrics-text; version=1.0.0; charset=utf-8",
            );
            match encoded {
                Some(data) => builder.body(data),
                None => builder.status(500).body(String::default()),
            }
        })
}

pub fn rooms_api(
    metrics: Metrics,
    enabled: bool,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
    warpext::enable_if(enabled)
        .and(warp::path!("api" / "rooms"))
        .and(warp::path::end())
        .and(warp::get())
        .map(move || warp::reply::json(&metrics.get_rooms()))
}

pub fn metadata_api(
    metrics: Metrics,
    enabled: bool,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
    let metric = warp::path::param::<String>()
        .map(Some)
        .or_else(|_| async { Ok::<(Option<String>,), std::convert::Infallible>((None,)) });

    warpext::enable_if(enabled)
        .and(warp::path!("api" / RoomID / ..))
        .and(metric)
        .and(warp::path::end())
        .and(warp::get())
        .map(
            move |room: RoomID, metric: Option<String>| match metric.as_deref() {
                Some("connections") => warp::reply::json(&metrics.get_room_connections(room)),
                Some("metadata") => warp::reply::json(&metrics.get_room_metadata(&room)),
                _ => warp::reply::json(&metrics.get_room(room)),
            },
        )
}

pub fn files(
    path: Option<PathBuf>,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone {
    warpext::enable_if(path.is_some()).and(warp::fs::dir(path.unwrap_or_default()))
}

#[cfg(test)]
mod tests {
    use prometheus_client::metrics::counter::Counter;
    use prometheus_client::metrics::family::Family;
    use serde_json::{self, Value};
    use warp::http::StatusCode;
    use warp::test::{RequestBuilder, request};

    use super::*;

    fn ws_request(path: &'static str) -> RequestBuilder {
        request()
            .method("GET")
            .header("Connection", "Upgrade")
            .header("Upgrade", "websocket")
            .header("Sec-WebSocket-Key", "SGVsbG8sIHdvcmxkIQ==")
            .header("Sec-WebSocket-Version", 13)
            .path(path)
    }

    fn browser_request(path: &'static str) -> RequestBuilder {
        request()
            .method("GET")
            .header("Connection", "keep-alive")
            .path(path)
    }

    #[tokio::test]
    async fn health_returns_ok() {
        let api = health();

        let resp = request().method("GET").path("/health").reply(&api).await;

        assert_eq!(resp.status(), StatusCode::OK);
        assert_eq!(resp.body(), "{\"status\":\"ok\"}");
    }

    #[tokio::test]
    async fn socket_rejects_reserved_room() {
        let (tx, _) = tokio::sync::mpsc::unbounded_channel::<Event>();
        let api = socket(tx, None).recover(handle_rejection);

        let ws = |path| ws_request(path).reply(&api);

        assert_eq!(ws("/ok").await.status(), StatusCode::SWITCHING_PROTOCOLS);
        assert_eq!(ws("/api").await.status(), StatusCode::BAD_REQUEST);
        assert_eq!(ws("/api/rooms").await.status(), StatusCode::BAD_REQUEST);
        assert_eq!(ws("/metrics").await.status(), StatusCode::BAD_REQUEST);
        assert_eq!(ws("/health").await.status(), StatusCode::BAD_REQUEST);
    }

    #[tokio::test]
    async fn socket_detects_browser_request() {
        let (tx, _) = tokio::sync::mpsc::unbounded_channel::<Event>();
        let api = socket(tx, None).recover(handle_rejection);

        let resp = browser_request("/room").reply(&api).await;
        assert_eq!(resp.status(), StatusCode::NOT_FOUND);

        let resp = ws_request("/room").reply(&api).await;
        assert_eq!(resp.status(), StatusCode::SWITCHING_PROTOCOLS);
    }

    #[tokio::test]
    async fn socket_accepts_allowed_room() {
        let (tx, _) = tokio::sync::mpsc::unbounded_channel::<Event>();
        let allowlist = vec!["allowed".to_string()];
        let api = socket(tx, Some(allowlist)).recover(handle_rejection);

        let ws = |path| ws_request(path).reply(&api);

        assert_eq!(
            ws("/allowed").await.status(),
            StatusCode::SWITCHING_PROTOCOLS
        );
        assert_eq!(ws("/other").await.status(), StatusCode::BAD_REQUEST);
    }

    #[tokio::test]
    async fn metrics_returns_metrics() {
        let mut registry = <Registry>::default();
        registry.register(
            "example_metric",
            "Example description",
            Family::<Vec<(String, String)>, Counter>::default(),
        );
        let api = openmetrics(Some(registry), true);

        let resp = request().method("GET").path("/metrics").reply(&api).await;

        assert!(resp.status().is_success());
        assert_eq!(
            resp.body(),
            "# HELP example_metric Example description.\n# TYPE example_metric counter\n# EOF\n"
        );
    }

    #[tokio::test]
    async fn room_api_returns_all_rooms() {
        let metrics = Metrics::new(&mut None, true);
        metrics.inc_ws_connections("foo");
        metrics.inc_ws_connections("bar");
        metrics.set_metadata("foo", json!({"_meta": true, "bar": 123}));

        let api = rooms_api(metrics, true);

        let resp = request()
            .method("GET")
            .path("/api/rooms/")
            .reply(&api)
            .await;

        assert!(resp.status().is_success());
        let body: Vec<Value> = serde_json::from_slice(resp.body()).unwrap();
        assert!(
            body.contains(
                &json!({"name": "foo", "connections": 1, "metadata": json!({"bar": 123})})
            )
        );
        assert!(body.contains(&json!({"name": "bar", "connections": 1, "metadata": null})));
    }

    #[tokio::test]
    async fn metadata_api_returns_room_metadata() {
        let metrics = Metrics::new(&mut None, true);
        metrics.set_metadata("foo", json!({"_meta": true, "bar": 123}));

        let api = metadata_api(metrics, true);

        let resp = request().method("GET").path("/api/foo/").reply(&api).await;

        assert!(resp.status().is_success());
        let body: Value = serde_json::from_slice(resp.body()).unwrap();
        assert_eq!(body["metadata"], json!({"bar": 123}));
    }

    #[tokio::test]
    async fn metadata_api_returns_room_metric() {
        let metrics = Metrics::new(&mut None, true);
        metrics.inc_ws_connections("foo");
        metrics.inc_ws_connections("bar");

        let api = metadata_api(metrics, true);

        let resp = request()
            .method("GET")
            .path("/api/foo/connections")
            .reply(&api)
            .await;

        assert!(resp.status().is_success());
        assert_eq!(resp.body(), "1");
    }
}