use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use swarm_engine_core::agent::{
BatchDecisionRequest, DecisionResponse, Guidance, ManagementDecision, ManagerAgent, ManagerId,
TaskContext, WorkResult, WorkerAgent,
};
use swarm_engine_core::learn::LearnableSwarmBuilder;
use swarm_engine_core::orchestrator::SwarmConfig;
use swarm_engine_core::state::SwarmState;
use swarm_engine_core::types::ActionResult;
use swarm_engine_core::types::{SwarmTask, WorkerId};
use tempfile::TempDir;
struct TestWorker {
id: WorkerId,
name: String,
action_count: Arc<AtomicU64>,
}
impl TestWorker {
fn new(id: usize, action_count: Arc<AtomicU64>) -> Self {
Self {
id: WorkerId(id),
name: format!("worker_{}", id),
action_count,
}
}
}
impl WorkerAgent for TestWorker {
fn id(&self) -> WorkerId {
self.id
}
fn name(&self) -> &str {
&self.name
}
fn think_and_act(&self, _state: &SwarmState, _guidance: Option<&Guidance>) -> WorkResult {
self.action_count.fetch_add(1, Ordering::SeqCst);
WorkResult::acted(ActionResult::success(
"test_action",
Duration::from_micros(100),
))
}
}
struct TestManager {
name: String,
}
impl TestManager {
fn new() -> Self {
Self {
name: "test_manager".to_string(),
}
}
}
impl ManagerAgent for TestManager {
fn prepare(&self, _context: &TaskContext) -> BatchDecisionRequest {
BatchDecisionRequest {
manager_id: ManagerId(0),
requests: Vec::new(),
}
}
fn finalize(
&self,
_context: &TaskContext,
_responses: Vec<(WorkerId, DecisionResponse)>,
) -> ManagementDecision {
ManagementDecision::default()
}
fn id(&self) -> ManagerId {
ManagerId(0)
}
fn name(&self) -> &str {
&self.name
}
}
fn make_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
}
#[test]
fn test_build_without_learning() {
let rt = make_runtime();
let action_count = Arc::new(AtomicU64::new(0));
let config = SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 5,
..Default::default()
};
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.swarm_config(config)
.add_worker(Box::new(TestWorker::new(0, action_count.clone())))
.add_manager(Box::new(TestManager::new()))
.build();
assert!(result.is_ok(), "Build should succeed without learning");
let mut swarm = result.unwrap();
assert!(!swarm.is_learning_enabled());
let run_result = swarm.run();
assert!(run_result.completed);
assert!(
action_count.load(Ordering::SeqCst) > 0,
"Worker should have acted"
);
}
#[test]
fn test_build_with_learning_requires_scenario() {
let rt = make_runtime();
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.with_learning(true)
.add_worker(Box::new(TestWorker::new(0, Arc::new(AtomicU64::new(0)))))
.build();
assert!(result.is_err());
match result {
Err(err) => {
assert!(
err.to_string().contains("scenario is required"),
"Error should mention scenario requirement: {}",
err
);
}
Ok(_) => panic!("Expected error but got Ok"),
}
}
#[test]
fn test_build_with_learning_enabled() {
let rt = make_runtime();
let temp_dir = TempDir::new().unwrap();
let action_count = Arc::new(AtomicU64::new(0));
let config = SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 3,
..Default::default()
};
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.scenario("test_scenario")
.data_dir(temp_dir.path())
.with_learning(true)
.swarm_config(config)
.add_worker(Box::new(TestWorker::new(0, action_count.clone())))
.add_manager(Box::new(TestManager::new()))
.build();
assert!(result.is_ok(), "Build should succeed with learning enabled");
let mut swarm = result.unwrap();
assert!(swarm.is_learning_enabled());
assert_eq!(swarm.config().scenario, "test_scenario");
let run_result = swarm.run();
assert!(run_result.completed);
if let Some(tx) = swarm.take_shutdown_tx() {
let _ = tx.try_send(());
}
}
#[test]
fn test_run_with_task() {
let rt = make_runtime();
let action_count = Arc::new(AtomicU64::new(0));
let config = SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 10,
..Default::default()
};
let mut swarm = LearnableSwarmBuilder::new(rt.handle().clone())
.swarm_config(config)
.add_worker(Box::new(TestWorker::new(0, action_count.clone())))
.add_manager(Box::new(TestManager::new()))
.build()
.unwrap();
let task = SwarmTask::new("Test task goal");
let result = swarm.run_task(task);
assert!(result.is_ok());
let run_result = result.unwrap();
assert!(run_result.completed);
assert!(action_count.load(Ordering::SeqCst) > 0);
}
#[test]
fn test_offline_model_ref_before_build() {
let rt = make_runtime();
let builder = LearnableSwarmBuilder::new(rt.handle().clone())
.scenario("test")
.add_worker(Box::new(TestWorker::new(0, Arc::new(AtomicU64::new(0)))));
assert!(builder.offline_model_ref().is_none());
}
#[test]
fn test_deferred_error_on_invalid_store_path() {
let rt = make_runtime();
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.scenario("test")
.with_learning_store_path("/nonexistent/path/that/should/fail/learning")
.add_worker(Box::new(TestWorker::new(0, Arc::new(AtomicU64::new(0)))))
.build();
let _ = result;
}
#[test]
fn test_build_with_scenario_profile() {
use swarm_engine_core::learn::{
LearnedExploration, LearnedStrategy, ScenarioProfile, ScenarioSource,
};
let rt = make_runtime();
let action_count = Arc::new(AtomicU64::new(0));
let mut profile =
ScenarioProfile::new("troubleshooting", ScenarioSource::from_path("/test.toml"));
profile.exploration = Some(LearnedExploration::new(2.5, 0.4, 1.2));
profile.strategy = Some(LearnedStrategy {
initial_strategy: "greedy".to_string(),
maturity_threshold: 10,
error_rate_threshold: 0.3,
confidence: 0.8,
session_count: 5,
updated_at: 0,
});
let config = SwarmConfig {
tick_duration: Duration::from_millis(1),
max_ticks: 3,
..Default::default()
};
let result = LearnableSwarmBuilder::new(rt.handle().clone())
.with_scenario_profile(&profile)
.swarm_config(config)
.add_worker(Box::new(TestWorker::new(0, action_count.clone())))
.add_manager(Box::new(TestManager::new()))
.build();
assert!(
result.is_ok(),
"Build with scenario profile should succeed: {:?}",
result.err()
);
let swarm = result.unwrap();
assert_eq!(swarm.config().scenario, "troubleshooting");
let offline_model = swarm.offline_model();
assert!(offline_model.is_some());
let model = offline_model.unwrap();
assert_eq!(model.parameters.ucb1_c, 2.5);
assert_eq!(model.parameters.learning_weight, 0.4);
assert_eq!(model.strategy_config.initial_strategy, "greedy");
}