minisat 0.4.1

MiniSat Rust interface. Solves a boolean satisfiability problem given in conjunctive normal form.
use super::{Bool, Solver, Model, ModelEq, ModelOrd, ModelValue};
use std::iter::once;

#[derive(Debug,Clone)]
pub struct Binary(Vec<Bool>);

impl Binary {
    pub fn new(solver: &mut Solver, size: usize) -> Binary {
        let (mut bits, mut n) = (0, size);
        while n > 0 {
            bits += 1;
            n /= 2;
        }
        let lits = (0..bits).map(|_| solver.new_lit()).collect::<Vec<_>>();
        Binary(lits)
    }

    pub fn constant(x: usize) -> Binary {
        let mut v = Vec::new();
        let mut i = x;
        while i > 0 {
            v.push(if i % 2 == 1 {
                true.into()
            } else {
                false.into()
            });
            i /= 2;
        }
        Binary(v)
    }

    pub fn invert(&self) -> Binary {
        Binary(self.0.iter().cloned().map(|x| !x).collect())
    }

    pub fn from_list<I: IntoIterator<Item = Bool>>(xs: I) -> Binary {
        Binary(xs.into_iter().collect())
    }

    pub fn count(&self, solver: &mut Solver) -> Binary {
        let bins = self.0
            .iter()
            .cloned()
            .map(|d| Binary::from_list(once(d)))
            .collect::<Vec<_>>();
        Binary::add_list(solver, bins.iter())
    }

    pub fn add(&self, solver: &mut Solver, other: &Binary) -> Binary {
        Binary::add_list(solver, once(self).chain(once(other)))
    }

    pub fn add_list<'a, I: IntoIterator<Item = &'a Binary>>(solver: &mut Solver, xs: I) -> Binary {
        Binary::add_bits(solver,
                         xs.into_iter()
                             .flat_map(|x| {
                                 x.0
                                     .iter()
                                     .cloned()
                                     .enumerate()
                             })
                             .collect())
    }

    pub fn add_bits(solver: &mut Solver, xs: Vec<(usize, Bool)>) -> Binary {
        #[derive(Debug)]
        struct Item {
            bit: usize,
            val: Bool,
        };
        let mut xs = xs.into_iter()
            .map(|(bit, val)| {
                Item {
                    bit: bit,
                    val: val,
                }
            })
            .collect::<Vec<_>>();
        xs.sort_by_key(|&Item { bit, .. }| bit);
        let mut out = Vec::new();
        let mut i = 0;

        while xs.len() > 0 {
            if i < xs[0].bit {
                out.push(false.into());
                i = i + 1;
            } else if xs.len() >= 3 && xs[0].bit == xs[1].bit && xs[1].bit == xs[2].bit {
                let Item { bit, val: a_val } = xs.remove(0);
                let Item { bit: _, val: b_val } = xs.remove(0);
                let Item { bit: _, val: c_val } = xs.remove(0);
                let (v, c) = Binary::full_add(solver, a_val, b_val, c_val);
                xs.push(Item { bit: bit, val: v });
                xs.push(Item {
                    bit: bit + 1,
                    val: c,
                });
                xs.sort_by_key(|&Item { bit, .. }| bit);
                i = bit;
            } else if xs.len() >= 2 && xs[0].bit == xs[1].bit {
                let Item { bit, val: a_val } = xs.remove(0);
                let Item { bit: _, val: b_val } = xs.remove(0);
                let (v, c) = Binary::full_add(solver, a_val, b_val, false.into());
                out.push(v);
                xs.push(Item {
                    bit: bit + 1,
                    val: c,
                });
                xs.sort_by_key(|&Item { bit, .. }| bit);
                i = bit + 1;
            } else {
                let Item { bit, val } = xs.remove(0);
                out.push(val);
                i = bit + 1;
            }
        }

        Binary(out)
    }

    fn full_add(solver: &mut Solver, a: Bool, b: Bool, c: Bool) -> (Bool, Bool) {
        let v = solver.xor_literal(once(a).chain(once(b)).chain(once(c)));
        let c = Binary::at_least_two(solver, a, b, c);
        (v, c)
    }

