use super::*;
use crate::core::{Check, ConstraintResult, ConstraintStatus, Level};
use crate::prelude::*;
use async_trait::async_trait;
use datafusion::prelude::SessionContext;
use std::sync::Arc;
#[derive(Debug, Clone)]
struct MockConstraint {
name: String,
result: ConstraintResult,
_sql: String,
}
#[async_trait]
impl crate::core::Constraint for MockConstraint {
async fn evaluate(&self, _ctx: &SessionContext) -> Result<ConstraintResult> {
Ok(self.result.clone())
}
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> Option<&str> {
Some("Mock constraint for testing")
}
}
#[cfg(test)]
mod optimizer_tests {
use super::*;
#[tokio::test]
async fn test_optimizer_creation() {
let optimizer = QueryOptimizer::new();
assert!(std::ptr::eq(&optimizer.analyzer, &optimizer.analyzer));
assert!(std::ptr::eq(&optimizer.combiner, &optimizer.combiner));
assert!(std::ptr::eq(&optimizer.executor, &optimizer.executor));
assert!(std::ptr::eq(&optimizer.stats_cache, &optimizer.stats_cache));
}
#[tokio::test]
async fn test_extract_constraints() {
let optimizer = QueryOptimizer::new();
let constraint1 = MockConstraint {
name: "test_constraint_1".to_string(),
result: ConstraintResult {
status: ConstraintStatus::Success,
metric: Some(1.0),
message: None,
},
_sql: "SELECT COUNT(*) FROM data".to_string(),
};
let constraint2 = MockConstraint {
name: "test_constraint_2".to_string(),
result: ConstraintResult {
status: ConstraintStatus::Success,
metric: Some(0.95),
message: None,
},
_sql: "SELECT COUNT(*), COUNT(column) FROM data".to_string(),
};
let check = Check::builder("test_check")
.level(Level::Error)
.constraint(constraint1)
.constraint(constraint2)
.build();
let checks = vec![check];
let constraints = optimizer.extract_constraints(&checks);
assert_eq!(constraints.len(), 2);
assert_eq!(constraints[0].0, "test_check.test_constraint_1");
assert_eq!(constraints[1].0, "test_check.test_constraint_2");
}
}
#[cfg(test)]
mod analyzer_tests {
use super::*;
use crate::optimizer::analyzer::AggregationType;
#[test]
fn test_analyzer_cache() {
let mut analyzer = QueryAnalyzer::new();
let constraint = Arc::new(MockConstraint {
name: "completeness".to_string(),
result: ConstraintResult {
status: ConstraintStatus::Success,
metric: Some(1.0),
message: None,
},
_sql: "SELECT COUNT(*) FROM data".to_string(),
});
let constraints = vec![(
"test1".to_string(),
constraint.clone() as Arc<dyn crate::core::Constraint>,
)];
let analysis1 = analyzer.analyze(&constraints).unwrap();
assert_eq!(analysis1.len(), 1);
let analysis2 = analyzer.analyze(&constraints).unwrap();
assert_eq!(analysis2.len(), 1);
assert_eq!(analysis1[0].name, "test1");
assert_eq!(analysis1[0].table_name, "data");
assert!(analysis1[0].aggregations.contains(&AggregationType::Count));
}
#[test]
fn test_aggregation_type_detection() {
let analyzer = QueryAnalyzer::new();
let test_cases = vec![
("completeness", vec![AggregationType::Count]),
(
"uniqueness",
vec![AggregationType::Count, AggregationType::CountDistinct],
),
("min", vec![AggregationType::Min]),
("max", vec![AggregationType::Max]),
("mean", vec![AggregationType::Avg]),
("sum", vec![AggregationType::Sum]),
("standard_deviation", vec![AggregationType::StdDev]),
];
for (constraint_name, expected_aggs) in test_cases {
let constraint = Arc::new(MockConstraint {
name: constraint_name.to_string(),
result: ConstraintResult {
status: ConstraintStatus::Success,
metric: Some(1.0),
message: None,
},
_sql: String::new(),
});
let analysis = analyzer
.analyze_constraint(
"test".to_string(),
constraint as Arc<dyn crate::core::Constraint>,
)
.unwrap();
assert_eq!(analysis.aggregations, expected_aggs);
}
}
}
#[cfg(test)]
mod combiner_tests {
use super::*;
use crate::optimizer::analyzer::ConstraintAnalysis;
#[test]
fn test_group_by_table() {
let combiner = QueryCombiner::new();
let analyses = vec![
create_test_analysis("c1", "data", true),
create_test_analysis("c2", "data", true),
create_test_analysis("c3", "other_table", true),
];
let groups = combiner.group_constraints(analyses).unwrap();
assert!(groups.len() >= 2);
}
#[test]
fn test_max_group_size() {
let mut combiner = QueryCombiner::new();
combiner.set_max_group_size(2);
let analyses: Vec<_> = (0..5)
.map(|i| create_test_analysis(&format!("c{i}"), "data", true))
.collect();
let groups = combiner.group_constraints(analyses).unwrap();
assert!(groups.len() >= 3);
for group in groups {
assert!(group.constraints.len() <= 2);
}
}
#[test]
fn test_non_combinable_constraints() {
let combiner = QueryCombiner::new();
let analyses = vec![
create_test_analysis("c1", "data", true),
create_test_analysis("c2", "data", false), create_test_analysis("c3", "data", true),
];
let groups = combiner.group_constraints(analyses).unwrap();
let non_combinable_group = groups
.iter()
.find(|g| g.constraints.len() == 1 && g.constraints[0].name == "c2");
assert!(non_combinable_group.is_some());
}
fn create_test_analysis(name: &str, table: &str, combinable: bool) -> ConstraintAnalysis {
ConstraintAnalysis {
name: name.to_string(),
constraint: Arc::new(MockConstraint {
name: name.to_string(),
result: ConstraintResult {
status: ConstraintStatus::Success,
metric: Some(1.0),
message: None,
},
_sql: String::new(),
}),
table_name: table.to_string(),
aggregations: vec![crate::optimizer::analyzer::AggregationType::Count],
columns: vec![],
has_predicates: false,
is_combinable: combinable,
}
}
}
#[cfg(test)]
mod executor_tests {
use super::*;
#[test]
fn test_predicate_pushdown_toggle() {
let mut executor = OptimizedExecutor::new();
assert!(executor.enable_pushdown);
executor.set_pushdown_enabled(false);
assert!(!executor.enable_pushdown);
executor.set_pushdown_enabled(true);
assert!(executor.enable_pushdown);
}
}
#[cfg(test)]
mod stats_cache_tests {
use super::*;
use std::thread::sleep;
use std::time::Duration;
#[test]
fn test_cache_operations() {
let mut cache = StatsCache::new();
cache.set("key1".to_string(), 42.0);
assert_eq!(cache.get("key1"), Some(42.0));
cache.set("key1".to_string(), 100.0);
assert_eq!(cache.get("key1"), Some(100.0));
assert_eq!(cache.get("non_existent"), None);
}
#[test]
fn test_cache_expiration() {
let mut cache = StatsCache::with_config(Duration::from_millis(50), 10);
cache.set("key1".to_string(), 42.0);
assert_eq!(cache.get("key1"), Some(42.0));
sleep(Duration::from_millis(60));
assert_eq!(cache.get("key1"), None);
}
#[test]
fn test_cache_eviction() {
let mut cache = StatsCache::with_config(Duration::from_secs(60), 2);
cache.set("key1".to_string(), 1.0);
cache.set("key2".to_string(), 2.0);
assert_eq!(cache.size(), 2);
cache.set("key3".to_string(), 3.0);
assert_eq!(cache.size(), 2);
assert_eq!(cache.get("key1"), None);
assert_eq!(cache.get("key2"), Some(2.0));
assert_eq!(cache.get("key3"), Some(3.0));
}
#[test]
fn test_cache_stats() {
let mut cache = StatsCache::with_config(Duration::from_millis(100), 10);
cache.set("key1".to_string(), 1.0);
cache.set("key2".to_string(), 2.0);
let stats = cache.stats();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.active_entries, 2);
assert_eq!(stats.expired_entries, 0);
sleep(Duration::from_millis(110));
let stats = cache.stats();
assert_eq!(stats.total_entries, 2);
assert_eq!(stats.active_entries, 0);
assert_eq!(stats.expired_entries, 2);
}
#[test]
fn test_remove_expired() {
let mut cache = StatsCache::with_config(Duration::from_millis(50), 10);
cache.set("key1".to_string(), 1.0);
cache.set("key2".to_string(), 2.0);
assert_eq!(cache.size(), 2);
sleep(Duration::from_millis(60));
cache.remove_expired();
assert_eq!(cache.size(), 0);
}
}