use std::collections::HashMap;
use crate::core::BackendKind;
use crate::types::SearchQuery;
use super::analyzer::{QueryAnalyzer, QueryFeature};
use super::config::{BackendEntry, CompositeConfig, CostConfig};
#[derive(Debug, Clone)]
pub struct QueryCost {
pub total: f64,
pub estimated_latency_ms: u64,
pub estimated_results: EstimatedCount,
pub confidence: f64,
pub breakdown: CostBreakdown,
}
#[derive(Debug, Clone)]
pub enum EstimatedCount {
Exact(u64),
Approximate(u64),
Range {
min: u64,
max: u64,
},
Unknown,
}
impl EstimatedCount {
pub fn expected(&self) -> u64 {
match self {
EstimatedCount::Exact(n) => *n,
EstimatedCount::Approximate(n) => *n,
EstimatedCount::Range { min, max } => (min + max) / 2,
EstimatedCount::Unknown => 100, }
}
}
#[derive(Debug, Clone, Default)]
pub struct CostBreakdown {
pub base: f64,
pub feature_costs: HashMap<QueryFeature, f64>,
pub volume_cost: f64,
pub latency_cost: f64,
pub resource_cost: f64,
}
impl CostBreakdown {
pub fn total(&self) -> f64 {
self.base
+ self.feature_costs.values().sum::<f64>()
+ self.volume_cost
+ self.latency_cost
+ self.resource_cost
}
}
pub struct CostEstimator {
config: CostConfig,
analyzer: QueryAnalyzer,
benchmarks: Option<BenchmarkResults>,
}
impl CostEstimator {
pub fn with_defaults() -> Self {
Self {
config: CostConfig::default(),
analyzer: QueryAnalyzer::new(),
benchmarks: None,
}
}
pub fn new(config: CostConfig) -> Self {
Self {
config,
analyzer: QueryAnalyzer::new(),
benchmarks: None,
}
}
pub fn with_benchmarks(mut self, benchmarks: BenchmarkResults) -> Self {
self.benchmarks = Some(benchmarks);
self
}
pub fn estimate(&self, query: &SearchQuery, backend: &BackendEntry) -> QueryCost {
let analysis = self.analyzer.analyze(query);
let base_cost = self
.config
.base_costs
.get(&backend.kind)
.copied()
.unwrap_or(1.0);
let mut feature_costs = HashMap::new();
for feature in &analysis.features {
let multiplier = self
.config
.feature_multipliers
.get(feature)
.copied()
.unwrap_or(1.0);
feature_costs.insert(*feature, base_cost * multiplier);
}
let specificity = self.estimate_specificity(query);
let volume_cost = base_cost * (1.0 - specificity) * 2.0;
let estimated_latency_ms = self.estimate_latency(&backend.kind, &analysis);
let total = base_cost * self.config.weights.latency
+ feature_costs.values().sum::<f64>()
+ volume_cost * self.config.weights.resource_usage;
let breakdown = CostBreakdown {
base: base_cost,
feature_costs,
volume_cost,
latency_cost: estimated_latency_ms as f64 * 0.01,
resource_cost: 0.0,
};
QueryCost {
total,
estimated_latency_ms,
estimated_results: EstimatedCount::Unknown,
confidence: self.estimate_confidence(&analysis),
breakdown,
}
}
pub fn estimate_all(
&self,
query: &SearchQuery,
config: &CompositeConfig,
) -> HashMap<String, QueryCost> {
config
.backends
.iter()
.filter(|b| b.enabled)
.map(|backend| (backend.id.clone(), self.estimate(query, backend)))
.collect()
}
pub fn cheapest_backend<'a>(
&self,
query: &SearchQuery,
backends: &'a [BackendEntry],
) -> Option<&'a BackendEntry> {
backends
.iter()
.filter(|b| b.enabled)
.map(|b| (b, self.estimate(query, b)))
.min_by(|(_, cost_a), (_, cost_b)| {
cost_a
.total
.partial_cmp(&cost_b.total)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(backend, _)| backend)
}
fn estimate_specificity(&self, query: &SearchQuery) -> f64 {
let mut specificity: f64 = 0.0;
for param in &query.parameters {
match param.name.as_str() {
"_id" => specificity += 0.9,
"identifier" => specificity += 0.7,
_ => specificity += 0.1,
}
if param.values.len() > 1 {
specificity *= 0.8;
}
}
specificity.min(1.0)
}
fn estimate_latency(
&self,
backend_kind: &BackendKind,
analysis: &super::analyzer::QueryAnalysis,
) -> u64 {
let base_latency = match backend_kind {
BackendKind::Sqlite => 1,
BackendKind::Postgres => 5,
BackendKind::Elasticsearch => 10,
BackendKind::Neo4j => 15,
BackendKind::S3 => 50,
_ => 10,
};
let feature_latency: u64 = analysis
.features
.iter()
.map(|f| match f {
QueryFeature::ChainedSearch => 20,
QueryFeature::ReverseChaining => 25,
QueryFeature::FullTextSearch => 15,
QueryFeature::TerminologySearch => 30,
QueryFeature::Include | QueryFeature::Revinclude => 10,
_ => 0,
})
.sum();
base_latency + feature_latency
}
fn estimate_confidence(&self, analysis: &super::analyzer::QueryAnalysis) -> f64 {
let mut confidence = 0.8;
if analysis.complexity_score > 5 {
confidence *= 0.8;
}
if self.benchmarks.is_none() {
confidence *= 0.7;
}
confidence
}
}
impl Default for CostEstimator {
fn default() -> Self {
Self::with_defaults()
}
}
#[derive(Debug, Clone, Default)]
pub struct BenchmarkResults {
pub operations: HashMap<(BackendKind, BenchmarkOperation), BenchmarkMeasurement>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum BenchmarkOperation {
IdLookup,
StringSearch,
TokenSearch,
DateSearch,
ChainedSearch1,
ChainedSearch2,
ChainedSearch3,
FullTextSearch,
TerminologyExpand,
IncludeResolve,
RevincludeResolve,
}
#[derive(Debug, Clone)]
pub struct BenchmarkMeasurement {
pub mean_us: f64,
pub std_dev_us: f64,
pub iterations: u64,
pub throughput: f64,
}
impl BenchmarkResults {
pub fn new() -> Self {
Self::default()
}
pub fn add(
&mut self,
backend: BackendKind,
operation: BenchmarkOperation,
measurement: BenchmarkMeasurement,
) {
self.operations.insert((backend, operation), measurement);
}
pub fn cost_multiplier(
&self,
backend: BackendKind,
operation: BenchmarkOperation,
) -> Option<f64> {
self.operations
.get(&(backend, operation))
.map(|m| m.mean_us / 1000.0) }
}
#[derive(Debug)]
pub struct CostComparison {
pub options: Vec<(String, QueryCost)>,
pub recommended: String,
pub savings_percent: f64,
}
impl CostComparison {
pub fn from_estimates(estimates: HashMap<String, QueryCost>) -> Self {
let mut options: Vec<_> = estimates.into_iter().collect();
options.sort_by(|a, b| {
a.1.total
.partial_cmp(&b.1.total)
.unwrap_or(std::cmp::Ordering::Equal)
});
let recommended = options
.first()
.map(|(id, _)| id.clone())
.unwrap_or_default();
let best_cost = options.first().map(|(_, c)| c.total).unwrap_or(1.0);
let worst_cost = options.last().map(|(_, c)| c.total).unwrap_or(1.0);
let savings_percent = if worst_cost > 0.0 {
((worst_cost - best_cost) / worst_cost) * 100.0
} else {
0.0
};
Self {
options,
recommended,
savings_percent,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cost_estimator_default() {
let estimator = CostEstimator::with_defaults();
assert!(
estimator
.config
.base_costs
.contains_key(&BackendKind::Sqlite)
);
}
#[test]
fn test_estimated_count_expected() {
assert_eq!(EstimatedCount::Exact(50).expected(), 50);
assert_eq!(EstimatedCount::Approximate(100).expected(), 100);
assert_eq!(EstimatedCount::Range { min: 10, max: 30 }.expected(), 20);
assert_eq!(EstimatedCount::Unknown.expected(), 100);
}
#[test]
fn test_cost_breakdown_total() {
let mut breakdown = CostBreakdown {
base: 1.0,
feature_costs: HashMap::new(),
volume_cost: 0.5,
latency_cost: 0.2,
resource_cost: 0.1,
};
breakdown
.feature_costs
.insert(QueryFeature::BasicSearch, 0.2);
assert!((breakdown.total() - 2.0).abs() < 0.01);
}
#[test]
fn test_estimate_simple_query() {
let estimator = CostEstimator::with_defaults();
let backend = BackendEntry::new(
"test",
super::super::config::BackendRole::Primary,
BackendKind::Sqlite,
);
let query = SearchQuery::new("Patient");
let cost = estimator.estimate(&query, &backend);
assert!(cost.total > 0.0);
assert!(cost.confidence > 0.0);
}
#[test]
fn test_benchmark_results() {
let mut results = BenchmarkResults::new();
results.add(
BackendKind::Sqlite,
BenchmarkOperation::IdLookup,
BenchmarkMeasurement {
mean_us: 100.0,
std_dev_us: 10.0,
iterations: 1000,
throughput: 10000.0,
},
);
let multiplier = results
.cost_multiplier(BackendKind::Sqlite, BenchmarkOperation::IdLookup)
.unwrap();
assert!((multiplier - 0.1).abs() < 0.01);
}
#[test]
fn test_cost_comparison() {
let mut estimates = HashMap::new();
estimates.insert(
"fast".to_string(),
QueryCost {
total: 1.0,
estimated_latency_ms: 10,
estimated_results: EstimatedCount::Unknown,
confidence: 0.8,
breakdown: CostBreakdown::default(),
},
);
estimates.insert(
"slow".to_string(),
QueryCost {
total: 2.0,
estimated_latency_ms: 20,
estimated_results: EstimatedCount::Unknown,
confidence: 0.8,
breakdown: CostBreakdown::default(),
},
);
let comparison = CostComparison::from_estimates(estimates);
assert_eq!(comparison.recommended, "fast");
assert!((comparison.savings_percent - 50.0).abs() < 0.01);
}
}