use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use swarm_engine_core::agent::{
ActionCandidate, BatchDecisionRequest, BatchInvokeResult, BatchInvoker, ContextTarget,
DecisionResponse, GlobalContext, Guidance, ManagementDecision, ManagementStrategy,
ManagerAgent, ManagerId, ResolvedContext, TaskContext, WorkResult, WorkerAgent,
WorkerDecisionRequest,
};
use swarm_engine_core::extensions::Extensions;
use swarm_engine_core::orchestrator::{OrchestratorBuilder, SwarmConfig};
use swarm_engine_core::state::SwarmState;
use swarm_engine_core::types::{ActionResult, SwarmTask, WorkerId};
struct TestWorker {
id: WorkerId,
name: String,
action_count: Arc<AtomicU64>,
}
impl TestWorker {
fn new(id: usize, action_count: Arc<AtomicU64>) -> Self {
Self {
id: WorkerId(id),
name: format!("TestWorker_{}", id),
action_count,
}
}
}
impl WorkerAgent for TestWorker {
fn id(&self) -> WorkerId {
self.id
}
fn name(&self) -> &str {
&self.name
}
fn think_and_act(&self, _state: &SwarmState, _guidance: Option<&Guidance>) -> WorkResult {
self.action_count.fetch_add(1, Ordering::SeqCst);
WorkResult::acted(ActionResult::success("", Duration::from_micros(100)))
}
}
struct TestManager {
name: String,
observe_count: Arc<AtomicU64>,
}
impl TestManager {
fn new(observe_count: Arc<AtomicU64>) -> Self {
Self {
name: "TestManager".to_string(),
observe_count,
}
}
}
impl ManagerAgent for TestManager {
fn prepare(&self, _context: &TaskContext) -> BatchDecisionRequest {
self.observe_count.fetch_add(1, Ordering::SeqCst);
BatchDecisionRequest {
manager_id: ManagerId(0),
requests: Vec::new(),
}
}
fn finalize(
&self,
_context: &TaskContext,
_responses: Vec<(WorkerId, DecisionResponse)>,
) -> ManagementDecision {
ManagementDecision::default()
}
fn id(&self) -> ManagerId {
ManagerId(0)
}
fn name(&self) -> &str {
&self.name
}
}
struct MockBatchInvoker {
invoke_count: Arc<AtomicU64>,
response_tool: String,
}
impl MockBatchInvoker {
fn new(invoke_count: Arc<AtomicU64>) -> Self {
Self {
invoke_count,
response_tool: "MockAction".to_string(),
}
}
#[allow(dead_code)]
fn with_response_tool(mut self, tool: impl Into<String>) -> Self {
self.response_tool = tool.into();
self
}
}
impl BatchInvoker for MockBatchInvoker {
fn invoke(&self, request: BatchDecisionRequest, _extensions: &Extensions) -> BatchInvokeResult {
self.invoke_count.fetch_add(1, Ordering::SeqCst);
request
.requests
.iter()
.map(|req| {
let response = DecisionResponse {
tool: self.response_tool.clone(),
target: format!("target_{}", req.worker_id.0),
args: std::collections::HashMap::new(),
reasoning: Some("Mock reasoning".to_string()),
confidence: 0.9,
prompt: None,
raw_response: None,
};
(req.worker_id, Ok(response))
})
.collect()
}
fn name(&self) -> &str {
"MockBatchInvoker"
}
}
struct TestManagerWithRequests {
name: String,
worker_count: usize,
}
impl TestManagerWithRequests {
fn new(worker_count: usize) -> Self {
Self {
name: "TestManagerWithRequests".to_string(),
worker_count,
}
}
}
impl ManagerAgent for TestManagerWithRequests {
fn prepare(&self, context: &TaskContext) -> BatchDecisionRequest {
let requests: Vec<WorkerDecisionRequest> = context
.worker_ids()
.into_iter()
.take(self.worker_count)
.map(|worker_id| {
let global = GlobalContext::new(context.tick);
let candidates = vec![
ActionCandidate {
name: "Action1".to_string(),
description: "Test action 1".to_string(),
params: vec![],
example: None,
},
ActionCandidate {
name: "Action2".to_string(),
description: "Test action 2".to_string(),
params: vec![],
example: None,
},
];
let resolved = ResolvedContext::new(global, ContextTarget::Worker(worker_id))
.with_candidates(candidates);
WorkerDecisionRequest {
worker_id,
query: format!("What should worker {} do?", worker_id.0),
context: resolved,
lora: None,
}
})
.collect();
BatchDecisionRequest {
manager_id: ManagerId(0),
requests,
}
}
fn finalize(
&self,
_context: &TaskContext,
responses: Vec<(WorkerId, DecisionResponse)>,
) -> ManagementDecision {
let mut guidances = std::collections::HashMap::new();
for (worker_id, response) in responses {
guidances.insert(
worker_id,
Guidance::hint(format!("Do {} on {}", response.tool, response.target)),
);
}
ManagementDecision {
guidances,
strategy_update: None,
async_tasks: vec![],
}
}
fn id(&self) -> ManagerId {
ManagerId(0)
}
fn name(&self) -> &str {
&self.name
}
}
#[test]
fn test_orchestrator_basic_run() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.add_worker(TestWorker::new(1, Arc::clone(&action_count)))
.add_worker(TestWorker::new(2, Arc::clone(&action_count)))
.max_ticks(5)
.tick_duration(Duration::from_millis(1))
.build(runtime.handle().clone());
let result = orchestrator.run();
assert!(result.completed);
assert_eq!(result.total_ticks, 5);
assert_eq!(action_count.load(Ordering::SeqCst), 15);
}
#[test]
fn test_orchestrator_with_manager() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let observe_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.manager(TestManager::new(Arc::clone(&observe_count)))
.config(SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 20,
management_strategy: ManagementStrategy::FixedInterval { interval: 5 },
})
.build(runtime.handle().clone());
let result = orchestrator.run();
assert!(result.completed);
assert_eq!(result.total_ticks, 20);
assert!(observe_count.load(Ordering::SeqCst) >= 3);
}
#[test]
fn test_orchestrator_state_access() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.max_ticks(3)
.tick_duration(Duration::from_millis(1))
.build(runtime.handle().clone());
orchestrator.run();
let state = orchestrator.state();
assert_eq!(state.shared.tick, 3);
assert_eq!(state.workers.len(), 1);
assert!(state.shared.stats.total_visits() >= 3);
}
#[test]
fn test_orchestrator_metrics() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.add_worker(TestWorker::new(1, Arc::clone(&action_count)))
.max_ticks(10)
.tick_duration(Duration::from_millis(1))
.build(runtime.handle().clone());
let result = orchestrator.run();
let state = orchestrator.state();
let stats = &state.shared.stats;
assert_eq!(result.total_ticks, 10);
assert_eq!(stats.total_visits(), 20); assert_eq!(stats.total_successes(), 20);
assert_eq!(stats.total_failures(), 0);
}
#[test]
fn test_orchestrator_request_terminate() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.max_ticks(100)
.tick_duration(Duration::from_millis(1))
.build(runtime.handle().clone());
let result = orchestrator.run();
assert!(result.completed);
assert_eq!(result.total_ticks, 100);
}
#[test]
fn test_orchestrator_with_batch_invoker() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let invoke_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.add_worker(TestWorker::new(1, Arc::clone(&action_count)))
.manager(TestManagerWithRequests::new(2))
.batch_invoker(MockBatchInvoker::new(Arc::clone(&invoke_count)))
.config(SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 10,
management_strategy: ManagementStrategy::FixedInterval { interval: 5 },
})
.build(runtime.handle().clone());
let result = orchestrator.run();
assert!(result.completed);
assert_eq!(result.total_ticks, 10);
let invokes = invoke_count.load(Ordering::SeqCst);
assert!(
invokes >= 1,
"BatchInvoker should be called at least once, got {}",
invokes
);
}
#[test]
fn test_batch_invoker_generates_responses() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let invoke_count = Arc::new(AtomicU64::new(0));
struct GuidanceAwareWorker {
id: WorkerId,
received_guidance_count: Arc<AtomicU64>,
}
impl WorkerAgent for GuidanceAwareWorker {
fn id(&self) -> WorkerId {
self.id
}
fn name(&self) -> &str {
"GuidanceAwareWorker"
}
fn think_and_act(&self, _state: &SwarmState, guidance: Option<&Guidance>) -> WorkResult {
if guidance.is_some() {
self.received_guidance_count.fetch_add(1, Ordering::SeqCst);
}
WorkResult::acted(ActionResult::success("", Duration::from_micros(100)))
}
}
let guidance_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(GuidanceAwareWorker {
id: WorkerId(0),
received_guidance_count: Arc::clone(&guidance_count),
})
.manager(TestManagerWithRequests::new(1))
.batch_invoker(MockBatchInvoker::new(Arc::clone(&invoke_count)))
.config(SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 10,
management_strategy: ManagementStrategy::FixedInterval { interval: 2 },
})
.build(runtime.handle().clone());
orchestrator.run();
let invokes = invoke_count.load(Ordering::SeqCst);
assert!(invokes > 0, "BatchInvoker should be called");
let guidances = guidance_count.load(Ordering::SeqCst);
assert!(
guidances > 0,
"Worker should receive guidance, got {}",
guidances
);
}
struct TestManagerWithId {
manager_id: ManagerId,
name: String,
observe_count: Arc<AtomicU64>,
}
impl TestManagerWithId {
fn new(id: usize, observe_count: Arc<AtomicU64>) -> Self {
Self {
manager_id: ManagerId(id),
name: format!("TestManager_{}", id),
observe_count,
}
}
}
impl ManagerAgent for TestManagerWithId {
fn prepare(&self, _context: &TaskContext) -> BatchDecisionRequest {
self.observe_count.fetch_add(1, Ordering::SeqCst);
BatchDecisionRequest {
manager_id: self.manager_id,
requests: Vec::new(),
}
}
fn finalize(
&self,
_context: &TaskContext,
_responses: Vec<(WorkerId, DecisionResponse)>,
) -> ManagementDecision {
ManagementDecision::default()
}
fn id(&self) -> ManagerId {
self.manager_id
}
fn name(&self) -> &str {
&self.name
}
}
#[test]
fn test_orchestrator_with_multiple_managers() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let manager1_count = Arc::new(AtomicU64::new(0));
let manager2_count = Arc::new(AtomicU64::new(0));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.manager(TestManagerWithId::new(0, Arc::clone(&manager1_count)))
.manager(TestManagerWithId::new(1, Arc::clone(&manager2_count)))
.config(SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 10,
management_strategy: ManagementStrategy::FixedInterval { interval: 5 },
})
.build(runtime.handle().clone());
let result = orchestrator.run();
assert!(result.completed);
assert_eq!(result.total_ticks, 10);
let m1_calls = manager1_count.load(Ordering::SeqCst);
let m2_calls = manager2_count.load(Ordering::SeqCst);
assert!(
m1_calls >= 1,
"Manager 1 should be called at least once, got {}",
m1_calls
);
assert!(
m2_calls >= 1,
"Manager 2 should be called at least once, got {}",
m2_calls
);
assert_eq!(
m1_calls, m2_calls,
"Both managers should be called same number of times"
);
}
struct TaskAwareManager {
name: String,
captured_goal: Arc<std::sync::Mutex<Option<String>>>,
captured_target: Arc<std::sync::Mutex<Option<String>>>,
}
impl TaskAwareManager {
fn new(
captured_goal: Arc<std::sync::Mutex<Option<String>>>,
captured_target: Arc<std::sync::Mutex<Option<String>>>,
) -> Self {
Self {
name: "TaskAwareManager".to_string(),
captured_goal,
captured_target,
}
}
}
impl ManagerAgent for TaskAwareManager {
fn prepare(&self, context: &TaskContext) -> BatchDecisionRequest {
if let Some(goal) = context.get_str("task") {
*self.captured_goal.lock().unwrap() = Some(goal.to_string());
}
if let Some(target) = context.get_str("task_target_path") {
*self.captured_target.lock().unwrap() = Some(target.to_string());
}
BatchDecisionRequest {
manager_id: ManagerId(0),
requests: Vec::new(),
}
}
fn finalize(
&self,
_context: &TaskContext,
_responses: Vec<(WorkerId, DecisionResponse)>,
) -> ManagementDecision {
ManagementDecision::default()
}
fn id(&self) -> ManagerId {
ManagerId(0)
}
fn name(&self) -> &str {
&self.name
}
}
#[test]
fn test_run_task_sets_swarm_task_in_extensions() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let captured_goal = Arc::new(std::sync::Mutex::new(None));
let captured_target = Arc::new(std::sync::Mutex::new(None));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TestWorker::new(0, Arc::clone(&action_count)))
.manager(TaskAwareManager::new(
Arc::clone(&captured_goal),
Arc::clone(&captured_target),
))
.config(SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 10,
management_strategy: ManagementStrategy::FixedInterval { interval: 5 },
})
.build(runtime.handle().clone());
let task = SwarmTask::new("Find the authentication handler").with_context(serde_json::json!({
"target_path": "/path/to/codebase",
"max_depth": 5
}));
let result = orchestrator
.run_task(task)
.expect("run_task should succeed");
assert!(result.completed);
let goal = captured_goal.lock().unwrap();
assert_eq!(goal.as_deref(), Some("Find the authentication handler"));
let target = captured_target.lock().unwrap();
assert_eq!(target.as_deref(), Some("/path/to/codebase"));
}
#[test]
fn test_run_task_swarm_task_accessible_by_worker() {
let runtime = tokio::runtime::Runtime::new().unwrap();
struct TaskAwareWorker {
id: WorkerId,
captured_goal: Arc<std::sync::Mutex<Option<String>>>,
}
impl WorkerAgent for TaskAwareWorker {
fn id(&self) -> WorkerId {
self.id
}
fn name(&self) -> &str {
"TaskAwareWorker"
}
fn think_and_act(&self, state: &SwarmState, _guidance: Option<&Guidance>) -> WorkResult {
if let Some(task) = state.shared.extensions.get::<SwarmTask>() {
let mut goal = self.captured_goal.lock().unwrap();
if goal.is_none() {
*goal = Some(task.goal.clone());
}
}
WorkResult::acted(ActionResult::success("", Duration::from_micros(100)))
}
}
let captured_goal = Arc::new(std::sync::Mutex::new(None));
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(TaskAwareWorker {
id: WorkerId(0),
captured_goal: Arc::clone(&captured_goal),
})
.max_ticks(5)
.tick_duration(Duration::from_millis(1))
.build(runtime.handle().clone());
let task = SwarmTask::new("Worker task goal");
orchestrator
.run_task(task)
.expect("run_task should succeed");
let goal = captured_goal.lock().unwrap();
assert_eq!(goal.as_deref(), Some("Worker task goal"));
}
#[test]
fn test_swarm_task_context_complex_types() {
let runtime = tokio::runtime::Runtime::new().unwrap();
let captured_tags: Arc<std::sync::Mutex<Option<Vec<String>>>> =
Arc::new(std::sync::Mutex::new(None));
let captured_count: Arc<std::sync::Mutex<Option<i32>>> = Arc::new(std::sync::Mutex::new(None));
struct ContextReadingWorker {
id: WorkerId,
captured_tags: Arc<std::sync::Mutex<Option<Vec<String>>>>,
captured_count: Arc<std::sync::Mutex<Option<i32>>>,
}
impl WorkerAgent for ContextReadingWorker {
fn id(&self) -> WorkerId {
self.id
}
fn name(&self) -> &str {
"ContextReadingWorker"
}
fn think_and_act(&self, state: &SwarmState, _guidance: Option<&Guidance>) -> WorkResult {
if let Some(task) = state.shared.extensions.get::<SwarmTask>() {
if let Some(tags) = task.get::<Vec<String>>("tags") {
*self.captured_tags.lock().unwrap() = Some(tags);
}
if let Some(count) = task.get::<i32>("count") {
*self.captured_count.lock().unwrap() = Some(count);
}
}
WorkResult::acted(ActionResult::success("", Duration::from_micros(100)))
}
}
let mut orchestrator = OrchestratorBuilder::new()
.add_worker(ContextReadingWorker {
id: WorkerId(0),
captured_tags: Arc::clone(&captured_tags),
captured_count: Arc::clone(&captured_count),
})
.max_ticks(3)
.tick_duration(Duration::from_millis(1))
.build(runtime.handle().clone());
let task = SwarmTask::new("Complex context test").with_context(serde_json::json!({
"tags": ["alpha", "beta", "gamma"],
"count": 42,
"nested": {
"key": "value"
}
}));
orchestrator
.run_task(task)
.expect("run_task should succeed");
let tags = captured_tags.lock().unwrap();
assert_eq!(
tags.as_ref(),
Some(&vec![
"alpha".to_string(),
"beta".to_string(),
"gamma".to_string()
])
);
let count = captured_count.lock().unwrap();
assert_eq!(*count, Some(42));
}