#![expect(
clippy::upper_case_acronyms,
reason = "Python naming for exposed types"
)]
use std::sync::PoisonError;
use pyo3::{create_exception, exceptions::PyException, prelude::*};
create_exception!(pindakaas, InvalidEncoder, PyException, "Raised when the chosen encoder does not support the constraint (e.g. when the `PairwiseEncoder` encoder for AMO constraints is used to encode a PB constraint).");
create_exception!(
pindakaas,
Unsatisfiable,
PyException,
"Raised when the given constraint is found to be Unsatisfiable during encoding."
);
type Result<R = (), E = ErrWrapper> = std::result::Result<R, E>;
struct ErrWrapper(PyErr);
impl<T> From<PoisonError<T>> for ErrWrapper {
fn from(e: PoisonError<T>) -> Self {
Self(PyException::new_err(e.to_string()))
}
}
impl From<PyErr> for ErrWrapper {
fn from(err: PyErr) -> Self {
ErrWrapper(err)
}
}
impl From<::pindakaas::Unsatisfiable> for ErrWrapper {
fn from(_: ::pindakaas::Unsatisfiable) -> Self {
Self(Unsatisfiable::new_err(
"The given constraint was found to be Unsatisfiable during encoding",
))
}
}
impl From<ErrWrapper> for PyErr {
fn from(err: ErrWrapper) -> Self {
err.0
}
}
#[pymodule]
mod pindakaas {
use std::fmt::{self, Display};
use pindakaas::{
bool_linear::{
AdderEncoder, BoolLinAggregator, BoolLinExp as BaseBoolLinExp, BoolLinVariant,
BoolLinear as BaseBoolLinCon, Comparator, SwcEncoder, TotalizerEncoder,
},
cardinality::SortingNetworkEncoder,
cardinality_one::{BitwiseEncoder, LadderEncoder, PairwiseEncoder},
propositional_logic::{Formula as BaseFormula, TseitinEncoder},
BoolVal, ClauseDatabase, ClauseDatabaseTools, Cnf, Encoder as _, Lit as BaseLit, Wcnf,
};
use pyo3::{prelude::*, types::PyIterator};
#[pymodule_export]
use crate::InvalidEncoder;
use crate::Result;
#[pymodule_export]
use crate::Unsatisfiable;
#[derive(FromPyObject)]
enum BoolLinArg {
Bool(bool),
BoolLin(BoolLinExp),
Int(i64),
Lit(Lit),
}
#[pyclass]
#[derive(Clone, Debug)]
struct BoolLinCon(BaseBoolLinCon);
#[pyclass]
#[derive(Clone, Debug)]
struct BoolLinExp(BaseBoolLinExp);
#[pyclass]
#[derive(Clone, Debug, Default)]
struct CNFInner(Cnf);
#[derive(FromPyObject)]
enum ConstraintArg {
BoolLin(BoolLinCon),
Formula(Formula),
}
#[expect(non_camel_case_types, reason = "match python naming convention")]
#[pyclass(eq, eq_int)]
#[derive(Clone, Copy, Debug, PartialEq)]
enum Encoder {
ADDER,
BITWISE,
DECISION_DIAGRAM,
LADDER,
PAIRWISE,
SORTED_WEIGHT_COUNTER,
SORTING_NETWORK,
TOTALIZER,
TSEITIN,
}
#[pyclass]
#[derive(Clone, Debug)]
struct Formula(BaseFormula<BoolVal>);
#[derive(FromPyObject)]
enum FormulaArg {
Const(bool),
Formula(Formula),
Lit(Lit),
}
#[pyclass]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Lit(BaseLit);
#[pyclass]
#[derive(Clone, Debug, Default)]
struct WCNFInner(Wcnf);
fn encode_constraint_with_conditions<Db: ClauseDatabase>(
db: &mut Db,
con: ConstraintArg,
enc: Option<Encoder>,
conditions: Vec<Lit>,
) -> Result {
if conditions.is_empty() {
encode_constraint(db, con, enc)
} else {
encode_constraint(
&mut db.with_conditions(conditions.into_iter().map(|l| l.0).collect()),
con,
enc,
)
}
}
fn encode_constraint<Db: ClauseDatabase>(
db: &mut Db,
con: ConstraintArg,
enc: Option<Encoder>,
) -> Result {
let invalid_enc = |con_ty, enc| {
Err(InvalidEncoder::new_err(format!(
"Unable to encode object of type `{con_ty}' using {enc:?}"
))
.into())
};
match con {
ConstraintArg::BoolLin(lin) => {
let aggregated = BoolLinAggregator::default().aggregate(db, &lin.0)?;
match aggregated {
BoolLinVariant::Cardinality(c) => match enc.unwrap_or(Encoder::SORTING_NETWORK)
{
Encoder::SORTING_NETWORK => SortingNetworkEncoder::default().encode(db, &c),
Encoder::ADDER => AdderEncoder::default().encode(db, &c),
Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, &c),
Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, &c),
_ => return invalid_enc("Cardinality", enc.unwrap()),
},
BoolLinVariant::CardinalityOne(c) => match enc.unwrap_or(Encoder::BITWISE) {
Encoder::BITWISE => BitwiseEncoder::default().encode(db, &c),
Encoder::ADDER => AdderEncoder::default().encode(db, &c),
Encoder::LADDER => LadderEncoder::default().encode(db, &c),
Encoder::PAIRWISE => PairwiseEncoder::default().encode(db, &c),
Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, &c),
Encoder::SORTING_NETWORK => SortingNetworkEncoder::default().encode(db, &c),
Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, &c),
_ => return invalid_enc("CardinalityOne", enc.unwrap()),
},
BoolLinVariant::Linear(lin) => match enc.unwrap_or(Encoder::TOTALIZER) {
Encoder::TOTALIZER => TotalizerEncoder::default().encode(db, &lin),
Encoder::ADDER => AdderEncoder::default().encode(db, &lin),
Encoder::SORTED_WEIGHT_COUNTER => SwcEncoder::default().encode(db, &lin),
_ => return invalid_enc("BoolLinear", enc.unwrap()),
},
BoolLinVariant::Trivial => return Ok(()),
}?;
}
ConstraintArg::Formula(f) => match enc.unwrap_or(Encoder::TSEITIN) {
Encoder::TSEITIN => TseitinEncoder.encode(db, &f.0)?,
_ => {
return invalid_enc("Formula", enc.unwrap());
}
},
};
Ok(())
}
impl BoolLinArg {
fn as_bool_lin_exp(&self) -> BoolLinExp {
match self {
&BoolLinArg::Bool(b) => BoolLinExp(b.into()),
BoolLinArg::BoolLin(exp) => exp.clone(),
&BoolLinArg::Int(i) => BoolLinExp(i.into()),
&BoolLinArg::Lit(l) => BoolLinExp(l.0.into()),
}
}
}
#[pymethods]
impl BoolLinCon {
fn __str__(&self) -> String {
self.0.to_string()
}
}
#[pymethods]
impl BoolLinExp {
fn __add__(&self, other: BoolLinArg) -> Self {
let mut res = self.clone();
res.__iadd__(other);
res
}
fn __radd__(&self, other: BoolLinArg) -> Self {
self.__add__(other)
}
fn __eq__(&self, other: i64) -> BoolLinCon {
BoolLinCon(BaseBoolLinCon::new(
self.0.clone(),
Comparator::Equal,
other,
))
}
fn __ge__(&self, other: i64) -> BoolLinCon {
BoolLinCon(BaseBoolLinCon::new(
self.0.clone(),
Comparator::GreaterEq,
other,
))
}
fn __gt__(&self, other: i64) -> BoolLinCon {
self.__ge__(other + 1)
}
fn __iadd__(&mut self, other: BoolLinArg) {
self.0 += other.as_bool_lin_exp().0;
}
fn __imul__(&mut self, other: i64) {
self.0 *= other;
}
fn __isub__(&mut self, other: BoolLinArg) {
self.0 -= other.as_bool_lin_exp().0;
}
fn __le__(&self, other: i64) -> BoolLinCon {
BoolLinCon(BaseBoolLinCon::new(
self.0.clone(),
Comparator::LessEq,
other,
))
}
fn __lt__(&self, other: i64) -> BoolLinCon {
self.__le__(other - 1)
}
fn __mul__(&self, other: i64) -> Self {
let mut res = self.clone();
res.__imul__(other);
res
}
fn __rmul__(&self, other: i64) -> Self {
self.__mul__(other)
}
fn __neg__(&self) -> Self {
Self(-self.0.clone())
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __sub__(&self, other: BoolLinArg) -> Self {
let mut res = self.clone();
res.__isub__(other);
res
}
}
use itertools::Itertools;
#[pymethods]
impl CNFInner {
fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
let clause: Vec<Lit> = clause
.into_iter()
.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
.try_collect()?;
self.0.add_clause(clause.into_iter().map(|lit| lit.0))?;
Ok(())
}
fn add_encoding(
&mut self,
con: ConstraintArg,
enc: Option<Encoder>,
conditions: Vec<Lit>,
) -> Result {
encode_constraint_with_conditions(&mut self.0, con, enc, conditions)
}
#[new]
fn new() -> Self {
Self(Default::default())
}
fn new_vars(&mut self, num_vars: usize) -> Vec<Lit> {
self.0
.new_var_range(num_vars)
.into_iter()
.map(|lit| Lit(lit.into()))
.collect()
}
fn to_dimacs(&self) -> String {
self.0.to_string()
}
}
#[pymethods]
impl Formula {
fn __and__(&self, other: FormulaArg) -> Self {
Self(self.0.clone() & other.as_formula())
}
fn __rand__(&self, other: FormulaArg) -> Self {
self.__and__(other)
}
fn __eq__(&self, other: FormulaArg) -> Self {
use BaseFormula::*;
Formula(Equiv(vec![self.0.clone(), other.as_formula()]))
}
fn __ge__(&self, other: FormulaArg) -> Self {
use BaseFormula::*;
Self(Implies(other.as_formula().into(), self.0.clone().into()))
}
fn __gt__(&self, other: FormulaArg) -> Self {
Self(self.0.clone() & !other.as_formula())
}
fn __le__(&self, other: FormulaArg) -> Self {
use BaseFormula::*;
Self(Implies(self.0.clone().into(), other.as_formula().into()))
}
fn __lt__(&self, other: FormulaArg) -> Self {
Self(!self.0.clone() & other.as_formula())
}
fn __invert__(&self) -> Self {
Self(!self.0.clone())
}
fn __ne__(&self, other: FormulaArg) -> Self {
self.__xor__(other)
}
fn __or__(&self, other: FormulaArg) -> Self {
Formula(self.0.clone() | other.as_formula())
}
fn __ror__(&self, other: FormulaArg) -> Self {
self.__or__(other)
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __xor__(&self, other: FormulaArg) -> Self {
Formula(self.0.clone() ^ other.as_formula())
}
fn __rxor__(&self, other: FormulaArg) -> Self {
self.__xor__(other)
}
}
impl FormulaArg {
fn as_formula(&self) -> BaseFormula<BoolVal> {
use BaseFormula::*;
match self {
FormulaArg::Const(b) => Atom(BoolVal::Const(*b)),
FormulaArg::Formula(formula) => formula.0.clone(),
FormulaArg::Lit(lit) => lit.as_formula(),
}
}
}
impl Lit {
fn as_bool_lin_exp(&self) -> BoolLinExp {
BoolLinExp(self.0.into())
}
fn as_formula(&self) -> BaseFormula<BoolVal> {
BaseFormula::Atom(self.0.into())
}
}
#[pymethods]
impl Lit {
fn __add__(&self, other: BoolLinArg) -> BoolLinExp {
self.as_bool_lin_exp().__add__(other)
}
fn __radd__(&self, other: BoolLinArg) -> BoolLinExp {
self.__add__(other)
}
fn __and__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__and__(other)
}
fn __rand__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__and__(other)
}
fn __eq__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__eq__(other)
}
fn __ge__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__ge__(other)
}
fn __gt__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__gt__(other)
}
fn __le__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__le__(other)
}
fn __lt__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__lt__(other)
}
fn __int__(&self) -> i32 {
self.0.into()
}
fn __invert__(&self) -> Self {
Self(!self.0)
}
fn __mul__(&self, other: i64) -> BoolLinExp {
self.as_bool_lin_exp().__mul__(other)
}
fn __rmul__(&self, other: i64) -> BoolLinExp {
self.__mul__(other)
}
fn __ne__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__ne__(other)
}
fn __or__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__or__(other)
}
fn __ror__(&self, other: FormulaArg) -> Formula {
self.__or__(other)
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __sub__(&self, other: BoolLinArg) -> BoolLinExp {
self.as_bool_lin_exp().__sub__(other)
}
fn __xor__(&self, other: FormulaArg) -> Formula {
Formula(self.as_formula()).__xor__(other)
}
fn __rxor__(&self, other: FormulaArg) -> Formula {
self.__xor__(other)
}
fn is_negated(&self) -> bool {
self.0.is_negated()
}
fn var(&self) -> Self {
Self(self.0.var().into())
}
}
impl Display for Lit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[pymethods]
impl WCNFInner {
fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
let clause: Vec<Lit> = clause
.into_iter()
.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
.try_collect()?;
self.0.add_clause(clause.into_iter().map(|lit| lit.0))?;
Ok(())
}
fn add_encoding(
&mut self,
con: ConstraintArg,
enc: Option<Encoder>,
conditions: Vec<Lit>,
) -> Result {
encode_constraint_with_conditions(&mut self.0, con, enc, conditions)
}
fn add_weighted_clause(&mut self, clause: Bound<'_, PyIterator>, weight: i64) -> Result {
let clause: Vec<Lit> = clause
.into_iter()
.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
.try_collect()?;
self.0
.add_weighted_clause(clause.into_iter().map(|lit| lit.0), weight)?;
Ok(())
}
#[new]
fn new() -> Self {
Self(Default::default())
}
fn new_vars(&mut self, num_vars: usize) -> Vec<Lit> {
self.0
.new_var_range(num_vars)
.into_iter()
.map(|lit| Lit(lit.into()))
.collect()
}
fn to_dimacs(&self) -> String {
self.0.to_string()
}
}
#[pymodule]
mod solver {
use std::{
collections::HashMap,
sync::Mutex,
time::{Duration, SystemTime},
};
use itertools::Itertools;
use pindakaas::{
solver::{
cadical::Cadical, FailedAssumtions, SlvTermSignal, SolveAssuming, SolveResult,
TermCallback,
},
ClauseDatabase, ClauseDatabaseTools, Valuation,
};
use pyo3::{prelude::*, types::PyIterator};
use super::{encode_constraint_with_conditions, Result};
use crate::pindakaas::{ConstraintArg, Encoder, Lit};
#[pyclass]
#[derive(Debug, Default)]
struct CaDiCaLInner(Mutex<Cadical>);
#[pyclass(eq, eq_int)]
#[derive(Clone, Copy, Debug, PartialEq)]
enum Status {
SATISFIED,
UNSATISFIABLE,
UNKNOWN,
}
fn dur_term_fn(dur: Duration) -> impl Fn() -> SlvTermSignal + 'static {
let deadline = SystemTime::now() + dur;
move || {
if SystemTime::now() > deadline {
SlvTermSignal::Terminate
} else {
SlvTermSignal::Continue
}
}
}
#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
Python::with_gil(|py| {
py.import("sys")?
.getattr("modules")?
.set_item("pindakaas.pindakaas.solver", m)
})
}
#[pymethods]
impl CaDiCaLInner {
fn add_clause(&mut self, clause: Bound<'_, PyIterator>) -> Result {
let clause: Vec<Lit> = clause
.into_iter()
.map(|any| any.and_then(|lit| lit.extract::<Lit>()))
.try_collect()?;
let mut guard = self.0.lock()?;
guard.add_clause(clause.into_iter().map(|lit| lit.0))?;
Ok(())
}
fn add_encoding(
&mut self,
con: ConstraintArg,
enc: Option<Encoder>,
conditions: Vec<Lit>,
) -> Result {
let mut guard = self.0.lock().unwrap();
encode_constraint_with_conditions(&mut *guard, con, enc, conditions)
}
#[new]
fn new() -> Self {
Self(Default::default())
}
fn new_vars(&mut self, num_vars: usize) -> Result<Vec<Lit>> {
let mut guard = self.0.lock()?;
Ok(guard
.new_var_range(num_vars)
.into_iter()
.map(|lit| Lit(lit.into()))
.collect())
}
fn set_time_limit(&mut self, limit: Option<Duration>) -> Result {
let mut guard = self.0.lock()?;
guard.set_terminate_callback(limit.map(dur_term_fn));
Ok(())
}
fn solve_assuming(
&self,
assumptions: Vec<Lit>,
) -> Result<(Status, HashMap<i32, bool>)> {
let mut guard = self.0.lock()?;
let vars = guard.emitted_vars();
Ok(
match guard.solve_assuming(assumptions.iter().map(|&lit| lit.0)) {
SolveResult::Satisfied(sol) => (
Status::SATISFIED,
vars.into_iter()
.map(|var| (var.into(), sol.value(var.into())))
.collect(),
),
SolveResult::Unsatisfiable(fail) => (
Status::UNSATISFIABLE,
assumptions
.iter()
.map(|&lit| (lit.0.into(), fail.fail(lit.0)))
.collect(),
),
SolveResult::Unknown => (Status::UNKNOWN, HashMap::new()),
},
)
}
}
}
}