1use crate::error::{LogicError, LogicResult};
19use std::collections::{HashMap, HashSet, VecDeque};
20
21pub type Domain = HashSet<i32>;
23
24pub type VarId = usize;
26
27#[derive(Debug, Clone)]
29pub enum DiscreteConstraint {
30 Binary {
32 var1: VarId,
34 var2: VarId,
36 relation: HashSet<(i32, i32)>,
38 },
39
40 AllDifferent {
42 variables: Vec<VarId>,
44 },
45
46 Sum {
48 variables: Vec<VarId>,
50 target: i32,
52 },
53
54 LessThan {
56 var1: VarId,
58 var2: VarId,
60 },
61
62 GreaterThan {
64 var1: VarId,
66 var2: VarId,
68 },
69}
70
71impl DiscreteConstraint {
72 pub fn variables(&self) -> Vec<VarId> {
74 match self {
75 Self::Binary { var1, var2, .. } => vec![*var1, *var2],
76 Self::AllDifferent { variables } => variables.clone(),
77 Self::Sum { variables, .. } => variables.clone(),
78 Self::LessThan { var1, var2 } => vec![*var1, *var2],
79 Self::GreaterThan { var1, var2 } => vec![*var1, *var2],
80 }
81 }
82
83 pub fn is_binary(&self) -> bool {
85 matches!(
86 self,
87 Self::Binary { .. } | Self::LessThan { .. } | Self::GreaterThan { .. }
88 )
89 }
90
91 pub fn is_satisfied(&self, assignment: &HashMap<VarId, i32>) -> bool {
93 match self {
94 Self::Binary {
95 var1,
96 var2,
97 relation,
98 } => {
99 if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
100 relation.contains(&(v1, v2))
101 } else {
102 true }
104 }
105 Self::AllDifferent { variables } => {
106 let values: Vec<i32> = variables
107 .iter()
108 .filter_map(|v| assignment.get(v))
109 .copied()
110 .collect();
111
112 let unique: HashSet<_> = values.iter().collect();
113 values.len() == unique.len()
114 }
115 Self::Sum { variables, target } => {
116 if variables.iter().all(|v| assignment.contains_key(v)) {
117 let sum: i32 = variables.iter().filter_map(|v| assignment.get(v)).sum();
118 sum == *target
119 } else {
120 true }
122 }
123 Self::LessThan { var1, var2 } => {
124 if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
125 v1 < v2
126 } else {
127 true
128 }
129 }
130 Self::GreaterThan { var1, var2 } => {
131 if let (Some(&v1), Some(&v2)) = (assignment.get(var1), assignment.get(var2)) {
132 v1 > v2
133 } else {
134 true
135 }
136 }
137 }
138 }
139}
140
141pub struct CSP {
143 num_variables: usize,
145 domains: Vec<Domain>,
147 constraints: Vec<DiscreteConstraint>,
149}
150
151impl CSP {
152 pub fn new(num_variables: usize, initial_domains: Vec<Domain>) -> LogicResult<Self> {
154 if initial_domains.len() != num_variables {
155 return Err(LogicError::InvalidInput(
156 "Domain count must match variable count".to_string(),
157 ));
158 }
159
160 Ok(Self {
161 num_variables,
162 domains: initial_domains,
163 constraints: Vec::new(),
164 })
165 }
166
167 pub fn add_constraint(&mut self, constraint: DiscreteConstraint) {
169 self.constraints.push(constraint);
170 }
171
172 pub fn domain(&self, var: VarId) -> Option<&Domain> {
174 self.domains.get(var)
175 }
176
177 pub fn constraints_for_variable(&self, var: VarId) -> Vec<&DiscreteConstraint> {
179 self.constraints
180 .iter()
181 .filter(|c| c.variables().contains(&var))
182 .collect()
183 }
184
185 pub fn is_complete(&self, assignment: &HashMap<VarId, i32>) -> bool {
187 assignment.len() == self.num_variables
188 }
189
190 pub fn is_consistent(&self, assignment: &HashMap<VarId, i32>) -> bool {
192 self.constraints.iter().all(|c| c.is_satisfied(assignment))
193 }
194}
195
196pub struct AC3 {
198 csp: CSP,
200}
201
202impl AC3 {
203 pub fn new(csp: CSP) -> Self {
205 Self { csp }
206 }
207
208 pub fn enforce_arc_consistency(&mut self) -> bool {
212 let mut queue: VecDeque<(VarId, VarId)> = VecDeque::new();
214
215 for constraint in &self.csp.constraints {
217 if let DiscreteConstraint::Binary { var1, var2, .. }
218 | DiscreteConstraint::LessThan { var1, var2 }
219 | DiscreteConstraint::GreaterThan { var1, var2 } = constraint
220 {
221 queue.push_back((*var1, *var2));
222 queue.push_back((*var2, *var1));
223 }
224 }
225
226 while let Some((xi, xj)) = queue.pop_front() {
228 if self.revise(xi, xj) {
229 if self.csp.domains[xi].is_empty() {
230 return false; }
232
233 for constraint in &self.csp.constraints.clone() {
235 let vars = constraint.variables();
236 if vars.contains(&xi) && vars.len() == 2 {
237 for &xk in &vars {
238 if xk != xi && xk != xj {
239 queue.push_back((xk, xi));
240 }
241 }
242 }
243 }
244 }
245 }
246
247 true
248 }
249
250 fn revise(&mut self, xi: VarId, xj: VarId) -> bool {
254 let mut revised = false;
255
256 let constraint = self
258 .csp
259 .constraints
260 .iter()
261 .find(|c| {
262 let vars = c.variables();
263 vars.len() == 2 && vars.contains(&xi) && vars.contains(&xj)
264 })
265 .cloned();
266
267 if let Some(constraint) = constraint {
268 let domain_j = self.csp.domains[xj].clone();
269 let mut new_domain_i = HashSet::new();
270
271 for &vi in &self.csp.domains[xi] {
272 let mut has_support = false;
274
275 for &vj in &domain_j {
276 let mut assignment = HashMap::new();
277 assignment.insert(xi, vi);
278 assignment.insert(xj, vj);
279
280 if constraint.is_satisfied(&assignment) {
281 has_support = true;
282 break;
283 }
284 }
285
286 if has_support {
287 new_domain_i.insert(vi);
288 } else {
289 revised = true;
290 }
291 }
292
293 self.csp.domains[xi] = new_domain_i;
294 }
295
296 revised
297 }
298
299 pub fn csp(self) -> CSP {
301 self.csp
302 }
303
304 pub fn csp_ref(&self) -> &CSP {
306 &self.csp
307 }
308}
309
310pub struct BacktrackingSearch {
312 csp: CSP,
314 use_forward_checking: bool,
316 solutions: Vec<HashMap<VarId, i32>>,
318 max_solutions: usize,
320}
321
322impl BacktrackingSearch {
323 pub fn new(csp: CSP) -> Self {
325 Self {
326 csp,
327 use_forward_checking: true,
328 solutions: Vec::new(),
329 max_solutions: 1,
330 }
331 }
332
333 pub fn with_forward_checking(mut self, enabled: bool) -> Self {
335 self.use_forward_checking = enabled;
336 self
337 }
338
339 pub fn with_max_solutions(mut self, max: usize) -> Self {
341 self.max_solutions = max;
342 self
343 }
344
345 pub fn solve(&mut self) -> Vec<HashMap<VarId, i32>> {
347 let assignment = HashMap::new();
348 self.backtrack(assignment);
349 self.solutions.clone()
350 }
351
352 fn backtrack(&mut self, assignment: HashMap<VarId, i32>) -> bool {
354 if self.solutions.len() >= self.max_solutions {
355 return true;
356 }
357
358 if self.csp.is_complete(&assignment) {
359 if self.csp.is_consistent(&assignment) {
360 self.solutions.push(assignment.clone());
361 return self.solutions.len() >= self.max_solutions;
362 }
363 return false;
364 }
365
366 let var = self.select_unassigned_variable(&assignment);
368
369 let values = self.order_domain_values(var, &assignment);
371
372 for value in values {
373 let mut new_assignment = assignment.clone();
374 new_assignment.insert(var, value);
375
376 if self.is_consistent_with_assignment(&new_assignment) {
377 if self.use_forward_checking {
378 }
381
382 if self.backtrack(new_assignment) {
383 return true;
384 }
385 }
386 }
387
388 false
389 }
390
391 fn select_unassigned_variable(&self, assignment: &HashMap<VarId, i32>) -> VarId {
393 let mut best_var = 0;
394 let mut min_domain_size = usize::MAX;
395
396 for var in 0..self.csp.num_variables {
397 if !assignment.contains_key(&var) {
398 let domain_size = self.csp.domains[var].len();
399 if domain_size < min_domain_size {
400 min_domain_size = domain_size;
401 best_var = var;
402 }
403 }
404 }
405
406 best_var
407 }
408
409 fn order_domain_values(&self, var: VarId, _assignment: &HashMap<VarId, i32>) -> Vec<i32> {
411 let mut values: Vec<i32> = self.csp.domains[var].iter().copied().collect();
412 values.sort(); values
414 }
415
416 fn is_consistent_with_assignment(&self, assignment: &HashMap<VarId, i32>) -> bool {
418 self.csp
419 .constraints
420 .iter()
421 .all(|c| c.is_satisfied(assignment))
422 }
423}
424
425pub struct ForwardChecker {
427 domains: Vec<Domain>,
429}
430
431impl ForwardChecker {
432 pub fn new(domains: Vec<Domain>) -> Self {
434 Self { domains }
435 }
436
437 pub fn prune(&mut self, var: VarId, value: i32, constraints: &[DiscreteConstraint]) -> bool {
439 for constraint in constraints {
441 if !constraint.variables().contains(&var) {
442 continue;
443 }
444
445 let vars = constraint.variables();
447 for &neighbor in &vars {
448 if neighbor == var {
449 continue;
450 }
451
452 let mut new_domain = HashSet::new();
453 for &v in &self.domains[neighbor] {
454 let mut assignment = HashMap::new();
455 assignment.insert(var, value);
456 assignment.insert(neighbor, v);
457
458 if constraint.is_satisfied(&assignment) {
459 new_domain.insert(v);
460 }
461 }
462
463 if new_domain.is_empty() {
464 return false; }
466
467 self.domains[neighbor] = new_domain;
468 }
469 }
470
471 true
472 }
473
474 pub fn restore(&mut self, saved_domains: &[Domain]) {
476 self.domains = saved_domains.to_vec();
477 }
478
479 pub fn domains(&self) -> &[Domain] {
481 &self.domains
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_binary_constraint() {
491 let mut relation = HashSet::new();
492 relation.insert((1, 2));
493 relation.insert((2, 3));
494
495 let constraint = DiscreteConstraint::Binary {
496 var1: 0,
497 var2: 1,
498 relation,
499 };
500
501 let mut assignment = HashMap::new();
502 assignment.insert(0, 1);
503 assignment.insert(1, 2);
504
505 assert!(constraint.is_satisfied(&assignment));
506
507 assignment.insert(1, 3);
508 assert!(!constraint.is_satisfied(&assignment));
509 }
510
511 #[test]
512 fn test_all_different_constraint() {
513 let constraint = DiscreteConstraint::AllDifferent {
514 variables: vec![0, 1, 2],
515 };
516
517 let mut assignment = HashMap::new();
518 assignment.insert(0, 1);
519 assignment.insert(1, 2);
520 assignment.insert(2, 3);
521
522 assert!(constraint.is_satisfied(&assignment));
523
524 assignment.insert(2, 1); assert!(!constraint.is_satisfied(&assignment));
526 }
527
528 #[test]
529 fn test_less_than_constraint() {
530 let constraint = DiscreteConstraint::LessThan { var1: 0, var2: 1 };
531
532 let mut assignment = HashMap::new();
533 assignment.insert(0, 5);
534 assignment.insert(1, 10);
535
536 assert!(constraint.is_satisfied(&assignment));
537
538 assignment.insert(1, 3);
539 assert!(!constraint.is_satisfied(&assignment));
540 }
541
542 #[test]
543 fn test_csp_creation() {
544 let domain1: Domain = [1, 2, 3].iter().cloned().collect();
545 let domain2: Domain = [2, 3, 4].iter().cloned().collect();
546
547 let csp = CSP::new(2, vec![domain1, domain2]).unwrap();
548
549 assert_eq!(csp.num_variables, 2);
550 assert_eq!(csp.domains.len(), 2);
551 }
552
553 #[test]
554 fn test_ac3_simple() {
555 let domain1: Domain = [1, 2, 3].iter().cloned().collect();
556 let domain2: Domain = [2, 3, 4].iter().cloned().collect();
557
558 let mut csp = CSP::new(2, vec![domain1, domain2]).unwrap();
559
560 csp.add_constraint(DiscreteConstraint::LessThan { var1: 0, var2: 1 });
562
563 let mut ac3 = AC3::new(csp);
564 let consistent = ac3.enforce_arc_consistency();
565
566 assert!(consistent);
567
568 let csp_result = ac3.csp();
570 assert!(!csp_result.domains[0].is_empty());
571 assert!(!csp_result.domains[1].is_empty());
572 }
573
574 #[test]
575 fn test_backtracking_search() {
576 let domain1: Domain = [1, 2].iter().cloned().collect();
577 let domain2: Domain = [1, 2].iter().cloned().collect();
578
579 let mut csp = CSP::new(2, vec![domain1, domain2]).unwrap();
580
581 csp.add_constraint(DiscreteConstraint::AllDifferent {
583 variables: vec![0, 1],
584 });
585
586 let mut search = BacktrackingSearch::new(csp).with_max_solutions(2);
587 let solutions = search.solve();
588
589 assert!(!solutions.is_empty());
590 assert!(solutions.len() <= 2);
592
593 for solution in solutions {
594 assert_ne!(solution.get(&0), solution.get(&1));
595 }
596 }
597
598 #[test]
599 fn test_forward_checker() {
600 let domain1: Domain = [1, 2, 3].iter().cloned().collect();
601 let domain2: Domain = [1, 2, 3].iter().cloned().collect();
602
603 let mut checker = ForwardChecker::new(vec![domain1, domain2]);
604
605 let constraints = vec![DiscreteConstraint::AllDifferent {
606 variables: vec![0, 1],
607 }];
608
609 let success = checker.prune(0, 1, &constraints);
611 assert!(success);
612
613 assert!(!checker.domains()[1].contains(&1));
615 assert!(checker.domains()[1].contains(&2));
616 assert!(checker.domains()[1].contains(&3));
617 }
618
619 #[test]
620 fn test_sum_constraint() {
621 let constraint = DiscreteConstraint::Sum {
622 variables: vec![0, 1, 2],
623 target: 6,
624 };
625
626 let mut assignment = HashMap::new();
627 assignment.insert(0, 1);
628 assignment.insert(1, 2);
629 assignment.insert(2, 3);
630
631 assert!(constraint.is_satisfied(&assignment)); assignment.insert(2, 4);
634 assert!(!constraint.is_satisfied(&assignment)); }
636}