use std::path::PathBuf;
use swarm_engine_core::agent::{DefaultBatchManagerAgent, ManagerId};
use swarm_engine_core::config::PathResolver;
use swarm_engine_core::exploration::{AdaptiveLlmOperatorProvider, ReviewPolicy};
use swarm_engine_eval::config::DependencyProviderKind;
use swarm_engine_eval::prelude::{EvalReport, EvalRunner};
use swarm_engine_eval::reporter::{JsonReporter, Reporter};
use swarm_engine_eval::scenario::{EvalScenario, LlmProvider};
use swarm_engine_llm::{create_llm_invoker, LlmStrategyAdvisor, OllamaDecider};
#[allow(clippy::too_many_arguments)]
pub fn cmd_eval(
scenario_path: PathBuf,
runs: usize,
seed: u64,
output: Option<PathBuf>,
output_dir: Option<PathBuf>,
no_log: bool,
verbose: bool,
variant: Option<String>,
list_variants: bool,
learning: bool,
provider_kind: &str,
) {
if verbose {
use tracing_subscriber::EnvFilter;
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new("swarm_engine_llm=info,swarm_engine_core=info"));
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(true)
.compact()
.init();
}
if !scenario_path.exists() {
eprintln!("Scenario file not found: {}", scenario_path.display());
std::process::exit(1);
}
let content = match std::fs::read_to_string(&scenario_path) {
Ok(c) => c,
Err(e) => {
eprintln!("Failed to read scenario file: {}", e);
std::process::exit(1);
}
};
let base_scenario: EvalScenario = match toml::from_str(&content) {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to parse scenario TOML: {}", e);
std::process::exit(1);
}
};
if list_variants {
println!("Available variants for '{}':", base_scenario.meta.name);
let names = base_scenario.variant_names();
if names.is_empty() {
println!(" (no variants defined)");
} else {
for name in names {
if let Some(v) = base_scenario.variants.iter().find(|v| v.name == name) {
if v.description.is_empty() {
println!(" - {}", name);
} else {
println!(" - {}: {}", name, v.description);
}
}
}
}
return;
}
let scenario = if let Some(ref variant_name) = variant {
match base_scenario.with_variant(variant_name) {
Some(s) => s,
None => {
eprintln!("Variant '{}' not found.", variant_name);
eprintln!("Available variants: {:?}", base_scenario.variant_names());
std::process::exit(1);
}
}
} else {
base_scenario
};
println!("=== SwarmEngine Eval ===");
println!("Scenario: {}", scenario_path.display());
if let Some(ref v) = variant {
println!("Variant: {}", v);
}
println!("Runs: {}", runs);
println!("Seed: {}", seed);
println!();
println!("=== Scenario ===");
println!("Name: {}", scenario.meta.name);
println!("ID: {}", scenario.meta.id);
println!("Version: {}", scenario.meta.version);
if !scenario.meta.description.is_empty() {
println!("Description: {}", scenario.meta.description);
}
if !scenario.meta.tags.is_empty() {
println!("Tags: {}", scenario.meta.tags.join(", "));
}
println!();
println!("=== Task ===");
println!("Goal: {}", scenario.task.goal);
if let Some(expected) = &scenario.task.expected {
println!("Expected: {}", expected);
}
println!();
println!("=== Config ===");
let worker_count: usize = scenario.agents.workers.iter().map(|w| w.count).sum();
println!("Workers: {}", worker_count);
println!("Max ticks: {}", scenario.app_config.max_ticks);
println!("LLM: {:?} ({})", scenario.llm.provider, scenario.llm.model);
if let Some(ref lora) = scenario.llm.lora {
println!("LoRA: id={}, scale={:.2}", lora.id, lora.scale);
}
println!("Actions: {:?}", scenario.actions.action_names());
println!();
let rt = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
let handle = rt.handle().clone();
if scenario.llm.provider.requires_endpoint() {
let endpoint = scenario
.llm
.endpoint
.clone()
.or_else(|| scenario.llm.provider.default_endpoint().map(String::from))
.unwrap_or_else(|| "http://localhost:8080".to_string());
println!("Checking LLM server health...");
let health_url = match scenario.llm.provider {
LlmProvider::Ollama => format!("{}/api/tags", endpoint),
LlmProvider::LlamaCppServer => format!("{}/health", endpoint),
_ => format!("{}/health", endpoint),
};
let health_result = rt.block_on(async {
let client = reqwest::Client::new();
client
.get(&health_url)
.timeout(std::time::Duration::from_secs(5))
.send()
.await
});
match health_result {
Ok(resp) if resp.status().is_success() => {
println!(" \x1b[32m✓ LLM server is healthy ({})\x1b[0m", endpoint);
}
Ok(resp) => {
eprintln!(
"\x1b[31m✗ LLM server returned error: {} ({})\x1b[0m",
resp.status(),
health_url
);
eprintln!(" Hint: Start the server with 'swarm llama start -m <model>'");
std::process::exit(1);
}
Err(e) => {
eprintln!("\x1b[31m✗ LLM server is not responding: {}\x1b[0m", e);
eprintln!(" Endpoint: {}", health_url);
eprintln!(" Hint: Start the server with 'swarm llama start -m <model>'");
std::process::exit(1);
}
}
println!();
}
let runner = match scenario.llm.provider {
LlmProvider::Ollama => {
use swarm_engine_core::exploration::AdaptiveOperatorProvider;
use swarm_engine_core::learn::LearningStore;
let llm_config = scenario
.llm
.to_ollama_config(scenario.batch_processor.max_concurrency);
let llm_config_for_advisor = llm_config.clone();
let handle_for_advisor = handle.clone();
let offline_model = if learning {
let learning_path = PathResolver::user_data_dir().join("learning");
LearningStore::new(&learning_path)
.ok()
.and_then(|store| store.load_offline_model(&scenario.meta.name).ok())
} else {
None
};
let adaptive_provider = if let Some(ref model) = offline_model {
println!(
"Applying offline model to provider: ucb1_c={:.3}, maturity={}, strategy={}",
model.parameters.ucb1_c,
model.strategy_config.maturity_threshold,
model.strategy_config.initial_strategy
);
AdaptiveOperatorProvider::default()
.with_ucb1_c(model.parameters.ucb1_c)
.with_maturity_threshold(model.strategy_config.maturity_threshold)
.with_error_rate_threshold(model.strategy_config.error_rate_threshold)
} else {
AdaptiveOperatorProvider::default()
};
let ucb1_c = offline_model
.as_ref()
.map(|m| m.parameters.ucb1_c)
.unwrap_or(std::f64::consts::SQRT_2);
EvalRunner::new(scenario, rt.handle().clone())
.with_runs(runs)
.with_seed(seed)
.with_verbose(verbose)
.with_exploration(true)
.with_manager_factory(|| Box::new(DefaultBatchManagerAgent::new(ManagerId(0))))
.with_batch_invoker_factory(move || {
let decider = OllamaDecider::new(llm_config.clone());
Box::new(create_llm_invoker(decider, handle.clone()))
})
.with_operator_provider_factory(move || {
let decider =
std::sync::Arc::new(OllamaDecider::new(llm_config_for_advisor.clone()));
let advisor = LlmStrategyAdvisor::new(decider, handle_for_advisor.clone());
Box::new(
AdaptiveLlmOperatorProvider::new(Box::new(advisor))
.with_adaptive(adaptive_provider.clone())
.with_ucb1_c(ucb1_c)
.with_policy(ReviewPolicy::default()),
)
})
}
LlmProvider::Mistral => {
eprintln!("Error: Mistral provider is not supported in this version");
eprintln!("Use 'ollama' or 'llama-server' instead");
std::process::exit(1);
}
#[cfg(feature = "llama-cpp")]
LlmProvider::LlamaCpp => {
use swarm_engine_llm::{LlamaCppConfig, LlamaCppDecider};
let llama_config = if scenario.llm.is_gguf() {
LlamaCppConfig::from_hf(
&scenario.llm.model,
scenario.llm.gguf_files.first().cloned().unwrap_or_default(),
)
} else {
LlamaCppConfig::from_local(&scenario.llm.model)
}
.with_max_tokens(scenario.llm.max_tokens.unwrap_or(256))
.with_temperature(scenario.llm.temperature)
.with_context_size(scenario.llm.num_ctx.unwrap_or(4096) as u32);
let decider =
LlamaCppDecider::new(llama_config).expect("Failed to create LlamaCppDecider");
EvalRunner::new(scenario, rt.handle().clone())
.with_runs(runs)
.with_seed(seed)
.with_verbose(verbose)
.with_manager_factory(|| Box::new(DefaultBatchManagerAgent::new(ManagerId(0))))
.with_batch_invoker_factory(move || {
Box::new(create_llm_invoker(decider.clone(), handle.clone()))
})
}
#[cfg(not(feature = "llama-cpp"))]
LlmProvider::LlamaCpp => {
eprintln!("Error: LlamaCpp provider requires --features llama-cpp");
eprintln!("Build with: cargo build --features llama-cpp");
std::process::exit(1);
}
LlmProvider::LlamaCppServer => {
use swarm_engine_core::exploration::AdaptiveOperatorProvider;
use swarm_engine_core::learn::LearningStore;
use swarm_engine_llm::{ChatTemplate, LlamaCppServerConfig, LlamaCppServerDecider};
let endpoint = scenario
.llm
.endpoint
.clone()
.unwrap_or_else(|| "http://localhost:8080".to_string());
let server_config = LlamaCppServerConfig::new(endpoint)
.with_model_name(&scenario.llm.model)
.with_max_tokens(scenario.llm.max_tokens.unwrap_or(256))
.with_temperature(scenario.llm.temperature)
.with_chat_template(ChatTemplate::Lfm2);
let decider = LlamaCppServerDecider::new(server_config.clone())
.expect("Failed to create LlamaCppServerDecider");
let server_config_for_advisor = server_config;
let handle_for_advisor = handle.clone();
let offline_model = if learning {
let learning_path = PathResolver::user_data_dir().join("learning");
LearningStore::new(&learning_path)
.ok()
.and_then(|store| store.load_offline_model(&scenario.meta.name).ok())
} else {
None
};
let adaptive_provider = if let Some(ref model) = offline_model {
println!(
"Applying offline model to provider: ucb1_c={:.3}, maturity={}, strategy={}",
model.parameters.ucb1_c,
model.strategy_config.maturity_threshold,
model.strategy_config.initial_strategy
);
AdaptiveOperatorProvider::default()
.with_ucb1_c(model.parameters.ucb1_c)
.with_maturity_threshold(model.strategy_config.maturity_threshold)
.with_error_rate_threshold(model.strategy_config.error_rate_threshold)
} else {
AdaptiveOperatorProvider::default()
};
let ucb1_c = offline_model
.as_ref()
.map(|m| m.parameters.ucb1_c)
.unwrap_or(std::f64::consts::SQRT_2);
EvalRunner::new(scenario, rt.handle().clone())
.with_runs(runs)
.with_seed(seed)
.with_verbose(verbose)
.with_exploration(true)
.with_manager_factory(|| Box::new(DefaultBatchManagerAgent::new(ManagerId(0))))
.with_batch_invoker_factory(move || {
Box::new(create_llm_invoker(decider.clone(), handle.clone()))
})
.with_operator_provider_factory(move || {
let decider = std::sync::Arc::new(
LlamaCppServerDecider::new(server_config_for_advisor.clone())
.expect("Failed to create LlamaCppServerDecider for advisor"),
);
let advisor = LlmStrategyAdvisor::new(decider, handle_for_advisor.clone());
Box::new(
AdaptiveLlmOperatorProvider::new(Box::new(advisor))
.with_adaptive(adaptive_provider.clone())
.with_ucb1_c(ucb1_c)
.with_policy(ReviewPolicy::default()),
)
})
}
other => {
eprintln!("Error: Provider {:?} is not yet supported", other);
eprintln!(
"Supported providers: ollama, llama-server, llama-cpp (--features llama-cpp)"
);
std::process::exit(1);
}
};
let session_dir = if !no_log {
let sessions_dir = output_dir
.clone()
.unwrap_or_else(|| PathResolver::user_data_dir().join("eval").join("sessions"));
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let scenario_name = scenario_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
let dir = sessions_dir.join(format!("{}_{}", scenario_name, timestamp));
if let Err(e) = std::fs::create_dir_all(&dir) {
eprintln!("Failed to create session directory: {}", e);
None
} else {
Some(dir)
}
} else {
None
};
let runner = if learning {
let learning_path = PathResolver::user_data_dir().join("learning");
println!("Learning enabled: {}", learning_path.display());
runner.with_learning_store(&learning_path)
} else {
runner
};
let provider_kind_enum = match provider_kind {
"learned" => DependencyProviderKind::Learned,
_ => DependencyProviderKind::Smart, };
let runner = runner.with_dependency_provider_kind(provider_kind_enum);
println!(
"Dependency provider: {:?}",
provider_kind_enum
);
println!("Running evaluation...");
let report: EvalReport = match runner.run() {
Ok(r) => r,
Err(e) => {
eprintln!("Evaluation failed: {}", e);
std::process::exit(1);
}
};
println!();
println!("=== Results ===");
println!(
"Success rate: {:.1}%",
report.aggregated.success_rate * 100.0
);
println!("Pass@1: {:.1}%", report.aggregated.pass_at_1 * 100.0);
println!(
"Avg ticks: {:.1}",
report.aggregated.statistics.total_ticks.mean
);
println!(
"Throughput: {:.1} actions/sec (effective: {:.1})",
report.aggregated.statistics.raw_throughput_per_sec.mean,
report
.aggregated
.statistics
.effective_throughput_per_sec
.mean
);
let total_invocations = report.aggregated.statistics.total_llm_invocations;
let total_errors = report.aggregated.statistics.total_llm_errors;
if total_invocations > 0 {
let error_rate = total_errors as f64 / total_invocations as f64 * 100.0;
if total_errors > 0 {
println!(
"\x1b[31m⚠ LLM: {}/{} calls failed ({:.1}% error rate)\x1b[0m",
total_errors, total_invocations, error_rate
);
} else {
println!(
"\x1b[32m✓ LLM: {} calls, 0 errors\x1b[0m",
total_invocations
);
}
}
if let Some(output_path) = output {
let reporter = JsonReporter::new();
match reporter.generate(&report) {
Ok(json) => match std::fs::write(&output_path, json) {
Ok(_) => println!("\nReport written to: {}", output_path.display()),
Err(e) => eprintln!("\nFailed to write report: {}", e),
},
Err(e) => eprintln!("\nFailed to generate report: {}", e),
}
}
if let Some(ref session_dir) = session_dir {
let reporter = JsonReporter::new();
match reporter.generate(&report) {
Ok(json) => {
let report_path = session_dir.join("report.json");
match std::fs::write(&report_path, &json) {
Ok(_) => println!("\nSession saved to: {}", session_dir.display()),
Err(e) => eprintln!("\nFailed to write report: {}", e),
}
for (i, run) in report.runs.iter().enumerate() {
let run_path = session_dir.join(format!("run_{:03}.json", i + 1));
if let Ok(run_json) = serde_json::to_string_pretty(run) {
let _ = std::fs::write(&run_path, run_json);
}
}
}
Err(e) => eprintln!("\nFailed to generate report: {}", e),
}
}
}