sockudo-core 4.6.0

Core traits, types, error handling, and configuration for Sockudo
Documentation
use crate::app::App;
use crate::error::{Error, Result};
use crate::websocket::ConnectionCapabilities;
use ahash::AHashMap;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use serde::Deserialize;
use std::time::{SystemTime, UNIX_EPOCH};

pub const MAX_CAPABILITY_TOKEN_BYTES: usize = 8 * 1024;
pub const MAX_CAPABILITY_PATTERNS: usize = 100;
pub const MAX_CLIENT_ID_BYTES: usize = 128;
pub const MAX_TOKEN_LIFETIME_SECONDS: i64 = 24 * 60 * 60;
pub const TOKEN_CLOCK_SKEW_SECONDS: i64 = 30;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenAuthContext {
    pub client_id: String,
    pub capabilities: ConnectionCapabilities,
    pub jti: String,
    pub exp: i64,
}

#[derive(Debug, Deserialize)]
struct CapabilityTokenClaims {
    #[serde(rename = "x-sockudo-capability")]
    capability: String,
    #[serde(rename = "x-sockudo-client-id")]
    client_id: String,
    iat: i64,
    exp: i64,
    #[serde(default)]
    nbf: Option<i64>,
    jti: String,
}

pub fn validate_capability_token(token: &str, app: &App) -> Result<TokenAuthContext> {
    if token.len() > MAX_CAPABILITY_TOKEN_BYTES {
        return Err(Error::Auth("capability token exceeds 8 KiB".to_string()));
    }

    let header = decode_header(token)
        .map_err(|_| Error::Auth("capability token has an invalid header".to_string()))?;
    if header.alg != Algorithm::HS256 {
        return Err(Error::Auth("capability token must use HS256".to_string()));
    }
    if header.kid.as_deref() != Some(app.key.as_str()) {
        return Err(Error::Auth(
            "capability token kid does not match app key".to_string(),
        ));
    }

    let mut validation = Validation::new(Algorithm::HS256);
    validation.leeway = TOKEN_CLOCK_SKEW_SECONDS as u64;
    validation.validate_exp = true;
    validation.validate_nbf = true;
    validation.set_required_spec_claims(&["exp"]);

    let decoded = decode::<CapabilityTokenClaims>(
        token,
        &DecodingKey::from_secret(app.secret.as_bytes()),
        &validation,
    )
    .map_err(|_| Error::Auth("capability token signature or claims are invalid".to_string()))?;

    let now = now_seconds()?;
    validate_claims(decoded.claims, now)
}

fn validate_claims(claims: CapabilityTokenClaims, now: i64) -> Result<TokenAuthContext> {
    if claims.client_id.is_empty() || claims.client_id.len() > MAX_CLIENT_ID_BYTES {
        return Err(Error::Auth(
            "x-sockudo-client-id must be 1..=128 bytes".to_string(),
        ));
    }
    if claims.jti.is_empty() {
        return Err(Error::Auth("jti is required".to_string()));
    }
    if claims.iat > now + TOKEN_CLOCK_SKEW_SECONDS {
        return Err(Error::Auth(
            "capability token iat is too far in the future".to_string(),
        ));
    }
    if claims.exp <= now - TOKEN_CLOCK_SKEW_SECONDS {
        return Err(Error::Auth("capability token is expired".to_string()));
    }
    if claims.exp - claims.iat > MAX_TOKEN_LIFETIME_SECONDS {
        return Err(Error::Auth(
            "capability token lifetime exceeds 24 hours".to_string(),
        ));
    }
    if let Some(nbf) = claims.nbf
        && nbf > now + TOKEN_CLOCK_SKEW_SECONDS
    {
        return Err(Error::Auth(
            "capability token is not active yet".to_string(),
        ));
    }

    let capabilities = parse_capability_map(&claims.capability)?;
    Ok(TokenAuthContext {
        client_id: claims.client_id,
        capabilities,
        jti: claims.jti,
        exp: claims.exp,
    })
}

fn parse_capability_map(input: &str) -> Result<ConnectionCapabilities> {
    let grants: AHashMap<String, Vec<String>> = sonic_rs::from_str(input).map_err(|_| {
        Error::Auth("x-sockudo-capability must be a JSON object string".to_string())
    })?;
    if grants.len() > MAX_CAPABILITY_PATTERNS {
        return Err(Error::Auth(
            "x-sockudo-capability exceeds 100 patterns".to_string(),
        ));
    }

    let mut capabilities = ConnectionCapabilities {
        subscribe: Some(Vec::new()),
        publish: Some(Vec::new()),
        history: Some(Vec::new()),
        presence: Some(Vec::new()),
        ..Default::default()
    };

    for (pattern, operations) in grants {
        if pattern.is_empty() {
            return Err(Error::Auth(
                "capability patterns must not be empty".to_string(),
            ));
        }
        for operation in operations {
            match operation.as_str() {
                "subscribe" => capabilities
                    .subscribe
                    .as_mut()
                    .unwrap()
                    .push(pattern.clone()),
                "publish" => capabilities.publish.as_mut().unwrap().push(pattern.clone()),
                "history" => capabilities.history.as_mut().unwrap().push(pattern.clone()),
                "presence" => capabilities
                    .presence
                    .as_mut()
                    .unwrap()
                    .push(pattern.clone()),
                _ => {
                    return Err(Error::Auth(format!(
                        "unsupported capability operation '{operation}'"
                    )));
                }
            }
        }
    }

    Ok(capabilities)
}

