use crate::{TaskBackend, TaskExecutionError, TaskId, TaskStatus};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskChain {
id: TaskId,
name: String,
task_ids: Vec<TaskId>,
current_index: usize,
status: ChainStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ChainStatus {
Pending,
Running,
Completed,
Failed,
}
impl TaskChain {
pub fn new(name: impl Into<String>) -> Self {
Self {
id: TaskId::new(),
name: name.into(),
task_ids: Vec::new(),
current_index: 0,
status: ChainStatus::Pending,
}
}
pub fn id(&self) -> TaskId {
self.id
}
pub fn name(&self) -> &str {
&self.name
}
pub fn add_task(&mut self, task_id: TaskId) {
self.task_ids.push(task_id);
}
pub fn task_count(&self) -> usize {
self.task_ids.len()
}
pub fn current_task(&self) -> Option<TaskId> {
self.task_ids.get(self.current_index).copied()
}
pub fn advance(&mut self) -> bool {
self.current_index += 1;
self.current_index < self.task_ids.len()
}
pub fn status(&self) -> ChainStatus {
self.status
}
pub fn set_status(&mut self, status: ChainStatus) {
self.status = status;
}
pub fn is_complete(&self) -> bool {
matches!(self.status, ChainStatus::Completed | ChainStatus::Failed)
}
pub async fn execute(
&mut self,
backend: Arc<dyn TaskBackend>,
) -> Result<(), TaskExecutionError> {
self.set_status(ChainStatus::Running);
while let Some(task_id) = self.current_task() {
let status = backend.get_status(task_id).await?;
match status {
TaskStatus::Success => {
if !self.advance() {
self.set_status(ChainStatus::Completed);
return Ok(());
}
}
TaskStatus::Failure => {
self.set_status(ChainStatus::Failed);
return Err(TaskExecutionError::ExecutionFailed(format!(
"Task {} in chain {} failed",
task_id, self.name
)));
}
TaskStatus::Pending | TaskStatus::Running | TaskStatus::Retry => {
return Ok(());
}
}
}
self.set_status(ChainStatus::Completed);
Ok(())
}
}
pub struct TaskChainBuilder {
chain: TaskChain,
}
impl TaskChainBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
chain: TaskChain::new(name),
}
}
pub fn add_task(mut self, task_id: TaskId) -> Self {
self.chain.add_task(task_id);
self
}
pub fn add_tasks(mut self, task_ids: Vec<TaskId>) -> Self {
for task_id in task_ids {
self.chain.add_task(task_id);
}
self
}
pub fn build(self) -> TaskChain {
self.chain
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DummyBackend, Task, TaskPriority};
struct TestTask {
id: TaskId,
}
impl Task for TestTask {
fn id(&self) -> TaskId {
self.id
}
fn name(&self) -> &str {
"test"
}
fn priority(&self) -> TaskPriority {
TaskPriority::default()
}
}
#[test]
fn test_chain_creation() {
let chain = TaskChain::new("test-chain");
assert_eq!(chain.name(), "test-chain");
assert_eq!(chain.task_count(), 0);
assert_eq!(chain.status(), ChainStatus::Pending);
}
#[test]
fn test_chain_add_task() {
let mut chain = TaskChain::new("test");
let task_id = TaskId::new();
chain.add_task(task_id);
assert_eq!(chain.task_count(), 1);
assert_eq!(chain.current_task(), Some(task_id));
}
#[test]
fn test_chain_advance() {
let mut chain = TaskChain::new("test");
chain.add_task(TaskId::new());
chain.add_task(TaskId::new());
assert!(chain.advance());
assert!(!chain.advance());
}
#[test]
fn test_chain_builder() {
let task1 = TaskId::new();
let task2 = TaskId::new();
let chain = TaskChainBuilder::new("builder-test")
.add_task(task1)
.add_task(task2)
.build();
assert_eq!(chain.task_count(), 2);
assert_eq!(chain.name(), "builder-test");
}
#[test]
fn test_chain_builder_multiple() {
let tasks = vec![TaskId::new(), TaskId::new(), TaskId::new()];
let chain = TaskChainBuilder::new("batch").add_tasks(tasks).build();
assert_eq!(chain.task_count(), 3);
}
#[test]
fn test_chain_status() {
let mut chain = TaskChain::new("test");
assert_eq!(chain.status(), ChainStatus::Pending);
chain.set_status(ChainStatus::Running);
assert_eq!(chain.status(), ChainStatus::Running);
chain.set_status(ChainStatus::Completed);
assert!(chain.is_complete());
}
#[tokio::test]
async fn test_chain_execution() {
let backend = Arc::new(DummyBackend::new());
let mut chain = TaskChain::new("test-execution");
let task1 = Box::new(TestTask { id: TaskId::new() });
let task2 = Box::new(TestTask { id: TaskId::new() });
let id1 = backend.enqueue(task1).await.unwrap();
let id2 = backend.enqueue(task2).await.unwrap();
chain.add_task(id1);
chain.add_task(id2);
chain.execute(backend).await.unwrap();
assert_eq!(chain.status(), ChainStatus::Completed);
}
}