use super::CommandResult;
use crate::cli::CliContext;
use colored::Colorize;
use scirs2_core::ndarray_ext::{Array1, Array2};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QueryFeatures {
pub triple_patterns: f64,
pub optional_count: f64,
pub union_count: f64,
pub filter_count: f64,
pub order_by_count: f64,
pub group_by_count: f64,
pub selectivity: f64,
pub has_limit: f64,
pub has_distinct: f64,
pub subquery_count: f64,
pub property_path_complexity: f64,
pub aggregation_count: f64,
}
impl QueryFeatures {
pub fn extract_from_query(query: &str) -> Self {
let query_upper = query.to_uppercase();
let dot_count = query.matches('.').count();
let semicolon_count = query.matches(';').count();
let has_where = query_upper.contains("WHERE");
let triple_patterns = if has_where && dot_count == 0 && semicolon_count == 0 {
1.0
} else {
(dot_count + semicolon_count).max(if has_where { 1 } else { 0 }) as f64
};
let optional_count = query_upper.matches("OPTIONAL").count() as f64;
let union_count = query_upper.matches("UNION").count() as f64;
let filter_count = query_upper.matches("FILTER").count() as f64;
let order_by_count = query_upper.matches("ORDER BY").count() as f64;
let group_by_count = query_upper.matches("GROUP BY").count() as f64;
let subquery_count = query_upper.matches("SELECT").count().saturating_sub(1) as f64;
let has_limit = if query_upper.contains("LIMIT") {
1.0
} else {
0.0
};
let has_distinct = if query_upper.contains("DISTINCT") {
1.0
} else {
0.0
};
let aggregation_count = ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("]
.iter()
.map(|agg| query_upper.matches(agg).count())
.sum::<usize>() as f64;
let has_specific_uris = query.contains("http://") || query.contains("https://");
let has_filters = query_upper.contains("FILTER");
let selectivity = match (has_specific_uris, has_filters) {
(true, true) => 0.9,
(true, false) => 0.6,
(false, true) => 0.5,
(false, false) => 0.2,
};
let property_path_complexity = (query.matches('/').count()
+ query.matches('+').count()
+ query.matches('*').count() * 2) as f64
* 10.0;
Self {
triple_patterns,
optional_count,
union_count,
filter_count,
order_by_count,
group_by_count,
selectivity,
has_limit,
has_distinct,
subquery_count,
property_path_complexity: property_path_complexity.min(100.0),
aggregation_count,
}
}
pub fn to_array(&self) -> Array1<f64> {
Array1::from(vec![
self.triple_patterns,
self.optional_count,
self.union_count,
self.filter_count,
self.order_by_count,
self.group_by_count,
self.selectivity,
self.has_limit,
self.has_distinct,
self.subquery_count,
self.property_path_complexity / 100.0, self.aggregation_count,
])
}
pub fn feature_names() -> Vec<&'static str> {
vec![
"Triple Patterns",
"OPTIONAL Clauses",
"UNION Clauses",
"FILTER Expressions",
"ORDER BY",
"GROUP BY",
"Selectivity",
"Has LIMIT",
"Has DISTINCT",
"Subqueries",
"Property Path Complexity",
"Aggregations",
]
}
}
#[derive(Debug, Clone)]
pub struct PerformancePrediction {
pub predicted_time_ms: f64,
pub confidence_lower: f64,
pub confidence_upper: f64,
pub confidence: f64,
pub category: PerformanceCategory,
pub contributing_factors: Vec<(String, f64)>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PerformanceCategory {
Fast, Medium, Slow, VerySlow, }
impl PerformanceCategory {
fn from_time_ms(time_ms: f64) -> Self {
if time_ms < 100.0 {
Self::Fast
} else if time_ms < 1000.0 {
Self::Medium
} else if time_ms < 10000.0 {
Self::Slow
} else {
Self::VerySlow
}
}
fn emoji(&self) -> &str {
match self {
Self::Fast => "🚀",
Self::Medium => "⚡",
Self::Slow => "🐌",
Self::VerySlow => "🐢",
}
}
fn colored_label(&self) -> String {
match self {
Self::Fast => "FAST".green().bold().to_string(),
Self::Medium => "MEDIUM".yellow().bold().to_string(),
Self::Slow => "SLOW".red().to_string(),
Self::VerySlow => "VERY SLOW".red().bold().to_string(),
}
}
}
pub struct QueryPerformancePredictor {
coefficients: Option<Array1<f64>>,
intercept: Option<f64>,
training_data: Vec<(QueryFeatures, f64)>,
feature_importance: HashMap<String, f64>,
ctx: CliContext,
}
impl QueryPerformancePredictor {
pub fn new() -> Self {
Self {
coefficients: None,
intercept: None,
training_data: Vec::new(),
feature_importance: HashMap::new(),
ctx: CliContext::new(),
}
}
pub fn add_training_data(&mut self, features: QueryFeatures, execution_time_ms: f64) {
self.training_data.push((features, execution_time_ms));
}
pub fn train(&mut self) -> Result<(), String> {
if self.training_data.is_empty() {
return Err("No training data available".to_string());
}
if self.training_data.len() < 10 {
self.ctx.warn(&format!(
"Limited training data ({} samples) - predictions may be less accurate",
self.training_data.len()
));
}
let n_samples = self.training_data.len();
let n_features = 12;
let mut x_data = Vec::with_capacity(n_samples * n_features);
let mut y_data = Vec::with_capacity(n_samples);
for (features, time) in &self.training_data {
let feature_array = features.to_array();
x_data.extend(feature_array.iter().copied());
y_data.push(*time);
}
let x = Array2::from_shape_vec((n_samples, n_features), x_data)
.map_err(|e| format!("Failed to create feature matrix: {}", e))?;
let y = Array1::from_vec(y_data);
let _x_mean = x
.mean_axis(scirs2_core::ndarray_ext::Axis(0))
.ok_or("Failed to calculate X mean")?;
let y_mean = y.mean().ok_or("Failed to calculate y mean")?;
let mut coefficients = Array1::zeros(n_features);
for i in 0..n_features {
let feature_col = x.column(i);
let correlation = calculate_correlation(&feature_col.to_owned(), &y);
coefficients[i] = correlation * 100.0; }
self.coefficients = Some(coefficients);
self.intercept = Some(y_mean);
let feature_names = QueryFeatures::feature_names();
self.feature_importance.clear();
if let Some(coefficients) = self.coefficients.as_ref() {
for (i, name) in feature_names.iter().enumerate() {
let importance = coefficients[i].abs();
self.feature_importance.insert(name.to_string(), importance);
}
}
self.ctx.success(&format!(
"Model trained successfully on {} samples",
n_samples
));
Ok(())
}
pub fn predict(&mut self, query: &str) -> Result<PerformancePrediction, String> {
if self.coefficients.is_none() {
return self.predict_heuristic(query);
}
let features = QueryFeatures::extract_from_query(query);
let feature_array = features.to_array();
let coefficients = self
.coefficients
.as_ref()
.expect("coefficients should be present after is_none check");
let intercept = self
.intercept
.expect("intercept should be present after is_none check");
let mut predicted_time_ms = intercept;
for (coef, &feature) in coefficients.iter().zip(feature_array.iter()) {
predicted_time_ms += coef * feature;
}
predicted_time_ms = predicted_time_ms.max(1.0);
let confidence_margin = predicted_time_ms * 0.3;
let confidence_lower = (predicted_time_ms - confidence_margin).max(0.1);
let confidence_upper = predicted_time_ms + confidence_margin;
let confidence = if self.training_data.len() > 100 {
0.9
} else if self.training_data.len() > 50 {
0.75
} else if self.training_data.len() > 10 {
0.6
} else {
0.4
};
let category = PerformanceCategory::from_time_ms(predicted_time_ms);
let mut contributing_factors = Vec::new();
let feature_names = QueryFeatures::feature_names();
for (i, name) in feature_names.iter().enumerate() {
let contribution = coefficients[i] * feature_array[i];
if contribution.abs() > 0.1 {
contributing_factors.push((name.to_string(), contribution));
}
}
contributing_factors.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(PerformancePrediction {
predicted_time_ms,
confidence_lower,
confidence_upper,
confidence,
category,
contributing_factors,
})
}
fn predict_heuristic(&mut self, query: &str) -> Result<PerformancePrediction, String> {
let features = QueryFeatures::extract_from_query(query);
let mut base_time = 10.0;
base_time += features.triple_patterns * 5.0;
base_time += features.optional_count * 15.0;
base_time += features.union_count * 12.0;
base_time += features.filter_count * 8.0;
base_time += features.order_by_count * 20.0;
base_time += features.group_by_count * 25.0;
base_time += features.subquery_count * 30.0;
base_time += features.property_path_complexity * 0.5;
base_time += features.aggregation_count * 10.0;
base_time *= 2.0 - features.selectivity;
if features.has_limit > 0.5 {
base_time *= 0.7;
}
if features.has_distinct > 0.5 {
base_time *= 1.3;
}
let predicted_time_ms = base_time;
let confidence_margin = predicted_time_ms * 0.5; let confidence_lower = (predicted_time_ms - confidence_margin).max(0.1);
let confidence_upper = predicted_time_ms + confidence_margin;
let category = PerformanceCategory::from_time_ms(predicted_time_ms);
let mut contributing_factors = vec![
(
"Triple Patterns".to_string(),
features.triple_patterns * 5.0,
),
(
"OPTIONAL Clauses".to_string(),
features.optional_count * 15.0,
),
("Subqueries".to_string(), features.subquery_count * 30.0),
];
contributing_factors.retain(|(_, contrib)| contrib.abs() > 5.0);
contributing_factors.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(PerformancePrediction {
predicted_time_ms,
confidence_lower,
confidence_upper,
confidence: 0.5, category,
contributing_factors,
})
}
pub fn display_prediction(
&self,
prediction: &PerformancePrediction,
query: &str,
) -> CommandResult {
self.ctx.info("\n🔮 Query Performance Prediction:\n");
println!("{}", "Query Analysis:".bold().underline());
let features = QueryFeatures::extract_from_query(query);
println!(" Triple Patterns: {}", features.triple_patterns);
println!(" Complexity Indicators:");
if features.optional_count > 0.0 {
println!(" - {} OPTIONAL clauses", features.optional_count);
}
if features.union_count > 0.0 {
println!(" - {} UNION clauses", features.union_count);
}
if features.filter_count > 0.0 {
println!(" - {} FILTER expressions", features.filter_count);
}
if features.subquery_count > 0.0 {
println!(" - {} subqueries", features.subquery_count);
}
println!();
println!("{}", "Performance Prediction:".bold().underline());
println!(
" {} Estimated Execution Time: {}{:.2}ms{}",
prediction.category.emoji(),
prediction.category.colored_label(),
prediction.predicted_time_ms,
"".normal()
);
println!(
" 95% Confidence Interval: [{:.2}ms - {:.2}ms]",
prediction.confidence_lower, prediction.confidence_upper
);
println!(
" Prediction Confidence: {:.0}%",
prediction.confidence * 100.0
);
println!();
if !prediction.contributing_factors.is_empty() {
println!("{}", "Top Contributing Factors:".bold());
for (i, (factor, contribution)) in
prediction.contributing_factors.iter().take(5).enumerate()
{
println!(" {}. {} ({:+.1}ms)", i + 1, factor, contribution);
}
println!();
}
if matches!(
prediction.category,
PerformanceCategory::Slow | PerformanceCategory::VerySlow
) {
println!("{}", "⚠️ Performance Recommendations:".yellow().bold());
if features.has_limit < 0.5 {
println!(" • Add LIMIT clause to reduce result set size");
}
if features.optional_count > 3.0 {
println!(" • Consider reducing OPTIONAL clauses or restructuring query");
}
if features.selectivity < 0.5 {
println!(" • Add more specific filters to improve selectivity");
}
if features.subquery_count > 2.0 {
println!(" • Consider flattening nested subqueries");
}
println!();
}
if prediction.confidence < 0.6 {
self.ctx.info(&format!(
"ℹ️ Note: Prediction confidence is {}. Consider training model with more data.",
if prediction.confidence < 0.5 {
"low"
} else {
"moderate"
}
));
}
Ok(())
}
}
impl Default for QueryPerformancePredictor {
fn default() -> Self {
Self::new()
}
}
fn calculate_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
if x.len() != y.len() || x.is_empty() {
return 0.0;
}
let _n = x.len() as f64;
let x_mean = x.mean().unwrap_or(0.0);
let y_mean = y.mean().unwrap_or(0.0);
let mut numerator = 0.0;
let mut x_var = 0.0;
let mut y_var = 0.0;
for i in 0..x.len() {
let x_diff = x[i] - x_mean;
let y_diff = y[i] - y_mean;
numerator += x_diff * y_diff;
x_var += x_diff * x_diff;
y_var += y_diff * y_diff;
}
if x_var < 1e-10 || y_var < 1e-10 {
return 0.0;
}
numerator / ((x_var * y_var).sqrt())
}
pub async fn predict_query_performance_cmd(
query: String,
train_data: Option<String>,
) -> CommandResult {
let ctx = CliContext::new();
ctx.info("🔮 Predicting query performance...\n");
let mut predictor = QueryPerformancePredictor::new();
if let Some(train_file) = train_data {
ctx.info(&format!("Loading training data from: {}", train_file));
ctx.warn("Training data loading not yet implemented - using heuristic model");
}
let prediction = predictor.predict(&query)?;
predictor.display_prediction(&prediction, &query)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_extraction_simple() {
let query = "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 10";
let features = QueryFeatures::extract_from_query(query);
assert!(features.triple_patterns > 0.0);
assert_eq!(features.optional_count, 0.0);
assert_eq!(features.has_limit, 1.0);
}
#[test]
fn test_feature_extraction_complex() {
let query = r#"
SELECT DISTINCT ?person ?name
WHERE {
?person foaf:name ?name .
OPTIONAL { ?person foaf:age ?age }
FILTER(?age > 18)
}
ORDER BY ?name
LIMIT 100
"#;
let features = QueryFeatures::extract_from_query(query);
assert!(features.triple_patterns > 0.0);
assert_eq!(features.optional_count, 1.0);
assert_eq!(features.filter_count, 1.0);
assert_eq!(features.order_by_count, 1.0);
assert_eq!(features.has_distinct, 1.0);
assert_eq!(features.has_limit, 1.0);
}
#[test]
fn test_feature_to_array() {
let query = "SELECT ?s WHERE { ?s ?p ?o }";
let features = QueryFeatures::extract_from_query(query);
let array = features.to_array();
assert_eq!(array.len(), 12); }
#[test]
fn test_performance_category() {
assert_eq!(
PerformanceCategory::from_time_ms(50.0),
PerformanceCategory::Fast
);
assert_eq!(
PerformanceCategory::from_time_ms(500.0),
PerformanceCategory::Medium
);
assert_eq!(
PerformanceCategory::from_time_ms(5000.0),
PerformanceCategory::Slow
);
assert_eq!(
PerformanceCategory::from_time_ms(15000.0),
PerformanceCategory::VerySlow
);
}
#[test]
fn test_predictor_creation() {
let predictor = QueryPerformancePredictor::new();
assert!(predictor.coefficients.is_none());
assert!(predictor.training_data.is_empty());
}
#[test]
fn test_heuristic_prediction_simple() {
let mut predictor = QueryPerformancePredictor::new();
let query = "SELECT ?s WHERE { ?s ?p ?o } LIMIT 10";
let result = predictor.predict(query);
assert!(result.is_ok());
let prediction = result.unwrap();
assert!(prediction.predicted_time_ms > 0.0);
assert!(prediction.confidence_lower > 0.0);
assert!(prediction.confidence_upper > prediction.predicted_time_ms);
}
#[test]
fn test_heuristic_prediction_complex() {
let mut predictor = QueryPerformancePredictor::new();
let query = r#"
SELECT DISTINCT ?x
WHERE {
?x ?p1 ?y .
OPTIONAL { ?y ?p2 ?z }
UNION { ?x ?p3 ?w }
FILTER(?x != ?y)
}
ORDER BY ?x
"#;
let result = predictor.predict(query);
assert!(result.is_ok());
let prediction = result.unwrap();
assert!(prediction.predicted_time_ms > 50.0);
}
#[test]
fn test_training_data_addition() {
let mut predictor = QueryPerformancePredictor::new();
let features = QueryFeatures::extract_from_query("SELECT ?s WHERE { ?s ?p ?o }");
predictor.add_training_data(features, 25.0);
assert_eq!(predictor.training_data.len(), 1);
}
#[test]
fn test_model_training_insufficient_data() {
let mut predictor = QueryPerformancePredictor::new();
let result = predictor.train();
assert!(result.is_err());
assert!(result.unwrap_err().contains("No training data"));
}
#[test]
fn test_correlation_calculation() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
let corr = calculate_correlation(&x, &y);
assert!((corr - 1.0).abs() < 0.01);
}
#[test]
fn test_prediction_confidence_intervals() {
let mut predictor = QueryPerformancePredictor::new();
let query = "SELECT ?s ?p ?o WHERE { ?s ?p ?o } LIMIT 100";
let prediction = predictor.predict(query).unwrap();
assert!(prediction.confidence_lower < prediction.predicted_time_ms);
assert!(prediction.confidence_upper > prediction.predicted_time_ms);
assert!(prediction.confidence >= 0.0 && prediction.confidence <= 1.0);
}
}