use crate::task_queue::TaskQueue;
use crate::types::{AgentRole, Task};
use argentor_core::{ArgentorError, ArgentorResult};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpawnRequest {
pub description: String,
pub role: AgentRole,
pub parent_task_id: Uuid,
#[serde(default)]
pub depends_on: Vec<Uuid>,
}
pub struct SubAgentSpawner {
max_depth: u32,
max_children_per_task: u32,
queue: Arc<RwLock<TaskQueue>>,
}
impl SubAgentSpawner {
pub fn new(queue: Arc<RwLock<TaskQueue>>) -> Self {
Self {
max_depth: 3,
max_children_per_task: 5,
queue,
}
}
pub fn with_max_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
pub fn with_max_children(mut self, max: u32) -> Self {
self.max_children_per_task = max;
self
}
pub async fn spawn(&self, request: SpawnRequest) -> ArgentorResult<Uuid> {
let (parent_depth, children_count) = {
let queue = self.queue.read().await;
let parent = queue.get(request.parent_task_id).ok_or_else(|| {
ArgentorError::Orchestrator(format!(
"parent task {} not found",
request.parent_task_id
))
})?;
let parent_depth = parent.depth;
if parent_depth + 1 > self.max_depth {
return Err(ArgentorError::Orchestrator(format!(
"maximum spawn depth {} exceeded (parent depth is {})",
self.max_depth, parent_depth
)));
}
let children_count = queue
.all_tasks()
.iter()
.filter(|t| t.parent_task == Some(request.parent_task_id))
.count();
(parent_depth, children_count)
};
if children_count >= self.max_children_per_task as usize {
return Err(ArgentorError::Orchestrator(format!(
"parent task {} already has {} children (max {})",
request.parent_task_id, children_count, self.max_children_per_task
)));
}
let mut task =
Task::new(&request.description, request.role).with_dependencies(request.depends_on);
task.parent_task = Some(request.parent_task_id);
task.depth = parent_depth + 1;
let task_id = task.id;
{
let mut queue = self.queue.write().await;
queue.add(task);
}
Ok(task_id)
}
pub async fn children_of(&self, parent_id: Uuid) -> Vec<Uuid> {
let queue = self.queue.read().await;
queue
.all_tasks()
.iter()
.filter(|t| t.parent_task == Some(parent_id))
.map(|t| t.id)
.collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use crate::types::AgentRole;
fn setup_queue_with_root() -> (Arc<RwLock<TaskQueue>>, Uuid) {
let mut tq = TaskQueue::new();
let root = Task::new("Root task", AgentRole::Orchestrator);
let root_id = root.id;
tq.add(root);
(Arc::new(RwLock::new(tq)), root_id)
}
#[tokio::test]
async fn test_spawn_sub_task_successfully() {
let (queue, root_id) = setup_queue_with_root();
let spawner = SubAgentSpawner::new(Arc::clone(&queue));
let request = SpawnRequest {
description: "Implement auth module".to_string(),
role: AgentRole::Coder,
parent_task_id: root_id,
depends_on: vec![],
};
let child_id = spawner.spawn(request).await.unwrap();
let q = queue.read().await;
let child = q.get(child_id).unwrap();
assert_eq!(child.description, "Implement auth module");
assert_eq!(child.assigned_to, AgentRole::Coder);
assert_eq!(child.parent_task, Some(root_id));
assert_eq!(child.depth, 1);
}
#[tokio::test]
async fn test_depth_limit_prevents_too_deep_spawning() {
let (queue, root_id) = setup_queue_with_root();
let spawner = SubAgentSpawner::new(Arc::clone(&queue)).with_max_depth(2);
let child1_id = spawner
.spawn(SpawnRequest {
description: "Level 1".to_string(),
role: AgentRole::Spec,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
let child2_id = spawner
.spawn(SpawnRequest {
description: "Level 2".to_string(),
role: AgentRole::Coder,
parent_task_id: child1_id,
depends_on: vec![],
})
.await
.unwrap();
let result = spawner
.spawn(SpawnRequest {
description: "Level 3 (too deep)".to_string(),
role: AgentRole::Tester,
parent_task_id: child2_id,
depends_on: vec![],
})
.await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("maximum spawn depth"),
"unexpected error: {err_msg}"
);
}
#[tokio::test]
async fn test_children_limit_prevents_too_many_children() {
let (queue, root_id) = setup_queue_with_root();
let spawner = SubAgentSpawner::new(Arc::clone(&queue)).with_max_children(2);
spawner
.spawn(SpawnRequest {
description: "Child 1".to_string(),
role: AgentRole::Coder,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
spawner
.spawn(SpawnRequest {
description: "Child 2".to_string(),
role: AgentRole::Tester,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
let result = spawner
.spawn(SpawnRequest {
description: "Child 3 (over limit)".to_string(),
role: AgentRole::Reviewer,
parent_task_id: root_id,
depends_on: vec![],
})
.await;
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("already has 2 children"),
"unexpected error: {err_msg}"
);
}
#[tokio::test]
async fn test_children_of_returns_correct_children() {
let (queue, root_id) = setup_queue_with_root();
let spawner = SubAgentSpawner::new(Arc::clone(&queue));
let c1 = spawner
.spawn(SpawnRequest {
description: "Child A".to_string(),
role: AgentRole::Spec,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
let c2 = spawner
.spawn(SpawnRequest {
description: "Child B".to_string(),
role: AgentRole::Coder,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
let _gc = spawner
.spawn(SpawnRequest {
description: "Grandchild".to_string(),
role: AgentRole::Tester,
parent_task_id: c1,
depends_on: vec![],
})
.await
.unwrap();
let children = spawner.children_of(root_id).await;
assert_eq!(children.len(), 2);
assert!(children.contains(&c1));
assert!(children.contains(&c2));
}
#[tokio::test]
async fn test_parent_task_and_depth_fields_set_correctly() {
let (queue, root_id) = setup_queue_with_root();
let spawner = SubAgentSpawner::new(Arc::clone(&queue));
let d1_id = spawner
.spawn(SpawnRequest {
description: "Depth 1".to_string(),
role: AgentRole::Spec,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
let d2_id = spawner
.spawn(SpawnRequest {
description: "Depth 2".to_string(),
role: AgentRole::Coder,
parent_task_id: d1_id,
depends_on: vec![],
})
.await
.unwrap();
let q = queue.read().await;
let root = q.get(root_id).unwrap();
assert_eq!(root.depth, 0);
assert_eq!(root.parent_task, None);
let d1 = q.get(d1_id).unwrap();
assert_eq!(d1.depth, 1);
assert_eq!(d1.parent_task, Some(root_id));
let d2 = q.get(d2_id).unwrap();
assert_eq!(d2.depth, 2);
assert_eq!(d2.parent_task, Some(d1_id));
}
#[tokio::test]
async fn test_spawn_with_dependencies() {
let (queue, root_id) = setup_queue_with_root();
let spawner = SubAgentSpawner::new(Arc::clone(&queue));
let spec_id = spawner
.spawn(SpawnRequest {
description: "Write spec".to_string(),
role: AgentRole::Spec,
parent_task_id: root_id,
depends_on: vec![],
})
.await
.unwrap();
let code_id = spawner
.spawn(SpawnRequest {
description: "Write code".to_string(),
role: AgentRole::Coder,
parent_task_id: root_id,
depends_on: vec![spec_id],
})
.await
.unwrap();
let test_id = spawner
.spawn(SpawnRequest {
description: "Write tests".to_string(),
role: AgentRole::Tester,
parent_task_id: root_id,
depends_on: vec![spec_id, code_id],
})
.await
.unwrap();
let q = queue.read().await;
let test_task = q.get(test_id).unwrap();
assert_eq!(test_task.dependencies.len(), 2);
assert!(test_task.dependencies.contains(&spec_id));
assert!(test_task.dependencies.contains(&code_id));
assert_eq!(test_task.parent_task, Some(root_id));
assert!(!test_task.is_ready(&[]));
assert!(test_task.is_ready(&[spec_id, code_id]));
}
}