1use crate::model::*;
7use crate::query::algebra::{Expression, TermPattern};
8use crate::OxirsError;
9use smallvec::SmallVec;
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct BindingSet {
16 pub variables: SmallVec<[Variable; 8]>,
18 pub bindings: Vec<TermBinding>,
20 pub constraints: Vec<Constraint>,
22 var_index: HashMap<Variable, usize>,
24}
25
26#[derive(Debug, Clone)]
28pub struct TermBinding {
29 pub variable: Variable,
31 pub term: Term,
33 pub metadata: BindingMetadata,
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct BindingMetadata {
40 pub source_pattern_id: usize,
42 pub confidence: f64,
44 pub is_fixed: bool,
46}
47
48#[derive(Debug, Clone)]
50pub enum Constraint {
51 TypeConstraint {
53 variable: Variable,
54 allowed_types: HashSet<TermType>,
55 },
56 ValueConstraint {
58 variable: Variable,
59 constraint: ValueConstraintType,
60 },
61 RelationshipConstraint {
63 left: Variable,
64 right: Variable,
65 relation: RelationType,
66 },
67 FilterConstraint { expression: Expression },
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub enum TermType {
74 NamedNode,
75 BlankNode,
76 Literal,
77 NumericLiteral,
78 StringLiteral,
79 BooleanLiteral,
80 DateTimeLiteral,
81}
82
83#[derive(Debug, Clone)]
85pub enum ValueConstraintType {
86 NumericRange { min: f64, max: f64 },
88 StringPattern(regex::Regex),
90 OneOf(HashSet<Term>),
92 NoneOf(HashSet<Term>),
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RelationType {
99 Equal,
100 NotEqual,
101 LessThan,
102 LessThanOrEqual,
103 GreaterThan,
104 GreaterThanOrEqual,
105}
106
107impl BindingSet {
108 pub fn new() -> Self {
110 Self {
111 variables: SmallVec::new(),
112 bindings: Vec::new(),
113 constraints: Vec::new(),
114 var_index: HashMap::new(),
115 }
116 }
117
118 pub fn with_variables(vars: Vec<Variable>) -> Self {
120 let mut var_index = HashMap::new();
121 for (idx, var) in vars.iter().enumerate() {
122 var_index.insert(var.clone(), idx);
123 }
124
125 Self {
126 variables: vars.into(),
127 bindings: Vec::new(),
128 constraints: Vec::new(),
129 var_index,
130 }
131 }
132
133 pub fn add_variable(&mut self, var: Variable) -> usize {
135 if let Some(&idx) = self.var_index.get(&var) {
136 idx
137 } else {
138 let idx = self.variables.len();
139 self.variables.push(var.clone());
140 self.var_index.insert(var, idx);
141 idx
142 }
143 }
144
145 pub fn bind(
147 &mut self,
148 var: Variable,
149 term: Term,
150 metadata: BindingMetadata,
151 ) -> Result<(), OxirsError> {
152 if !self.var_index.contains_key(&var) {
154 self.add_variable(var.clone());
155 }
156
157 self.validate_binding(&var, &term)?;
159
160 self.bindings.retain(|b| b.variable != var);
162
163 self.bindings.push(TermBinding {
165 variable: var,
166 term,
167 metadata,
168 });
169
170 Ok(())
171 }
172
173 pub fn get(&self, var: &Variable) -> Option<&Term> {
175 self.bindings
176 .iter()
177 .find(|b| &b.variable == var)
178 .map(|b| &b.term)
179 }
180
181 pub fn is_bound(&self, var: &Variable) -> bool {
183 self.bindings.iter().any(|b| &b.variable == var)
184 }
185
186 pub fn unbound_variables(&self) -> Vec<&Variable> {
188 self.variables
189 .iter()
190 .filter(|v| !self.is_bound(v))
191 .collect()
192 }
193
194 pub fn add_constraint(&mut self, constraint: Constraint) {
196 self.constraints.push(constraint);
197 }
198
199 fn validate_binding(&self, var: &Variable, term: &Term) -> Result<(), OxirsError> {
201 for constraint in &self.constraints {
202 match constraint {
203 Constraint::TypeConstraint {
204 variable,
205 allowed_types,
206 } => {
207 if variable == var && !self.check_type_constraint(term, allowed_types) {
208 return Err(OxirsError::Query(format!(
209 "Type constraint violation for variable {var}"
210 )));
211 }
212 }
213 Constraint::ValueConstraint {
214 variable,
215 constraint,
216 } => {
217 if variable == var && !self.check_value_constraint(term, constraint) {
218 return Err(OxirsError::Query(format!(
219 "Value constraint violation for variable {var}"
220 )));
221 }
222 }
223 _ => {} }
225 }
226 Ok(())
227 }
228
229 fn check_type_constraint(&self, term: &Term, allowed_types: &HashSet<TermType>) -> bool {
231 let term_type = match term {
232 Term::NamedNode(_) => TermType::NamedNode,
233 Term::BlankNode(_) => TermType::BlankNode,
234 Term::Literal(lit) => {
235 let datatype = lit.datatype();
236 let datatype_str = datatype.as_str();
237 if datatype_str == "http://www.w3.org/2001/XMLSchema#integer"
238 || datatype_str == "http://www.w3.org/2001/XMLSchema#decimal"
239 || datatype_str == "http://www.w3.org/2001/XMLSchema#double"
240 {
241 TermType::NumericLiteral
242 } else if datatype_str == "http://www.w3.org/2001/XMLSchema#boolean" {
243 TermType::BooleanLiteral
244 } else if datatype_str == "http://www.w3.org/2001/XMLSchema#dateTime" {
245 TermType::DateTimeLiteral
246 } else if datatype_str == "http://www.w3.org/2001/XMLSchema#string"
247 || datatype_str == "http://www.w3.org/1999/02/22-rdf-syntax-ns#langString"
248 {
249 TermType::StringLiteral
250 } else {
251 TermType::Literal
252 }
253 }
254 _ => return false, };
256
257 if allowed_types.contains(&term_type) {
259 return true;
260 }
261
262 match term_type {
264 TermType::NumericLiteral
265 | TermType::StringLiteral
266 | TermType::BooleanLiteral
267 | TermType::DateTimeLiteral => allowed_types.contains(&TermType::Literal),
268 _ => false,
269 }
270 }
271
272 fn check_value_constraint(&self, term: &Term, constraint: &ValueConstraintType) -> bool {
274 match constraint {
275 ValueConstraintType::NumericRange { min, max } => {
276 if let Term::Literal(lit) = term {
277 if let Ok(value) = lit.value().parse::<f64>() {
278 return value >= *min && value <= *max;
279 }
280 }
281 false
282 }
283 ValueConstraintType::StringPattern(regex) => {
284 if let Term::Literal(lit) = term {
285 return regex.is_match(lit.value());
286 }
287 false
288 }
289 ValueConstraintType::OneOf(allowed) => allowed.contains(term),
290 ValueConstraintType::NoneOf(forbidden) => !forbidden.contains(term),
291 }
292 }
293
294 pub fn to_map(&self) -> HashMap<Variable, Term> {
296 self.bindings
297 .iter()
298 .map(|b| (b.variable.clone(), b.term.clone()))
299 .collect()
300 }
301
302 pub fn merge(&mut self, other: &BindingSet) -> Result<(), OxirsError> {
304 for var in &other.variables {
306 self.add_variable(var.clone());
307 }
308
309 for binding in &other.bindings {
311 if let Some(existing) = self.get(&binding.variable) {
313 if existing != &binding.term {
314 return Err(OxirsError::Query(format!(
315 "Conflicting bindings for variable {}",
316 binding.variable
317 )));
318 }
319 } else {
320 self.bindings.push(binding.clone());
321 }
322 }
323
324 self.constraints.extend(other.constraints.clone());
326
327 Ok(())
328 }
329
330 pub fn apply_to_pattern(&self, pattern: &TermPattern) -> TermPattern {
332 match pattern {
333 TermPattern::Variable(var) => {
334 if let Some(term) = self.get(var) {
335 match term {
336 Term::NamedNode(n) => TermPattern::NamedNode(n.clone()),
337 Term::BlankNode(b) => TermPattern::BlankNode(b.clone()),
338 Term::Literal(l) => TermPattern::Literal(l.clone()),
339 _ => pattern.clone(), }
341 } else {
342 pattern.clone()
343 }
344 }
345 _ => pattern.clone(),
346 }
347 }
348}
349
350pub struct BindingOptimizer {
352 binding_cache: HashMap<String, Arc<BindingSet>>,
354 stats: BindingStats,
356}
357
358#[derive(Debug, Default)]
360struct BindingStats {
361 bindings_created: usize,
363 cache_hits: usize,
365 cache_misses: usize,
367 constraint_violations: usize,
369}
370
371impl BindingOptimizer {
372 pub fn new() -> Self {
374 Self {
375 binding_cache: HashMap::new(),
376 stats: BindingStats::default(),
377 }
378 }
379
380 pub fn optimize_bindings(
382 &mut self,
383 variables: Vec<Variable>,
384 constraints: Vec<Constraint>,
385 ) -> Arc<BindingSet> {
386 let cache_key = self.create_cache_key(&variables, &constraints);
388
389 if let Some(cached) = self.binding_cache.get(&cache_key) {
391 self.stats.cache_hits += 1;
392 return Arc::clone(cached);
393 }
394
395 self.stats.cache_misses += 1;
396
397 let mut binding_set = BindingSet::with_variables(variables);
399 for constraint in constraints {
400 binding_set.add_constraint(constraint);
401 }
402
403 self.propagate_constraints(&mut binding_set);
405
406 let arc_set = Arc::new(binding_set);
408 self.binding_cache.insert(cache_key, Arc::clone(&arc_set));
409 arc_set
410 }
411
412 fn create_cache_key(&self, variables: &[Variable], constraints: &[Constraint]) -> String {
414 let mut key = String::new();
415 for var in variables {
416 key.push_str(var.as_str());
417 key.push(',');
418 }
419 key.push('|');
420 key.push_str(&format!("{}", constraints.len()));
422 key
423 }
424
425 fn propagate_constraints(&self, binding_set: &mut BindingSet) {
427 let mut constraint_graph: HashMap<Variable, Vec<usize>> = HashMap::new();
429
430 for (idx, constraint) in binding_set.constraints.iter().enumerate() {
431 match constraint {
432 Constraint::TypeConstraint { variable, .. }
433 | Constraint::ValueConstraint { variable, .. } => {
434 constraint_graph
435 .entry(variable.clone())
436 .or_default()
437 .push(idx);
438 }
439 Constraint::RelationshipConstraint { left, right, .. } => {
440 constraint_graph.entry(left.clone()).or_default().push(idx);
441 constraint_graph.entry(right.clone()).or_default().push(idx);
442 }
443 _ => {}
444 }
445 }
446
447 self.propagate_equality_constraints(binding_set, constraint_graph);
449 }
450
451 fn propagate_equality_constraints(
453 &self,
454 binding_set: &mut BindingSet,
455 constraint_graph: HashMap<Variable, Vec<usize>>,
456 ) {
457 let mut equiv_classes: HashMap<Variable, Variable> = HashMap::new();
459
460 for constraint in &binding_set.constraints {
461 if let Constraint::RelationshipConstraint {
462 left,
463 right,
464 relation: RelationType::Equal,
465 } = constraint
466 {
467 let left_root = self.find_root(left, &equiv_classes);
469 let right_root = self.find_root(right, &equiv_classes);
470 if left_root != right_root {
471 equiv_classes.insert(left_root, right_root.clone());
472 }
473 }
474 }
475
476 for (var, root) in &equiv_classes {
478 if var != root {
479 if let Some(root_constraints) = constraint_graph.get(root) {
481 for &_constraint in root_constraints {
482 }
484 }
485 }
486 }
487 }
488
489 fn find_root<'a>(
491 &self,
492 var: &'a Variable,
493 equiv_classes: &'a HashMap<Variable, Variable>,
494 ) -> Variable {
495 let mut current = var.clone();
496 while let Some(parent) = equiv_classes.get(¤t) {
497 if parent == ¤t {
498 break;
499 }
500 current = parent.clone();
501 }
502 current
503 }
504
505 pub fn stats(&self) -> String {
507 format!(
508 "Bindings created: {}, Cache hits: {}, Cache misses: {}, Violations: {}",
509 self.stats.bindings_created,
510 self.stats.cache_hits,
511 self.stats.cache_misses,
512 self.stats.constraint_violations
513 )
514 }
515}
516
517pub struct BindingIterator {
519 base_bindings: Vec<HashMap<Variable, Term>>,
521 variables: Vec<Variable>,
523 possible_values: HashMap<Variable, Vec<Term>>,
525 current_position: Vec<usize>,
527 constraints: Vec<Constraint>,
529}
530
531impl BindingIterator {
532 pub fn new(
534 base_bindings: Vec<HashMap<Variable, Term>>,
535 variables: Vec<Variable>,
536 possible_values: HashMap<Variable, Vec<Term>>,
537 constraints: Vec<Constraint>,
538 ) -> Self {
539 let current_position = vec![0; variables.len()];
540 Self {
541 base_bindings,
542 variables,
543 possible_values,
544 current_position,
545 constraints,
546 }
547 }
548
549 pub fn next_valid(&mut self) -> Option<HashMap<Variable, Term>> {
551 while let Some(binding) = self.next_combination() {
552 if self.validate_binding(&binding) {
553 return Some(binding);
554 }
555 }
556 None
557 }
558
559 fn next_combination(&mut self) -> Option<HashMap<Variable, Term>> {
561 if self.base_bindings.is_empty() {
562 return None;
563 }
564
565 if self.current_position.iter().all(|&p| p == 0) && !self.current_position.is_empty() {
567 } else if self.current_position.is_empty() {
569 return None;
570 }
571
572 let mut result = self.base_bindings[0].clone();
574 for (i, var) in self.variables.iter().enumerate() {
575 if let Some(values) = self.possible_values.get(var) {
576 if self.current_position[i] < values.len() {
577 result.insert(var.clone(), values[self.current_position[i]].clone());
578 }
579 }
580 }
581
582 self.increment_position();
584
585 Some(result)
586 }
587
588 fn increment_position(&mut self) {
590 for i in (0..self.current_position.len()).rev() {
591 if let Some(values) = self.possible_values.get(&self.variables[i]) {
592 if self.current_position[i] + 1 < values.len() {
593 self.current_position[i] += 1;
594 return;
595 } else {
596 self.current_position[i] = 0;
597 }
598 }
599 }
600 self.current_position.clear();
602 }
603
604 fn validate_binding(&self, binding: &HashMap<Variable, Term>) -> bool {
606 for constraint in &self.constraints {
607 if let Constraint::RelationshipConstraint {
608 left,
609 right,
610 relation,
611 } = constraint
612 {
613 if let (Some(left_val), Some(right_val)) = (binding.get(left), binding.get(right)) {
614 if !self.check_relation(left_val, right_val, *relation) {
615 return false;
616 }
617 }
618 }
619 }
621 true
622 }
623
624 fn check_relation(&self, left: &Term, right: &Term, relation: RelationType) -> bool {
626 match relation {
627 RelationType::Equal => left == right,
628 RelationType::NotEqual => left != right,
629 _ => {
630 if let (Term::Literal(l_lit), Term::Literal(r_lit)) = (left, right) {
632 if let (Ok(l_val), Ok(r_val)) =
633 (l_lit.value().parse::<f64>(), r_lit.value().parse::<f64>())
634 {
635 match relation {
636 RelationType::LessThan => l_val < r_val,
637 RelationType::LessThanOrEqual => l_val <= r_val,
638 RelationType::GreaterThan => l_val > r_val,
639 RelationType::GreaterThanOrEqual => l_val >= r_val,
640 _ => false,
641 }
642 } else {
643 false
644 }
645 } else {
646 false
647 }
648 }
649 }
650 }
651}
652
653impl Default for BindingSet {
654 fn default() -> Self {
655 Self::new()
656 }
657}
658
659impl Default for BindingOptimizer {
660 fn default() -> Self {
661 Self::new()
662 }
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668
669 #[test]
670 fn test_binding_set_basic() {
671 let mut bindings = BindingSet::new();
672 let var_x = Variable::new("x").unwrap();
673 let var_y = Variable::new("y").unwrap();
674
675 bindings.add_variable(var_x.clone());
677 bindings.add_variable(var_y.clone());
678
679 let term = Term::NamedNode(NamedNode::new("http://example.org/alice").unwrap());
681 bindings
682 .bind(var_x.clone(), term.clone(), BindingMetadata::default())
683 .unwrap();
684
685 assert_eq!(bindings.get(&var_x), Some(&term));
687 assert_eq!(bindings.get(&var_y), None);
688
689 let unbound = bindings.unbound_variables();
691 assert_eq!(unbound.len(), 1);
692 assert_eq!(unbound[0], &var_y);
693 }
694
695 #[test]
696 fn test_type_constraints() {
697 let mut bindings = BindingSet::new();
698 let var = Variable::new("x").unwrap();
699
700 bindings.add_constraint(Constraint::TypeConstraint {
702 variable: var.clone(),
703 allowed_types: vec![TermType::Literal, TermType::NumericLiteral]
704 .into_iter()
705 .collect(),
706 });
707
708 let named_node = Term::NamedNode(NamedNode::new("http://example.org/thing").unwrap());
710 assert!(bindings
711 .bind(var.clone(), named_node, BindingMetadata::default())
712 .is_err());
713
714 let literal = Term::Literal(Literal::new("test"));
716 assert!(bindings
717 .bind(var.clone(), literal, BindingMetadata::default())
718 .is_ok());
719 }
720
721 #[test]
722 fn test_value_constraints() {
723 let mut bindings = BindingSet::new();
724 let var = Variable::new("age").unwrap();
725
726 bindings.add_constraint(Constraint::ValueConstraint {
728 variable: var.clone(),
729 constraint: ValueConstraintType::NumericRange {
730 min: 0.0,
731 max: 150.0,
732 },
733 });
734
735 let valid_age = Term::Literal(Literal::new("25"));
737 assert!(bindings
738 .bind(var.clone(), valid_age, BindingMetadata::default())
739 .is_ok());
740
741 let invalid_age = Term::Literal(Literal::new("200"));
743 assert!(bindings
744 .bind(var.clone(), invalid_age, BindingMetadata::default())
745 .is_err());
746 }
747
748 #[test]
749 fn test_binding_merge() {
750 let mut bindings1 = BindingSet::new();
751 let mut bindings2 = BindingSet::new();
752
753 let var_x = Variable::new("x").unwrap();
754 let var_y = Variable::new("y").unwrap();
755
756 let term_x = Term::NamedNode(NamedNode::new("http://example.org/x").unwrap());
757 let term_y = Term::NamedNode(NamedNode::new("http://example.org/y").unwrap());
758
759 bindings1
760 .bind(var_x.clone(), term_x.clone(), BindingMetadata::default())
761 .unwrap();
762 bindings2
763 .bind(var_y.clone(), term_y.clone(), BindingMetadata::default())
764 .unwrap();
765
766 bindings1.merge(&bindings2).unwrap();
768
769 assert_eq!(bindings1.get(&var_x), Some(&term_x));
771 assert_eq!(bindings1.get(&var_y), Some(&term_y));
772 }
773
774 #[test]
775 fn test_binding_optimizer() {
776 let mut optimizer = BindingOptimizer::new();
777
778 let vars = vec![Variable::new("x").unwrap(), Variable::new("y").unwrap()];
779 let constraints = vec![];
780
781 let _bindings1 = optimizer.optimize_bindings(vars.clone(), constraints.clone());
783
784 let _bindings2 = optimizer.optimize_bindings(vars, constraints);
786
787 let stats = optimizer.stats();
788 assert!(stats.contains("Cache hits: 1"));
789 }
790}