use std::collections::{HashMap, HashSet, VecDeque};
use std::time::Duration;
use serde::{Deserialize, Serialize};
use super::coordination::AgentId;
use super::types::{LoopConfig, LoopSummary};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceBudget {
pub total_iterations: u32,
pub total_tool_calls: u32,
pub total_tokens: u32,
}
impl Default for ResourceBudget {
fn default() -> Self {
Self {
total_iterations: 100,
total_tool_calls: 500,
total_tokens: 131_072,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Complexity {
Low,
Medium,
High,
}
impl Complexity {
fn weight(self) -> f64 {
match self {
Self::Low => 1.0,
Self::Medium => 2.0,
Self::High => 4.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Subtask {
pub id: String,
pub objective: String,
pub depends_on: Vec<String>,
pub capabilities: Vec<String>,
pub complexity: Complexity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DecompositionStrategy {
SingleAgent,
ClientProvided {
subtasks: Vec<Subtask>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RoutingStrategy {
Parallel,
DependencyAware,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SharedContextMode {
None,
SummarySharing,
FullSharing,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SupervisorConfig {
pub resource_budget: ResourceBudget,
pub decomposition: DecompositionStrategy,
pub routing: RoutingStrategy,
pub shared_context_mode: SharedContextMode,
pub max_concurrent_agents: u32,
pub max_retries: u32,
pub circuit_breaker_threshold: u32,
}
impl Default for SupervisorConfig {
fn default() -> Self {
Self {
resource_budget: ResourceBudget::default(),
decomposition: DecompositionStrategy::SingleAgent,
routing: RoutingStrategy::DependencyAware,
shared_context_mode: SharedContextMode::SummarySharing,
max_concurrent_agents: 3,
max_retries: 2,
circuit_breaker_threshold: 3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SupervisorEvent {
AgentSpawned {
agent_id: AgentId,
subtask_id: String,
},
AgentCompleted {
agent_id: AgentId,
subtask_id: String,
summary: LoopSummary,
},
Rerouted {
from_agent: AgentId,
to_agent: AgentId,
subtask_id: String,
reason: RerouteReason,
},
SupervisorError {
message: String,
recoverable: bool,
},
SupervisorCompleted {
summary: SupervisorSummary,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RerouteReason {
AgentStuck {
attempts: u32,
},
AgentYielded {
suggested_expertise: Vec<String>,
},
EngineError {
retries: u32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SupervisorTermination {
AllComplete,
PartialComplete {
completed: u32,
failed: u32,
},
Failed {
reason: String,
},
ResourceExhausted,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SubtaskStatus {
Completed,
Failed {
reason: String,
},
Partial {
progress: String,
},
Skipped {
reason: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubtaskResult {
pub subtask_id: String,
pub status: SubtaskStatus,
pub summary: Option<LoopSummary>,
pub agent_id: Option<AgentId>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SupervisorSummary {
pub termination: SupervisorTermination,
pub subtask_results: Vec<SubtaskResult>,
pub total_agents_spawned: u32,
pub total_iterations: u32,
pub total_tool_calls: u32,
pub total_tokens: u32,
pub wall_time: Duration,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ResourceConsumption {
pub iterations: u32,
pub tool_calls: u32,
pub tokens: u32,
}
#[derive(Debug)]
pub struct BudgetAllocator {
budget: ResourceBudget,
consumed: ResourceConsumption,
}
impl BudgetAllocator {
pub fn new(budget: ResourceBudget) -> Self {
Self {
budget,
consumed: ResourceConsumption::default(),
}
}
pub fn remaining(&self) -> ResourceConsumption {
ResourceConsumption {
iterations: self
.budget
.total_iterations
.saturating_sub(self.consumed.iterations),
tool_calls: self
.budget
.total_tool_calls
.saturating_sub(self.consumed.tool_calls),
tokens: self
.budget
.total_tokens
.saturating_sub(self.consumed.tokens),
}
}
pub fn record_consumption(&mut self, consumption: &ResourceConsumption) {
self.consumed.iterations = self
.consumed
.iterations
.saturating_add(consumption.iterations);
self.consumed.tool_calls = self
.consumed
.tool_calls
.saturating_add(consumption.tool_calls);
self.consumed.tokens = self.consumed.tokens.saturating_add(consumption.tokens);
}
pub fn total_consumed(&self) -> &ResourceConsumption {
&self.consumed
}
pub fn allocate(&self, subtask: &Subtask, total_weight: f64) -> LoopConfig {
let remaining = self.remaining();
let fraction = if total_weight > 0.0 {
subtask.complexity.weight() / total_weight
} else {
1.0
};
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let max_iterations = (f64::from(remaining.iterations) * fraction)
.round()
.max(1.0) as u32;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let max_tool_calls = (f64::from(remaining.tool_calls) * fraction)
.round()
.max(1.0) as u32;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let max_tokens = (f64::from(remaining.tokens) * fraction).round().max(1.0) as u32;
LoopConfig {
max_iterations,
max_tool_calls,
max_tokens,
..LoopConfig::default()
}
}
}
pub fn rebalance_budget(
remaining_iterations: u32,
remaining_calls: u32,
remaining_tokens: u32,
running_count: u32,
) -> Vec<LoopConfig> {
if running_count == 0 {
return vec![];
}
let per_iter = remaining_iterations / running_count;
let per_calls = remaining_calls / running_count;
let per_tokens = remaining_tokens / running_count;
let iter_rem = remaining_iterations % running_count;
let calls_rem = remaining_calls % running_count;
let tokens_rem = remaining_tokens % running_count;
(0..running_count)
.map(|i| {
let iter_extra = u32::from(i < iter_rem);
let calls_extra = u32::from(i < calls_rem);
let tokens_extra = u32::from(i < tokens_rem);
LoopConfig {
max_iterations: per_iter + iter_extra,
max_tool_calls: per_calls + calls_extra,
max_tokens: per_tokens + tokens_extra,
..LoopConfig::default()
}
})
.collect()
}
#[derive(Debug)]
pub struct DependencyResolver {
subtasks: HashMap<String, Subtask>,
completed: HashSet<String>,
running: HashSet<String>,
failed: HashSet<String>,
}
impl DependencyResolver {
pub fn new(subtasks: Vec<Subtask>) -> Self {
let map: HashMap<String, Subtask> =
subtasks.into_iter().map(|s| (s.id.clone(), s)).collect();
Self {
subtasks: map,
completed: HashSet::new(),
running: HashSet::new(),
failed: HashSet::new(),
}
}
pub fn ready(&self) -> Vec<String> {
self.subtasks
.values()
.filter(|s| {
!self.completed.contains(&s.id)
&& !self.running.contains(&s.id)
&& !self.failed.contains(&s.id)
&& s.depends_on.iter().all(|dep| self.completed.contains(dep))
})
.map(|s| s.id.clone())
.collect()
}
pub fn mark_running(&mut self, id: &str) {
self.running.insert(id.to_string());
}
pub fn mark_completed(&mut self, id: &str) {
self.running.remove(id);
self.completed.insert(id.to_string());
}
pub fn mark_failed(&mut self, id: &str) {
self.running.remove(id);
self.failed.insert(id.to_string());
}
pub fn is_done(&self) -> bool {
self.subtasks
.keys()
.all(|id| self.completed.contains(id) || self.failed.contains(id))
}
pub fn get(&self, id: &str) -> Option<&Subtask> {
self.subtasks.get(id)
}
pub fn total_weight(&self, ids: &[String]) -> f64 {
ids.iter()
.filter_map(|id| self.subtasks.get(id))
.map(|s| s.complexity.weight())
.sum()
}
pub fn completed_count(&self) -> usize {
self.completed.len()
}
pub fn failed_count(&self) -> usize {
self.failed.len()
}
pub fn total(&self) -> usize {
self.subtasks.len()
}
pub fn validate(&self) -> Result<(), SupervisorError> {
for subtask in self.subtasks.values() {
for dep in &subtask.depends_on {
if !self.subtasks.contains_key(dep) {
return Err(SupervisorError::InvalidDependency {
subtask: subtask.id.clone(),
missing_dep: dep.clone(),
});
}
}
}
let mut visited = HashSet::new();
let mut in_stack = HashSet::new();
for id in self.subtasks.keys() {
if !visited.contains(id) && self.has_cycle(id, &mut visited, &mut in_stack) {
return Err(SupervisorError::CyclicDependency {
subtask: id.clone(),
});
}
}
Ok(())
}
fn has_cycle(
&self,
id: &str,
visited: &mut HashSet<String>,
in_stack: &mut HashSet<String>,
) -> bool {
visited.insert(id.to_string());
in_stack.insert(id.to_string());
if let Some(subtask) = self.subtasks.get(id) {
for dep in &subtask.depends_on {
if !visited.contains(dep) {
if self.has_cycle(dep, visited, in_stack) {
return true;
}
} else if in_stack.contains(dep) {
return true;
}
}
}
in_stack.remove(id);
false
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum SupervisorError {
#[error("subtask '{subtask}' depends on '{missing_dep}' which does not exist")]
InvalidDependency {
subtask: String,
missing_dep: String,
},
#[error("cyclic dependency detected involving subtask '{subtask}'")]
CyclicDependency {
subtask: String,
},
#[error("circuit breaker triggered after {consecutive} consecutive {failure_type} failures")]
CircuitBreakerTriggered {
consecutive: u32,
failure_type: String,
},
}
#[derive(Debug)]
pub struct ConcurrencyLimiter {
max_concurrent: u32,
active: u32,
max_observed: u32,
queue: VecDeque<String>,
}
impl ConcurrencyLimiter {
pub fn new(max_concurrent: u32) -> Self {
Self {
max_concurrent,
active: 0,
max_observed: 0,
queue: VecDeque::new(),
}
}
pub fn try_acquire(&mut self) -> bool {
if self.active < self.max_concurrent {
self.active += 1;
self.max_observed = self.max_observed.max(self.active);
true
} else {
false
}
}
pub fn release(&mut self) -> Option<String> {
self.active = self.active.saturating_sub(1);
self.queue.pop_front()
}
pub fn enqueue(&mut self, subtask_id: String) {
self.queue.push_back(subtask_id);
}
pub fn active_count(&self) -> u32 {
self.active
}
pub fn max_observed(&self) -> u32 {
self.max_observed
}
pub fn queued_count(&self) -> usize {
self.queue.len()
}
}
#[derive(Debug)]
pub struct LifecycleTracker {
spawned: HashMap<AgentId, String>,
resolved: HashSet<AgentId>,
total_spawned: u32,
}
impl LifecycleTracker {
pub fn new() -> Self {
Self {
spawned: HashMap::new(),
resolved: HashSet::new(),
total_spawned: 0,
}
}
pub fn record_spawn(&mut self, agent_id: AgentId, subtask_id: String) {
self.spawned.insert(agent_id, subtask_id);
self.total_spawned += 1;
}
pub fn record_completion(&mut self, agent_id: &str) {
self.resolved.insert(agent_id.to_string());
}
pub fn record_reroute(&mut self, from_agent: &str) {
self.resolved.insert(from_agent.to_string());
}
pub fn zombies(&self) -> Vec<AgentId> {
self.spawned
.keys()
.filter(|id| !self.resolved.contains(*id))
.cloned()
.collect()
}
pub fn all_resolved(&self) -> bool {
self.zombies().is_empty()
}
pub fn total_spawned(&self) -> u32 {
self.total_spawned
}
}
impl Default for LifecycleTracker {
fn default() -> Self {
Self::new()
}
}
pub struct RerouteResolver;
impl RerouteResolver {
pub fn find_match(
requested_expertise: &[String],
available_subtasks: &[Subtask],
exclude_ids: &HashSet<String>,
) -> Option<String> {
let mut best_match: Option<(String, usize)> = None;
for subtask in available_subtasks {
if exclude_ids.contains(&subtask.id) {
continue;
}
let matching_caps = subtask
.capabilities
.iter()
.filter(|cap| {
requested_expertise
.iter()
.any(|req| cap.contains(req.as_str()))
})
.count();
if matching_caps > 0 {
if let Some((_, best_count)) = &best_match {
if matching_caps > *best_count {
best_match = Some((subtask.id.clone(), matching_caps));
}
} else {
best_match = Some((subtask.id.clone(), matching_caps));
}
}
}
best_match.map(|(id, _)| id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum FailureType {
EngineError,
AgentStuck,
AgentYielded,
ToolError,
Timeout,
}
impl std::fmt::Display for FailureType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EngineError => write!(f, "engine_error"),
Self::AgentStuck => write!(f, "agent_stuck"),
Self::AgentYielded => write!(f, "agent_yielded"),
Self::ToolError => write!(f, "tool_error"),
Self::Timeout => write!(f, "timeout"),
}
}
}
#[derive(Debug)]
pub struct CircuitBreaker {
threshold: u32,
consecutive_count: u32,
last_failure_type: Option<FailureType>,
is_open: bool,
}
impl CircuitBreaker {
pub fn new(threshold: u32) -> Self {
Self {
threshold,
consecutive_count: 0,
last_failure_type: None,
is_open: false,
}
}
pub fn record_failure(&mut self, failure_type: FailureType) -> Result<(), SupervisorError> {
if self.last_failure_type == Some(failure_type) {
self.consecutive_count += 1;
} else {
self.last_failure_type = Some(failure_type);
self.consecutive_count = 1;
}
if self.consecutive_count >= self.threshold {
self.is_open = true;
return Err(SupervisorError::CircuitBreakerTriggered {
consecutive: self.consecutive_count,
failure_type: failure_type.to_string(),
});
}
Ok(())
}
pub fn record_success(&mut self) {
self.consecutive_count = 0;
self.last_failure_type = None;
}
pub fn is_open(&self) -> bool {
self.is_open
}
pub fn reset(&mut self) {
self.is_open = false;
self.consecutive_count = 0;
self.last_failure_type = None;
}
pub fn consecutive_failures(&self) -> u32 {
self.consecutive_count
}
}
#[derive(Debug)]
pub struct ResultAggregator {
results: HashMap<String, SubtaskResult>,
}
impl ResultAggregator {
pub fn new() -> Self {
Self {
results: HashMap::new(),
}
}
pub fn record_result(&mut self, result: SubtaskResult) {
self.results.insert(result.subtask_id.clone(), result);
}
pub fn record_skipped(&mut self, subtask_id: &str, reason: &str) {
self.results.insert(
subtask_id.to_string(),
SubtaskResult {
subtask_id: subtask_id.to_string(),
status: SubtaskStatus::Skipped {
reason: reason.to_string(),
},
summary: None,
agent_id: None,
},
);
}
pub fn get(&self, subtask_id: &str) -> Option<&SubtaskResult> {
self.results.get(subtask_id)
}
pub fn all_results(&self) -> Vec<SubtaskResult> {
self.results.values().cloned().collect()
}
pub fn completed_count(&self) -> usize {
self.results
.values()
.filter(|r| matches!(r.status, SubtaskStatus::Completed))
.count()
}
pub fn failed_count(&self) -> usize {
self.results
.values()
.filter(|r| matches!(r.status, SubtaskStatus::Failed { .. }))
.count()
}
pub fn skipped_count(&self) -> usize {
self.results
.values()
.filter(|r| matches!(r.status, SubtaskStatus::Skipped { .. }))
.count()
}
pub fn build_summary(
&self,
total_spawned: u32,
consumed: &ResourceConsumption,
wall_time: Duration,
total_subtasks: usize,
) -> SupervisorSummary {
let completed = self.completed_count();
let failed = self.failed_count();
#[allow(clippy::cast_possible_truncation)]
let termination = if completed == total_subtasks {
SupervisorTermination::AllComplete
} else if completed == 0 && failed > 0 {
SupervisorTermination::Failed {
reason: format!("all {failed} subtasks failed"),
}
} else {
SupervisorTermination::PartialComplete {
completed: completed as u32,
failed: failed as u32,
}
};
SupervisorSummary {
termination,
subtask_results: self.all_results(),
total_agents_spawned: total_spawned,
total_iterations: consumed.iterations,
total_tool_calls: consumed.tool_calls,
total_tokens: consumed.tokens,
wall_time,
}
}
}
impl Default for ResultAggregator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AgentWellbeingState {
Healthy,
Cautious,
Concerned,
Distressed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WellbeingAggregate {
pub agents_total: usize,
pub agents_healthy: usize,
pub agents_cautious: usize,
pub agents_concerned: usize,
pub agents_distressed: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SupervisorWellbeingAction {
Continue,
PauseAndReplan,
EscalateToClient,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum WellbeingResponse {
Continue,
Pause,
Reassign,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentWellbeingAction {
pub agent_id: AgentId,
pub response: WellbeingResponse,
}
pub fn compute_aggregate_wellbeing(states: &[AgentWellbeingState]) -> WellbeingAggregate {
let mut agg = WellbeingAggregate {
agents_total: states.len(),
agents_healthy: 0,
agents_cautious: 0,
agents_concerned: 0,
agents_distressed: 0,
};
for state in states {
match state {
AgentWellbeingState::Healthy => agg.agents_healthy += 1,
AgentWellbeingState::Cautious => agg.agents_cautious += 1,
AgentWellbeingState::Concerned => agg.agents_concerned += 1,
AgentWellbeingState::Distressed => agg.agents_distressed += 1,
}
}
agg
}
pub fn supervisor_level_response(aggregate: &WellbeingAggregate) -> SupervisorWellbeingAction {
if aggregate.agents_total == 0 {
return SupervisorWellbeingAction::Continue;
}
if aggregate.agents_distressed > 0 {
return SupervisorWellbeingAction::EscalateToClient;
}
#[allow(clippy::cast_precision_loss)]
let concerned_fraction = aggregate.agents_concerned as f64 / aggregate.agents_total as f64;
if concerned_fraction > 0.5 {
return SupervisorWellbeingAction::PauseAndReplan;
}
SupervisorWellbeingAction::Continue
}
pub fn supervisor_wellbeing_response(
agent_states: &[(AgentId, AgentWellbeingState)],
) -> Vec<AgentWellbeingAction> {
agent_states
.iter()
.filter_map(|(agent_id, state)| {
let response = match state {
AgentWellbeingState::Healthy | AgentWellbeingState::Cautious => {
return None; },
AgentWellbeingState::Concerned => WellbeingResponse::Pause,
AgentWellbeingState::Distressed => WellbeingResponse::Reassign,
};
Some(AgentWellbeingAction {
agent_id: agent_id.clone(),
response,
})
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn subtask(id: &str, deps: Vec<&str>, complexity: Complexity) -> Subtask {
Subtask {
id: id.to_string(),
objective: format!("Do {id}"),
depends_on: deps.into_iter().map(String::from).collect(),
capabilities: vec![],
complexity,
}
}
fn subtask_with_capabilities(id: &str, deps: Vec<&str>, capabilities: Vec<&str>) -> Subtask {
Subtask {
id: id.to_string(),
objective: format!("Do {id}"),
depends_on: deps.into_iter().map(String::from).collect(),
capabilities: capabilities.into_iter().map(String::from).collect(),
complexity: Complexity::Medium,
}
}
#[test]
fn test_allocate_budget_proportional_to_complexity() {
let budget = ResourceBudget {
total_iterations: 100,
total_tool_calls: 500,
total_tokens: 100_000,
};
let allocator = BudgetAllocator::new(budget);
let low = subtask("low", vec![], Complexity::Low);
let high = subtask("high", vec![], Complexity::High);
let total_weight = Complexity::Low.weight() + Complexity::High.weight();
let low_config = allocator.allocate(&low, total_weight);
let high_config = allocator.allocate(&high, total_weight);
assert!(high_config.max_iterations > low_config.max_iterations);
assert!(high_config.max_tool_calls > low_config.max_tool_calls);
assert!(high_config.max_tokens > low_config.max_tokens);
}
#[test]
fn test_total_allocation_within_budget() {
let budget = ResourceBudget {
total_iterations: 100,
total_tool_calls: 500,
total_tokens: 100_000,
};
let allocator = BudgetAllocator::new(budget.clone());
let subtasks = vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec![], Complexity::Medium),
subtask("C", vec![], Complexity::High),
];
let ids: Vec<String> = subtasks.iter().map(|s| s.id.clone()).collect();
let resolver = DependencyResolver::new(subtasks.clone());
let total_weight = resolver.total_weight(&ids);
let configs: Vec<LoopConfig> = subtasks
.iter()
.map(|s| allocator.allocate(s, total_weight))
.collect();
let total_iters: u32 = configs.iter().map(|c| c.max_iterations).sum();
let total_calls: u32 = configs.iter().map(|c| c.max_tool_calls).sum();
let total_tokens: u32 = configs.iter().map(|c| c.max_tokens).sum();
assert!(total_iters <= budget.total_iterations);
assert!(total_calls <= budget.total_tool_calls);
assert!(total_tokens <= budget.total_tokens);
}
#[test]
fn test_record_consumption_updates_remaining() {
let budget = ResourceBudget {
total_iterations: 100,
total_tool_calls: 500,
total_tokens: 50_000,
};
let mut allocator = BudgetAllocator::new(budget);
allocator.record_consumption(&ResourceConsumption {
iterations: 30,
tool_calls: 100,
tokens: 10_000,
});
let remaining = allocator.remaining();
assert_eq!(remaining.iterations, 70);
assert_eq!(remaining.tool_calls, 400);
assert_eq!(remaining.tokens, 40_000);
}
#[test]
fn test_rebalance_within_remaining() {
let configs = rebalance_budget(100, 500, 50_000, 4);
assert_eq!(configs.len(), 4);
let total_iters: u32 = configs.iter().map(|c| c.max_iterations).sum();
let total_calls: u32 = configs.iter().map(|c| c.max_tool_calls).sum();
let total_tokens: u32 = configs.iter().map(|c| c.max_tokens).sum();
assert!(total_iters <= 100);
assert!(total_calls <= 500);
assert!(total_tokens <= 50_000);
}
#[test]
fn test_rebalance_zero_agents_returns_empty() {
let configs = rebalance_budget(100, 500, 50_000, 0);
assert!(configs.is_empty());
}
#[test]
fn test_independent_subtasks_all_ready() {
let resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec![], Complexity::Low),
subtask("C", vec![], Complexity::Low),
]);
let ready = resolver.ready();
assert_eq!(ready.len(), 3);
}
#[test]
fn test_dependent_subtask_waits() {
let resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec!["A"], Complexity::Low),
subtask("C", vec!["A", "B"], Complexity::Low),
]);
let ready = resolver.ready();
assert_eq!(ready.len(), 1);
assert_eq!(ready[0], "A");
}
#[test]
fn test_completing_dep_unblocks_dependent() {
let mut resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec!["A"], Complexity::Low),
subtask("C", vec!["B"], Complexity::Low),
]);
resolver.mark_running("A");
assert!(resolver.ready().is_empty());
resolver.mark_completed("A");
let ready = resolver.ready();
assert_eq!(ready.len(), 1);
assert_eq!(ready[0], "B");
}
#[test]
fn test_chain_runs_sequentially() {
let mut resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec!["A"], Complexity::Low),
subtask("C", vec!["B"], Complexity::Low),
]);
assert_eq!(resolver.ready(), vec!["A"]);
resolver.mark_running("A");
resolver.mark_completed("A");
assert_eq!(resolver.ready(), vec!["B"]);
resolver.mark_running("B");
resolver.mark_completed("B");
assert_eq!(resolver.ready(), vec!["C"]);
resolver.mark_running("C");
resolver.mark_completed("C");
assert!(resolver.is_done());
}
#[test]
fn test_diamond_dependency() {
let mut resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec!["A"], Complexity::Low),
subtask("C", vec!["A"], Complexity::Low),
subtask("D", vec!["B", "C"], Complexity::Low),
]);
assert_eq!(resolver.ready().len(), 1);
resolver.mark_running("A");
resolver.mark_completed("A");
let ready = resolver.ready();
assert_eq!(ready.len(), 2);
assert!(ready.contains(&"B".to_string()));
assert!(ready.contains(&"C".to_string()));
resolver.mark_running("B");
resolver.mark_running("C");
resolver.mark_completed("B");
assert!(resolver.ready().is_empty());
resolver.mark_completed("C");
assert_eq!(resolver.ready(), vec!["D"]);
}
#[test]
fn test_validate_missing_dependency() {
let resolver = DependencyResolver::new(vec![subtask("A", vec!["Z"], Complexity::Low)]);
let result = resolver.validate();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SupervisorError::InvalidDependency { .. }
));
}
#[test]
fn test_validate_cyclic_dependency() {
let resolver = DependencyResolver::new(vec![
subtask("A", vec!["B"], Complexity::Low),
subtask("B", vec!["A"], Complexity::Low),
]);
let result = resolver.validate();
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SupervisorError::CyclicDependency { .. }
));
}
#[test]
fn test_validate_valid_dag() {
let resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec!["A"], Complexity::Low),
subtask("C", vec!["A"], Complexity::Low),
subtask("D", vec!["B", "C"], Complexity::Low),
]);
assert!(resolver.validate().is_ok());
}
#[test]
fn test_failed_subtask_not_ready() {
let mut resolver = DependencyResolver::new(vec![
subtask("A", vec![], Complexity::Low),
subtask("B", vec![], Complexity::Low),
]);
resolver.mark_failed("A");
let ready = resolver.ready();
assert_eq!(ready.len(), 1);
assert_eq!(ready[0], "B");
}
#[test]
fn test_concurrency_limit_enforced() {
let mut limiter = ConcurrencyLimiter::new(2);
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire()); assert_eq!(limiter.active_count(), 2);
assert_eq!(limiter.max_observed(), 2);
}
#[test]
fn test_release_opens_slot() {
let mut limiter = ConcurrencyLimiter::new(1);
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
limiter.release();
assert!(limiter.try_acquire());
}
#[test]
fn test_queued_dispatched_on_release() {
let mut limiter = ConcurrencyLimiter::new(1);
assert!(limiter.try_acquire());
limiter.enqueue("task_B".to_string());
limiter.enqueue("task_C".to_string());
assert_eq!(limiter.queued_count(), 2);
let next = limiter.release();
assert_eq!(next, Some("task_B".to_string()));
assert_eq!(limiter.queued_count(), 1);
}
#[test]
fn test_no_zombies_when_all_resolved() {
let mut tracker = LifecycleTracker::new();
tracker.record_spawn("agent_1".to_string(), "task_A".to_string());
tracker.record_spawn("agent_2".to_string(), "task_B".to_string());
assert_eq!(tracker.zombies().len(), 2);
tracker.record_completion("agent_1");
tracker.record_completion("agent_2");
assert!(tracker.all_resolved());
assert!(tracker.zombies().is_empty());
}
#[test]
fn test_reroute_resolves_original_agent() {
let mut tracker = LifecycleTracker::new();
tracker.record_spawn("agent_1".to_string(), "task_A".to_string());
tracker.record_reroute("agent_1");
tracker.record_spawn("agent_2".to_string(), "task_A".to_string());
tracker.record_completion("agent_2");
assert!(tracker.all_resolved());
assert_eq!(tracker.total_spawned(), 2);
}
#[test]
fn test_zombie_detected() {
let mut tracker = LifecycleTracker::new();
tracker.record_spawn("agent_1".to_string(), "task_A".to_string());
tracker.record_spawn("agent_2".to_string(), "task_B".to_string());
tracker.record_completion("agent_1");
let zombies = tracker.zombies();
assert_eq!(zombies.len(), 1);
assert_eq!(zombies[0], "agent_2");
}
#[test]
fn test_reroute_matches_expertise() {
let subtasks = vec![
subtask_with_capabilities("research", vec![], vec!["general"]),
subtask_with_capabilities("implement", vec![], vec!["rust", "database"]),
];
let exclude = HashSet::from(["research".to_string()]);
let result = RerouteResolver::find_match(&["database".to_string()], &subtasks, &exclude);
assert_eq!(result, Some("implement".to_string()));
}
#[test]
fn test_reroute_no_match_returns_none() {
let subtasks = vec![subtask_with_capabilities(
"research",
vec![],
vec!["general"],
)];
let exclude = HashSet::from(["research".to_string()]);
let result = RerouteResolver::find_match(&["database".to_string()], &subtasks, &exclude);
assert!(result.is_none());
}
#[test]
fn test_reroute_selects_best_match() {
let subtasks = vec![
subtask_with_capabilities("a", vec![], vec!["python"]),
subtask_with_capabilities("b", vec![], vec!["rust"]),
subtask_with_capabilities("c", vec![], vec!["rust", "database", "api"]),
];
let exclude = HashSet::new();
let result = RerouteResolver::find_match(
&["rust".to_string(), "database".to_string()],
&subtasks,
&exclude,
);
assert_eq!(result, Some("c".to_string()));
}
#[test]
fn test_circuit_breaker_triggers_at_threshold() {
let mut cb = CircuitBreaker::new(3);
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::EngineError).is_err()); assert!(cb.is_open());
}
#[test]
fn test_circuit_breaker_resets_on_different_type() {
let mut cb = CircuitBreaker::new(3);
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::AgentStuck).is_ok());
assert!(!cb.is_open());
assert_eq!(cb.consecutive_failures(), 1);
}
#[test]
fn test_circuit_breaker_success_resets_count() {
let mut cb = CircuitBreaker::new(3);
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::EngineError).is_ok());
cb.record_success();
assert_eq!(cb.consecutive_failures(), 0);
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::EngineError).is_ok());
assert!(cb.record_failure(FailureType::EngineError).is_err());
}
#[test]
fn test_circuit_breaker_reset() {
let mut cb = CircuitBreaker::new(2);
assert!(cb.record_failure(FailureType::Timeout).is_ok());
assert!(cb.record_failure(FailureType::Timeout).is_err());
assert!(cb.is_open());
cb.reset();
assert!(!cb.is_open());
assert_eq!(cb.consecutive_failures(), 0);
}
#[test]
fn test_aggregation_collects_all_results() {
let mut agg = ResultAggregator::new();
agg.record_result(SubtaskResult {
subtask_id: "A".to_string(),
status: SubtaskStatus::Completed,
summary: None,
agent_id: Some("agent_1".to_string()),
});
agg.record_result(SubtaskResult {
subtask_id: "B".to_string(),
status: SubtaskStatus::Completed,
summary: None,
agent_id: Some("agent_2".to_string()),
});
agg.record_result(SubtaskResult {
subtask_id: "C".to_string(),
status: SubtaskStatus::Failed {
reason: "timeout".to_string(),
},
summary: None,
agent_id: Some("agent_3".to_string()),
});
assert_eq!(agg.completed_count(), 2);
assert_eq!(agg.failed_count(), 1);
assert_eq!(agg.all_results().len(), 3);
}
#[test]
fn test_completed_results_preserved_in_summary() {
let mut agg = ResultAggregator::new();
agg.record_result(SubtaskResult {
subtask_id: "A".to_string(),
status: SubtaskStatus::Completed,
summary: None,
agent_id: Some("agent_1".to_string()),
});
agg.record_result(SubtaskResult {
subtask_id: "B".to_string(),
status: SubtaskStatus::Failed {
reason: "error".to_string(),
},
summary: None,
agent_id: Some("agent_2".to_string()),
});
let summary = agg.build_summary(
2,
&ResourceConsumption {
iterations: 10,
tool_calls: 30,
tokens: 5000,
},
Duration::from_secs(60),
2,
);
assert!(matches!(
summary.termination,
SupervisorTermination::PartialComplete {
completed: 1,
failed: 1
}
));
assert_eq!(summary.subtask_results.len(), 2);
let a_result = summary.subtask_results.iter().find(|r| r.subtask_id == "A");
assert!(a_result.is_some());
assert!(matches!(
a_result.map(|r| &r.status),
Some(SubtaskStatus::Completed)
));
}
#[test]
fn test_all_complete_termination() {
let mut agg = ResultAggregator::new();
for id in &["A", "B", "C"] {
agg.record_result(SubtaskResult {
subtask_id: id.to_string(),
status: SubtaskStatus::Completed,
summary: None,
agent_id: None,
});
}
let summary = agg.build_summary(
3,
&ResourceConsumption::default(),
Duration::from_secs(30),
3,
);
assert!(matches!(
summary.termination,
SupervisorTermination::AllComplete
));
}
#[test]
fn test_all_failed_termination() {
let mut agg = ResultAggregator::new();
agg.record_result(SubtaskResult {
subtask_id: "A".to_string(),
status: SubtaskStatus::Failed {
reason: "err".to_string(),
},
summary: None,
agent_id: None,
});
let summary = agg.build_summary(
1,
&ResourceConsumption::default(),
Duration::from_secs(10),
1,
);
assert!(matches!(
summary.termination,
SupervisorTermination::Failed { .. }
));
}
#[test]
fn test_skipped_subtask_recorded() {
let mut agg = ResultAggregator::new();
agg.record_skipped("C", "dependency A failed");
let result = agg.get("C");
assert!(result.is_some());
assert!(matches!(
result.map(|r| &r.status),
Some(SubtaskStatus::Skipped { .. })
));
}
#[test]
fn test_majority_failure_partial_result_preserved() {
let mut agg = ResultAggregator::new();
agg.record_result(SubtaskResult {
subtask_id: "D".to_string(),
status: SubtaskStatus::Completed,
summary: None,
agent_id: Some("agent_4".to_string()),
});
for id in &["A", "B", "C"] {
agg.record_result(SubtaskResult {
subtask_id: id.to_string(),
status: SubtaskStatus::Failed {
reason: "fail".to_string(),
},
summary: None,
agent_id: None,
});
}
let summary = agg.build_summary(
4,
&ResourceConsumption::default(),
Duration::from_secs(30),
4,
);
assert!(matches!(
summary.termination,
SupervisorTermination::PartialComplete {
completed: 1,
failed: 3
}
));
assert!(summary
.subtask_results
.iter()
.any(|r| r.subtask_id == "D" && matches!(r.status, SubtaskStatus::Completed)));
}
#[test]
fn test_wellbeing_counts_sum_to_total() {
let states = vec![
AgentWellbeingState::Healthy,
AgentWellbeingState::Cautious,
AgentWellbeingState::Concerned,
AgentWellbeingState::Distressed,
];
let agg = compute_aggregate_wellbeing(&states);
assert_eq!(agg.agents_total, 4);
assert_eq!(
agg.agents_healthy + agg.agents_cautious + agg.agents_concerned + agg.agents_distressed,
agg.agents_total
);
}
#[test]
fn test_distressed_child_paused_not_punished() {
let states = vec![
("agent_1".to_string(), AgentWellbeingState::Distressed),
("agent_2".to_string(), AgentWellbeingState::Healthy),
("agent_3".to_string(), AgentWellbeingState::Healthy),
];
let actions = supervisor_wellbeing_response(&states);
let distressed_action = actions.iter().find(|a| a.agent_id == "agent_1");
assert!(distressed_action.is_some());
assert!(matches!(
distressed_action.map(|a| &a.response),
Some(WellbeingResponse::Reassign)
));
assert!(!actions.iter().any(|a| a.agent_id == "agent_2"));
assert!(!actions.iter().any(|a| a.agent_id == "agent_3"));
}
#[test]
fn test_majority_concerned_triggers_replan() {
let states = vec![
AgentWellbeingState::Concerned,
AgentWellbeingState::Concerned,
AgentWellbeingState::Concerned,
AgentWellbeingState::Healthy,
];
let agg = compute_aggregate_wellbeing(&states);
let action = supervisor_level_response(&agg);
assert_eq!(action, SupervisorWellbeingAction::PauseAndReplan);
}
#[test]
fn test_all_concerned_with_distressed_escalates() {
let states = vec![
AgentWellbeingState::Concerned,
AgentWellbeingState::Distressed,
AgentWellbeingState::Concerned,
];
let agg = compute_aggregate_wellbeing(&states);
let action = supervisor_level_response(&agg);
assert_eq!(action, SupervisorWellbeingAction::EscalateToClient);
}
#[test]
fn test_all_healthy_continues() {
let states = vec![
AgentWellbeingState::Healthy,
AgentWellbeingState::Healthy,
AgentWellbeingState::Cautious,
];
let agg = compute_aggregate_wellbeing(&states);
let action = supervisor_level_response(&agg);
assert_eq!(action, SupervisorWellbeingAction::Continue);
}
#[test]
fn test_empty_agents_continues() {
let agg = compute_aggregate_wellbeing(&[]);
let action = supervisor_level_response(&agg);
assert_eq!(action, SupervisorWellbeingAction::Continue);
}
#[test]
fn test_concerned_agent_paused() {
let states = vec![
("agent_1".to_string(), AgentWellbeingState::Concerned),
("agent_2".to_string(), AgentWellbeingState::Healthy),
];
let actions = supervisor_wellbeing_response(&states);
assert_eq!(actions.len(), 1);
assert_eq!(actions[0].agent_id, "agent_1");
assert_eq!(actions[0].response, WellbeingResponse::Pause);
}
mod proptest_supervisor {
use super::*;
fn arb_complexity() -> impl Strategy<Value = Complexity> {
prop_oneof![
Just(Complexity::Low),
Just(Complexity::Medium),
Just(Complexity::High),
]
}
fn arb_subtask(id: String) -> impl Strategy<Value = Subtask> {
arb_complexity().prop_map(move |complexity| Subtask {
id: id.clone(),
objective: format!("task {}", id),
depends_on: vec![],
capabilities: vec![],
complexity,
})
}
fn arb_resource_budget() -> impl Strategy<Value = ResourceBudget> {
(10u32..1000, 10u32..5000, 1000u32..200_000).prop_map(|(iters, calls, tokens)| {
ResourceBudget {
total_iterations: iters,
total_tool_calls: calls,
total_tokens: tokens,
}
})
}
fn arb_failure_type() -> impl Strategy<Value = FailureType> {
prop_oneof![
Just(FailureType::EngineError),
Just(FailureType::AgentStuck),
Just(FailureType::AgentYielded),
Just(FailureType::ToolError),
Just(FailureType::Timeout),
]
}
fn arb_wellbeing_state() -> impl Strategy<Value = AgentWellbeingState> {
prop_oneof![
Just(AgentWellbeingState::Healthy),
Just(AgentWellbeingState::Cautious),
Just(AgentWellbeingState::Concerned),
Just(AgentWellbeingState::Distressed),
]
}
proptest! {
#[test]
fn prop_total_allocation_within_budget(
budget in arb_resource_budget(),
count in 1u32..10,
) {
let allocator = BudgetAllocator::new(budget.clone());
let subtasks: Vec<Subtask> = (0..count)
.map(|i| Subtask {
id: format!("task_{i}"),
objective: format!("do {i}"),
depends_on: vec![],
capabilities: vec![],
complexity: [Complexity::Low, Complexity::Medium, Complexity::High]
[i as usize % 3],
})
.collect();
let ids: Vec<String> = subtasks.iter().map(|s| s.id.clone()).collect();
let resolver = DependencyResolver::new(subtasks.clone());
let total_weight = resolver.total_weight(&ids);
let configs: Vec<LoopConfig> = subtasks
.iter()
.map(|s| allocator.allocate(s, total_weight))
.collect();
let total_iters: u32 = configs.iter().map(|c| c.max_iterations).sum();
let total_calls: u32 = configs.iter().map(|c| c.max_tool_calls).sum();
let total_tokens: u32 = configs.iter().map(|c| c.max_tokens).sum();
prop_assert!(total_iters <= budget.total_iterations + count,
"iterations {total_iters} > budget {}", budget.total_iterations);
prop_assert!(total_calls <= budget.total_tool_calls + count,
"calls {total_calls} > budget {}", budget.total_tool_calls);
prop_assert!(total_tokens <= budget.total_tokens + count,
"tokens {total_tokens} > budget {}", budget.total_tokens);
}
#[test]
fn prop_rebalance_within_remaining(
remaining_iters in 1u32..1000,
remaining_calls in 1u32..5000,
remaining_tokens in 1u32..200_000,
running_count in 1u32..5,
) {
let configs = rebalance_budget(
remaining_iters,
remaining_calls,
remaining_tokens,
running_count,
);
let total_iters: u32 = configs.iter().map(|c| c.max_iterations).sum();
let total_calls: u32 = configs.iter().map(|c| c.max_tool_calls).sum();
let total_tokens: u32 = configs.iter().map(|c| c.max_tokens).sum();
prop_assert!(total_iters <= remaining_iters);
prop_assert!(total_calls <= remaining_calls);
prop_assert!(total_tokens <= remaining_tokens);
}
#[test]
fn prop_concurrency_limit_respected(
max_concurrent in 1u32..5,
events in 1u32..20,
) {
let mut limiter = ConcurrencyLimiter::new(max_concurrent);
for _ in 0..events {
if limiter.try_acquire() {
if limiter.active_count() > 1 {
limiter.release();
}
}
}
prop_assert!(limiter.max_observed() <= max_concurrent);
}
#[test]
fn prop_circuit_breaker_bounded(
threshold in 2u32..5,
failure_type in arb_failure_type(),
) {
let mut cb = CircuitBreaker::new(threshold);
let mut count = 0u32;
loop {
match cb.record_failure(failure_type.clone()) {
Ok(()) => count += 1,
Err(_) => break,
}
}
prop_assert!(count < threshold,
"circuit breaker allowed {} failures before tripping (threshold={})",
count, threshold);
}
#[test]
fn prop_wellbeing_counts_sum(
states in prop::collection::vec(arb_wellbeing_state(), 0..20),
) {
let agg = compute_aggregate_wellbeing(&states);
let sum = agg.agents_healthy
+ agg.agents_cautious
+ agg.agents_concerned
+ agg.agents_distressed;
prop_assert_eq!(sum, agg.agents_total);
prop_assert_eq!(agg.agents_total, states.len());
}
#[test]
fn prop_distressed_child_not_punished(
agent_count in 1usize..10,
) {
let mut states: Vec<(AgentId, AgentWellbeingState)> = (0..agent_count)
.map(|i| (format!("agent_{i}"), AgentWellbeingState::Healthy))
.collect();
states[0].1 = AgentWellbeingState::Distressed;
let actions = supervisor_wellbeing_response(&states);
let distressed_action = actions.iter().find(|a| a.agent_id == states[0].0);
prop_assert!(distressed_action.is_some(),
"distressed agent should have an action");
let action = distressed_action.expect("checked above");
prop_assert!(
matches!(action.response, WellbeingResponse::Pause | WellbeingResponse::Reassign),
"distressed agent should be paused or reassigned, got {:?}",
action.response
);
}
}
}
}