use crate::spec_ai_collective::types::{CollectiveError, Domain, ExecutionId, InstanceId, Result, WorkflowId};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum StageType {
Sequential,
Parallel { min_agents: usize },
MapReduce { chunks: usize },
Consensus { min_agreement: f32 },
ConditionalBranch { condition: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum StageState {
Pending,
Ready,
Running,
Completed,
Failed { reason: String },
Skipped,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowStage {
pub stage_id: String,
pub name: String,
pub description: String,
pub stage_type: StageType,
pub required_capabilities: Vec<Domain>,
pub dependencies: Vec<String>,
pub timeout: Duration,
pub config: serde_json::Value,
}
impl WorkflowStage {
pub fn sequential(
stage_id: impl Into<String>,
name: impl Into<String>,
description: impl Into<String>,
) -> Self {
Self {
stage_id: stage_id.into(),
name: name.into(),
description: description.into(),
stage_type: StageType::Sequential,
required_capabilities: Vec::new(),
dependencies: Vec::new(),
timeout: Duration::minutes(30),
config: serde_json::json!({}),
}
}
pub fn parallel(
stage_id: impl Into<String>,
name: impl Into<String>,
description: impl Into<String>,
min_agents: usize,
) -> Self {
Self {
stage_id: stage_id.into(),
name: name.into(),
description: description.into(),
stage_type: StageType::Parallel { min_agents },
required_capabilities: Vec::new(),
dependencies: Vec::new(),
timeout: Duration::minutes(30),
config: serde_json::json!({}),
}
}
pub fn map_reduce(
stage_id: impl Into<String>,
name: impl Into<String>,
description: impl Into<String>,
chunks: usize,
) -> Self {
Self {
stage_id: stage_id.into(),
name: name.into(),
description: description.into(),
stage_type: StageType::MapReduce { chunks },
required_capabilities: Vec::new(),
dependencies: Vec::new(),
timeout: Duration::minutes(60),
config: serde_json::json!({}),
}
}
pub fn consensus(
stage_id: impl Into<String>,
name: impl Into<String>,
description: impl Into<String>,
min_agreement: f32,
) -> Self {
Self {
stage_id: stage_id.into(),
name: name.into(),
description: description.into(),
stage_type: StageType::Consensus { min_agreement },
required_capabilities: Vec::new(),
dependencies: Vec::new(),
timeout: Duration::hours(1),
config: serde_json::json!({}),
}
}
pub fn with_capabilities(mut self, capabilities: Vec<String>) -> Self {
self.required_capabilities = capabilities;
self
}
pub fn with_dependencies(mut self, dependencies: Vec<String>) -> Self {
self.dependencies = dependencies;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_config(mut self, config: serde_json::Value) -> Self {
self.config = config;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WorkflowState {
Draft,
Running,
Completed,
Failed { reason: String },
Cancelled,
Paused,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Workflow {
pub workflow_id: WorkflowId,
pub name: String,
pub description: String,
pub stages: Vec<WorkflowStage>,
pub state: WorkflowState,
pub created_by: InstanceId,
pub created_at: DateTime<Utc>,
pub input: serde_json::Value,
}
impl Workflow {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
created_by: InstanceId,
) -> Self {
Self {
workflow_id: uuid::Uuid::new_v4().to_string(),
name: name.into(),
description: description.into(),
stages: Vec::new(),
state: WorkflowState::Draft,
created_by,
created_at: Utc::now(),
input: serde_json::json!({}),
}
}
pub fn add_stage(mut self, stage: WorkflowStage) -> Self {
self.stages.push(stage);
self
}
pub fn with_input(mut self, input: serde_json::Value) -> Self {
self.input = input;
self
}
pub fn validate(&self) -> Result<()> {
let stage_ids: std::collections::HashSet<_> =
self.stages.iter().map(|s| s.stage_id.as_str()).collect();
for stage in &self.stages {
for dep in &stage.dependencies {
if !stage_ids.contains(dep.as_str()) {
return Err(CollectiveError::WorkflowExecutionFailed(format!(
"Stage {} depends on unknown stage {}",
stage.stage_id, dep
)));
}
}
}
for stage in &self.stages {
if stage.dependencies.contains(&stage.stage_id) {
return Err(CollectiveError::WorkflowExecutionFailed(format!(
"Stage {} has a self-dependency",
stage.stage_id
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageExecution {
pub stage_id: String,
pub state: StageState,
pub assigned_agents: Vec<InstanceId>,
pub results: HashMap<InstanceId, serde_json::Value>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub error: Option<String>,
}
impl StageExecution {
pub fn new(stage_id: String) -> Self {
Self {
stage_id,
state: StageState::Pending,
assigned_agents: Vec::new(),
results: HashMap::new(),
started_at: None,
completed_at: None,
error: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowExecution {
pub workflow_id: WorkflowId,
pub execution_id: ExecutionId,
pub stages: HashMap<String, StageExecution>,
pub results: HashMap<String, serde_json::Value>,
pub started_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub state: WorkflowState,
}
impl WorkflowExecution {
pub fn new(workflow: &Workflow) -> Self {
let mut stages = HashMap::new();
for stage in &workflow.stages {
stages.insert(
stage.stage_id.clone(),
StageExecution::new(stage.stage_id.clone()),
);
}
Self {
workflow_id: workflow.workflow_id.clone(),
execution_id: uuid::Uuid::new_v4().to_string(),
stages,
results: HashMap::new(),
started_at: Utc::now(),
completed_at: None,
state: WorkflowState::Running,
}
}
pub fn ready_stages<'a>(&self, workflow: &'a Workflow) -> Vec<&'a str> {
let mut ready = Vec::new();
for stage in &workflow.stages {
if let Some(execution) = self.stages.get(&stage.stage_id) {
if execution.state != StageState::Pending {
continue;
}
let deps_completed = stage.dependencies.iter().all(|dep| {
self.stages
.get(dep)
.map(|s| s.state == StageState::Completed)
.unwrap_or(false)
});
if deps_completed {
ready.push(stage.stage_id.as_str());
}
}
}
ready
}
pub fn is_complete(&self) -> bool {
self.stages
.values()
.all(|s| matches!(s.state, StageState::Completed | StageState::Skipped))
}
pub fn has_failed(&self) -> bool {
self.stages
.values()
.any(|s| matches!(s.state, StageState::Failed { .. }))
}
}
#[derive(Debug)]
pub struct WorkflowEngine {
instance_id: InstanceId,
workflows: HashMap<WorkflowId, Workflow>,
executions: HashMap<ExecutionId, WorkflowExecution>,
max_concurrent: usize,
}
impl WorkflowEngine {
pub fn new(instance_id: InstanceId) -> Self {
Self {
instance_id,
workflows: HashMap::new(),
executions: HashMap::new(),
max_concurrent: 5,
}
}
pub fn instance_id(&self) -> &str {
&self.instance_id
}
pub fn set_max_concurrent(&mut self, max: usize) {
self.max_concurrent = max;
}
pub fn register_workflow(&mut self, workflow: Workflow) -> Result<WorkflowId> {
workflow.validate()?;
let workflow_id = workflow.workflow_id.clone();
self.workflows.insert(workflow_id.clone(), workflow);
Ok(workflow_id)
}
pub fn get_workflow(&self, workflow_id: &str) -> Option<&Workflow> {
self.workflows.get(workflow_id)
}
pub fn start_execution(&mut self, workflow_id: &str) -> Result<ExecutionId> {
if self.executions.len() >= self.max_concurrent {
return Err(CollectiveError::WorkflowExecutionFailed(
"Maximum concurrent workflows reached".to_string(),
));
}
let workflow = self
.workflows
.get(workflow_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(workflow_id.to_string()))?;
let execution = WorkflowExecution::new(workflow);
let execution_id = execution.execution_id.clone();
self.executions.insert(execution_id.clone(), execution);
Ok(execution_id)
}
pub fn get_execution(&self, execution_id: &str) -> Option<&WorkflowExecution> {
self.executions.get(execution_id)
}
pub fn get_execution_mut(&mut self, execution_id: &str) -> Option<&mut WorkflowExecution> {
self.executions.get_mut(execution_id)
}
pub fn start_stage(
&mut self,
execution_id: &str,
stage_id: &str,
agents: Vec<InstanceId>,
) -> Result<()> {
let execution = self
.executions
.get_mut(execution_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(execution_id.to_string()))?;
if let Some(stage) = execution.stages.get_mut(stage_id) {
stage.state = StageState::Running;
stage.assigned_agents = agents;
stage.started_at = Some(Utc::now());
}
Ok(())
}
pub fn record_stage_result(
&mut self,
execution_id: &str,
stage_id: &str,
agent_id: InstanceId,
result: serde_json::Value,
) -> Result<()> {
let execution = self
.executions
.get_mut(execution_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(execution_id.to_string()))?;
if let Some(stage) = execution.stages.get_mut(stage_id) {
stage.results.insert(agent_id, result);
}
Ok(())
}
pub fn complete_stage(
&mut self,
execution_id: &str,
stage_id: &str,
final_result: serde_json::Value,
) -> Result<()> {
let execution = self
.executions
.get_mut(execution_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(execution_id.to_string()))?;
if let Some(stage) = execution.stages.get_mut(stage_id) {
stage.state = StageState::Completed;
stage.completed_at = Some(Utc::now());
}
execution.results.insert(stage_id.to_string(), final_result);
if execution.is_complete() {
execution.state = WorkflowState::Completed;
execution.completed_at = Some(Utc::now());
}
Ok(())
}
pub fn fail_stage(&mut self, execution_id: &str, stage_id: &str, reason: String) -> Result<()> {
let execution = self
.executions
.get_mut(execution_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(execution_id.to_string()))?;
if let Some(stage) = execution.stages.get_mut(stage_id) {
stage.state = StageState::Failed {
reason: reason.clone(),
};
stage.error = Some(reason.clone());
stage.completed_at = Some(Utc::now());
}
execution.state = WorkflowState::Failed { reason };
execution.completed_at = Some(Utc::now());
Ok(())
}
pub fn get_ready_stages(&self, execution_id: &str) -> Result<Vec<String>> {
let execution = self
.executions
.get(execution_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(execution_id.to_string()))?;
let workflow = self
.workflows
.get(&execution.workflow_id)
.ok_or_else(|| CollectiveError::WorkflowNotFound(execution.workflow_id.clone()))?;
Ok(execution
.ready_stages(workflow)
.into_iter()
.map(String::from)
.collect())
}
pub fn active_executions(&self) -> Vec<&WorkflowExecution> {
self.executions
.values()
.filter(|e| e.state == WorkflowState::Running)
.collect()
}
pub fn cleanup_completed(&mut self, max_age: Duration) -> usize {
let cutoff = Utc::now() - max_age;
let before = self.executions.len();
self.executions
.retain(|_, e| e.completed_at.map(|t| t > cutoff).unwrap_or(true));
before - self.executions.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workflow_creation() {
let workflow = Workflow::new("test-workflow", "A test workflow", "agent-1".to_string())
.add_stage(WorkflowStage::sequential(
"stage-1",
"First Stage",
"Do first thing",
))
.add_stage(
WorkflowStage::parallel("stage-2", "Second Stage", "Do in parallel", 2)
.with_dependencies(vec!["stage-1".to_string()]),
);
assert_eq!(workflow.stages.len(), 2);
assert!(workflow.validate().is_ok());
}
#[test]
fn test_workflow_execution() {
let mut engine = WorkflowEngine::new("agent-1".to_string());
let workflow = Workflow::new("test", "Test", "agent-1".to_string())
.add_stage(WorkflowStage::sequential("s1", "Stage 1", "First"))
.add_stage(
WorkflowStage::sequential("s2", "Stage 2", "Second")
.with_dependencies(vec!["s1".to_string()]),
);
let workflow_id = engine.register_workflow(workflow).unwrap();
let execution_id = engine.start_execution(&workflow_id).unwrap();
let ready = engine.get_ready_stages(&execution_id).unwrap();
assert_eq!(ready, vec!["s1"]);
engine
.start_stage(&execution_id, "s1", vec!["agent-1".to_string()])
.unwrap();
engine
.complete_stage(&execution_id, "s1", serde_json::json!({"done": true}))
.unwrap();
let ready = engine.get_ready_stages(&execution_id).unwrap();
assert_eq!(ready, vec!["s2"]);
}
}