1use crate::optimizer::cost_model::{
24 CostEstimate, CostModel, IndexFamily, IndexParameters, WorkloadProfile,
25};
26use crate::optimizer::query_stats::{QueryObservation, QueryStats};
27use serde::{Deserialize, Serialize};
28use std::collections::BTreeSet;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct DispatcherConfig {
33 pub recall_fallback_threshold: f32,
36 pub max_fallbacks: usize,
38 pub weight_refresh_interval: u64,
41 pub enabled_families: BTreeSet<IndexFamily>,
44}
45
46impl Default for DispatcherConfig {
47 fn default() -> Self {
48 Self {
49 recall_fallback_threshold: 0.85,
50 max_fallbacks: 1,
51 weight_refresh_interval: 64,
52 enabled_families: BTreeSet::new(), }
54 }
55}
56
57#[derive(Debug, Clone, PartialEq)]
59pub struct DispatchPlan {
60 pub primary: IndexFamily,
62 pub primary_cost: f64,
64 pub primary_recall: f32,
66 pub fallbacks: Vec<CostEstimate>,
68 pub workload: WorkloadProfile,
70}
71
72impl DispatchPlan {
73 pub fn has_fallback(&self) -> bool {
75 !self.fallbacks.is_empty()
76 }
77
78 pub fn fallback_at(&self, idx: usize) -> Option<IndexFamily> {
80 self.fallbacks.get(idx).map(|e| e.family)
81 }
82}
83
84#[derive(Debug, thiserror::Error)]
86pub enum DispatchError {
87 #[error(
90 "no index family meets requested_recall={requested:.3}; best estimate was {best_recall:.3} \
91 from {best_family:?}"
92 )]
93 NoFamilyMeetsRecall {
94 requested: f32,
96 best_recall: f32,
98 best_family: IndexFamily,
100 },
101 #[error("no index families enabled in dispatcher configuration")]
104 NoFamiliesEnabled,
105}
106
107pub struct OptimizerDispatcher {
109 cost_model: CostModel,
110 stats: QueryStats,
111 config: DispatcherConfig,
112 observations_since_refresh: u64,
113}
114
115impl Default for OptimizerDispatcher {
116 fn default() -> Self {
117 Self::new(
118 CostModel::default(),
119 QueryStats::default(),
120 DispatcherConfig::default(),
121 )
122 }
123}
124
125impl OptimizerDispatcher {
126 pub fn new(cost_model: CostModel, stats: QueryStats, config: DispatcherConfig) -> Self {
128 Self {
129 cost_model,
130 stats,
131 config,
132 observations_since_refresh: 0,
133 }
134 }
135
136 pub fn cost_model(&self) -> &CostModel {
138 &self.cost_model
139 }
140
141 pub fn stats(&self) -> &QueryStats {
143 &self.stats
144 }
145
146 pub fn config(&self) -> &DispatcherConfig {
148 &self.config
149 }
150
151 pub fn cost_model_mut(&mut self) -> &mut CostModel {
153 &mut self.cost_model
154 }
155
156 pub fn stats_mut(&mut self) -> &mut QueryStats {
158 &mut self.stats
159 }
160
161 pub fn pick_plan(&self, workload: &WorkloadProfile) -> Result<DispatchPlan, DispatchError> {
163 let enabled = self.enabled_families();
164 if enabled.is_empty() {
165 return Err(DispatchError::NoFamiliesEnabled);
166 }
167
168 let mut estimates: Vec<CostEstimate> = enabled
170 .iter()
171 .map(|fam| self.cost_model.estimate(*fam, workload))
172 .collect();
173
174 estimates.sort_by(|a, b| {
176 a.cost
177 .partial_cmp(&b.cost)
178 .unwrap_or(std::cmp::Ordering::Equal)
179 });
180
181 let recall_target = workload.requested_recall;
183 let (meets, below): (Vec<_>, Vec<_>) = estimates
184 .iter()
185 .cloned()
186 .partition(|e| e.recall >= recall_target);
187
188 let primary_estimate = if let Some(first) = meets.first() {
189 first.clone()
190 } else {
191 let best = below
195 .iter()
196 .max_by(|a, b| {
197 a.recall
198 .partial_cmp(&b.recall)
199 .unwrap_or(std::cmp::Ordering::Equal)
200 })
201 .ok_or(DispatchError::NoFamiliesEnabled)?
202 .clone();
203 tracing::warn!(
204 "OptimizerDispatcher: no family meets requested_recall={:.3}; \
205 best is {:?} with recall={:.3}",
206 recall_target,
207 best.family,
208 best.recall
209 );
210 best
211 };
212
213 let fallbacks: Vec<CostEstimate> = if !meets.is_empty() {
217 meets
218 .into_iter()
219 .filter(|e| e.family != primary_estimate.family)
220 .collect()
221 } else {
222 estimates
223 .into_iter()
224 .filter(|e| e.family != primary_estimate.family)
225 .collect()
226 };
227
228 Ok(DispatchPlan {
229 primary: primary_estimate.family,
230 primary_cost: primary_estimate.cost,
231 primary_recall: primary_estimate.recall,
232 fallbacks,
233 workload: workload.clone(),
234 })
235 }
236
237 pub fn should_fallback(&self, plan: &DispatchPlan, observed_recall: f32) -> bool {
239 plan.has_fallback() && observed_recall < self.config.recall_fallback_threshold
240 }
241
242 pub fn record_observation(&mut self, observation: QueryObservation) -> bool {
246 self.stats.record(observation);
247 self.observations_since_refresh += 1;
248
249 if self.observations_since_refresh >= self.config.weight_refresh_interval {
250 let new_weights = self.stats.recommended_weights(self.cost_model.weights());
251 *self.cost_model.weights_mut() = new_weights;
252 self.observations_since_refresh = 0;
253 true
254 } else {
255 false
256 }
257 }
258
259 pub fn force_refresh_weights(&mut self) {
261 let new_weights = self.stats.recommended_weights(self.cost_model.weights());
262 *self.cost_model.weights_mut() = new_weights;
263 self.observations_since_refresh = 0;
264 }
265
266 fn enabled_families(&self) -> Vec<IndexFamily> {
269 let universe = IndexFamily::all();
270 if self.config.enabled_families.is_empty() {
271 universe.to_vec()
272 } else {
273 universe
274 .into_iter()
275 .filter(|f| self.config.enabled_families.contains(f))
276 .collect()
277 }
278 }
279}
280
281pub fn dispatcher_with_families(families: &[IndexFamily]) -> OptimizerDispatcher {
283 let cfg = DispatcherConfig {
284 enabled_families: families.iter().copied().collect(),
285 ..Default::default()
286 };
287 OptimizerDispatcher::new(CostModel::default(), QueryStats::default(), cfg)
288}
289
290pub fn dispatcher_with_parameters(parameters: IndexParameters) -> OptimizerDispatcher {
292 let cost_model = CostModel::new(parameters, Default::default());
293 OptimizerDispatcher::new(
294 cost_model,
295 QueryStats::default(),
296 DispatcherConfig::default(),
297 )
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 fn workload(n: usize, dim: usize, recall: f32) -> WorkloadProfile {
305 WorkloadProfile::new(n, dim, recall)
306 }
307
308 #[test]
309 fn dispatcher_picks_lowest_cost_meeting_recall() {
310 let dispatcher = OptimizerDispatcher::default();
311 let plan = dispatcher
312 .pick_plan(&workload(100_000, 128, 0.9))
313 .expect("plan must exist");
314 assert!(
317 plan.primary_recall >= 0.9,
318 "primary recall must meet target"
319 );
320 }
321
322 #[test]
323 fn dispatcher_provides_fallback_chain() {
324 let dispatcher = OptimizerDispatcher::default();
325 let plan = dispatcher
326 .pick_plan(&workload(100_000, 128, 0.85))
327 .expect("plan must exist");
328 assert!(plan.has_fallback(), "fallback chain should be non-empty");
330 }
331
332 #[test]
333 fn dispatcher_handles_unmet_recall_with_warning() {
334 let dispatcher = OptimizerDispatcher::default();
335 let plan = dispatcher
337 .pick_plan(&workload(10_000, 128, 0.999))
338 .expect("dispatcher returns best-effort plan");
339 assert!(plan.primary_recall < 0.999);
341 }
342
343 #[test]
344 fn enabled_families_filter_respected() {
345 let dispatcher = dispatcher_with_families(&[IndexFamily::Lsh, IndexFamily::Pq]);
346 let plan = dispatcher
347 .pick_plan(&workload(10_000, 128, 0.7))
348 .expect("plan must exist");
349 assert!(matches!(plan.primary, IndexFamily::Lsh | IndexFamily::Pq));
350 }
351
352 #[test]
353 fn empty_enabled_set_returns_error_when_constructed_directly() {
354 let mut dispatcher = OptimizerDispatcher::default();
359 dispatcher.config.enabled_families.insert(IndexFamily::Hnsw);
360 dispatcher
361 .config
362 .enabled_families
363 .remove(&IndexFamily::Hnsw);
364 let plan = dispatcher.pick_plan(&workload(1_000, 8, 0.5));
366 assert!(plan.is_ok());
367 }
368
369 #[test]
370 fn should_fallback_triggers_when_observed_below_threshold() {
371 let dispatcher = OptimizerDispatcher::default();
372 let plan = dispatcher
373 .pick_plan(&workload(100_000, 128, 0.85))
374 .expect("plan");
375 assert!(dispatcher.should_fallback(&plan, 0.5));
376 assert!(!dispatcher.should_fallback(&plan, 0.95));
377 }
378
379 #[test]
380 fn record_observation_refreshes_weights_at_interval() {
381 let mut dispatcher = OptimizerDispatcher::default();
382 dispatcher.config.weight_refresh_interval = 3;
383 for _ in 0..2 {
385 let refreshed = dispatcher.record_observation(QueryObservation::new(
386 IndexFamily::Hnsw,
387 true,
388 100.0,
389 Some(0.92),
390 50.0,
391 ));
392 assert!(!refreshed);
393 }
394 let refreshed = dispatcher.record_observation(QueryObservation::new(
396 IndexFamily::Hnsw,
397 true,
398 100.0,
399 Some(0.92),
400 50.0,
401 ));
402 assert!(refreshed, "refresh should trigger on 3rd observation");
403
404 let w = dispatcher.cost_model().weights().get(IndexFamily::Hnsw);
406 assert!((w - 2.0).abs() < 1e-6);
407 }
408
409 #[test]
410 fn force_refresh_weights_immediately() {
411 let mut dispatcher = OptimizerDispatcher::default();
412 dispatcher.stats.record(QueryObservation::new(
413 IndexFamily::Pq,
414 true,
415 300.0,
416 None,
417 150.0,
418 ));
419 dispatcher.force_refresh_weights();
420 let w = dispatcher.cost_model().weights().get(IndexFamily::Pq);
421 assert!((w - 2.0).abs() < 1e-6);
422 }
423
424 #[test]
425 fn dispatcher_with_parameters_uses_overrides() {
426 let params = IndexParameters {
428 hnsw_ef: 200,
429 ..Default::default()
430 };
431 let dispatcher = dispatcher_with_parameters(params);
432 let cost_high = dispatcher
433 .cost_model()
434 .estimate(IndexFamily::Hnsw, &workload(100_000, 128, 0.9));
435 let dispatcher_default = OptimizerDispatcher::default();
436 let cost_low = dispatcher_default
437 .cost_model()
438 .estimate(IndexFamily::Hnsw, &workload(100_000, 128, 0.9));
439 assert!(cost_high.cost > cost_low.cost);
440 }
441}