rsub 0.1.0

A high-performance message broker with QUIC transport and pub/sub messaging patterns
Documentation
use anyhow::{anyhow, Context, Result};
use bytes::Bytes;
use quinn::{Connection, Endpoint};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, error, info, trace};

use crate::common::auth::AuthConfig;
use crate::common::message::{Message, MessageType};
use crate::common::stream::{RecvStream, SendStream};
use crate::common::tls::create_client_config;

type MessageCallback = Box<dyn Fn(Message) + Send + Sync + 'static>;
type MessageCallbacks = Arc<RwLock<HashMap<MessageType, Vec<MessageCallback>>>>;
type ResponseSender = mpsc::Sender<Message>;
type ResponseReceivers = Arc<RwLock<HashMap<String, ResponseSender>>>;

/// A client for the RSub pub/sub system
pub struct RsubClient {
    connection: Connection,
    auth: AuthConfig,
    send_stream: Arc<SendStream>,
    recv_stream: Arc<RecvStream>,
    callbacks: MessageCallbacks,
    response_receivers: ResponseReceivers,
}

impl RsubClient {
    /// Establishes a new connection to an RSub server
    ///
    /// # Arguments
    /// * `addr` - The server address to connect to
    /// * `ca_path` - Path to the CA certificate file
    /// * `auth` - Authentication configuration
    ///
    /// # Returns
    /// A new RsubClient instance if successful
    pub async fn connect(addr: SocketAddr, ca_path: &str, auth: AuthConfig) -> Result<Self> {
        trace!("Creating new client connection to {}", addr);
        // Initialize TLS config and endpoint
        let client_config = create_client_config(ca_path).context("Failed to create TLS config")?;
        let mut endpoint =
            Endpoint::client("[::]:0".parse().unwrap()).context("Failed to create endpoint")?;
        endpoint.set_default_client_config(client_config);

        trace!("Attempting to establish QUIC connection");
        // Establish QUIC connection
        let connection = endpoint
            .connect(addr, "localhost")?
            .await
            .context("Failed to establish connection")?;
        trace!("QUIC connection established successfully");

        // Open all required streams
        trace!("Opening bi-directional streams");
        let (send, recv) = connection
            .open_bi()
            .await
            .context("Failed to open authentication stream")?;

        let send_stream = SendStream::new("auth_send".to_string(), send);
        let recv_stream = RecvStream::new("auth_recv".to_string(), recv);
        trace!("Bi-directional streams opened successfully");

        // Create client instance
        trace!("Creating new RsubClient instance");
        let client = Self {
            connection,
            auth,
            send_stream: Arc::new(send_stream),
            recv_stream: Arc::new(recv_stream),
            callbacks: Arc::new(RwLock::new(HashMap::new())),
            response_receivers: Arc::new(RwLock::new(HashMap::new())),
        };

        // Start message handling loop
        trace!("Starting message handling loop");
        client.start_message_loop();

        // Authenticate
        trace!("Beginning authentication process");
        client.authenticate().await?;

        Ok(client)
    }

    /// Starts the message handling loop
    fn start_message_loop(&self) {
        let recv_stream = self.recv_stream.clone();
        let callbacks = self.callbacks.clone();
        let response_receivers = self.response_receivers.clone();

        tokio::spawn(async move {
            trace!("Message handling loop started");
            loop {
                match recv_stream.recv_message().await {
                    Ok(message) => {
                        trace!("Received message: {:?}", message);
                        // Handle callbacks for message type
                        if let Some(handlers) = callbacks.read().await.get(&message.message_type) {
                            trace!(
                                "Found {} handlers for message type {:?}",
                                handlers.len(),
                                message.message_type
                            );
                            for handler in handlers {
                                trace!("Executing message handler");
                                handler(message.clone());
                            }
                        }

                        // Handle response receivers
                        if message.message_type == MessageType::Response {
                            if let Some(topic) = message.topics.first() {
                                trace!("Processing response for topic: {}", topic);
                                if let Some(sender) = response_receivers.write().await.remove(topic)
                                {
                                    trace!("Sending response to waiting receiver");
                                    let _ = sender.send(message).await;
                                }
                            }
                        }
                    }
                    Err(e) => {
                        error!("Error reading message: {}", e);
                        trace!("Breaking message loop due to error");
                        break;
                    }
                }
            }
        });
    }

    /// Registers a callback for a specific message type
    pub async fn on_message(
        &self,
        msg_type: MessageType,
        callback: impl Fn(Message) + Send + Sync + 'static,
    ) {
        trace!("Registering callback for message type {:?}", msg_type);
        self.callbacks
            .write()
            .await
            .entry(msg_type)
            .or_default()
            .push(Box::new(callback));
    }

