ipfrs_tensorlogic/
optimizer.rs

1//! Query optimization for TensorLogic
2//!
3//! Optimizes logical queries by:
4//! - Reordering predicates in rule bodies for better performance
5//! - Selecting optimal join orders
6//! - Estimating predicate selectivity
7//! - Cost-based query planning
8//! - Cardinality estimation
9//! - Statistics tracking
10//! - Materialized views for common queries
11
12use crate::ir::{KnowledgeBase, Predicate, Rule, Term};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::time::{Duration, SystemTime};
16
17/// Statistics for a single predicate
18#[derive(Debug, Clone, Serialize, Deserialize, Default)]
19pub struct PredicateStats {
20    /// Number of facts for this predicate
21    pub fact_count: usize,
22    /// Number of rules with this predicate as head
23    pub rule_count: usize,
24    /// Average arity (number of arguments)
25    pub avg_arity: f64,
26    /// Estimated cardinality after filtering
27    pub estimated_cardinality: f64,
28    /// Selectivity (0.0 = highly selective, 1.0 = not selective)
29    pub selectivity: f64,
30}
31
32impl PredicateStats {
33    /// Create new stats
34    pub fn new(fact_count: usize, rule_count: usize, avg_arity: f64) -> Self {
35        Self {
36            fact_count,
37            rule_count,
38            avg_arity,
39            estimated_cardinality: fact_count as f64,
40            selectivity: 1.0,
41        }
42    }
43
44    /// Compute selectivity based on total facts
45    #[inline]
46    pub fn compute_selectivity(&mut self, total_facts: usize) {
47        if total_facts == 0 {
48            self.selectivity = 1.0;
49        } else {
50            self.selectivity = self.fact_count as f64 / total_facts as f64;
51        }
52    }
53}
54
55/// Query plan node representing a single operation
56#[derive(Debug, Clone)]
57pub enum PlanNode {
58    /// Scan a predicate (fact lookup)
59    Scan {
60        predicate: String,
61        bound_vars: Vec<String>,
62        estimated_rows: f64,
63    },
64    /// Join two plans
65    Join {
66        left: Box<PlanNode>,
67        right: Box<PlanNode>,
68        join_vars: Vec<String>,
69        estimated_rows: f64,
70    },
71    /// Filter results
72    Filter {
73        input: Box<PlanNode>,
74        condition: Predicate,
75        estimated_rows: f64,
76    },
77}
78
79impl PlanNode {
80    /// Get estimated row count
81    #[inline]
82    pub fn estimated_rows(&self) -> f64 {
83        match self {
84            PlanNode::Scan { estimated_rows, .. } => *estimated_rows,
85            PlanNode::Join { estimated_rows, .. } => *estimated_rows,
86            PlanNode::Filter { estimated_rows, .. } => *estimated_rows,
87        }
88    }
89
90    /// Compute cost of this plan
91    pub fn cost(&self) -> f64 {
92        match self {
93            PlanNode::Scan { estimated_rows, .. } => *estimated_rows,
94            PlanNode::Join {
95                left,
96                right,
97                estimated_rows,
98                ..
99            } => left.cost() + right.cost() + *estimated_rows,
100            PlanNode::Filter {
101                input,
102                estimated_rows,
103                ..
104            } => input.cost() + *estimated_rows * 0.1,
105        }
106    }
107}
108
109/// Query plan for a goal
110#[derive(Debug, Clone)]
111pub struct QueryPlan {
112    /// Root of the plan tree
113    pub root: PlanNode,
114    /// Total estimated cost
115    pub estimated_cost: f64,
116    /// Estimated result cardinality
117    pub estimated_rows: f64,
118    /// Variables that will be bound
119    pub output_vars: Vec<String>,
120}
121
122impl QueryPlan {
123    /// Create a new query plan
124    pub fn new(root: PlanNode) -> Self {
125        let estimated_cost = root.cost();
126        let estimated_rows = root.estimated_rows();
127        Self {
128            root,
129            estimated_cost,
130            estimated_rows,
131            output_vars: Vec::new(),
132        }
133    }
134
135    /// Create with output variables
136    pub fn with_vars(root: PlanNode, output_vars: Vec<String>) -> Self {
137        let estimated_cost = root.cost();
138        let estimated_rows = root.estimated_rows();
139        Self {
140            root,
141            estimated_cost,
142            estimated_rows,
143            output_vars,
144        }
145    }
146}
147
148/// Query optimizer for TensorLogic
149pub struct QueryOptimizer {
150    /// Statistics about predicates
151    predicate_stats: HashMap<String, PredicateStats>,
152    /// Total facts in knowledge base
153    total_facts: usize,
154    /// Selectivity cache (for backwards compatibility)
155    selectivity_cache: HashMap<String, f64>,
156}
157
158impl QueryOptimizer {
159    /// Create a new query optimizer
160    #[inline]
161    pub fn new() -> Self {
162        Self {
163            predicate_stats: HashMap::new(),
164            total_facts: 0,
165            selectivity_cache: HashMap::new(),
166        }
167    }
168
169    /// Create a query plan for a conjunction of goals
170    pub fn plan_query(&self, goals: &[Predicate], kb: &KnowledgeBase) -> QueryPlan {
171        if goals.is_empty() {
172            return QueryPlan::new(PlanNode::Scan {
173                predicate: "empty".to_string(),
174                bound_vars: Vec::new(),
175                estimated_rows: 0.0,
176            });
177        }
178
179        if goals.len() == 1 {
180            return self.plan_single_goal(&goals[0], kb);
181        }
182
183        // Order goals by selectivity
184        let ordered = self.optimize_goal(goals.to_vec(), kb);
185
186        // Build join plan
187        let mut current_plan = self.plan_single_goal(&ordered[0], kb);
188
189        for goal in ordered.iter().skip(1) {
190            let right_plan = self.plan_single_goal(goal, kb);
191
192            // Find join variables
193            let join_vars = self.find_join_vars(&current_plan, &right_plan, goal);
194
195            // Estimate join cardinality
196            let estimated_rows = self.estimate_join_cardinality(
197                current_plan.estimated_rows,
198                right_plan.estimated_rows,
199                &join_vars,
200            );
201
202            current_plan = QueryPlan::new(PlanNode::Join {
203                left: Box::new(current_plan.root),
204                right: Box::new(right_plan.root),
205                join_vars,
206                estimated_rows,
207            });
208        }
209
210        current_plan
211    }
212
213    /// Plan a single goal
214    fn plan_single_goal(&self, goal: &Predicate, kb: &KnowledgeBase) -> QueryPlan {
215        let fact_count = kb.get_predicates(&goal.name).len();
216        let groundness = self.compute_groundness(goal);
217
218        // Estimate rows based on fact count and groundness
219        let estimated_rows = if groundness >= 1.0 {
220            1.0 // Fully ground query returns at most 1 result
221        } else {
222            fact_count as f64 * (1.0 - groundness + 0.1)
223        };
224
225        let bound_vars: Vec<String> = goal
226            .args
227            .iter()
228            .filter_map(|t| {
229                if let Term::Var(v) = t {
230                    Some(v.clone())
231                } else {
232                    None
233                }
234            })
235            .collect();
236
237        QueryPlan::with_vars(
238            PlanNode::Scan {
239                predicate: goal.name.clone(),
240                bound_vars: bound_vars.clone(),
241                estimated_rows,
242            },
243            bound_vars,
244        )
245    }
246
247    /// Find variables that join two plans
248    fn find_join_vars(
249        &self,
250        left: &QueryPlan,
251        _right: &QueryPlan,
252        right_goal: &Predicate,
253    ) -> Vec<String> {
254        let mut join_vars = Vec::new();
255        for var in &left.output_vars {
256            for arg in &right_goal.args {
257                if let Term::Var(v) = arg {
258                    if v == var {
259                        join_vars.push(var.clone());
260                    }
261                }
262            }
263        }
264        join_vars
265    }
266
267    /// Estimate cardinality of a join
268    fn estimate_join_cardinality(
269        &self,
270        left_rows: f64,
271        right_rows: f64,
272        join_vars: &[String],
273    ) -> f64 {
274        if join_vars.is_empty() {
275            // Cross product
276            left_rows * right_rows
277        } else {
278            // Estimated selectivity based on join
279            let selectivity = 0.1_f64.powi(join_vars.len() as i32);
280            (left_rows * right_rows * selectivity).max(1.0)
281        }
282    }
283
284    /// Get predicate statistics
285    #[inline]
286    pub fn get_stats(&self, predicate_name: &str) -> Option<&PredicateStats> {
287        self.predicate_stats.get(predicate_name)
288    }
289
290    /// Get all statistics
291    #[inline]
292    pub fn all_stats(&self) -> &HashMap<String, PredicateStats> {
293        &self.predicate_stats
294    }
295
296    /// Estimate cardinality for a predicate
297    pub fn estimate_cardinality(&self, predicate: &Predicate, kb: &KnowledgeBase) -> f64 {
298        let fact_count = kb.get_predicates(&predicate.name).len() as f64;
299        let groundness = self.compute_groundness(predicate);
300
301        // More ground args = lower cardinality
302        fact_count * (1.0 - groundness + 0.1)
303    }
304
305    /// Optimize a rule by reordering its body predicates
306    ///
307    /// Reorders predicates to put more selective ones first,
308    /// reducing the intermediate result set size
309    pub fn optimize_rule(&self, rule: &Rule, kb: &KnowledgeBase) -> Rule {
310        if rule.body.is_empty() {
311            return rule.clone();
312        }
313
314        let body = rule.body.clone();
315
316        // Compute selectivity scores for each predicate
317        let mut scores: Vec<(usize, f64)> = body
318            .iter()
319            .enumerate()
320            .map(|(i, pred)| (i, self.estimate_selectivity(pred, kb)))
321            .collect();
322
323        // Sort by selectivity (most selective first)
324        scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
325
326        // Reorder body predicates
327        let optimized_body: Vec<Predicate> = scores.iter().map(|(i, _)| body[*i].clone()).collect();
328
329        Rule::new(rule.head.clone(), optimized_body)
330    }
331
332    /// Estimate selectivity of a predicate
333    ///
334    /// Returns a value where LOWER is more selective (should execute first)
335    /// Higher values indicate less selective predicates
336    fn estimate_selectivity(&self, predicate: &Predicate, kb: &KnowledgeBase) -> f64 {
337        // Check cache first
338        if let Some(&selectivity) = self.selectivity_cache.get(&predicate.name) {
339            return selectivity;
340        }
341
342        // Count facts with this predicate name
343        let fact_count = kb.get_predicates(&predicate.name).len();
344
345        // Estimate based on groundness and fact count
346        let groundness = self.compute_groundness(predicate);
347
348        // More facts = less selective (higher score)
349        // More ground = more selective (lower score)
350        let fact_factor = if fact_count == 0 {
351            100.0 // Unknown predicates assumed least selective
352        } else {
353            fact_count as f64
354        };
355
356        // Combine: less ground + more facts = higher (less selective) score
357        fact_factor * (1.0 - groundness + 0.1)
358    }
359
360    /// Compute how "ground" a predicate is (0.0 = all variables, 1.0 = all constants)
361    #[inline]
362    fn compute_groundness(&self, predicate: &Predicate) -> f64 {
363        if predicate.args.is_empty() {
364            return 1.0;
365        }
366
367        let ground_count = predicate.args.iter().filter(|t| t.is_ground()).count();
368        ground_count as f64 / predicate.args.len() as f64
369    }
370
371    /// Update selectivity statistics from a knowledge base
372    pub fn update_statistics(&mut self, kb: &KnowledgeBase) {
373        // Clear old stats
374        self.selectivity_cache.clear();
375        self.predicate_stats.clear();
376        self.total_facts = kb.facts.len();
377
378        // Count facts by predicate name
379        let mut fact_counts: HashMap<String, usize> = HashMap::new();
380        let mut arity_sums: HashMap<String, usize> = HashMap::new();
381
382        for fact in &kb.facts {
383            *fact_counts.entry(fact.name.clone()).or_insert(0) += 1;
384            *arity_sums.entry(fact.name.clone()).or_insert(0) += fact.args.len();
385        }
386
387        // Count rules by head predicate
388        let mut rule_counts: HashMap<String, usize> = HashMap::new();
389        for rule in &kb.rules {
390            *rule_counts.entry(rule.head.name.clone()).or_insert(0) += 1;
391        }
392
393        let total_facts = kb.facts.len() as f64;
394        if total_facts == 0.0 {
395            return;
396        }
397
398        // Build predicate stats
399        let all_predicates: std::collections::HashSet<_> = fact_counts
400            .keys()
401            .chain(rule_counts.keys())
402            .cloned()
403            .collect();
404
405        for name in all_predicates {
406            let fact_count = *fact_counts.get(&name).unwrap_or(&0);
407            let rule_count = *rule_counts.get(&name).unwrap_or(&0);
408            let arity_sum = *arity_sums.get(&name).unwrap_or(&0);
409            let avg_arity = if fact_count > 0 {
410                arity_sum as f64 / fact_count as f64
411            } else {
412                0.0
413            };
414
415            let mut stats = PredicateStats::new(fact_count, rule_count, avg_arity);
416            stats.compute_selectivity(self.total_facts);
417
418            // Also update selectivity cache for backwards compatibility
419            self.selectivity_cache
420                .insert(name.clone(), stats.selectivity);
421            self.predicate_stats.insert(name, stats);
422        }
423    }
424
425    /// Get total fact count
426    #[inline]
427    pub fn total_facts(&self) -> usize {
428        self.total_facts
429    }
430
431    /// Optimize a query goal
432    ///
433    /// For complex goals with multiple predicates, reorder them optimally
434    pub fn optimize_goal(&self, goals: Vec<Predicate>, kb: &KnowledgeBase) -> Vec<Predicate> {
435        if goals.len() <= 1 {
436            return goals;
437        }
438
439        let mut scored: Vec<(Predicate, f64)> = goals
440            .into_iter()
441            .map(|p| {
442                let score = self.estimate_selectivity(&p, kb);
443                (p, score)
444            })
445            .collect();
446
447        // Sort by selectivity (most selective first)
448        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
449
450        scored.into_iter().map(|(p, _)| p).collect()
451    }
452
453    /// Get optimization recommendations for a knowledge base
454    pub fn get_recommendations(&self, kb: &KnowledgeBase) -> Vec<OptimizationRecommendation> {
455        let mut recommendations = Vec::new();
456
457        for rule in &kb.rules {
458            if rule.body.len() > 1 {
459                let optimized = self.optimize_rule(rule, kb);
460
461                // Check if order changed
462                let changed = rule
463                    .body
464                    .iter()
465                    .zip(optimized.body.iter())
466                    .any(|(a, b)| a.name != b.name);
467
468                if changed {
469                    recommendations.push(OptimizationRecommendation {
470                        rule_head: rule.head.name.clone(),
471                        original_order: rule.body.iter().map(|p| p.name.clone()).collect(),
472                        optimized_order: optimized.body.iter().map(|p| p.name.clone()).collect(),
473                        estimated_improvement: 0.5, // Simplified estimate
474                    });
475                }
476            }
477        }
478
479        recommendations
480    }
481}
482
483impl Default for QueryOptimizer {
484    fn default() -> Self {
485        Self::new()
486    }
487}
488
489/// Optimization recommendation
490#[derive(Debug, Clone)]
491pub struct OptimizationRecommendation {
492    /// Name of the rule head
493    pub rule_head: String,
494    /// Original predicate order
495    pub original_order: Vec<String>,
496    /// Optimized predicate order
497    pub optimized_order: Vec<String>,
498    /// Estimated improvement factor (higher = better)
499    pub estimated_improvement: f64,
500}
501
502/// Materialized view metadata
503#[derive(Debug, Clone, Serialize, Deserialize)]
504pub struct MaterializedView {
505    /// Unique view name
506    pub name: String,
507    /// Query pattern that defines this view
508    pub query: Vec<Predicate>,
509    /// Precomputed results
510    pub results: Vec<Vec<Term>>,
511    /// Time when the view was created/refreshed
512    pub last_refresh: SystemTime,
513    /// Time-to-live before refresh needed
514    pub ttl: Option<Duration>,
515    /// Statistics about view usage
516    pub access_count: usize,
517    /// Total cost saved by using this view
518    pub total_cost_saved: f64,
519}
520
521impl MaterializedView {
522    /// Create a new materialized view
523    pub fn new(name: String, query: Vec<Predicate>) -> Self {
524        Self {
525            name,
526            query,
527            results: Vec::new(),
528            last_refresh: SystemTime::now(),
529            ttl: None,
530            access_count: 0,
531            total_cost_saved: 0.0,
532        }
533    }
534
535    /// Create with TTL
536    pub fn with_ttl(name: String, query: Vec<Predicate>, ttl: Duration) -> Self {
537        Self {
538            name,
539            query,
540            results: Vec::new(),
541            last_refresh: SystemTime::now(),
542            ttl: Some(ttl),
543            access_count: 0,
544            total_cost_saved: 0.0,
545        }
546    }
547
548    /// Check if view needs refresh based on TTL
549    pub fn needs_refresh(&self) -> bool {
550        if let Some(ttl) = self.ttl {
551            if let Ok(elapsed) = self.last_refresh.elapsed() {
552                return elapsed > ttl;
553            }
554        }
555        false
556    }
557
558    /// Refresh the view with new results
559    pub fn refresh(&mut self, results: Vec<Vec<Term>>) {
560        self.results = results;
561        self.last_refresh = SystemTime::now();
562    }
563
564    /// Record a view access
565    #[inline]
566    pub fn record_access(&mut self, cost_saved: f64) {
567        self.access_count += 1;
568        self.total_cost_saved += cost_saved;
569    }
570
571    /// Check if query matches this view
572    pub fn matches_query(&self, query: &[Predicate]) -> bool {
573        if self.query.len() != query.len() {
574            return false;
575        }
576
577        self.query
578            .iter()
579            .zip(query.iter())
580            .all(|(a, b)| a.name == b.name && a.args.len() == b.args.len())
581    }
582}
583
584/// Materialized view manager
585pub struct MaterializedViewManager {
586    /// All materialized views
587    views: HashMap<String, MaterializedView>,
588    /// Max number of views to maintain
589    max_views: usize,
590    /// Minimum access count to keep a view
591    min_access_threshold: usize,
592}
593
594impl MaterializedViewManager {
595    /// Create a new view manager
596    pub fn new(max_views: usize) -> Self {
597        Self {
598            views: HashMap::new(),
599            max_views,
600            min_access_threshold: 5,
601        }
602    }
603
604    /// Create a materialized view
605    pub fn create_view(
606        &mut self,
607        name: String,
608        query: Vec<Predicate>,
609        ttl: Option<Duration>,
610    ) -> Result<(), String> {
611        if self.views.contains_key(&name) {
612            return Err(format!("View '{}' already exists", name));
613        }
614
615        // Enforce max views limit
616        if self.views.len() >= self.max_views {
617            self.evict_least_useful_view();
618        }
619
620        let view = if let Some(ttl) = ttl {
621            MaterializedView::with_ttl(name.clone(), query, ttl)
622        } else {
623            MaterializedView::new(name.clone(), query)
624        };
625
626        self.views.insert(name, view);
627        Ok(())
628    }
629
630    /// Drop a materialized view
631    pub fn drop_view(&mut self, name: &str) -> Result<(), String> {
632        if self.views.remove(name).is_none() {
633            return Err(format!("View '{}' does not exist", name));
634        }
635        Ok(())
636    }
637
638    /// Refresh a view with new results
639    pub fn refresh_view(&mut self, name: &str, results: Vec<Vec<Term>>) -> Result<(), String> {
640        let view = self
641            .views
642            .get_mut(name)
643            .ok_or_else(|| format!("View '{}' does not exist", name))?;
644        view.refresh(results);
645        Ok(())
646    }
647
648    /// Find a view that matches the query
649    pub fn find_matching_view(&mut self, query: &[Predicate]) -> Option<&mut MaterializedView> {
650        self.views
651            .values_mut()
652            .find(|view| view.matches_query(query))
653    }
654
655    /// Get a view by name
656    #[inline]
657    pub fn get_view(&self, name: &str) -> Option<&MaterializedView> {
658        self.views.get(name)
659    }
660
661    /// Get a mutable view by name
662    #[inline]
663    pub fn get_view_mut(&mut self, name: &str) -> Option<&mut MaterializedView> {
664        self.views.get_mut(name)
665    }
666
667    /// Get all views
668    #[inline]
669    pub fn all_views(&self) -> &HashMap<String, MaterializedView> {
670        &self.views
671    }
672
673    /// Evict the least useful view
674    fn evict_least_useful_view(&mut self) {
675        if self.views.is_empty() {
676            return;
677        }
678
679        // Find view with lowest utility score
680        let mut min_score = f64::INFINITY;
681        let mut evict_name: Option<String> = None;
682
683        for (name, view) in &self.views {
684            // Utility score: cost saved per access
685            let score = if view.access_count > 0 {
686                view.total_cost_saved / view.access_count as f64
687            } else {
688                0.0
689            };
690
691            if score < min_score {
692                min_score = score;
693                evict_name = Some(name.clone());
694            }
695        }
696
697        if let Some(name) = evict_name {
698            self.views.remove(&name);
699        }
700    }
701
702    /// Clean up stale views (based on TTL and access count)
703    pub fn cleanup_stale_views(&mut self) {
704        let to_remove: Vec<String> = self
705            .views
706            .iter()
707            .filter(|(_, view)| {
708                view.needs_refresh() || view.access_count < self.min_access_threshold
709            })
710            .map(|(name, _)| name.clone())
711            .collect();
712
713        for name in to_remove {
714            self.views.remove(&name);
715        }
716    }
717
718    /// Get view usage statistics
719    pub fn get_statistics(&self) -> ViewStatistics {
720        let total_views = self.views.len();
721        let total_accesses: usize = self.views.values().map(|v| v.access_count).sum();
722        let total_cost_saved: f64 = self.views.values().map(|v| v.total_cost_saved).sum();
723
724        let avg_access_count = if total_views > 0 {
725            total_accesses as f64 / total_views as f64
726        } else {
727            0.0
728        };
729
730        ViewStatistics {
731            total_views,
732            total_accesses,
733            total_cost_saved,
734            avg_access_count,
735        }
736    }
737
738    /// Set minimum access threshold for view retention
739    #[inline]
740    pub fn set_min_access_threshold(&mut self, threshold: usize) {
741        self.min_access_threshold = threshold;
742    }
743}
744
745impl Default for MaterializedViewManager {
746    fn default() -> Self {
747        Self::new(100)
748    }
749}
750
751/// View usage statistics
752#[derive(Debug, Clone)]
753pub struct ViewStatistics {
754    /// Total number of views
755    pub total_views: usize,
756    /// Total number of view accesses
757    pub total_accesses: usize,
758    /// Total cost saved by using views
759    pub total_cost_saved: f64,
760    /// Average access count per view
761    pub avg_access_count: f64,
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use crate::ir::{Constant, Term};
768
769    #[test]
770    fn test_groundness() {
771        let optimizer = QueryOptimizer::new();
772
773        // All constants
774        let pred1 = Predicate::new(
775            "test".to_string(),
776            vec![
777                Term::Const(Constant::String("a".to_string())),
778                Term::Const(Constant::String("b".to_string())),
779            ],
780        );
781        assert_eq!(optimizer.compute_groundness(&pred1), 1.0);
782
783        // All variables
784        let pred2 = Predicate::new(
785            "test".to_string(),
786            vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
787        );
788        assert_eq!(optimizer.compute_groundness(&pred2), 0.0);
789
790        // Mixed
791        let pred3 = Predicate::new(
792            "test".to_string(),
793            vec![
794                Term::Const(Constant::String("a".to_string())),
795                Term::Var("Y".to_string()),
796            ],
797        );
798        assert_eq!(optimizer.compute_groundness(&pred3), 0.5);
799    }
800
801    #[test]
802    fn test_optimize_rule() {
803        let optimizer = QueryOptimizer::new();
804        let mut kb = KnowledgeBase::new();
805
806        // Add some facts to influence selectivity
807        kb.add_fact(Predicate::new(
808            "rare".to_string(),
809            vec![
810                Term::Const(Constant::String("a".to_string())),
811                Term::Const(Constant::String("b".to_string())),
812            ],
813        ));
814
815        for i in 0..100 {
816            kb.add_fact(Predicate::new(
817                "common".to_string(),
818                vec![
819                    Term::Const(Constant::Int(i)),
820                    Term::Const(Constant::Int(i + 1)),
821                ],
822            ));
823        }
824
825        // Rule with predicates in suboptimal order
826        let rule = Rule::new(
827            Predicate::new("result".to_string(), vec![Term::Var("X".to_string())]),
828            vec![
829                Predicate::new(
830                    "common".to_string(),
831                    vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
832                ),
833                Predicate::new(
834                    "rare".to_string(),
835                    vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
836                ),
837            ],
838        );
839
840        let optimized = optimizer.optimize_rule(&rule, &kb);
841
842        // The optimizer should put 'rare' before 'common' since it's more selective
843        assert_eq!(optimized.body[0].name, "rare");
844        assert_eq!(optimized.body[1].name, "common");
845    }
846
847    #[test]
848    fn test_update_statistics() {
849        let mut optimizer = QueryOptimizer::new();
850        let mut kb = KnowledgeBase::new();
851
852        // Add facts
853        for i in 0..10 {
854            kb.add_fact(Predicate::new(
855                "parent".to_string(),
856                vec![
857                    Term::Const(Constant::Int(i)),
858                    Term::Const(Constant::Int(i + 1)),
859                ],
860            ));
861        }
862
863        for i in 0..5 {
864            kb.add_fact(Predicate::new(
865                "child".to_string(),
866                vec![
867                    Term::Const(Constant::Int(i)),
868                    Term::Const(Constant::Int(i + 1)),
869                ],
870            ));
871        }
872
873        optimizer.update_statistics(&kb);
874
875        // Check stats
876        assert_eq!(optimizer.total_facts(), 15);
877
878        let parent_stats = optimizer.get_stats("parent").unwrap();
879        assert_eq!(parent_stats.fact_count, 10);
880        assert!((parent_stats.selectivity - (10.0 / 15.0)).abs() < 0.001);
881
882        let child_stats = optimizer.get_stats("child").unwrap();
883        assert_eq!(child_stats.fact_count, 5);
884        assert!((child_stats.selectivity - (5.0 / 15.0)).abs() < 0.001);
885    }
886
887    #[test]
888    fn test_query_plan_single() {
889        let optimizer = QueryOptimizer::new();
890        let mut kb = KnowledgeBase::new();
891
892        for i in 0..100 {
893            kb.add_fact(Predicate::new(
894                "test".to_string(),
895                vec![
896                    Term::Const(Constant::Int(i)),
897                    Term::Const(Constant::Int(i * 2)),
898                ],
899            ));
900        }
901
902        let goal = Predicate::new(
903            "test".to_string(),
904            vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
905        );
906
907        let plan = optimizer.plan_query(&[goal], &kb);
908
909        // Should be a scan node
910        matches!(plan.root, PlanNode::Scan { .. });
911        assert!(plan.estimated_rows > 0.0);
912    }
913
914    #[test]
915    fn test_query_plan_join() {
916        let optimizer = QueryOptimizer::new();
917        let mut kb = KnowledgeBase::new();
918
919        for i in 0..10 {
920            kb.add_fact(Predicate::new(
921                "parent".to_string(),
922                vec![
923                    Term::Const(Constant::String(format!("p{}", i))),
924                    Term::Const(Constant::String(format!("c{}", i))),
925                ],
926            ));
927            kb.add_fact(Predicate::new(
928                "likes".to_string(),
929                vec![
930                    Term::Const(Constant::String(format!("c{}", i))),
931                    Term::Const(Constant::String("pizza".to_string())),
932                ],
933            ));
934        }
935
936        let goals = vec![
937            Predicate::new(
938                "parent".to_string(),
939                vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
940            ),
941            Predicate::new(
942                "likes".to_string(),
943                vec![
944                    Term::Var("Y".to_string()),
945                    Term::Const(Constant::String("pizza".to_string())),
946                ],
947            ),
948        ];
949
950        let plan = optimizer.plan_query(&goals, &kb);
951
952        // Should have a join node
953        assert!(plan.estimated_cost > 0.0);
954    }
955
956    #[test]
957    fn test_predicate_stats() {
958        let mut stats = PredicateStats::new(100, 5, 2.5);
959        assert_eq!(stats.fact_count, 100);
960        assert_eq!(stats.rule_count, 5);
961        assert!((stats.avg_arity - 2.5).abs() < 0.001);
962
963        stats.compute_selectivity(1000);
964        assert!((stats.selectivity - 0.1).abs() < 0.001);
965    }
966
967    #[test]
968    fn test_plan_node_cost() {
969        let scan = PlanNode::Scan {
970            predicate: "test".to_string(),
971            bound_vars: vec!["X".to_string()],
972            estimated_rows: 100.0,
973        };
974
975        assert!((scan.cost() - 100.0).abs() < 0.001);
976        assert!((scan.estimated_rows() - 100.0).abs() < 0.001);
977
978        let join = PlanNode::Join {
979            left: Box::new(scan.clone()),
980            right: Box::new(PlanNode::Scan {
981                predicate: "other".to_string(),
982                bound_vars: vec!["Y".to_string()],
983                estimated_rows: 50.0,
984            }),
985            join_vars: vec!["X".to_string()],
986            estimated_rows: 10.0,
987        };
988
989        // Cost should include both scans plus join result
990        assert!(join.cost() > 150.0);
991    }
992
993    #[test]
994    fn test_materialized_view_basic() {
995        let query = vec![Predicate::new(
996            "parent".to_string(),
997            vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
998        )];
999
1000        let mut view = MaterializedView::new("parent_view".to_string(), query.clone());
1001        assert_eq!(view.name, "parent_view");
1002        assert_eq!(view.query.len(), 1);
1003        assert_eq!(view.results.len(), 0);
1004        assert_eq!(view.access_count, 0);
1005
1006        // Add results
1007        let results = vec![
1008            vec![
1009                Term::Const(Constant::String("alice".to_string())),
1010                Term::Const(Constant::String("bob".to_string())),
1011            ],
1012            vec![
1013                Term::Const(Constant::String("bob".to_string())),
1014                Term::Const(Constant::String("charlie".to_string())),
1015            ],
1016        ];
1017        view.refresh(results.clone());
1018        assert_eq!(view.results.len(), 2);
1019
1020        // Record access
1021        view.record_access(10.0);
1022        assert_eq!(view.access_count, 1);
1023        assert!((view.total_cost_saved - 10.0).abs() < 0.001);
1024    }
1025
1026    #[test]
1027    fn test_materialized_view_ttl() {
1028        use std::thread;
1029
1030        let query = vec![Predicate::new(
1031            "test".to_string(),
1032            vec![Term::Var("X".to_string())],
1033        )];
1034
1035        let ttl = Duration::from_millis(10);
1036        let view = MaterializedView::with_ttl("test_view".to_string(), query, ttl);
1037
1038        assert!(!view.needs_refresh());
1039
1040        // Wait for TTL to expire
1041        thread::sleep(Duration::from_millis(20));
1042        assert!(view.needs_refresh());
1043    }
1044
1045    #[test]
1046    fn test_materialized_view_matches_query() {
1047        let query1 = vec![
1048            Predicate::new(
1049                "parent".to_string(),
1050                vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
1051            ),
1052            Predicate::new(
1053                "likes".to_string(),
1054                vec![Term::Var("Y".to_string()), Term::Var("Z".to_string())],
1055            ),
1056        ];
1057
1058        let view = MaterializedView::new("view1".to_string(), query1.clone());
1059
1060        // Same query should match
1061        assert!(view.matches_query(&query1));
1062
1063        // Different query should not match
1064        let query2 = vec![Predicate::new(
1065            "parent".to_string(),
1066            vec![Term::Var("A".to_string()), Term::Var("B".to_string())],
1067        )];
1068        assert!(!view.matches_query(&query2));
1069    }
1070
1071    #[test]
1072    fn test_view_manager_create_drop() {
1073        let mut manager = MaterializedViewManager::new(10);
1074
1075        let query = vec![Predicate::new(
1076            "test".to_string(),
1077            vec![Term::Var("X".to_string())],
1078        )];
1079
1080        // Create view
1081        assert!(manager
1082            .create_view("view1".to_string(), query.clone(), None)
1083            .is_ok());
1084        assert_eq!(manager.all_views().len(), 1);
1085
1086        // Create duplicate should fail
1087        assert!(manager
1088            .create_view("view1".to_string(), query, None)
1089            .is_err());
1090
1091        // Drop view
1092        assert!(manager.drop_view("view1").is_ok());
1093        assert_eq!(manager.all_views().len(), 0);
1094
1095        // Drop non-existent view should fail
1096        assert!(manager.drop_view("view1").is_err());
1097    }
1098
1099    #[test]
1100    fn test_view_manager_refresh() {
1101        let mut manager = MaterializedViewManager::new(10);
1102
1103        let query = vec![Predicate::new(
1104            "test".to_string(),
1105            vec![Term::Var("X".to_string())],
1106        )];
1107
1108        manager
1109            .create_view("view1".to_string(), query, None)
1110            .unwrap();
1111
1112        let results = vec![vec![Term::Const(Constant::Int(1))]];
1113
1114        assert!(manager.refresh_view("view1", results.clone()).is_ok());
1115
1116        let view = manager.get_view("view1").unwrap();
1117        assert_eq!(view.results.len(), 1);
1118    }
1119
1120    #[test]
1121    fn test_view_manager_find_matching() {
1122        let mut manager = MaterializedViewManager::new(10);
1123
1124        let query1 = vec![Predicate::new(
1125            "parent".to_string(),
1126            vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
1127        )];
1128
1129        manager
1130            .create_view("parent_view".to_string(), query1.clone(), None)
1131            .unwrap();
1132
1133        // Should find matching view
1134        let found = manager.find_matching_view(&query1);
1135        assert!(found.is_some());
1136        assert_eq!(found.unwrap().name, "parent_view");
1137
1138        // Should not find non-matching view
1139        let query2 = vec![Predicate::new(
1140            "likes".to_string(),
1141            vec![Term::Var("X".to_string()), Term::Var("Y".to_string())],
1142        )];
1143        let not_found = manager.find_matching_view(&query2);
1144        assert!(not_found.is_none());
1145    }
1146
1147    #[test]
1148    fn test_view_manager_eviction() {
1149        let mut manager = MaterializedViewManager::new(3);
1150
1151        // Create 3 views
1152        for i in 0..3 {
1153            let query = vec![Predicate::new(
1154                format!("pred{}", i),
1155                vec![Term::Var("X".to_string())],
1156            )];
1157            manager
1158                .create_view(format!("view{}", i), query, None)
1159                .unwrap();
1160        }
1161
1162        assert_eq!(manager.all_views().len(), 3);
1163
1164        // Record different access counts
1165        if let Some(view) = manager.get_view_mut("view0") {
1166            view.record_access(100.0);
1167        }
1168        if let Some(view) = manager.get_view_mut("view1") {
1169            view.record_access(50.0);
1170        }
1171        // view2 has no accesses
1172
1173        // Create one more view - should evict view2 (lowest utility)
1174        let query = vec![Predicate::new(
1175            "pred3".to_string(),
1176            vec![Term::Var("X".to_string())],
1177        )];
1178        manager
1179            .create_view("view3".to_string(), query, None)
1180            .unwrap();
1181
1182        assert_eq!(manager.all_views().len(), 3);
1183        assert!(manager.get_view("view2").is_none()); // view2 should be evicted
1184        assert!(manager.get_view("view0").is_some());
1185        assert!(manager.get_view("view1").is_some());
1186        assert!(manager.get_view("view3").is_some());
1187    }
1188
1189    #[test]
1190    fn test_view_manager_cleanup_stale() {
1191        use std::thread;
1192
1193        let mut manager = MaterializedViewManager::new(10);
1194        manager.set_min_access_threshold(5);
1195
1196        // Create view with TTL
1197        let query1 = vec![Predicate::new(
1198            "test1".to_string(),
1199            vec![Term::Var("X".to_string())],
1200        )];
1201        manager
1202            .create_view("view1".to_string(), query1, Some(Duration::from_millis(10)))
1203            .unwrap();
1204
1205        // Create view with low access count
1206        let query2 = vec![Predicate::new(
1207            "test2".to_string(),
1208            vec![Term::Var("X".to_string())],
1209        )];
1210        manager
1211            .create_view("view2".to_string(), query2, None)
1212            .unwrap();
1213
1214        // Create view with high access count
1215        let query3 = vec![Predicate::new(
1216            "test3".to_string(),
1217            vec![Term::Var("X".to_string())],
1218        )];
1219        manager
1220            .create_view("view3".to_string(), query3, None)
1221            .unwrap();
1222
1223        if let Some(view) = manager.get_view_mut("view3") {
1224            for _ in 0..10 {
1225                view.record_access(1.0);
1226            }
1227        }
1228
1229        // Wait for TTL to expire
1230        thread::sleep(Duration::from_millis(20));
1231
1232        manager.cleanup_stale_views();
1233
1234        // view1 should be removed (TTL expired)
1235        // view2 should be removed (low access count)
1236        // view3 should remain (high access count)
1237        assert!(manager.get_view("view1").is_none());
1238        assert!(manager.get_view("view2").is_none());
1239        assert!(manager.get_view("view3").is_some());
1240    }
1241
1242    #[test]
1243    fn test_view_statistics() {
1244        let mut manager = MaterializedViewManager::new(10);
1245
1246        // Create views with different access patterns
1247        for i in 0..3 {
1248            let query = vec![Predicate::new(
1249                format!("pred{}", i),
1250                vec![Term::Var("X".to_string())],
1251            )];
1252            manager
1253                .create_view(format!("view{}", i), query, None)
1254                .unwrap();
1255
1256            if let Some(view) = manager.get_view_mut(&format!("view{}", i)) {
1257                for _ in 0..((i + 1) * 5) {
1258                    view.record_access(10.0);
1259                }
1260            }
1261        }
1262
1263        let stats = manager.get_statistics();
1264        assert_eq!(stats.total_views, 3);
1265        assert_eq!(stats.total_accesses, 30); // 5 + 10 + 15
1266        assert!((stats.total_cost_saved - 300.0).abs() < 0.001); // 30 * 10.0
1267        assert!((stats.avg_access_count - 10.0).abs() < 0.001); // 30 / 3
1268    }
1269}