Skip to main content

pindakaas/
lib.rs

1//! This crate implements the the internal `pindakaas.pindakaas` Python module,
2//! which provides bindings for the `pindakaas` Rust crate.
3#![expect(
4	clippy::upper_case_acronyms,
5	reason = "Python naming for exposed types"
6)]
7
8use std::sync::PoisonError;
9
10use pyo3::{create_exception, exceptions::PyException, prelude::*};
11
12// Avoid orphan rule preventing impl PyErr on pindakaas::Unsatisfiable
13struct ErrWrapper(PyErr);
14
15// Use Result i/o PyResult to use `?` to easily return Rust errors as Python
16// exceptions
17type Result<R = (), E = ErrWrapper> = std::result::Result<R, E>;
18
19// Allow `pindakaas::Unsatisfiable` to become a wrapped Unsatisfiable exception
20impl From<::pindakaas::Unsatisfiable> for ErrWrapper {
21	fn from(_: ::pindakaas::Unsatisfiable) -> Self {
22		Self(Unsatisfiable::new_err(
23			"The given constraint was found to be Unsatisfiable during encoding",
24		))
25	}
26}
27
28// Allow `pindakaas::Unsatisfiable` to become a wrapped Unsatisfiable exception
29impl<T> From<PoisonError<T>> for ErrWrapper {
30	fn from(e: PoisonError<T>) -> Self {
31		Self(PyException::new_err(e.to_string()))
32	}
33}
34
35// Allow other `PyErr`s to become a wrapped exception
36impl From<PyErr> for ErrWrapper {
37	fn from(err: PyErr) -> Self {
38		ErrWrapper(err)
39	}
40}
41
42// Allow ErrWrapper to become PyErr
43impl From<ErrWrapper> for PyErr {
44	fn from(err: ErrWrapper) -> Self {
45		err.0
46	}
47}
48
49create_exception! {
50	pindakaas,
51	InvalidEncoder,
52	PyException,
53	"Raised when the chosen encoder does not support the constraint (e.g. when the `PairwiseEncoder` encoder for AMO constraints is used to encode a PB constraint)."
54}
55create_exception! {
56	pindakaas,
57	Unsatisfiable,
58	PyException,
59	"Raised when the given constraint is found to be Unsatisfiable during encoding."
60}
61
62#[pymodule]
63mod pindakaas {
64	use std::{
65		fmt::{self, Display},
66		num::NonZeroI32,
67		sync::Mutex,
68	};
69
70	use itertools::Itertools;
71	use pindakaas::{
72		bool_linear::{
73			AdderEncoder, BoolLinAggregator, BoolLinExp as BaseBoolLinExp, BoolLinVariant,
74			BoolLinear as BaseBoolLinCon, Comparator, LinearEncoder, NormalizedBoolLinear,
75			SwcEncoder, TotalizerEncoder,
76		},
77		cardinality::{Cardinality, SortingNetworkEncoder},
78		cardinality_one::{BitwiseEncoder, CardinalityOne, LadderEncoder, PairwiseEncoder},
79		propositional_logic::{Formula as BaseFormula, TseitinEncoder},
80		BoolVal, ClauseDatabase, ClauseDatabaseTools, Cnf, Encoder as EncoderTrait, Lit as BaseLit,
81		VarRange as BaseVarRange, Wcnf,
82	};
83	use pyo3::{exceptions::PyValueError, prelude::*, types::PyIterator};
84
85	#[pymodule_export]
86	use crate::InvalidEncoder;
87	use crate::Result;
88	#[pymodule_export]
89	use crate::Unsatisfiable;
90
91	#[derive(FromPyObject)]
92	/// Argument capture for types that can become :class:`BoolLinExp`.
93	enum BoolLinArg {
94		Bool(bool),
95		BoolLin(BoolLinExp),
96		Int(i64),
97		Lit(Lit),
98	}
99
100	#[pyclass]
101	#[derive(Clone, Debug)]
102	/// A Boolean linear constraint, also known as a pseudo-Boolean constraint.
103	struct BoolLinCon(BaseBoolLinCon);
104
105	#[pyclass]
106	#[derive(Clone, Debug)]
107	/// A Boolean linear expression, also known as a pseudo-Boolean expression.
108	///
109	/// Using operators `<`, `<=`, `==`, `>=`, and `>` with a `int` right hand
110	/// side, the expression can be turned into a :class:`BoolLinCon`.
111	struct BoolLinExp(BaseBoolLinExp);
112
113	#[pyclass]
114	#[derive(Clone, Debug, Default)]
115	/// The internal representation of a CNF formula.
116	struct CNFInner(Cnf);
117
118	#[derive(FromPyObject)]
119	/// Argument capture for types that represent constraint that can be encoded
120	/// into a CNF formula.
121	enum ConstraintArg {
122		/// A Boolean linear constraint
123		BoolLin(BoolLinCon),
124		/// A propositional formula to be enforced.
125		Formula(Formula),
126	}
127
128	#[expect(non_camel_case_types, reason = "match python naming convention")]
129	#[pyclass(eq, eq_int)]
130	#[derive(Clone, Copy, Debug, PartialEq)]
131	/// Method used to encode a constraint.
132	///
133	/// Warning: Not all encoders can be used to encode each type of constraint.
134	/// If an invalid encoder is selected, then an :class:`InvalidEncoder`
135	/// exception will be raised.
136	enum Encoder {
137		// TODO These doc-strings do not show up, upstream issue: https://github.com/PyO3/pyo3/issues/5197
138		/// Use :class:`pindakaas::bool_linear::AdderEncoder`, which is able to
139		/// encode all Boolean linear constraints.
140		ADDER,
141		/// Use :class:`pindakaas::cardinality_one::BitwiseEncoder`, which is
142		/// able to encode all Boolean cardinality one constraints.
143		BITWISE,
144		/// Use :class:`pindakaas::bool_linear::BddEncoder`, which is able to
145		/// encode all Boolean linear constraints.
146		DECISION_DIAGRAM,
147		/// Use :class:`pindakaas::cardinality_one::LadderEncoder`, which is
148		/// able to encode all Boolean cardinality one constraints.
149		LADDER,
150		/// Use :class:`pindakaas::cardinality_one::PairwiseEncoder`, which is
151		/// able to encode all Boolean cardinality one constraints.
152		PAIRWISE,
153		/// Use :class:`pindakaas::bool_linear::SwcEncoder`, which is able to
154		/// encode all Boolean linear constraints.
155		SORTED_WEIGHT_COUNTER,
156		/// Use :class:`pindakaas::cardinality::SwcEncoder`, which is able to
157		/// encode all Boolean cardinality constraints.
158		SORTING_NETWORK,
159		/// Use :class:`pindakaas::bool_linear::TotalizerEncoder`, which is able
160		/// to encode all Boolean linear constraints.
161		TOTALIZER,
162		/// Use :class:`pindakaas::propositional_logic::TseitinEncoder`, which
163		/// is able to encode propositional logic formulas.
164		TSEITIN,
165	}
166
167	#[pyclass]
168	#[derive(Clone, Debug)]
169	/// A propositional logic formula.
170	struct Formula(BaseFormula<BoolVal>);
171
172	#[derive(FromPyObject)]
173	/// Argument capture for types that can become :class:`Formula`.
174	enum FormulaArg {
175		Const(bool),
176		Formula(Formula),
177		Lit(Lit),
178	}
179
180	struct LinEncoderWrapper {
181		/// Method chosen by the user.
182		method: Option<Encoder>,
183		/// Error message for an invalid choice.
184		error_message: Mutex<Option<PyErr>>,
185	}
186
187	#[pyclass]
188	#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
189	/// A Boolean literal, representing a Boolean variable or its negation.
190	struct Lit(BaseLit);
191
192	#[pyclass]
193	#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
194	/// Representation of a continuous range of variables.
195	struct VarRange(BaseVarRange);
196
197	#[pyclass]
198	#[derive(Clone, Debug, Default)]
199	/// The internal representation of a CNF formula where clauses have optional
200	/// associated weights.
201	struct WCNFInner(Wcnf);
202
203	#[pyfunction]
204	fn _wrap_encode_constraint(
205		obj: &Bound<'_, PyAny>,
206		con: ConstraintArg,
207		enc: Option<Encoder>,
208		conditions: Vec<Lit>,
209	) -> Result {
210		struct PyDbWrapper<'a>(&'a Bound<'a, PyAny>);
211		impl ClauseDatabase for PyDbWrapper<'_> {
212			fn add_clause_from_slice(
213				&mut self,
214				clause: &[BaseLit],
215			) -> Result<(), pindakaas::Unsatisfiable> {
216				let clause_vec = clause.iter().map(|&l| Lit(l)).collect_vec();
217				let res = self.0.call_method1("add_clause", (clause_vec,));
218				match res {
219					Err(e) if e.is_instance_of::<Unsatisfiable>(self.0.py()) => {
220						Err(pindakaas::Unsatisfiable)
221					}
222					Err(e) => {
223						panic!("unexpected error in add_clause implementation: {}", e)
224					}
225					// We would have expected the user implementation to raise `Unsatisfiable`, but
226					// is did not. Since encodings depend on this behaviour, we return the error
227					// instead.
228					Ok(_) if clause.is_empty() => Err(pindakaas::Unsatisfiable),
229					Ok(_) => Ok(()),
230				}
231			}
232
233			fn new_var_range(&mut self, len: usize) -> BaseVarRange {
234				let tup = self
235					.0
236					.call_method1("new_var_range", (len,))
237					.expect("unexpected error in new_var_range implementation");
238				let (start, end): (Lit, Lit) = tup
239					.extract()
240					.expect("new_var_range did not return a tuple of two literals");
241				BaseVarRange::new(start.0.var(), end.0.var())
242			}
243		}
244
245		encode_constraint(&mut PyDbWrapper(obj), con, enc, conditions)
246	}
247
248	/// Internal function to help with the encoding of a constraint given an
249	/// optional encoder.
250	///
251	/// If conditions are provided, the constraint is encoded to only hold if
252	/// all conditions are true.
253	fn encode_constraint<Db>(
254		db: &mut Db,
255		con: ConstraintArg,
256		enc: Option<Encoder>,
257		conditions: Vec<Lit>,
258	) -> Result
259	where
260		Db: ClauseDatabase + ?Sized,
261	{
262		let invalid_enc = |con_ty, enc| {
263			Err(InvalidEncoder::new_err(format!(
264				"Unable to encode object of type `{con_ty}' using {enc:?}"
265			))
266			.into())
267		};
268		let conditions: Vec<_> = conditions.into_iter().map(|l| l.0).collect();
269
270		match con {
271			ConstraintArg::BoolLin(lin) => {
272				let encoder = LinEncoderWrapper::new(enc);
273				let encoder = LinearEncoder::new(encoder, BoolLinAggregator::default());
274				encoder.encode_implied(db, &conditions, &lin.0)?;
275				let err = encoder
276					.variant_encoder()
277					.error_message
278					.lock()
279					.unwrap()
280					.take();
281				if let Some(err) = err {
282					return Err(err.into());
283				}
284			}
285			ConstraintArg::Formula(f) => match enc.unwrap_or(Encoder::TSEITIN) {
286				Encoder::TSEITIN => TseitinEncoder.encode_implied(db, &conditions, &f.0)?,
287				_ => {
288					return invalid_enc("Formula", enc.unwrap());
289				}
290			},
291		};
292		Ok(())
293	}
294
295	impl BoolLinArg {
296		fn as_bool_lin_exp(&self) -> BoolLinExp {
297			match self {
298				&BoolLinArg::Bool(b) => BoolLinExp(b.into()),
299				BoolLinArg::BoolLin(exp) => exp.clone(),
300				&BoolLinArg::Int(i) => BoolLinExp(i.into()),
301				&BoolLinArg::Lit(l) => BoolLinExp(l.0.into()),
302			}
303		}
304	}
305
306	#[pymethods]
307	impl BoolLinCon {
308		fn __str__(&self) -> String {
309			self.0.to_string()
310		}
311	}
312
313	#[pymethods]
314	impl BoolLinExp {
315		fn __add__(&self, other: BoolLinArg) -> Self {
316			let mut res = self.clone();
317			res.__iadd__(other);
318			res
319		}
320
321		fn __eq__(&self, other: i64) -> BoolLinCon {
322			BoolLinCon(BaseBoolLinCon::new(
323				self.0.clone(),
324				Comparator::Equal,
325				other,
326			))
327		}
328
329		fn __ge__(&self, other: i64) -> BoolLinCon {
330			BoolLinCon(BaseBoolLinCon::new(
331				self.0.clone(),
332				Comparator::GreaterEq,
333				other,
334			))
335		}
336
337		fn __gt__(&self, other: i64) -> BoolLinCon {
338			self.__ge__(other + 1)
339		}
340
341		fn __iadd__(&mut self, other: BoolLinArg) {
342			self.0 += other.as_bool_lin_exp().0;
343		}
344
345		fn __imul__(&mut self, other: i64) {
346			self.0 *= other;
347		}
348
349		fn __isub__(&mut self, other: BoolLinArg) {
350			self.0 -= other.as_bool_lin_exp().0;
351		}
352
353		fn __le__(&self, other: i64) -> BoolLinCon {
354			BoolLinCon(BaseBoolLinCon::new(
355				self.0.clone(),
356				Comparator::LessEq,
357				other,
358			))
359		}
360
361		fn __lt__(&self, other: i64) -> BoolLinCon {
362			self.__le__(other - 1)
363		}
364
365		fn __mul__(&self, other: i64) -> Self {
366			let mut res = self.clone();
367			res.__imul__(other);
368			res
369		}
370
371		fn __neg__(&self) -> Self {
372			Self(-self.0.clone())
373		}
374
375		fn __radd__(&self, other: BoolLinArg) -> Self {
376			self.__add__(other)
377		}
378
379		fn __rmul__(&self, other: i64) -> Self {
380			self.__mul__(other)
381		}
382
383		fn __str__(&self) -> String {
384			self.0.to_string()
385		}
386
387		fn __sub__(&self, other: BoolLinArg) -> Self {
388			let mut res = self.clone();
389			res.__isub__(other);
390			res
391		}
392	}
393
394	#[pymethods]
395	impl CNFInner {
396		fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
397			let clause: Vec<Lit> = clause
398				.into_iter()
399				.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
400				.try_collect()?;
401			self.0.add_clause(clause.into_iter().map(|lit| lit.0))?;
402			Ok(())
403		}
404
405		fn add_encoding(
406			&mut self,
407			con: ConstraintArg,
408			enc: Option<Encoder>,
409			conditions: Vec<Lit>,
410		) -> Result {
411			encode_constraint(&mut self.0, con, enc, conditions)
412		}
413
414		fn clauses(&self) -> Vec<Vec<Lit>> {
415			// TODO: It would be great if this could be converted to be lazy, but it
416			// seems a little tricky. This should probably be okay for now.
417			self.0
418				.iter()
419				.map(|c| c.iter().map(|&lit| Lit(lit)).collect())
420				.collect()
421		}
422
423		#[new]
424		fn new() -> Self {
425			Self(Default::default())
426		}
427
428		fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
429			let range = self.0.new_var_range(num_vars);
430			Ok(VarRange(range))
431		}
432
433		fn to_dimacs(&self) -> String {
434			self.0.to_string()
435		}
436
437		fn variables(&self) -> VarRange {
438			VarRange(self.0.variables())
439		}
440	}
441
442	#[pymethods]
443	impl Formula {
444		fn __and__(&self, other: FormulaArg) -> Self {
445			Self(self.0.clone() & other.as_formula())
446		}
447
448		fn __eq__(&self, other: FormulaArg) -> Self {
449			use BaseFormula::*;
450
451			Formula(Equiv(vec![self.0.clone(), other.as_formula()]))
452		}
453
454		fn __ge__(&self, other: FormulaArg) -> Self {
455			use BaseFormula::*;
456
457			Self(Implies(other.as_formula().into(), self.0.clone().into()))
458		}
459
460		fn __gt__(&self, other: FormulaArg) -> Self {
461			Self(self.0.clone() & !other.as_formula())
462		}
463
464		fn __invert__(&self) -> Self {
465			Self(!self.0.clone())
466		}
467
468		fn __le__(&self, other: FormulaArg) -> Self {
469			use BaseFormula::*;
470
471			Self(Implies(self.0.clone().into(), other.as_formula().into()))
472		}
473
474		fn __lt__(&self, other: FormulaArg) -> Self {
475			Self(!self.0.clone() & other.as_formula())
476		}
477
478		fn __ne__(&self, other: FormulaArg) -> Self {
479			self.__xor__(other)
480		}
481
482		fn __or__(&self, other: FormulaArg) -> Self {
483			Formula(self.0.clone() | other.as_formula())
484		}
485
486		fn __rand__(&self, other: FormulaArg) -> Self {
487			self.__and__(other)
488		}
489
490		fn __ror__(&self, other: FormulaArg) -> Self {
491			self.__or__(other)
492		}
493
494		fn __rxor__(&self, other: FormulaArg) -> Self {
495			self.__xor__(other)
496		}
497
498		fn __str__(&self) -> String {
499			self.0.to_string()
500		}
501
502		fn __xor__(&self, other: FormulaArg) -> Self {
503			Formula(self.0.clone() ^ other.as_formula())
504		}
505	}
506
507	impl FormulaArg {
508		/// Internal method used to convert the :class:`FormulaArg` into a
509		/// :class:`BaseFormula<BoolVal>`.
510		fn as_formula(&self) -> BaseFormula<BoolVal> {
511			use BaseFormula::*;
512
513			match self {
514				FormulaArg::Const(b) => Atom(BoolVal::Const(*b)),
515				FormulaArg::Formula(formula) => formula.0.clone(),
516				FormulaArg::Lit(lit) => lit.as_formula(),
517			}
518		}
519	}
520
521	impl LinEncoderWrapper {
522		fn new(method: Option<Encoder>) -> Self {
523			Self {
524				method,
525				error_message: Mutex::new(None),
526			}
527		}
528
529		fn set_err(&self, con_ty: &str, enc: Encoder) {
530			let _ = self
531				.error_message
532				.lock()
533				.unwrap()
534				.replace(InvalidEncoder::new_err(format!(
535					"Unable to encode object of type `{con_ty}' using {enc:?}"
536				)));
537		}
538	}
539
540	impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, BoolLinVariant> for LinEncoderWrapper {
541		fn encode(
542			&self,
543			db: &mut Db,
544			con: &BoolLinVariant,
545		) -> Result<(), pindakaas::Unsatisfiable> {
546			match con {
547				BoolLinVariant::Linear(lin) => self.encode(db, lin),
548				BoolLinVariant::Cardinality(card) => self.encode(db, card),
549				BoolLinVariant::CardinalityOne(card1) => self.encode(db, card1),
550				BoolLinVariant::Trivial => Ok(()),
551			}
552		}
553	}
554
555	impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, Cardinality> for LinEncoderWrapper {
556		fn encode(&self, db: &mut Db, con: &Cardinality) -> Result<(), pindakaas::Unsatisfiable> {
557			match self.method.unwrap_or(Encoder::ADDER) {
558				Encoder::SORTING_NETWORK => SortingNetworkEncoder::default().encode(db, con),
559				Encoder::ADDER => AdderEncoder::default().encode(db, con),
560				Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, con),
561				Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, con),
562				enc => {
563					self.set_err("Cardinality", enc);
564					Ok(())
565				}
566			}
567		}
568	}
569
570	impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, CardinalityOne> for LinEncoderWrapper {
571		fn encode(
572			&self,
573			db: &mut Db,
574			con: &CardinalityOne,
575		) -> Result<(), pindakaas::Unsatisfiable> {
576			match self.method.unwrap_or(Encoder::BITWISE) {
577				Encoder::BITWISE => BitwiseEncoder::default().encode(db, con),
578				Encoder::ADDER => AdderEncoder::default().encode(db, con),
579				Encoder::LADDER => LadderEncoder::default().encode(db, con),
580				Encoder::PAIRWISE => PairwiseEncoder::default().encode(db, con),
581				Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, con),
582				Encoder::SORTING_NETWORK => SortingNetworkEncoder::default().encode(db, con),
583				Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, con),
584				enc => {
585					self.set_err("CardinalityOne", enc);
586					Ok(())
587				}
588			}
589		}
590	}
591
592	impl<Db: ClauseDatabase + ?Sized> EncoderTrait<Db, NormalizedBoolLinear> for LinEncoderWrapper {
593		fn encode(
594			&self,
595			db: &mut Db,
596			con: &NormalizedBoolLinear,
597		) -> Result<(), pindakaas::Unsatisfiable> {
598			match self.method.unwrap_or(Encoder::ADDER) {
599				Encoder::ADDER => AdderEncoder::default().encode(db, con),
600				Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, con),
601				Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, con),
602				enc => {
603					self.set_err("BoolLinear", enc);
604					Ok(())
605				}
606			}
607		}
608	}
609
610	impl Lit {
611		fn as_bool_lin_exp(&self) -> BoolLinExp {
612			BoolLinExp(self.0.into())
613		}
614
615		fn as_formula(&self) -> BaseFormula<BoolVal> {
616			BaseFormula::Atom(self.0.into())
617		}
618	}
619
620	#[pymethods]
621	impl Lit {
622		fn __add__(&self, other: BoolLinArg) -> BoolLinExp {
623			self.as_bool_lin_exp().__add__(other)
624		}
625
626		fn __and__(&self, other: FormulaArg) -> Formula {
627			Formula(self.as_formula()).__and__(other)
628		}
629
630		fn __eq__(&self, other: FormulaArg) -> Formula {
631			Formula(self.as_formula()).__eq__(other)
632		}
633
634		fn __ge__(&self, other: FormulaArg) -> Formula {
635			Formula(self.as_formula()).__ge__(other)
636		}
637
638		fn __gt__(&self, other: FormulaArg) -> Formula {
639			Formula(self.as_formula()).__gt__(other)
640		}
641
642		fn __int__(&self) -> i32 {
643			self.0.into()
644		}
645
646		fn __invert__(&self) -> Self {
647			Self(!self.0)
648		}
649
650		fn __le__(&self, other: FormulaArg) -> Formula {
651			Formula(self.as_formula()).__le__(other)
652		}
653
654		fn __lt__(&self, other: FormulaArg) -> Formula {
655			Formula(self.as_formula()).__lt__(other)
656		}
657
658		fn __mul__(&self, other: i64) -> BoolLinExp {
659			self.as_bool_lin_exp().__mul__(other)
660		}
661
662		fn __ne__(&self, other: FormulaArg) -> Formula {
663			Formula(self.as_formula()).__ne__(other)
664		}
665
666		fn __or__(&self, other: FormulaArg) -> Formula {
667			Formula(self.as_formula()).__or__(other)
668		}
669
670		fn __radd__(&self, other: BoolLinArg) -> BoolLinExp {
671			self.__add__(other)
672		}
673
674		fn __rand__(&self, other: FormulaArg) -> Formula {
675			Formula(self.as_formula()).__and__(other)
676		}
677
678		fn __rmul__(&self, other: i64) -> BoolLinExp {
679			self.__mul__(other)
680		}
681
682		fn __ror__(&self, other: FormulaArg) -> Formula {
683			self.__or__(other)
684		}
685
686		fn __rxor__(&self, other: FormulaArg) -> Formula {
687			self.__xor__(other)
688		}
689
690		fn __str__(&self) -> String {
691			self.0.to_string()
692		}
693
694		fn __sub__(&self, other: BoolLinArg) -> BoolLinExp {
695			self.as_bool_lin_exp().__sub__(other)
696		}
697
698		fn __xor__(&self, other: FormulaArg) -> Formula {
699			Formula(self.as_formula()).__xor__(other)
700		}
701
702		#[staticmethod]
703		fn from_raw(value: NonZeroI32) -> Self {
704			Self(BaseLit::from_raw(value))
705		}
706
707		/// Return whether the variable is negated
708		fn is_negated(&self) -> bool {
709			self.0.is_negated()
710		}
711
712		/// Return the literal's variable
713		fn var(&self) -> Self {
714			Self(self.0.var().into())
715		}
716	}
717
718	impl Display for Lit {
719		fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
720			self.0.fmt(f)
721		}
722	}
723
724	#[pymethods]
725	impl VarRange {
726		fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
727			slf
728		}
729
730		fn __len__(&self) -> usize {
731			self.0.len()
732		}
733
734		fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Lit> {
735			slf.0.next().map(|lit| Lit(lit.into()))
736		}
737
738		/// Returns the final variable included in the range.
739		fn end(&self) -> Lit {
740			Lit(self.0.end().into())
741		}
742
743		#[new]
744		/// Create a new variable range that includes all variables between
745		/// `start` and `end` (inclusive).
746		fn new(start: Lit, end: Lit) -> PyResult<Self> {
747			if start.is_negated() || end.is_negated() {
748				return Err(PyValueError::new_err(
749					"`start' and `end' must be positive literals (directly representing variables)",
750				));
751			}
752			Ok(Self(BaseVarRange::new(start.0.var(), end.0.var())))
753		}
754
755		/// Returns the first variable included in the range.
756		fn start(&self) -> Lit {
757			Lit(self.0.start().into())
758		}
759	}
760
761	#[pymethods]
762	impl WCNFInner {
763		fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
764			let clause: Vec<Lit> = clause
765				.into_iter()
766				.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
767				.try_collect()?;
768			self.0.add_clause(clause.into_iter().map(|lit| lit.0))?;
769			Ok(())
770		}
771
772		fn add_encoding(
773			&mut self,
774			con: ConstraintArg,
775			enc: Option<Encoder>,
776			conditions: Vec<Lit>,
777		) -> Result {
778			encode_constraint(&mut self.0, con, enc, conditions)
779		}
780
781		fn add_weighted_clause(&mut self, clause: Bound<'_, PyIterator>, weight: i64) -> Result {
782			let clause: Vec<Lit> = clause
783				.into_iter()
784				.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
785				.try_collect()?;
786			self.0
787				.add_weighted_clause(clause.into_iter().map(|lit| lit.0), weight)?;
788			Ok(())
789		}
790
791		fn clauses(&self) -> Vec<Vec<Lit>> {
792			// TODO: It would be great if this could be converted to be lazy, but it
793			// seems a little tricky. This should probably be okay for now.
794			self.0
795				.iter()
796				.filter(|(_, w)| w.is_none())
797				.map(|(c, _)| c.iter().map(|&lit| Lit(lit)).collect_vec())
798				.collect()
799		}
800
801		#[new]
802		fn new() -> Self {
803			Self(Default::default())
804		}
805
806		fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
807			let range = self.0.new_var_range(num_vars);
808			Ok(VarRange(range))
809		}
810
811		fn to_dimacs(&self) -> String {
812			self.0.to_string()
813		}
814
815		fn variables(&self) -> VarRange {
816			VarRange(self.0.variables())
817		}
818
819		fn weighted_clauses(&self) -> Vec<(Option<i64>, Vec<Lit>)> {
820			// TODO: It would be great if this could be converted to be lazy, but it
821			// seems a little tricky. This should probably be okay for now.
822			self.0
823				.iter()
824				.map(|(c, &w)| (w, (c.iter().map(|&lit| Lit(lit)).collect())))
825				.collect()
826		}
827	}
828
829	#[pymodule]
830	mod solver {
831		macro_rules! py_solver_result {
832			($name:ident, $owner:ident, $solver:ty) => {
833				#[pymethods]
834				impl $name {
835					fn __enter__(slf: Py<Self>) -> Py<Self> {
836						slf
837					}
838
839					fn __exit__(
840						&mut self,
841						py: Python<'_>,
842						_exc_type: Option<&Bound<'_, PyAny>>,
843						_exc: Option<&Bound<'_, PyAny>>,
844						_traceback: Option<&Bound<'_, PyAny>>,
845					) -> PyResult<bool> {
846						self.0.exit(py, |owner| &mut owner.0)
847					}
848
849					fn failed(&self, lit: Lit) -> PyResult<Option<bool>> {
850						self.0.failed(lit)
851					}
852
853					#[getter]
854					fn status(&self) -> PyResult<Status> {
855						self.0.status()
856					}
857
858					fn value(&self, lit: Lit) -> PyResult<Option<bool>> {
859						self.0.value(lit)
860					}
861				}
862			};
863		}
864
865		use std::{
866			mem::transmute,
867			time::{Duration, SystemTime},
868		};
869
870		use itertools::Itertools;
871		use pindakaas::{
872			solver::{
873				cadical::Cadical, kissat::Kissat, Assumptions, FailedAssumptions, SolveResult,
874				Solver, TermSignal, TerminateCallback,
875			},
876			ClauseDatabase, ClauseDatabaseTools, Lit as BaseLit, Valuation,
877		};
878		use pyo3::{
879			exceptions::{PyNotImplementedError, PyRuntimeError},
880			prelude::*,
881			pyclass::boolean_struct::False,
882			types::{PyAny, PyIterator},
883			PyClass,
884		};
885
886		use crate::{
887			pindakaas::{encode_constraint, ConstraintArg, Encoder, Lit, VarRange},
888			Result,
889		};
890
891		const CHECKED_OUT_ERROR: &str = "solver is currently checked out by an active result";
892		const INACTIVE_RESULT_ERROR: &str = "solver result is no longer active";
893		const RESTORED_ERROR: &str = "solver was already restored to its owner";
894
895		#[pyclass(unsendable)]
896		#[derive(Debug)]
897		struct CaDiCaLInner(SolverImpl<Cadical>);
898
899		#[pyclass(unsendable)]
900		struct CaDiCaLResult(SolverResultImpl<CaDiCaLInner, Cadical>);
901
902		#[pyclass(unsendable)]
903		#[derive(Debug)]
904		struct KissatInner(SolverImpl<Kissat>);
905
906		#[pyclass(unsendable)]
907		struct KissatResult(SolverResultImpl<KissatInner, Kissat>);
908
909		#[derive(Debug)]
910		struct SolverImpl<S> {
911			solver: Option<S>,
912		}
913
914		struct SolverResultImpl<Owner, S> {
915			owner: Py<Owner>,
916			result: Option<SolverResultState>,
917			solver: Option<S>,
918			supports_assumptions: bool,
919		}
920
921		enum SolverResultState {
922			/// A satisfying valuation for the current solve call.
923			Satisfied(Box<dyn Valuation + 'static>),
924			/// Failed assumptions for the current solve call.
925			Unsatisfiable(Box<dyn Fn(BaseLit) -> Option<bool> + 'static>),
926			/// The solver terminated without a definitive result.
927			Unknown,
928		}
929
930		#[pyclass(eq, eq_int)]
931		#[derive(Clone, Copy, Debug, PartialEq)]
932		/// The resulting status of solving a problem.
933		enum Status {
934			/// A solution was found.
935			SATISFIED,
936			/// No solution exists for the given problem.
937			UNSATISFIABLE,
938			/// The solving process was interrupted before a result was found.
939			UNKNOWN,
940		}
941
942		/// Hack: workaround for https://github.com/PyO3/pyo3/issues/759
943		#[pymodule_init]
944		fn init(module: &Bound<'_, PyModule>) -> PyResult<()> {
945			module
946				.py()
947				.import("sys")?
948				.getattr("modules")?
949				.set_item("pindakaas.pindakaas.solver", module)
950		}
951
952		#[pymethods]
953		impl CaDiCaLInner {
954			fn _set_option(&mut self, name: &str, value: i32) -> PyResult<()> {
955				self.0.solver_mut()?.set_option(name, value);
956				Ok(())
957			}
958
959			fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
960				self.0.add_clause(clause)
961			}
962
963			fn add_encoding(
964				&mut self,
965				con: ConstraintArg,
966				enc: Option<Encoder>,
967				conditions: Vec<Lit>,
968			) -> Result {
969				self.0.add_encoding(con, enc, conditions)
970			}
971
972			#[new]
973			fn new() -> Self {
974				Self(SolverImpl::default())
975			}
976
977			fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
978				self.0.new_var_range(num_vars)
979			}
980
981			fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
982				self.0.set_time_limit(limit)
983			}
984
985			fn solve_assuming(
986				slf: Py<Self>,
987				py: Python<'_>,
988				assumptions: Vec<Lit>,
989			) -> Result<Py<CaDiCaLResult>> {
990				let mut inner = slf.bind(py).borrow_mut();
991				let solver = inner.0.take()?;
992				Ok(Py::new(
993					py,
994					CaDiCaLResult(SolverResultImpl::from_assumptions_solver(
995						slf.clone_ref(py),
996						solver,
997						&assumptions,
998					)),
999				)?)
1000			}
1001		}
1002
1003		#[pymethods]
1004		impl KissatInner {
1005			fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
1006				self.0.add_clause(clause)
1007			}
1008
1009			fn add_encoding(
1010				&mut self,
1011				con: ConstraintArg,
1012				enc: Option<Encoder>,
1013				conditions: Vec<Lit>,
1014			) -> Result {
1015				self.0.add_encoding(con, enc, conditions)
1016			}
1017
1018			#[new]
1019			fn new() -> Self {
1020				Self(SolverImpl::default())
1021			}
1022
1023			fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
1024				self.0.new_var_range(num_vars)
1025			}
1026
1027			fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
1028				self.0.set_time_limit(limit)
1029			}
1030
1031			fn solve_assuming(
1032				slf: Py<Self>,
1033				py: Python<'_>,
1034				assumptions: Vec<Lit>,
1035			) -> Result<Py<KissatResult>> {
1036				if !assumptions.is_empty() {
1037					return Err(PyNotImplementedError::new_err(
1038						"solver does not support assumptions",
1039					)
1040					.into());
1041				}
1042				let mut inner = slf.bind(py).borrow_mut();
1043				let solver = inner.0.take()?;
1044				Ok(Py::new(
1045					py,
1046					KissatResult(SolverResultImpl::from_solver(slf.clone_ref(py), solver)),
1047				)?)
1048			}
1049		}
1050
1051		impl<S> SolverImpl<S> {
1052			fn solver_mut(&mut self) -> PyResult<&mut S> {
1053				self.solver
1054					.as_mut()
1055					.ok_or_else(|| PyRuntimeError::new_err(CHECKED_OUT_ERROR))
1056			}
1057
1058			fn take(&mut self) -> PyResult<S> {
1059				self.solver
1060					.take()
1061					.ok_or_else(|| PyRuntimeError::new_err(CHECKED_OUT_ERROR))
1062			}
1063		}
1064
1065		impl<S: ClauseDatabase> SolverImpl<S> {
1066			fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
1067				let clause: Vec<Lit> = clause
1068					.into_iter()
1069					.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
1070					.try_collect()?;
1071				self.solver_mut()?
1072					.add_clause(clause.into_iter().map(|lit| lit.0))?;
1073				Ok(())
1074			}
1075
1076			fn add_encoding(
1077				&mut self,
1078				con: ConstraintArg,
1079				enc: Option<Encoder>,
1080				conditions: Vec<Lit>,
1081			) -> Result {
1082				encode_constraint(self.solver_mut()?, con, enc, conditions)
1083			}
1084
1085			fn new_var_range(&mut self, num_vars: usize) -> PyResult<VarRange> {
1086				Ok(VarRange(self.solver_mut()?.new_var_range(num_vars)))
1087			}
1088		}
1089
1090		impl<S: TerminateCallback> SolverImpl<S> {
1091			fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
1092				self.solver_mut()?.set_terminate_callback(limit.map(|dur| {
1093					let deadline = SystemTime::now() + dur;
1094					move || {
1095						if SystemTime::now() > deadline {
1096							TermSignal::Terminate
1097						} else {
1098							TermSignal::Continue
1099						}
1100					}
1101				}));
1102				Ok(())
1103			}
1104		}
1105
1106		impl<S: Default> Default for SolverImpl<S> {
1107			fn default() -> Self {
1108				Self {
1109					solver: Some(S::default()),
1110				}
1111			}
1112		}
1113
1114		impl<Owner: PyClass<Frozen = False>, S> SolverResultImpl<Owner, S> {
1115			fn exit(
1116				&mut self,
1117				py: Python<'_>,
1118				slot: fn(&mut Owner) -> &mut SolverImpl<S>,
1119			) -> PyResult<bool> {
1120				self.result = None;
1121				if let Some(solver) = self.solver.take() {
1122					let mut owner = self.owner.bind(py).borrow_mut();
1123					let inner = slot(std::ops::DerefMut::deref_mut(&mut owner));
1124					if inner.solver.is_some() {
1125						return Err(PyRuntimeError::new_err(RESTORED_ERROR));
1126					}
1127					inner.solver = Some(solver);
1128				}
1129				Ok(false)
1130			}
1131		}
1132
1133		impl<Owner, S: Solver> SolverResultImpl<Owner, S> {
1134			fn from_solver(owner: Py<Owner>, mut solver: S) -> Self {
1135				let result = match solver.solve() {
1136					SolveResult::Satisfied(sol) => {
1137						let sol: Box<dyn Valuation + '_> = Box::new(sol);
1138						// SAFETY: The returned valuation is tied to the checked-out
1139						// solver and is dropped before solver access is restored.
1140						let sol: Box<dyn Valuation + 'static> = unsafe { transmute(sol) };
1141						SolverResultState::Satisfied(sol)
1142					}
1143					SolveResult::Unsatisfiable(_) => {
1144						SolverResultState::Unsatisfiable(Box::new(|_| None))
1145					}
1146					SolveResult::Unknown => SolverResultState::Unknown,
1147				};
1148				Self::new(owner, result, solver, false)
1149			}
1150		}
1151
1152		impl<Owner, S: Assumptions> SolverResultImpl<Owner, S> {
1153			fn from_assumptions_solver(
1154				owner: Py<Owner>,
1155				mut solver: S,
1156				assumptions: &[Lit],
1157			) -> Self {
1158				let result = match solver.solve_assuming(assumptions.iter().map(|lit| lit.0)) {
1159					SolveResult::Satisfied(sol) => {
1160						let sol: Box<dyn Valuation + '_> = Box::new(sol);
1161						// SAFETY: The returned valuation is only valid while the solver
1162						// state remains alive and unchanged. The corresponding result
1163						// object owns the checked-out solver and drops this boxed value
1164						// before restoring solver access.
1165						let sol: Box<dyn Valuation + 'static> = unsafe { transmute(sol) };
1166						SolverResultState::Satisfied(sol)
1167					}
1168					SolveResult::Unsatisfiable(fail) => {
1169						let fail: Box<dyn FailedAssumptions + '_> = Box::new(fail);
1170						// SAFETY: Same reasoning as above for the failed-assumptions
1171						// object.
1172						let fail: Box<dyn FailedAssumptions + 'static> = unsafe { transmute(fail) };
1173						let fail = move |lit: BaseLit| Some(fail.fail(lit));
1174						SolverResultState::Unsatisfiable(Box::new(fail))
1175					}
1176					SolveResult::Unknown => SolverResultState::Unknown,
1177				};
1178				Self::new(owner, result, solver, true)
1179			}
1180		}
1181
1182		impl<Owner, S> SolverResultImpl<Owner, S> {
1183			fn failed(&self, lit: Lit) -> PyResult<Option<bool>> {
1184				let Some(result) = self.result.as_ref() else {
1185					return Err(PyRuntimeError::new_err(INACTIVE_RESULT_ERROR));
1186				};
1187				if !self.supports_assumptions {
1188					return Ok(None);
1189				}
1190				Ok(match result {
1191					SolverResultState::Unsatisfiable(fail) => fail(lit.0),
1192					_ => None,
1193				})
1194			}
1195
1196			fn new(
1197				owner: Py<Owner>,
1198				result: SolverResultState,
1199				solver: S,
1200				supports_assumptions: bool,
1201			) -> Self {
1202				Self {
1203					owner,
1204					result: Some(result),
1205					solver: Some(solver),
1206					supports_assumptions,
1207				}
1208			}
1209
1210			fn status(&self) -> PyResult<Status> {
1211				let Some(result) = self.result.as_ref() else {
1212					return Err(PyRuntimeError::new_err(INACTIVE_RESULT_ERROR));
1213				};
1214				Ok(match result {
1215					SolverResultState::Satisfied(_) => Status::SATISFIED,
1216					SolverResultState::Unsatisfiable(_) => Status::UNSATISFIABLE,
1217					SolverResultState::Unknown => Status::UNKNOWN,
1218				})
1219			}
1220
1221			fn value(&self, lit: Lit) -> PyResult<Option<bool>> {
1222				let Some(result) = self.result.as_ref() else {
1223					return Err(PyRuntimeError::new_err(INACTIVE_RESULT_ERROR));
1224				};
1225				Ok(match result {
1226					SolverResultState::Satisfied(sol) => Some(sol.value(lit.0)),
1227					_ => None,
1228				})
1229			}
1230		}
1231
1232		py_solver_result!(CaDiCaLResult, CaDiCaLInner, Cadical);
1233
1234		py_solver_result!(KissatResult, KissatInner, Kissat);
1235	}
1236}