1use crate::model::*;
7use crate::query::algebra::{Expression, TermPattern};
8use crate::OxirsError;
9use smallvec::SmallVec;
10use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13#[derive(Debug, Clone)]
15pub struct BindingSet {
16 pub variables: SmallVec<[Variable; 8]>,
18 pub bindings: Vec<TermBinding>,
20 pub constraints: Vec<Constraint>,
22 var_index: HashMap<Variable, usize>,
24}
25
26#[derive(Debug, Clone)]
28pub struct TermBinding {
29 pub variable: Variable,
31 pub term: Term,
33 pub metadata: BindingMetadata,
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct BindingMetadata {
40 pub source_pattern_id: usize,
42 pub confidence: f64,
44 pub is_fixed: bool,
46}
47
48#[derive(Debug, Clone)]
50pub enum Constraint {
51 TypeConstraint {
53 variable: Variable,
54 allowed_types: HashSet<TermType>,
55 },
56 ValueConstraint {
58 variable: Variable,
59 constraint: ValueConstraintType,
60 },
61 RelationshipConstraint {
63 left: Variable,
64 right: Variable,
65 relation: RelationType,
66 },
67 FilterConstraint { expression: Expression },
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73pub enum TermType {
74 NamedNode,
75 BlankNode,
76 Literal,
77 NumericLiteral,
78 StringLiteral,
79 BooleanLiteral,
80 DateTimeLiteral,
81}
82
83#[derive(Debug, Clone)]
85pub enum ValueConstraintType {
86 NumericRange { min: f64, max: f64 },
88 StringPattern(regex::Regex),
90 OneOf(HashSet<Term>),
92 NoneOf(HashSet<Term>),
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RelationType {
99 Equal,
100 NotEqual,
101 LessThan,
102 LessThanOrEqual,
103 GreaterThan,
104 GreaterThanOrEqual,
105}
106
107impl BindingSet {
108 pub fn new() -> Self {
110 Self {
111 variables: SmallVec::new(),
112 bindings: Vec::new(),
113 constraints: Vec::new(),
114 var_index: HashMap::new(),
115 }
116 }
117
118 pub fn with_variables(vars: Vec<Variable>) -> Self {
120 let mut var_index = HashMap::new();
121 for (idx, var) in vars.iter().enumerate() {
122 var_index.insert(var.clone(), idx);
123 }
124
125 Self {
126 variables: vars.into(),
127 bindings: Vec::new(),
128 constraints: Vec::new(),
129 var_index,
130 }
131 }
132
133 pub fn add_variable(&mut self, var: Variable) -> usize {
135 if let Some(&idx) = self.var_index.get(&var) {
136 idx
137 } else {
138 let idx = self.variables.len();
139 self.variables.push(var.clone());
140 self.var_index.insert(var, idx);
141 idx
142 }
143 }
144
145 pub fn bind(
147 &mut self,
148 var: Variable,
149 term: Term,
150 metadata: BindingMetadata,
151 ) -> Result<(), OxirsError> {
152 if !self.var_index.contains_key(&var) {
154 self.add_variable(var.clone());
155 }
156
157 self.validate_binding(&var, &term)?;
159
160 self.bindings.retain(|b| b.variable != var);
162
163 self.bindings.push(TermBinding {
165 variable: var,
166 term,
167 metadata,
168 });
169
170 Ok(())
171 }
172
173 pub fn get(&self, var: &Variable) -> Option<&Term> {
175 self.bindings
176 .iter()
177 .find(|b| &b.variable == var)
178 .map(|b| &b.term)
179 }
180
181 pub fn is_bound(&self, var: &Variable) -> bool {
183 self.bindings.iter().any(|b| &b.variable == var)
184 }
185
186 pub fn unbound_variables(&self) -> Vec<&Variable> {
188 self.variables
189 .iter()
190 .filter(|v| !self.is_bound(v))
191 .collect()
192 }
193
194 pub fn add_constraint(&mut self, constraint: Constraint) {
196 self.constraints.push(constraint);
197 }
198
199 fn validate_binding(&self, var: &Variable, term: &Term) -> Result<(), OxirsError> {
201 for constraint in &self.constraints {
202 match constraint {
203 Constraint::TypeConstraint {
204 variable,
205 allowed_types,
206 } if variable == var && !self.check_type_constraint(term, allowed_types) => {
207 return Err(OxirsError::Query(format!(
208 "Type constraint violation for variable {var}"
209 )));
210 }
211 Constraint::ValueConstraint {
212 variable,
213 constraint,
214 } if variable == var && !self.check_value_constraint(term, constraint) => {
215 return Err(OxirsError::Query(format!(
216 "Value constraint violation for variable {var}"
217 )));
218 }
219 _ => {} }
221 }
222 Ok(())
223 }
224
225 fn check_type_constraint(&self, term: &Term, allowed_types: &HashSet<TermType>) -> bool {
227 let term_type = match term {
228 Term::NamedNode(_) => TermType::NamedNode,
229 Term::BlankNode(_) => TermType::BlankNode,
230 Term::Literal(lit) => {
231 let datatype = lit.datatype();
232 let datatype_str = datatype.as_str();
233 if datatype_str == "http://www.w3.org/2001/XMLSchema#integer"
234 || datatype_str == "http://www.w3.org/2001/XMLSchema#decimal"
235 || datatype_str == "http://www.w3.org/2001/XMLSchema#double"
236 {
237 TermType::NumericLiteral
238 } else if datatype_str == "http://www.w3.org/2001/XMLSchema#boolean" {
239 TermType::BooleanLiteral
240 } else if datatype_str == "http://www.w3.org/2001/XMLSchema#dateTime" {
241 TermType::DateTimeLiteral
242 } else if datatype_str == "http://www.w3.org/2001/XMLSchema#string"
243 || datatype_str == "http://www.w3.org/1999/02/22-rdf-syntax-ns#langString"
244 {
245 TermType::StringLiteral
246 } else {
247 TermType::Literal
248 }
249 }
250 _ => return false, };
252
253 if allowed_types.contains(&term_type) {
255 return true;
256 }
257
258 match term_type {
260 TermType::NumericLiteral
261 | TermType::StringLiteral
262 | TermType::BooleanLiteral
263 | TermType::DateTimeLiteral => allowed_types.contains(&TermType::Literal),
264 _ => false,
265 }
266 }
267
268 fn check_value_constraint(&self, term: &Term, constraint: &ValueConstraintType) -> bool {
270 match constraint {
271 ValueConstraintType::NumericRange { min, max } => {
272 if let Term::Literal(lit) = term {
273 if let Ok(value) = lit.value().parse::<f64>() {
274 return value >= *min && value <= *max;
275 }
276 }
277 false
278 }
279 ValueConstraintType::StringPattern(regex) => {
280 if let Term::Literal(lit) = term {
281 return regex.is_match(lit.value());
282 }
283 false
284 }
285 ValueConstraintType::OneOf(allowed) => allowed.contains(term),
286 ValueConstraintType::NoneOf(forbidden) => !forbidden.contains(term),
287 }
288 }
289
290 pub fn to_map(&self) -> HashMap<Variable, Term> {
292 self.bindings
293 .iter()
294 .map(|b| (b.variable.clone(), b.term.clone()))
295 .collect()
296 }
297
298 pub fn merge(&mut self, other: &BindingSet) -> Result<(), OxirsError> {
300 for var in &other.variables {
302 self.add_variable(var.clone());
303 }
304
305 for binding in &other.bindings {
307 if let Some(existing) = self.get(&binding.variable) {
309 if existing != &binding.term {
310 return Err(OxirsError::Query(format!(
311 "Conflicting bindings for variable {}",
312 binding.variable
313 )));
314 }
315 } else {
316 self.bindings.push(binding.clone());
317 }
318 }
319
320 self.constraints.extend(other.constraints.clone());
322
323 Ok(())
324 }
325
326 pub fn apply_to_pattern(&self, pattern: &TermPattern) -> TermPattern {
328 match pattern {
329 TermPattern::Variable(var) => {
330 if let Some(term) = self.get(var) {
331 match term {
332 Term::NamedNode(n) => TermPattern::NamedNode(n.clone()),
333 Term::BlankNode(b) => TermPattern::BlankNode(b.clone()),
334 Term::Literal(l) => TermPattern::Literal(l.clone()),
335 _ => pattern.clone(), }
337 } else {
338 pattern.clone()
339 }
340 }
341 _ => pattern.clone(),
342 }
343 }
344}
345
346pub struct BindingOptimizer {
348 binding_cache: HashMap<String, Arc<BindingSet>>,
350 stats: BindingStats,
352}
353
354#[derive(Debug, Default)]
356struct BindingStats {
357 bindings_created: usize,
359 cache_hits: usize,
361 cache_misses: usize,
363 constraint_violations: usize,
365}
366
367impl BindingOptimizer {
368 pub fn new() -> Self {
370 Self {
371 binding_cache: HashMap::new(),
372 stats: BindingStats::default(),
373 }
374 }
375
376 pub fn optimize_bindings(
378 &mut self,
379 variables: Vec<Variable>,
380 constraints: Vec<Constraint>,
381 ) -> Arc<BindingSet> {
382 let cache_key = self.create_cache_key(&variables, &constraints);
384
385 if let Some(cached) = self.binding_cache.get(&cache_key) {
387 self.stats.cache_hits += 1;
388 return Arc::clone(cached);
389 }
390
391 self.stats.cache_misses += 1;
392
393 let mut binding_set = BindingSet::with_variables(variables);
395 for constraint in constraints {
396 binding_set.add_constraint(constraint);
397 }
398
399 self.propagate_constraints(&mut binding_set);
401
402 let arc_set = Arc::new(binding_set);
404 self.binding_cache.insert(cache_key, Arc::clone(&arc_set));
405 arc_set
406 }
407
408 fn create_cache_key(&self, variables: &[Variable], constraints: &[Constraint]) -> String {
410 let mut key = String::new();
411 for var in variables {
412 key.push_str(var.as_str());
413 key.push(',');
414 }
415 key.push('|');
416 key.push_str(&format!("{}", constraints.len()));
418 key
419 }
420
421 fn propagate_constraints(&self, binding_set: &mut BindingSet) {
423 let mut constraint_graph: HashMap<Variable, Vec<usize>> = HashMap::new();
425
426 for (idx, constraint) in binding_set.constraints.iter().enumerate() {
427 match constraint {
428 Constraint::TypeConstraint { variable, .. }
429 | Constraint::ValueConstraint { variable, .. } => {
430 constraint_graph
431 .entry(variable.clone())
432 .or_default()
433 .push(idx);
434 }
435 Constraint::RelationshipConstraint { left, right, .. } => {
436 constraint_graph.entry(left.clone()).or_default().push(idx);
437 constraint_graph.entry(right.clone()).or_default().push(idx);
438 }
439 _ => {}
440 }
441 }
442
443 self.propagate_equality_constraints(binding_set, constraint_graph);
445 }
446
447 fn propagate_equality_constraints(
449 &self,
450 binding_set: &mut BindingSet,
451 constraint_graph: HashMap<Variable, Vec<usize>>,
452 ) {
453 let mut equiv_classes: HashMap<Variable, Variable> = HashMap::new();
455
456 for constraint in &binding_set.constraints {
457 if let Constraint::RelationshipConstraint {
458 left,
459 right,
460 relation: RelationType::Equal,
461 } = constraint
462 {
463 let left_root = self.find_root(left, &equiv_classes);
465 let right_root = self.find_root(right, &equiv_classes);
466 if left_root != right_root {
467 equiv_classes.insert(left_root, right_root.clone());
468 }
469 }
470 }
471
472 for (var, root) in &equiv_classes {
474 if var != root {
475 if let Some(root_constraints) = constraint_graph.get(root) {
477 for &_constraint in root_constraints {
478 }
480 }
481 }
482 }
483 }
484
485 fn find_root<'a>(
487 &self,
488 var: &'a Variable,
489 equiv_classes: &'a HashMap<Variable, Variable>,
490 ) -> Variable {
491 let mut current = var.clone();
492 while let Some(parent) = equiv_classes.get(¤t) {
493 if parent == ¤t {
494 break;
495 }
496 current = parent.clone();
497 }
498 current
499 }
500
501 pub fn stats(&self) -> String {
503 format!(
504 "Bindings created: {}, Cache hits: {}, Cache misses: {}, Violations: {}",
505 self.stats.bindings_created,
506 self.stats.cache_hits,
507 self.stats.cache_misses,
508 self.stats.constraint_violations
509 )
510 }
511}
512
513pub struct BindingIterator {
515 base_bindings: Vec<HashMap<Variable, Term>>,
517 variables: Vec<Variable>,
519 possible_values: HashMap<Variable, Vec<Term>>,
521 current_position: Vec<usize>,
523 constraints: Vec<Constraint>,
525}
526
527impl BindingIterator {
528 pub fn new(
530 base_bindings: Vec<HashMap<Variable, Term>>,
531 variables: Vec<Variable>,
532 possible_values: HashMap<Variable, Vec<Term>>,
533 constraints: Vec<Constraint>,
534 ) -> Self {
535 let current_position = vec![0; variables.len()];
536 Self {
537 base_bindings,
538 variables,
539 possible_values,
540 current_position,
541 constraints,
542 }
543 }
544
545 pub fn next_valid(&mut self) -> Option<HashMap<Variable, Term>> {
547 while let Some(binding) = self.next_combination() {
548 if self.validate_binding(&binding) {
549 return Some(binding);
550 }
551 }
552 None
553 }
554
555 fn next_combination(&mut self) -> Option<HashMap<Variable, Term>> {
557 if self.base_bindings.is_empty() {
558 return None;
559 }
560
561 if self.current_position.iter().all(|&p| p == 0) && !self.current_position.is_empty() {
563 } else if self.current_position.is_empty() {
565 return None;
566 }
567
568 let mut result = self.base_bindings[0].clone();
570 for (i, var) in self.variables.iter().enumerate() {
571 if let Some(values) = self.possible_values.get(var) {
572 if self.current_position[i] < values.len() {
573 result.insert(var.clone(), values[self.current_position[i]].clone());
574 }
575 }
576 }
577
578 self.increment_position();
580
581 Some(result)
582 }
583
584 fn increment_position(&mut self) {
586 for i in (0..self.current_position.len()).rev() {
587 if let Some(values) = self.possible_values.get(&self.variables[i]) {
588 if self.current_position[i] + 1 < values.len() {
589 self.current_position[i] += 1;
590 return;
591 } else {
592 self.current_position[i] = 0;
593 }
594 }
595 }
596 self.current_position.clear();
598 }
599
600 fn validate_binding(&self, binding: &HashMap<Variable, Term>) -> bool {
602 for constraint in &self.constraints {
603 if let Constraint::RelationshipConstraint {
604 left,
605 right,
606 relation,
607 } = constraint
608 {
609 if let (Some(left_val), Some(right_val)) = (binding.get(left), binding.get(right)) {
610 if !self.check_relation(left_val, right_val, *relation) {
611 return false;
612 }
613 }
614 }
615 }
617 true
618 }
619
620 fn check_relation(&self, left: &Term, right: &Term, relation: RelationType) -> bool {
622 match relation {
623 RelationType::Equal => left == right,
624 RelationType::NotEqual => left != right,
625 _ => {
626 if let (Term::Literal(l_lit), Term::Literal(r_lit)) = (left, right) {
628 if let (Ok(l_val), Ok(r_val)) =
629 (l_lit.value().parse::<f64>(), r_lit.value().parse::<f64>())
630 {
631 match relation {
632 RelationType::LessThan => l_val < r_val,
633 RelationType::LessThanOrEqual => l_val <= r_val,
634 RelationType::GreaterThan => l_val > r_val,
635 RelationType::GreaterThanOrEqual => l_val >= r_val,
636 _ => false,
637 }
638 } else {
639 false
640 }
641 } else {
642 false
643 }
644 }
645 }
646 }
647}
648
649impl Default for BindingSet {
650 fn default() -> Self {
651 Self::new()
652 }
653}
654
655impl Default for BindingOptimizer {
656 fn default() -> Self {
657 Self::new()
658 }
659}
660
661#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_binding_set_basic() {
667 let mut bindings = BindingSet::new();
668 let var_x = Variable::new("x").expect("valid variable name");
669 let var_y = Variable::new("y").expect("valid variable name");
670
671 bindings.add_variable(var_x.clone());
673 bindings.add_variable(var_y.clone());
674
675 let term = Term::NamedNode(NamedNode::new("http://example.org/alice").expect("valid IRI"));
677 bindings
678 .bind(var_x.clone(), term.clone(), BindingMetadata::default())
679 .expect("operation should succeed");
680
681 assert_eq!(bindings.get(&var_x), Some(&term));
683 assert_eq!(bindings.get(&var_y), None);
684
685 let unbound = bindings.unbound_variables();
687 assert_eq!(unbound.len(), 1);
688 assert_eq!(unbound[0], &var_y);
689 }
690
691 #[test]
692 fn test_type_constraints() {
693 let mut bindings = BindingSet::new();
694 let var = Variable::new("x").expect("valid variable name");
695
696 bindings.add_constraint(Constraint::TypeConstraint {
698 variable: var.clone(),
699 allowed_types: vec![TermType::Literal, TermType::NumericLiteral]
700 .into_iter()
701 .collect(),
702 });
703
704 let named_node =
706 Term::NamedNode(NamedNode::new("http://example.org/thing").expect("valid IRI"));
707 assert!(bindings
708 .bind(var.clone(), named_node, BindingMetadata::default())
709 .is_err());
710
711 let literal = Term::Literal(Literal::new("test"));
713 assert!(bindings
714 .bind(var.clone(), literal, BindingMetadata::default())
715 .is_ok());
716 }
717
718 #[test]
719 fn test_value_constraints() {
720 let mut bindings = BindingSet::new();
721 let var = Variable::new("age").expect("valid variable name");
722
723 bindings.add_constraint(Constraint::ValueConstraint {
725 variable: var.clone(),
726 constraint: ValueConstraintType::NumericRange {
727 min: 0.0,
728 max: 150.0,
729 },
730 });
731
732 let valid_age = Term::Literal(Literal::new("25"));
734 assert!(bindings
735 .bind(var.clone(), valid_age, BindingMetadata::default())
736 .is_ok());
737
738 let invalid_age = Term::Literal(Literal::new("200"));
740 assert!(bindings
741 .bind(var.clone(), invalid_age, BindingMetadata::default())
742 .is_err());
743 }
744
745 #[test]
746 fn test_binding_merge() {
747 let mut bindings1 = BindingSet::new();
748 let mut bindings2 = BindingSet::new();
749
750 let var_x = Variable::new("x").expect("valid variable name");
751 let var_y = Variable::new("y").expect("valid variable name");
752
753 let term_x = Term::NamedNode(NamedNode::new("http://example.org/x").expect("valid IRI"));
754 let term_y = Term::NamedNode(NamedNode::new("http://example.org/y").expect("valid IRI"));
755
756 bindings1
757 .bind(var_x.clone(), term_x.clone(), BindingMetadata::default())
758 .expect("operation should succeed");
759 bindings2
760 .bind(var_y.clone(), term_y.clone(), BindingMetadata::default())
761 .expect("operation should succeed");
762
763 bindings1.merge(&bindings2).expect("merge should succeed");
765
766 assert_eq!(bindings1.get(&var_x), Some(&term_x));
768 assert_eq!(bindings1.get(&var_y), Some(&term_y));
769 }
770
771 #[test]
772 fn test_binding_optimizer() {
773 let mut optimizer = BindingOptimizer::new();
774
775 let vars = vec![
776 Variable::new("x").expect("valid variable name"),
777 Variable::new("y").expect("valid variable name"),
778 ];
779 let constraints = vec![];
780
781 let _bindings1 = optimizer.optimize_bindings(vars.clone(), constraints.clone());
783
784 let _bindings2 = optimizer.optimize_bindings(vars, constraints);
786
787 let stats = optimizer.stats();
788 assert!(stats.contains("Cache hits: 1"));
789 }
790}