rsub 0.1.0

A high-performance message broker with QUIC transport and pub/sub messaging patterns
Documentation
use anyhow::{anyhow, Result};
use bytes::Bytes;
use quinn::{Endpoint, VarInt};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::vec;
use tokio::sync::RwLock;
use tracing::{debug, error, info, trace};

use super::config::{Config, ServerConfig, TlsConfig};
use super::connection::ClientConnection;
use crate::common::auth::{AuthConfig, AuthType};
use crate::common::message::{Message, MessageType};
use crate::common::stream::{RecvStream, SendStream};
use crate::common::tls::create_server_config;

/// Represents the QUIC server for RSUB.
#[derive(Clone)]
pub struct RsubServer {
    endpoint: Endpoint,
    connections: Arc<RwLock<HashMap<String, Arc<ClientConnection>>>>,
    send_streams: Arc<RwLock<HashMap<String, Arc<SendStream>>>>,
    recv_streams: Arc<RwLock<HashMap<String, Arc<RecvStream>>>>,
    topics: Arc<RwLock<HashMap<String, Vec<Arc<SendStream>>>>>,
    config: Arc<Config>,
}

impl RsubServer {
    /// Starts the RSUB server with the given configuration.
    pub async fn start(config: Arc<Config>) -> Result<()> {
        let config = config.clone();
        let server_config = &config.server;
        let tls_config = &config.tls;

        let addr = SocketAddr::new(server_config.host.parse()?, server_config.port);
        let (endpoint, _server_cert) = Self::create_endpoint(&server_config, &tls_config)?;

        info!("Server listening on {}", addr);
        trace!("Server configuration: {:?}", server_config);

        let server = Self::new(endpoint, config.clone());
        server.accept_connections().await?;

        Ok(())
    }

    fn new(endpoint: Endpoint, config: Arc<Config>) -> Self {
        trace!("Creating new RsubServer instance");
        RsubServer {
            endpoint,
            connections: Arc::new(RwLock::new(HashMap::new())),
            send_streams: Arc::new(RwLock::new(HashMap::new())),
            recv_streams: Arc::new(RwLock::new(HashMap::new())),
            topics: Arc::new(RwLock::new(HashMap::new())),
            config,
        }
    }

    fn create_endpoint(
        server_config: &ServerConfig,
        tls_config: &TlsConfig,
    ) -> Result<(Endpoint, Vec<u8>)> {
        trace!("Creating server endpoint with TLS configuration");
        let (quinn_server_config, server_cert) = create_server_config(tls_config)?;
        let addr = SocketAddr::new(server_config.host.parse()?, server_config.port);
        let endpoint = Endpoint::server(quinn_server_config, addr)?;

        Ok((endpoint, server_cert))
    }

    async fn accept_connections(&self) -> Result<()> {
        while let Some(conn) = self.endpoint.accept().await {
            let connection = conn.await?;
            let client_id = connection.remote_address().to_string();
            trace!("New connection attempt from {}", client_id);
            let client_connection = ClientConnection::new(connection).await?;

            self.connections
                .write()
                .await
                .insert(client_id.clone(), Arc::new(client_connection));

            let server = self.clone();
            let client_id = client_id.clone();

            tokio::spawn(async move {
                if let Err(e) = server.handle_client(client_id.clone()).await {
                    error!("Error handling client {}: {}", client_id, e);
                }
            });
        }
        Ok(())
    }

    async fn handle_client(&self, client_id: String) -> Result<()> {
        trace!("Starting to handle client {}", client_id);
        let connection = self
            .connections
            .read()
            .await
            .get(&client_id)
            .cloned()
            .ok_or_else(|| anyhow!("Client connection not found"))?;

        // Open bi-directional stream for authentication
        let (send, recv) = connection.connection.accept_bi().await?;
        trace!("Opened bi-directional stream for client {}", client_id);

        let send_stream = SendStream::new("auth_send".to_string(), send);
        let recv_stream = RecvStream::new("auth_recv".to_string(), recv);

        // Receive auth message
        let auth_message = recv_stream.recv_message().await?;
        trace!("Received auth message from client {}", client_id);
        if auth_message.message_type != MessageType::Auth {
            return Err(anyhow!("Expected auth message"));
        }

        let auth_config: AuthConfig = serde_json::from_slice(&auth_message.data)?;
        trace!("Parsed auth config from client {}", client_id);

        let is_authenticated = match auth_config.auth_type {
            AuthType::Basic(basic_auth) => self
                .config
                .auth
                .validate(format!("{}:{}", basic_auth.username, basic_auth.password).as_str()),
            AuthType::Bearer(bearer_auth) => self.config.auth.validate(bearer_auth.token.as_str()),
            AuthType::None => true,
        };

        if !is_authenticated {
            trace!("Authentication failed for client {}", client_id);
            let response = Message::new(MessageType::Response, vec!["auth".to_string()], vec![0]);
            send_stream.write_message(response).await?;
            connection
                .connection
                .close(VarInt::from_u32(0), b"Invalid credentials");
            return Err(anyhow!("Invalid credentials"));
        }

        trace!("Client {} successfully authenticated", client_id);
        // Send success response
        let response = Message::new(MessageType::Response, vec!["auth".to_string()], vec![200]);
        send_stream.write_message(response).await?;

        // Handle messages
        while let Ok(message) = recv_stream.recv_message().await {
            trace!(
                "Received message from client {}: {:?}",
                client_id,
                message.message_type
            );
            match message.message_type {
                MessageType::Subscribe => {
                    for topic in message.topics {
                        self.topics
                            .write()
                            .await
                            .entry(topic.clone())
                            .or_default()
                            .push(Arc::new(send_stream.clone()));
                        debug!("Client {} subscribed to topic {}", client_id, topic);
                        trace!(
                            "Added subscription for client {} to topic {}",
                            client_id,
                            topic
                        );
                    }
                }
                MessageType::Message => {
                    for topic in message.topics.clone() {
                        trace!(
                            "Processing message for topic {} from client {}",
                            topic,
                            client_id
                        );
                        let message_clone = message.clone();
                        if let Some(subscribers) = self.topics.read().await.get(&topic) {
                            for subscriber in subscribers {
                                trace!("Forwarding message to subscriber {}", subscriber.id());
                                subscriber.write_message(message_clone.clone()).await?;
                            }
                        }
                    }
                }
                _ => {
                    trace!(
                        "Ignoring unsupported message type from client {}",
                        client_id
                    );
                }
            }
        }

        trace!("Client {} connection ended, starting cleanup", client_id);
        self.cleanup_client(&client_id).await;
        Ok(())
    }

    async fn cleanup_client(&self, client_id: &str) {
        trace!("Starting cleanup for client {}", client_id);
        self.connections.write().await.remove(client_id);
        self.send_streams
            .write()
            .await
            .retain(|k, _| !k.starts_with(client_id));
        self.recv_streams
            .write()
            .await
            .retain(|k, _| !k.starts_with(client_id));

        // Remove client's streams from topics
        let mut topics = self.topics.write().await;
        for streams in topics.values_mut() {
            streams.retain(|s| !s.id().starts_with(client_id));
        }
        // Remove empty topics
        topics.retain(|_, streams| !streams.is_empty());
        debug!("Cleaned up resources for client {}", client_id);
        trace!("Completed cleanup for client {}", client_id);
    }
}