use core::any::Any;
use std::{fmt::Debug, marker::PhantomData};
use async_trait::async_trait;
use tokio::sync::{mpsc, oneshot};
use crate::callback::CallbackWrapper;
#[async_trait]
pub trait SparkGenericModuleHandler<Message, Response, Error = eyre::Error> {
async fn handle_request(&mut self, request: Message) -> Result<Response, Error>;
async fn handle_command(&mut self, command: Message) -> Result<(), Error>;
}
pub trait SparkChannelCancellationTrait {
fn cancel_execution(&self);
}
pub enum SparkGenericModuleMessage<
Message,
Response,
CancellationMessage: SparkChannelCancellationTrait,
Error,
> {
Request(CallbackWrapper<Message, Result<Response, Error>>),
Command(Message),
Shutdown(CancellationMessage),
}
pub trait IntoResult<Success, Error> {
type Output;
fn into_result(result: Result<Success, Error>) -> Self::Output;
}
impl<T, E> IntoResult<T, E> for eyre::Result<T>
where
E: std::fmt::Display + Send + Sync + 'static,
{
type Output = eyre::Result<T>;
fn into_result(result: Result<T, E>) -> Self::Output {
result.map_err(|e| eyre::eyre!("{}", e))
}
}
#[derive(Debug, Clone)]
pub struct SparkChannelError(pub String);
impl std::fmt::Display for SparkChannelError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for SparkChannelError {}
impl From<&'static str> for SparkChannelError {
fn from(s: &'static str) -> Self {
SparkChannelError(s.to_string())
}
}
impl From<String> for SparkChannelError {
fn from(s: String) -> Self {
SparkChannelError(s)
}
}
impl<T> IntoResult<T, SparkChannelError> for Result<T, SparkChannelError> {
type Output = Result<T, SparkChannelError>;
fn into_result(result: Result<T, SparkChannelError>) -> Self::Output {
result
}
}
#[derive(Clone)]
pub struct SparkGenericModuleDispatcher<
Message,
Response,
CancellationMessage: SparkChannelCancellationTrait,
Error,
> {
pub sender:
mpsc::Sender<SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>>,
_error_marker: PhantomData<Error>,
}
impl<Message, Response, CancellationMessage, Error>
SparkGenericModuleDispatcher<Message, Response, CancellationMessage, Error>
where
Error: std::error::Error + Send + Sync + 'static,
CancellationMessage: SparkChannelCancellationTrait,
{
#[must_use]
pub const fn new(
sender: mpsc::Sender<
SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>,
>,
) -> Self {
Self {
sender,
_error_marker: PhantomData,
}
}
async fn internal_request<Req, Resp, R>(&self, request: Req) -> R::Output
where
Req: Into<Message> + Send + 'static,
Resp: 'static + Send + Clone,
Response: AsRef<dyn Any + Send>,
R: IntoResult<Resp, Error>,
Error: Debug + Send + Sync + From<String> + From<&'static str>,
{
let (callback_tx, callback_rx) = oneshot::channel();
let wrapper = CallbackWrapper {
message: request.into(),
sender: callback_tx,
};
let send_result = self
.sender
.send(SparkGenericModuleMessage::Request(wrapper))
.await;
if let Err(_) = send_result {
return R::into_result(Err(Error::from("Failed to send request")));
}
let receive_result = callback_rx.await;
let result = match receive_result {
Ok(result) => result,
Err(_) => return R::into_result(Err(Error::from("Failed to receive response"))),
};
match result {
Ok(response) => {
match response.as_ref().downcast_ref::<Resp>() {
Some(value) => R::into_result(Ok(value.clone())),
None => R::into_result(Err(Error::from("Invalid response type"))),
}
},
Err(err) => {
let error_message = format!("Error handling request: {:?}", err);
tracing::debug!("Received error on handler callback: {}", error_message);
R::into_result(Err(Error::from(error_message)))
},
}
}
pub async fn request<Req>(&self, request: Req) -> Result<Response, Error>
where
Req: Into<Message> + Send + 'static,
Response: 'static + Send + Clone,
Response: AsRef<dyn Any + Send>,
Result<Response, Error>: IntoResult<Response, Error, Output = Result<Response, Error>>,
Error: Debug + Send + Sync + From<String> + From<&'static str>,
{
self.internal_request::<_, Response, Result<_, Error>>(request)
.await
}
pub async fn send_command<C, R>(&self, command: C) -> R::Output
where
C: Into<Message> + Send + 'static,
R: IntoResult<(), Error>,
Error: From<&'static str>,
{
let send_result = self
.sender
.send(SparkGenericModuleMessage::Command(command.into()))
.await;
match send_result {
Ok(_) => R::into_result(Ok(())),
Err(_) => R::into_result(Err(Error::from("Failed to send command"))),
}
}
}
pub async fn run_module_server<Message, Response, CancellationToken, Error, H>(
mut handler: H,
mut receiver: mpsc::Receiver<
SparkGenericModuleMessage<Message, Response, CancellationToken, Error>,
>,
) where
Message: Send + 'static,
Response: Debug + Send + 'static,
CancellationToken: SparkChannelCancellationTrait + Send + 'static,
Error: std::error::Error + Send + Sync + 'static,
H: SparkGenericModuleHandler<Message, Response, Error> + Send,
{
while let Some(message) = receiver.recv().await {
match message {
SparkGenericModuleMessage::Request(wrapper) => {
let (request, callback) = wrapper.inner_owned();
let result = handler.handle_request(request).await;
if let Err(err) = &result {
tracing::error!("Error handling request: {:?}", err);
}
if let Err(err) = callback.send(result) {
tracing::error!("Failed to send result to callback: {:?}", err);
}
},
SparkGenericModuleMessage::Command(command) => {
let result = handler.handle_command(command).await;
if let Err(err) = &result {
tracing::error!("Error handling command: {:?}", err);
}
},
SparkGenericModuleMessage::Shutdown(cancellation_token) => {
cancellation_token.cancel_execution();
break;
},
}
}
}