ac_library/
twosat.rs

1//! A 2-SAT Solver.
2use crate::internal_scc;
3
4/// A 2-SAT Solver.
5///
6/// For variables $x_0, x_1, \ldots, x_{N - 1}$ and clauses with from
7///
8/// \\[
9///   (x_i = f) \lor (x_j = g)
10/// \\]
11///
12/// it decides whether there is a truth assignment that satisfies all clauses.
13///
14/// # Example
15///
16/// ```
17/// #![allow(clippy::many_single_char_names)]
18///
19/// use ac_library::TwoSat;
20/// use proconio::{input, marker::Bytes, source::once::OnceSource};
21///
22/// input! {
23///     from OnceSource::from(
24///         "3\n\
25///          3\n\
26///          a b\n\
27///          !b c\n\
28///          !a !a\n",
29///     ),
30///     n: usize,
31///     pqs: [(Bytes, Bytes)],
32/// }
33///
34/// let mut twosat = TwoSat::new(n);
35///
36/// for (p, q) in pqs {
37///     fn parse(s: &[u8]) -> (usize, bool) {
38///         match *s {
39///             [c] => ((c - b'a').into(), true),
40///             [b'!', c] => ((c - b'a').into(), false),
41///             _ => unreachable!(),
42///         }
43///     }
44///     let ((i, f), (j, g)) = (parse(&p), parse(&q));
45///     twosat.add_clause(i, f, j, g);
46/// }
47///
48/// assert!(twosat.satisfiable());
49/// assert_eq!(twosat.answer(), [false, true, true]);
50/// ```
51#[derive(Clone, Debug)]
52pub struct TwoSat {
53    n: usize,
54    scc: internal_scc::SccGraph,
55    answer: Vec<bool>,
56}
57impl TwoSat {
58    /// Creates a new `TwoSat` of `n` variables and 0 clauses.
59    ///
60    /// # Constraints
61    ///
62    /// - $0 \leq n \leq 10^8$
63    ///
64    /// # Complexity
65    ///
66    /// - $O(n)$
67    pub fn new(n: usize) -> Self {
68        TwoSat {
69            n,
70            answer: vec![false; n],
71            scc: internal_scc::SccGraph::new(2 * n),
72        }
73    }
74    /// Adds a clause $(x_i = f) \lor (x_j = g)$.
75    ///
76    /// # Constraints
77    ///
78    /// - $0 \leq i < n$
79    /// - $0 \leq j < n$
80    ///
81    /// # Panics
82    ///
83    /// Panics if the above constraints are not satisfied.
84    ///
85    /// # Complexity
86    ///
87    /// - $O(1)$ amortized
88    pub fn add_clause(&mut self, i: usize, f: bool, j: usize, g: bool) {
89        assert!(i < self.n && j < self.n);
90        self.scc.add_edge(2 * i + !f as usize, 2 * j + g as usize);
91        self.scc.add_edge(2 * j + !g as usize, 2 * i + f as usize);
92    }
93    /// Returns whether there is a truth assignment that satisfies all clauses.
94    ///
95    /// # Complexity
96    ///
97    /// - $O(n + m)$ where $m$ is the number of added clauses
98    pub fn satisfiable(&mut self) -> bool {
99        let id = self.scc.scc_ids().1;
100        for i in 0..self.n {
101            if id[2 * i] == id[2 * i + 1] {
102                return false;
103            }
104            self.answer[i] = id[2 * i] < id[2 * i + 1];
105        }
106        true
107    }
108    /// Returns a truth assignment that satisfies all clauses **of the last call of [`satisfiable`]**.
109    ///
110    /// # Constraints
111    ///
112    /// - [`satisfiable`] is called after adding all clauses and it has returned `true`.
113    ///
114    /// # Complexity
115    ///
116    /// - $O(n)$
117    ///
118    /// [`satisfiable`]: #method.satisfiable
119    pub fn answer(&self) -> &[bool] {
120        &self.answer
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    #![allow(clippy::many_single_char_names)]
127    use super::*;
128    #[test]
129    fn solve_alpc_h_sample1() {
130        // https://atcoder.jp/contests/practice2/tasks/practice2_h
131
132        let (n, d) = (3, 2);
133        let x = [1, 2, 0i32];
134        let y = [4, 5, 6];
135
136        let mut t = TwoSat::new(n);
137
138        for i in 0..n {
139            for j in i + 1..n {
140                if (x[i] - x[j]).abs() < d {
141                    t.add_clause(i, false, j, false);
142                }
143                if (x[i] - y[j]).abs() < d {
144                    t.add_clause(i, false, j, true);
145                }
146                if (y[i] - x[j]).abs() < d {
147                    t.add_clause(i, true, j, false);
148                }
149                if (y[i] - y[j]).abs() < d {
150                    t.add_clause(i, true, j, true);
151                }
152            }
153        }
154        assert!(t.satisfiable());
155        let answer = t.answer();
156        let mut res = vec![];
157        for (i, &v) in answer.iter().enumerate() {
158            if v {
159                res.push(x[i])
160            } else {
161                res.push(y[i]);
162            }
163        }
164
165        //Check the min distance between flags
166        res.sort_unstable();
167        let mut min_distance = i32::MAX;
168        for i in 1..res.len() {
169            min_distance = std::cmp::min(min_distance, res[i] - res[i - 1]);
170        }
171        assert!(min_distance >= d);
172    }
173
174    #[test]
175    fn solve_alpc_h_sample2() {
176        // https://atcoder.jp/contests/practice2/tasks/practice2_h
177
178        let (n, d) = (3, 3);
179        let x = [1, 2, 0i32];
180        let y = [4, 5, 6];
181
182        let mut t = TwoSat::new(n);
183
184        for i in 0..n {
185            for j in i + 1..n {
186                if (x[i] - x[j]).abs() < d {
187                    t.add_clause(i, false, j, false);
188                }
189                if (x[i] - y[j]).abs() < d {
190                    t.add_clause(i, false, j, true);
191                }
192                if (y[i] - x[j]).abs() < d {
193                    t.add_clause(i, true, j, false);
194                }
195                if (y[i] - y[j]).abs() < d {
196                    t.add_clause(i, true, j, true);
197                }
198            }
199        }
200        assert!(!t.satisfiable());
201    }
202}