spark-channel 0.0.1

A generic channel listener for Spark
Documentation
//! # Spark Listener
//!
//! This crate is used to listen for messages between modules. This crate is not specific to any of the traits but a generic implementation to implement them in an event-driven architecture.

use core::any::Any;
use std::{fmt::Debug, marker::PhantomData};

use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};

use crate::callback::CallbackWrapper;

/// A trait for handling module requests and commands.
#[async_trait]
pub trait SparkGenericModuleHandler<Message, Response, Error = eyre::Error> {
    /// Handle a request, potentially returning an error.
    async fn handle_request(&mut self, request: Message) -> Result<Response, Error>;

    /// Handle a command, potentially returning an error.
    async fn handle_command(&mut self, command: Message) -> Result<(), Error>;
}

/// Trait for cancellation of execution.
pub trait SparkChannelCancellationTrait {
    /// Cancels the execution of a running task.
    fn cancel_execution(&self);
}

/// The message type for the module dispatcher.
pub enum SparkGenericModuleMessage<
    Message,
    Response,
    CancellationMessage: SparkChannelCancellationTrait,
    Error,
> {
    /// A request message with an error type.
    Request(CallbackWrapper<Message, Result<Response, Error>>),

    /// A command message.
    Command(Message),

    /// A shutdown message.
    Shutdown(CancellationMessage),
}

/// Trait for converting between result types.
pub trait IntoResult<Success, Error> {
    /// The output type.
    type Output;

    /// Convert a result to the output type.
    fn into_result(result: Result<Success, Error>) -> Self::Output;
}

// Implementation for eyre::Result
impl<T, E> IntoResult<T, E> for eyre::Result<T>
where
    E: std::fmt::Display + Send + Sync + 'static,
{
    type Output = eyre::Result<T>;

    fn into_result(result: Result<T, E>) -> Self::Output {
        result.map_err(|e| eyre::eyre!("{}", e))
    }
}

/// Implementation specific to SparkChannelError. This can be overridden by the user to use their own error type.
#[derive(Debug, Clone)]
pub struct SparkChannelError(pub String);

impl std::fmt::Display for SparkChannelError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

impl std::error::Error for SparkChannelError {}

impl From<&'static str> for SparkChannelError {
    fn from(s: &'static str) -> Self {
        SparkChannelError(s.to_string())
    }
}

impl From<String> for SparkChannelError {
    fn from(s: String) -> Self {
        SparkChannelError(s)
    }
}

impl<T> IntoResult<T, SparkChannelError> for Result<T, SparkChannelError> {
    type Output = Result<T, SparkChannelError>;

    fn into_result(result: Result<T, SparkChannelError>) -> Self::Output {
        result
    }
}

/// A dispatcher for module types.
#[derive(Clone)]
pub struct SparkGenericModuleDispatcher<
    Message,
    Response,
    CancellationMessage: SparkChannelCancellationTrait,
    Error,
> {
    /// The sender for the module dispatcher.
    pub sender:
        mpsc::Sender<SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>>,
    /// Phantom data to mark the Error type
    _error_marker: PhantomData<Error>,
}

impl<Message, Response, CancellationMessage, Error>
    SparkGenericModuleDispatcher<Message, Response, CancellationMessage, Error>
where
    Error: std::error::Error + Send + Sync + 'static,
    CancellationMessage: SparkChannelCancellationTrait,
{
    /// Creates a new module dispatcher.
    #[must_use]
    pub const fn new(
        sender: mpsc::Sender<
            SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>,
        >,
    ) -> Self {
        Self {
            sender,
            _error_marker: PhantomData,
        }
    }

    /// Send a request and get typed response using your callback pattern.
    async fn internal_request<Req, Resp, R>(&self, request: Req) -> R::Output
    where
        Req: Into<Message> + Send + 'static,
        Resp: 'static + Send + Clone,
        Response: AsRef<dyn Any + Send>,
        R: IntoResult<Resp, Error>,
        Error: Debug + Send + Sync + From<String> + From<&'static str>,
    {
        let (callback_tx, callback_rx) = oneshot::channel();

        let wrapper = CallbackWrapper {
            message: request.into(),
            sender: callback_tx,
        };

        // Handle channel send error
        let send_result = self
            .sender
            .send(SparkGenericModuleMessage::Request(wrapper))
            .await;

        if let Err(_) = send_result {
            return R::into_result(Err(Error::from("Failed to send request")));
        }

        // Handle channel receive error
        let receive_result = callback_rx.await;
        let result = match receive_result {
            Ok(result) => result,
            Err(_) => return R::into_result(Err(Error::from("Failed to receive response"))),
        };

        // Handle application result (success or error)
        match result {
            Ok(response) => {
                // Try to downcast the response
                match response.as_ref().downcast_ref::<Resp>() {
                    Some(value) => R::into_result(Ok(value.clone())),
                    None => R::into_result(Err(Error::from("Invalid response type"))),
                }
            },
            Err(err) => {
                // Use a static string with From trait to avoid lifetime issues
                let error_message = format!("Error handling request: {:?}", err);
                tracing::debug!("Received error on handler callback: {}", error_message);
                R::into_result(Err(Error::from(error_message)))
            },
        }
    }

    /// Sends a request and get typed response using your callback pattern.
    pub async fn request<Req>(&self, request: Req) -> Result<Response, Error>
    where
        Req: Into<Message> + Send + 'static,
        Response: 'static + Send + Clone,
        Response: AsRef<dyn Any + Send>,
        Result<Response, Error>: IntoResult<Response, Error, Output = Result<Response, Error>>,
        Error: Debug + Send + Sync + From<String> + From<&'static str>,
    {
        // Use the existing send_request method but with the known types
        self.internal_request::<_, Response, Result<_, Error>>(request)
            .await
    }

    /// Send a command (fire and forget).
    pub async fn send_command<C, R>(&self, command: C) -> R::Output
    where
        C: Into<Message> + Send + 'static,
        R: IntoResult<(), Error>,
        Error: From<&'static str>,
    {
        let send_result = self
            .sender
            .send(SparkGenericModuleMessage::Command(command.into()))
            .await;

        match send_result {
            Ok(_) => R::into_result(Ok(())),
            Err(_) => R::into_result(Err(Error::from("Failed to send command"))),
        }
    }
}

/// Runs the module server.
pub async fn run_module_server<Message, Response, CancellationToken, Error, H>(
    mut handler: H,
    mut receiver: mpsc::Receiver<
        SparkGenericModuleMessage<Message, Response, CancellationToken, Error>,
    >,
) where
    Message: Send + 'static,
    Response: Debug + Send + 'static,
    CancellationToken: SparkChannelCancellationTrait + Send + 'static,
    Error: std::error::Error + Send + Sync + 'static,
    H: SparkGenericModuleHandler<Message, Response, Error> + Send,
{
    while let Some(message) = receiver.recv().await {
        match message {
            SparkGenericModuleMessage::Request(wrapper) => {
                let (request, callback) = wrapper.inner_owned();

                // Call the handler and get the Result
                let result = handler.handle_request(request).await;
                if let Err(err) = &result {
                    tracing::error!("Error handling request: {:?}", err);
                }

                if let Err(err) = callback.send(result) {
                    tracing::error!("Failed to send result to callback: {:?}", err);
                }
            },
            SparkGenericModuleMessage::Command(command) => {
                let result = handler.handle_command(command).await;
                if let Err(err) = &result {
                    tracing::error!("Error handling command: {:?}", err);
                }
            },
            SparkGenericModuleMessage::Shutdown(cancellation_token) => {
                cancellation_token.cancel_execution();
                break;
            },
        }
    }
}