pub mod namespace;
use crate::error::{RegistrationError, ValidationError};
use once_cell::sync::Lazy;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub use cloacina_workflow::{Task, TaskState};
pub use namespace::{parse_namespace, TaskNamespace};
pub struct TaskRegistry {
tasks: HashMap<TaskNamespace, Arc<dyn Task>>,
}
impl TaskRegistry {
pub fn new() -> Self {
Self {
tasks: HashMap::new(),
}
}
pub fn register<T: Task + 'static>(
&mut self,
namespace: TaskNamespace,
task: T,
) -> Result<(), RegistrationError> {
if namespace.task_id.is_empty() {
return Err(RegistrationError::InvalidTaskId {
message: "Task ID cannot be empty".to_string(),
});
}
if self.tasks.contains_key(&namespace) {
return Err(RegistrationError::DuplicateTaskId {
id: namespace.to_string(),
});
}
self.tasks.insert(namespace, Arc::new(task));
Ok(())
}
pub fn register_arc(
&mut self,
namespace: TaskNamespace,
task: Arc<dyn Task>,
) -> Result<(), RegistrationError> {
if namespace.task_id.is_empty() {
return Err(RegistrationError::InvalidTaskId {
message: "Task ID cannot be empty".to_string(),
});
}
if self.tasks.contains_key(&namespace) {
return Err(RegistrationError::DuplicateTaskId {
id: namespace.to_string(),
});
}
self.tasks.insert(namespace, task);
Ok(())
}
pub fn get_task(&self, namespace: &TaskNamespace) -> Option<Arc<dyn Task>> {
self.tasks.get(namespace).cloned()
}
pub fn task_ids(&self) -> Vec<TaskNamespace> {
self.tasks.keys().cloned().collect()
}
pub fn task_count(&self) -> usize {
self.tasks.len()
}
pub fn validate_dependencies(&self) -> Result<(), ValidationError> {
for (namespace, task) in &self.tasks {
for dependency_namespace in task.dependencies() {
if !self.tasks.contains_key(dependency_namespace) {
return Err(ValidationError::MissingDependencyOld {
task_id: namespace.to_string(),
dependency: dependency_namespace.to_string(),
});
}
}
}
let mut visited = HashMap::new();
let mut rec_stack = HashMap::new();
for namespace in self.tasks.keys() {
if !visited.get(namespace).unwrap_or(&false) {
if let Err(cycle) = self.check_cycles(namespace, &mut visited, &mut rec_stack) {
return Err(ValidationError::CircularDependency { cycle });
}
}
}
Ok(())
}
fn check_cycles(
&self,
namespace: &TaskNamespace,
visited: &mut HashMap<TaskNamespace, bool>,
rec_stack: &mut HashMap<TaskNamespace, bool>,
) -> Result<(), String> {
visited.insert(namespace.clone(), true);
rec_stack.insert(namespace.clone(), true);
if let Some(task) = self.tasks.get(namespace) {
for dependency_namespace in task.dependencies() {
if !visited.get(dependency_namespace).unwrap_or(&false) {
if let Err(cycle) = self.check_cycles(dependency_namespace, visited, rec_stack)
{
return Err(format!("{} -> {}", namespace.task_id, cycle));
}
} else if *rec_stack.get(dependency_namespace).unwrap_or(&false) {
return Err(format!(
"{} -> {}",
namespace.task_id, dependency_namespace.task_id
));
}
}
}
rec_stack.insert(namespace.clone(), false);
Ok(())
}
pub fn topological_sort(&self) -> Result<Vec<TaskNamespace>, ValidationError> {
self.validate_dependencies()?;
let mut in_degree = HashMap::new();
let mut adj_list = HashMap::new();
for namespace in self.tasks.keys() {
in_degree.insert(namespace.clone(), 0);
adj_list.insert(namespace.clone(), Vec::new());
}
for (namespace, task) in &self.tasks {
for dependency_namespace in task.dependencies() {
if let Some(adj_list_entry) = adj_list.get_mut(dependency_namespace) {
adj_list_entry.push(namespace.clone());
*in_degree.get_mut(namespace).unwrap() += 1;
}
}
}
let mut queue = Vec::new();
let mut result = Vec::new();
for (namespace, °ree) in &in_degree {
if degree == 0 {
queue.push(namespace.clone());
}
}
while let Some(current) = queue.pop() {
result.push(current.clone());
for neighbor in &adj_list[¤t] {
let degree = in_degree.get_mut(neighbor).unwrap();
*degree -= 1;
if *degree == 0 {
queue.push(neighbor.clone());
}
}
}
if result.len() != self.tasks.len() {
return Err(ValidationError::InvalidGraph {
message: "Graph contains cycles".to_string(),
});
}
Ok(result)
}
}
impl Default for TaskRegistry {
fn default() -> Self {
Self::new()
}
}
type TaskConstructor = Box<dyn Fn() -> Arc<dyn Task> + Send + Sync>;
type GlobalTaskRegistry = Arc<RwLock<HashMap<TaskNamespace, TaskConstructor>>>;
static GLOBAL_TASK_REGISTRY: Lazy<GlobalTaskRegistry> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
pub fn register_task_constructor<F>(namespace: TaskNamespace, constructor: F)
where
F: Fn() -> Arc<dyn Task> + Send + Sync + 'static,
{
let mut registry = GLOBAL_TASK_REGISTRY.write();
registry.insert(namespace.clone(), Box::new(constructor));
tracing::debug!(
"Successfully registered task constructor for namespace: {}",
namespace
);
}
pub fn global_task_registry() -> GlobalTaskRegistry {
GLOBAL_TASK_REGISTRY.clone()
}
pub fn get_task(namespace: &TaskNamespace) -> Option<Arc<dyn Task>> {
let registry = GLOBAL_TASK_REGISTRY.read();
registry.get(namespace).map(|constructor| constructor())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::TaskError;
use crate::init_test_logging;
use crate::Context;
use async_trait::async_trait;
use chrono::Utc;
struct TestTask {
id: String,
dependencies: Vec<TaskNamespace>,
fingerprint: Option<String>,
}
impl TestTask {
fn new(id: &str, dependencies: Vec<TaskNamespace>) -> Self {
Self {
id: id.to_string(),
dependencies,
fingerprint: None,
}
}
fn with_fingerprint(mut self, fingerprint: &str) -> Self {
self.fingerprint = Some(fingerprint.to_string());
self
}
}
#[async_trait]
impl Task for TestTask {
async fn execute(
&self,
context: Context<serde_json::Value>,
) -> Result<Context<serde_json::Value>, TaskError> {
Ok(context)
}
fn id(&self) -> &str {
&self.id
}
fn dependencies(&self) -> &[TaskNamespace] {
&self.dependencies
}
fn code_fingerprint(&self) -> Option<String> {
self.fingerprint.clone()
}
}
#[test]
fn test_task_state() {
init_test_logging();
let pending = TaskState::Pending;
assert!(pending.is_pending());
assert!(!pending.is_running());
assert!(!pending.is_completed());
assert!(!pending.is_failed());
let running = TaskState::Running {
start_time: Utc::now(),
};
assert!(running.is_running());
assert!(!running.is_pending());
let completed = TaskState::Completed {
completion_time: Utc::now(),
};
assert!(completed.is_completed());
assert!(!running.is_failed());
let failed = TaskState::Failed {
error: "test error".to_string(),
failure_time: Utc::now(),
};
assert!(failed.is_failed());
assert!(!failed.is_completed());
}
#[test]
fn test_task_registry_basic() {
init_test_logging();
let mut registry = TaskRegistry::new();
let ns1 = TaskNamespace::new("public", "embedded", "test_workflow", "task1");
let ns2 = TaskNamespace::new("public", "embedded", "test_workflow", "task2");
let task1 = TestTask::new("task1", vec![]);
let task2 = TestTask::new("task2", vec![ns1.clone()]);
assert!(registry.register(ns1.clone(), task1).is_ok());
assert!(registry.register(ns2.clone(), task2).is_ok());
assert!(registry.get_task(&ns1).is_some());
assert!(registry.get_task(&ns2).is_some());
}
#[test]
fn test_task_registry_duplicate_id() {
init_test_logging();
let mut registry = TaskRegistry::new();
let ns1 = TaskNamespace::new("public", "embedded", "test_workflow", "task1");
let task1 = TestTask::new("task1", vec![]);
let task1_duplicate = TestTask::new("task1", vec![]);
assert!(registry.register(ns1.clone(), task1).is_ok());
assert!(matches!(
registry.register(ns1, task1_duplicate),
Err(RegistrationError::DuplicateTaskId { .. })
));
}
#[test]
fn test_dependency_validation() {
init_test_logging();
let mut registry = TaskRegistry::new();
let ns1 = TaskNamespace::new("public", "embedded", "test_workflow", "task1");
let ns2 = TaskNamespace::new("public", "embedded", "test_workflow", "task2");
let ns3 = TaskNamespace::new("public", "embedded", "test_workflow", "task3");
let nonexistent_ns =
TaskNamespace::new("public", "embedded", "test_workflow", "nonexistent");
let task1 = TestTask::new("task1", vec![]);
let task2 = TestTask::new("task2", vec![ns1.clone()]);
let task3 = TestTask::new("task3", vec![nonexistent_ns]);
registry.register(ns1, task1).unwrap();
registry.register(ns2, task2).unwrap();
registry.register(ns3, task3).unwrap();
assert!(matches!(
registry.validate_dependencies(),
Err(ValidationError::MissingDependencyOld { .. })
));
}
#[test]
fn test_circular_dependency_detection() {
init_test_logging();
let mut registry = TaskRegistry::new();
let ns1 = TaskNamespace::new("public", "embedded", "test_workflow", "task1");
let ns2 = TaskNamespace::new("public", "embedded", "test_workflow", "task2");
let task1 = TestTask::new("task1", vec![ns2.clone()]);
let task2 = TestTask::new("task2", vec![ns1.clone()]);
registry.register(ns1, task1).unwrap();
registry.register(ns2, task2).unwrap();
assert!(matches!(
registry.validate_dependencies(),
Err(ValidationError::CircularDependency { .. })
));
}
#[test]
fn test_topological_sort() {
init_test_logging();
let mut registry = TaskRegistry::new();
let ns1 = TaskNamespace::new("public", "embedded", "test_workflow", "task1");
let ns2 = TaskNamespace::new("public", "embedded", "test_workflow", "task2");
let ns3 = TaskNamespace::new("public", "embedded", "test_workflow", "task3");
let task1 = TestTask::new("task1", vec![]);
let task2 = TestTask::new("task2", vec![ns1.clone()]);
let task3 = TestTask::new("task3", vec![ns1.clone(), ns2.clone()]);
registry.register(ns1.clone(), task1).unwrap();
registry.register(ns2.clone(), task2).unwrap();
registry.register(ns3.clone(), task3).unwrap();
let sorted = registry.topological_sort().unwrap();
let pos1 = sorted.iter().position(|x| x.task_id == "task1").unwrap();
let pos2 = sorted.iter().position(|x| x.task_id == "task2").unwrap();
let pos3 = sorted.iter().position(|x| x.task_id == "task3").unwrap();
assert!(pos1 < pos2);
assert!(pos1 < pos3);
assert!(pos2 < pos3);
}
#[test]
fn test_code_fingerprint_none_by_default() {
init_test_logging();
let task = TestTask::new("test", vec![]);
assert_eq!(task.code_fingerprint(), None);
}
#[test]
fn test_code_fingerprint_when_provided() {
init_test_logging();
let task = TestTask::new("test", vec![]).with_fingerprint("abc123def456");
assert_eq!(task.code_fingerprint(), Some("abc123def456".to_string()));
}
}