1use ipfrs_core::Result;
10use ipfrs_tensorlogic::{KnowledgeBase, Predicate, Term};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub enum QueryPattern {
17 Exact(Predicate),
19 Pattern {
21 name: Option<String>,
22 args: Vec<TermPattern>,
23 },
24 Variable(String),
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
30pub enum TermPattern {
31 Exact(Term),
33 Wildcard,
35 Variable(String),
37 TypeConstraint(TermType),
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
43pub enum TermType {
44 Var,
45 Const,
46 Fun,
47 Ref,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52pub enum BooleanQuery {
53 And(Vec<Query>),
55 Or(Vec<Query>),
57 Not(Box<Query>),
59 Atom(Query),
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65pub enum FilterExpr {
66 Equals(String, String),
68 NotEquals(String, String),
70 Regex(String, String),
72 IsType(String, TermType),
74 And(Vec<FilterExpr>),
76 Or(Vec<FilterExpr>),
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
82pub struct Query {
83 pub select: Vec<String>,
85 pub patterns: Vec<QueryPattern>,
87 pub filters: Vec<FilterExpr>,
89 pub limit: Option<usize>,
91 pub offset: Option<usize>,
93}
94
95impl Query {
96 pub fn new() -> Self {
98 Self {
99 select: Vec::new(),
100 patterns: Vec::new(),
101 filters: Vec::new(),
102 limit: None,
103 offset: None,
104 }
105 }
106
107 pub fn select(mut self, var: impl Into<String>) -> Self {
109 self.select.push(var.into());
110 self
111 }
112
113 pub fn where_pattern(mut self, pattern: QueryPattern) -> Self {
115 self.patterns.push(pattern);
116 self
117 }
118
119 pub fn filter(mut self, expr: FilterExpr) -> Self {
121 self.filters.push(expr);
122 self
123 }
124
125 pub fn limit(mut self, n: usize) -> Self {
127 self.limit = Some(n);
128 self
129 }
130
131 pub fn offset(mut self, n: usize) -> Self {
133 self.offset = Some(n);
134 self
135 }
136}
137
138impl Default for Query {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct QueryResult {
147 pub bindings: Vec<HashMap<String, Term>>,
149 pub stats: QueryStats,
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct QueryStats {
156 pub patterns_evaluated: usize,
158 pub intermediate_results: usize,
160 pub final_results: usize,
162 pub execution_time_ms: u64,
164}
165
166pub struct QueryExecutor {
168 kb: KnowledgeBase,
170 optimize: bool,
172}
173
174impl QueryExecutor {
175 pub fn new(kb: KnowledgeBase) -> Self {
177 Self { kb, optimize: true }
178 }
179
180 pub fn set_optimization(&mut self, enabled: bool) {
182 self.optimize = enabled;
183 }
184
185 pub fn execute(&self, mut query: Query) -> Result<QueryResult> {
187 let start = std::time::Instant::now();
188
189 if self.optimize {
191 query = self.optimize_query(query)?;
192 }
193
194 let mut bindings = vec![HashMap::new()];
196 let mut patterns_evaluated = 0;
197 let mut intermediate_results = 0;
198
199 for pattern in &query.patterns {
200 let new_bindings = self.match_pattern(pattern, &bindings)?;
201 intermediate_results += new_bindings.len();
202 bindings = new_bindings;
203 patterns_evaluated += 1;
204 }
205
206 bindings = self.apply_filters(&query.filters, bindings)?;
208
209 bindings = self.project_variables(&query.select, bindings);
211
212 if let Some(offset) = query.offset {
214 bindings = bindings.into_iter().skip(offset).collect();
215 }
216 if let Some(limit) = query.limit {
217 bindings.truncate(limit);
218 }
219
220 let execution_time_ms = start.elapsed().as_millis() as u64;
221 let final_results = bindings.len();
222
223 Ok(QueryResult {
224 bindings,
225 stats: QueryStats {
226 patterns_evaluated,
227 intermediate_results,
228 final_results,
229 execution_time_ms,
230 },
231 })
232 }
233
234 fn optimize_query(&self, mut query: Query) -> Result<Query> {
236 query.patterns = self.reorder_patterns(query.patterns)?;
238
239 Ok(query)
243 }
244
245 fn reorder_patterns(&self, patterns: Vec<QueryPattern>) -> Result<Vec<QueryPattern>> {
247 let mut scored: Vec<(QueryPattern, usize)> = patterns
248 .into_iter()
249 .map(|p| {
250 let selectivity = self.estimate_selectivity(&p);
251 (p, selectivity)
252 })
253 .collect();
254
255 scored.sort_by_key(|(_, s)| *s);
257
258 Ok(scored.into_iter().map(|(p, _)| p).collect())
259 }
260
261 fn estimate_selectivity(&self, pattern: &QueryPattern) -> usize {
263 match pattern {
264 QueryPattern::Exact(pred) => {
265 if self.kb.facts.contains(pred) {
267 1
268 } else {
269 0
270 }
271 }
272 QueryPattern::Pattern { name, args } => {
273 let mut count = 0;
275 for fact in &self.kb.facts {
276 if let Some(n) = name {
277 if &fact.name != n {
278 continue;
279 }
280 }
281 if args.len() != fact.args.len() {
282 continue;
283 }
284 if args
285 .iter()
286 .zip(&fact.args)
287 .all(|(p, t)| self.term_matches(p, t))
288 {
289 count += 1;
290 }
291 }
292 count
293 }
294 QueryPattern::Variable(_) => self.kb.facts.len(), }
296 }
297
298 fn match_pattern(
300 &self,
301 pattern: &QueryPattern,
302 current_bindings: &[HashMap<String, Term>],
303 ) -> Result<Vec<HashMap<String, Term>>> {
304 let mut new_bindings = Vec::new();
305
306 for binding in current_bindings {
307 match pattern {
308 QueryPattern::Exact(pred) => {
309 if self.kb.facts.contains(pred) {
311 new_bindings.push(binding.clone());
312 }
313 }
314 QueryPattern::Pattern { name, args } => {
315 for fact in &self.kb.facts {
317 if let Some(n) = name {
318 if &fact.name != n {
319 continue;
320 }
321 }
322 if args.len() != fact.args.len() {
323 continue;
324 }
325
326 let mut new_binding = binding.clone();
328 let mut matches = true;
329
330 for (pattern_arg, fact_arg) in args.iter().zip(&fact.args) {
331 if !self.match_term_pattern(pattern_arg, fact_arg, &mut new_binding) {
332 matches = false;
333 break;
334 }
335 }
336
337 if matches {
338 new_bindings.push(new_binding);
339 }
340 }
341 }
342 QueryPattern::Variable(var) => {
343 for fact in &self.kb.facts {
345 let mut new_binding = binding.clone();
346 new_binding.insert(var.clone(), Term::Var(fact.name.clone()));
348 new_bindings.push(new_binding);
349 }
350 }
351 }
352 }
353
354 Ok(new_bindings)
355 }
356
357 fn match_term_pattern(
359 &self,
360 pattern: &TermPattern,
361 term: &Term,
362 binding: &mut HashMap<String, Term>,
363 ) -> bool {
364 match pattern {
365 TermPattern::Exact(ref expected) => term == expected,
366 TermPattern::Wildcard => true,
367 TermPattern::Variable(var) => {
368 if let Some(bound_term) = binding.get(var) {
370 bound_term == term
371 } else {
372 binding.insert(var.clone(), term.clone());
374 true
375 }
376 }
377 TermPattern::TypeConstraint(typ) => self.check_term_type(term, *typ),
378 }
379 }
380
381 fn check_term_type(&self, term: &Term, typ: TermType) -> bool {
383 matches!(
384 (term, typ),
385 (Term::Var(_), TermType::Var)
386 | (Term::Const(_), TermType::Const)
387 | (Term::Fun(_, _), TermType::Fun)
388 | (Term::Ref(_), TermType::Ref)
389 )
390 }
391
392 fn term_matches(&self, pattern: &TermPattern, term: &Term) -> bool {
394 match pattern {
395 TermPattern::Exact(ref expected) => term == expected,
396 TermPattern::Wildcard => true,
397 TermPattern::Variable(_) => true,
398 TermPattern::TypeConstraint(typ) => self.check_term_type(term, *typ),
399 }
400 }
401
402 fn apply_filters(
404 &self,
405 filters: &[FilterExpr],
406 bindings: Vec<HashMap<String, Term>>,
407 ) -> Result<Vec<HashMap<String, Term>>> {
408 let mut result = bindings;
409
410 for filter in filters {
411 result.retain(|binding| self.evaluate_filter(filter, binding));
412 }
413
414 Ok(result)
415 }
416
417 fn evaluate_filter(&self, filter: &FilterExpr, binding: &HashMap<String, Term>) -> bool {
419 match filter {
420 FilterExpr::Equals(var1, var2) => {
421 let t1 = binding.get(var1);
422 let t2 = binding.get(var2);
423 t1.is_some() && t2.is_some() && t1 == t2
424 }
425 FilterExpr::NotEquals(var1, var2) => {
426 let t1 = binding.get(var1);
427 let t2 = binding.get(var2);
428 t1.is_some() && t2.is_some() && t1 != t2
429 }
430 FilterExpr::Regex(var, pattern) => {
431 if let Some(term) = binding.get(var) {
432 let term_str = format!("{:?}", term);
433 term_str.contains(pattern)
434 } else {
435 false
436 }
437 }
438 FilterExpr::IsType(var, typ) => {
439 if let Some(term) = binding.get(var) {
440 self.check_term_type(term, *typ)
441 } else {
442 false
443 }
444 }
445 FilterExpr::And(exprs) => exprs.iter().all(|e| self.evaluate_filter(e, binding)),
446 FilterExpr::Or(exprs) => exprs.iter().any(|e| self.evaluate_filter(e, binding)),
447 }
448 }
449
450 fn project_variables(
452 &self,
453 vars: &[String],
454 bindings: Vec<HashMap<String, Term>>,
455 ) -> Vec<HashMap<String, Term>> {
456 if vars.is_empty() {
457 return bindings;
459 }
460
461 bindings
462 .into_iter()
463 .map(|binding| {
464 vars.iter()
465 .filter_map(|v| binding.get(v).map(|t| (v.clone(), t.clone())))
466 .collect()
467 })
468 .collect()
469 }
470
471 pub fn execute_boolean(&self, query: &BooleanQuery) -> Result<QueryResult> {
473 match query {
474 BooleanQuery::And(queries) => {
475 let mut results: Option<Vec<HashMap<String, Term>>> = None;
477
478 for q in queries {
479 let result = self.execute(q.clone())?;
480
481 if let Some(existing) = results {
482 let new_set: HashSet<_> = result
484 .bindings
485 .into_iter()
486 .map(|b| format!("{:?}", b))
487 .collect();
488 results = Some(
489 existing
490 .into_iter()
491 .filter(|b| new_set.contains(&format!("{:?}", b)))
492 .collect(),
493 );
494 } else {
495 results = Some(result.bindings);
496 }
497 }
498
499 let final_results = results.as_ref().map(|r| r.len()).unwrap_or(0);
500 Ok(QueryResult {
501 bindings: results.unwrap_or_default(),
502 stats: QueryStats {
503 patterns_evaluated: queries.len(),
504 intermediate_results: 0,
505 final_results,
506 execution_time_ms: 0,
507 },
508 })
509 }
510 BooleanQuery::Or(queries) => {
511 let mut all_bindings = Vec::new();
513 let mut seen = HashSet::new();
514
515 for q in queries {
516 let result = self.execute(q.clone())?;
517
518 for binding in result.bindings {
519 let key = format!("{:?}", binding);
520 if seen.insert(key) {
521 all_bindings.push(binding);
522 }
523 }
524 }
525
526 Ok(QueryResult {
527 bindings: all_bindings.clone(),
528 stats: QueryStats {
529 patterns_evaluated: queries.len(),
530 intermediate_results: 0,
531 final_results: all_bindings.len(),
532 execution_time_ms: 0,
533 },
534 })
535 }
536 BooleanQuery::Not(query) => {
537 let all_result = self.execute(Query::new())?;
539 let excluded_result = self.execute(query.as_ref().clone())?;
540
541 let excluded_set: HashSet<_> = excluded_result
542 .bindings
543 .into_iter()
544 .map(|b| format!("{:?}", b))
545 .collect();
546
547 let filtered: Vec<_> = all_result
548 .bindings
549 .into_iter()
550 .filter(|b| !excluded_set.contains(&format!("{:?}", b)))
551 .collect();
552
553 Ok(QueryResult {
554 bindings: filtered.clone(),
555 stats: QueryStats {
556 patterns_evaluated: 1,
557 intermediate_results: 0,
558 final_results: filtered.len(),
559 execution_time_ms: 0,
560 },
561 })
562 }
563 BooleanQuery::Atom(query) => self.execute(query.clone()),
564 }
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use ipfrs_tensorlogic::Constant;
572
573 #[test]
574 fn test_query_builder() {
575 let query = Query::new()
576 .select("X")
577 .select("Y")
578 .where_pattern(QueryPattern::Pattern {
579 name: Some("parent".to_string()),
580 args: vec![
581 TermPattern::Variable("X".to_string()),
582 TermPattern::Variable("Y".to_string()),
583 ],
584 })
585 .limit(10);
586
587 assert_eq!(query.select.len(), 2);
588 assert_eq!(query.patterns.len(), 1);
589 assert_eq!(query.limit, Some(10));
590 }
591
592 #[test]
593 fn test_query_executor() {
594 let mut kb = KnowledgeBase::new();
595
596 let alice = Term::Const(Constant::String("Alice".to_string()));
598 let bob = Term::Const(Constant::String("Bob".to_string()));
599 kb.add_fact(Predicate::new(
600 "parent".to_string(),
601 vec![alice.clone(), bob.clone()],
602 ));
603
604 let executor = QueryExecutor::new(kb);
605
606 let query = Query::new().where_pattern(QueryPattern::Pattern {
608 name: Some("parent".to_string()),
609 args: vec![TermPattern::Wildcard, TermPattern::Wildcard],
610 });
611
612 let result = executor.execute(query).unwrap();
613 assert!(!result.bindings.is_empty());
614 }
615
616 #[test]
617 fn test_pattern_matching() {
618 let mut kb = KnowledgeBase::new();
619
620 let alice = Term::Const(Constant::String("Alice".to_string()));
621 let bob = Term::Const(Constant::String("Bob".to_string()));
622 kb.add_fact(Predicate::new("parent".to_string(), vec![alice, bob]));
623
624 let executor = QueryExecutor::new(kb);
625
626 let query = Query::new()
628 .select("X")
629 .select("Y")
630 .where_pattern(QueryPattern::Pattern {
631 name: Some("parent".to_string()),
632 args: vec![
633 TermPattern::Variable("X".to_string()),
634 TermPattern::Variable("Y".to_string()),
635 ],
636 });
637
638 let result = executor.execute(query).unwrap();
639 assert_eq!(result.bindings.len(), 1);
640 assert!(result.bindings[0].contains_key("X"));
641 assert!(result.bindings[0].contains_key("Y"));
642 }
643
644 #[test]
645 fn test_filter_expr() {
646 let mut kb = KnowledgeBase::new();
647
648 let alice = Term::Const(Constant::String("Alice".to_string()));
649 let bob = Term::Const(Constant::String("Bob".to_string()));
650 kb.add_fact(Predicate::new("person".to_string(), vec![alice]));
651 kb.add_fact(Predicate::new("person".to_string(), vec![bob]));
652
653 let executor = QueryExecutor::new(kb);
654
655 let query = Query::new()
657 .select("X")
658 .where_pattern(QueryPattern::Pattern {
659 name: Some("person".to_string()),
660 args: vec![TermPattern::Variable("X".to_string())],
661 })
662 .filter(FilterExpr::IsType("X".to_string(), TermType::Const));
663
664 let result = executor.execute(query).unwrap();
665 assert_eq!(result.bindings.len(), 2);
666 }
667}