    fn at_least_two(solver: &mut Solver, a: Bool, b: Bool, c: Bool) -> Bool {
        if a == true.into() {
            solver.or_literal(once(b).chain(once(c)))
        } else if b == true.into() {
            solver.or_literal(once(a).chain(once(c)))
        } else if c == true.into() {
            solver.or_literal(once(a).chain(once(b)))
        } else if a == false.into() {
            solver.and_literal(once(b).chain(once(c)))
        } else if b == false.into() {
            solver.and_literal(once(a).chain(once(c)))
        } else if c == false.into() {
            solver.and_literal(once(a).chain(once(b)))
        } else if a == b {
            a
        } else if b == c {
            b
        } else if a == c {
            c
        } else if a == !b {
            c
        } else if b == !c {
            a
        } else if a == !c {
            b
        } else {
            let v = solver.new_lit();
            solver.add_clause(once(!a).chain(once(!b)).chain(once(v)));
            solver.add_clause(once(!a).chain(once(!c)).chain(once(v)));
            solver.add_clause(once(!b).chain(once(!c)).chain(once(v)));
            solver.add_clause(once(a).chain(once(b)).chain(once(!v)));
            solver.add_clause(once(a).chain(once(c)).chain(once(!v)));
            solver.add_clause(once(b).chain(once(c)).chain(once(!v)));
            v
        }
    }

    pub fn bound(&self) -> usize {
        (2 as usize).pow(self.0.len() as u32) - 1
    }

    pub fn mul_digit(&self, solver: &mut Solver, y: Bool) -> Binary {
        Binary(self.0.iter().cloned().map(|x| solver.and_literal(once(x).chain(once(y)))).collect())
    }

    pub fn mul(&self, solver: &mut Solver, other: &Binary) -> Binary {
        let mut out = Vec::new();
        for (i, x) in self.0.iter().cloned().enumerate() {
            for (j, y) in other.0.iter().cloned().enumerate() {
                let z = solver.and_literal(once(x).chain(once(y)));
                out.push(((i + j), z));
            }
        }

        Binary::add_bits(solver, out)
    }
}

impl<'a> ModelValue<'a> for Binary {
    type T = usize;
    fn value(&self, m: &Model) -> usize {
        self.0
            .iter()
            .enumerate()
            .map(|(i, x)| if m.value(x) {
                (2 as usize).pow(i as u32)
            } else {
                0
            })
            .sum()
    }
}


impl ModelOrd for Binary {
    fn assert_less_or(solver: &mut Solver,
                      prefix: Vec<Bool>,
                      inclusive: bool,
                      a: &Binary,
                      b: &Binary) {
        use std::iter::repeat;
        let len = a.0.len().max(b.0.len());
        let mut a_bits = a.0
            .iter()
            .cloned()
            .chain(repeat(false.into()))
            .take(len)
            .collect::<Vec<_>>();
        a_bits.reverse();
        let mut b_bits = b.0
            .iter()
            .cloned()
            .chain(repeat(false.into()))
            .take(len)
            .collect::<Vec<_>>();
        b_bits.reverse();
        <&[Bool]>::assert_less_or(solver,
                                  prefix,
                                  inclusive,
                                  &&a_bits.as_slice(),
                                  &&b_bits.as_slice());
    }
}



impl ModelEq for Binary {
    fn assert_equal_or(solver: &mut Solver, prefix: Vec<Bool>, a: &Binary, b: &Binary) {
        let mut i = 0;
        while i < a.0.len() || i < b.0.len() {
            Bool::assert_equal_or(solver,
                                  prefix.clone(),
                                  a.0.get(i).unwrap_or(&false.into()),
                                  b.0.get(i).unwrap_or(&false.into()));
            i += 1;
        }
    }

    fn assert_not_equal_or(solver: &mut Solver, mut prefix: Vec<Bool>, a: &Binary, b: &Binary) {
        let mut i = 0;
        let mut q = solver.new_lit();
        prefix.push(q);

        while i < a.0.len() || i < b.0.len() {
            let ai = a.0.get(i).cloned().unwrap_or(false.into());
            let bi = b.0.get(i).cloned().unwrap_or(false.into());

            Bool::assert_not_equal_or(solver, prefix.clone(), &ai, &bi);
            prefix.clear();
            prefix.push(!q);

            q = solver.new_lit();
            i += 1;
        }
        solver.add_clause(prefix);
    }
}