use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum OperationType {
Select,
Insert,
Update,
Delete,
Ddl,
FullTableScan,
VectorSearch,
TransactionControl,
Administrative,
Unknown,
}
impl Default for OperationType {
fn default() -> Self {
Self::Unknown
}
}
impl std::fmt::Display for OperationType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OperationType::Select => write!(f, "SELECT"),
OperationType::Insert => write!(f, "INSERT"),
OperationType::Update => write!(f, "UPDATE"),
OperationType::Delete => write!(f, "DELETE"),
OperationType::Ddl => write!(f, "DDL"),
OperationType::FullTableScan => write!(f, "FULL_SCAN"),
OperationType::VectorSearch => write!(f, "VECTOR"),
OperationType::TransactionControl => write!(f, "TXN_CTRL"),
OperationType::Administrative => write!(f, "ADMIN"),
OperationType::Unknown => write!(f, "UNKNOWN"),
}
}
}
#[derive(Debug, Clone)]
pub struct QueryCostEstimator {
base_cost: u32,
operation_costs: HashMap<OperationType, f32>,
pattern_costs: Vec<(String, u32)>,
detect_full_scans: bool,
expensive_keywords: Vec<String>,
}
impl QueryCostEstimator {
pub fn new() -> Self {
let mut operation_costs = HashMap::new();
operation_costs.insert(OperationType::Select, 1.0);
operation_costs.insert(OperationType::Insert, 2.0);
operation_costs.insert(OperationType::Update, 3.0);
operation_costs.insert(OperationType::Delete, 3.0);
operation_costs.insert(OperationType::Ddl, 10.0);
operation_costs.insert(OperationType::FullTableScan, 5.0);
operation_costs.insert(OperationType::VectorSearch, 3.0);
operation_costs.insert(OperationType::TransactionControl, 0.5);
operation_costs.insert(OperationType::Administrative, 5.0);
operation_costs.insert(OperationType::Unknown, 1.0);
Self {
base_cost: 1,
operation_costs,
pattern_costs: Vec::new(),
detect_full_scans: true,
expensive_keywords: vec![
"CROSS JOIN".to_string(),
"CARTESIAN".to_string(),
"FULL OUTER JOIN".to_string(),
"ORDER BY".to_string(),
"GROUP BY".to_string(),
"DISTINCT".to_string(),
"UNION".to_string(),
],
}
}
pub fn with_base_cost(mut self, cost: u32) -> Self {
self.base_cost = cost;
self
}
pub fn with_operation_cost(mut self, op: OperationType, cost: f32) -> Self {
self.operation_costs.insert(op, cost);
self
}
pub fn with_pattern_cost(mut self, pattern: impl Into<String>, cost: u32) -> Self {
self.pattern_costs.push((pattern.into().to_uppercase(), cost));
self
}
pub fn with_full_scan_detection(mut self, enabled: bool) -> Self {
self.detect_full_scans = enabled;
self
}
pub fn estimate_cost(&self, query: &str) -> u32 {
let upper = query.to_uppercase();
let op_type = self.detect_operation(&upper);
let multiplier = self.operation_costs.get(&op_type).copied().unwrap_or(1.0);
let mut cost = (self.base_cost as f32 * multiplier) as u32;
for keyword in &self.expensive_keywords {
if upper.contains(keyword) {
cost += 1;
}
}
for (pattern, pattern_cost) in &self.pattern_costs {
if upper.contains(pattern) {
cost += pattern_cost;
}
}
if self.detect_full_scans && self.is_likely_full_scan(&upper) {
let scan_multiplier = self
.operation_costs
.get(&OperationType::FullTableScan)
.copied()
.unwrap_or(5.0);
cost = (cost as f32 * scan_multiplier) as u32;
}
cost.max(1)
}
#[cfg(feature = "lag-routing")]
pub fn estimate_write_cost_sync_mode(&self, sync_mode: crate::lag::SyncMode) -> u32 {
use crate::lag::SyncMode;
match sync_mode {
SyncMode::Sync => 5, SyncMode::SemiSync => 3, SyncMode::Async => 1, SyncMode::Unknown => 2,
}
}
pub fn detect_operation(&self, query: &str) -> OperationType {
let upper = query.trim().to_uppercase();
if upper.starts_with("BEGIN")
|| upper.starts_with("COMMIT")
|| upper.starts_with("ROLLBACK")
|| upper.starts_with("SAVEPOINT")
|| upper.starts_with("START TRANSACTION")
|| upper.starts_with("END")
{
return OperationType::TransactionControl;
}
if upper.starts_with("CREATE")
|| upper.starts_with("ALTER")
|| upper.starts_with("DROP")
|| upper.starts_with("TRUNCATE")
{
return OperationType::Ddl;
}
if upper.starts_with("ANALYZE")
|| upper.starts_with("VACUUM")
|| upper.starts_with("REINDEX")
|| upper.starts_with("CLUSTER")
{
return OperationType::Administrative;
}
if upper.starts_with("SELECT") || upper.starts_with("WITH") {
if upper.contains("VECTOR_SEARCH")
|| upper.contains("<->")
|| upper.contains("COSINE")
|| upper.contains("L2_DISTANCE")
|| upper.contains("EMBEDDING")
{
return OperationType::VectorSearch;
}
return OperationType::Select;
}
if upper.starts_with("INSERT") {
return OperationType::Insert;
}
if upper.starts_with("UPDATE") {
return OperationType::Update;
}
if upper.starts_with("DELETE") {
return OperationType::Delete;
}
OperationType::Unknown
}
fn is_likely_full_scan(&self, upper: &str) -> bool {
if upper.starts_with("SELECT") || upper.contains(" SELECT ") {
if !upper.contains("WHERE") {
if upper.contains("FROM") {
return true;
}
}
}
if (upper.starts_with("DELETE") || upper.starts_with("UPDATE")) && !upper.contains("WHERE")
{
return true;
}
false
}
pub fn extract_cost_hint(&self, query: &str) -> Option<u32> {
if let Some(start) = query.find("/*helios:cost=") {
let after_prefix = &query[start + 14..];
if let Some(end) = after_prefix.find("*/") {
let cost_str = &after_prefix[..end];
return cost_str.trim().parse().ok();
}
}
if let Some(start) = query.find("/*cost:") {
let after_prefix = &query[start + 7..];
if let Some(end) = after_prefix.find("*/") {
let cost_str = &after_prefix[..end];
return cost_str.trim().parse().ok();
}
}
None
}
pub fn estimate_cost_with_hint(&self, query: &str) -> u32 {
if let Some(hint_cost) = self.extract_cost_hint(query) {
return hint_cost;
}
self.estimate_cost(query)
}
pub fn get_operation_multiplier(&self, op: OperationType) -> f32 {
self.operation_costs.get(&op).copied().unwrap_or(1.0)
}
pub fn set_base_cost(&mut self, cost: u32) {
self.base_cost = cost;
}
pub fn base_cost(&self) -> u32 {
self.base_cost
}
}
impl Default for QueryCostEstimator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_select() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("SELECT * FROM users"),
OperationType::Select
);
assert_eq!(
estimator.detect_operation("select id from users where id = 1"),
OperationType::Select
);
}
#[test]
fn test_detect_insert() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("INSERT INTO users (name) VALUES ('test')"),
OperationType::Insert
);
}
#[test]
fn test_detect_update() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("UPDATE users SET name = 'test' WHERE id = 1"),
OperationType::Update
);
}
#[test]
fn test_detect_delete() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("DELETE FROM users WHERE id = 1"),
OperationType::Delete
);
}
#[test]
fn test_detect_ddl() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("CREATE TABLE test (id INT)"),
OperationType::Ddl
);
assert_eq!(
estimator.detect_operation("ALTER TABLE users ADD COLUMN age INT"),
OperationType::Ddl
);
assert_eq!(
estimator.detect_operation("DROP TABLE test"),
OperationType::Ddl
);
}
#[test]
fn test_detect_transaction_control() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("BEGIN"),
OperationType::TransactionControl
);
assert_eq!(
estimator.detect_operation("COMMIT"),
OperationType::TransactionControl
);
assert_eq!(
estimator.detect_operation("ROLLBACK"),
OperationType::TransactionControl
);
assert_eq!(
estimator.detect_operation("START TRANSACTION"),
OperationType::TransactionControl
);
}
#[test]
fn test_detect_vector_search() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("SELECT * FROM docs ORDER BY embedding <-> '[1,2,3]'"),
OperationType::VectorSearch
);
assert_eq!(
estimator.detect_operation("SELECT vector_search(embedding, query)"),
OperationType::VectorSearch
);
}
#[test]
fn test_detect_administrative() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.detect_operation("ANALYZE users"),
OperationType::Administrative
);
assert_eq!(
estimator.detect_operation("VACUUM FULL users"),
OperationType::Administrative
);
}
#[test]
fn test_estimate_cost_by_type() {
let estimator = QueryCostEstimator::new();
let select_cost = estimator.estimate_cost("SELECT id FROM users WHERE id = 1");
let insert_cost = estimator.estimate_cost("INSERT INTO users (name) VALUES ('test')");
let update_cost = estimator.estimate_cost("UPDATE users SET name = 'test' WHERE id = 1");
assert!(insert_cost > select_cost);
assert!(update_cost > insert_cost);
}
#[test]
fn test_full_scan_detection() {
let estimator = QueryCostEstimator::new();
let scan_cost = estimator.estimate_cost("SELECT * FROM users");
let indexed_cost = estimator.estimate_cost("SELECT * FROM users WHERE id = 1");
assert!(scan_cost > indexed_cost);
}
#[test]
fn test_expensive_keywords() {
let estimator = QueryCostEstimator::new();
let simple_cost = estimator.estimate_cost("SELECT id FROM users WHERE id = 1");
let complex_cost =
estimator.estimate_cost("SELECT COUNT(*) FROM users GROUP BY status ORDER BY status");
assert!(complex_cost > simple_cost);
}
#[test]
fn test_extract_cost_hint() {
let estimator = QueryCostEstimator::new();
assert_eq!(
estimator.extract_cost_hint("/*helios:cost=10*/ SELECT * FROM users"),
Some(10)
);
assert_eq!(
estimator.extract_cost_hint("/*cost:5*/ SELECT * FROM users"),
Some(5)
);
assert_eq!(
estimator.extract_cost_hint("SELECT * FROM users"),
None
);
}
#[test]
fn test_estimate_cost_with_hint() {
let estimator = QueryCostEstimator::new();
let hint_cost = estimator.estimate_cost_with_hint("/*helios:cost=100*/ SELECT * FROM users");
assert_eq!(hint_cost, 100);
let estimated_cost = estimator.estimate_cost_with_hint("SELECT * FROM users WHERE id = 1");
assert!(estimated_cost < 100);
}
#[test]
fn test_custom_operation_cost() {
let estimator = QueryCostEstimator::new()
.with_operation_cost(OperationType::Select, 5.0);
let cost = estimator.estimate_cost("SELECT id FROM users WHERE id = 1");
assert_eq!(cost, 5);
}
#[test]
fn test_custom_pattern_cost() {
let estimator = QueryCostEstimator::new()
.with_pattern_cost("EXPENSIVE_TABLE", 20);
let cost = estimator.estimate_cost("SELECT * FROM EXPENSIVE_TABLE WHERE id = 1");
assert!(cost > 20);
}
#[test]
fn test_minimum_cost() {
let estimator = QueryCostEstimator::new()
.with_operation_cost(OperationType::TransactionControl, 0.0);
let cost = estimator.estimate_cost("BEGIN");
assert!(cost >= 1);
}
#[test]
fn test_with_query() {
let estimator = QueryCostEstimator::new();
let op = estimator.detect_operation(
"WITH cte AS (SELECT * FROM users) SELECT * FROM cte WHERE id = 1"
);
assert_eq!(op, OperationType::Select);
}
#[test]
fn test_delete_without_where() {
let estimator = QueryCostEstimator::new();
let cost_without_where = estimator.estimate_cost("DELETE FROM temp_table");
let cost_with_where = estimator.estimate_cost("DELETE FROM temp_table WHERE id = 1");
assert!(cost_without_where > cost_with_where);
}
}