1#![expect(
4 clippy::upper_case_acronyms,
5 reason = "Python naming for exposed types"
6)]
7
8use std::sync::PoisonError;
9
10use pyo3::{create_exception, exceptions::PyException, prelude::*};
11
12struct ErrWrapper(PyErr);
14
15type Result<R = (), E = ErrWrapper> = std::result::Result<R, E>;
18
19impl From<::pindakaas::Unsatisfiable> for ErrWrapper {
21 fn from(_: ::pindakaas::Unsatisfiable) -> Self {
22 Self(Unsatisfiable::new_err(
23 "The given constraint was found to be Unsatisfiable during encoding",
24 ))
25 }
26}
27
28impl<T> From<PoisonError<T>> for ErrWrapper {
30 fn from(e: PoisonError<T>) -> Self {
31 Self(PyException::new_err(e.to_string()))
32 }
33}
34
35impl From<PyErr> for ErrWrapper {
37 fn from(err: PyErr) -> Self {
38 ErrWrapper(err)
39 }
40}
41
42impl From<ErrWrapper> for PyErr {
44 fn from(err: ErrWrapper) -> Self {
45 err.0
46 }
47}
48
49create_exception! {
50 pindakaas,
51 InvalidEncoder,
52 PyException,
53 "Raised when the chosen encoder does not support the constraint (e.g. when the `PairwiseEncoder` encoder for AMO constraints is used to encode a PB constraint)."
54}
55create_exception! {
56 pindakaas,
57 Unsatisfiable,
58 PyException,
59 "Raised when the given constraint is found to be Unsatisfiable during encoding."
60}
61
62#[pymodule]
63mod pindakaas {
64 use std::{
65 fmt::{self, Display},
66 num::NonZeroI32,
67 sync::Mutex,
68 };
69
70 use itertools::Itertools;
71 use pindakaas::{
72 bool_linear::{
73 AdderEncoder, BoolLinAggregator, BoolLinExp as BaseBoolLinExp, BoolLinVariant,
74 BoolLinear as BaseBoolLinCon, Comparator, LinearEncoder, NormalizedBoolLinear,
75 SwcEncoder, TotalizerEncoder,
76 },
77 cardinality::{Cardinality, SortingNetworkEncoder},
78 cardinality_one::{BitwiseEncoder, CardinalityOne, LadderEncoder, PairwiseEncoder},
79 propositional_logic::{Formula as BaseFormula, TseitinEncoder},
80 BoolVal, ClauseDatabase, ClauseDatabaseTools, Cnf, Encoder as EncoderTrait, Lit as BaseLit,
81 VarRange as BaseVarRange, Wcnf,
82 };
83 use pyo3::{exceptions::PyValueError, prelude::*, types::PyIterator};
84
85 #[pymodule_export]
86 use crate::InvalidEncoder;
87 use crate::Result;
88 #[pymodule_export]
89 use crate::Unsatisfiable;
90
91 #[derive(FromPyObject)]
92 enum BoolLinArg {
94 Bool(bool),
95 BoolLin(BoolLinExp),
96 Int(i64),
97 Lit(Lit),
98 }
99
100 #[pyclass]
101 #[derive(Clone, Debug)]
102 struct BoolLinCon(BaseBoolLinCon);
104
105 #[pyclass]
106 #[derive(Clone, Debug)]
107 struct BoolLinExp(BaseBoolLinExp);
112
113 #[pyclass]
114 #[derive(Clone, Debug, Default)]
115 struct CNFInner(Cnf);
117
118 #[derive(FromPyObject)]
119 enum ConstraintArg {
122 BoolLin(BoolLinCon),
124 Formula(Formula),
126 }
127
128 #[expect(non_camel_case_types, reason = "match python naming convention")]
129 #[pyclass(eq, eq_int)]
130 #[derive(Clone, Copy, Debug, PartialEq)]
131 enum Encoder {
137 ADDER,
141 BITWISE,
144 DECISION_DIAGRAM,
147 LADDER,
150 PAIRWISE,
153 SORTED_WEIGHT_COUNTER,
156 SORTING_NETWORK,
159 TOTALIZER,
162 TSEITIN,
165 }
166
167 #[pyclass]
168 #[derive(Clone, Debug)]
169 struct Formula(BaseFormula<BoolVal>);
171
172 #[derive(FromPyObject)]
173 enum FormulaArg {
175 Const(bool),
176 Formula(Formula),
177 Lit(Lit),
178 }
179
180 struct LinEncoderWrapper {
181 method: Option<Encoder>,
183 error_message: Mutex<Option<PyErr>>,
185 }
186
187 #[pyclass]
188 #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
189 struct Lit(BaseLit);
191
192 #[pyclass]
193 #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
194 struct VarRange(BaseVarRange);
196
197 #[pyclass]
198 #[derive(Clone, Debug, Default)]
199 struct WCNFInner(Wcnf);
202
203 #[pyfunction]
204 fn _wrap_encode_constraint(
205 obj: &Bound<'_, PyAny>,
206 con: ConstraintArg,
207 enc: Option<Encoder>,
208 conditions: Vec<Lit>,
209 ) -> Result {
210 struct PyDbWrapper<'a>(&'a Bound<'a, PyAny>);
211 impl ClauseDatabase for PyDbWrapper<'_> {
212 fn add_clause_from_slice(
213 &mut self,
214 clause: &[BaseLit],
215 ) -> Result<(), pindakaas::Unsatisfiable> {
216 let clause_vec = clause.iter().map(|&l| Lit(l)).collect_vec();
217 let res = self.0.call_method1("add_clause", (clause_vec,));
218 match res {
219 Err(e) if e.is_instance_of::<Unsatisfiable>(self.0.py()) => {
220 Err(pindakaas::Unsatisfiable)
221 }
222 Err(e) => {
223 panic!("unexpected error in add_clause implementation: {}", e)
224 }
225 Ok(_) if clause.is_empty() => Err(pindakaas::Unsatisfiable),
229 Ok(_) => Ok(()),
230 }
231 }
232
233 fn new_var_range(&mut self, len: usize) -> BaseVarRange {
234 let tup = self
235 .0
236 .call_method1("new_var_range", (len,))
237 .expect("unexpected error in new_var_range implementation");
238 let (start, end): (Lit, Lit) = tup
239 .extract()
240 .expect("new_var_range did not return a tuple of two literals");
241 BaseVarRange::new(start.0.var(), end.0.var())
242 }
243 }
244
245 encode_constraint(&mut PyDbWrapper(obj), con, enc, conditions)
246 }
247
248 fn encode_constraint<Db>(
254 db: &mut Db,
255 con: ConstraintArg,
256 enc: Option<Encoder>,
257 conditions: Vec<Lit>,
258 ) -> Result
259 where
260 Db: ClauseDatabase + ?Sized,
261 {
262 let invalid_enc = |con_ty, enc| {
263 Err(InvalidEncoder::new_err(format!(
264 "Unable to encode object of type `{con_ty}' using {enc:?}"
265 ))
266 .into())
267 };
268 let conditions: Vec<_> = conditions.into_iter().map(|l| l.0).collect();
269
270 match con {
271 ConstraintArg::BoolLin(lin) => {
272 let encoder = LinEncoderWrapper::new(enc);
273 let encoder = LinearEncoder::new(encoder, BoolLinAggregator::default());
274 encoder.encode_implied(db, &conditions, &lin.0)?;
275 let err = encoder
276 .variant_encoder()
277 .error_message
278 .lock()
279 .unwrap()
280 .take();
281 if let Some(err) = err {
282 return Err(err.into());
283 }
284 }
285 ConstraintArg::Formula(f) => match enc.unwrap_or(Encoder::TSEITIN) {
286 Encoder::TSEITIN => TseitinEncoder.encode_implied(db, &conditions, &f.0)?,
287 _ => {
288 return invalid_enc("Formula", enc.unwrap());
289 }
290 },
291 };
292 Ok(())
293 }
294
295 impl BoolLinArg {
296 fn as_bool_lin_exp(&self) -> BoolLinExp {
297 match self {
298 &BoolLinArg::Bool(b) => BoolLinExp(b.into()),
299 BoolLinArg::BoolLin(exp) => exp.clone(),
300 &BoolLinArg::Int(i) => BoolLinExp(i.into()),
301 &BoolLinArg::Lit(l) => BoolLinExp(l.0.into()),
302 }
303 }
304 }
305
306 #[pymethods]
307 impl BoolLinCon {
308 fn __str__(&self) -> String {
309 self.0.to_string()
310 }
311 }
312
313 #[pymethods]
314 impl BoolLinExp {
315 fn __add__(&self, other: BoolLinArg) -> Self {
316 let mut res = self.clone();
317 res.__iadd__(other);
318 res
319 }
320
321 fn __eq__(&self, other: i64) -> BoolLinCon {
322 BoolLinCon(BaseBoolLinCon::new(
323 self.0.clone(),
324 Comparator::Equal,
325 other,
326 ))
327 }
328
329 fn __ge__(&self, other: i64) -> BoolLinCon {
330 BoolLinCon(BaseBoolLinCon::new(
331 self.0.clone(),
332 Comparator::GreaterEq,
333 other,
334 ))
335 }
336
337 fn __gt__(&self, other: i64) -> BoolLinCon {
338 self.__ge__(other + 1)
339 }
340
341 fn __iadd__(&mut self, other: BoolLinArg) {
342 self.0 += other.as_bool_lin_exp().0;
343 }
344
345 fn __imul__(&mut self, other: i64) {
346 self.0 *= other;
347 }
348
349 fn __isub__(&mut self, other: BoolLinArg) {
350 self.0 -= other.as_bool_lin_exp().0;
351 }
352
353 fn __le__(&self, other: i64) -> BoolLinCon {
354 BoolLinCon(BaseBoolLinCon::new(
355 self.0.clone(),
356 Comparator::LessEq,
357 other,
358 ))
359 }
360
361 fn __lt__(&self, other: i64) -> BoolLinCon {
362 self.__le__(other - 1)
363 }
364
365 fn __mul__(&self, other: i64) -> Self {
366 let mut res = self.clone();
367 res.__imul__(other);
368 res
369 }
370
371 fn __neg__(&self) -> Self {
372 Self(-self.0.clone())
373 }
374
375 fn __radd__(&self, other: BoolLinArg) -> Self {
376 self.__add__(other)
377 }
378
379 fn __rmul__(&self, other: i64) -> Self {
380 self.__mul__(other)
381 }
382
383 fn __str__(&self) -> String {
384 self.0.to_string()
385 }
386
387 fn __sub__(&self, other: BoolLinArg) -> Self {
388 let mut res = self.clone();
389 res.__isub__(other);
390 res
391 }
392 }
393
394 #[pymethods]
395 impl CNFInner {
396 fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
397 let clause: Vec<Lit> = clause
398 .into_iter()
399 .map(|any| any.and_then(|lit| lit.extract::<Lit>()))
400 .try_collect()?;
401 self.0.add_clause(clause.into_iter().map(|lit| lit.0))?;
402 Ok(())
403 }
404
405 fn add_encoding(
406 &mut self,
407 con: ConstraintArg,
408 enc: Option<Encoder>,
409 conditions: Vec<Lit>,
410 ) -> Result {
411 encode_constraint(&mut self.0, con, enc, conditions)
412 }
413
414 fn clauses(&self) -> Vec<Vec<Lit>> {
415 self.0
418 .iter()
419 .map(|c| c.iter().map(|&lit| Lit(lit)).collect())
420 .collect()
421 }
422
423 #[new]
424 fn new() -> Self {
425 Self(Default::default())
426 }
427
428 fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
429 let range = self.0.new_var_range(num_vars);
430 Ok(VarRange(range))
431 }
432
433 fn to_dimacs(&self) -> String {
434 self.0.to_string()
435 }
436
437 fn variables(&self) -> VarRange {
438 VarRange(self.0.variables())
439 }
440 }
441
442 #[pymethods]
443 impl Formula {
444 fn __and__(&self, other: FormulaArg) -> Self {
445 Self(self.0.clone() & other.as_formula())
446 }
447
448 fn __eq__(&self, other: FormulaArg) -> Self {
449 use BaseFormula::*;
450
451 Formula(Equiv(vec![self.0.clone(), other.as_formula()]))
452 }
453
454 fn __ge__(&self, other: FormulaArg) -> Self {
455 use BaseFormula::*;
456
457 Self(Implies(other.as_formula().into(), self.0.clone().into()))
458 }
459
460 fn __gt__(&self, other: FormulaArg) -> Self {
461 Self(self.0.clone() & !other.as_formula())
462 }
463
464 fn __invert__(&self) -> Self {
465 Self(!self.0.clone())
466 }
467
468 fn __le__(&self, other: FormulaArg) -> Self {
469 use BaseFormula::*;
470
471 Self(Implies(self.0.clone().into(), other.as_formula().into()))
472 }
473
474 fn __lt__(&self, other: FormulaArg) -> Self {
475 Self(!self.0.clone() & other.as_formula())
476 }
477
478 fn __ne__(&self, other: FormulaArg) -> Self {
479 self.__xor__(other)
480 }
481
482 fn __or__(&self, other: FormulaArg) -> Self {
483 Formula(self.0.clone() | other.as_formula())
484 }
485
486 fn __rand__(&self, other: FormulaArg) -> Self {
487 self.__and__(other)
488 }
489
490 fn __ror__(&self, other: FormulaArg) -> Self {
491 self.__or__(other)
492 }
493
494 fn __rxor__(&self, other: FormulaArg) -> Self {
495 self.__xor__(other)
496 }
497
498 fn __str__(&self) -> String {
499 self.0.to_string()
500 }
501
502 fn __xor__(&self, other: FormulaArg) -> Self {
503 Formula(self.0.clone() ^ other.as_formula())
504 }
505 }
506
507 impl FormulaArg {
508 fn as_formula(&self) -> BaseFormula<BoolVal> {
511 use BaseFormula::*;
512
513 match self {
514 FormulaArg::Const(b) => Atom(BoolVal::Const(*b)),
515 FormulaArg::Formula(formula) => formula.0.clone(),
516 FormulaArg::Lit(lit) => lit.as_formula(),
517 }
518 }
519 }
520
521 impl LinEncoderWrapper {
522 fn new(method: Option<Encoder>) -> Self {
523 Self {
524 method,
525 error_message: Mutex::new(None),
526 }
527 }
528
529 fn set_err(&self, con_ty: &str, enc: Encoder) {
530 let _ = self
531 .error_message
532 .lock()
533 .unwrap()
534 .replace(InvalidEncoder::new_err(format!(
535 "Unable to encode object of type `{con_ty}' using {enc:?}"
536 )));
537 }
538 }
539
540 impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, BoolLinVariant> for LinEncoderWrapper {
541 fn encode(
542 &self,
543 db: &mut Db,
544 con: &BoolLinVariant,
545 ) -> Result<(), pindakaas::Unsatisfiable> {
546 match con {
547 BoolLinVariant::Linear(lin) => self.encode(db, lin),
548 BoolLinVariant::Cardinality(card) => self.encode(db, card),
549 BoolLinVariant::CardinalityOne(card1) => self.encode(db, card1),
550 BoolLinVariant::Trivial => Ok(()),
551 }
552 }
553 }
554
555 impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, Cardinality> for LinEncoderWrapper {
556 fn encode(&self, db: &mut Db, con: &Cardinality) -> Result<(), pindakaas::Unsatisfiable> {
557 match self.method.unwrap_or(Encoder::ADDER) {
558 Encoder::SORTING_NETWORK => SortingNetworkEncoder::default().encode(db, con),
559 Encoder::ADDER => AdderEncoder::default().encode(db, con),
560 Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, con),
561 Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, con),
562 enc => {
563 self.set_err("Cardinality", enc);
564 Ok(())
565 }
566 }
567 }
568 }
569
570 impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, CardinalityOne> for LinEncoderWrapper {
571 fn encode(
572 &self,
573 db: &mut Db,
574 con: &CardinalityOne,
575 ) -> Result<(), pindakaas::Unsatisfiable> {
576 match self.method.unwrap_or(Encoder::BITWISE) {
577 Encoder::BITWISE => BitwiseEncoder::default().encode(db, con),
578 Encoder::ADDER => AdderEncoder::default().encode(db, con),
579 Encoder::LADDER => LadderEncoder::default().encode(db, con),
580 Encoder::PAIRWISE => PairwiseEncoder::default().encode(db, con),
581 Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, con),
582 Encoder::SORTING_NETWORK => SortingNetworkEncoder::default().encode(db, con),
583 Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, con),
584 enc => {
585 self.set_err("CardinalityOne", enc);
586 Ok(())
587 }
588 }
589 }
590 }
591
592 impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, NormalizedBoolLinear> for LinEncoderWrapper {
593 fn encode(
594 &self,
595 db: &mut Db,
596 con: &NormalizedBoolLinear,
597 ) -> Result<(), pindakaas::Unsatisfiable> {
598 match self.method.unwrap_or(Encoder::ADDER) {
599 Encoder::ADDER => AdderEncoder::default().encode(db, con),
600 Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, con),
601 Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, con),
602 enc => {
603 self.set_err("BoolLinear", enc);
604 Ok(())
605 }
606 }
607 }
608 }
609
610 impl Lit {
611 fn as_bool_lin_exp(&self) -> BoolLinExp {
612 BoolLinExp(self.0.into())
613 }
614
615 fn as_formula(&self) -> BaseFormula<BoolVal> {
616 BaseFormula::Atom(self.0.into())
617 }
618 }
619
620 #[pymethods]
621 impl Lit {
622 fn __add__(&self, other: BoolLinArg) -> BoolLinExp {
623 self.as_bool_lin_exp().__add__(other)
624 }
625
626 fn __and__(&self, other: FormulaArg) -> Formula {
627 Formula(self.as_formula()).__and__(other)
628 }
629
630 fn __eq__(&self, other: FormulaArg) -> Formula {
631 Formula(self.as_formula()).__eq__(other)
632 }
633
634 fn __ge__(&self, other: FormulaArg) -> Formula {
635 Formula(self.as_formula()).__ge__(other)
636 }
637
638 fn __gt__(&self, other: FormulaArg) -> Formula {
639 Formula(self.as_formula()).__gt__(other)
640 }
641
642 fn __int__(&self) -> i32 {
643 self.0.into()
644 }
645
646 fn __invert__(&self) -> Self {
647 Self(!self.0)
648 }
649
650 fn __le__(&self, other: FormulaArg) -> Formula {
651 Formula(self.as_formula()).__le__(other)
652 }
653
654 fn __lt__(&self, other: FormulaArg) -> Formula {
655 Formula(self.as_formula()).__lt__(other)
656 }
657
658 fn __mul__(&self, other: i64) -> BoolLinExp {
659 self.as_bool_lin_exp().__mul__(other)
660 }
661
662 fn __ne__(&self, other: FormulaArg) -> Formula {
663 Formula(self.as_formula()).__ne__(other)
664 }
665
666 fn __or__(&self, other: FormulaArg) -> Formula {
667 Formula(self.as_formula()).__or__(other)
668 }
669
670 fn __radd__(&self, other: BoolLinArg) -> BoolLinExp {
671 self.__add__(other)
672 }
673
674 fn __rand__(&self, other: FormulaArg) -> Formula {
675 Formula(self.as_formula()).__and__(other)
676 }
677
678 fn __rmul__(&self, other: i64) -> BoolLinExp {
679 self.__mul__(other)
680 }
681
682 fn __ror__(&self, other: FormulaArg) -> Formula {
683 self.__or__(other)
684 }
685
686 fn __rxor__(&self, other: FormulaArg) -> Formula {
687 self.__xor__(other)
688 }
689
690 fn __str__(&self) -> String {
691 self.0.to_string()
692 }
693
694 fn __sub__(&self, other: BoolLinArg) -> BoolLinExp {
695 self.as_bool_lin_exp().__sub__(other)
696 }
697
698 fn __xor__(&self, other: FormulaArg) -> Formula {
699 Formula(self.as_formula()).__xor__(other)
700 }
701
702 #[staticmethod]
703 fn from_raw(value: NonZeroI32) -> Self {
704 Self(BaseLit::from_raw(value))
705 }
706
707 fn is_negated(&self) -> bool {
709 self.0.is_negated()
710 }
711
712 fn var(&self) -> Self {
714 Self(self.0.var().into())
715 }
716 }
717
718 impl Display for Lit {
719 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
720 self.0.fmt(f)
721 }
722 }
723
724 #[pymethods]
725 impl VarRange {
726 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
727 slf
728 }
729
730 fn __len__(&self) -> usize {
731 self.0.len()
732 }
733
734 fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Lit> {
735 slf.0.next().map(|lit| Lit(lit.into()))
736 }
737
738 fn end(&self) -> Lit {
740 Lit(self.0.end().into())
741 }
742
743 #[new]
744 fn new(start: Lit, end: Lit) -> PyResult<Self> {
747 if start.is_negated() || end.is_negated() {
748 return Err(PyValueError::new_err(
749 "`start' and `end' must be positive literals (directly representing variables)",
750 ));
751 }
752 Ok(Self(BaseVarRange::new(start.0.var(), end.0.var())))
753 }
754
755 fn start(&self) -> Lit {
757 Lit(self.0.start().into())
758 }
759 }
760
761 #[pymethods]
762 impl WCNFInner {
763 fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
764 let clause: Vec<Lit> = clause
765 .into_iter()
766 .map(|any| any.and_then(|lit| lit.extract::<Lit>()))
767 .try_collect()?;
768 self.0.add_clause(clause.into_iter().map(|lit| lit.0))?;
769 Ok(())
770 }
771
772 fn add_encoding(
773 &mut self,
774 con: ConstraintArg,
775 enc: Option<Encoder>,
776 conditions: Vec<Lit>,
777 ) -> Result {
778 encode_constraint(&mut self.0, con, enc, conditions)
779 }
780
781 fn add_weighted_clause(&mut self, clause: Bound<'_, PyIterator>, weight: i64) -> Result {
782 let clause: Vec<Lit> = clause
783 .into_iter()
784 .map(|any| any.and_then(|lit| lit.extract::<Lit>()))
785 .try_collect()?;
786 self.0
787 .add_weighted_clause(clause.into_iter().map(|lit| lit.0), weight)?;
788 Ok(())
789 }
790
791 fn clauses(&self) -> Vec<Vec<Lit>> {
792 self.0
795 .iter()
796 .filter(|(_, w)| w.is_none())
797 .map(|(c, _)| c.iter().map(|&lit| Lit(lit)).collect_vec())
798 .collect()
799 }
800
801 #[new]
802 fn new() -> Self {
803 Self(Default::default())
804 }
805
806 fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
807 let range = self.0.new_var_range(num_vars);
808 Ok(VarRange(range))
809 }
810
811 fn to_dimacs(&self) -> String {
812 self.0.to_string()
813 }
814
815 fn variables(&self) -> VarRange {
816 VarRange(self.0.variables())
817 }
818
819 fn weighted_clauses(&self) -> Vec<(Option<i64>, Vec<Lit>)> {
820 self.0
823 .iter()
824 .map(|(c, &w)| (w, (c.iter().map(|&lit| Lit(lit)).collect())))
825 .collect()
826 }
827 }
828
829 #[pymodule]
830 mod solver {
831 macro_rules! py_solver_result {
832 ($name:ident, $owner:ident, $solver:ty) => {
833 #[pymethods]
834 impl $name {
835 fn __enter__(slf: Py<Self>) -> Py<Self> {
836 slf
837 }
838
839 fn __exit__(
840 &mut self,
841 py: Python<'_>,
842 _exc_type: Option<&Bound<'_, PyAny>>,
843 _exc: Option<&Bound<'_, PyAny>>,
844 _traceback: Option<&Bound<'_, PyAny>>,
845 ) -> PyResult<bool> {
846 self.0.exit(py, |owner| &mut owner.0)
847 }
848
849 fn failed(&self, lit: Lit) -> PyResult<Option<bool>> {
850 self.0.failed(lit)
851 }
852
853 #[getter]
854 fn status(&self) -> PyResult<Status> {
855 self.0.status()
856 }
857
858 fn value(&self, lit: Lit) -> PyResult<Option<bool>> {
859 self.0.value(lit)
860 }
861 }
862 };
863 }
864
865 use std::{
866 mem::transmute,
867 time::{Duration, SystemTime},
868 };
869
870 use itertools::Itertools;
871 use pindakaas::{
872 solver::{
873 cadical::Cadical, kissat::Kissat, Assumptions, FailedAssumptions, SolveResult,
874 Solver, TermSignal, TerminateCallback,
875 },
876 ClauseDatabase, ClauseDatabaseTools, Lit as BaseLit, Valuation,
877 };
878 use pyo3::{
879 exceptions::{PyNotImplementedError, PyRuntimeError},
880 prelude::*,
881 pyclass::boolean_struct::False,
882 types::{PyAny, PyIterator},
883 PyClass,
884 };
885
886 use crate::{
887 pindakaas::{encode_constraint, ConstraintArg, Encoder, Lit, VarRange},
888 Result,
889 };
890
891 const CHECKED_OUT_ERROR: &str = "solver is currently checked out by an active result";
892 const INACTIVE_RESULT_ERROR: &str = "solver result is no longer active";
893 const RESTORED_ERROR: &str = "solver was already restored to its owner";
894
895 #[pyclass(unsendable)]
896 #[derive(Debug)]
897 struct CaDiCaLInner(SolverImpl<Cadical>);
898
899 #[pyclass(unsendable)]
900 struct CaDiCaLResult(SolverResultImpl<CaDiCaLInner, Cadical>);
901
902 #[pyclass(unsendable)]
903 #[derive(Debug)]
904 struct KissatInner(SolverImpl<Kissat>);
905
906 #[pyclass(unsendable)]
907 struct KissatResult(SolverResultImpl<KissatInner, Kissat>);
908
909 #[derive(Debug)]
910 struct SolverImpl<S> {
911 solver: Option<S>,
912 }
913
914 struct SolverResultImpl<Owner, S> {
915 owner: Py<Owner>,
916 result: Option<SolverResultState>,
917 solver: Option<S>,
918 supports_assumptions: bool,
919 }
920
921 enum SolverResultState {
922 Satisfied(Box<dyn Valuation + 'static>),
924 Unsatisfiable(Box<dyn Fn(BaseLit) -> Option<bool> + 'static>),
926 Unknown,
928 }
929
930 #[pyclass(eq, eq_int)]
931 #[derive(Clone, Copy, Debug, PartialEq)]
932 enum Status {
934 SATISFIED,
936 UNSATISFIABLE,
938 UNKNOWN,
940 }
941
942 #[pymodule_init]
944 fn init(module: &Bound<'_, PyModule>) -> PyResult<()> {
945 module
946 .py()
947 .import("sys")?
948 .getattr("modules")?
949 .set_item("pindakaas.pindakaas.solver", module)
950 }
951
952 #[pymethods]
953 impl CaDiCaLInner {
954 fn _set_option(&mut self, name: &str, value: i32) -> PyResult<()> {
955 self.0.solver_mut()?.set_option(name, value);
956 Ok(())
957 }
958
959 fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
960 self.0.add_clause(clause)
961 }
962
963 fn add_encoding(
964 &mut self,
965 con: ConstraintArg,
966 enc: Option<Encoder>,
967 conditions: Vec<Lit>,
968 ) -> Result {
969 self.0.add_encoding(con, enc, conditions)
970 }
971
972 #[new]
973 fn new() -> Self {
974 Self(SolverImpl::default())
975 }
976
977 fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
978 self.0.new_var_range(num_vars)
979 }
980
981 fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
982 self.0.set_time_limit(limit)
983 }
984
985 fn solve_assuming(
986 slf: Py<Self>,
987 py: Python<'_>,
988 assumptions: Vec<Lit>,
989 ) -> Result<Py<CaDiCaLResult>> {
990 let mut inner = slf.bind(py).borrow_mut();
991 let solver = inner.0.take()?;
992 Ok(Py::new(
993 py,
994 CaDiCaLResult(SolverResultImpl::from_assumptions_solver(
995 slf.clone_ref(py),
996 solver,
997 &assumptions,
998 )),
999 )?)
1000 }
1001 }
1002
1003 #[pymethods]
1004 impl KissatInner {
1005 fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
1006 self.0.add_clause(clause)
1007 }
1008
1009 fn add_encoding(
1010 &mut self,
1011 con: ConstraintArg,
1012 enc: Option<Encoder>,
1013 conditions: Vec<Lit>,
1014 ) -> Result {
1015 self.0.add_encoding(con, enc, conditions)
1016 }
1017
1018 #[new]
1019 fn new() -> Self {
1020 Self(SolverImpl::default())
1021 }
1022
1023 fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
1024 self.0.new_var_range(num_vars)
1025 }
1026
1027 fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
1028 self.0.set_time_limit(limit)
1029 }
1030
1031 fn solve_assuming(
1032 slf: Py<Self>,
1033 py: Python<'_>,
1034 assumptions: Vec<Lit>,
1035 ) -> Result<Py<KissatResult>> {
1036 if !assumptions.is_empty() {
1037 return Err(PyNotImplementedError::new_err(
1038 "solver does not support assumptions",
1039 )
1040 .into());
1041 }
1042 let mut inner = slf.bind(py).borrow_mut();
1043 let solver = inner.0.take()?;
1044 Ok(Py::new(
1045 py,
1046 KissatResult(SolverResultImpl::from_solver(slf.clone_ref(py), solver)),
1047 )?)
1048 }
1049 }
1050
1051 impl<S> SolverImpl<S> {
1052 fn solver_mut(&mut self) -> PyResult<&mut S> {
1053 self.solver
1054 .as_mut()
1055 .ok_or_else(|| PyRuntimeError::new_err(CHECKED_OUT_ERROR))
1056 }
1057
1058 fn take(&mut self) -> PyResult<S> {
1059 self.solver
1060 .take()
1061 .ok_or_else(|| PyRuntimeError::new_err(CHECKED_OUT_ERROR))
1062 }
1063 }
1064
1065 impl<S: ClauseDatabase> SolverImpl<S> {
1066 fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
1067 let clause: Vec<Lit> = clause
1068 .into_iter()
1069 .map(|any| any.and_then(|lit| lit.extract::<Lit>()))
1070 .try_collect()?;
1071 self.solver_mut()?
1072 .add_clause(clause.into_iter().map(|lit| lit.0))?;
1073 Ok(())
1074 }
1075
1076 fn add_encoding(
1077 &mut self,
1078 con: ConstraintArg,
1079 enc: Option<Encoder>,
1080 conditions: Vec<Lit>,
1081 ) -> Result {
1082 encode_constraint(self.solver_mut()?, con, enc, conditions)
1083 }
1084
1085 fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
1086 Ok(VarRange(self.solver_mut()?.new_var_range(num_vars)))
1087 }
1088 }
1089
1090 impl<S: TerminateCallback> SolverImpl<S> {
1091 fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
1092 self.solver_mut()?.set_terminate_callback(limit.map(|dur| {
1093 let deadline = SystemTime::now() + dur;
1094 move || {
1095 if SystemTime::now() > deadline {
1096 TermSignal::Terminate
1097 } else {
1098 TermSignal::Continue
1099 }
1100 }
1101 }));
1102 Ok(())
1103 }
1104 }
1105
1106 impl<S: Default> Default for SolverImpl<S> {
1107 fn default() -> Self {
1108 Self {
1109 solver: Some(S::default()),
1110 }
1111 }
1112 }
1113
1114 impl<Owner: PyClass<Frozen = False>, S> SolverResultImpl<Owner, S> {
1115 fn exit(
1116 &mut self,
1117 py: Python<'_>,
1118 slot: fn(&mut Owner) -> &mut SolverImpl<S>,
1119 ) -> PyResult<bool> {
1120 self.result = None;
1121 if let Some(solver) = self.solver.take() {
1122 let mut owner = self.owner.bind(py).borrow_mut();
1123 let inner = slot(std::ops::DerefMut::deref_mut(&mut owner));
1124 if inner.solver.is_some() {
1125 return Err(PyRuntimeError::new_err(RESTORED_ERROR));
1126 }
1127 inner.solver = Some(solver);
1128 }
1129 Ok(false)
1130 }
1131 }
1132
1133 impl<Owner, S: Solver> SolverResultImpl<Owner, S> {
1134 fn from_solver(owner: Py<Owner>, mut solver: S) -> Self {
1135 let result = match solver.solve() {
1136 SolveResult::Satisfied(sol) => {
1137 let sol: Box<dyn Valuation + '_> = Box::new(sol);
1138 let sol: Box<dyn Valuation + 'static> = unsafe { transmute(sol) };
1141 SolverResultState::Satisfied(sol)
1142 }
1143 SolveResult::Unsatisfiable(_) => {
1144 SolverResultState::Unsatisfiable(Box::new(|_| None))
1145 }
1146 SolveResult::Unknown => SolverResultState::Unknown,
1147 };
1148 Self::new(owner, result, solver, false)
1149 }
1150 }
1151
1152 impl<Owner, S: Assumptions> SolverResultImpl<Owner, S> {
1153 fn from_assumptions_solver(
1154 owner: Py<Owner>,
1155 mut solver: S,
1156 assumptions: &[Lit],
1157 ) -> Self {
1158 let result = match solver.solve_assuming(assumptions.iter().map(|lit| lit.0)) {
1159 SolveResult::Satisfied(sol) => {
1160 let sol: Box<dyn Valuation + '_> = Box::new(sol);
1161 let sol: Box<dyn Valuation + 'static> = unsafe { transmute(sol) };
1166 SolverResultState::Satisfied(sol)
1167 }
1168 SolveResult::Unsatisfiable(fail) => {
1169 let fail: Box<dyn FailedAssumptions + '_> = Box::new(fail);
1170 let fail: Box<dyn FailedAssumptions + 'static> = unsafe { transmute(fail) };
1173 let fail = move |lit: BaseLit| Some(fail.fail(lit));
1174 SolverResultState::Unsatisfiable(Box::new(fail))
1175 }
1176 SolveResult::Unknown => SolverResultState::Unknown,
1177 };
1178 Self::new(owner, result, solver, true)
1179 }
1180 }
1181
1182 impl<Owner, S> SolverResultImpl<Owner, S> {
1183 fn failed(&self, lit: Lit) -> PyResult<Option<bool>> {
1184 let Some(result) = self.result.as_ref() else {
1185 return Err(PyRuntimeError::new_err(INACTIVE_RESULT_ERROR));
1186 };
1187 if !self.supports_assumptions {
1188 return Ok(None);
1189 }
1190 Ok(match result {
1191 SolverResultState::Unsatisfiable(fail) => fail(lit.0),
1192 _ => None,
1193 })
1194 }
1195
1196 fn new(
1197 owner: Py<Owner>,
1198 result: SolverResultState,
1199 solver: S,
1200 supports_assumptions: bool,
1201 ) -> Self {
1202 Self {
1203 owner,
1204 result: Some(result),
1205 solver: Some(solver),
1206 supports_assumptions,
1207 }
1208 }
1209
1210 fn status(&self) -> PyResult<Status> {
1211 let Some(result) = self.result.as_ref() else {
1212 return Err(PyRuntimeError::new_err(INACTIVE_RESULT_ERROR));
1213 };
1214 Ok(match result {
1215 SolverResultState::Satisfied(_) => Status::SATISFIED,
1216 SolverResultState::Unsatisfiable(_) => Status::UNSATISFIABLE,
1217 SolverResultState::Unknown => Status::UNKNOWN,
1218 })
1219 }
1220
1221 fn value(&self, lit: Lit) -> PyResult<Option<bool>> {
1222 let Some(result) = self.result.as_ref() else {
1223 return Err(PyRuntimeError::new_err(INACTIVE_RESULT_ERROR));
1224 };
1225 Ok(match result {
1226 SolverResultState::Satisfied(sol) => Some(sol.value(lit.0)),
1227 _ => None,
1228 })
1229 }
1230 }
1231
1232 py_solver_result!(CaDiCaLResult, CaDiCaLInner, Cadical);
1233
1234 py_solver_result!(KissatResult, KissatInner, Kissat);
1235 }
1236}