kizzasi_logic/
constraint_propagation.rs

1//! Constraint Propagation for Discrete CSP
2//!
3//! This module implements constraint propagation algorithms for discrete
4//! Constraint Satisfaction Problems (CSP):
5//! - AC-3 (Arc Consistency 3) algorithm
6//! - Domain reduction and filtering
7//! - Forward checking
8//! - Backtracking search with constraint propagation
9//!
10//! # Use Cases
11//!
12//! - Scheduling problems
13//! - Resource allocation
14//! - Configuration problems
15//! - Graph coloring
16//! - Sudoku and puzzle solving
17
18use crate::error::{LogicError, LogicResult};
19use std::collections::{HashMap, HashSet, VecDeque};
20
21/// Domain of possible values for a variable
22pub type Domain = HashSet<i32>;
23
24/// Variable identifier
25pub type VarId = usize;
26
27/// Discrete constraint between variables
28#[derive(Debug, Clone)]
29pub enum DiscreteConstraint {
30    /// Binary constraint: R(x, y)
31    Binary {
32        /// First variable
33        var1: VarId,
34        /// Second variable
35        var2: VarId,
36        /// Relation: set of allowed (var1, var2) pairs
37        relation: HashSet<(i32, i32)>,
38    },
39
40    /// All-different constraint
41    AllDifferent {
42        /// Variables that must have different values
43        variables: Vec<VarId>,
44    },
45
46    /// Sum constraint: Σ x_i = target
47    Sum {
48        /// Variables to sum
49        variables: Vec<VarId>,
50        /// Target sum
51        target: i32,
52    },
53
54    /// Less-than constraint: x < y
55    LessThan {
56        /// First variable
57        var1: VarId,
58        /// Second variable
59        var2: VarId,
60    },
61
62    /// Greater-than constraint: x > y
63    GreaterThan {
64        /// First variable
65        var1: VarId,
66        /// Second variable
67        var2: VarId,
68    },
69}
70
71impl DiscreteConstraint {
72    /// Get all variables involved in this constraint
73    pub fn variables(&self) -> Vec<VarId> {
74        match self {
75            Self::Binary { var1, var2, .. } => vec![*var1, *var2],
76            Self::AllDifferent { variables } => variables.clone(),
77            Self::Sum { variables, .. } => variables.clone(),
78            Self::LessThan { var1, var2 } => vec![*var1, *var2],
79            Self::GreaterThan { var1, var2 } => vec![*var1, *var2],
80        }
81    }
82
83    /// Check if constraint is binary (involves exactly 2 variables)
84    pub fn is_binary(&self) -> bool {
85        matches!(
86            self,
87            Self::Binary { .. } | Self::LessThan { .. } | Self::GreaterThan { .. }
88        )
89    }
90
91    /// Check if assignment satisfies constraint
92    pub fn is_satisfied(&self, assignment: &HashMap<VarId, i32>) -> bool {
93        match self {
94            Self::Binary {
95                var1,
96                var2,
97                relation,
98            } => {
99                if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
100                    relation.contains(&(v1, v2))
101                } else {
102                    true // Not fully assigned yet
103                }
104            }
105            Self::AllDifferent { variables } => {
106                let values: Vec<i32> = variables
107                    .iter()
108                    .filter_map(|v| assignment.get(v))
109                    .copied()
110                    .collect();
111
112                let unique: HashSet<_> = values.iter().collect();
113                values.len() == unique.len()
114            }
115            Self::Sum { variables, target } => {
116                if variables.iter().all(|v| assignment.contains_key(v)) {
117                    let sum: i32 = variables.iter().filter_map(|v| assignment.get(v)).sum();
118                    sum == *target
119                } else {
120                    true // Not fully assigned yet
121                }
122            }
123            Self::LessThan { var1, var2 } => {
124                if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
125                    v1 < v2
126                } else {
127                    true
128                }
129            }
130            Self::GreaterThan { var1, var2 } => {
131                if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
132                    v1 > v2
133                } else {
134                    true
135                }
136            }
137        }
138    }
139}
140
141/// Constraint Satisfaction Problem
142pub struct CSP {
143    /// Number of variables
144    num_variables: usize,
145    /// Domain for each variable
146    domains: Vec<Domain>,
147    /// Constraints
148    constraints: Vec<DiscreteConstraint>,
149}
150
151impl CSP {
152    /// Create a new CSP
153    pub fn new(num_variables: usize, initial_domains: Vec<Domain>) -> LogicResult<Self> {
154        if initial_domains.len() != num_variables {
155            return Err(LogicError::InvalidInput(
156                "Domain count must match variable count".to_string(),
157            ));
158        }
159
160        Ok(Self {
161            num_variables,
162            domains: initial_domains,
163            constraints: Vec::new(),
164        })
165    }
166
167    /// Add a constraint
168    pub fn add_constraint(&mut self, constraint: DiscreteConstraint) {
169        self.constraints.push(constraint);
170    }
171
172    /// Get domain of a variable
173    pub fn domain(&self, var: VarId) -> Option<&Domain> {
174        self.domains.get(var)
175    }
176
177    /// Get all constraints involving a variable
178    pub fn constraints_for_variable(&self, var: VarId) -> Vec<&DiscreteConstraint> {
179        self.constraints
180            .iter()
181            .filter(|c| c.variables().contains(&var))
182            .collect()
183    }
184
185    /// Check if assignment is complete
186    pub fn is_complete(&self, assignment: &HashMap<VarId, i32>) -> bool {
187        assignment.len() == self.num_variables
188    }
189
190    /// Check if assignment satisfies all constraints
191    pub fn is_consistent(&self, assignment: &HashMap<VarId, i32>) -> bool {
192        self.constraints.iter().all(|c| c.is_satisfied(assignment))
193    }
194}
195
196/// AC-3 Algorithm for Arc Consistency
197pub struct AC3 {
198    /// CSP instance
199    csp: CSP,
200}
201
202impl AC3 {
203    /// Create a new AC-3 solver
204    pub fn new(csp: CSP) -> Self {
205        Self { csp }
206    }
207
208    /// Enforce arc consistency
209    ///
210    /// Returns true if CSP is consistent, false if inconsistency detected
211    pub fn enforce_arc_consistency(&mut self) -> bool {
212        // Build queue of arcs to check
213        let mut queue: VecDeque<(VarId, VarId)> = VecDeque::new();
214
215        // Add all binary constraint arcs
216        for constraint in &self.csp.constraints {
217            if let DiscreteConstraint::Binary { var1, var2, .. }
218            | DiscreteConstraint::LessThan { var1, var2 }
219            | DiscreteConstraint::GreaterThan { var1, var2 } = constraint
220            {
221                queue.push_back((*var1, *var2));
222                queue.push_back((*var2, *var1));
223            }
224        }
225
226        // Process arcs
227        while let Some((xi, xj)) = queue.pop_front() {
228            if self.revise(xi, xj) {
229                if self.csp.domains[xi].is_empty() {
230                    return false; // Inconsistency detected
231                }
232
233                // Add all arcs (xk, xi) where xk is a neighbor of xi
234                for constraint in &self.csp.constraints.clone() {
235                    let vars = constraint.variables();
236                    if vars.contains(&xi) && vars.len() == 2 {
237                        for &xk in &vars {
238                            if xk != xi && xk != xj {
239                                queue.push_back((xk, xi));
240                            }
241                        }
242                    }
243                }
244            }
245        }
246
247        true
248    }
249
250    /// Revise domain of xi with respect to xj
251    ///
252    /// Returns true if domain of xi was revised
253    fn revise(&mut self, xi: VarId, xj: VarId) -> bool {
254        let mut revised = false;
255
256        // Find constraint between xi and xj
257        let constraint = self
258            .csp
259            .constraints
260            .iter()
261            .find(|c| {
262                let vars = c.variables();
263                vars.len() == 2 && vars.contains(&xi) && vars.contains(&xj)
264            })
265            .cloned();
266
267        if let Some(constraint) = constraint {
268            let domain_j = self.csp.domains[xj].clone();
269            let mut new_domain_i = HashSet::new();
270
271            for &vi in &self.csp.domains[xi] {
272                // Check if there exists vj in domain_j that satisfies constraint
273                let mut has_support = false;
274
275                for &vj in &domain_j {
276                    let mut assignment = HashMap::new();
277                    assignment.insert(xi, vi);
278                    assignment.insert(xj, vj);
279
280                    if constraint.is_satisfied(&assignment) {
281                        has_support = true;
282                        break;
283                    }
284                }
285
286                if has_support {
287                    new_domain_i.insert(vi);
288                } else {
289                    revised = true;
290                }
291            }
292
293            self.csp.domains[xi] = new_domain_i;
294        }
295
296        revised
297    }
298
299    /// Get the CSP after arc consistency
300    pub fn csp(self) -> CSP {
301        self.csp
302    }
303
304    /// Get reference to CSP
305    pub fn csp_ref(&self) -> &CSP {
306        &self.csp
307    }
308}
309
310/// Backtracking search with constraint propagation
311pub struct BacktrackingSearch {
312    /// CSP instance
313    csp: CSP,
314    /// Use forward checking
315    use_forward_checking: bool,
316    /// Solutions found
317    solutions: Vec<HashMap<VarId, i32>>,
318    /// Maximum solutions to find
319    max_solutions: usize,
320}
321
322impl BacktrackingSearch {
323    /// Create a new backtracking search
324    pub fn new(csp: CSP) -> Self {
325        Self {
326            csp,
327            use_forward_checking: true,
328            solutions: Vec::new(),
329            max_solutions: 1,
330        }
331    }
332
333    /// Enable or disable forward checking
334    pub fn with_forward_checking(mut self, enabled: bool) -> Self {
335        self.use_forward_checking = enabled;
336        self
337    }
338
339    /// Set maximum solutions to find
340    pub fn with_max_solutions(mut self, max: usize) -> Self {
341        self.max_solutions = max;
342        self
343    }
344
345    /// Solve the CSP
346    pub fn solve(&mut self) -> Vec<HashMap<VarId, i32>> {
347        let assignment = HashMap::new();
348        self.backtrack(assignment);
349        self.solutions.clone()
350    }
351
352    /// Backtracking recursive search
353    fn backtrack(&mut self, assignment: HashMap<VarId, i32>) -> bool {
354        if self.solutions.len() >= self.max_solutions {
355            return true;
356        }
357
358        if self.csp.is_complete(&assignment) {
359            if self.csp.is_consistent(&assignment) {
360                self.solutions.push(assignment.clone());
361                return self.solutions.len() >= self.max_solutions;
362            }
363            return false;
364        }
365
366        // Select unassigned variable (MRV heuristic)
367        let var = self.select_unassigned_variable(&assignment);
368
369        // Order domain values
370        let values = self.order_domain_values(var, &assignment);
371
372        for value in values {
373            let mut new_assignment = assignment.clone();
374            new_assignment.insert(var, value);
375
376            if self.is_consistent_with_assignment(&new_assignment) {
377                if self.use_forward_checking {
378                    // Forward checking would go here
379                    // Simplified: just continue with backtracking
380                }
381
382                if self.backtrack(new_assignment) {
383                    return true;
384                }
385            }
386        }
387
388        false
389    }
390
391    /// Select unassigned variable (Minimum Remaining Values heuristic)
392    fn select_unassigned_variable(&self, assignment: &HashMap<VarId, i32>) -> VarId {
393        let mut best_var = 0;
394        let mut min_domain_size = usize::MAX;
395
396        for var in 0..self.csp.num_variables {
397            if !assignment.contains_key(&var) {
398                let domain_size = self.csp.domains[var].len();
399                if domain_size < min_domain_size {
400                    min_domain_size = domain_size;
401                    best_var = var;
402                }
403            }
404        }
405
406        best_var
407    }
408
409    /// Order domain values (Least Constraining Value heuristic)
410    fn order_domain_values(&self, var: VarId, _assignment: &HashMap<VarId, i32>) -> Vec<i32> {
411        let mut values: Vec<i32> = self.csp.domains[var].iter().copied().collect();
412        values.sort(); // Simplified: just sort numerically
413        values
414    }
415
416    /// Check if assignment is consistent with all constraints
417    fn is_consistent_with_assignment(&self, assignment: &HashMap<VarId, i32>) -> bool {
418        self.csp
419            .constraints
420            .iter()
421            .all(|c| c.is_satisfied(assignment))
422    }
423}
424
425/// Forward checking: maintain arc consistency during search
426pub struct ForwardChecker {
427    /// Original domains
428    domains: Vec<Domain>,
429}
430
431impl ForwardChecker {
432    /// Create a new forward checker
433    pub fn new(domains: Vec<Domain>) -> Self {
434        Self { domains }
435    }
436
437    /// Prune domains based on assignment
438    pub fn prune(&mut self, var: VarId, value: i32, constraints: &[DiscreteConstraint]) -> bool {
439        // For each constraint involving var
440        for constraint in constraints {
441            if !constraint.variables().contains(&var) {
442                continue;
443            }
444
445            // Remove inconsistent values from neighboring variables
446            let vars = constraint.variables();
447            for &neighbor in &vars {
448                if neighbor == var {
449                    continue;
450                }
451
452                let mut new_domain = HashSet::new();
453                for &v in &self.domains[neighbor] {
454                    let mut assignment = HashMap::new();
455                    assignment.insert(var, value);
456                    assignment.insert(neighbor, v);
457
458                    if constraint.is_satisfied(&assignment) {
459                        new_domain.insert(v);
460                    }
461                }
462
463                if new_domain.is_empty() {
464                    return false; // Domain wipeout
465                }
466
467                self.domains[neighbor] = new_domain;
468            }
469        }
470
471        true
472    }
473
474    /// Restore domains
475    pub fn restore(&mut self, saved_domains: &[Domain]) {
476        self.domains = saved_domains.to_vec();
477    }
478
479    /// Get current domains
480    pub fn domains(&self) -> &[Domain] {
481        &self.domains
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    #[test]
490    fn test_binary_constraint() {
491        let mut relation = HashSet::new();
492        relation.insert((1, 2));
493        relation.insert((2, 3));
494
495        let constraint = DiscreteConstraint::Binary {
496            var1: 0,
497            var2: 1,
498            relation,
499        };
500
501        let mut assignment = HashMap::new();
502        assignment.insert(0, 1);
503        assignment.insert(1, 2);
504
505        assert!(constraint.is_satisfied(&assignment));
506
507        assignment.insert(1, 3);
508        assert!(!constraint.is_satisfied(&assignment));
509    }
510
511    #[test]
512    fn test_all_different_constraint() {
513        let constraint = DiscreteConstraint::AllDifferent {
514            variables: vec![0, 1, 2],
515        };
516
517        let mut assignment = HashMap::new();
518        assignment.insert(0, 1);
519        assignment.insert(1, 2);
520        assignment.insert(2, 3);
521
522        assert!(constraint.is_satisfied(&assignment));
523
524        assignment.insert(2, 1); // Same as var 0
525        assert!(!constraint.is_satisfied(&assignment));
526    }
527
528    #[test]
529    fn test_less_than_constraint() {
530        let constraint = DiscreteConstraint::LessThan { var1: 0, var2: 1 };
531
532        let mut assignment = HashMap::new();
533        assignment.insert(0, 5);
534        assignment.insert(1, 10);
535
536        assert!(constraint.is_satisfied(&assignment));
537
538        assignment.insert(1, 3);
539        assert!(!constraint.is_satisfied(&assignment));
540    }
541
542    #[test]
543    fn test_csp_creation() {
544        let domain1: Domain = [1, 2, 3].iter().cloned().collect();
545        let domain2: Domain = [2, 3, 4].iter().cloned().collect();
546
547        let csp = CSP::new(2, vec![domain1, domain2]).unwrap();
548
549        assert_eq!(csp.num_variables, 2);
550        assert_eq!(csp.domains.len(), 2);
551    }
552
553    #[test]
554    fn test_ac3_simple() {
555        let domain1: Domain = [1, 2, 3].iter().cloned().collect();
556        let domain2: Domain = [2, 3, 4].iter().cloned().collect();
557
558        let mut csp = CSP::new(2, vec![domain1, domain2]).unwrap();
559
560        // Add constraint: var0 < var1
561        csp.add_constraint(DiscreteConstraint::LessThan { var1: 0, var2: 1 });
562
563        let mut ac3 = AC3::new(csp);
564        let consistent = ac3.enforce_arc_consistency();
565
566        assert!(consistent);
567
568        // Domain of var0 should be reduced (values < some value in domain of var1)
569        let csp_result = ac3.csp();
570        assert!(!csp_result.domains[0].is_empty());
571        assert!(!csp_result.domains[1].is_empty());
572    }
573
574    #[test]
575    fn test_backtracking_search() {
576        let domain1: Domain = [1, 2].iter().cloned().collect();
577        let domain2: Domain = [1, 2].iter().cloned().collect();
578
579        let mut csp = CSP::new(2, vec![domain1, domain2]).unwrap();
580
581        // All different constraint
582        csp.add_constraint(DiscreteConstraint::AllDifferent {
583            variables: vec![0, 1],
584        });
585
586        let mut search = BacktrackingSearch::new(csp).with_max_solutions(2);
587        let solutions = search.solve();
588
589        assert!(!solutions.is_empty());
590        // Should find 2 solutions: (1,2) and (2,1)
591        assert!(solutions.len() <= 2);
592
593        for solution in solutions {
594            assert_ne!(solution.get(&0), solution.get(&1));
595        }
596    }
597
598    #[test]
599    fn test_forward_checker() {
600        let domain1: Domain = [1, 2, 3].iter().cloned().collect();
601        let domain2: Domain = [1, 2, 3].iter().cloned().collect();
602
603        let mut checker = ForwardChecker::new(vec![domain1, domain2]);
604
605        let constraints = vec![DiscreteConstraint::AllDifferent {
606            variables: vec![0, 1],
607        }];
608
609        // Assign var0 = 1
610        let success = checker.prune(0, 1, &constraints);
611        assert!(success);
612
613        // Domain of var1 should not contain 1
614        assert!(!checker.domains()[1].contains(&1));
615        assert!(checker.domains()[1].contains(&2));
616        assert!(checker.domains()[1].contains(&3));
617    }
618
619    #[test]
620    fn test_sum_constraint() {
621        let constraint = DiscreteConstraint::Sum {
622            variables: vec![0, 1, 2],
623            target: 6,
624        };
625
626        let mut assignment = HashMap::new();
627        assignment.insert(0, 1);
628        assignment.insert(1, 2);
629        assignment.insert(2, 3);
630
631        assert!(constraint.is_satisfied(&assignment)); // 1 + 2 + 3 = 6
632
633        assignment.insert(2, 4);
634        assert!(!constraint.is_satisfied(&assignment)); // 1 + 2 + 4 = 7
635    }
636}