ipfrs_semantic/
kb_query.rs

1//! Knowledge Base Query Language
2//!
3//! This module provides a SPARQL-like query language for semantic knowledge bases:
4//! - Triple pattern matching for graph queries
5//! - Pattern matching for logic terms with wildcards
6//! - Query optimization (join order, filter pushdown)
7//! - Complex boolean queries (AND/OR/NOT)
8
9use ipfrs_core::Result;
10use ipfrs_tensorlogic::{KnowledgeBase, Predicate, Term};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13
14/// Query pattern for matching predicates
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub enum QueryPattern {
17    /// Exact predicate match
18    Exact(Predicate),
19    /// Wildcard pattern (name, args with wildcards)
20    Pattern {
21        name: Option<String>,
22        args: Vec<TermPattern>,
23    },
24    /// Variable binding
25    Variable(String),
26}
27
28/// Pattern for matching terms
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub enum TermPattern {
31    /// Exact term match
32    Exact(Term),
33    /// Wildcard (matches any term)
34    Wildcard,
35    /// Variable (binds to matched term)
36    Variable(String),
37    /// Type constraint (e.g., must be constant)
38    TypeConstraint(TermType),
39}
40
41/// Term type for type constraints
42#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
43pub enum TermType {
44    Var,
45    Const,
46    Fun,
47    Ref,
48}
49
50/// Boolean query operators
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub enum BooleanQuery {
53    /// Conjunction (AND)
54    And(Vec<Query>),
55    /// Disjunction (OR)
56    Or(Vec<Query>),
57    /// Negation (NOT)
58    Not(Box<Query>),
59    /// Atomic query
60    Atom(Query),
61}
62
63/// Query filter expressions
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub enum FilterExpr {
66    /// Equality comparison
67    Equals(String, String),
68    /// Inequality
69    NotEquals(String, String),
70    /// Regex match on variable
71    Regex(String, String),
72    /// Type check
73    IsType(String, TermType),
74    /// Conjunction of filters
75    And(Vec<FilterExpr>),
76    /// Disjunction of filters
77    Or(Vec<FilterExpr>),
78}
79
80/// A query for the knowledge base
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub struct Query {
83    /// SELECT clause - variables to return
84    pub select: Vec<String>,
85    /// WHERE clause - patterns to match
86    pub patterns: Vec<QueryPattern>,
87    /// FILTER clause - filter expressions
88    pub filters: Vec<FilterExpr>,
89    /// LIMIT - maximum results
90    pub limit: Option<usize>,
91    /// OFFSET - skip first N results
92    pub offset: Option<usize>,
93}
94
95impl Query {
96    /// Create a new query
97    pub fn new() -> Self {
98        Self {
99            select: Vec::new(),
100            patterns: Vec::new(),
101            filters: Vec::new(),
102            limit: None,
103            offset: None,
104        }
105    }
106
107    /// Add a SELECT variable
108    pub fn select(mut self, var: impl Into<String>) -> Self {
109        self.select.push(var.into());
110        self
111    }
112
113    /// Add a WHERE pattern
114    pub fn where_pattern(mut self, pattern: QueryPattern) -> Self {
115        self.patterns.push(pattern);
116        self
117    }
118
119    /// Add a FILTER expression
120    pub fn filter(mut self, expr: FilterExpr) -> Self {
121        self.filters.push(expr);
122        self
123    }
124
125    /// Set LIMIT
126    pub fn limit(mut self, n: usize) -> Self {
127        self.limit = Some(n);
128        self
129    }
130
131    /// Set OFFSET
132    pub fn offset(mut self, n: usize) -> Self {
133        self.offset = Some(n);
134        self
135    }
136}
137
138impl Default for Query {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144/// Query execution result
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct QueryResult {
147    /// Variable bindings
148    pub bindings: Vec<HashMap<String, Term>>,
149    /// Query statistics
150    pub stats: QueryStats,
151}
152
153/// Query execution statistics
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct QueryStats {
156    /// Number of patterns evaluated
157    pub patterns_evaluated: usize,
158    /// Number of intermediate results
159    pub intermediate_results: usize,
160    /// Number of final results
161    pub final_results: usize,
162    /// Execution time in milliseconds
163    pub execution_time_ms: u64,
164}
165
166/// Query executor with optimization
167pub struct QueryExecutor {
168    /// Knowledge base to query
169    kb: KnowledgeBase,
170    /// Whether to enable query optimization
171    optimize: bool,
172}
173
174impl QueryExecutor {
175    /// Create a new query executor
176    pub fn new(kb: KnowledgeBase) -> Self {
177        Self { kb, optimize: true }
178    }
179
180    /// Enable or disable query optimization
181    pub fn set_optimization(&mut self, enabled: bool) {
182        self.optimize = enabled;
183    }
184
185    /// Execute a query
186    pub fn execute(&self, mut query: Query) -> Result<QueryResult> {
187        let start = std::time::Instant::now();
188
189        // Optimize query if enabled
190        if self.optimize {
191            query = self.optimize_query(query)?;
192        }
193
194        // Execute query patterns
195        let mut bindings = vec![HashMap::new()];
196        let mut patterns_evaluated = 0;
197        let mut intermediate_results = 0;
198
199        for pattern in &query.patterns {
200            let new_bindings = self.match_pattern(pattern, &bindings)?;
201            intermediate_results += new_bindings.len();
202            bindings = new_bindings;
203            patterns_evaluated += 1;
204        }
205
206        // Apply filters
207        bindings = self.apply_filters(&query.filters, bindings)?;
208
209        // Apply projection (SELECT clause)
210        bindings = self.project_variables(&query.select, bindings);
211
212        // Apply OFFSET and LIMIT
213        if let Some(offset) = query.offset {
214            bindings = bindings.into_iter().skip(offset).collect();
215        }
216        if let Some(limit) = query.limit {
217            bindings.truncate(limit);
218        }
219
220        let execution_time_ms = start.elapsed().as_millis() as u64;
221        let final_results = bindings.len();
222
223        Ok(QueryResult {
224            bindings,
225            stats: QueryStats {
226                patterns_evaluated,
227                intermediate_results,
228                final_results,
229                execution_time_ms,
230            },
231        })
232    }
233
234    /// Optimize query (join reordering, filter pushdown)
235    fn optimize_query(&self, mut query: Query) -> Result<Query> {
236        // Reorder patterns by selectivity (most selective first)
237        query.patterns = self.reorder_patterns(query.patterns)?;
238
239        // Push filters down (apply as early as possible)
240        // For now, filters are applied after all patterns
241
242        Ok(query)
243    }
244
245    /// Reorder patterns by selectivity
246    fn reorder_patterns(&self, patterns: Vec<QueryPattern>) -> Result<Vec<QueryPattern>> {
247        let mut scored: Vec<(QueryPattern, usize)> = patterns
248            .into_iter()
249            .map(|p| {
250                let selectivity = self.estimate_selectivity(&p);
251                (p, selectivity)
252            })
253            .collect();
254
255        // Sort by selectivity (ascending - most selective first)
256        scored.sort_by_key(|(_, s)| *s);
257
258        Ok(scored.into_iter().map(|(p, _)| p).collect())
259    }
260
261    /// Estimate selectivity of a pattern (number of matches)
262    fn estimate_selectivity(&self, pattern: &QueryPattern) -> usize {
263        match pattern {
264            QueryPattern::Exact(pred) => {
265                // Exact match - check if exists in facts
266                if self.kb.facts.contains(pred) {
267                    1
268                } else {
269                    0
270                }
271            }
272            QueryPattern::Pattern { name, args } => {
273                // Pattern match - count matching facts
274                let mut count = 0;
275                for fact in &self.kb.facts {
276                    if let Some(n) = name {
277                        if &fact.name != n {
278                            continue;
279                        }
280                    }
281                    if args.len() != fact.args.len() {
282                        continue;
283                    }
284                    if args
285                        .iter()
286                        .zip(&fact.args)
287                        .all(|(p, t)| self.term_matches(p, t))
288                    {
289                        count += 1;
290                    }
291                }
292                count
293            }
294            QueryPattern::Variable(_) => self.kb.facts.len(), // Matches all
295        }
296    }
297
298    /// Match a pattern against current bindings
299    fn match_pattern(
300        &self,
301        pattern: &QueryPattern,
302        current_bindings: &[HashMap<String, Term>],
303    ) -> Result<Vec<HashMap<String, Term>>> {
304        let mut new_bindings = Vec::new();
305
306        for binding in current_bindings {
307            match pattern {
308                QueryPattern::Exact(pred) => {
309                    // Check if predicate exists in facts
310                    if self.kb.facts.contains(pred) {
311                        new_bindings.push(binding.clone());
312                    }
313                }
314                QueryPattern::Pattern { name, args } => {
315                    // Match against all facts
316                    for fact in &self.kb.facts {
317                        if let Some(n) = name {
318                            if &fact.name != n {
319                                continue;
320                            }
321                        }
322                        if args.len() != fact.args.len() {
323                            continue;
324                        }
325
326                        // Try to match all arguments
327                        let mut new_binding = binding.clone();
328                        let mut matches = true;
329
330                        for (pattern_arg, fact_arg) in args.iter().zip(&fact.args) {
331                            if !self.match_term_pattern(pattern_arg, fact_arg, &mut new_binding) {
332                                matches = false;
333                                break;
334                            }
335                        }
336
337                        if matches {
338                            new_bindings.push(new_binding);
339                        }
340                    }
341                }
342                QueryPattern::Variable(var) => {
343                    // Bind variable to all facts
344                    for fact in &self.kb.facts {
345                        let mut new_binding = binding.clone();
346                        // Convert predicate to term representation (simplified)
347                        new_binding.insert(var.clone(), Term::Var(fact.name.clone()));
348                        new_bindings.push(new_binding);
349                    }
350                }
351            }
352        }
353
354        Ok(new_bindings)
355    }
356
357    /// Match a term pattern against a term
358    fn match_term_pattern(
359        &self,
360        pattern: &TermPattern,
361        term: &Term,
362        binding: &mut HashMap<String, Term>,
363    ) -> bool {
364        match pattern {
365            TermPattern::Exact(ref expected) => term == expected,
366            TermPattern::Wildcard => true,
367            TermPattern::Variable(var) => {
368                // Check if variable already bound
369                if let Some(bound_term) = binding.get(var) {
370                    bound_term == term
371                } else {
372                    // Bind variable
373                    binding.insert(var.clone(), term.clone());
374                    true
375                }
376            }
377            TermPattern::TypeConstraint(typ) => self.check_term_type(term, *typ),
378        }
379    }
380
381    /// Check if term matches type constraint
382    fn check_term_type(&self, term: &Term, typ: TermType) -> bool {
383        matches!(
384            (term, typ),
385            (Term::Var(_), TermType::Var)
386                | (Term::Const(_), TermType::Const)
387                | (Term::Fun(_, _), TermType::Fun)
388                | (Term::Ref(_), TermType::Ref)
389        )
390    }
391
392    /// Check if term matches pattern
393    fn term_matches(&self, pattern: &TermPattern, term: &Term) -> bool {
394        match pattern {
395            TermPattern::Exact(ref expected) => term == expected,
396            TermPattern::Wildcard => true,
397            TermPattern::Variable(_) => true,
398            TermPattern::TypeConstraint(typ) => self.check_term_type(term, *typ),
399        }
400    }
401
402    /// Apply filter expressions to bindings
403    fn apply_filters(
404        &self,
405        filters: &[FilterExpr],
406        bindings: Vec<HashMap<String, Term>>,
407    ) -> Result<Vec<HashMap<String, Term>>> {
408        let mut result = bindings;
409
410        for filter in filters {
411            result.retain(|binding| self.evaluate_filter(filter, binding));
412        }
413
414        Ok(result)
415    }
416
417    /// Evaluate a filter expression
418    fn evaluate_filter(&self, filter: &FilterExpr, binding: &HashMap<String, Term>) -> bool {
419        match filter {
420            FilterExpr::Equals(var1, var2) => {
421                let t1 = binding.get(var1);
422                let t2 = binding.get(var2);
423                t1.is_some() && t2.is_some() && t1 == t2
424            }
425            FilterExpr::NotEquals(var1, var2) => {
426                let t1 = binding.get(var1);
427                let t2 = binding.get(var2);
428                t1.is_some() && t2.is_some() && t1 != t2
429            }
430            FilterExpr::Regex(var, pattern) => {
431                if let Some(term) = binding.get(var) {
432                    let term_str = format!("{:?}", term);
433                    term_str.contains(pattern)
434                } else {
435                    false
436                }
437            }
438            FilterExpr::IsType(var, typ) => {
439                if let Some(term) = binding.get(var) {
440                    self.check_term_type(term, *typ)
441                } else {
442                    false
443                }
444            }
445            FilterExpr::And(exprs) => exprs.iter().all(|e| self.evaluate_filter(e, binding)),
446            FilterExpr::Or(exprs) => exprs.iter().any(|e| self.evaluate_filter(e, binding)),
447        }
448    }
449
450    /// Project variables (SELECT clause)
451    fn project_variables(
452        &self,
453        vars: &[String],
454        bindings: Vec<HashMap<String, Term>>,
455    ) -> Vec<HashMap<String, Term>> {
456        if vars.is_empty() {
457            // No projection, return all
458            return bindings;
459        }
460
461        bindings
462            .into_iter()
463            .map(|binding| {
464                vars.iter()
465                    .filter_map(|v| binding.get(v).map(|t| (v.clone(), t.clone())))
466                    .collect()
467            })
468            .collect()
469    }
470
471    /// Execute a boolean query
472    pub fn execute_boolean(&self, query: &BooleanQuery) -> Result<QueryResult> {
473        match query {
474            BooleanQuery::And(queries) => {
475                // Execute all queries and intersect results
476                let mut results: Option<Vec<HashMap<String, Term>>> = None;
477
478                for q in queries {
479                    let result = self.execute(q.clone())?;
480
481                    if let Some(existing) = results {
482                        // Intersect
483                        let new_set: HashSet<_> = result
484                            .bindings
485                            .into_iter()
486                            .map(|b| format!("{:?}", b))
487                            .collect();
488                        results = Some(
489                            existing
490                                .into_iter()
491                                .filter(|b| new_set.contains(&format!("{:?}", b)))
492                                .collect(),
493                        );
494                    } else {
495                        results = Some(result.bindings);
496                    }
497                }
498
499                let final_results = results.as_ref().map(|r| r.len()).unwrap_or(0);
500                Ok(QueryResult {
501                    bindings: results.unwrap_or_default(),
502                    stats: QueryStats {
503                        patterns_evaluated: queries.len(),
504                        intermediate_results: 0,
505                        final_results,
506                        execution_time_ms: 0,
507                    },
508                })
509            }
510            BooleanQuery::Or(queries) => {
511                // Execute all queries and union results
512                let mut all_bindings = Vec::new();
513                let mut seen = HashSet::new();
514
515                for q in queries {
516                    let result = self.execute(q.clone())?;
517
518                    for binding in result.bindings {
519                        let key = format!("{:?}", binding);
520                        if seen.insert(key) {
521                            all_bindings.push(binding);
522                        }
523                    }
524                }
525
526                Ok(QueryResult {
527                    bindings: all_bindings.clone(),
528                    stats: QueryStats {
529                        patterns_evaluated: queries.len(),
530                        intermediate_results: 0,
531                        final_results: all_bindings.len(),
532                        execution_time_ms: 0,
533                    },
534                })
535            }
536            BooleanQuery::Not(query) => {
537                // Get all possible bindings, then subtract query results
538                let all_result = self.execute(Query::new())?;
539                let excluded_result = self.execute(query.as_ref().clone())?;
540
541                let excluded_set: HashSet<_> = excluded_result
542                    .bindings
543                    .into_iter()
544                    .map(|b| format!("{:?}", b))
545                    .collect();
546
547                let filtered: Vec<_> = all_result
548                    .bindings
549                    .into_iter()
550                    .filter(|b| !excluded_set.contains(&format!("{:?}", b)))
551                    .collect();
552
553                Ok(QueryResult {
554                    bindings: filtered.clone(),
555                    stats: QueryStats {
556                        patterns_evaluated: 1,
557                        intermediate_results: 0,
558                        final_results: filtered.len(),
559                        execution_time_ms: 0,
560                    },
561                })
562            }
563            BooleanQuery::Atom(query) => self.execute(query.clone()),
564        }
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use ipfrs_tensorlogic::Constant;
572
573    #[test]
574    fn test_query_builder() {
575        let query = Query::new()
576            .select("X")
577            .select("Y")
578            .where_pattern(QueryPattern::Pattern {
579                name: Some("parent".to_string()),
580                args: vec![
581                    TermPattern::Variable("X".to_string()),
582                    TermPattern::Variable("Y".to_string()),
583                ],
584            })
585            .limit(10);
586
587        assert_eq!(query.select.len(), 2);
588        assert_eq!(query.patterns.len(), 1);
589        assert_eq!(query.limit, Some(10));
590    }
591
592    #[test]
593    fn test_query_executor() {
594        let mut kb = KnowledgeBase::new();
595
596        // Add some facts
597        let alice = Term::Const(Constant::String("Alice".to_string()));
598        let bob = Term::Const(Constant::String("Bob".to_string()));
599        kb.add_fact(Predicate::new(
600            "parent".to_string(),
601            vec![alice.clone(), bob.clone()],
602        ));
603
604        let executor = QueryExecutor::new(kb);
605
606        // Query for all parent relationships
607        let query = Query::new().where_pattern(QueryPattern::Pattern {
608            name: Some("parent".to_string()),
609            args: vec![TermPattern::Wildcard, TermPattern::Wildcard],
610        });
611
612        let result = executor.execute(query).unwrap();
613        assert!(!result.bindings.is_empty());
614    }
615
616    #[test]
617    fn test_pattern_matching() {
618        let mut kb = KnowledgeBase::new();
619
620        let alice = Term::Const(Constant::String("Alice".to_string()));
621        let bob = Term::Const(Constant::String("Bob".to_string()));
622        kb.add_fact(Predicate::new("parent".to_string(), vec![alice, bob]));
623
624        let executor = QueryExecutor::new(kb);
625
626        // Query with variable binding
627        let query = Query::new()
628            .select("X")
629            .select("Y")
630            .where_pattern(QueryPattern::Pattern {
631                name: Some("parent".to_string()),
632                args: vec![
633                    TermPattern::Variable("X".to_string()),
634                    TermPattern::Variable("Y".to_string()),
635                ],
636            });
637
638        let result = executor.execute(query).unwrap();
639        assert_eq!(result.bindings.len(), 1);
640        assert!(result.bindings[0].contains_key("X"));
641        assert!(result.bindings[0].contains_key("Y"));
642    }
643
644    #[test]
645    fn test_filter_expr() {
646        let mut kb = KnowledgeBase::new();
647
648        let alice = Term::Const(Constant::String("Alice".to_string()));
649        let bob = Term::Const(Constant::String("Bob".to_string()));
650        kb.add_fact(Predicate::new("person".to_string(), vec![alice]));
651        kb.add_fact(Predicate::new("person".to_string(), vec![bob]));
652
653        let executor = QueryExecutor::new(kb);
654
655        // Query with type filter
656        let query = Query::new()
657            .select("X")
658            .where_pattern(QueryPattern::Pattern {
659                name: Some("person".to_string()),
660                args: vec![TermPattern::Variable("X".to_string())],
661            })
662            .filter(FilterExpr::IsType("X".to_string(), TermType::Const));
663
664        let result = executor.execute(query).unwrap();
665        assert_eq!(result.bindings.len(), 2);
666    }
667}