sockudo-adapter 4.7.0

Connection adapters and horizontal scaling for Sockudo
use super::ConnectionHandler;
use sockudo_core::app::App;
use sockudo_core::capability_token::{TokenAuthContext, validate_capability_token};
use sockudo_core::error::{Error, Result};
use sockudo_core::websocket::SocketId;
use sockudo_protocol::ProtocolVersion;
use sockudo_protocol::messages::{MessageData, PusherMessage};
use sockudo_protocol::protocol_version::{CANONICAL_AUTH_SUCCESS, CANONICAL_TOKEN_EXPIRED};
use sonic_rs::prelude::*;
use sonic_rs::{Value, json};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::time::{Instant, sleep_until};
use tracing::{debug, warn};

pub const TOKEN_EXPIRED_CODE: u32 = 40142;
pub const TOKEN_REVOKED_CODE: u32 = 40160;
pub const TOKEN_GRACE_CLOSE_SECONDS: u64 = 30;

#[derive(Debug, Clone, Default)]
pub struct RevocationRequest {
    pub jti: Option<String>,
    pub client_id: Option<String>,
    pub expires_at: Option<i64>,
    pub ttl_seconds: Option<u64>,
    pub reason: Option<String>,
}

#[derive(Debug, Clone, Default)]
pub struct RevocationResult {
    pub revoked_jti: bool,
    pub revoked_client_id: bool,
    pub closed_connections: usize,
}

fn jti_revocation_key(app_id: &str, jti: &str) -> String {
    format!("sockudo:capability_revocation:{app_id}:jti:{jti}")
}

fn client_revocation_key(app_id: &str, client_id: &str) -> String {
    format!("sockudo:capability_revocation:{app_id}:client:{client_id}")
}

fn ttl_from_request(request: &RevocationRequest, now: i64) -> u64 {
    if let Some(ttl) = request.ttl_seconds {
        return ttl.max(1);
    }
    if let Some(expires_at) = request.expires_at {
        return expires_at.saturating_sub(now).max(1) as u64;
    }
    24 * 60 * 60
}

fn now_seconds() -> Result<i64> {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .map(|duration| duration.as_secs() as i64)
        .map_err(|_| Error::Internal("system clock is before unix epoch".to_string()))
}

impl ConnectionHandler {
    pub(crate) async fn validate_connection_token(
        &self,
        app_config: &App,
        token: &str,
    ) -> Result<TokenAuthContext> {
        let context = validate_capability_token(token, app_config)?;
        if self
            .cache_manager
            .has(&jti_revocation_key(&app_config.id, &context.jti))
            .await?
            || self
                .cache_manager
                .has(&client_revocation_key(&app_config.id, &context.client_id))
                .await?
        {
            return Err(Error::Auth("capability token has been revoked".to_string()));
        }

        Ok(context)
    }

    pub(crate) async fn apply_connection_token(
        &self,
        socket_id: &SocketId,
        app_config: &App,
        context: TokenAuthContext,
        is_refresh: bool,
    ) -> Result<()> {
        let Some(connection) = self
            .connection_manager
            .get_connection(socket_id, &app_config.id)
            .await
        else {
            return Err(Error::ConnectionNotFound);
        };

        if connection.protocol_version != ProtocolVersion::V2 {
            return Err(Error::Auth(
                "capability tokens require protocol V2".to_string(),
            ));
        }

        if is_refresh {
            if let Some(existing) = connection.get_token_auth_context().await {
                if existing.client_id != context.client_id {
                    return Err(Error::Auth(
                        "capability token refresh cannot change client_id".to_string(),
                    ));
                }
            } else if let Some(existing_user_id) = connection.get_user_id().await
                && existing_user_id != context.client_id
            {
                return Err(Error::Auth(
                    "capability token refresh cannot change authenticated identity".to_string(),
                ));
            }
        }

        connection.set_token_auth_context(context.clone()).await;
        self.connection_manager.add_user(connection.clone()).await?;
        self.schedule_token_expiry(socket_id, app_config, &context);

        Ok(())
    }

    pub(crate) async fn handle_auth_token_refresh(
        &self,
        socket_id: &SocketId,
        app_config: &App,
        message: &PusherMessage,
    ) -> Result<()> {
        let Some(connection) = self
            .connection_manager
            .get_connection(socket_id, &app_config.id)
            .await
        else {
            return Err(Error::ConnectionNotFound);
        };
        if connection.protocol_version != ProtocolVersion::V2 {
            return Err(Error::Auth(
                "sockudo:auth is only supported on protocol V2".to_string(),
            ));
        }

        let token = token_from_message(message)?;
        let context = self.validate_connection_token(app_config, &token).await?;
        self.apply_connection_token(socket_id, app_config, context.clone(), true)
            .await?;

        let response = PusherMessage {
            event: Some(ProtocolVersion::V2.wire_event(CANONICAL_AUTH_SUCCESS)),
            data: Some(MessageData::from(json!({
                "client_id": context.client_id,
                "jti": context.jti,
                "exp": context.exp,
            }))),
            channel: None,
            name: None,
            user_id: None,
            tags: None,
            sequence: None,
            conflation_key: None,
            message_id: None,
            stream_id: None,
            serial: None,
            idempotency_key: None,
            extras: None,
            delta_sequence: None,
            delta_conflation_key: None,
        };
        self.send_message_to_socket(&app_config.id, socket_id, response)
            .await
    }

