sockudo 1.4.0

A simple, fast, and secure WebSocket server for real-time applications.
// src/adapter/handler/validation.rs
use super::types::*;
use super::ConnectionHandler;
use crate::error::{Error, Result};
use crate::app::config::App;
use crate::channel::ChannelType;
use crate::protocol::constants::*;
use crate::utils;
use serde_json::Value;

impl ConnectionHandler {
    pub async fn validate_subscription_request(
        &self,
        app_config: &App,
        request: &SubscriptionRequest,
    ) -> Result<()> {
        if !app_config.enabled {
            return Err(Error::ApplicationDisabled);
        }

        // Validate channel name
        crate::utils::validate_channel_name(app_config, &request.channel).await?;

        // Check if authentication is required and provided
        let requires_auth = request.channel.starts_with("presence-")
            || request.channel.starts_with("private-");

        if requires_auth && request.auth.is_none() {
            return Err(Error::AuthError(
                "Authentication signature required for this channel".into()
            ));
        }

        Ok(())
    }

    pub async fn validate_presence_subscription(
        &self,
        app_config: &App,
        request: &SubscriptionRequest,
    ) -> Result<()> {
        let channel_data = request.channel_data.as_ref()
            .ok_or_else(|| Error::InvalidMessageFormat(
                "Missing channel_data for presence channel".into()
            ))?;

        // Parse and validate user info size
        let user_info_payload: Value = serde_json::from_str(channel_data)
            .map_err(|_| Error::InvalidMessageFormat(
                "Invalid channel_data JSON for presence".into()
            ))?;

        let user_info = user_info_payload.get("user_info").cloned().unwrap_or_default();
        let user_info_size_kb = utils::data_to_bytes_flexible(vec![user_info]) / 1024;

        if let Some(max_size) = app_config.max_presence_member_size_in_kb {
            if user_info_size_kb > max_size as usize {
                return Err(Error::ChannelError(format!(
                    "Presence member data size ({}KB) exceeds limit ({}KB)",
                    user_info_size_kb, max_size
                )));
            }
        }

        // Check member count limit
        if let Some(max_members) = app_config.max_presence_members_per_channel {
            let current_count = self.get_channel_member_count(app_config, &request.channel).await?;
            if current_count >= max_members as usize {
                return Err(Error::OverCapacity);
            }
        }

        Ok(())
    }

    pub async fn validate_client_event(
        &self,
        app_config: &App,
        request: &ClientEventRequest,
    ) -> Result<()> {
        // Check if client events are enabled
        if !app_config.enable_client_messages {
            return Err(Error::ClientEventError(
                "Client events are not enabled for this app".into()
            ));
        }

        // Validate event name
        if !request.event.starts_with(CLIENT_EVENT_PREFIX) {
            return Err(Error::InvalidEventName(
                "Client events must start with 'client-'".into()
            ));
        }

        // Validate event name length
        let max_event_len = app_config.max_event_name_length
            .unwrap_or(DEFAULT_EVENT_NAME_MAX_LENGTH as u32);
        if request.event.len() > max_event_len as usize {
            return Err(Error::InvalidEventName(format!(
                "Event name exceeds maximum length of {}", max_event_len
            )));
        }

        // Validate channel name length
        let max_channel_len = app_config.max_channel_name_length
            .unwrap_or(DEFAULT_CHANNEL_NAME_MAX_LENGTH as u32);
        if request.channel.len() > max_channel_len as usize {
            return Err(Error::InvalidChannelName(format!(
                "Channel name exceeds maximum length of {}", max_channel_len
            )));
        }

        // Validate channel type
        let channel_type = ChannelType::from_name(&request.channel);
        if !matches!(channel_type, ChannelType::Private | ChannelType::Presence) {
            return Err(Error::ClientEventError(
                "Client events can only be sent to private or presence channels".into()
            ));
        }

        // Validate payload size
        if let Some(max_payload_kb) = app_config.max_event_payload_in_kb {
            let payload_size = utils::data_to_bytes_flexible(vec![request.data.clone()]);
            if payload_size > (max_payload_kb as usize * 1024) {
                return Err(Error::ClientEventError(format!(
                    "Event payload size ({} bytes) exceeds limit ({}KB)",
                    payload_size, max_payload_kb
                )));
            }
        }

        Ok(())
    }
}