use std::path::{Path, PathBuf};
use std::time::Duration;
use clap::Subcommand;
use swarm_engine_core::config::PathResolver;
#[derive(Subcommand)]
pub enum LearnAction {
Auto {
scenario: PathBuf,
#[arg(long, default_value = "10")]
bootstrap_runs: usize,
#[arg(long, default_value = "10")]
release_runs: usize,
#[arg(long, default_value = "with_graph")]
bootstrap_variant: String,
#[arg(long)]
stop_at: Option<String>,
#[arg(long)]
skip_bootstrap: bool,
#[arg(long)]
learning_dir: Option<PathBuf>,
#[arg(short = 'v', long)]
verbose: bool,
},
Once {
scenario: String,
#[arg(short = 'n', long, default_value = "20")]
sessions: usize,
#[arg(long)]
learning_dir: Option<PathBuf>,
},
Dpo {
scenario: String,
#[arg(long)]
learning_dir: Option<PathBuf>,
#[arg(short = 'o', long)]
output: Option<PathBuf>,
#[arg(long, default_value = "0.1")]
min_gap: f64,
#[arg(long)]
include_sft: bool,
},
Watch {
#[arg(long)]
learning_dir: Option<PathBuf>,
#[arg(short = 'n', long, default_value = "20")]
sessions: usize,
#[arg(long, default_value = "30")]
debounce_secs: u64,
},
Daemon {
scenario: String,
#[arg(long)]
learning_dir: Option<PathBuf>,
#[arg(long, default_value = "count:100")]
trigger: String,
#[arg(long, default_value = "offline")]
mode: String,
#[arg(long)]
auto_apply: bool,
#[arg(long, default_value = "10")]
check_interval_secs: u64,
},
}
pub fn cmd_learn(action: LearnAction) {
match action {
LearnAction::Auto {
scenario,
bootstrap_runs,
release_runs,
bootstrap_variant,
stop_at,
skip_bootstrap,
learning_dir,
verbose,
} => {
let learning_path =
learning_dir.unwrap_or_else(|| PathResolver::user_data_dir().join("learning"));
cmd_learn_auto(
&scenario,
bootstrap_runs,
release_runs,
&bootstrap_variant,
stop_at.as_deref(),
skip_bootstrap,
&learning_path,
verbose,
);
}
LearnAction::Once {
scenario,
sessions,
learning_dir,
} => {
let learning_path =
learning_dir.unwrap_or_else(|| PathResolver::user_data_dir().join("learning"));
cmd_learn_once(&scenario, sessions, &learning_path);
}
LearnAction::Watch {
learning_dir,
sessions,
debounce_secs,
} => {
let learning_path =
learning_dir.unwrap_or_else(|| PathResolver::user_data_dir().join("learning"));
cmd_learn_watch(learning_path, sessions, debounce_secs);
}
LearnAction::Daemon {
scenario,
learning_dir,
trigger,
mode,
auto_apply,
check_interval_secs,
} => {
let learning_path =
learning_dir.unwrap_or_else(|| PathResolver::user_data_dir().join("learning"));
cmd_learn_daemon(
&scenario,
learning_path,
&trigger,
&mode,
auto_apply,
check_interval_secs,
);
}
LearnAction::Dpo {
scenario,
learning_dir,
output,
min_gap,
include_sft,
} => {
let learning_path =
learning_dir.unwrap_or_else(|| PathResolver::user_data_dir().join("learning"));
cmd_learn_dpo(
&scenario,
&learning_path,
output.as_deref(),
min_gap,
include_sft,
);
}
}
}
fn cmd_learn_once(scenario: &str, sessions: usize, learning_path: &PathBuf) {
use swarm_engine_core::learn::daemon::{Processor, ProcessorConfig, ProcessorMode};
use swarm_engine_core::learn::store::InMemoryEpisodeStore;
use swarm_engine_core::learn::LearningStore;
println!("=== Offline Learning (via Processor) ===");
println!("Scenario: {}", scenario);
println!("Sessions to analyze: {}", sessions);
println!("Learning data: {}", learning_path.display());
println!();
let store = match LearningStore::new(learning_path) {
Ok(s) => s,
Err(e) => {
eprintln!("Error: Failed to open learning store: {}", e);
return;
}
};
match store.list_sessions(scenario) {
Ok(session_ids) => {
println!("Found {} sessions", session_ids.len());
if session_ids.is_empty() {
println!("No sessions found. Run eval with --learning first.");
return;
}
}
Err(e) => {
eprintln!("Error: Failed to list sessions: {}", e);
return;
}
}
let existing_action_order = store
.load_offline_model(scenario)
.ok()
.and_then(|m| m.action_order);
let config = ProcessorConfig::new(scenario)
.mode(ProcessorMode::OfflineOnly)
.max_sessions(sessions);
let processor = Processor::new(config).with_learning_store(store.clone());
let rt = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
let episode_store = InMemoryEpisodeStore::new();
println!("\nRunning offline analysis via Processor...");
match rt.block_on(processor.run(&episode_store)) {
Ok(result) => {
if let Some(mut model) = result.offline_model().cloned() {
if let Some(action_order) = existing_action_order.clone() {
model.action_order = Some(action_order.clone());
println!("\n=== Action Order (from DependencyGraph inference) ===");
println!(" discover: {:?}", action_order.discover);
println!(" not_discover: {:?}", action_order.not_discover);
println!(" action_set_hash: {}", action_order.action_set_hash);
println!(" source: {:?}", action_order.source);
}
if let Err(e) = store.save_offline_model(scenario, &model) {
eprintln!("Warning: Failed to save updated model: {}", e);
}
println!("\n=== Offline Model Generated ===");
println!("Analyzed sessions: {}", model.analyzed_sessions);
println!();
println!("Parameters:");
println!(" ucb1_c: {:.3}", model.parameters.ucb1_c);
println!(" learning_weight: {:.3}", model.parameters.learning_weight);
println!(" ngram_weight: {:.3}", model.parameters.ngram_weight);
println!();
println!("Strategy Config:");
println!(
" maturity_threshold: {}",
model.strategy_config.maturity_threshold
);
println!(
" error_rate_threshold: {:.3}",
model.strategy_config.error_rate_threshold
);
println!(
" initial_strategy: {}",
model.strategy_config.initial_strategy
);
println!();
println!("Recommended Paths:");
for (i, path) in model.recommended_paths.iter().enumerate() {
println!(
" {}. {:?} (success: {:.1}%, obs: {})",
i + 1,
path.actions,
path.success_rate * 100.0,
path.observations
);
}
println!();
if model.action_order.is_some() {
println!("Action Order: available (from DependencyGraph inference)");
} else {
println!("Action Order: none (run eval with LLM DependencyGraph first)");
}
println!();
println!(
"Model saved to: {}/scenarios/{}/offline_model.json",
learning_path.display(),
scenario
);
} else {
eprintln!("Error: No offline model in result");
}
}
Err(e) => {
eprintln!("Error: Processor failed: {}", e);
}
}
}
#[allow(unused_variables)]
fn cmd_learn_dpo(
scenario: &str,
learning_path: &Path,
output: Option<&Path>,
min_gap: f64,
include_sft: bool,
) {
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use swarm_engine_core::learn::{
DependencyGraphLearnModel, DependencyGraphRecord, DpoConfig, DpoLearnModel, Episode,
TrainingData,
};
println!("=== DPO Training Data Generation ===");
println!("Scenario: {}", scenario);
println!("Learning data: {}", learning_path.display());
println!("Min quality gap: {:.1}%", min_gap * 100.0);
println!("Include SFT: {}", include_sft);
println!();
eprintln!("Error: DPO command is temporarily disabled.");
eprintln!("The DependencyGraph learning data is now stored in RecordStore.");
eprintln!("Please use 'learn once' to extract action_order from the OfflineModel.");
return;
#[allow(unreachable_code)]
{
let episodes_path = learning_path
.join("scenarios")
.join(scenario)
.join("dep_graph_episodes.jsonl");
if !episodes_path.exists() {
eprintln!(
"Error: Episodes file not found: {}",
episodes_path.display()
);
eprintln!("Run 'swarm-engine eval' first to generate episode data.");
return;
}
let file = match File::open(&episodes_path) {
Ok(f) => f,
Err(e) => {
eprintln!("Error: Failed to open episodes file: {}", e);
return;
}
};
let reader = BufReader::new(file);
let mut episodes: Vec<Episode> = Vec::new();
for (i, line) in reader.lines().enumerate() {
match line {
Ok(l) if !l.trim().is_empty() => match serde_json::from_str::<Episode>(&l) {
Ok(ep) => episodes.push(ep),
Err(e) => {
eprintln!("Warning: Failed to parse line {}: {}", i + 1, e);
}
},
Err(e) => {
eprintln!("Warning: Failed to read line {}: {}", i + 1, e);
}
_ => {}
}
}
println!("Loaded {} episodes", episodes.len());
let success_count = episodes.iter().filter(|e| e.outcome.is_success()).count();
let failure_count = episodes.iter().filter(|e| e.outcome.is_failure()).count();
println!(" Success: {}", success_count);
println!(" Failure: {}", failure_count);
println!();
if success_count == 0 || failure_count == 0 {
eprintln!("Error: Need both success and failure episodes for DPO.");
eprintln!("Run more evaluations to collect both outcomes.");
return;
}
let dpo_config = DpoConfig {
min_quality_gap: min_gap,
..Default::default()
};
let dpo_model = DpoLearnModel::new(DependencyGraphLearnModel::extractor())
.with_system_prompt(DependencyGraphLearnModel::default_system_prompt())
.with_config(dpo_config);
let pairs = dpo_model.build_pairs(&episodes);
println!("Generated {} DPO pairs", pairs.len());
if pairs.is_empty() {
eprintln!("Warning: No DPO pairs generated.");
eprintln!("Try lowering --min-gap threshold or collect more episodes.");
eprintln!("Note: Episodes must share the same group_id for DPO pairing.");
return;
}
let mut training_data: Vec<TrainingData> = dpo_model.convert_pairs(&pairs);
println!(" DPO samples: {}", training_data.len());
if include_sft {
let sft_data: Vec<TrainingData> = episodes
.iter()
.filter(|ep| ep.outcome.is_success())
.filter_map(|ep| {
ep.context.iter::<DependencyGraphRecord>().next().map(|r| {
let scenario_name = ep.metadata.scenario_name.as_deref().unwrap_or("unknown");
TrainingData::sft(
DependencyGraphLearnModel::default_system_prompt(),
&r.prompt,
&r.response,
)
.with_scenario(scenario_name)
})
})
.collect();
println!(" SFT samples: {}", sft_data.len());
training_data.extend(sft_data);
}
println!();
println!("Total training samples: {}", training_data.len());
if let Some(path) = output {
match File::create(path) {
Ok(mut file) => {
for data in &training_data {
let json = serde_json::to_string(data).unwrap();
if let Err(e) = writeln!(file, "{}", json) {
eprintln!("Error: Failed to write: {}", e);
return;
}
}
println!("Output written to: {}", path.display());
}
Err(e) => {
eprintln!("Error: Failed to create output file: {}", e);
}
}
} else {
println!();
println!("--- Training Data (JSONL) ---");
for data in &training_data {
let json = serde_json::to_string(data).unwrap();
println!("{}", json);
}
}
} }
fn cmd_learn_watch(learning_path: PathBuf, sessions: usize, debounce_secs: u64) {
use swarm_engine_core::prelude::{
DebounceTransform, LearningSink, LocalFileWatcherSource, Pipeline,
};
println!("=== Learning Daemon ===");
println!("Learning data: {}", learning_path.display());
println!("Sessions to analyze: {}", sessions);
println!("Debounce: {}s", debounce_secs);
println!();
let watch_dir = learning_path.clone();
if !watch_dir.exists() {
eprintln!(
"Error: Learning directory does not exist: {}",
watch_dir.display()
);
eprintln!("Run 'swarm-engine eval --learning' first to generate data.");
std::process::exit(1);
}
println!("Watching: {}", watch_dir.display());
println!("Press Ctrl+C to stop");
println!();
let rt = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
rt.block_on(async {
let source = match LocalFileWatcherSource::new(&watch_dir, Duration::from_secs(1)) {
Ok(s) => s,
Err(e) => {
eprintln!("Error: Failed to create file watcher: {}", e);
std::process::exit(1);
}
};
let transform = DebounceTransform::new(Duration::from_secs(debounce_secs));
let sink = LearningSink::new(learning_path, sessions);
let mut pipeline = Pipeline::new(source, transform, sink);
if let Err(e) = pipeline.run().await {
eprintln!("Pipeline error: {}", e);
}
});
}
fn cmd_learn_daemon(
scenario: &str,
learning_path: PathBuf,
trigger_spec: &str,
mode_spec: &str,
auto_apply: bool,
check_interval_secs: u64,
) {
use swarm_engine_core::learn::{DaemonConfig, LearningDaemon, ProcessorMode};
println!("=== Learning Daemon ===");
println!("Scenario: {}", scenario);
println!("Learning data: {}", learning_path.display());
println!("Trigger: {}", trigger_spec);
println!("Mode: {}", mode_spec);
println!("Auto-apply: {}", auto_apply);
println!("Check interval: {}s", check_interval_secs);
println!();
let trigger = match parse_trigger_spec(trigger_spec) {
Ok(t) => t,
Err(e) => {
eprintln!("Error: Invalid trigger spec: {}", e);
std::process::exit(1);
}
};
let processor_mode: ProcessorMode = match mode_spec.parse() {
Ok(m) => m,
Err(e) => {
eprintln!("Error: Invalid mode: {}", e);
std::process::exit(1);
}
};
let config = DaemonConfig::new(scenario)
.data_dir(&learning_path)
.check_interval(Duration::from_secs(check_interval_secs))
.processor_mode(processor_mode)
.auto_apply(auto_apply);
let daemon = match LearningDaemon::new(config, trigger) {
Ok(d) => d,
Err(e) => {
eprintln!("Error: Failed to create daemon: {}", e);
std::process::exit(1);
}
};
println!("Daemon started. Press Ctrl+C to stop.");
println!();
let rt = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
rt.block_on(async {
let mut daemon = daemon;
if let Err(e) = daemon.run().await {
eprintln!("Daemon error: {}", e);
}
});
}
fn parse_trigger_spec(
spec: &str,
) -> Result<std::sync::Arc<dyn swarm_engine_core::learn::TrainTrigger>, String> {
use swarm_engine_core::learn::TriggerBuilder;
let spec = spec.to_lowercase();
if spec == "always" {
return Ok(TriggerBuilder::always());
}
if spec == "never" {
return Ok(TriggerBuilder::never());
}
if let Some(n_str) = spec.strip_prefix("count:") {
let n: usize = n_str
.parse()
.map_err(|_| format!("Invalid count: {}", n_str))?;
return Ok(TriggerBuilder::every_n_episodes(n));
}
if let Some(time_str) = spec.strip_prefix("time:") {
if let Some(hours_str) = time_str.strip_suffix('h') {
let hours: u64 = hours_str
.parse()
.map_err(|_| format!("Invalid hours: {}", hours_str))?;
return Ok(TriggerBuilder::every_minutes(hours * 60));
}
if let Some(mins_str) = time_str.strip_suffix('m') {
let mins: u64 = mins_str
.parse()
.map_err(|_| format!("Invalid minutes: {}", mins_str))?;
return Ok(TriggerBuilder::every_minutes(mins));
}
return Err(format!("Invalid time spec: {} (use Nh or Nm)", time_str));
}
Err(format!("Unknown trigger spec: {}", spec))
}
#[allow(clippy::too_many_arguments)]
fn cmd_learn_auto(
scenario_path: &Path,
bootstrap_runs: usize,
release_runs: usize,
bootstrap_variant: &str,
stop_at: Option<&str>,
skip_bootstrap: bool,
learning_path: &Path,
verbose: bool,
) {
use swarm_engine_core::learn::daemon::{Processor, ProcessorConfig, ProcessorMode};
use swarm_engine_core::learn::store::InMemoryEpisodeStore;
use swarm_engine_core::learn::{LearningPhase, LearningStore, SessionGroup, SessionGroupId};
use swarm_engine_eval::scenario::EvalScenario;
println!("=== Learn Auto: 3-Step Workflow ===");
println!("Scenario: {}", scenario_path.display());
println!("Bootstrap runs: {}", bootstrap_runs);
println!("Release runs: {}", release_runs);
println!("Bootstrap variant: {}", bootstrap_variant);
if let Some(stop) = stop_at {
println!("Stop at: {}", stop);
}
println!("Learning data: {}", learning_path.display());
println!();
if !scenario_path.exists() {
eprintln!("Scenario file not found: {}", scenario_path.display());
std::process::exit(1);
}
let content = match std::fs::read_to_string(scenario_path) {
Ok(c) => c,
Err(e) => {
eprintln!("Failed to read scenario file: {}", e);
std::process::exit(1);
}
};
let base_scenario: EvalScenario = match toml::from_str(&content) {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to parse scenario TOML: {}", e);
std::process::exit(1);
}
};
let variant_names = base_scenario.variant_names();
if !variant_names.contains(&bootstrap_variant) {
eprintln!("Variant '{}' not found in scenario.", bootstrap_variant);
eprintln!("Available variants: {:?}", variant_names);
std::process::exit(1);
}
let rt = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
let handle = rt.handle().clone();
let store = match LearningStore::new(learning_path) {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to create LearningStore: {}", e);
std::process::exit(1);
}
};
let scenario_key = base_scenario.meta.id.learning_key();
let _group_id = SessionGroupId::new();
if !skip_bootstrap {
println!("╔══════════════════════════════════════════════════════════════╗");
println!(
"║ [Step 1/3] Bootstrap (variant: {}) ",
bootstrap_variant
);
println!("╚══════════════════════════════════════════════════════════════╝");
println!();
let bootstrap_scenario = base_scenario.with_variant(bootstrap_variant).unwrap();
let mut bootstrap_group =
SessionGroup::new(LearningPhase::Bootstrap, &scenario_key, bootstrap_runs)
.with_variant(bootstrap_variant);
for i in 0..bootstrap_runs {
let seed = 42 + i as u64;
print!(" Run {}/{}: ", i + 1, bootstrap_runs);
std::io::Write::flush(&mut std::io::stdout()).ok();
let runner = build_eval_runner(
bootstrap_scenario.clone(),
handle.clone(),
1,
seed,
verbose,
Some(learning_path),
None, );
match runner.run() {
Ok(report) => {
let success = report.aggregated.success_rate >= 1.0;
let ticks = report.aggregated.statistics.total_ticks.mean as u64;
bootstrap_group.add_session(
swarm_engine_core::learn::SessionId(format!(
"{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
)),
success,
);
if success {
println!("\x1b[32m✓ Success ({} ticks)\x1b[0m", ticks);
} else {
println!("\x1b[31m✗ Failure ({} ticks)\x1b[0m", ticks);
}
}
Err(e) => {
println!("\x1b[31m✗ Error: {}\x1b[0m", e);
bootstrap_group.add_session(
swarm_engine_core::learn::SessionId("error".to_string()),
false,
);
}
}
}
bootstrap_group.mark_completed();
println!();
println!(
" Summary: {}/{} success ({:.0}%)",
bootstrap_group.metadata.success_count,
bootstrap_runs,
bootstrap_group.success_rate() * 100.0
);
println!();
if stop_at == Some("bootstrap") {
println!("Stopped at bootstrap (--stop-at=bootstrap)");
return;
}
} else {
println!("Skipping bootstrap phase (--skip-bootstrap)");
println!();
}
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ [Step 2/3] Offline Learning ║");
println!("╚══════════════════════════════════════════════════════════════╝");
println!();
let config = ProcessorConfig::new(&scenario_key)
.mode(ProcessorMode::OfflineOnly)
.max_sessions(bootstrap_runs);
let processor = Processor::new(config).with_learning_store(store.clone());
let episode_store = InMemoryEpisodeStore::new();
println!(" Analyzing {} sessions...", bootstrap_runs);
match rt.block_on(processor.run(&episode_store)) {
Ok(result) => {
if let Some(model) = result.offline_model() {
println!(" \x1b[32m✓ Offline model generated\x1b[0m");
println!(
" Parameters: ucb1_c={:.3}, learning_weight={:.3}",
model.parameters.ucb1_c, model.parameters.learning_weight
);
println!(
" Strategy: {} (maturity={})",
model.strategy_config.initial_strategy,
model.strategy_config.maturity_threshold
);
if let Some(ref action_order) = model.action_order {
println!(" Action order: discover={:?}", action_order.discover);
println!(
" not_discover={:?}",
action_order.not_discover
);
}
} else {
eprintln!(" \x1b[31m✗ No offline model generated\x1b[0m");
}
}
Err(e) => {
eprintln!(" \x1b[31m✗ Processor failed: {}\x1b[0m", e);
}
}
println!();
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ [Step 3/3] Release (no graph, using learned model) ║");
println!("╚══════════════════════════════════════════════════════════════╝");
println!();
let release_scenario = base_scenario.clone();
let mut release_group = SessionGroup::new(LearningPhase::Release, &scenario_key, release_runs);
for i in 0..release_runs {
let seed = 1000 + i as u64; print!(" Run {}/{}: ", i + 1, release_runs);
std::io::Write::flush(&mut std::io::stdout()).ok();
let runner = build_eval_runner(
release_scenario.clone(),
handle.clone(),
1,
seed,
verbose,
Some(learning_path),
None, );
match runner.run() {
Ok(report) => {
let success = report.aggregated.success_rate >= 1.0;
let ticks = report.aggregated.statistics.total_ticks.mean as u64;
release_group.add_session(
swarm_engine_core::learn::SessionId(format!(
"{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
)),
success,
);
if success {
println!("\x1b[32m✓ Success ({} ticks)\x1b[0m", ticks);
} else {
println!("\x1b[31m✗ Failure ({} ticks)\x1b[0m", ticks);
}
}
Err(e) => {
println!("\x1b[31m✗ Error: {}\x1b[0m", e);
release_group.add_session(
swarm_engine_core::learn::SessionId("error".to_string()),
false,
);
}
}
}
release_group.mark_completed();
println!();
println!(
" Summary: {}/{} success ({:.0}%)",
release_group.metadata.success_count,
release_runs,
release_group.success_rate() * 100.0
);
println!();
println!("╔══════════════════════════════════════════════════════════════╗");
println!("║ Final Report ║");
println!("╚══════════════════════════════════════════════════════════════╝");
println!();
if !skip_bootstrap {
println!(
" Bootstrap: {:.0}% success",
bootstrap_runs as f64 * 100.0 / bootstrap_runs as f64
); }
println!(
" Release: {:.0}% success",
release_group.success_rate() * 100.0
);
println!();
println!(" Learning data saved to: {}", learning_path.display());
}
fn build_eval_runner(
scenario: swarm_engine_eval::scenario::EvalScenario,
handle: tokio::runtime::Handle,
runs: usize,
seed: u64,
verbose: bool,
learning_path: Option<&Path>,
trace_subscriber: Option<std::sync::Arc<dyn swarm_engine_core::events::TraceSubscriber>>,
) -> swarm_engine_eval::prelude::EvalRunner {
use swarm_engine_core::agent::{DefaultBatchManagerAgent, ManagerId};
use swarm_engine_core::exploration::{
AdaptiveLlmOperatorProvider, AdaptiveOperatorProvider, ReviewPolicy,
};
use swarm_engine_core::learn::LearningStore;
use swarm_engine_eval::prelude::EvalRunner;
use swarm_engine_eval::scenario::LlmProvider;
use swarm_engine_llm::{
create_llm_invoker, ChatTemplate, LlamaCppServerConfig, LlamaCppServerDecider,
LlmStrategyAdvisor, OllamaDecider,
};
let runner = match scenario.llm.provider {
LlmProvider::Ollama => {
let llm_config = scenario
.llm
.to_ollama_config(scenario.batch_processor.max_concurrency);
let llm_config_for_advisor = llm_config.clone();
let handle_for_advisor = handle.clone();
let offline_model = learning_path.and_then(|path| {
LearningStore::new(path).ok().and_then(|store| {
store
.load_offline_model(&scenario.meta.id.learning_key())
.ok()
})
});
let adaptive_provider = if let Some(ref model) = offline_model {
AdaptiveOperatorProvider::default()
.with_ucb1_c(model.parameters.ucb1_c)
.with_maturity_threshold(model.strategy_config.maturity_threshold)
.with_error_rate_threshold(model.strategy_config.error_rate_threshold)
} else {
AdaptiveOperatorProvider::default()
};
let ucb1_c = offline_model
.as_ref()
.map(|m| m.parameters.ucb1_c)
.unwrap_or(std::f64::consts::SQRT_2);
EvalRunner::new(scenario, handle.clone())
.with_runs(runs)
.with_seed(seed)
.with_verbose(verbose)
.with_exploration(true)
.with_manager_factory(|| Box::new(DefaultBatchManagerAgent::new(ManagerId(0))))
.with_batch_invoker_factory(move || {
let decider = OllamaDecider::new(llm_config.clone());
Box::new(create_llm_invoker(decider, handle.clone()))
})
.with_operator_provider_factory(move || {
let decider =
std::sync::Arc::new(OllamaDecider::new(llm_config_for_advisor.clone()));
let advisor = LlmStrategyAdvisor::new(decider, handle_for_advisor.clone());
Box::new(
AdaptiveLlmOperatorProvider::new(Box::new(advisor))
.with_adaptive(adaptive_provider.clone())
.with_ucb1_c(ucb1_c)
.with_policy(ReviewPolicy::default()),
)
})
}
LlmProvider::LlamaCppServer => {
let endpoint = scenario
.llm
.endpoint
.clone()
.unwrap_or_else(|| "http://localhost:8080".to_string());
let server_config = LlamaCppServerConfig::new(endpoint)
.with_model_name(&scenario.llm.model)
.with_max_tokens(scenario.llm.max_tokens.unwrap_or(256))
.with_temperature(scenario.llm.temperature)
.with_chat_template(ChatTemplate::Lfm2);
let decider = LlamaCppServerDecider::new(server_config.clone())
.expect("Failed to create LlamaCppServerDecider");
let server_config_for_advisor = server_config;
let handle_for_advisor = handle.clone();
let offline_model = learning_path.and_then(|path| {
LearningStore::new(path).ok().and_then(|store| {
store
.load_offline_model(&scenario.meta.id.learning_key())
.ok()
})
});
let adaptive_provider = if let Some(ref model) = offline_model {
AdaptiveOperatorProvider::default()
.with_ucb1_c(model.parameters.ucb1_c)
.with_maturity_threshold(model.strategy_config.maturity_threshold)
.with_error_rate_threshold(model.strategy_config.error_rate_threshold)
} else {
AdaptiveOperatorProvider::default()
};
let ucb1_c = offline_model
.as_ref()
.map(|m| m.parameters.ucb1_c)
.unwrap_or(std::f64::consts::SQRT_2);
EvalRunner::new(scenario, handle.clone())
.with_runs(runs)
.with_seed(seed)
.with_verbose(verbose)
.with_exploration(true)
.with_manager_factory(|| Box::new(DefaultBatchManagerAgent::new(ManagerId(0))))
.with_batch_invoker_factory(move || {
Box::new(create_llm_invoker(decider.clone(), handle.clone()))
})
.with_operator_provider_factory(move || {
let decider = std::sync::Arc::new(
LlamaCppServerDecider::new(server_config_for_advisor.clone())
.expect("Failed to create LlamaCppServerDecider for advisor"),
);
let advisor = LlmStrategyAdvisor::new(decider, handle_for_advisor.clone());
Box::new(
AdaptiveLlmOperatorProvider::new(Box::new(advisor))
.with_adaptive(adaptive_provider.clone())
.with_ucb1_c(ucb1_c)
.with_policy(ReviewPolicy::default()),
)
})
}
other => {
eprintln!(
"Error: Provider {:?} is not supported for learn auto",
other
);
std::process::exit(1);
}
};
let runner = if let Some(path) = learning_path {
runner.with_learning_store(path)
} else {
runner
};
if let Some(subscriber) = trace_subscriber {
runner.with_trace_subscriber(subscriber)
} else {
runner
}
}