use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PilotConfig {
pub mode: PilotMode,
pub budget: BudgetConfig,
pub intervention: InterventionConfig,
pub guide_at_start: bool,
pub guide_at_backtrack: bool,
pub prompt_template_path: Option<String>,
pub prefilter: PrefilterConfig,
pub prune: PruneConfig,
}
impl Default for PilotConfig {
fn default() -> Self {
Self {
mode: PilotMode::Balanced,
budget: BudgetConfig::default(),
intervention: InterventionConfig::default(),
guide_at_start: true,
guide_at_backtrack: true,
prompt_template_path: None,
prefilter: PrefilterConfig::default(),
prune: PruneConfig::default(),
}
}
}
impl PilotConfig {
pub fn with_mode(mode: PilotMode) -> Self {
Self {
mode,
..Default::default()
}
}
pub fn high_quality() -> Self {
Self {
mode: PilotMode::Aggressive,
budget: BudgetConfig {
max_tokens_per_query: 5000,
max_tokens_per_call: 1000,
max_calls_per_query: 10,
max_calls_per_level: 3,
hard_limit: false,
},
intervention: InterventionConfig {
fork_threshold: 2,
score_gap_threshold: 0.2,
low_score_threshold: 0.4,
max_interventions_per_level: 3,
},
guide_at_start: true,
guide_at_backtrack: true,
prompt_template_path: None,
prefilter: PrefilterConfig {
threshold: 20,
max_to_pilot: 20,
enabled: true,
},
prune: PruneConfig {
enabled: true,
threshold: 25,
min_keep: 5,
},
}
}
pub fn low_cost() -> Self {
Self {
mode: PilotMode::Conservative,
budget: BudgetConfig {
max_tokens_per_query: 500,
max_tokens_per_call: 200,
max_calls_per_query: 2,
max_calls_per_level: 1,
hard_limit: true,
},
intervention: InterventionConfig {
fork_threshold: 5,
score_gap_threshold: 0.1,
low_score_threshold: 0.2,
max_interventions_per_level: 1,
},
guide_at_start: false,
guide_at_backtrack: true,
prompt_template_path: None,
prefilter: PrefilterConfig {
threshold: 8,
max_to_pilot: 8,
enabled: true,
},
prune: PruneConfig {
enabled: true,
threshold: 12,
min_keep: 2,
},
}
}
pub fn algorithm_only() -> Self {
Self {
mode: PilotMode::AlgorithmOnly,
prefilter: PrefilterConfig {
threshold: 15,
max_to_pilot: 15,
enabled: false,
},
prune: PruneConfig {
enabled: false,
threshold: 20,
min_keep: 3,
},
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum PilotMode {
Aggressive,
#[default]
Balanced,
Conservative,
AlgorithmOnly,
}
impl PilotMode {
pub fn uses_llm(&self) -> bool {
!matches!(self, PilotMode::AlgorithmOnly)
}
pub fn fork_threshold_multiplier(&self) -> f32 {
match self {
PilotMode::Aggressive => 0.5, PilotMode::Balanced => 1.0,
PilotMode::Conservative => 2.0, PilotMode::AlgorithmOnly => f32::MAX,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetConfig {
pub max_tokens_per_query: usize,
pub max_tokens_per_call: usize,
pub max_calls_per_query: usize,
pub max_calls_per_level: usize,
pub hard_limit: bool,
}
impl Default for BudgetConfig {
fn default() -> Self {
Self {
max_tokens_per_query: 2000,
max_tokens_per_call: 500,
max_calls_per_query: 5,
max_calls_per_level: 2,
hard_limit: true,
}
}
}
impl BudgetConfig {
pub fn is_within_budget(&self, used: usize) -> bool {
used < self.max_tokens_per_query
}
pub fn remaining_tokens(&self, used: usize) -> usize {
self.max_tokens_per_query.saturating_sub(used)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterventionConfig {
pub fork_threshold: usize,
pub score_gap_threshold: f32,
pub low_score_threshold: f32,
pub max_interventions_per_level: usize,
}
impl Default for InterventionConfig {
fn default() -> Self {
Self {
fork_threshold: 3,
score_gap_threshold: 0.15,
low_score_threshold: 0.3,
max_interventions_per_level: 2,
}
}
}
impl InterventionConfig {
pub fn should_intervene_at_fork(&self, candidate_count: usize) -> bool {
candidate_count > self.fork_threshold
}
pub fn scores_are_close(&self, scores: &[f32]) -> bool {
if scores.len() < 2 {
return false;
}
let max_score = scores.iter().cloned().fold(0.0, f32::max);
let min_score = scores.iter().cloned().fold(1.0, f32::min);
(max_score - min_score) < self.score_gap_threshold
}
pub fn is_low_confidence(&self, best_score: f32) -> bool {
best_score < self.low_score_threshold
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefilterConfig {
pub threshold: usize,
pub max_to_pilot: usize,
pub enabled: bool,
}
impl Default for PrefilterConfig {
fn default() -> Self {
Self {
threshold: 15,
max_to_pilot: 15,
enabled: true,
}
}
}
impl PrefilterConfig {
pub fn should_prefilter(&self, candidate_count: usize) -> bool {
self.enabled && candidate_count > self.threshold
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PruneConfig {
pub enabled: bool,
pub threshold: usize,
pub min_keep: usize,
}
impl Default for PruneConfig {
fn default() -> Self {
Self {
enabled: true,
threshold: 20,
min_keep: 3,
}
}
}
impl PruneConfig {
pub fn should_prune(&self, candidate_count: usize) -> bool {
self.enabled && candidate_count > self.threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pilot_mode_uses_llm() {
assert!(PilotMode::Aggressive.uses_llm());
assert!(PilotMode::Balanced.uses_llm());
assert!(PilotMode::Conservative.uses_llm());
assert!(!PilotMode::AlgorithmOnly.uses_llm());
}
#[test]
fn test_budget_config() {
let config = BudgetConfig::default();
assert!(config.is_within_budget(1000));
assert!(!config.is_within_budget(3000));
assert_eq!(config.remaining_tokens(1500), 500);
}
#[test]
fn test_intervention_config() {
let config = InterventionConfig::default();
assert!(!config.should_intervene_at_fork(2));
assert!(config.should_intervene_at_fork(4));
assert!(config.scores_are_close(&[0.5, 0.55, 0.52]));
assert!(!config.scores_are_close(&[0.3, 0.8]));
assert!(config.is_low_confidence(0.2));
assert!(!config.is_low_confidence(0.5));
}
#[test]
fn test_pilot_config_presets() {
let high = PilotConfig::high_quality();
assert_eq!(high.mode, PilotMode::Aggressive);
assert!(high.prefilter.enabled);
assert_eq!(high.prefilter.threshold, 20);
let low = PilotConfig::low_cost();
assert_eq!(low.mode, PilotMode::Conservative);
assert!(low.prefilter.enabled);
assert_eq!(low.prefilter.threshold, 8);
let algo = PilotConfig::algorithm_only();
assert_eq!(algo.mode, PilotMode::AlgorithmOnly);
assert!(!algo.prefilter.enabled);
}
#[test]
fn test_prefilter_config_default() {
let cfg = PrefilterConfig::default();
assert!(cfg.enabled);
assert_eq!(cfg.threshold, 15);
assert_eq!(cfg.max_to_pilot, 15);
}
#[test]
fn test_prefilter_should_prefilter() {
let cfg = PrefilterConfig::default();
assert!(!cfg.should_prefilter(15)); assert!(!cfg.should_prefilter(10)); assert!(cfg.should_prefilter(16));
let disabled = PrefilterConfig {
enabled: false,
..Default::default()
};
assert!(!disabled.should_prefilter(100));
}
#[test]
fn test_prune_config_default() {
let cfg = PruneConfig::default();
assert!(cfg.enabled);
assert_eq!(cfg.threshold, 20);
assert_eq!(cfg.min_keep, 3);
}
#[test]
fn test_prune_should_prune() {
let cfg = PruneConfig::default();
assert!(!cfg.should_prune(20)); assert!(!cfg.should_prune(15)); assert!(cfg.should_prune(21));
let disabled = PruneConfig {
enabled: false,
..Default::default()
};
assert!(!disabled.should_prune(100));
}
#[test]
fn test_pilot_config_presets_prune() {
let high = PilotConfig::high_quality();
assert!(high.prune.enabled);
assert_eq!(high.prune.threshold, 25);
let low = PilotConfig::low_cost();
assert!(low.prune.enabled);
assert_eq!(low.prune.threshold, 12);
let algo = PilotConfig::algorithm_only();
assert!(!algo.prune.enabled);
}
}