use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use serde::de::DeserializeOwned;
use serde::Serialize;
use rustvello_proto::call::SerializedArguments;
use rustvello_proto::config::TaskConfig;
use rustvello_proto::identifiers::TaskId;
use crate::error::{RustvelloError, RustvelloResult};
pub trait Task: Send + Sync + 'static {
type Params: Serialize + DeserializeOwned + Send + Sync + 'static;
type Result: Serialize + DeserializeOwned + Send + Sync + 'static;
fn task_id(&self) -> &TaskId;
fn config(&self) -> &TaskConfig;
fn run(&self, params: Self::Params) -> RustvelloResult<Self::Result>;
}
pub trait DynTask: Send + Sync {
fn task_id(&self) -> &TaskId;
fn config(&self) -> &TaskConfig;
fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String>;
}
pub fn serialized_args_to_json(
args: &SerializedArguments,
) -> RustvelloResult<std::borrow::Cow<'_, str>> {
use std::borrow::Cow;
if args.0.len() == 1 && args.0.contains_key("__args__") {
return Ok(Cow::Borrowed(&args.0["__args__"]));
}
use std::fmt::Write;
let mut buf = String::with_capacity(args.0.len() * 32 + 2);
buf.push('{');
for (i, (k, v)) in args.0.iter().enumerate() {
if i > 0 {
buf.push(',');
}
let escaped_key =
serde_json::to_string(k.as_str()).map_err(|e| RustvelloError::Serialization {
message: format!("failed to escape JSON key: {e}"),
})?;
serde_json::from_str::<serde_json::Value>(v).map_err(|e| {
RustvelloError::Serialization {
message: format!("invalid JSON value for key {k}: {e}"),
}
})?;
write!(buf, "{}:{}", escaped_key, v).map_err(|e| RustvelloError::Serialization {
message: format!("failed to build JSON: {e}"),
})?;
}
buf.push('}');
Ok(Cow::Owned(buf))
}
impl<T: Task> DynTask for T {
#[inline]
fn task_id(&self) -> &TaskId {
Task::task_id(self)
}
#[inline]
fn config(&self) -> &TaskConfig {
Task::config(self)
}
fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String> {
let json_str = serialized_args_to_json(args)?;
let params: T::Params =
serde_json::from_str(&json_str).map_err(|e| RustvelloError::Serialization {
message: e.to_string(),
})?;
let result = self.run(params)?;
serde_json::to_string(&result).map_err(|e| RustvelloError::Serialization {
message: e.to_string(),
})
}
}
impl fmt::Debug for dyn DynTask {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DynTask")
.field("task_id", &self.task_id())
.finish()
}
}
pub trait CrossLanguageSafe: Serialize + DeserializeOwned {}
impl CrossLanguageSafe for String {}
impl CrossLanguageSafe for bool {}
impl CrossLanguageSafe for i32 {}
impl CrossLanguageSafe for i64 {}
impl CrossLanguageSafe for u32 {}
impl CrossLanguageSafe for u64 {}
impl CrossLanguageSafe for f32 {}
impl CrossLanguageSafe for f64 {}
impl<T: CrossLanguageSafe> CrossLanguageSafe for Vec<T> {}
impl<T: CrossLanguageSafe> CrossLanguageSafe for Option<T> {}
impl<K: CrossLanguageSafe + Ord, V: CrossLanguageSafe> CrossLanguageSafe
for std::collections::BTreeMap<K, V>
{
}
impl<K: CrossLanguageSafe + Eq + std::hash::Hash, V: CrossLanguageSafe> CrossLanguageSafe
for std::collections::HashMap<K, V>
{
}
pub trait ForeignTask: Send + Sync + 'static {
type Params: CrossLanguageSafe + Send + Sync + 'static;
type Result: CrossLanguageSafe + Send + Sync + 'static;
fn task_id(&self) -> TaskId;
fn config(&self) -> TaskConfig {
TaskConfig::default()
}
}
struct ForeignTaskAdapter<F: ForeignTask> {
_inner: F,
task_id: TaskId,
config: TaskConfig,
}
impl<F: ForeignTask> ForeignTaskAdapter<F> {
fn new(task: F) -> Self {
let task_id = task.task_id();
let config = task.config();
Self {
_inner: task,
task_id,
config,
}
}
}
impl<F: ForeignTask> DynTask for ForeignTaskAdapter<F> {
fn task_id(&self) -> &TaskId {
&self.task_id
}
fn config(&self) -> &TaskConfig {
&self.config
}
fn execute(&self, _args: &SerializedArguments) -> RustvelloResult<String> {
Err(RustvelloError::Configuration {
message: format!(
"foreign task {} cannot be executed locally — must be processed by a {} worker",
self.task_id,
self.task_id.language(),
),
})
}
}
pub type TaskFn = Arc<dyn Fn(String) -> RustvelloResult<String> + Send + Sync>;
pub struct TaskDefinition {
pub task_id: TaskId,
pub config: TaskConfig,
pub func: TaskFn,
}
impl TaskDefinition {
pub fn new(task_id: TaskId, config: TaskConfig, func: TaskFn) -> Self {
Self {
task_id,
config,
func,
}
}
}
impl fmt::Debug for TaskDefinition {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TaskDefinition")
.field("task_id", &self.task_id)
.field("config", &self.config)
.finish()
}
}
struct LegacyTaskAdapter {
definition: Arc<TaskDefinition>,
}
impl DynTask for LegacyTaskAdapter {
fn task_id(&self) -> &TaskId {
&self.definition.task_id
}
fn config(&self) -> &TaskConfig {
&self.definition.config
}
fn execute(&self, args: &SerializedArguments) -> RustvelloResult<String> {
let args_json =
serde_json::to_string(&args.0).map_err(|e| RustvelloError::Serialization {
message: e.to_string(),
})?;
(self.definition.func)(args_json)
}
}
#[derive(Default)]
pub struct TaskRegistry {
tasks: HashMap<TaskId, Arc<dyn DynTask>>,
legacy_tasks: HashMap<TaskId, Arc<TaskDefinition>>,
}
impl TaskRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register_typed<T: Task>(&mut self, task: T) -> RustvelloResult<()> {
let task_id = task.task_id().clone();
if self.tasks.contains_key(&task_id) {
return Err(RustvelloError::Configuration {
message: format!("task already registered: {}", task_id),
});
}
self.tasks.insert(task_id, Arc::new(task));
Ok(())
}
pub fn register_foreign<F: ForeignTask>(&mut self, task: F) -> RustvelloResult<()> {
let task_id = task.task_id();
if !task_id.is_foreign() {
return Err(RustvelloError::Configuration {
message: format!(
"ForeignTask must have a non-empty language, got: {}",
task_id
),
});
}
if self.tasks.contains_key(&task_id) {
return Err(RustvelloError::Configuration {
message: format!("task already registered: {}", task_id),
});
}
self.tasks
.insert(task_id, Arc::new(ForeignTaskAdapter::new(task)));
Ok(())
}
pub fn register(&mut self, definition: TaskDefinition) -> RustvelloResult<()> {
let task_id = definition.task_id.clone();
if self.tasks.contains_key(&task_id) {
return Err(RustvelloError::Configuration {
message: format!("task already registered: {}", task_id),
});
}
let def = Arc::new(definition);
let adapter = LegacyTaskAdapter {
definition: Arc::clone(&def),
};
self.tasks.insert(task_id.clone(), Arc::new(adapter));
self.legacy_tasks.insert(task_id, def);
Ok(())
}
pub fn get_dyn(&self, task_id: &TaskId) -> Option<Arc<dyn DynTask>> {
self.tasks.get(task_id).cloned()
}
pub fn get(&self, task_id: &TaskId) -> Option<Arc<TaskDefinition>> {
self.legacy_tasks.get(task_id).cloned()
}
pub fn contains(&self, task_id: &TaskId) -> bool {
self.tasks.contains_key(task_id)
}
pub fn task_ids(&self) -> Vec<&TaskId> {
self.tasks.keys().collect()
}
pub fn len(&self) -> usize {
self.tasks.len()
}
pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
}
impl fmt::Debug for TaskRegistry {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TaskRegistry")
.field("tasks", &self.tasks.keys().collect::<Vec<_>>())
.finish()
}
}
pub trait TaskModule: Send + Sync {
fn name(&self) -> &str;
fn register(&self, registry: &mut TaskRegistry) -> RustvelloResult<()>;
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_fn() -> TaskFn {
Arc::new(|_| Ok("ok".to_string()))
}
#[test]
fn registry_new_is_empty() {
let reg = TaskRegistry::new();
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
}
#[test]
fn register_and_get() {
let mut reg = TaskRegistry::new();
let tid = TaskId::new("mod", "func");
reg.register(TaskDefinition::new(
tid.clone(),
TaskConfig::default(),
dummy_fn(),
))
.unwrap();
assert_eq!(reg.len(), 1);
assert!(!reg.is_empty());
assert!(reg.contains(&tid));
assert!(reg.get(&tid).is_some());
assert_eq!(reg.get(&tid).unwrap().task_id, tid);
}
#[test]
fn register_duplicate_errors() {
let mut reg = TaskRegistry::new();
let tid = TaskId::new("mod", "func");
reg.register(TaskDefinition::new(
tid.clone(),
TaskConfig::default(),
dummy_fn(),
))
.unwrap();
let result = reg.register(TaskDefinition::new(tid, TaskConfig::default(), dummy_fn()));
assert!(result.is_err());
}
#[test]
fn get_nonexistent_returns_none() {
let reg = TaskRegistry::new();
let tid = TaskId::new("no", "such");
assert!(!reg.contains(&tid));
assert!(reg.get(&tid).is_none());
}
#[test]
fn task_ids_lists_all() {
let mut reg = TaskRegistry::new();
let t1 = TaskId::new("mod", "a");
let t2 = TaskId::new("mod", "b");
reg.register(TaskDefinition::new(
t1.clone(),
TaskConfig::default(),
dummy_fn(),
))
.unwrap();
reg.register(TaskDefinition::new(
t2.clone(),
TaskConfig::default(),
dummy_fn(),
))
.unwrap();
let ids = reg.task_ids();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&&t1));
assert!(ids.contains(&&t2));
}
#[test]
fn task_definition_debug() {
let def = TaskDefinition::new(
TaskId::new("mod", "func"),
TaskConfig::default(),
dummy_fn(),
);
let debug_str = format!("{:?}", def);
assert!(debug_str.contains("mod"));
assert!(debug_str.contains("func"));
}
#[derive(serde::Serialize, serde::Deserialize)]
struct TestParams {
value: String,
}
impl CrossLanguageSafe for TestParams {}
struct TestForeignTask;
impl ForeignTask for TestForeignTask {
type Params = TestParams;
type Result = String;
fn task_id(&self) -> TaskId {
TaskId::foreign("python", "analytics.tasks", "train_model")
}
}
#[test]
fn register_foreign_task() {
let mut reg = TaskRegistry::new();
reg.register_foreign(TestForeignTask).unwrap();
let tid = TaskId::foreign("python", "analytics.tasks", "train_model");
assert!(reg.contains(&tid));
assert_eq!(reg.len(), 1);
let dyn_task = reg.get_dyn(&tid).unwrap();
assert_eq!(dyn_task.task_id(), &tid);
assert!(dyn_task.task_id().is_foreign());
}
#[test]
fn foreign_task_execute_returns_error() {
let mut reg = TaskRegistry::new();
reg.register_foreign(TestForeignTask).unwrap();
let tid = TaskId::foreign("python", "analytics.tasks", "train_model");
let dyn_task = reg.get_dyn(&tid).unwrap();
let args = SerializedArguments::default();
let result = dyn_task.execute(&args);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("foreign task"));
assert!(err_msg.contains("python"));
}
#[test]
fn register_foreign_duplicate_errors() {
let mut reg = TaskRegistry::new();
reg.register_foreign(TestForeignTask).unwrap();
let result = reg.register_foreign(TestForeignTask);
assert!(result.is_err());
}
#[test]
fn cross_language_safe_primitives() {
fn assert_cls<T: CrossLanguageSafe>() {}
assert_cls::<String>();
assert_cls::<bool>();
assert_cls::<i32>();
assert_cls::<i64>();
assert_cls::<u32>();
assert_cls::<u64>();
assert_cls::<f32>();
assert_cls::<f64>();
assert_cls::<Vec<String>>();
assert_cls::<Option<i64>>();
assert_cls::<std::collections::BTreeMap<String, i64>>();
assert_cls::<std::collections::HashMap<String, String>>();
}
}