use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::context::ActionContext;
use crate::error::ActionError;
pub trait ActionHandler: Send + Sync + 'static {
fn invoke(
&self,
context: ActionContext,
body: Bytes,
) -> Pin<Box<dyn Future<Output = Result<Bytes, ActionError>> + Send>>;
}
struct JsonActionHandler<F, I, O, Fut> {
f: F,
_phantom: PhantomData<fn(I) -> (O, Fut)>,
}
impl<F, I, O, Fut> ActionHandler for JsonActionHandler<F, I, O, Fut>
where
F: Fn(ActionContext, I) -> Fut + Send + Sync + 'static,
I: DeserializeOwned + Send + 'static,
O: Serialize + Send + 'static,
Fut: Future<Output = Result<O, ActionError>> + Send + 'static,
{
fn invoke(
&self,
context: ActionContext,
body: Bytes,
) -> Pin<Box<dyn Future<Output = Result<Bytes, ActionError>> + Send>> {
let input: I = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => {
return Box::pin(std::future::ready(Err(ActionError::BadRequest(
e.to_string(),
))))
}
};
let future = (self.f)(context, input);
Box::pin(async move {
let output = future.await?;
serde_json::to_vec(&output)
.map(Bytes::from)
.map_err(|e| ActionError::Internal(e.to_string()))
})
}
}
pub struct ActionRegistry {
pub(crate) map: HashMap<&'static str, Arc<dyn ActionHandler>>,
}
impl ActionRegistry {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn register<I, O, F, Fut>(mut self, name: &'static str, f: F) -> Self
where
I: DeserializeOwned + Send + 'static,
O: Serialize + Send + 'static,
F: Fn(ActionContext, I) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<O, ActionError>> + Send + 'static,
{
self.map.insert(
name,
Arc::new(JsonActionHandler {
f,
_phantom: PhantomData,
}),
);
self
}
pub fn get(&self, name: &str) -> Option<Arc<dyn ActionHandler>> {
self.map.get(name).cloned()
}
}
impl Default for ActionRegistry {
fn default() -> Self {
Self::new()
}
}