1use crate::algebra::{Term, TriplePattern};
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, Default)]
15pub struct RuntimeStats {
16 pub pattern_stats: HashMap<String, PatternRuntimeStats>,
18 pub join_stats: HashMap<String, JoinRuntimeStats>,
20 pub execution_times: HashMap<String, Duration>,
22 pub query_count: u64,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct PatternRuntimeStats {
29 pub estimated_cardinality_sum: u64,
31 pub actual_cardinality_sum: u64,
33 pub estimation_error: f64,
35 pub sample_count: u64,
37 pub correction_factor: f64,
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct JoinRuntimeStats {
44 pub left_cardinality_sum: u64,
46 pub right_cardinality_sum: u64,
48 pub output_cardinality_sum: u64,
50 pub observed_selectivity: f64,
52 pub sample_count: u64,
54}
55
56pub struct AdaptiveStatsStore {
58 stats: Arc<RwLock<RuntimeStats>>,
59 max_history: usize,
60}
61
62impl AdaptiveStatsStore {
63 pub fn new(max_history: usize) -> Self {
65 Self {
66 stats: Arc::new(RwLock::new(RuntimeStats::default())),
67 max_history,
68 }
69 }
70
71 pub fn record_pattern_execution(&self, pattern_id: &str, estimated: u64, actual: u64) {
73 let Ok(mut stats) = self.stats.write() else {
74 return;
75 };
76 let entry = stats
77 .pattern_stats
78 .entry(pattern_id.to_string())
79 .or_default();
80
81 entry.sample_count += 1;
82 entry.estimated_cardinality_sum += estimated;
83 entry.actual_cardinality_sum += actual;
84
85 let ratio = if estimated > 0 {
86 actual as f64 / estimated as f64
87 } else {
88 1.0
89 };
90 entry.estimation_error = ratio;
91
92 if entry.sample_count == 1 {
94 entry.correction_factor = ratio;
95 } else {
96 entry.correction_factor = 0.8 * entry.correction_factor + 0.2 * ratio;
97 }
98
99 if entry.sample_count > self.max_history as u64 {
101 let avg_est = entry.estimated_cardinality_sum / entry.sample_count;
102 let avg_act = entry.actual_cardinality_sum / entry.sample_count;
103 entry.estimated_cardinality_sum = avg_est;
104 entry.actual_cardinality_sum = avg_act;
105 entry.sample_count = 1;
106 }
107 }
108
109 pub fn record_join_execution(&self, join_id: &str, left: u64, right: u64, output: u64) {
111 let Ok(mut stats) = self.stats.write() else {
112 return;
113 };
114 let entry = stats.join_stats.entry(join_id.to_string()).or_default();
115
116 entry.sample_count += 1;
117 entry.left_cardinality_sum += left;
118 entry.right_cardinality_sum += right;
119 entry.output_cardinality_sum += output;
120
121 let denominator = (left as f64) * (right as f64);
122 let selectivity = if denominator > 0.0 {
123 output as f64 / denominator
124 } else {
125 0.0
126 };
127
128 if entry.sample_count == 1 {
130 entry.observed_selectivity = selectivity;
131 } else {
132 entry.observed_selectivity = 0.8 * entry.observed_selectivity + 0.2 * selectivity;
133 }
134 }
135
136 pub fn record_execution_time(&self, component_id: &str, duration: Duration) {
138 let Ok(mut stats) = self.stats.write() else {
139 return;
140 };
141 stats
142 .execution_times
143 .insert(component_id.to_string(), duration);
144 }
145
146 pub fn get_adjusted_cardinality(&self, pattern_id: &str, base_estimate: u64) -> u64 {
148 let Ok(stats) = self.stats.read() else {
149 return base_estimate;
150 };
151 let Some(entry) = stats.pattern_stats.get(pattern_id) else {
152 return base_estimate;
153 };
154
155 if entry.sample_count == 0 {
156 return base_estimate;
157 }
158
159 let adjusted = (base_estimate as f64 * entry.correction_factor).round() as u64;
160 adjusted.max(1)
161 }
162
163 pub fn get_adjusted_selectivity(&self, join_id: &str, base_selectivity: f64) -> f64 {
165 let Ok(stats) = self.stats.read() else {
166 return base_selectivity;
167 };
168 let Some(entry) = stats.join_stats.get(join_id) else {
169 return base_selectivity;
170 };
171
172 if entry.sample_count == 0 {
173 return base_selectivity;
174 }
175
176 let observed_weight = (entry.sample_count as f64 / 10.0).min(0.8);
178 let base_weight = 1.0 - observed_weight;
179 (base_weight * base_selectivity + observed_weight * entry.observed_selectivity)
180 .clamp(0.0001, 1.0)
181 }
182
183 pub fn snapshot(&self) -> Option<RuntimeStats> {
185 self.stats.read().ok().map(|s| s.clone())
186 }
187}
188
189#[derive(Debug, Clone, PartialEq, Eq)]
191pub enum JoinAlgorithm {
192 Hash,
194 NestedLoop,
196 Merge,
198}
199
200#[derive(Debug, Clone)]
202pub enum PatternTerm {
203 Variable(String),
204 Iri(String),
205 Literal(String),
206 BlankNode(String),
207}
208
209impl PatternTerm {
210 pub fn is_variable(&self) -> bool {
212 matches!(self, PatternTerm::Variable(_))
213 }
214
215 pub fn variable_name(&self) -> Option<&str> {
217 match self {
218 PatternTerm::Variable(name) => Some(name),
219 _ => None,
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct TriplePatternInfo {
227 pub id: String,
229 pub subject: PatternTerm,
230 pub predicate: PatternTerm,
231 pub object: PatternTerm,
232 pub estimated_cardinality: u64,
234 pub bound_variables: Vec<String>,
236 pub original_pattern: Option<TriplePattern>,
238}
239
240impl TriplePatternInfo {
241 pub fn from_triple_pattern(pattern: &TriplePattern, estimated_cardinality: u64) -> Self {
243 let subject = term_to_pattern_term(&pattern.subject);
244 let predicate = term_to_pattern_term(&pattern.predicate);
245 let object = term_to_pattern_term(&pattern.object);
246
247 let mut bound_variables = Vec::new();
248 if let PatternTerm::Variable(ref v) = subject {
249 bound_variables.push(v.clone());
250 }
251 if let PatternTerm::Variable(ref v) = predicate {
252 bound_variables.push(v.clone());
253 }
254 if let PatternTerm::Variable(ref v) = object {
255 bound_variables.push(v.clone());
256 }
257
258 let id = build_pattern_fingerprint(&subject, &predicate, &object);
260
261 Self {
262 id,
263 subject,
264 predicate,
265 object,
266 estimated_cardinality,
267 bound_variables,
268 original_pattern: Some(pattern.clone()),
269 }
270 }
271
272 pub fn bound_positions(&self) -> usize {
274 let mut count = 0;
275 if !self.subject.is_variable() {
276 count += 1;
277 }
278 if !self.predicate.is_variable() {
279 count += 1;
280 }
281 if !self.object.is_variable() {
282 count += 1;
283 }
284 count
285 }
286}
287
288fn term_to_pattern_term(term: &Term) -> PatternTerm {
289 match term {
290 Term::Variable(v) => PatternTerm::Variable(v.name().to_string()),
291 Term::Iri(iri) => PatternTerm::Iri(iri.as_str().to_string()),
292 Term::Literal(lit) => PatternTerm::Literal(lit.value.clone()),
293 Term::BlankNode(bn) => PatternTerm::BlankNode(bn.as_str().to_string()),
294 _ => PatternTerm::Iri(format!("{term}")),
296 }
297}
298
299fn build_pattern_fingerprint(
300 subject: &PatternTerm,
301 predicate: &PatternTerm,
302 object: &PatternTerm,
303) -> String {
304 let s = match subject {
305 PatternTerm::Variable(_) => "?".to_string(),
306 PatternTerm::Iri(v) => v.clone(),
307 PatternTerm::Literal(v) => format!("\"{v}\""),
308 PatternTerm::BlankNode(v) => format!("_:{v}"),
309 };
310 let p = match predicate {
311 PatternTerm::Variable(_) => "?".to_string(),
312 PatternTerm::Iri(v) => v.clone(),
313 PatternTerm::Literal(v) => format!("\"{v}\""),
314 PatternTerm::BlankNode(v) => format!("_:{v}"),
315 };
316 let o = match object {
317 PatternTerm::Variable(_) => "?".to_string(),
318 PatternTerm::Iri(v) => v.clone(),
319 PatternTerm::Literal(v) => format!("\"{v}\""),
320 PatternTerm::BlankNode(v) => format!("_:{v}"),
321 };
322 format!("{s} {p} {o}")
323}
324
325#[derive(Debug, Clone)]
327#[allow(clippy::large_enum_variant)]
328pub enum JoinPlanNode {
329 TriplePatternScan { info: TriplePatternInfo },
331 HashJoin {
333 left: Box<JoinPlanNode>,
334 right: Box<JoinPlanNode>,
335 join_vars: Vec<String>,
336 estimated_output: u64,
337 },
338 NestedLoopJoin {
340 outer: Box<JoinPlanNode>,
341 inner: Box<JoinPlanNode>,
342 join_vars: Vec<String>,
343 estimated_output: u64,
344 },
345 MergeJoin {
347 left: Box<JoinPlanNode>,
348 right: Box<JoinPlanNode>,
349 join_vars: Vec<String>,
350 sort_key: Vec<String>,
351 estimated_output: u64,
352 },
353}
354
355impl JoinPlanNode {
356 pub fn estimated_cardinality(&self) -> u64 {
358 match self {
359 JoinPlanNode::TriplePatternScan { info } => info.estimated_cardinality,
360 JoinPlanNode::HashJoin {
361 estimated_output, ..
362 } => *estimated_output,
363 JoinPlanNode::NestedLoopJoin {
364 estimated_output, ..
365 } => *estimated_output,
366 JoinPlanNode::MergeJoin {
367 estimated_output, ..
368 } => *estimated_output,
369 }
370 }
371
372 pub fn output_variables(&self) -> Vec<String> {
374 match self {
375 JoinPlanNode::TriplePatternScan { info } => info.bound_variables.clone(),
376 JoinPlanNode::HashJoin { left, right, .. } => {
377 merge_variable_sets(left.output_variables(), right.output_variables())
378 }
379 JoinPlanNode::NestedLoopJoin { outer, inner, .. } => {
380 merge_variable_sets(outer.output_variables(), inner.output_variables())
381 }
382 JoinPlanNode::MergeJoin { left, right, .. } => {
383 merge_variable_sets(left.output_variables(), right.output_variables())
384 }
385 }
386 }
387}
388
389fn merge_variable_sets(mut left: Vec<String>, right: Vec<String>) -> Vec<String> {
390 for v in right {
391 if !left.contains(&v) {
392 left.push(v);
393 }
394 }
395 left
396}
397
398pub struct AdaptiveJoinOrderOptimizer {
400 stats_store: Arc<AdaptiveStatsStore>,
402 max_patterns_for_dp: usize,
404 default_selectivity: f64,
406}
407
408impl AdaptiveJoinOrderOptimizer {
409 pub fn new(stats_store: Arc<AdaptiveStatsStore>) -> Self {
411 Self {
412 stats_store,
413 max_patterns_for_dp: 8,
414 default_selectivity: 0.1,
415 }
416 }
417
418 pub fn with_dp_threshold(mut self, threshold: usize) -> Self {
420 self.max_patterns_for_dp = threshold;
421 self
422 }
423
424 pub fn optimize(&self, patterns: Vec<TriplePatternInfo>) -> Result<JoinPlanNode> {
426 if patterns.is_empty() {
427 return Err(anyhow!("Cannot optimize empty pattern list"));
428 }
429 if patterns.len() == 1 {
430 return Ok(JoinPlanNode::TriplePatternScan {
431 info: patterns.into_iter().next().expect("checked len == 1"),
432 });
433 }
434
435 let adjusted = self.apply_cardinality_corrections(patterns);
437
438 if adjusted.len() <= self.max_patterns_for_dp {
439 self.dp_optimize(&adjusted)
440 } else {
441 self.greedy_optimize(&adjusted)
442 }
443 }
444
445 fn apply_cardinality_corrections(
447 &self,
448 patterns: Vec<TriplePatternInfo>,
449 ) -> Vec<TriplePatternInfo> {
450 patterns
451 .into_iter()
452 .map(|mut p| {
453 let adjusted = self
454 .stats_store
455 .get_adjusted_cardinality(&p.id, p.estimated_cardinality);
456 p.estimated_cardinality = adjusted;
457 p
458 })
459 .collect()
460 }
461
462 fn dp_optimize(&self, patterns: &[TriplePatternInfo]) -> Result<JoinPlanNode> {
467 let n = patterns.len();
468 let total_masks = 1usize << n;
471 let mut dp: Vec<Option<(f64, JoinPlanNode)>> = vec![None; total_masks];
472
473 for (i, pattern) in patterns.iter().enumerate() {
475 let mask = 1usize << i;
476 let plan = JoinPlanNode::TriplePatternScan {
477 info: pattern.clone(),
478 };
479 let cost = self.scan_cost(pattern);
480 dp[mask] = Some((cost, plan));
481 }
482
483 for mask in 1..total_masks {
485 let bit_count = mask.count_ones() as usize;
487 if bit_count < 2 {
488 continue;
489 }
490
491 let mut best: Option<(f64, JoinPlanNode)> = None;
492
493 let mut left_mask = (mask - 1) & mask;
495 while left_mask > 0 {
496 let right_mask = mask ^ left_mask;
497 if right_mask == 0 {
498 left_mask = (left_mask - 1) & mask;
499 continue;
500 }
501
502 if left_mask >= right_mask {
504 left_mask = (left_mask - 1) & mask;
505 continue;
506 }
507
508 let (Some((left_cost, ref left_plan)), Some((right_cost, ref right_plan))) =
509 (&dp[left_mask], &dp[right_mask])
510 else {
511 left_mask = (left_mask - 1) & mask;
512 continue;
513 };
514
515 let left_vars = left_plan.output_variables();
516 let right_vars = right_plan.output_variables();
517 let join_vars = Self::find_join_variables_sets(&left_vars, &right_vars);
518
519 let join_id = format!("{left_mask}x{right_mask}");
521 let selectivity = if join_vars.is_empty() {
522 1.0 } else {
524 self.stats_store
525 .get_adjusted_selectivity(&join_id, self.default_selectivity)
526 };
527
528 let left_card = left_plan.estimated_cardinality();
529 let right_card = right_plan.estimated_cardinality();
530 let output_card =
531 ((left_card as f64 * right_card as f64 * selectivity).round() as u64).max(1);
532
533 let algorithm = Self::select_join_algorithm(left_card, right_card, &join_vars);
534 let join_cost =
535 self.join_cost(left_cost + right_cost, left_card, right_card, &algorithm);
536 let total_cost = left_cost + right_cost + join_cost;
537
538 if best.is_none() || total_cost < best.as_ref().map(|(c, _)| *c).unwrap_or(f64::MAX)
539 {
540 let plan = self.build_join_plan(
541 left_plan.clone(),
542 right_plan.clone(),
543 join_vars,
544 output_card,
545 algorithm,
546 );
547 best = Some((total_cost, plan));
548 }
549
550 left_mask = (left_mask - 1) & mask;
551 }
552
553 if best.is_some() {
554 dp[mask] = best;
555 }
556 }
557
558 let full_mask = total_masks - 1;
559 dp[full_mask]
560 .take()
561 .map(|(_, plan)| plan)
562 .ok_or_else(|| anyhow!("DP optimizer failed to find a valid plan"))
563 }
564
565 fn greedy_optimize(&self, patterns: &[TriplePatternInfo]) -> Result<JoinPlanNode> {
569 if patterns.is_empty() {
570 return Err(anyhow!("Cannot optimize empty pattern list"));
571 }
572
573 let mut remaining: Vec<TriplePatternInfo> = patterns.to_vec();
575 remaining.sort_by_key(|p| p.estimated_cardinality);
576
577 let first = remaining.remove(0);
579 let mut current_plan = JoinPlanNode::TriplePatternScan { info: first };
580
581 while !remaining.is_empty() {
582 let mut best_idx = 0;
584 let mut best_cost = f64::MAX;
585
586 let current_vars = current_plan.output_variables();
587 let current_card = current_plan.estimated_cardinality();
588
589 for (idx, candidate) in remaining.iter().enumerate() {
590 let join_vars =
591 Self::find_join_variables_sets(¤t_vars, &candidate.bound_variables);
592 let join_id = format!("g_{idx}_{}", candidate.id);
593 let selectivity = self
594 .stats_store
595 .get_adjusted_selectivity(&join_id, self.default_selectivity);
596
597 let algorithm = Self::select_join_algorithm(
598 current_card,
599 candidate.estimated_cardinality,
600 &join_vars,
601 );
602 let cost = self.join_cost(
603 0.0,
604 current_card,
605 candidate.estimated_cardinality,
606 &algorithm,
607 );
608
609 let adjusted_cost = if join_vars.is_empty() {
611 cost * 1000.0
612 } else {
613 cost * (1.0 + (1.0 - selectivity))
614 };
615
616 if adjusted_cost < best_cost {
617 best_cost = adjusted_cost;
618 best_idx = idx;
619 }
620 }
621
622 let next = remaining.remove(best_idx);
623 let join_vars = Self::find_join_variables_sets(¤t_vars, &next.bound_variables);
624 let selectivity = self.stats_store.get_adjusted_selectivity(
625 &format!("g_{best_idx}_{}", next.id),
626 self.default_selectivity,
627 );
628 let next_card = next.estimated_cardinality;
629 let output_card =
630 ((current_card as f64 * next_card as f64 * selectivity).round() as u64).max(1);
631 let algorithm = Self::select_join_algorithm(current_card, next_card, &join_vars);
632 let right_plan = JoinPlanNode::TriplePatternScan { info: next };
633
634 current_plan =
635 self.build_join_plan(current_plan, right_plan, join_vars, output_card, algorithm);
636 }
637
638 Ok(current_plan)
639 }
640
641 fn scan_cost(&self, pattern: &TriplePatternInfo) -> f64 {
643 let base = pattern.estimated_cardinality as f64;
646 let bound_factor = match pattern.bound_positions() {
647 0 => 1.0, 1 => 0.3, 2 => 0.05, _ => 0.01, };
652 base * bound_factor
653 }
654
655 fn join_cost(
657 &self,
658 children_cost: f64,
659 left_card: u64,
660 right_card: u64,
661 algorithm: &JoinAlgorithm,
662 ) -> f64 {
663 let l = left_card as f64;
664 let r = right_card as f64;
665 match algorithm {
666 JoinAlgorithm::Hash => {
667 children_cost + r + l
669 }
670 JoinAlgorithm::NestedLoop => {
671 children_cost + l * r
673 }
674 JoinAlgorithm::Merge => {
675 children_cost + l * l.max(1.0).ln() + r * r.max(1.0).ln() + l + r
677 }
678 }
679 }
680
681 fn find_join_variables_sets(left: &[String], right: &[String]) -> Vec<String> {
683 left.iter().filter(|v| right.contains(v)).cloned().collect()
684 }
685
686 pub fn select_join_algorithm(
688 left_card: u64,
689 right_card: u64,
690 join_vars: &[String],
691 ) -> JoinAlgorithm {
692 if join_vars.is_empty() {
693 if left_card.min(right_card) < 100 {
695 return JoinAlgorithm::NestedLoop;
696 }
697 return JoinAlgorithm::Hash;
698 }
699
700 let smaller = left_card.min(right_card);
701 let larger = left_card.max(right_card);
702
703 if smaller < 1000 {
704 JoinAlgorithm::Hash
706 } else if smaller > 50_000 && larger > 50_000 {
707 JoinAlgorithm::Merge
709 } else {
710 JoinAlgorithm::Hash
711 }
712 }
713
714 fn build_join_plan(
716 &self,
717 left: JoinPlanNode,
718 right: JoinPlanNode,
719 join_vars: Vec<String>,
720 estimated_output: u64,
721 algorithm: JoinAlgorithm,
722 ) -> JoinPlanNode {
723 match algorithm {
724 JoinAlgorithm::Hash => JoinPlanNode::HashJoin {
725 left: Box::new(left),
726 right: Box::new(right),
727 join_vars,
728 estimated_output,
729 },
730 JoinAlgorithm::NestedLoop => {
731 JoinPlanNode::NestedLoopJoin {
733 outer: Box::new(left),
734 inner: Box::new(right),
735 join_vars,
736 estimated_output,
737 }
738 }
739 JoinAlgorithm::Merge => {
740 let sort_key = join_vars.clone();
741 JoinPlanNode::MergeJoin {
742 left: Box::new(left),
743 right: Box::new(right),
744 join_vars,
745 sort_key,
746 estimated_output,
747 }
748 }
749 }
750 }
751}
752
753pub struct PlanTimer {
755 component_id: String,
756 start: Instant,
757 stats_store: Arc<AdaptiveStatsStore>,
758}
759
760impl PlanTimer {
761 pub fn start(component_id: impl Into<String>, stats_store: Arc<AdaptiveStatsStore>) -> Self {
763 Self {
764 component_id: component_id.into(),
765 start: Instant::now(),
766 stats_store,
767 }
768 }
769}
770
771impl Drop for PlanTimer {
772 fn drop(&mut self) {
773 let elapsed = self.start.elapsed();
774 self.stats_store
775 .record_execution_time(&self.component_id, elapsed);
776 }
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782 use crate::algebra::{Term, TriplePattern};
783 use oxirs_core::model::{NamedNode, Variable as CoreVariable};
784
785 fn make_var(name: &str) -> Term {
786 Term::Variable(CoreVariable::new(name).unwrap())
787 }
788
789 fn make_iri(iri: &str) -> Term {
790 Term::Iri(NamedNode::new_unchecked(iri))
791 }
792
793 fn pattern_info(
794 subject: PatternTerm,
795 predicate: PatternTerm,
796 object: PatternTerm,
797 cardinality: u64,
798 ) -> TriplePatternInfo {
799 let bound_variables: Vec<String> = [&subject, &predicate, &object]
800 .iter()
801 .filter_map(|t| t.variable_name().map(|s| s.to_string()))
802 .collect();
803 let id = format!("{:?}-{:?}-{:?}", subject, predicate, object);
804 TriplePatternInfo {
805 id,
806 subject,
807 predicate,
808 object,
809 estimated_cardinality: cardinality,
810 bound_variables,
811 original_pattern: None,
812 }
813 }
814
815 #[test]
816 fn test_adaptive_stats_store_record_and_adjust() {
817 let store = AdaptiveStatsStore::new(100);
818 store.record_pattern_execution("pat1", 1000, 500);
819
820 let adjusted = store.get_adjusted_cardinality("pat1", 1000);
822 assert_eq!(
823 adjusted, 500,
824 "Adjusted cardinality should reflect correction factor"
825 );
826 }
827
828 #[test]
829 fn test_adaptive_stats_store_unknown_pattern_returns_base() {
830 let store = AdaptiveStatsStore::new(100);
831 let adjusted = store.get_adjusted_cardinality("unknown_pat", 500);
832 assert_eq!(
833 adjusted, 500,
834 "Unknown pattern should return base estimate unchanged"
835 );
836 }
837
838 #[test]
839 fn test_adaptive_stats_store_join_selectivity() {
840 let store = AdaptiveStatsStore::new(100);
841 store.record_join_execution("j1", 100, 200, 50);
843
844 let adjusted = store.get_adjusted_selectivity("j1", 0.1);
845 assert!(
847 adjusted < 0.1,
848 "Adjusted selectivity should be reduced toward observed value"
849 );
850 assert!(adjusted > 0.0, "Adjusted selectivity must remain positive");
851 }
852
853 #[test]
854 fn test_single_pattern_optimization() {
855 let store = Arc::new(AdaptiveStatsStore::new(100));
856 let optimizer = AdaptiveJoinOrderOptimizer::new(store);
857
858 let patterns = vec![pattern_info(
859 PatternTerm::Variable("s".to_string()),
860 PatternTerm::Iri("http://example.org/type".to_string()),
861 PatternTerm::Variable("o".to_string()),
862 500,
863 )];
864
865 let plan = optimizer.optimize(patterns).unwrap();
866 assert!(matches!(plan, JoinPlanNode::TriplePatternScan { .. }));
867 }
868
869 #[test]
870 fn test_two_pattern_dp_optimization() {
871 let store = Arc::new(AdaptiveStatsStore::new(100));
872 let optimizer = AdaptiveJoinOrderOptimizer::new(store);
873
874 let patterns = vec![
875 pattern_info(
876 PatternTerm::Variable("s".to_string()),
877 PatternTerm::Iri("http://example.org/type".to_string()),
878 PatternTerm::Iri("http://example.org/Person".to_string()),
879 50,
880 ),
881 pattern_info(
882 PatternTerm::Variable("s".to_string()),
883 PatternTerm::Iri("http://xmlns.com/foaf/0.1/name".to_string()),
884 PatternTerm::Variable("name".to_string()),
885 10000,
886 ),
887 ];
888
889 let plan = optimizer.optimize(patterns).unwrap();
890 assert!(
892 matches!(
893 plan,
894 JoinPlanNode::HashJoin { .. }
895 | JoinPlanNode::NestedLoopJoin { .. }
896 | JoinPlanNode::MergeJoin { .. }
897 ),
898 "Should produce a join plan"
899 );
900 }
901
902 #[test]
903 fn test_greedy_optimization_for_large_pattern_sets() {
904 let store = Arc::new(AdaptiveStatsStore::new(100));
905 let optimizer = AdaptiveJoinOrderOptimizer::new(store).with_dp_threshold(3);
906
907 let patterns: Vec<TriplePatternInfo> = (0..6)
908 .map(|i| {
909 pattern_info(
910 PatternTerm::Variable(format!("s{i}")),
911 PatternTerm::Iri(format!("http://example.org/p{i}")),
912 PatternTerm::Variable(format!("o{i}")),
913 (i + 1) as u64 * 100,
914 )
915 })
916 .collect();
917
918 let plan = optimizer.optimize(patterns).unwrap();
919 assert!(
921 !matches!(plan, JoinPlanNode::TriplePatternScan { .. }),
922 "Multiple patterns should produce a join plan"
923 );
924 }
925
926 #[test]
927 fn test_empty_patterns_returns_error() {
928 let store = Arc::new(AdaptiveStatsStore::new(100));
929 let optimizer = AdaptiveJoinOrderOptimizer::new(store);
930 assert!(optimizer.optimize(vec![]).is_err());
931 }
932
933 #[test]
934 fn test_join_algorithm_selection() {
935 let alg =
937 AdaptiveJoinOrderOptimizer::select_join_algorithm(100, 1_000_000, &["x".to_string()]);
938 assert_eq!(alg, JoinAlgorithm::Hash);
939
940 let alg =
942 AdaptiveJoinOrderOptimizer::select_join_algorithm(100_000, 200_000, &["x".to_string()]);
943 assert_eq!(alg, JoinAlgorithm::Merge);
944 }
945
946 #[test]
947 fn test_from_triple_pattern() {
948 let pattern = TriplePattern::new(
949 make_var("s"),
950 make_iri("http://example.org/p"),
951 make_var("o"),
952 );
953 let info = TriplePatternInfo::from_triple_pattern(&pattern, 100);
954 assert_eq!(info.estimated_cardinality, 100);
955 assert!(info.bound_variables.contains(&"s".to_string()));
956 assert!(info.bound_variables.contains(&"o".to_string()));
957 assert_eq!(info.bound_positions(), 1); }
959
960 #[test]
961 fn test_cardinality_correction_with_multiple_samples() {
962 let store = AdaptiveStatsStore::new(100);
963 for _ in 0..5 {
965 store.record_pattern_execution("pat2", 100, 200);
966 }
967 let adjusted = store.get_adjusted_cardinality("pat2", 100);
968 assert!(adjusted > 100, "Cardinality should be adjusted upward");
970 }
971
972 #[test]
973 fn test_plan_timer_records_duration() {
974 let store = Arc::new(AdaptiveStatsStore::new(100));
975 {
976 let _timer = PlanTimer::start("test_component", Arc::clone(&store));
977 std::thread::sleep(std::time::Duration::from_millis(5));
978 }
979 let snapshot = store.snapshot().unwrap();
980 assert!(
981 snapshot.execution_times.contains_key("test_component"),
982 "Timer should record execution time on drop"
983 );
984 }
985
986 #[test]
987 fn test_output_variables_propagation() {
988 let store = Arc::new(AdaptiveStatsStore::new(100));
989 let optimizer = AdaptiveJoinOrderOptimizer::new(store);
990
991 let patterns = vec![
992 pattern_info(
993 PatternTerm::Variable("s".to_string()),
994 PatternTerm::Iri("http://example.org/type".to_string()),
995 PatternTerm::Variable("type".to_string()),
996 100,
997 ),
998 pattern_info(
999 PatternTerm::Variable("s".to_string()),
1000 PatternTerm::Iri("http://example.org/name".to_string()),
1001 PatternTerm::Variable("name".to_string()),
1002 500,
1003 ),
1004 ];
1005
1006 let plan = optimizer.optimize(patterns).unwrap();
1007 let vars = plan.output_variables();
1008 assert!(vars.contains(&"s".to_string()), "Plan should expose ?s");
1010 assert!(
1011 vars.contains(&"name".to_string()),
1012 "Plan should expose ?name"
1013 );
1014 }
1015}
1016
1017#[cfg(test)]
1018mod extended_tests {
1019 use super::*;
1020 use crate::algebra::{Term, TriplePattern};
1021 use oxirs_core::model::{NamedNode, Variable as CoreVariable};
1022
1023 fn make_var(name: &str) -> Term {
1024 Term::Variable(CoreVariable::new(name).unwrap())
1025 }
1026
1027 fn make_iri(iri: &str) -> Term {
1028 Term::Iri(NamedNode::new_unchecked(iri))
1029 }
1030
1031 fn p_info(
1032 subject: PatternTerm,
1033 predicate: PatternTerm,
1034 object: PatternTerm,
1035 cardinality: u64,
1036 ) -> TriplePatternInfo {
1037 let bound_variables: Vec<String> = [&subject, &predicate, &object]
1038 .iter()
1039 .filter_map(|t| t.variable_name().map(|s| s.to_string()))
1040 .collect();
1041 let id = format!("{:?}-{:?}-{:?}", subject, predicate, object);
1042 TriplePatternInfo {
1043 id,
1044 subject,
1045 predicate,
1046 object,
1047 estimated_cardinality: cardinality,
1048 bound_variables,
1049 original_pattern: None,
1050 }
1051 }
1052
1053 #[test]
1056 fn test_stats_snapshot_contains_recorded_pattern() {
1057 let store = AdaptiveStatsStore::new(50);
1058 store.record_pattern_execution("snap_pat", 200, 400);
1059
1060 let snapshot = store.snapshot().unwrap();
1061 assert!(snapshot.pattern_stats.contains_key("snap_pat"));
1062 let entry = &snapshot.pattern_stats["snap_pat"];
1063 assert_eq!(entry.sample_count, 1);
1064 assert_eq!(entry.actual_cardinality_sum, 400);
1065 }
1066
1067 #[test]
1068 fn test_stats_snapshot_contains_recorded_join() {
1069 let store = AdaptiveStatsStore::new(50);
1070 store.record_join_execution("j_snap", 1000, 500, 25);
1071
1072 let snapshot = store.snapshot().unwrap();
1073 assert!(snapshot.join_stats.contains_key("j_snap"));
1074 let entry = &snapshot.join_stats["j_snap"];
1075 assert_eq!(entry.sample_count, 1);
1076 assert_eq!(entry.output_cardinality_sum, 25);
1077 }
1078
1079 #[test]
1080 fn test_correction_factor_clamped_above_zero() {
1081 let store = AdaptiveStatsStore::new(50);
1082 store.record_pattern_execution("extreme_over", 1_000_000, 1);
1084 let adjusted = store.get_adjusted_cardinality("extreme_over", 1_000_000);
1085 assert!(adjusted >= 1, "Adjusted cardinality must be at least 1");
1086 }
1087
1088 #[test]
1089 fn test_multiple_patterns_tracked_independently() {
1090 let store = AdaptiveStatsStore::new(50);
1091 store.record_pattern_execution("pat_a", 100, 50);
1092 store.record_pattern_execution("pat_b", 100, 300);
1093
1094 let adj_a = store.get_adjusted_cardinality("pat_a", 100);
1095 let adj_b = store.get_adjusted_cardinality("pat_b", 100);
1096 assert!(
1097 adj_a < adj_b,
1098 "pat_a (undercount) should produce lower estimate than pat_b (overcount)"
1099 );
1100 }
1101
1102 #[test]
1103 fn test_execution_time_recorded_via_snapshot() {
1104 let store = AdaptiveStatsStore::new(50);
1105 store.record_execution_time("component_x", std::time::Duration::from_millis(42));
1106 let snapshot = store.snapshot().unwrap();
1107 assert!(snapshot.execution_times.contains_key("component_x"));
1108 assert_eq!(
1109 snapshot.execution_times["component_x"],
1110 std::time::Duration::from_millis(42)
1111 );
1112 }
1113
1114 #[test]
1115 fn test_join_selectivity_unknown_join_returns_base() {
1116 let store = AdaptiveStatsStore::new(50);
1117 let base = 0.05;
1118 let adj = store.get_adjusted_selectivity("no_such_join", base);
1119 assert!(
1120 (adj - base).abs() < 1e-9,
1121 "Unknown join should return base selectivity unchanged"
1122 );
1123 }
1124
1125 #[test]
1126 fn test_join_selectivity_clamps_to_valid_range() {
1127 let store = AdaptiveStatsStore::new(50);
1128 for _ in 0..20 {
1130 store.record_join_execution("tiny_sel", 1_000_000, 1_000_000, 1);
1131 }
1132 let adj = store.get_adjusted_selectivity("tiny_sel", 0.5);
1133 assert!(adj > 0.0, "Selectivity must remain positive");
1134 assert!(adj <= 1.0, "Selectivity must not exceed 1.0");
1135 }
1136
1137 #[test]
1140 fn test_pattern_term_iri_is_not_variable() {
1141 let term = PatternTerm::Iri("http://example.org/foo".to_string());
1142 assert!(!term.is_variable());
1143 assert!(term.variable_name().is_none());
1144 }
1145
1146 #[test]
1147 fn test_pattern_term_literal_is_not_variable() {
1148 let term = PatternTerm::Literal("hello".to_string());
1149 assert!(!term.is_variable());
1150 assert!(term.variable_name().is_none());
1151 }
1152
1153 #[test]
1154 fn test_pattern_term_blank_node_is_not_variable() {
1155 let term = PatternTerm::BlankNode("b1".to_string());
1156 assert!(!term.is_variable());
1157 assert!(term.variable_name().is_none());
1158 }
1159
1160 #[test]
1161 fn test_triple_pattern_info_bound_positions_fully_bound() {
1162 let info = p_info(
1163 PatternTerm::Iri("http://s".to_string()),
1164 PatternTerm::Iri("http://p".to_string()),
1165 PatternTerm::Literal("val".to_string()),
1166 10,
1167 );
1168 assert_eq!(info.bound_positions(), 3, "All positions are bound");
1169 }
1170
1171 #[test]
1172 fn test_triple_pattern_info_bound_positions_no_variables() {
1173 let info = p_info(
1174 PatternTerm::Variable("s".to_string()),
1175 PatternTerm::Variable("p".to_string()),
1176 PatternTerm::Variable("o".to_string()),
1177 100,
1178 );
1179 assert_eq!(
1180 info.bound_positions(),
1181 0,
1182 "No positions are bound when all are variables"
1183 );
1184 }
1185
1186 #[test]
1187 fn test_from_triple_pattern_literal_object() {
1188 let pattern = TriplePattern::new(
1189 make_var("s"),
1190 make_iri("http://example.org/p"),
1191 make_iri("http://example.org/o"),
1192 );
1193 let info = TriplePatternInfo::from_triple_pattern(&pattern, 42);
1194 assert_eq!(info.estimated_cardinality, 42);
1195 assert!(info.bound_variables.contains(&"s".to_string()));
1197 }
1198
1199 #[test]
1202 fn test_join_plan_node_hash_join_estimated_cardinality() {
1203 let left = JoinPlanNode::TriplePatternScan {
1204 info: p_info(
1205 PatternTerm::Variable("s".to_string()),
1206 PatternTerm::Iri("http://p".to_string()),
1207 PatternTerm::Variable("o".to_string()),
1208 100,
1209 ),
1210 };
1211 let right = JoinPlanNode::TriplePatternScan {
1212 info: p_info(
1213 PatternTerm::Variable("s".to_string()),
1214 PatternTerm::Iri("http://q".to_string()),
1215 PatternTerm::Variable("x".to_string()),
1216 200,
1217 ),
1218 };
1219 let node = JoinPlanNode::HashJoin {
1220 left: Box::new(left),
1221 right: Box::new(right),
1222 join_vars: vec!["s".to_string()],
1223 estimated_output: 50,
1224 };
1225 assert_eq!(node.estimated_cardinality(), 50);
1226 }
1227
1228 #[test]
1229 fn test_join_plan_nested_loop_output_variables() {
1230 let outer = JoinPlanNode::TriplePatternScan {
1231 info: p_info(
1232 PatternTerm::Variable("s".to_string()),
1233 PatternTerm::Iri("http://p".to_string()),
1234 PatternTerm::Variable("o".to_string()),
1235 100,
1236 ),
1237 };
1238 let inner = JoinPlanNode::TriplePatternScan {
1239 info: p_info(
1240 PatternTerm::Variable("o".to_string()),
1241 PatternTerm::Iri("http://q".to_string()),
1242 PatternTerm::Variable("z".to_string()),
1243 50,
1244 ),
1245 };
1246 let node = JoinPlanNode::NestedLoopJoin {
1247 outer: Box::new(outer),
1248 inner: Box::new(inner),
1249 join_vars: vec!["o".to_string()],
1250 estimated_output: 30,
1251 };
1252 let vars = node.output_variables();
1253 assert!(vars.contains(&"s".to_string()), "Should contain s");
1254 assert!(vars.contains(&"o".to_string()), "Should contain o");
1255 assert!(vars.contains(&"z".to_string()), "Should contain z");
1256 }
1257
1258 #[test]
1261 fn test_optimizer_selects_lower_cardinality_pattern_first() {
1262 let store = Arc::new(AdaptiveStatsStore::new(50));
1263 let optimizer = AdaptiveJoinOrderOptimizer::new(Arc::clone(&store));
1264
1265 let patterns = vec![
1266 p_info(
1267 PatternTerm::Variable("s".to_string()),
1268 PatternTerm::Iri("http://rare".to_string()),
1269 PatternTerm::Variable("o1".to_string()),
1270 5, ),
1272 p_info(
1273 PatternTerm::Variable("s".to_string()),
1274 PatternTerm::Iri("http://common".to_string()),
1275 PatternTerm::Variable("o2".to_string()),
1276 50_000, ),
1278 ];
1279
1280 let plan = optimizer.optimize(patterns).unwrap();
1281 assert!(
1284 matches!(
1285 plan,
1286 JoinPlanNode::HashJoin { .. }
1287 | JoinPlanNode::NestedLoopJoin { .. }
1288 | JoinPlanNode::MergeJoin { .. }
1289 ),
1290 "Two patterns should produce a join plan"
1291 );
1292 }
1293
1294 #[test]
1295 fn test_optimizer_dp_threshold_boundary() {
1296 let store = Arc::new(AdaptiveStatsStore::new(50));
1298 let optimizer = AdaptiveJoinOrderOptimizer::new(Arc::clone(&store)).with_dp_threshold(4);
1299
1300 let patterns: Vec<TriplePatternInfo> = (0..4)
1301 .map(|i| {
1302 p_info(
1303 PatternTerm::Variable(format!("s{i}")),
1304 PatternTerm::Iri(format!("http://p{i}")),
1305 PatternTerm::Variable(format!("o{i}")),
1306 (i + 1) as u64 * 50,
1307 )
1308 })
1309 .collect();
1310
1311 let result = optimizer.optimize(patterns);
1312 assert!(
1313 result.is_ok(),
1314 "DP optimization at threshold should succeed"
1315 );
1316 }
1317
1318 #[test]
1319 fn test_optimizer_uses_runtime_feedback_for_ordering() {
1320 let store = Arc::new(AdaptiveStatsStore::new(50));
1321 store.record_pattern_execution("? http://heavy ?", 10, 100_000);
1323
1324 let optimizer = AdaptiveJoinOrderOptimizer::new(Arc::clone(&store));
1325 let patterns = vec![
1326 p_info(
1327 PatternTerm::Variable("s".to_string()),
1328 PatternTerm::Iri("http://heavy".to_string()),
1329 PatternTerm::Variable("o".to_string()),
1330 10, ),
1332 p_info(
1333 PatternTerm::Variable("s".to_string()),
1334 PatternTerm::Iri("http://light".to_string()),
1335 PatternTerm::Variable("x".to_string()),
1336 500,
1337 ),
1338 ];
1339 let result = optimizer.optimize(patterns);
1340 assert!(
1341 result.is_ok(),
1342 "Optimizer should succeed with runtime feedback"
1343 );
1344 }
1345}