fn now_seconds() -> Result<i64> {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .map(|duration| duration.as_secs() as i64)
        .map_err(|_| Error::Internal("system clock is before unix epoch".to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use jsonwebtoken::{EncodingKey, Header, encode};
    use serde::Serialize;

    #[derive(Serialize)]
    struct TestClaims<'a> {
        #[serde(rename = "x-sockudo-capability")]
        capability: &'a str,
        #[serde(rename = "x-sockudo-client-id")]
        client_id: &'a str,
        iat: i64,
        exp: i64,
        #[serde(skip_serializing_if = "Option::is_none")]
        nbf: Option<i64>,
        jti: &'a str,
    }

    fn app() -> App {
        App::from_policy(
            "app".to_string(),
            "app-key".to_string(),
            "app-secret".to_string(),
            true,
            Default::default(),
        )
    }

    fn token_with(claims: TestClaims<'_>, alg: Algorithm, kid: &str, secret: &str) -> String {
        let mut header = Header::new(alg);
        header.kid = Some(kid.to_string());
        encode(
            &header,
            &claims,
            &EncodingKey::from_secret(secret.as_bytes()),
        )
        .unwrap()
    }

    fn claims(now: i64, capability: &str) -> TestClaims<'_> {
        TestClaims {
            capability,
            client_id: "client-1",
            iat: now,
            exp: now + 3600,
            nbf: None,
            jti: "jti-1",
        }
    }

    #[test]
    fn validates_hs256_token_and_builds_capabilities() {
        let now = now_seconds().unwrap();
        let token = token_with(
            claims(
                now,
                r#"{"room:*":["subscribe","history"],"private-room":["publish"],"presence-room":["presence"]}"#,
            ),
            Algorithm::HS256,
            "app-key",
            "app-secret",
        );

        let context = validate_capability_token(&token, &app()).unwrap();
        assert_eq!(context.client_id, "client-1");
        assert!(context.capabilities.allows_subscribe("room:1"));
        assert!(context.capabilities.allows_history("room:1"));
        assert!(context.capabilities.allows_publish("private-room"));
        assert!(context.capabilities.allows_subscribe("presence-room"));
        assert!(!context.capabilities.allows_publish("room:1"));
    }

    #[test]
    fn rejects_wrong_kid() {
        let now = now_seconds().unwrap();
        let token = token_with(
            claims(now, r#"{"*":["subscribe"]}"#),
            Algorithm::HS256,
            "other-key",
            "app-secret",
        );

        assert!(validate_capability_token(&token, &app()).is_err());
    }

    #[test]
    fn rejects_non_hs256_algorithm() {
        let now = now_seconds().unwrap();
        let token = token_with(
            claims(now, r#"{"*":["subscribe"]}"#),
            Algorithm::HS384,
            "app-key",
            "app-secret",
        );

        assert!(validate_capability_token(&token, &app()).is_err());
    }

    #[test]
    fn rejects_unsupported_operation() {
        let now = now_seconds().unwrap();
        let token = token_with(
            claims(now, r#"{"*":["subscribe","admin"]}"#),
            Algorithm::HS256,
            "app-key",
            "app-secret",
        );

        assert!(validate_capability_token(&token, &app()).is_err());
    }

    #[test]
    fn rejects_missing_jti() {
        let now = now_seconds().unwrap();
        let token = token_with(
            TestClaims {
                jti: "",
                ..claims(now, r#"{"*":["subscribe"]}"#)
            },
            Algorithm::HS256,
            "app-key",
            "app-secret",
        );

        assert!(validate_capability_token(&token, &app()).is_err());
    }

    #[test]
    fn rejects_lifetime_over_24_hours() {
        let now = now_seconds().unwrap();
        let mut test_claims = claims(now, r#"{"*":["subscribe"]}"#);
        test_claims.exp = now + MAX_TOKEN_LIFETIME_SECONDS + 1;
        let token = token_with(test_claims, Algorithm::HS256, "app-key", "app-secret");

        assert!(validate_capability_token(&token, &app()).is_err());
    }

    #[test]
    fn wildcard_prefix_does_not_match_bare_namespace() {
        let capabilities = parse_capability_map(r#"{"ns:*":["subscribe"]}"#).unwrap();

        assert!(!capabilities.allows_subscribe("ns"));
        assert!(capabilities.allows_subscribe("ns:one"));
        assert!(!capabilities.allows_subscribe("NS:one"));
    }
}