ac_library/twosat.rs
1//! A 2-SAT Solver.
2use crate::internal_scc;
3
4/// A 2-SAT Solver.
5///
6/// For variables $x_0, x_1, \ldots, x_{N - 1}$ and clauses with from
7///
8/// \\[
9/// (x_i = f) \lor (x_j = g)
10/// \\]
11///
12/// it decides whether there is a truth assignment that satisfies all clauses.
13///
14/// # Example
15///
16/// ```
17/// #![allow(clippy::many_single_char_names)]
18///
19/// use ac_library::TwoSat;
20/// use proconio::{input, marker::Bytes, source::once::OnceSource};
21///
22/// input! {
23/// from OnceSource::from(
24/// "3\n\
25/// 3\n\
26/// a b\n\
27/// !b c\n\
28/// !a !a\n",
29/// ),
30/// n: usize,
31/// pqs: [(Bytes, Bytes)],
32/// }
33///
34/// let mut twosat = TwoSat::new(n);
35///
36/// for (p, q) in pqs {
37/// fn parse(s: &[u8]) -> (usize, bool) {
38/// match *s {
39/// [c] => ((c - b'a').into(), true),
40/// [b'!', c] => ((c - b'a').into(), false),
41/// _ => unreachable!(),
42/// }
43/// }
44/// let ((i, f), (j, g)) = (parse(&p), parse(&q));
45/// twosat.add_clause(i, f, j, g);
46/// }
47///
48/// assert!(twosat.satisfiable());
49/// assert_eq!(twosat.answer(), [false, true, true]);
50/// ```
51#[derive(Clone, Debug)]
52pub struct TwoSat {
53 n: usize,
54 scc: internal_scc::SccGraph,
55 answer: Vec<bool>,
56}
57impl TwoSat {
58 /// Creates a new `TwoSat` of `n` variables and 0 clauses.
59 ///
60 /// # Constraints
61 ///
62 /// - $0 \leq n \leq 10^8$
63 ///
64 /// # Complexity
65 ///
66 /// - $O(n)$
67 pub fn new(n: usize) -> Self {
68 TwoSat {
69 n,
70 answer: vec![false; n],
71 scc: internal_scc::SccGraph::new(2 * n),
72 }
73 }
74 /// Adds a clause $(x_i = f) \lor (x_j = g)$.
75 ///
76 /// # Constraints
77 ///
78 /// - $0 \leq i < n$
79 /// - $0 \leq j < n$
80 ///
81 /// # Panics
82 ///
83 /// Panics if the above constraints are not satisfied.
84 ///
85 /// # Complexity
86 ///
87 /// - $O(1)$ amortized
88 pub fn add_clause(&mut self, i: usize, f: bool, j: usize, g: bool) {
89 assert!(i < self.n && j < self.n);
90 self.scc.add_edge(2 * i + !f as usize, 2 * j + g as usize);
91 self.scc.add_edge(2 * j + !g as usize, 2 * i + f as usize);
92 }
93 /// Returns whether there is a truth assignment that satisfies all clauses.
94 ///
95 /// # Complexity
96 ///
97 /// - $O(n + m)$ where $m$ is the number of added clauses
98 pub fn satisfiable(&mut self) -> bool {
99 let id = self.scc.scc_ids().1;
100 for i in 0..self.n {
101 if id[2 * i] == id[2 * i + 1] {
102 return false;
103 }
104 self.answer[i] = id[2 * i] < id[2 * i + 1];
105 }
106 true
107 }
108 /// Returns a truth assignment that satisfies all clauses **of the last call of [`satisfiable`]**.
109 ///
110 /// # Constraints
111 ///
112 /// - [`satisfiable`] is called after adding all clauses and it has returned `true`.
113 ///
114 /// # Complexity
115 ///
116 /// - $O(n)$
117 ///
118 /// [`satisfiable`]: #method.satisfiable
119 pub fn answer(&self) -> &[bool] {
120 &self.answer
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 #![allow(clippy::many_single_char_names)]
127 use super::*;
128 #[test]
129 fn solve_alpc_h_sample1() {
130 // https://atcoder.jp/contests/practice2/tasks/practice2_h
131
132 let (n, d) = (3, 2);
133 let x = [1, 2, 0i32];
134 let y = [4, 5, 6];
135
136 let mut t = TwoSat::new(n);
137
138 for i in 0..n {
139 for j in i + 1..n {
140 if (x[i] - x[j]).abs() < d {
141 t.add_clause(i, false, j, false);
142 }
143 if (x[i] - y[j]).abs() < d {
144 t.add_clause(i, false, j, true);
145 }
146 if (y[i] - x[j]).abs() < d {
147 t.add_clause(i, true, j, false);
148 }
149 if (y[i] - y[j]).abs() < d {
150 t.add_clause(i, true, j, true);
151 }
152 }
153 }
154 assert!(t.satisfiable());
155 let answer = t.answer();
156 let mut res = vec![];
157 for (i, &v) in answer.iter().enumerate() {
158 if v {
159 res.push(x[i])
160 } else {
161 res.push(y[i]);
162 }
163 }
164
165 //Check the min distance between flags
166 res.sort_unstable();
167 let mut min_distance = i32::MAX;
168 for i in 1..res.len() {
169 min_distance = std::cmp::min(min_distance, res[i] - res[i - 1]);
170 }
171 assert!(min_distance >= d);
172 }
173
174 #[test]
175 fn solve_alpc_h_sample2() {
176 // https://atcoder.jp/contests/practice2/tasks/practice2_h
177
178 let (n, d) = (3, 3);
179 let x = [1, 2, 0i32];
180 let y = [4, 5, 6];
181
182 let mut t = TwoSat::new(n);
183
184 for i in 0..n {
185 for j in i + 1..n {
186 if (x[i] - x[j]).abs() < d {
187 t.add_clause(i, false, j, false);
188 }
189 if (x[i] - y[j]).abs() < d {
190 t.add_clause(i, false, j, true);
191 }
192 if (y[i] - x[j]).abs() < d {
193 t.add_clause(i, true, j, false);
194 }
195 if (y[i] - y[j]).abs() < d {
196 t.add_clause(i, true, j, true);
197 }
198 }
199 }
200 assert!(!t.satisfiable());
201 }
202}