    /// Waits for a response with a specific tag
    pub async fn wait_for_response(&self, tag: &str) -> Result<Message> {
        trace!("Creating response receiver for tag: {}", tag);
        let (tx, mut rx) = mpsc::channel(1);
        self.response_receivers
            .write()
            .await
            .insert(tag.to_string(), tx);

        trace!("Waiting for response with tag: {}", tag);
        Ok(rx.recv().await.context("Failed to receive response")?)
    }

    /// Authenticates the client with the server
    ///
    /// # Returns
    /// * `Ok(())` if authentication is successful
    /// * `Err` if there was an error sending the authentication message
    async fn authenticate(&self) -> Result<()> {
        trace!("Serializing authentication config");
        let auth_bytes =
            serde_json::to_vec(&self.auth).context("Failed to serialize auth config")?;

        trace!("Sending authentication message");
        self.send_stream
            .write_message(Message::new(MessageType::Auth, Vec::new(), auth_bytes))
            .await
            .context("Failed to send auth frame")?;

        // Wait for auth response
        trace!("Waiting for authentication response");
        let response = self.wait_for_response("auth").await?;
        if response.data.is_empty() {
            trace!("Authentication failed: empty response data");
            return Err(anyhow!("Failed to authenticate"));
        }

        if response.data[0] != 200 {
            trace!("Authentication failed: status code {}", response.data[0]);
            return Err(anyhow!("Failed to authenticate"));
        }

        info!("Successfully authenticated with server");
        trace!("Authentication process completed successfully");
        Ok(())
    }

    /// Subscribes to topics and handles incoming messages with a callback
    ///
    /// # Arguments
    /// * `topics` - List of topics to subscribe to
    /// * `callback` - Function to handle received messages
    pub async fn subscribe(
        &self,
        topics: Vec<String>,
        callback: impl Fn(Message) + Send + Sync + 'static,
    ) -> Result<()> {
        trace!("Registering callback for subscription messages");
        // Register callback for messages
        self.on_message(MessageType::Message, callback).await;

        trace!("Sending subscription request for topics: {:?}", topics);
        // Send subscription request
        self.send_stream
            .write_message(Message::new(MessageType::Subscribe, topics, Vec::new()))
            .await
            .context("Failed to send subscription")?;

        Ok(())
    }

    /// Publishes a message to specified topics
    ///
    /// # Arguments
    /// * `topics` - List of topics to publish to
    /// * `data` - Message payload
    pub async fn publish(&self, topics: Vec<String>, data: Bytes) -> Result<()> {
        trace!("Publishing message to topics: {:?}", topics);
        self.send_stream
            .write_message(Message::new(MessageType::Message, topics, data.to_vec()))
            .await
            .context("Failed to send message")?;

        Ok(())
    }

    /// Gracefully closes the connection
    pub async fn close(&self) -> Result<()> {
        trace!("Closing client connection");
        self.connection
            .close(0u32.into(), b"Client closed connection");
        Ok(())
    }

    /// Returns a reference to the underlying connection
    pub fn connection(&self) -> &Connection {
        &self.connection
    }

    /// Checks if the connection is closed
    pub async fn is_closed(&self) -> bool {
        let is_closed = self.connection.close_reason().is_some();
        trace!("Connection closed status: {}", is_closed);
        is_closed
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::common::auth::{AuthType, BasicAuth};
    use tokio::runtime::Runtime;

    const TEST_ADDR: &str = "127.0.0.1:4222";
    const TEST_CERT: &str = "cert2.der";
    const TEST_TOPIC: &str = "topic1";

    #[test]
    fn test_client_connection() {
        tracing_subscriber::fmt()
            .with_max_level(tracing::Level::TRACE)
            .init();

        let runtime = Runtime::new().unwrap();

        runtime.block_on(async {
            // Set up test client
            let client = RsubClient::connect(
                TEST_ADDR.parse().unwrap(),
                TEST_CERT,
                AuthConfig::new(AuthType::Basic(BasicAuth {
                    username: "admin".to_string(),
                    password: "secret".to_string(),
                })),
            )
            .await
            .expect("Failed to connect");

            client
                .subscribe(vec![TEST_TOPIC.to_string()], |message| {
                    debug!("Received message: {:?}", message);
                })
                .await
                .expect("Failed to subscribe");

            client
                .publish(vec![TEST_TOPIC.to_string()], Bytes::from("Hello RSub"))
                .await
                .expect("Failed to publish");

            tokio::time::sleep(std::time::Duration::from_secs(15)).await;

            client.close().await.expect("Failed to close client");
        });
    }
}