use crate::error::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskOutput {
pub raw: String,
pub json: Option<serde_json::Value>,
pub task_id: String,
pub agent_name: Option<String>,
pub duration_ms: Option<u64>,
pub tokens_used: Option<u32>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl TaskOutput {
pub fn new(raw: impl Into<String>, task_id: impl Into<String>) -> Self {
Self {
raw: raw.into(),
json: None,
task_id: task_id.into(),
agent_name: None,
duration_ms: None,
tokens_used: None,
metadata: HashMap::new(),
}
}
pub fn with_json(mut self, json: serde_json::Value) -> Self {
self.json = Some(json);
self
}
pub fn with_agent(mut self, name: impl Into<String>) -> Self {
self.agent_name = Some(name.into());
self
}
pub fn with_duration(mut self, ms: u64) -> Self {
self.duration_ms = Some(ms);
self
}
pub fn with_tokens(mut self, tokens: u32) -> Self {
self.tokens_used = Some(tokens);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn as_str(&self) -> &str {
&self.raw
}
pub fn parse_json(&self) -> Result<serde_json::Value> {
serde_json::from_str(&self.raw)
.map_err(|e| Error::config(format!("Failed to parse output as JSON: {}", e)))
}
}
impl std::fmt::Display for TaskOutput {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.raw)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum TaskStatus {
#[default]
NotStarted,
InProgress,
Completed,
Failed,
Skipped,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum TaskType {
#[default]
Task,
Decision,
Loop,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum OnError {
#[default]
Stop,
Continue,
Retry,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskConfig {
pub max_retries: u32,
pub retry_delay: f64,
pub on_error: OnError,
pub skip_on_failure: bool,
pub quality_check: bool,
pub async_execution: bool,
}
impl Default for TaskConfig {
fn default() -> Self {
Self {
max_retries: 3,
retry_delay: 0.0,
on_error: OnError::Stop,
skip_on_failure: false,
quality_check: true,
async_execution: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: String,
pub name: Option<String>,
pub description: String,
pub expected_output: String,
pub status: TaskStatus,
pub task_type: TaskType,
#[serde(skip)]
pub result: Option<TaskOutput>,
#[serde(default)]
pub depends_on: Vec<String>,
#[serde(default)]
pub next_tasks: Vec<String>,
#[serde(default)]
pub condition: HashMap<String, Vec<String>>,
#[serde(default)]
pub config: TaskConfig,
pub output_file: Option<String>,
pub output_variable: Option<String>,
#[serde(default)]
pub variables: HashMap<String, serde_json::Value>,
pub retry_count: u32,
pub is_start: bool,
}
impl Task {
pub fn new(description: impl Into<String>) -> TaskBuilder {
TaskBuilder::new(description)
}
pub fn id(&self) -> &str {
&self.id
}
pub fn display_name(&self) -> &str {
self.name.as_deref().unwrap_or(&self.description)
}
pub fn is_completed(&self) -> bool {
matches!(self.status, TaskStatus::Completed)
}
pub fn is_failed(&self) -> bool {
matches!(self.status, TaskStatus::Failed)
}
pub fn can_retry(&self) -> bool {
self.retry_count < self.config.max_retries
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
pub fn set_result(&mut self, output: TaskOutput) {
self.result = Some(output);
self.status = TaskStatus::Completed;
}
pub fn set_failed(&mut self, error: &str) {
self.status = TaskStatus::Failed;
self.result = Some(TaskOutput::new(format!("Error: {}", error), &self.id));
}
pub fn result_str(&self) -> Option<&str> {
self.result.as_ref().map(|r| r.raw.as_str())
}
pub fn substitute_variables(&self, context: &HashMap<String, String>) -> String {
let mut result = self.description.clone();
for (key, value) in context {
result = result.replace(&format!("{{{{{}}}}}", key), value);
}
for (key, value) in &self.variables {
let value_str = match value {
serde_json::Value::String(s) => s.clone(),
_ => value.to_string(),
};
result = result.replace(&format!("{{{{{}}}}}", key), &value_str);
}
result
}
pub fn to_dict(&self) -> serde_json::Value {
serde_json::json!({
"id": self.id,
"name": self.name,
"description": self.description,
"expected_output": self.expected_output,
"status": self.status,
"task_type": self.task_type,
"depends_on": self.depends_on,
"next_tasks": self.next_tasks,
"condition": self.condition,
"is_start": self.is_start,
})
}
}
pub struct TaskBuilder {
description: String,
name: Option<String>,
expected_output: String,
depends_on: Vec<String>,
next_tasks: Vec<String>,
condition: HashMap<String, Vec<String>>,
config: TaskConfig,
output_file: Option<String>,
output_variable: Option<String>,
variables: HashMap<String, serde_json::Value>,
task_type: TaskType,
is_start: bool,
}
impl TaskBuilder {
pub fn new(description: impl Into<String>) -> Self {
Self {
description: description.into(),
name: None,
expected_output: "Complete the task successfully".to_string(),
depends_on: Vec::new(),
next_tasks: Vec::new(),
condition: HashMap::new(),
config: TaskConfig::default(),
output_file: None,
output_variable: None,
variables: HashMap::new(),
task_type: TaskType::Task,
is_start: false,
}
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn expected_output(mut self, output: impl Into<String>) -> Self {
self.expected_output = output.into();
self
}
pub fn depends_on(mut self, task: impl Into<String>) -> Self {
self.depends_on.push(task.into());
self
}
pub fn next_task(mut self, task: impl Into<String>) -> Self {
self.next_tasks.push(task.into());
self
}
pub fn task_type(mut self, task_type: TaskType) -> Self {
self.task_type = task_type;
self
}
pub fn decision(mut self) -> Self {
self.task_type = TaskType::Decision;
self
}
pub fn loop_task(mut self) -> Self {
self.task_type = TaskType::Loop;
self
}
pub fn output_file(mut self, path: impl Into<String>) -> Self {
self.output_file = Some(path.into());
self
}
pub fn output_variable(mut self, name: impl Into<String>) -> Self {
self.output_variable = Some(name.into());
self
}
pub fn variable(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.variables.insert(key.into(), value);
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.config.max_retries = retries;
self
}
pub fn on_error(mut self, behavior: OnError) -> Self {
self.config.on_error = behavior;
self
}
pub fn is_start(mut self, is_start: bool) -> Self {
self.is_start = is_start;
self
}
pub fn build(self) -> Task {
Task {
id: uuid::Uuid::new_v4().to_string(),
name: self.name,
description: self.description,
expected_output: self.expected_output,
status: TaskStatus::NotStarted,
task_type: self.task_type,
result: None,
depends_on: self.depends_on,
next_tasks: self.next_tasks,
condition: self.condition,
config: self.config,
output_file: self.output_file,
output_variable: self.output_variable,
variables: self.variables,
retry_count: 0,
is_start: self.is_start,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_creation() {
let task = Task::new("Research AI trends")
.name("research_task")
.expected_output("A summary of AI trends")
.build();
assert_eq!(task.description, "Research AI trends");
assert_eq!(task.name, Some("research_task".to_string()));
assert_eq!(task.status, TaskStatus::NotStarted);
}
#[test]
fn test_task_output() {
let output = TaskOutput::new("Hello world", "task-1")
.with_agent("my-agent")
.with_duration(100);
assert_eq!(output.raw, "Hello world");
assert_eq!(output.agent_name, Some("my-agent".to_string()));
assert_eq!(output.duration_ms, Some(100));
}
#[test]
fn test_task_dependencies() {
let task = Task::new("Analyze data")
.depends_on("collect_data")
.depends_on("clean_data")
.build();
assert_eq!(task.depends_on.len(), 2);
assert!(task.depends_on.contains(&"collect_data".to_string()));
}
#[test]
fn test_variable_substitution() {
let mut variables = HashMap::new();
variables.insert("topic".to_string(), serde_json::json!("AI"));
let task = Task::new("Research {{topic}} trends")
.variable("topic", serde_json::json!("AI"))
.build();
let context = HashMap::new();
let result = task.substitute_variables(&context);
assert_eq!(result, "Research AI trends");
}
#[test]
fn test_task_status() {
let mut task = Task::new("Test task").build();
assert!(!task.is_completed());
assert!(!task.is_failed());
task.set_result(TaskOutput::new("Done", &task.id));
assert!(task.is_completed());
let mut task2 = Task::new("Test task 2").build();
task2.set_failed("Something went wrong");
assert!(task2.is_failed());
}
#[test]
fn test_retry_logic() {
let mut task = Task::new("Retryable task").max_retries(3).build();
assert!(task.can_retry());
task.increment_retry();
assert!(task.can_retry());
task.increment_retry();
assert!(task.can_retry());
task.increment_retry();
assert!(!task.can_retry());
}
}