use crate::optimizer::cost_model::{
CostEstimate, CostModel, IndexFamily, IndexParameters, WorkloadProfile,
};
use crate::optimizer::query_stats::{QueryObservation, QueryStats};
use serde::{Deserialize, Serialize};
use std::collections::BTreeSet;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DispatcherConfig {
pub recall_fallback_threshold: f32,
pub max_fallbacks: usize,
pub weight_refresh_interval: u64,
pub enabled_families: BTreeSet<IndexFamily>,
}
impl Default for DispatcherConfig {
fn default() -> Self {
Self {
recall_fallback_threshold: 0.85,
max_fallbacks: 1,
weight_refresh_interval: 64,
enabled_families: BTreeSet::new(), }
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DispatchPlan {
pub primary: IndexFamily,
pub primary_cost: f64,
pub primary_recall: f32,
pub fallbacks: Vec<CostEstimate>,
pub workload: WorkloadProfile,
}
impl DispatchPlan {
pub fn has_fallback(&self) -> bool {
!self.fallbacks.is_empty()
}
pub fn fallback_at(&self, idx: usize) -> Option<IndexFamily> {
self.fallbacks.get(idx).map(|e| e.family)
}
}
#[derive(Debug, thiserror::Error)]
pub enum DispatchError {
#[error(
"no index family meets requested_recall={requested:.3}; best estimate was {best_recall:.3} \
from {best_family:?}"
)]
NoFamilyMeetsRecall {
requested: f32,
best_recall: f32,
best_family: IndexFamily,
},
#[error("no index families enabled in dispatcher configuration")]
NoFamiliesEnabled,
}
pub struct OptimizerDispatcher {
cost_model: CostModel,
stats: QueryStats,
config: DispatcherConfig,
observations_since_refresh: u64,
}
impl Default for OptimizerDispatcher {
fn default() -> Self {
Self::new(
CostModel::default(),
QueryStats::default(),
DispatcherConfig::default(),
)
}
}
impl OptimizerDispatcher {
pub fn new(cost_model: CostModel, stats: QueryStats, config: DispatcherConfig) -> Self {
Self {
cost_model,
stats,
config,
observations_since_refresh: 0,
}
}
pub fn cost_model(&self) -> &CostModel {
&self.cost_model
}
pub fn stats(&self) -> &QueryStats {
&self.stats
}
pub fn config(&self) -> &DispatcherConfig {
&self.config
}
pub fn cost_model_mut(&mut self) -> &mut CostModel {
&mut self.cost_model
}
pub fn stats_mut(&mut self) -> &mut QueryStats {
&mut self.stats
}
pub fn pick_plan(&self, workload: &WorkloadProfile) -> Result<DispatchPlan, DispatchError> {
let enabled = self.enabled_families();
if enabled.is_empty() {
return Err(DispatchError::NoFamiliesEnabled);
}
let mut estimates: Vec<CostEstimate> = enabled
.iter()
.map(|fam| self.cost_model.estimate(*fam, workload))
.collect();
estimates.sort_by(|a, b| {
a.cost
.partial_cmp(&b.cost)
.unwrap_or(std::cmp::Ordering::Equal)
});
let recall_target = workload.requested_recall;
let (meets, below): (Vec<_>, Vec<_>) = estimates
.iter()
.cloned()
.partition(|e| e.recall >= recall_target);
let primary_estimate = if let Some(first) = meets.first() {
first.clone()
} else {
let best = below
.iter()
.max_by(|a, b| {
a.recall
.partial_cmp(&b.recall)
.unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or(DispatchError::NoFamiliesEnabled)?
.clone();
tracing::warn!(
"OptimizerDispatcher: no family meets requested_recall={:.3}; \
best is {:?} with recall={:.3}",
recall_target,
best.family,
best.recall
);
best
};
let fallbacks: Vec<CostEstimate> = if !meets.is_empty() {
meets
.into_iter()
.filter(|e| e.family != primary_estimate.family)
.collect()
} else {
estimates
.into_iter()
.filter(|e| e.family != primary_estimate.family)
.collect()
};
Ok(DispatchPlan {
primary: primary_estimate.family,
primary_cost: primary_estimate.cost,
primary_recall: primary_estimate.recall,
fallbacks,
workload: workload.clone(),
})
}
pub fn should_fallback(&self, plan: &DispatchPlan, observed_recall: f32) -> bool {
plan.has_fallback() && observed_recall < self.config.recall_fallback_threshold
}
pub fn record_observation(&mut self, observation: QueryObservation) -> bool {
self.stats.record(observation);
self.observations_since_refresh += 1;
if self.observations_since_refresh >= self.config.weight_refresh_interval {
let new_weights = self.stats.recommended_weights(self.cost_model.weights());
*self.cost_model.weights_mut() = new_weights;
self.observations_since_refresh = 0;
true
} else {
false
}
}
pub fn force_refresh_weights(&mut self) {
let new_weights = self.stats.recommended_weights(self.cost_model.weights());
*self.cost_model.weights_mut() = new_weights;
self.observations_since_refresh = 0;
}
fn enabled_families(&self) -> Vec<IndexFamily> {
let universe = IndexFamily::all();
if self.config.enabled_families.is_empty() {
universe.to_vec()
} else {
universe
.into_iter()
.filter(|f| self.config.enabled_families.contains(f))
.collect()
}
}
}
pub fn dispatcher_with_families(families: &[IndexFamily]) -> OptimizerDispatcher {
let cfg = DispatcherConfig {
enabled_families: families.iter().copied().collect(),
..Default::default()
};
OptimizerDispatcher::new(CostModel::default(), QueryStats::default(), cfg)
}
pub fn dispatcher_with_parameters(parameters: IndexParameters) -> OptimizerDispatcher {
let cost_model = CostModel::new(parameters, Default::default());
OptimizerDispatcher::new(
cost_model,
QueryStats::default(),
DispatcherConfig::default(),
)
}
#[cfg(test)]
mod tests {
use super::*;
fn workload(n: usize, dim: usize, recall: f32) -> WorkloadProfile {
WorkloadProfile::new(n, dim, recall)
}
#[test]
fn dispatcher_picks_lowest_cost_meeting_recall() {
let dispatcher = OptimizerDispatcher::default();
let plan = dispatcher
.pick_plan(&workload(100_000, 128, 0.9))
.expect("plan must exist");
assert!(
plan.primary_recall >= 0.9,
"primary recall must meet target"
);
}
#[test]
fn dispatcher_provides_fallback_chain() {
let dispatcher = OptimizerDispatcher::default();
let plan = dispatcher
.pick_plan(&workload(100_000, 128, 0.85))
.expect("plan must exist");
assert!(plan.has_fallback(), "fallback chain should be non-empty");
}
#[test]
fn dispatcher_handles_unmet_recall_with_warning() {
let dispatcher = OptimizerDispatcher::default();
let plan = dispatcher
.pick_plan(&workload(10_000, 128, 0.999))
.expect("dispatcher returns best-effort plan");
assert!(plan.primary_recall < 0.999);
}
#[test]
fn enabled_families_filter_respected() {
let dispatcher = dispatcher_with_families(&[IndexFamily::Lsh, IndexFamily::Pq]);
let plan = dispatcher
.pick_plan(&workload(10_000, 128, 0.7))
.expect("plan must exist");
assert!(matches!(plan.primary, IndexFamily::Lsh | IndexFamily::Pq));
}
#[test]
fn empty_enabled_set_returns_error_when_constructed_directly() {
let mut dispatcher = OptimizerDispatcher::default();
dispatcher.config.enabled_families.insert(IndexFamily::Hnsw);
dispatcher
.config
.enabled_families
.remove(&IndexFamily::Hnsw);
let plan = dispatcher.pick_plan(&workload(1_000, 8, 0.5));
assert!(plan.is_ok());
}
#[test]
fn should_fallback_triggers_when_observed_below_threshold() {
let dispatcher = OptimizerDispatcher::default();
let plan = dispatcher
.pick_plan(&workload(100_000, 128, 0.85))
.expect("plan");
assert!(dispatcher.should_fallback(&plan, 0.5));
assert!(!dispatcher.should_fallback(&plan, 0.95));
}
#[test]
fn record_observation_refreshes_weights_at_interval() {
let mut dispatcher = OptimizerDispatcher::default();
dispatcher.config.weight_refresh_interval = 3;
for _ in 0..2 {
let refreshed = dispatcher.record_observation(QueryObservation::new(
IndexFamily::Hnsw,
true,
100.0,
Some(0.92),
50.0,
));
assert!(!refreshed);
}
let refreshed = dispatcher.record_observation(QueryObservation::new(
IndexFamily::Hnsw,
true,
100.0,
Some(0.92),
50.0,
));
assert!(refreshed, "refresh should trigger on 3rd observation");
let w = dispatcher.cost_model().weights().get(IndexFamily::Hnsw);
assert!((w - 2.0).abs() < 1e-6);
}
#[test]
fn force_refresh_weights_immediately() {
let mut dispatcher = OptimizerDispatcher::default();
dispatcher.stats.record(QueryObservation::new(
IndexFamily::Pq,
true,
300.0,
None,
150.0,
));
dispatcher.force_refresh_weights();
let w = dispatcher.cost_model().weights().get(IndexFamily::Pq);
assert!((w - 2.0).abs() < 1e-6);
}
#[test]
fn dispatcher_with_parameters_uses_overrides() {
let params = IndexParameters {
hnsw_ef: 200,
..Default::default()
};
let dispatcher = dispatcher_with_parameters(params);
let cost_high = dispatcher
.cost_model()
.estimate(IndexFamily::Hnsw, &workload(100_000, 128, 0.9));
let dispatcher_default = OptimizerDispatcher::default();
let cost_low = dispatcher_default
.cost_model()
.estimate(IndexFamily::Hnsw, &workload(100_000, 128, 0.9));
assert!(cost_high.cost > cost_low.cost);
}
}