    fn schedule_token_expiry(
        &self,
        socket_id: &SocketId,
        app_config: &App,
        context: &TokenAuthContext,
    ) {
        let handler = self.clone();
        let socket_id = *socket_id;
        let app_config = app_config.clone();
        let jti = context.jti.clone();
        let exp = context.exp;
        tokio::spawn(async move {
            let now = match now_seconds() {
                Ok(now) => now,
                Err(error) => {
                    warn!(error = %error, "failed to schedule token expiry");
                    return;
                }
            };
            let delay = exp.saturating_sub(now) as u64;
            sleep_until(Instant::now() + Duration::from_secs(delay)).await;

            if !handler
                .connection_still_uses_token(&app_config.id, &socket_id, &jti)
                .await
            {
                return;
            }
            if let Err(error) = handler
                .send_token_expired(&app_config.id, &socket_id, TOKEN_EXPIRED_CODE, "expired")
                .await
            {
                warn!(%socket_id, error = %error, "failed to send token_expired");
            }

            tokio::time::sleep(Duration::from_secs(TOKEN_GRACE_CLOSE_SECONDS)).await;
            if handler
                .connection_still_uses_token(&app_config.id, &socket_id, &jti)
                .await
                && let Err(error) = handler
                    .close_connection(&socket_id, &app_config, 4009, "capability token expired")
                    .await
            {
                warn!(%socket_id, error = %error, "failed to close expired token connection");
            }
        });
    }

    async fn connection_still_uses_token(
        &self,
        app_id: &str,
        socket_id: &SocketId,
        jti: &str,
    ) -> bool {
        let Some(connection) = self
            .connection_manager
            .get_connection(socket_id, app_id)
            .await
        else {
            return false;
        };
        connection
            .get_token_auth_context()
            .await
            .is_some_and(|context| context.jti == jti)
    }

    pub(crate) async fn send_token_expired(
        &self,
        app_id: &str,
        socket_id: &SocketId,
        code: u32,
        reason: &str,
    ) -> Result<()> {
        let message = PusherMessage {
            event: Some(ProtocolVersion::V2.wire_event(CANONICAL_TOKEN_EXPIRED)),
            data: Some(MessageData::from(json!({
                "code": code,
                "reason": reason,
            }))),
            channel: None,
            name: None,
            user_id: None,
            tags: None,
            sequence: None,
            conflation_key: None,
            message_id: None,
            stream_id: None,
            serial: None,
            idempotency_key: None,
            extras: None,
            delta_sequence: None,
            delta_conflation_key: None,
        };
        self.send_message_to_socket(app_id, socket_id, message)
            .await
    }

    pub async fn revoke_capability_tokens(
        &self,
        app_config: &App,
        request: RevocationRequest,
    ) -> Result<RevocationResult> {
        if request.jti.is_none() && request.client_id.is_none() {
            return Err(Error::Auth(
                "revocation requires jti or client_id".to_string(),
            ));
        }
        let now = now_seconds()?;
        let ttl = ttl_from_request(&request, now);
        let reason = request.reason.as_deref().unwrap_or("revoked");
        let mut result = RevocationResult::default();

        if let Some(jti) = request.jti.as_deref() {
            self.cache_manager
                .set(&jti_revocation_key(&app_config.id, jti), reason, ttl)
                .await?;
            result.revoked_jti = true;
        }
        if let Some(client_id) = request.client_id.as_deref() {
            self.cache_manager
                .set(
                    &client_revocation_key(&app_config.id, client_id),
                    reason,
                    ttl,
                )
                .await?;
            result.revoked_client_id = true;
        }

        let sockets = self
            .connection_manager
            .get_all_connections(&app_config.id)
            .await?;
        for socket_id in sockets {
            let Some(connection) = self
                .connection_manager
                .get_connection(&socket_id, &app_config.id)
                .await
            else {
                continue;
            };
            let Some(context) = connection.get_token_auth_context().await else {
                continue;
            };
            let jti_matches = request.jti.as_deref() == Some(context.jti.as_str());
            let client_matches = request.client_id.as_deref() == Some(context.client_id.as_str());
            if !jti_matches && !client_matches {
                continue;
            }

            debug!(%socket_id, "closing revoked capability-token connection");
            let _ = self
                .send_token_expired(&app_config.id, &socket_id, TOKEN_REVOKED_CODE, "revoked")
                .await;
            self.close_connection(&socket_id, app_config, 4009, "capability token revoked")
                .await?;
            result.closed_connections += 1;
        }

        Ok(result)
    }
}

fn token_from_message(message: &PusherMessage) -> Result<String> {
    let Some(data) = &message.data else {
        return Err(Error::InvalidMessageFormat(
            "sockudo:auth requires data.token".to_string(),
        ));
    };

    let value = match data {
        MessageData::Json(value) => value.clone(),
        MessageData::String(text) => sonic_rs::from_str::<Value>(text).map_err(|_| {
            Error::InvalidMessageFormat("sockudo:auth data must be JSON".to_string())
        })?,
        MessageData::Structured { .. } => {
            return Err(Error::InvalidMessageFormat(
                "sockudo:auth data must contain token".to_string(),
            ));
        }
    };

    value
        .get("token")
        .and_then(Value::as_str)
        .map(ToOwned::to_owned)
        .ok_or_else(|| Error::InvalidMessageFormat("sockudo:auth requires data.token".to_string()))
}