1#![allow(dead_code)]
7
8use crate::indexing::IndexStats;
9use crate::model::Variable;
10use crate::query::algebra::{
11 AlgebraTriplePattern, GraphPattern, Query as AlgebraQuery, QueryForm, TermPattern,
12};
13use crate::query::plan::{ExecutionPlan, QueryPlanner};
14use crate::OxirsError;
15use std::collections::{HashMap, VecDeque};
16use std::sync::{Arc, RwLock};
17use std::time::{Duration, Instant};
18
19#[derive(Debug, Clone)]
21pub struct CostModel {
22 execution_history: Arc<RwLock<QueryHistory>>,
24 learned_parameters: Arc<RwLock<LearnedParameters>>,
26 index_stats: Arc<IndexStats>,
28}
29
30#[derive(Debug, Default)]
32struct QueryHistory {
33 patterns: VecDeque<(QueryPattern, ExecutionMetrics)>,
35 max_size: usize,
37}
38
39#[derive(Debug, Default)]
41struct LearnedParameters {
42 scan_costs: HashMap<String, f64>,
44 join_selectivities: HashMap<JoinPattern, f64>,
46 filter_selectivities: HashMap<String, f64>,
48}
49
50#[derive(Debug, Clone, Hash, Eq, PartialEq)]
52struct QueryPattern {
53 num_patterns: usize,
55 predicates: Vec<String>,
57 join_types: Vec<JoinType>,
59 has_filter: bool,
61}
62
63#[derive(Debug, Clone, Hash, Eq, PartialEq)]
65struct JoinPattern {
66 num_vars: usize,
68 term_types: Vec<String>,
70}
71
72#[derive(Debug, Clone, Hash, Eq, PartialEq)]
74enum JoinType {
75 SubjectSubject,
76 SubjectObject,
77 ObjectObject,
78 PredicatePredicate,
79}
80
81#[derive(Debug, Clone)]
83struct ExecutionMetrics {
84 execution_time: Duration,
86 result_count: usize,
88 memory_used: usize,
90 cpu_percent: f32,
92}
93
94pub struct AIQueryOptimizer {
96 base_planner: QueryPlanner,
98 cost_model: CostModel,
100 query_cache: Arc<RwLock<QueryCache>>,
102 hardware_info: HardwareInfo,
104}
105
106#[derive(Debug, Default)]
108struct QueryCache {
109 cache: HashMap<String, CachedResult>,
111 access_patterns: VecDeque<AccessPattern>,
113 max_size: usize,
115}
116
117#[derive(Debug, Clone)]
119struct CachedResult {
120 data: Vec<u8>,
122 cached_at: Instant,
124 access_count: usize,
126 last_accessed: Instant,
128}
129
130#[derive(Debug, Clone)]
132struct AccessPattern {
133 query_hash: String,
135 accessed_at: Instant,
137 session_id: String,
139}
140
141#[derive(Debug, Clone)]
143struct HardwareInfo {
144 cpu_cores: usize,
146 memory_bytes: usize,
148 cpu_features: CpuFeatures,
150 gpu_available: bool,
152}
153
154#[derive(Debug, Clone)]
156struct CpuFeatures {
157 has_simd: bool,
159 has_avx2: bool,
161 cache_line_size: usize,
163}
164
165impl AIQueryOptimizer {
166 pub fn new(index_stats: Arc<IndexStats>) -> Self {
168 Self {
169 base_planner: QueryPlanner::new(),
170 cost_model: CostModel::new(index_stats),
171 query_cache: Arc::new(RwLock::new(QueryCache::new())),
172 hardware_info: HardwareInfo::detect(),
173 }
174 }
175
176 pub fn optimize_query(&self, query: &AlgebraQuery) -> Result<OptimizedPlan, OxirsError> {
178 let pattern = self.extract_query_pattern(query)?;
180
181 if let Some(cached) = self.check_predictive_cache(&pattern) {
183 return Ok(cached);
184 }
185
186 let candidates = self.generate_candidate_plans(query)?;
188
189 let mut best_plan = None;
191 let mut best_cost = f64::MAX;
192
193 for candidate in candidates {
194 let cost = self.estimate_cost(&candidate, &pattern)?;
195 if cost < best_cost {
196 best_cost = cost;
197 best_plan = Some(candidate);
198 }
199 }
200
201 let plan = best_plan
202 .ok_or_else(|| OxirsError::Query("No valid execution plan found".to_string()))?;
203
204 let optimized = self.apply_hardware_optimizations(plan)?;
206
207 self.update_learning_model(&pattern, &optimized);
209
210 Ok(optimized)
211 }
212
213 fn extract_query_pattern(&self, query: &AlgebraQuery) -> Result<QueryPattern, OxirsError> {
215 match &query.form {
216 QueryForm::Select { where_clause, .. } => {
217 let (num_patterns, predicates, join_types) =
218 self.analyze_graph_pattern(where_clause)?;
219
220 Ok(QueryPattern {
221 num_patterns,
222 predicates,
223 join_types,
224 has_filter: self.has_filter(where_clause),
225 })
226 }
227 _ => Err(OxirsError::Query("Unsupported query form".to_string())),
228 }
229 }
230
231 fn analyze_graph_pattern(
233 &self,
234 pattern: &GraphPattern,
235 ) -> Result<(usize, Vec<String>, Vec<JoinType>), OxirsError> {
236 match pattern {
237 GraphPattern::Bgp(patterns) => {
238 let num_patterns = patterns.len();
239 let mut predicates = Vec::new();
240 let mut join_types = Vec::new();
241
242 for triple in patterns {
244 if let TermPattern::NamedNode(pred) = &triple.predicate {
245 predicates.push(pred.as_str().to_string());
246 }
247 }
248
249 for i in 0..patterns.len() {
251 for j in (i + 1)..patterns.len() {
252 if let Some(join_type) = self.get_join_type(&patterns[i], &patterns[j]) {
253 join_types.push(join_type);
254 }
255 }
256 }
257
258 Ok((num_patterns, predicates, join_types))
259 }
260 _ => Ok((0, Vec::new(), Vec::new())),
261 }
262 }
263
264 fn get_join_type(
266 &self,
267 left: &AlgebraTriplePattern,
268 right: &AlgebraTriplePattern,
269 ) -> Option<JoinType> {
270 if self.patterns_match(&left.subject, &right.subject) {
272 return Some(JoinType::SubjectSubject);
273 }
274
275 if self.patterns_match(&left.subject, &right.object) {
277 return Some(JoinType::SubjectObject);
278 }
279
280 if self.patterns_match(&left.object, &right.object) {
282 return Some(JoinType::ObjectObject);
283 }
284
285 if self.patterns_match(&left.predicate, &right.predicate) {
287 return Some(JoinType::PredicatePredicate);
288 }
289
290 None
291 }
292
293 fn patterns_match(&self, left: &TermPattern, right: &TermPattern) -> bool {
295 match (left, right) {
296 (TermPattern::Variable(v1), TermPattern::Variable(v2)) => v1 == v2,
297 _ => false,
298 }
299 }
300
301 #[allow(clippy::only_used_in_recursion)]
303 fn has_filter(&self, pattern: &GraphPattern) -> bool {
304 match pattern {
305 GraphPattern::Filter { .. } => true,
306 GraphPattern::Bgp(_) => false,
307 GraphPattern::Union(left, right) => self.has_filter(left) || self.has_filter(right),
308 _ => false,
309 }
310 }
311
312 fn generate_candidate_plans(
314 &self,
315 query: &AlgebraQuery,
316 ) -> Result<Vec<ExecutionPlan>, OxirsError> {
317 let mut candidates = Vec::new();
318
319 let basic_plan = self.base_planner.plan_query(query)?;
321 candidates.push(basic_plan.clone());
322
323 if let QueryForm::Select {
325 where_clause: GraphPattern::Bgp(patterns),
326 ..
327 } = &query.form
328 {
329 let join_orders = self.generate_join_orders(patterns);
331 for order in join_orders {
332 if let Ok(plan) = self.create_plan_with_order(patterns, &order) {
333 candidates.push(plan);
334 }
335 }
336 }
337
338 candidates.extend(self.generate_index_plans(query)?);
340
341 Ok(candidates)
342 }
343
344 fn generate_join_orders(&self, patterns: &[AlgebraTriplePattern]) -> Vec<Vec<usize>> {
346 let mut orders = Vec::new();
347
348 orders.push((0..patterns.len()).collect());
350
351 let mut selective_order: Vec<usize> = (0..patterns.len()).collect();
353 selective_order.sort_by_key(|&i| self.estimate_selectivity(&patterns[i]));
354 orders.push(selective_order);
355
356 orders.truncate(5);
358 orders
359 }
360
361 fn estimate_selectivity(&self, pattern: &AlgebraTriplePattern) -> i64 {
363 let mut score = 0;
365
366 if !matches!(pattern.subject, TermPattern::Variable(_)) {
368 score -= 1000;
369 }
370 if !matches!(pattern.predicate, TermPattern::Variable(_)) {
371 score -= 100;
372 }
373 if !matches!(pattern.object, TermPattern::Variable(_)) {
374 score -= 1000;
375 }
376
377 score
378 }
379
380 fn create_plan_with_order(
382 &self,
383 patterns: &[AlgebraTriplePattern],
384 order: &[usize],
385 ) -> Result<ExecutionPlan, OxirsError> {
386 if order.is_empty() {
387 return Err(OxirsError::Query("Empty join order".to_string()));
388 }
389
390 let mut plan = ExecutionPlan::TripleScan {
391 pattern: crate::query::plan::convert_algebra_triple_pattern(&patterns[order[0]]),
392 };
393
394 for &idx in &order[1..] {
395 let right_plan = ExecutionPlan::TripleScan {
396 pattern: crate::query::plan::convert_algebra_triple_pattern(&patterns[idx]),
397 };
398
399 plan = ExecutionPlan::HashJoin {
400 left: Box::new(plan),
401 right: Box::new(right_plan),
402 join_vars: Vec::new(), };
404 }
405
406 Ok(plan)
407 }
408
409 fn generate_index_plans(
411 &self,
412 _query: &AlgebraQuery,
413 ) -> Result<Vec<ExecutionPlan>, OxirsError> {
414 Ok(Vec::new())
416 }
417
418 fn estimate_cost(
420 &self,
421 plan: &ExecutionPlan,
422 pattern: &QueryPattern,
423 ) -> Result<f64, OxirsError> {
424 let params = self
425 .cost_model
426 .learned_parameters
427 .read()
428 .map_err(|e| OxirsError::Query(format!("Failed to read parameters: {e}")))?;
429
430 let base_cost = self.estimate_plan_cost(plan, ¶ms)?;
431
432 let history_factor = self.get_history_factor(pattern);
434
435 Ok(base_cost * history_factor)
436 }
437
438 #[allow(clippy::only_used_in_recursion)]
440 fn estimate_plan_cost(
441 &self,
442 plan: &ExecutionPlan,
443 params: &LearnedParameters,
444 ) -> Result<f64, OxirsError> {
445 match plan {
446 ExecutionPlan::TripleScan { pattern } => {
447 let mut cost = 100.0;
449
450 if let Some(crate::model::pattern::PredicatePattern::NamedNode(pred)) =
452 &pattern.predicate
453 {
454 if let Some(&pred_cost) = params.scan_costs.get(pred.as_str()) {
455 cost *= pred_cost;
456 }
457 }
458
459 Ok(cost)
460 }
461 ExecutionPlan::HashJoin { left, right, .. } => {
462 let left_cost = self.estimate_plan_cost(left, params)?;
463 let right_cost = self.estimate_plan_cost(right, params)?;
464
465 Ok(left_cost + right_cost + (left_cost * right_cost * 0.01))
467 }
468 ExecutionPlan::Filter { input, .. } => {
469 let input_cost = self.estimate_plan_cost(input, params)?;
470 Ok(input_cost * 0.5)
472 }
473 _ => Ok(1000.0), }
475 }
476
477 fn get_history_factor(&self, pattern: &QueryPattern) -> f64 {
479 if let Ok(history) = self.cost_model.execution_history.read() {
481 for (hist_pattern, metrics) in history.patterns.iter() {
482 if self.patterns_similar(pattern, hist_pattern) {
483 return if metrics.execution_time.as_millis() < 100 {
485 0.8 } else {
487 1.2 };
489 }
490 }
491 }
492 1.0 }
494
495 fn patterns_similar(&self, a: &QueryPattern, b: &QueryPattern) -> bool {
497 a.num_patterns == b.num_patterns
498 && a.has_filter == b.has_filter
499 && a.predicates.len() == b.predicates.len()
500 }
501
502 fn check_predictive_cache(&self, _pattern: &QueryPattern) -> Option<OptimizedPlan> {
504 None
506 }
507
508 fn apply_hardware_optimizations(
510 &self,
511 plan: ExecutionPlan,
512 ) -> Result<OptimizedPlan, OxirsError> {
513 let mut optimized = OptimizedPlan {
514 base_plan: plan,
515 parallelism_level: 1,
516 use_simd: false,
517 use_gpu: false,
518 memory_budget: 0,
519 };
520
521 optimized.parallelism_level = self.calculate_optimal_parallelism();
523
524 optimized.use_simd = self.hardware_info.cpu_features.has_simd;
526
527 optimized.use_gpu =
529 self.hardware_info.gpu_available && self.should_use_gpu(&optimized.base_plan);
530
531 optimized.memory_budget = self.calculate_memory_budget();
533
534 Ok(optimized)
535 }
536
537 fn calculate_optimal_parallelism(&self) -> usize {
539 (self.hardware_info.cpu_cores as f32 * 0.75) as usize
541 }
542
543 fn should_use_gpu(&self, _plan: &ExecutionPlan) -> bool {
545 false }
548
549 fn calculate_memory_budget(&self) -> usize {
551 self.hardware_info.memory_bytes / 2
553 }
554
555 fn update_learning_model(&self, pattern: &QueryPattern, _plan: &OptimizedPlan) {
557 if let Ok(mut history) = self.cost_model.execution_history.write() {
559 let metrics = ExecutionMetrics {
560 execution_time: Duration::from_millis(50), result_count: 100, memory_used: 1024 * 1024, cpu_percent: 25.0, };
565
566 history.add_execution(pattern.clone(), metrics);
567 }
568 }
569}
570
571#[derive(Debug)]
573pub struct OptimizedPlan {
574 pub base_plan: ExecutionPlan,
576 pub parallelism_level: usize,
578 pub use_simd: bool,
580 pub use_gpu: bool,
582 pub memory_budget: usize,
584}
585
586impl CostModel {
587 fn new(index_stats: Arc<IndexStats>) -> Self {
588 Self {
589 execution_history: Arc::new(RwLock::new(QueryHistory::new())),
590 learned_parameters: Arc::new(RwLock::new(LearnedParameters::default())),
591 index_stats,
592 }
593 }
594}
595
596impl QueryHistory {
597 fn new() -> Self {
598 Self {
599 patterns: VecDeque::new(),
600 max_size: 10000,
601 }
602 }
603
604 fn add_execution(&mut self, pattern: QueryPattern, metrics: ExecutionMetrics) {
605 self.patterns.push_back((pattern, metrics));
606
607 while self.patterns.len() > self.max_size {
609 self.patterns.pop_front();
610 }
611 }
612}
613
614impl QueryCache {
615 fn new() -> Self {
616 Self {
617 cache: HashMap::new(),
618 access_patterns: VecDeque::new(),
619 max_size: 1000,
620 }
621 }
622}
623
624impl HardwareInfo {
625 fn detect() -> Self {
626 Self {
627 cpu_cores: std::thread::available_parallelism()
628 .map(|p| p.get())
629 .unwrap_or(1),
630 memory_bytes: 8 * 1024 * 1024 * 1024, cpu_features: CpuFeatures {
632 has_simd: cfg!(target_feature = "sse2"),
633 has_avx2: cfg!(target_feature = "avx2"),
634 cache_line_size: 64,
635 },
636 gpu_available: false, }
638 }
639}
640
641pub struct MultiQueryOptimizer {
643 single_optimizer: AIQueryOptimizer,
645 subexpression_cache: Arc<RwLock<HashMap<String, ExecutionPlan>>>,
647}
648
649impl MultiQueryOptimizer {
650 pub fn new(index_stats: Arc<IndexStats>) -> Self {
652 Self {
653 single_optimizer: AIQueryOptimizer::new(index_stats),
654 subexpression_cache: Arc::new(RwLock::new(HashMap::new())),
655 }
656 }
657
658 pub fn optimize_batch(
660 &self,
661 queries: &[AlgebraQuery],
662 ) -> Result<Vec<OptimizedPlan>, OxirsError> {
663 let common_subs = self.detect_common_subexpressions(queries)?;
665
666 let mut optimized_plans = Vec::new();
668
669 for query in queries {
670 let mut plan = self.single_optimizer.optimize_query(query)?;
671
672 plan = self.reuse_common_subexpressions(plan, &common_subs)?;
674
675 optimized_plans.push(plan);
676 }
677
678 Ok(optimized_plans)
679 }
680
681 fn detect_common_subexpressions(
683 &self,
684 queries: &[AlgebraQuery],
685 ) -> Result<HashMap<String, ExecutionPlan>, OxirsError> {
686 let mut common_subs = HashMap::new();
687
688 let mut pattern_counts = HashMap::new();
690
691 for query in queries {
692 self.count_patterns(query, &mut pattern_counts)?;
693 }
694
695 for (pattern_key, count) in pattern_counts {
697 if count > 1 {
698 common_subs.insert(
701 pattern_key,
702 ExecutionPlan::TripleScan {
703 pattern: crate::model::pattern::TriplePattern::new(
704 Some(crate::model::pattern::SubjectPattern::Variable(
705 Variable::new("?s").unwrap(),
706 )),
707 Some(crate::model::pattern::PredicatePattern::Variable(
708 Variable::new("?p").unwrap(),
709 )),
710 Some(crate::model::pattern::ObjectPattern::Variable(
711 Variable::new("?o").unwrap(),
712 )),
713 ),
714 },
715 );
716 }
717 }
718
719 Ok(common_subs)
720 }
721
722 fn count_patterns(
724 &self,
725 query: &AlgebraQuery,
726 counts: &mut HashMap<String, usize>,
727 ) -> Result<(), OxirsError> {
728 if let QueryForm::Select { where_clause, .. } = &query.form {
729 self.count_graph_patterns(where_clause, counts)?;
730 }
731 Ok(())
732 }
733
734 fn count_graph_patterns(
736 &self,
737 pattern: &GraphPattern,
738 counts: &mut HashMap<String, usize>,
739 ) -> Result<(), OxirsError> {
740 if let GraphPattern::Bgp(patterns) = pattern {
741 for triple in patterns {
742 let key = format!("{triple}"); *counts.entry(key).or_insert(0) += 1;
744 }
745 }
746 Ok(())
747 }
748
749 fn reuse_common_subexpressions(
751 &self,
752 plan: OptimizedPlan,
753 _common: &HashMap<String, ExecutionPlan>,
754 ) -> Result<OptimizedPlan, OxirsError> {
755 Ok(plan)
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use super::*;
763
764 #[test]
765 fn test_ai_optimizer_creation() {
766 let stats = Arc::new(IndexStats::new());
767 let optimizer = AIQueryOptimizer::new(stats);
768
769 assert!(optimizer.hardware_info.cpu_cores > 0);
770 }
771
772 #[test]
773 fn test_cost_model() {
774 let stats = Arc::new(IndexStats::new());
775 let model = CostModel::new(stats);
776
777 let history = model.execution_history.read().unwrap();
778 assert_eq!(history.patterns.len(), 0);
779 }
780
781 #[test]
782 fn test_hardware_detection() {
783 let hw = HardwareInfo::detect();
784
785 assert!(hw.cpu_cores > 0);
786 assert!(hw.memory_bytes > 0);
787 assert_eq!(hw.cpu_features.cache_line_size, 64);
788 }
789}