Skip to main content

oxirs_arq/optimizer/
adaptive.rs

1//! Adaptive Join Ordering with Runtime Feedback
2//!
3//! This module implements an adaptive query optimizer that learns from runtime
4//! execution statistics to continuously improve join ordering decisions.
5//! It uses a combination of cost-based optimization and runtime feedback.
6
7use crate::algebra::{Term, TriplePattern};
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12
13/// Runtime statistics collected during query execution for feedback loops
14#[derive(Debug, Clone, Default)]
15pub struct RuntimeStats {
16    /// Estimated vs actual cardinality per pattern fingerprint
17    pub pattern_stats: HashMap<String, PatternRuntimeStats>,
18    /// Join selectivity statistics per join fingerprint
19    pub join_stats: HashMap<String, JoinRuntimeStats>,
20    /// Execution times per plan component
21    pub execution_times: HashMap<String, Duration>,
22    /// Total number of queries optimized
23    pub query_count: u64,
24}
25
26/// Per-pattern runtime statistics
27#[derive(Debug, Clone, Default)]
28pub struct PatternRuntimeStats {
29    /// Cumulative estimated cardinality (sum across executions)
30    pub estimated_cardinality_sum: u64,
31    /// Cumulative actual cardinality (sum across executions)
32    pub actual_cardinality_sum: u64,
33    /// Most recent estimation error (actual / estimated ratio)
34    pub estimation_error: f64,
35    /// Number of samples recorded
36    pub sample_count: u64,
37    /// Correction factor derived from history (actual / estimated)
38    pub correction_factor: f64,
39}
40
41/// Per-join runtime statistics
42#[derive(Debug, Clone, Default)]
43pub struct JoinRuntimeStats {
44    /// Cumulative left input cardinality
45    pub left_cardinality_sum: u64,
46    /// Cumulative right input cardinality
47    pub right_cardinality_sum: u64,
48    /// Cumulative output cardinality
49    pub output_cardinality_sum: u64,
50    /// Observed selectivity (output / (left * right))
51    pub observed_selectivity: f64,
52    /// Number of samples
53    pub sample_count: u64,
54}
55
56/// Thread-safe store for adaptive runtime statistics
57pub struct AdaptiveStatsStore {
58    stats: Arc<RwLock<RuntimeStats>>,
59    max_history: usize,
60}
61
62impl AdaptiveStatsStore {
63    /// Create a new adaptive statistics store
64    pub fn new(max_history: usize) -> Self {
65        Self {
66            stats: Arc::new(RwLock::new(RuntimeStats::default())),
67            max_history,
68        }
69    }
70
71    /// Record actual vs estimated cardinality for a pattern
72    pub fn record_pattern_execution(&self, pattern_id: &str, estimated: u64, actual: u64) {
73        let Ok(mut stats) = self.stats.write() else {
74            return;
75        };
76        let entry = stats
77            .pattern_stats
78            .entry(pattern_id.to_string())
79            .or_default();
80
81        entry.sample_count += 1;
82        entry.estimated_cardinality_sum += estimated;
83        entry.actual_cardinality_sum += actual;
84
85        let ratio = if estimated > 0 {
86            actual as f64 / estimated as f64
87        } else {
88            1.0
89        };
90        entry.estimation_error = ratio;
91
92        // Exponential moving average of correction factor (alpha = 0.2)
93        if entry.sample_count == 1 {
94            entry.correction_factor = ratio;
95        } else {
96            entry.correction_factor = 0.8 * entry.correction_factor + 0.2 * ratio;
97        }
98
99        // Trim history if needed by resetting sums when they exceed max_history
100        if entry.sample_count > self.max_history as u64 {
101            let avg_est = entry.estimated_cardinality_sum / entry.sample_count;
102            let avg_act = entry.actual_cardinality_sum / entry.sample_count;
103            entry.estimated_cardinality_sum = avg_est;
104            entry.actual_cardinality_sum = avg_act;
105            entry.sample_count = 1;
106        }
107    }
108
109    /// Record actual join execution statistics
110    pub fn record_join_execution(&self, join_id: &str, left: u64, right: u64, output: u64) {
111        let Ok(mut stats) = self.stats.write() else {
112            return;
113        };
114        let entry = stats.join_stats.entry(join_id.to_string()).or_default();
115
116        entry.sample_count += 1;
117        entry.left_cardinality_sum += left;
118        entry.right_cardinality_sum += right;
119        entry.output_cardinality_sum += output;
120
121        let denominator = (left as f64) * (right as f64);
122        let selectivity = if denominator > 0.0 {
123            output as f64 / denominator
124        } else {
125            0.0
126        };
127
128        // Exponential moving average
129        if entry.sample_count == 1 {
130            entry.observed_selectivity = selectivity;
131        } else {
132            entry.observed_selectivity = 0.8 * entry.observed_selectivity + 0.2 * selectivity;
133        }
134    }
135
136    /// Record execution time for a plan component
137    pub fn record_execution_time(&self, component_id: &str, duration: Duration) {
138        let Ok(mut stats) = self.stats.write() else {
139            return;
140        };
141        stats
142            .execution_times
143            .insert(component_id.to_string(), duration);
144    }
145
146    /// Get adjusted cardinality estimate incorporating runtime feedback
147    pub fn get_adjusted_cardinality(&self, pattern_id: &str, base_estimate: u64) -> u64 {
148        let Ok(stats) = self.stats.read() else {
149            return base_estimate;
150        };
151        let Some(entry) = stats.pattern_stats.get(pattern_id) else {
152            return base_estimate;
153        };
154
155        if entry.sample_count == 0 {
156            return base_estimate;
157        }
158
159        let adjusted = (base_estimate as f64 * entry.correction_factor).round() as u64;
160        adjusted.max(1)
161    }
162
163    /// Get adjusted selectivity incorporating runtime feedback
164    pub fn get_adjusted_selectivity(&self, join_id: &str, base_selectivity: f64) -> f64 {
165        let Ok(stats) = self.stats.read() else {
166            return base_selectivity;
167        };
168        let Some(entry) = stats.join_stats.get(join_id) else {
169            return base_selectivity;
170        };
171
172        if entry.sample_count == 0 {
173            return base_selectivity;
174        }
175
176        // Blend base estimate with observed selectivity (weight toward observed as samples grow)
177        let observed_weight = (entry.sample_count as f64 / 10.0).min(0.8);
178        let base_weight = 1.0 - observed_weight;
179        (base_weight * base_selectivity + observed_weight * entry.observed_selectivity)
180            .clamp(0.0001, 1.0)
181    }
182
183    /// Snapshot a read-only view of the current statistics
184    pub fn snapshot(&self) -> Option<RuntimeStats> {
185        self.stats.read().ok().map(|s| s.clone())
186    }
187}
188
189/// Which join algorithm to use for a given join node
190#[derive(Debug, Clone, PartialEq, Eq)]
191pub enum JoinAlgorithm {
192    /// Hash join: good when one side is small enough to fit in memory
193    Hash,
194    /// Nested loop join: good when outer is tiny and inner has index support
195    NestedLoop,
196    /// Sort-merge join: good for large sorted inputs with shared sort keys
197    Merge,
198}
199
200/// A term in a triple pattern (subject, predicate, or object position)
201#[derive(Debug, Clone)]
202pub enum PatternTerm {
203    Variable(String),
204    Iri(String),
205    Literal(String),
206    BlankNode(String),
207}
208
209impl PatternTerm {
210    /// Return true if this term is a variable (unbound position)
211    pub fn is_variable(&self) -> bool {
212        matches!(self, PatternTerm::Variable(_))
213    }
214
215    /// Return the variable name if this is a variable
216    pub fn variable_name(&self) -> Option<&str> {
217        match self {
218            PatternTerm::Variable(name) => Some(name),
219            _ => None,
220        }
221    }
222}
223
224/// Full information about a triple pattern for optimization purposes
225#[derive(Debug, Clone)]
226pub struct TriplePatternInfo {
227    /// Unique identifier (fingerprint) for this pattern
228    pub id: String,
229    pub subject: PatternTerm,
230    pub predicate: PatternTerm,
231    pub object: PatternTerm,
232    /// Estimated number of matching triples
233    pub estimated_cardinality: u64,
234    /// Variable names bound by this pattern
235    pub bound_variables: Vec<String>,
236    /// Reference to the original TriplePattern (for re-use downstream)
237    pub original_pattern: Option<TriplePattern>,
238}
239
240impl TriplePatternInfo {
241    /// Construct a TriplePatternInfo from an algebra TriplePattern and cardinality estimate
242    pub fn from_triple_pattern(pattern: &TriplePattern, estimated_cardinality: u64) -> Self {
243        let subject = term_to_pattern_term(&pattern.subject);
244        let predicate = term_to_pattern_term(&pattern.predicate);
245        let object = term_to_pattern_term(&pattern.object);
246
247        let mut bound_variables = Vec::new();
248        if let PatternTerm::Variable(ref v) = subject {
249            bound_variables.push(v.clone());
250        }
251        if let PatternTerm::Variable(ref v) = predicate {
252            bound_variables.push(v.clone());
253        }
254        if let PatternTerm::Variable(ref v) = object {
255            bound_variables.push(v.clone());
256        }
257
258        // Build a stable pattern fingerprint
259        let id = build_pattern_fingerprint(&subject, &predicate, &object);
260
261        Self {
262            id,
263            subject,
264            predicate,
265            object,
266            estimated_cardinality,
267            bound_variables,
268            original_pattern: Some(pattern.clone()),
269        }
270    }
271
272    /// Number of unbound (variable) positions - lower = more selective
273    pub fn bound_positions(&self) -> usize {
274        let mut count = 0;
275        if !self.subject.is_variable() {
276            count += 1;
277        }
278        if !self.predicate.is_variable() {
279            count += 1;
280        }
281        if !self.object.is_variable() {
282            count += 1;
283        }
284        count
285    }
286}
287
288fn term_to_pattern_term(term: &Term) -> PatternTerm {
289    match term {
290        Term::Variable(v) => PatternTerm::Variable(v.name().to_string()),
291        Term::Iri(iri) => PatternTerm::Iri(iri.as_str().to_string()),
292        Term::Literal(lit) => PatternTerm::Literal(lit.value.clone()),
293        Term::BlankNode(bn) => PatternTerm::BlankNode(bn.as_str().to_string()),
294        // Treat quoted triples and property paths as opaque IRIs for ordering purposes
295        _ => PatternTerm::Iri(format!("{term}")),
296    }
297}
298
299fn build_pattern_fingerprint(
300    subject: &PatternTerm,
301    predicate: &PatternTerm,
302    object: &PatternTerm,
303) -> String {
304    let s = match subject {
305        PatternTerm::Variable(_) => "?".to_string(),
306        PatternTerm::Iri(v) => v.clone(),
307        PatternTerm::Literal(v) => format!("\"{v}\""),
308        PatternTerm::BlankNode(v) => format!("_:{v}"),
309    };
310    let p = match predicate {
311        PatternTerm::Variable(_) => "?".to_string(),
312        PatternTerm::Iri(v) => v.clone(),
313        PatternTerm::Literal(v) => format!("\"{v}\""),
314        PatternTerm::BlankNode(v) => format!("_:{v}"),
315    };
316    let o = match object {
317        PatternTerm::Variable(_) => "?".to_string(),
318        PatternTerm::Iri(v) => v.clone(),
319        PatternTerm::Literal(v) => format!("\"{v}\""),
320        PatternTerm::BlankNode(v) => format!("_:{v}"),
321    };
322    format!("{s} {p} {o}")
323}
324
325/// A join plan node - the output of the adaptive optimizer
326#[derive(Debug, Clone)]
327#[allow(clippy::large_enum_variant)]
328pub enum JoinPlanNode {
329    /// A leaf scan of a single triple pattern
330    TriplePatternScan { info: TriplePatternInfo },
331    /// Hash join of two sub-plans
332    HashJoin {
333        left: Box<JoinPlanNode>,
334        right: Box<JoinPlanNode>,
335        join_vars: Vec<String>,
336        estimated_output: u64,
337    },
338    /// Nested-loop join (outer drives inner)
339    NestedLoopJoin {
340        outer: Box<JoinPlanNode>,
341        inner: Box<JoinPlanNode>,
342        join_vars: Vec<String>,
343        estimated_output: u64,
344    },
345    /// Sort-merge join of two ordered sub-plans
346    MergeJoin {
347        left: Box<JoinPlanNode>,
348        right: Box<JoinPlanNode>,
349        join_vars: Vec<String>,
350        sort_key: Vec<String>,
351        estimated_output: u64,
352    },
353}
354
355impl JoinPlanNode {
356    /// Estimated output cardinality of this node
357    pub fn estimated_cardinality(&self) -> u64 {
358        match self {
359            JoinPlanNode::TriplePatternScan { info } => info.estimated_cardinality,
360            JoinPlanNode::HashJoin {
361                estimated_output, ..
362            } => *estimated_output,
363            JoinPlanNode::NestedLoopJoin {
364                estimated_output, ..
365            } => *estimated_output,
366            JoinPlanNode::MergeJoin {
367                estimated_output, ..
368            } => *estimated_output,
369        }
370    }
371
372    /// Collect all variable names produced by this node
373    pub fn output_variables(&self) -> Vec<String> {
374        match self {
375            JoinPlanNode::TriplePatternScan { info } => info.bound_variables.clone(),
376            JoinPlanNode::HashJoin { left, right, .. } => {
377                merge_variable_sets(left.output_variables(), right.output_variables())
378            }
379            JoinPlanNode::NestedLoopJoin { outer, inner, .. } => {
380                merge_variable_sets(outer.output_variables(), inner.output_variables())
381            }
382            JoinPlanNode::MergeJoin { left, right, .. } => {
383                merge_variable_sets(left.output_variables(), right.output_variables())
384            }
385        }
386    }
387}
388
389fn merge_variable_sets(mut left: Vec<String>, right: Vec<String>) -> Vec<String> {
390    for v in right {
391        if !left.contains(&v) {
392            left.push(v);
393        }
394    }
395    left
396}
397
398/// Adaptive join order optimizer using dynamic programming with runtime feedback
399pub struct AdaptiveJoinOrderOptimizer {
400    /// Shared statistics store for feedback
401    stats_store: Arc<AdaptiveStatsStore>,
402    /// Above this threshold, fall back to greedy heuristics
403    max_patterns_for_dp: usize,
404    /// Default join selectivity when no history is available
405    default_selectivity: f64,
406}
407
408impl AdaptiveJoinOrderOptimizer {
409    /// Create a new optimizer with a reference to the shared stats store
410    pub fn new(stats_store: Arc<AdaptiveStatsStore>) -> Self {
411        Self {
412            stats_store,
413            max_patterns_for_dp: 8,
414            default_selectivity: 0.1,
415        }
416    }
417
418    /// Override DP threshold
419    pub fn with_dp_threshold(mut self, threshold: usize) -> Self {
420        self.max_patterns_for_dp = threshold;
421        self
422    }
423
424    /// Main optimization entry point
425    pub fn optimize(&self, patterns: Vec<TriplePatternInfo>) -> Result<JoinPlanNode> {
426        if patterns.is_empty() {
427            return Err(anyhow!("Cannot optimize empty pattern list"));
428        }
429        if patterns.len() == 1 {
430            return Ok(JoinPlanNode::TriplePatternScan {
431                info: patterns.into_iter().next().expect("checked len == 1"),
432            });
433        }
434
435        // Apply adaptive cardinality corrections
436        let adjusted = self.apply_cardinality_corrections(patterns);
437
438        if adjusted.len() <= self.max_patterns_for_dp {
439            self.dp_optimize(&adjusted)
440        } else {
441            self.greedy_optimize(&adjusted)
442        }
443    }
444
445    /// Apply cardinality corrections from the stats store to all patterns
446    fn apply_cardinality_corrections(
447        &self,
448        patterns: Vec<TriplePatternInfo>,
449    ) -> Vec<TriplePatternInfo> {
450        patterns
451            .into_iter()
452            .map(|mut p| {
453                let adjusted = self
454                    .stats_store
455                    .get_adjusted_cardinality(&p.id, p.estimated_cardinality);
456                p.estimated_cardinality = adjusted;
457                p
458            })
459            .collect()
460    }
461
462    /// Dynamic programming optimizer for small numbers of patterns
463    ///
464    /// Uses the classic Selinger-style DP approach: enumerate all subsets,
465    /// find optimal plan for each, then combine bottom-up.
466    fn dp_optimize(&self, patterns: &[TriplePatternInfo]) -> Result<JoinPlanNode> {
467        let n = patterns.len();
468        // dp[mask] = best (cost, plan) for the subset encoded by `mask`
469        // mask is a bitmask of pattern indices
470        let total_masks = 1usize << n;
471        let mut dp: Vec<Option<(f64, JoinPlanNode)>> = vec![None; total_masks];
472
473        // Initialize single-pattern entries
474        for (i, pattern) in patterns.iter().enumerate() {
475            let mask = 1usize << i;
476            let plan = JoinPlanNode::TriplePatternScan {
477                info: pattern.clone(),
478            };
479            let cost = self.scan_cost(pattern);
480            dp[mask] = Some((cost, plan));
481        }
482
483        // Fill DP table bottom-up (increasing subset size)
484        for mask in 1..total_masks {
485            // Skip single-bit masks (already initialized) and zero
486            let bit_count = mask.count_ones() as usize;
487            if bit_count < 2 {
488                continue;
489            }
490
491            let mut best: Option<(f64, JoinPlanNode)> = None;
492
493            // Enumerate all proper subsets of mask as the left side
494            let mut left_mask = (mask - 1) & mask;
495            while left_mask > 0 {
496                let right_mask = mask ^ left_mask;
497                if right_mask == 0 {
498                    left_mask = (left_mask - 1) & mask;
499                    continue;
500                }
501
502                // Avoid duplicate pairs (left, right) == (right, left) by requiring left < right
503                if left_mask >= right_mask {
504                    left_mask = (left_mask - 1) & mask;
505                    continue;
506                }
507
508                let (Some((left_cost, ref left_plan)), Some((right_cost, ref right_plan))) =
509                    (&dp[left_mask], &dp[right_mask])
510                else {
511                    left_mask = (left_mask - 1) & mask;
512                    continue;
513                };
514
515                let left_vars = left_plan.output_variables();
516                let right_vars = right_plan.output_variables();
517                let join_vars = Self::find_join_variables_sets(&left_vars, &right_vars);
518
519                // No shared variables means a cross product - penalize heavily
520                let join_id = format!("{left_mask}x{right_mask}");
521                let selectivity = if join_vars.is_empty() {
522                    1.0 // cross product
523                } else {
524                    self.stats_store
525                        .get_adjusted_selectivity(&join_id, self.default_selectivity)
526                };
527
528                let left_card = left_plan.estimated_cardinality();
529                let right_card = right_plan.estimated_cardinality();
530                let output_card =
531                    ((left_card as f64 * right_card as f64 * selectivity).round() as u64).max(1);
532
533                let algorithm = Self::select_join_algorithm(left_card, right_card, &join_vars);
534                let join_cost =
535                    self.join_cost(left_cost + right_cost, left_card, right_card, &algorithm);
536                let total_cost = left_cost + right_cost + join_cost;
537
538                if best.is_none() || total_cost < best.as_ref().map(|(c, _)| *c).unwrap_or(f64::MAX)
539                {
540                    let plan = self.build_join_plan(
541                        left_plan.clone(),
542                        right_plan.clone(),
543                        join_vars,
544                        output_card,
545                        algorithm,
546                    );
547                    best = Some((total_cost, plan));
548                }
549
550                left_mask = (left_mask - 1) & mask;
551            }
552
553            if best.is_some() {
554                dp[mask] = best;
555            }
556        }
557
558        let full_mask = total_masks - 1;
559        dp[full_mask]
560            .take()
561            .map(|(_, plan)| plan)
562            .ok_or_else(|| anyhow!("DP optimizer failed to find a valid plan"))
563    }
564
565    /// Greedy optimizer for large numbers of patterns
566    ///
567    /// Repeatedly picks the cheapest next pattern to join with the current running plan.
568    fn greedy_optimize(&self, patterns: &[TriplePatternInfo]) -> Result<JoinPlanNode> {
569        if patterns.is_empty() {
570            return Err(anyhow!("Cannot optimize empty pattern list"));
571        }
572
573        // Sort patterns by estimated cardinality ascending (most selective first)
574        let mut remaining: Vec<TriplePatternInfo> = patterns.to_vec();
575        remaining.sort_by_key(|p| p.estimated_cardinality);
576
577        // Start with the most selective pattern
578        let first = remaining.remove(0);
579        let mut current_plan = JoinPlanNode::TriplePatternScan { info: first };
580
581        while !remaining.is_empty() {
582            // Find the next best pattern to join
583            let mut best_idx = 0;
584            let mut best_cost = f64::MAX;
585
586            let current_vars = current_plan.output_variables();
587            let current_card = current_plan.estimated_cardinality();
588
589            for (idx, candidate) in remaining.iter().enumerate() {
590                let join_vars =
591                    Self::find_join_variables_sets(&current_vars, &candidate.bound_variables);
592                let join_id = format!("g_{idx}_{}", candidate.id);
593                let selectivity = self
594                    .stats_store
595                    .get_adjusted_selectivity(&join_id, self.default_selectivity);
596
597                let algorithm = Self::select_join_algorithm(
598                    current_card,
599                    candidate.estimated_cardinality,
600                    &join_vars,
601                );
602                let cost = self.join_cost(
603                    0.0,
604                    current_card,
605                    candidate.estimated_cardinality,
606                    &algorithm,
607                );
608
609                // Prefer patterns that share variables (avoid cross products)
610                let adjusted_cost = if join_vars.is_empty() {
611                    cost * 1000.0
612                } else {
613                    cost * (1.0 + (1.0 - selectivity))
614                };
615
616                if adjusted_cost < best_cost {
617                    best_cost = adjusted_cost;
618                    best_idx = idx;
619                }
620            }
621
622            let next = remaining.remove(best_idx);
623            let join_vars = Self::find_join_variables_sets(&current_vars, &next.bound_variables);
624            let selectivity = self.stats_store.get_adjusted_selectivity(
625                &format!("g_{best_idx}_{}", next.id),
626                self.default_selectivity,
627            );
628            let next_card = next.estimated_cardinality;
629            let output_card =
630                ((current_card as f64 * next_card as f64 * selectivity).round() as u64).max(1);
631            let algorithm = Self::select_join_algorithm(current_card, next_card, &join_vars);
632            let right_plan = JoinPlanNode::TriplePatternScan { info: next };
633
634            current_plan =
635                self.build_join_plan(current_plan, right_plan, join_vars, output_card, algorithm);
636        }
637
638        Ok(current_plan)
639    }
640
641    /// Estimate cost of scanning a single triple pattern
642    fn scan_cost(&self, pattern: &TriplePatternInfo) -> f64 {
643        // Base cost: proportional to estimated cardinality
644        // Bound positions reduce cost due to index selectivity
645        let base = pattern.estimated_cardinality as f64;
646        let bound_factor = match pattern.bound_positions() {
647            0 => 1.0,  // full scan
648            1 => 0.3,  // one index lookup
649            2 => 0.05, // two-component lookup
650            _ => 0.01, // fully bound - cheap existence check
651        };
652        base * bound_factor
653    }
654
655    /// Estimate cost of a join given child costs and sizes
656    fn join_cost(
657        &self,
658        children_cost: f64,
659        left_card: u64,
660        right_card: u64,
661        algorithm: &JoinAlgorithm,
662    ) -> f64 {
663        let l = left_card as f64;
664        let r = right_card as f64;
665        match algorithm {
666            JoinAlgorithm::Hash => {
667                // Build: O(r), Probe: O(l)
668                children_cost + r + l
669            }
670            JoinAlgorithm::NestedLoop => {
671                // O(l * r) - expensive, only good when outer is tiny
672                children_cost + l * r
673            }
674            JoinAlgorithm::Merge => {
675                // Sort: O(n log n) each, Merge: O(n+m)
676                children_cost + l * l.max(1.0).ln() + r * r.max(1.0).ln() + l + r
677            }
678        }
679    }
680
681    /// Find shared variables between two variable sets
682    fn find_join_variables_sets(left: &[String], right: &[String]) -> Vec<String> {
683        left.iter().filter(|v| right.contains(v)).cloned().collect()
684    }
685
686    /// Select best join algorithm based on estimated cardinalities
687    pub fn select_join_algorithm(
688        left_card: u64,
689        right_card: u64,
690        join_vars: &[String],
691    ) -> JoinAlgorithm {
692        if join_vars.is_empty() {
693            // Cross product - nested loop is simplest for tiny outer
694            if left_card.min(right_card) < 100 {
695                return JoinAlgorithm::NestedLoop;
696            }
697            return JoinAlgorithm::Hash;
698        }
699
700        let smaller = left_card.min(right_card);
701        let larger = left_card.max(right_card);
702
703        if smaller < 1000 {
704            // Small build side - hash join is ideal
705            JoinAlgorithm::Hash
706        } else if smaller > 50_000 && larger > 50_000 {
707            // Both sides large - merge join amortizes sort cost
708            JoinAlgorithm::Merge
709        } else {
710            JoinAlgorithm::Hash
711        }
712    }
713
714    /// Build a JoinPlanNode from two sub-plans
715    fn build_join_plan(
716        &self,
717        left: JoinPlanNode,
718        right: JoinPlanNode,
719        join_vars: Vec<String>,
720        estimated_output: u64,
721        algorithm: JoinAlgorithm,
722    ) -> JoinPlanNode {
723        match algorithm {
724            JoinAlgorithm::Hash => JoinPlanNode::HashJoin {
725                left: Box::new(left),
726                right: Box::new(right),
727                join_vars,
728                estimated_output,
729            },
730            JoinAlgorithm::NestedLoop => {
731                // Ensure the smaller side is outer
732                JoinPlanNode::NestedLoopJoin {
733                    outer: Box::new(left),
734                    inner: Box::new(right),
735                    join_vars,
736                    estimated_output,
737                }
738            }
739            JoinAlgorithm::Merge => {
740                let sort_key = join_vars.clone();
741                JoinPlanNode::MergeJoin {
742                    left: Box::new(left),
743                    right: Box::new(right),
744                    join_vars,
745                    sort_key,
746                    estimated_output,
747                }
748            }
749        }
750    }
751}
752
753/// Execution timer for recording plan execution times
754pub struct PlanTimer {
755    component_id: String,
756    start: Instant,
757    stats_store: Arc<AdaptiveStatsStore>,
758}
759
760impl PlanTimer {
761    /// Start timing a plan component
762    pub fn start(component_id: impl Into<String>, stats_store: Arc<AdaptiveStatsStore>) -> Self {
763        Self {
764            component_id: component_id.into(),
765            start: Instant::now(),
766            stats_store,
767        }
768    }
769}
770
771impl Drop for PlanTimer {
772    fn drop(&mut self) {
773        let elapsed = self.start.elapsed();
774        self.stats_store
775            .record_execution_time(&self.component_id, elapsed);
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782    use crate::algebra::{Term, TriplePattern};
783    use oxirs_core::model::{NamedNode, Variable as CoreVariable};
784
785    fn make_var(name: &str) -> Term {
786        Term::Variable(CoreVariable::new(name).unwrap())
787    }
788
789    fn make_iri(iri: &str) -> Term {
790        Term::Iri(NamedNode::new_unchecked(iri))
791    }
792
793    fn pattern_info(
794        subject: PatternTerm,
795        predicate: PatternTerm,
796        object: PatternTerm,
797        cardinality: u64,
798    ) -> TriplePatternInfo {
799        let bound_variables: Vec<String> = [&subject, &predicate, &object]
800            .iter()
801            .filter_map(|t| t.variable_name().map(|s| s.to_string()))
802            .collect();
803        let id = format!("{:?}-{:?}-{:?}", subject, predicate, object);
804        TriplePatternInfo {
805            id,
806            subject,
807            predicate,
808            object,
809            estimated_cardinality: cardinality,
810            bound_variables,
811            original_pattern: None,
812        }
813    }
814
815    #[test]
816    fn test_adaptive_stats_store_record_and_adjust() {
817        let store = AdaptiveStatsStore::new(100);
818        store.record_pattern_execution("pat1", 1000, 500);
819
820        // Correction factor should be 0.5 after one sample
821        let adjusted = store.get_adjusted_cardinality("pat1", 1000);
822        assert_eq!(
823            adjusted, 500,
824            "Adjusted cardinality should reflect correction factor"
825        );
826    }
827
828    #[test]
829    fn test_adaptive_stats_store_unknown_pattern_returns_base() {
830        let store = AdaptiveStatsStore::new(100);
831        let adjusted = store.get_adjusted_cardinality("unknown_pat", 500);
832        assert_eq!(
833            adjusted, 500,
834            "Unknown pattern should return base estimate unchanged"
835        );
836    }
837
838    #[test]
839    fn test_adaptive_stats_store_join_selectivity() {
840        let store = AdaptiveStatsStore::new(100);
841        // actual output = 50, left = 100, right = 200 => selectivity = 50/20000 = 0.0025
842        store.record_join_execution("j1", 100, 200, 50);
843
844        let adjusted = store.get_adjusted_selectivity("j1", 0.1);
845        // Should blend base (0.1) with observed (0.0025)
846        assert!(
847            adjusted < 0.1,
848            "Adjusted selectivity should be reduced toward observed value"
849        );
850        assert!(adjusted > 0.0, "Adjusted selectivity must remain positive");
851    }
852
853    #[test]
854    fn test_single_pattern_optimization() {
855        let store = Arc::new(AdaptiveStatsStore::new(100));
856        let optimizer = AdaptiveJoinOrderOptimizer::new(store);
857
858        let patterns = vec![pattern_info(
859            PatternTerm::Variable("s".to_string()),
860            PatternTerm::Iri("http://example.org/type".to_string()),
861            PatternTerm::Variable("o".to_string()),
862            500,
863        )];
864
865        let plan = optimizer.optimize(patterns).unwrap();
866        assert!(matches!(plan, JoinPlanNode::TriplePatternScan { .. }));
867    }
868
869    #[test]
870    fn test_two_pattern_dp_optimization() {
871        let store = Arc::new(AdaptiveStatsStore::new(100));
872        let optimizer = AdaptiveJoinOrderOptimizer::new(store);
873
874        let patterns = vec![
875            pattern_info(
876                PatternTerm::Variable("s".to_string()),
877                PatternTerm::Iri("http://example.org/type".to_string()),
878                PatternTerm::Iri("http://example.org/Person".to_string()),
879                50,
880            ),
881            pattern_info(
882                PatternTerm::Variable("s".to_string()),
883                PatternTerm::Iri("http://xmlns.com/foaf/0.1/name".to_string()),
884                PatternTerm::Variable("name".to_string()),
885                10000,
886            ),
887        ];
888
889        let plan = optimizer.optimize(patterns).unwrap();
890        // Should produce a join node
891        assert!(
892            matches!(
893                plan,
894                JoinPlanNode::HashJoin { .. }
895                    | JoinPlanNode::NestedLoopJoin { .. }
896                    | JoinPlanNode::MergeJoin { .. }
897            ),
898            "Should produce a join plan"
899        );
900    }
901
902    #[test]
903    fn test_greedy_optimization_for_large_pattern_sets() {
904        let store = Arc::new(AdaptiveStatsStore::new(100));
905        let optimizer = AdaptiveJoinOrderOptimizer::new(store).with_dp_threshold(3);
906
907        let patterns: Vec<TriplePatternInfo> = (0..6)
908            .map(|i| {
909                pattern_info(
910                    PatternTerm::Variable(format!("s{i}")),
911                    PatternTerm::Iri(format!("http://example.org/p{i}")),
912                    PatternTerm::Variable(format!("o{i}")),
913                    (i + 1) as u64 * 100,
914                )
915            })
916            .collect();
917
918        let plan = optimizer.optimize(patterns).unwrap();
919        // Should produce some kind of join
920        assert!(
921            !matches!(plan, JoinPlanNode::TriplePatternScan { .. }),
922            "Multiple patterns should produce a join plan"
923        );
924    }
925
926    #[test]
927    fn test_empty_patterns_returns_error() {
928        let store = Arc::new(AdaptiveStatsStore::new(100));
929        let optimizer = AdaptiveJoinOrderOptimizer::new(store);
930        assert!(optimizer.optimize(vec![]).is_err());
931    }
932
933    #[test]
934    fn test_join_algorithm_selection() {
935        // Small build side -> hash join
936        let alg =
937            AdaptiveJoinOrderOptimizer::select_join_algorithm(100, 1_000_000, &["x".to_string()]);
938        assert_eq!(alg, JoinAlgorithm::Hash);
939
940        // Both large -> merge join
941        let alg =
942            AdaptiveJoinOrderOptimizer::select_join_algorithm(100_000, 200_000, &["x".to_string()]);
943        assert_eq!(alg, JoinAlgorithm::Merge);
944    }
945
946    #[test]
947    fn test_from_triple_pattern() {
948        let pattern = TriplePattern::new(
949            make_var("s"),
950            make_iri("http://example.org/p"),
951            make_var("o"),
952        );
953        let info = TriplePatternInfo::from_triple_pattern(&pattern, 100);
954        assert_eq!(info.estimated_cardinality, 100);
955        assert!(info.bound_variables.contains(&"s".to_string()));
956        assert!(info.bound_variables.contains(&"o".to_string()));
957        assert_eq!(info.bound_positions(), 1); // predicate is bound
958    }
959
960    #[test]
961    fn test_cardinality_correction_with_multiple_samples() {
962        let store = AdaptiveStatsStore::new(100);
963        // Repeat with consistent underestimation (estimated 100, actual 200)
964        for _ in 0..5 {
965            store.record_pattern_execution("pat2", 100, 200);
966        }
967        let adjusted = store.get_adjusted_cardinality("pat2", 100);
968        // Correction factor should converge toward 2.0
969        assert!(adjusted > 100, "Cardinality should be adjusted upward");
970    }
971
972    #[test]
973    fn test_plan_timer_records_duration() {
974        let store = Arc::new(AdaptiveStatsStore::new(100));
975        {
976            let _timer = PlanTimer::start("test_component", Arc::clone(&store));
977            std::thread::sleep(std::time::Duration::from_millis(5));
978        }
979        let snapshot = store.snapshot().unwrap();
980        assert!(
981            snapshot.execution_times.contains_key("test_component"),
982            "Timer should record execution time on drop"
983        );
984    }
985
986    #[test]
987    fn test_output_variables_propagation() {
988        let store = Arc::new(AdaptiveStatsStore::new(100));
989        let optimizer = AdaptiveJoinOrderOptimizer::new(store);
990
991        let patterns = vec![
992            pattern_info(
993                PatternTerm::Variable("s".to_string()),
994                PatternTerm::Iri("http://example.org/type".to_string()),
995                PatternTerm::Variable("type".to_string()),
996                100,
997            ),
998            pattern_info(
999                PatternTerm::Variable("s".to_string()),
1000                PatternTerm::Iri("http://example.org/name".to_string()),
1001                PatternTerm::Variable("name".to_string()),
1002                500,
1003            ),
1004        ];
1005
1006        let plan = optimizer.optimize(patterns).unwrap();
1007        let vars = plan.output_variables();
1008        // Both ?s and ?type and ?name should appear
1009        assert!(vars.contains(&"s".to_string()), "Plan should expose ?s");
1010        assert!(
1011            vars.contains(&"name".to_string()),
1012            "Plan should expose ?name"
1013        );
1014    }
1015}
1016
1017#[cfg(test)]
1018mod extended_tests {
1019    use super::*;
1020    use crate::algebra::{Term, TriplePattern};
1021    use oxirs_core::model::{NamedNode, Variable as CoreVariable};
1022
1023    fn make_var(name: &str) -> Term {
1024        Term::Variable(CoreVariable::new(name).unwrap())
1025    }
1026
1027    fn make_iri(iri: &str) -> Term {
1028        Term::Iri(NamedNode::new_unchecked(iri))
1029    }
1030
1031    fn p_info(
1032        subject: PatternTerm,
1033        predicate: PatternTerm,
1034        object: PatternTerm,
1035        cardinality: u64,
1036    ) -> TriplePatternInfo {
1037        let bound_variables: Vec<String> = [&subject, &predicate, &object]
1038            .iter()
1039            .filter_map(|t| t.variable_name().map(|s| s.to_string()))
1040            .collect();
1041        let id = format!("{:?}-{:?}-{:?}", subject, predicate, object);
1042        TriplePatternInfo {
1043            id,
1044            subject,
1045            predicate,
1046            object,
1047            estimated_cardinality: cardinality,
1048            bound_variables,
1049            original_pattern: None,
1050        }
1051    }
1052
1053    // --- AdaptiveStatsStore tests ---
1054
1055    #[test]
1056    fn test_stats_snapshot_contains_recorded_pattern() {
1057        let store = AdaptiveStatsStore::new(50);
1058        store.record_pattern_execution("snap_pat", 200, 400);
1059
1060        let snapshot = store.snapshot().unwrap();
1061        assert!(snapshot.pattern_stats.contains_key("snap_pat"));
1062        let entry = &snapshot.pattern_stats["snap_pat"];
1063        assert_eq!(entry.sample_count, 1);
1064        assert_eq!(entry.actual_cardinality_sum, 400);
1065    }
1066
1067    #[test]
1068    fn test_stats_snapshot_contains_recorded_join() {
1069        let store = AdaptiveStatsStore::new(50);
1070        store.record_join_execution("j_snap", 1000, 500, 25);
1071
1072        let snapshot = store.snapshot().unwrap();
1073        assert!(snapshot.join_stats.contains_key("j_snap"));
1074        let entry = &snapshot.join_stats["j_snap"];
1075        assert_eq!(entry.sample_count, 1);
1076        assert_eq!(entry.output_cardinality_sum, 25);
1077    }
1078
1079    #[test]
1080    fn test_correction_factor_clamped_above_zero() {
1081        let store = AdaptiveStatsStore::new(50);
1082        // Extreme overestimation (estimated 1 million, actual 1)
1083        store.record_pattern_execution("extreme_over", 1_000_000, 1);
1084        let adjusted = store.get_adjusted_cardinality("extreme_over", 1_000_000);
1085        assert!(adjusted >= 1, "Adjusted cardinality must be at least 1");
1086    }
1087
1088    #[test]
1089    fn test_multiple_patterns_tracked_independently() {
1090        let store = AdaptiveStatsStore::new(50);
1091        store.record_pattern_execution("pat_a", 100, 50);
1092        store.record_pattern_execution("pat_b", 100, 300);
1093
1094        let adj_a = store.get_adjusted_cardinality("pat_a", 100);
1095        let adj_b = store.get_adjusted_cardinality("pat_b", 100);
1096        assert!(
1097            adj_a < adj_b,
1098            "pat_a (undercount) should produce lower estimate than pat_b (overcount)"
1099        );
1100    }
1101
1102    #[test]
1103    fn test_execution_time_recorded_via_snapshot() {
1104        let store = AdaptiveStatsStore::new(50);
1105        store.record_execution_time("component_x", std::time::Duration::from_millis(42));
1106        let snapshot = store.snapshot().unwrap();
1107        assert!(snapshot.execution_times.contains_key("component_x"));
1108        assert_eq!(
1109            snapshot.execution_times["component_x"],
1110            std::time::Duration::from_millis(42)
1111        );
1112    }
1113
1114    #[test]
1115    fn test_join_selectivity_unknown_join_returns_base() {
1116        let store = AdaptiveStatsStore::new(50);
1117        let base = 0.05;
1118        let adj = store.get_adjusted_selectivity("no_such_join", base);
1119        assert!(
1120            (adj - base).abs() < 1e-9,
1121            "Unknown join should return base selectivity unchanged"
1122        );
1123    }
1124
1125    #[test]
1126    fn test_join_selectivity_clamps_to_valid_range() {
1127        let store = AdaptiveStatsStore::new(50);
1128        // Very low actual output => very low observed selectivity
1129        for _ in 0..20 {
1130            store.record_join_execution("tiny_sel", 1_000_000, 1_000_000, 1);
1131        }
1132        let adj = store.get_adjusted_selectivity("tiny_sel", 0.5);
1133        assert!(adj > 0.0, "Selectivity must remain positive");
1134        assert!(adj <= 1.0, "Selectivity must not exceed 1.0");
1135    }
1136
1137    // --- PatternTerm tests ---
1138
1139    #[test]
1140    fn test_pattern_term_iri_is_not_variable() {
1141        let term = PatternTerm::Iri("http://example.org/foo".to_string());
1142        assert!(!term.is_variable());
1143        assert!(term.variable_name().is_none());
1144    }
1145
1146    #[test]
1147    fn test_pattern_term_literal_is_not_variable() {
1148        let term = PatternTerm::Literal("hello".to_string());
1149        assert!(!term.is_variable());
1150        assert!(term.variable_name().is_none());
1151    }
1152
1153    #[test]
1154    fn test_pattern_term_blank_node_is_not_variable() {
1155        let term = PatternTerm::BlankNode("b1".to_string());
1156        assert!(!term.is_variable());
1157        assert!(term.variable_name().is_none());
1158    }
1159
1160    #[test]
1161    fn test_triple_pattern_info_bound_positions_fully_bound() {
1162        let info = p_info(
1163            PatternTerm::Iri("http://s".to_string()),
1164            PatternTerm::Iri("http://p".to_string()),
1165            PatternTerm::Literal("val".to_string()),
1166            10,
1167        );
1168        assert_eq!(info.bound_positions(), 3, "All positions are bound");
1169    }
1170
1171    #[test]
1172    fn test_triple_pattern_info_bound_positions_no_variables() {
1173        let info = p_info(
1174            PatternTerm::Variable("s".to_string()),
1175            PatternTerm::Variable("p".to_string()),
1176            PatternTerm::Variable("o".to_string()),
1177            100,
1178        );
1179        assert_eq!(
1180            info.bound_positions(),
1181            0,
1182            "No positions are bound when all are variables"
1183        );
1184    }
1185
1186    #[test]
1187    fn test_from_triple_pattern_literal_object() {
1188        let pattern = TriplePattern::new(
1189            make_var("s"),
1190            make_iri("http://example.org/p"),
1191            make_iri("http://example.org/o"),
1192        );
1193        let info = TriplePatternInfo::from_triple_pattern(&pattern, 42);
1194        assert_eq!(info.estimated_cardinality, 42);
1195        // subject is variable
1196        assert!(info.bound_variables.contains(&"s".to_string()));
1197    }
1198
1199    // --- JoinPlanNode tests ---
1200
1201    #[test]
1202    fn test_join_plan_node_hash_join_estimated_cardinality() {
1203        let left = JoinPlanNode::TriplePatternScan {
1204            info: p_info(
1205                PatternTerm::Variable("s".to_string()),
1206                PatternTerm::Iri("http://p".to_string()),
1207                PatternTerm::Variable("o".to_string()),
1208                100,
1209            ),
1210        };
1211        let right = JoinPlanNode::TriplePatternScan {
1212            info: p_info(
1213                PatternTerm::Variable("s".to_string()),
1214                PatternTerm::Iri("http://q".to_string()),
1215                PatternTerm::Variable("x".to_string()),
1216                200,
1217            ),
1218        };
1219        let node = JoinPlanNode::HashJoin {
1220            left: Box::new(left),
1221            right: Box::new(right),
1222            join_vars: vec!["s".to_string()],
1223            estimated_output: 50,
1224        };
1225        assert_eq!(node.estimated_cardinality(), 50);
1226    }
1227
1228    #[test]
1229    fn test_join_plan_nested_loop_output_variables() {
1230        let outer = JoinPlanNode::TriplePatternScan {
1231            info: p_info(
1232                PatternTerm::Variable("s".to_string()),
1233                PatternTerm::Iri("http://p".to_string()),
1234                PatternTerm::Variable("o".to_string()),
1235                100,
1236            ),
1237        };
1238        let inner = JoinPlanNode::TriplePatternScan {
1239            info: p_info(
1240                PatternTerm::Variable("o".to_string()),
1241                PatternTerm::Iri("http://q".to_string()),
1242                PatternTerm::Variable("z".to_string()),
1243                50,
1244            ),
1245        };
1246        let node = JoinPlanNode::NestedLoopJoin {
1247            outer: Box::new(outer),
1248            inner: Box::new(inner),
1249            join_vars: vec!["o".to_string()],
1250            estimated_output: 30,
1251        };
1252        let vars = node.output_variables();
1253        assert!(vars.contains(&"s".to_string()), "Should contain s");
1254        assert!(vars.contains(&"o".to_string()), "Should contain o");
1255        assert!(vars.contains(&"z".to_string()), "Should contain z");
1256    }
1257
1258    // --- AdaptiveJoinOrderOptimizer tests ---
1259
1260    #[test]
1261    fn test_optimizer_selects_lower_cardinality_pattern_first() {
1262        let store = Arc::new(AdaptiveStatsStore::new(50));
1263        let optimizer = AdaptiveJoinOrderOptimizer::new(Arc::clone(&store));
1264
1265        let patterns = vec![
1266            p_info(
1267                PatternTerm::Variable("s".to_string()),
1268                PatternTerm::Iri("http://rare".to_string()),
1269                PatternTerm::Variable("o1".to_string()),
1270                5, // low cardinality - should be leaf
1271            ),
1272            p_info(
1273                PatternTerm::Variable("s".to_string()),
1274                PatternTerm::Iri("http://common".to_string()),
1275                PatternTerm::Variable("o2".to_string()),
1276                50_000, // high cardinality
1277            ),
1278        ];
1279
1280        let plan = optimizer.optimize(patterns).unwrap();
1281        // The lower-cardinality side should be in the plan somewhere
1282        // We verify the plan contains a join (not just a scan)
1283        assert!(
1284            matches!(
1285                plan,
1286                JoinPlanNode::HashJoin { .. }
1287                    | JoinPlanNode::NestedLoopJoin { .. }
1288                    | JoinPlanNode::MergeJoin { .. }
1289            ),
1290            "Two patterns should produce a join plan"
1291        );
1292    }
1293
1294    #[test]
1295    fn test_optimizer_dp_threshold_boundary() {
1296        // At exactly the DP threshold, optimizer uses DP
1297        let store = Arc::new(AdaptiveStatsStore::new(50));
1298        let optimizer = AdaptiveJoinOrderOptimizer::new(Arc::clone(&store)).with_dp_threshold(4);
1299
1300        let patterns: Vec<TriplePatternInfo> = (0..4)
1301            .map(|i| {
1302                p_info(
1303                    PatternTerm::Variable(format!("s{i}")),
1304                    PatternTerm::Iri(format!("http://p{i}")),
1305                    PatternTerm::Variable(format!("o{i}")),
1306                    (i + 1) as u64 * 50,
1307                )
1308            })
1309            .collect();
1310
1311        let result = optimizer.optimize(patterns);
1312        assert!(
1313            result.is_ok(),
1314            "DP optimization at threshold should succeed"
1315        );
1316    }
1317
1318    #[test]
1319    fn test_optimizer_uses_runtime_feedback_for_ordering() {
1320        let store = Arc::new(AdaptiveStatsStore::new(50));
1321        // Record that "pat_heavy" actually has far more rows than estimated
1322        store.record_pattern_execution("? http://heavy ?", 10, 100_000);
1323
1324        let optimizer = AdaptiveJoinOrderOptimizer::new(Arc::clone(&store));
1325        let patterns = vec![
1326            p_info(
1327                PatternTerm::Variable("s".to_string()),
1328                PatternTerm::Iri("http://heavy".to_string()),
1329                PatternTerm::Variable("o".to_string()),
1330                10, // low estimate, but runtime says 100_000
1331            ),
1332            p_info(
1333                PatternTerm::Variable("s".to_string()),
1334                PatternTerm::Iri("http://light".to_string()),
1335                PatternTerm::Variable("x".to_string()),
1336                500,
1337            ),
1338        ];
1339        let result = optimizer.optimize(patterns);
1340        assert!(
1341            result.is_ok(),
1342            "Optimizer should succeed with runtime feedback"
1343        );
1344    }
1345}