1use std::collections::HashMap;
28
29use crate::core::BackendKind;
30use crate::types::SearchQuery;
31
32use super::analyzer::{QueryAnalyzer, QueryFeature};
33use super::config::{BackendEntry, CompositeConfig, CostConfig};
34
35#[derive(Debug, Clone)]
37pub struct QueryCost {
38 pub total: f64,
40
41 pub estimated_latency_ms: u64,
43
44 pub estimated_results: EstimatedCount,
46
47 pub confidence: f64,
49
50 pub breakdown: CostBreakdown,
52}
53
54#[derive(Debug, Clone)]
56pub enum EstimatedCount {
57 Exact(u64),
59 Approximate(u64),
61 Range {
63 min: u64,
65 max: u64,
67 },
68 Unknown,
70}
71
72impl EstimatedCount {
73 pub fn expected(&self) -> u64 {
75 match self {
76 EstimatedCount::Exact(n) => *n,
77 EstimatedCount::Approximate(n) => *n,
78 EstimatedCount::Range { min, max } => (min + max) / 2,
79 EstimatedCount::Unknown => 100, }
81 }
82}
83
84#[derive(Debug, Clone, Default)]
86pub struct CostBreakdown {
87 pub base: f64,
89
90 pub feature_costs: HashMap<QueryFeature, f64>,
92
93 pub volume_cost: f64,
95
96 pub latency_cost: f64,
98
99 pub resource_cost: f64,
101}
102
103impl CostBreakdown {
104 pub fn total(&self) -> f64 {
106 self.base
107 + self.feature_costs.values().sum::<f64>()
108 + self.volume_cost
109 + self.latency_cost
110 + self.resource_cost
111 }
112}
113
114pub struct CostEstimator {
116 config: CostConfig,
118
119 analyzer: QueryAnalyzer,
121
122 benchmarks: Option<BenchmarkResults>,
124}
125
126impl CostEstimator {
127 pub fn with_defaults() -> Self {
129 Self {
130 config: CostConfig::default(),
131 analyzer: QueryAnalyzer::new(),
132 benchmarks: None,
133 }
134 }
135
136 pub fn new(config: CostConfig) -> Self {
138 Self {
139 config,
140 analyzer: QueryAnalyzer::new(),
141 benchmarks: None,
142 }
143 }
144
145 pub fn with_benchmarks(mut self, benchmarks: BenchmarkResults) -> Self {
147 self.benchmarks = Some(benchmarks);
148 self
149 }
150
151 pub fn estimate(&self, query: &SearchQuery, backend: &BackendEntry) -> QueryCost {
153 let analysis = self.analyzer.analyze(query);
154
155 let base_cost = self
157 .config
158 .base_costs
159 .get(&backend.kind)
160 .copied()
161 .unwrap_or(1.0);
162
163 let mut feature_costs = HashMap::new();
165 for feature in &analysis.features {
166 let multiplier = self
167 .config
168 .feature_multipliers
169 .get(feature)
170 .copied()
171 .unwrap_or(1.0);
172
173 feature_costs.insert(*feature, base_cost * multiplier);
174 }
175
176 let specificity = self.estimate_specificity(query);
178 let volume_cost = base_cost * (1.0 - specificity) * 2.0;
179
180 let estimated_latency_ms = self.estimate_latency(&backend.kind, &analysis);
182
183 let total = base_cost * self.config.weights.latency
185 + feature_costs.values().sum::<f64>()
186 + volume_cost * self.config.weights.resource_usage;
187
188 let breakdown = CostBreakdown {
189 base: base_cost,
190 feature_costs,
191 volume_cost,
192 latency_cost: estimated_latency_ms as f64 * 0.01,
193 resource_cost: 0.0,
194 };
195
196 QueryCost {
197 total,
198 estimated_latency_ms,
199 estimated_results: EstimatedCount::Unknown,
200 confidence: self.estimate_confidence(&analysis),
201 breakdown,
202 }
203 }
204
205 pub fn estimate_all(
207 &self,
208 query: &SearchQuery,
209 config: &CompositeConfig,
210 ) -> HashMap<String, QueryCost> {
211 config
212 .backends
213 .iter()
214 .filter(|b| b.enabled)
215 .map(|backend| (backend.id.clone(), self.estimate(query, backend)))
216 .collect()
217 }
218
219 pub fn cheapest_backend<'a>(
221 &self,
222 query: &SearchQuery,
223 backends: &'a [BackendEntry],
224 ) -> Option<&'a BackendEntry> {
225 backends
226 .iter()
227 .filter(|b| b.enabled)
228 .map(|b| (b, self.estimate(query, b)))
229 .min_by(|(_, cost_a), (_, cost_b)| {
230 cost_a
231 .total
232 .partial_cmp(&cost_b.total)
233 .unwrap_or(std::cmp::Ordering::Equal)
234 })
235 .map(|(backend, _)| backend)
236 }
237
238 fn estimate_specificity(&self, query: &SearchQuery) -> f64 {
240 let mut specificity: f64 = 0.0;
241
242 for param in &query.parameters {
243 match param.name.as_str() {
245 "_id" => specificity += 0.9,
246 "identifier" => specificity += 0.7,
247 _ => specificity += 0.1,
248 }
249
250 if param.values.len() > 1 {
252 specificity *= 0.8;
253 }
254 }
255
256 specificity.min(1.0)
258 }
259
260 fn estimate_latency(
262 &self,
263 backend_kind: &BackendKind,
264 analysis: &super::analyzer::QueryAnalysis,
265 ) -> u64 {
266 let base_latency = match backend_kind {
268 BackendKind::Sqlite => 1,
269 BackendKind::Postgres => 5,
270 BackendKind::Elasticsearch => 10,
271 BackendKind::Neo4j => 15,
272 BackendKind::S3 => 50,
273 _ => 10,
274 };
275
276 let feature_latency: u64 = analysis
278 .features
279 .iter()
280 .map(|f| match f {
281 QueryFeature::ChainedSearch => 20,
282 QueryFeature::ReverseChaining => 25,
283 QueryFeature::FullTextSearch => 15,
284 QueryFeature::TerminologySearch => 30,
285 QueryFeature::Include | QueryFeature::Revinclude => 10,
286 _ => 0,
287 })
288 .sum();
289
290 base_latency + feature_latency
291 }
292
293 fn estimate_confidence(&self, analysis: &super::analyzer::QueryAnalysis) -> f64 {
295 let mut confidence = 0.8;
296
297 if analysis.complexity_score > 5 {
299 confidence *= 0.8;
300 }
301
302 if self.benchmarks.is_none() {
304 confidence *= 0.7;
305 }
306
307 confidence
308 }
309}
310
311impl Default for CostEstimator {
312 fn default() -> Self {
313 Self::with_defaults()
314 }
315}
316
317#[derive(Debug, Clone, Default)]
319pub struct BenchmarkResults {
320 pub operations: HashMap<(BackendKind, BenchmarkOperation), BenchmarkMeasurement>,
322}
323
324#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
326pub enum BenchmarkOperation {
327 IdLookup,
329 StringSearch,
331 TokenSearch,
333 DateSearch,
335 ChainedSearch1,
337 ChainedSearch2,
339 ChainedSearch3,
341 FullTextSearch,
343 TerminologyExpand,
345 IncludeResolve,
347 RevincludeResolve,
349}
350
351#[derive(Debug, Clone)]
353pub struct BenchmarkMeasurement {
354 pub mean_us: f64,
356
357 pub std_dev_us: f64,
359
360 pub iterations: u64,
362
363 pub throughput: f64,
365}
366
367impl BenchmarkResults {
368 pub fn new() -> Self {
370 Self::default()
371 }
372
373 pub fn add(
375 &mut self,
376 backend: BackendKind,
377 operation: BenchmarkOperation,
378 measurement: BenchmarkMeasurement,
379 ) {
380 self.operations.insert((backend, operation), measurement);
381 }
382
383 pub fn cost_multiplier(
385 &self,
386 backend: BackendKind,
387 operation: BenchmarkOperation,
388 ) -> Option<f64> {
389 self.operations
390 .get(&(backend, operation))
391 .map(|m| m.mean_us / 1000.0) }
393}
394
395#[derive(Debug)]
397pub struct CostComparison {
398 pub options: Vec<(String, QueryCost)>,
400
401 pub recommended: String,
403
404 pub savings_percent: f64,
406}
407
408impl CostComparison {
409 pub fn from_estimates(estimates: HashMap<String, QueryCost>) -> Self {
411 let mut options: Vec<_> = estimates.into_iter().collect();
412 options.sort_by(|a, b| {
413 a.1.total
414 .partial_cmp(&b.1.total)
415 .unwrap_or(std::cmp::Ordering::Equal)
416 });
417
418 let recommended = options
419 .first()
420 .map(|(id, _)| id.clone())
421 .unwrap_or_default();
422
423 let best_cost = options.first().map(|(_, c)| c.total).unwrap_or(1.0);
424 let worst_cost = options.last().map(|(_, c)| c.total).unwrap_or(1.0);
425 let savings_percent = if worst_cost > 0.0 {
426 ((worst_cost - best_cost) / worst_cost) * 100.0
427 } else {
428 0.0
429 };
430
431 Self {
432 options,
433 recommended,
434 savings_percent,
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_cost_estimator_default() {
445 let estimator = CostEstimator::with_defaults();
446 assert!(
447 estimator
448 .config
449 .base_costs
450 .contains_key(&BackendKind::Sqlite)
451 );
452 }
453
454 #[test]
455 fn test_estimated_count_expected() {
456 assert_eq!(EstimatedCount::Exact(50).expected(), 50);
457 assert_eq!(EstimatedCount::Approximate(100).expected(), 100);
458 assert_eq!(EstimatedCount::Range { min: 10, max: 30 }.expected(), 20);
459 assert_eq!(EstimatedCount::Unknown.expected(), 100);
460 }
461
462 #[test]
463 fn test_cost_breakdown_total() {
464 let mut breakdown = CostBreakdown {
465 base: 1.0,
466 feature_costs: HashMap::new(),
467 volume_cost: 0.5,
468 latency_cost: 0.2,
469 resource_cost: 0.1,
470 };
471 breakdown
472 .feature_costs
473 .insert(QueryFeature::BasicSearch, 0.2);
474
475 assert!((breakdown.total() - 2.0).abs() < 0.01);
476 }
477
478 #[test]
479 fn test_estimate_simple_query() {
480 let estimator = CostEstimator::with_defaults();
481 let backend = BackendEntry::new(
482 "test",
483 super::super::config::BackendRole::Primary,
484 BackendKind::Sqlite,
485 );
486 let query = SearchQuery::new("Patient");
487
488 let cost = estimator.estimate(&query, &backend);
489 assert!(cost.total > 0.0);
490 assert!(cost.confidence > 0.0);
491 }
492
493 #[test]
494 fn test_benchmark_results() {
495 let mut results = BenchmarkResults::new();
496 results.add(
497 BackendKind::Sqlite,
498 BenchmarkOperation::IdLookup,
499 BenchmarkMeasurement {
500 mean_us: 100.0,
501 std_dev_us: 10.0,
502 iterations: 1000,
503 throughput: 10000.0,
504 },
505 );
506
507 let multiplier = results
508 .cost_multiplier(BackendKind::Sqlite, BenchmarkOperation::IdLookup)
509 .unwrap();
510 assert!((multiplier - 0.1).abs() < 0.01);
511 }
512
513 #[test]
514 fn test_cost_comparison() {
515 let mut estimates = HashMap::new();
516 estimates.insert(
517 "fast".to_string(),
518 QueryCost {
519 total: 1.0,
520 estimated_latency_ms: 10,
521 estimated_results: EstimatedCount::Unknown,
522 confidence: 0.8,
523 breakdown: CostBreakdown::default(),
524 },
525 );
526 estimates.insert(
527 "slow".to_string(),
528 QueryCost {
529 total: 2.0,
530 estimated_latency_ms: 20,
531 estimated_results: EstimatedCount::Unknown,
532 confidence: 0.8,
533 breakdown: CostBreakdown::default(),
534 },
535 );
536
537 let comparison = CostComparison::from_estimates(estimates);
538 assert_eq!(comparison.recommended, "fast");
539 assert!((comparison.savings_percent - 50.0).abs() < 0.01);
540 }
541}