use crate::codec::{Codec, sealed};
use crate::error::{BoxError, CodecError};
use crate::loop_result::LoopResult;
use crate::task::{
BranchOutputs, BytesFuture, CoreTask, TaskMetadata, UntypedCoreTask,
to_heterogeneous_join_task_arc,
};
use bytes::Bytes;
use std::collections::HashMap;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
pub type TaskFactory = Box<dyn Fn() -> UntypedCoreTask + Send + Sync>;
pub struct TaskEntry {
factory: TaskFactory,
metadata: TaskMetadata,
}
pub struct TaskRegistry {
tasks: HashMap<String, TaskEntry>,
}
impl Default for TaskRegistry {
fn default() -> Self {
Self::new()
}
}
impl TaskRegistry {
#[must_use]
pub fn new() -> Self {
Self {
tasks: HashMap::new(),
}
}
pub fn merge(&mut self, other: Self) {
for (id, entry) in other.tasks {
self.tasks.entry(id).or_insert(entry);
}
}
pub fn register<T, C>(&mut self, id: &str, codec: Arc<C>, task: T)
where
T: CoreTask + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
self.register_with_metadata(id, codec, task, TaskMetadata::default());
}
pub fn register_with_metadata<T, C>(
&mut self,
id: &str,
codec: Arc<C>,
task: T,
metadata: TaskMetadata,
) where
T: CoreTask + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
self.register_task_arc(id, codec, Arc::new(task), metadata);
}
pub fn register_from_deps<T, C>(
&mut self,
codec: Arc<C>,
deps: &crate::deps::Deps,
) -> Result<(), Vec<crate::deps::MissingDep>>
where
T: crate::deps::DepsInjectable,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
let missing = T::verify_deps(deps);
if !missing.is_empty() {
return Err(missing);
}
let task = T::from_deps(deps);
self.register_with_metadata(T::task_id(), codec, task, T::metadata());
Ok(())
}
pub(crate) fn register_task_arc<T, C>(
&mut self,
id: &str,
codec: Arc<C>,
task: Arc<T>,
metadata: TaskMetadata,
) where
T: CoreTask + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
let task_id = id.to_string();
let factory = Box::new(move || -> UntypedCoreTask {
let task = Arc::clone(&task);
let codec = Arc::clone(&codec);
Box::new(TaskWrapper {
task,
codec,
task_id: task_id.clone(),
})
});
self.tasks
.insert(id.to_string(), TaskEntry { factory, metadata });
}
pub fn register_fn<I, O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: F)
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
self.register_fn_with_metadata(id, codec, func, TaskMetadata::default());
}
pub fn register_fn_with_metadata<I, O, F, Fut, C>(
&mut self,
id: &str,
codec: Arc<C>,
func: F,
metadata: TaskMetadata,
) where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
self.register_fn_arc(id, codec, Arc::new(func), metadata);
}
pub(crate) fn register_fn_arc<I, O, F, Fut, C>(
&mut self,
id: &str,
codec: Arc<C>,
func: Arc<F>,
metadata: TaskMetadata,
) where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
let task_id = id.to_string();
let factory = Box::new(move || -> UntypedCoreTask {
let func = Arc::clone(&func);
let codec = Arc::clone(&codec);
Box::new(FnTaskWrapper {
func,
codec,
task_id: task_id.clone(),
_phantom: PhantomData,
})
});
self.tasks
.insert(id.to_string(), TaskEntry { factory, metadata });
}
pub(crate) fn register_loop_task_arc<T, O, C>(
&mut self,
id: &str,
codec: Arc<C>,
task: Arc<T>,
metadata: TaskMetadata,
) where
T: CoreTask<Output = LoopResult<O>> + 'static,
T::Input: Send + 'static,
O: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<O> + 'static,
{
use crate::task::wrap_core_loop_task;
let task_id = id.to_string();
let factory = Box::new(move || -> UntypedCoreTask {
let task = Arc::clone(&task);
let codec = Arc::clone(&codec);
wrap_core_loop_task(&task_id, task, codec)
});
self.tasks
.insert(id.to_string(), TaskEntry { factory, metadata });
}
pub(crate) fn register_loop_fn_arc<I, O, F, Fut, C>(
&mut self,
id: &str,
codec: Arc<C>,
func: Arc<F>,
metadata: TaskMetadata,
) where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<LoopResult<O>, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
let task_id = id.to_string();
let factory = Box::new(move || -> UntypedCoreTask {
let func = Arc::clone(&func);
let codec = Arc::clone(&codec);
Box::new(LoopFnTaskWrapper {
func,
codec,
task_id: task_id.clone(),
_phantom: PhantomData,
})
});
self.tasks
.insert(id.to_string(), TaskEntry { factory, metadata });
}
pub fn register_join<O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: F)
where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec
+ sealed::EncodeValue<O>
+ sealed::DecodeValue<crate::branch_results::NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
self.register_join_with_metadata(id, codec, func, TaskMetadata::default());
}
pub fn register_join_with_metadata<O, F, Fut, C>(
&mut self,
id: &str,
codec: Arc<C>,
func: F,
metadata: TaskMetadata,
) where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec
+ sealed::EncodeValue<O>
+ sealed::DecodeValue<crate::branch_results::NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
self.register_arc_join(id, codec, Arc::new(func), metadata);
}
pub(crate) fn register_arc_join<O, F, Fut, C>(
&mut self,
id: &str,
codec: Arc<C>,
func: Arc<F>,
metadata: TaskMetadata,
) where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec
+ sealed::EncodeValue<O>
+ sealed::DecodeValue<crate::branch_results::NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
let task_id = id.to_string();
let factory = Box::new(move || -> UntypedCoreTask {
to_heterogeneous_join_task_arc(&task_id, Arc::clone(&func), Arc::clone(&codec))
});
self.tasks
.insert(id.to_string(), TaskEntry { factory, metadata });
}
#[must_use]
pub fn get(&self, id: &str) -> Option<UntypedCoreTask> {
self.tasks.get(id).map(|entry| (entry.factory)())
}
#[must_use]
pub fn get_metadata(&self, id: &str) -> Option<&TaskMetadata> {
self.tasks.get(id).map(|entry| &entry.metadata)
}
#[must_use]
pub fn get_with_metadata(&self, id: &str) -> Option<(UntypedCoreTask, &TaskMetadata)> {
self.tasks
.get(id)
.map(|entry| ((entry.factory)(), &entry.metadata))
}
pub fn set_metadata(&mut self, id: &str, metadata: TaskMetadata) -> bool {
if let Some(entry) = self.tasks.get_mut(id) {
entry.metadata = metadata;
true
} else {
false
}
}
#[must_use]
pub fn contains(&self, id: &str) -> bool {
self.tasks.contains_key(id)
}
#[must_use]
pub fn len(&self) -> usize {
self.tasks.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
pub fn task_ids(&self) -> impl Iterator<Item = &str> {
self.tasks.keys().map(std::string::String::as_str)
}
pub fn with_codec<C>(codec: Arc<C>) -> RegistryBuilder<C>
where
C: Codec,
{
RegistryBuilder {
codec,
registry: TaskRegistry::new(),
}
}
}
pub struct RegistryBuilder<C> {
codec: Arc<C>,
registry: TaskRegistry,
}
impl<C: Codec> RegistryBuilder<C> {
#[must_use]
pub fn register<T>(mut self, id: &str, task: T) -> Self
where
T: CoreTask + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
self.registry.register(id, Arc::clone(&self.codec), task);
self
}
#[must_use]
pub fn register_with_metadata<T>(mut self, id: &str, task: T, metadata: TaskMetadata) -> Self
where
T: CoreTask + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output> + 'static,
{
self.registry
.register_with_metadata(id, Arc::clone(&self.codec), task, metadata);
self
}
#[must_use]
pub fn register_fn<I, O, F, Fut>(mut self, id: &str, func: F) -> Self
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
self.registry.register_fn(id, Arc::clone(&self.codec), func);
self
}
#[must_use]
pub fn register_fn_with_metadata<I, O, F, Fut>(
mut self,
id: &str,
func: F,
metadata: TaskMetadata,
) -> Self
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
{
self.registry
.register_fn_with_metadata(id, Arc::clone(&self.codec), func, metadata);
self
}
#[must_use]
pub fn register_join<O, F, Fut>(mut self, id: &str, func: F) -> Self
where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: sealed::EncodeValue<O>
+ sealed::DecodeValue<crate::branch_results::NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
self.registry
.register_join(id, Arc::clone(&self.codec), func);
self
}
#[must_use]
pub fn register_join_with_metadata<O, F, Fut>(
mut self,
id: &str,
func: F,
metadata: TaskMetadata,
) -> Self
where
F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: sealed::EncodeValue<O>
+ sealed::DecodeValue<crate::branch_results::NamedBranchResults>
+ Send
+ Sync
+ 'static,
{
self.registry
.register_join_with_metadata(id, Arc::clone(&self.codec), func, metadata);
self
}
#[must_use]
pub fn build(self) -> TaskRegistry {
self.registry
}
}
struct FnTaskWrapper<F, I, O, C> {
func: Arc<F>,
codec: Arc<C>,
task_id: String,
_phantom: PhantomData<fn(I) -> O>,
}
impl<F, I, O, Fut, C> CoreTask for FnTaskWrapper<F, I, O, C>
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let func = Arc::clone(&self.func);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<I>(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<I>(),
source: e,
})
})?;
let output = func(decoded_input).await?;
codec.encode(&output).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})
})
}
}
struct LoopFnTaskWrapper<F, I, O, C> {
func: Arc<F>,
codec: Arc<C>,
task_id: String,
_phantom: PhantomData<fn(I) -> O>,
}
impl<F, I, O, Fut, C> CoreTask for LoopFnTaskWrapper<F, I, O, C>
where
F: Fn(I) -> Fut + Send + Sync + 'static,
I: Send + 'static,
O: Send + 'static,
Fut: Future<Output = Result<LoopResult<O>, BoxError>> + Send + 'static,
C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let func = Arc::clone(&self.func);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<I>(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<I>(),
source: e,
})
})?;
let loop_result = func(decoded_input).await?;
let (decision, inner) = loop_result.into_decision();
let inner_bytes = codec.encode(&inner).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})?;
Ok(crate::codec::encode_loop_envelope(decision, &inner_bytes))
})
}
}
struct TaskWrapper<T, C> {
task: Arc<T>,
codec: Arc<C>,
task_id: String,
}
impl<T, C> CoreTask for TaskWrapper<T, C>
where
T: CoreTask + Send + Sync + 'static,
T::Input: Send + 'static,
T::Output: Send + 'static,
T::Future: Send + 'static,
C: Codec + sealed::DecodeValue<T::Input> + sealed::EncodeValue<T::Output>,
{
type Input = Bytes;
type Output = Bytes;
type Future = BytesFuture;
fn run(&self, input: Bytes) -> Self::Future {
let task = Arc::clone(&self.task);
let codec = Arc::clone(&self.codec);
let task_id = self.task_id.clone();
BytesFuture::new(async move {
let decoded_input = codec.decode::<T::Input>(input).map_err(|e| -> BoxError {
Box::new(CodecError::DecodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
expected_type: std::any::type_name::<T::Input>(),
source: e,
})
})?;
let output = task.run(decoded_input).await?;
codec.encode(&output).map_err(|e| -> BoxError {
Box::new(CodecError::EncodeFailed {
task_id: crate::TaskId::from(task_id.as_str()),
source: e,
})
})
})
}
}
#[cfg(test)]
#[allow(clippy::stable_sort_primitive)]
mod tests {
use super::*;
use crate::codec::{Decoder, Encoder};
struct DummyCodec;
impl Encoder for DummyCodec {}
impl Decoder for DummyCodec {}
impl sealed::EncodeValue<u32> for DummyCodec {
fn encode_value(&self, _: &u32) -> Result<Bytes, BoxError> {
Ok(Bytes::from_static(b"encoded"))
}
}
impl sealed::DecodeValue<u32> for DummyCodec {
fn decode_value(&self, _: Bytes) -> Result<u32, BoxError> {
Ok(42)
}
}
#[test]
fn test_registry_register() {
let mut registry = TaskRegistry::new();
let codec = Arc::new(DummyCodec);
registry.register_fn("double", codec, |input: u32| async move { Ok(input * 2) });
assert!(registry.contains("double"));
assert_eq!(registry.len(), 1);
}
#[test]
fn test_registry_get() {
let mut registry = TaskRegistry::new();
let codec = Arc::new(DummyCodec);
registry.register_fn("double", codec, |input: u32| async move { Ok(input * 2) });
let task = registry.get("double");
assert!(task.is_some());
let missing = registry.get("nonexistent");
assert!(missing.is_none());
}
#[test]
fn test_registry_task_ids() {
let mut registry = TaskRegistry::new();
let codec = Arc::new(DummyCodec);
registry.register_fn("task_a", codec.clone(), |i: u32| async move { Ok(i) });
registry.register_fn("task_b", codec.clone(), |i: u32| async move { Ok(i) });
registry.register_fn("task_c", codec, |i: u32| async move { Ok(i) });
let mut ids: Vec<_> = registry.task_ids().collect();
ids.sort();
assert_eq!(ids, vec!["task_a", "task_b", "task_c"]);
}
}