1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2pub struct TSet(pub [u64; 4]);
3
4impl TSet {
5 #[inline]
6 pub const fn splat(v: u64) -> Self {
7 TSet([v, v, v, v])
8 }
9
10 pub fn from_bytes(bytes: &[u8]) -> Self {
11 let mut bits = [0u64; 4];
12 for &b in bytes {
13 bits[b as usize / 64] |= 1u64 << (b as usize % 64);
14 }
15 Self(bits)
16 }
17
18 #[inline(always)]
19 pub fn contains_byte(&self, b: u8) -> bool {
20 self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
21 }
22}
23
24impl std::ops::Index<usize> for TSet {
25 type Output = u64;
26 #[inline]
27 fn index(&self, i: usize) -> &u64 {
28 &self.0[i]
29 }
30}
31
32impl std::ops::IndexMut<usize> for TSet {
33 #[inline]
34 fn index_mut(&mut self, i: usize) -> &mut u64 {
35 &mut self.0[i]
36 }
37}
38
39impl std::ops::BitAnd for TSet {
40 type Output = TSet;
41 #[inline]
42 fn bitand(self, rhs: TSet) -> TSet {
43 TSet([
44 self.0[0] & rhs.0[0],
45 self.0[1] & rhs.0[1],
46 self.0[2] & rhs.0[2],
47 self.0[3] & rhs.0[3],
48 ])
49 }
50}
51
52impl std::ops::BitAnd for &TSet {
53 type Output = TSet;
54 #[inline]
55 fn bitand(self, rhs: &TSet) -> TSet {
56 TSet([
57 self.0[0] & rhs.0[0],
58 self.0[1] & rhs.0[1],
59 self.0[2] & rhs.0[2],
60 self.0[3] & rhs.0[3],
61 ])
62 }
63}
64
65impl std::ops::BitOr for TSet {
66 type Output = TSet;
67 #[inline]
68 fn bitor(self, rhs: TSet) -> TSet {
69 TSet([
70 self.0[0] | rhs.0[0],
71 self.0[1] | rhs.0[1],
72 self.0[2] | rhs.0[2],
73 self.0[3] | rhs.0[3],
74 ])
75 }
76}
77
78impl std::ops::Not for TSet {
79 type Output = TSet;
80 #[inline]
81 fn not(self) -> TSet {
82 TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
83 }
84}
85
86impl std::ops::BitAnd<TSet> for &TSet {
88 type Output = TSet;
89 #[inline]
90 fn bitand(self, rhs: TSet) -> TSet {
91 TSet([
92 self.0[0] & rhs.0[0],
93 self.0[1] & rhs.0[1],
94 self.0[2] & rhs.0[2],
95 self.0[3] & rhs.0[3],
96 ])
97 }
98}
99
100impl std::ops::BitOr<TSet> for &TSet {
101 type Output = TSet;
102 #[inline]
103 fn bitor(self, rhs: TSet) -> TSet {
104 TSet([
105 self.0[0] | rhs.0[0],
106 self.0[1] | rhs.0[1],
107 self.0[2] | rhs.0[2],
108 self.0[3] | rhs.0[3],
109 ])
110 }
111}
112
113const EMPTY: TSet = TSet::splat(u64::MIN);
114const FULL: TSet = TSet::splat(u64::MAX);
115
116#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
117pub struct TSetId(pub u32);
118impl TSetId {
119 pub const EMPTY: TSetId = TSetId(0);
120 pub const FULL: TSetId = TSetId(1);
121}
122
123use rustc_hash::FxHashMap;
124use std::collections::BTreeSet;
125
126pub struct Solver {
127 cache: FxHashMap<TSet, TSetId>,
128 pub array: Vec<TSet>,
129}
130
131impl Default for Solver {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl Solver {
138 pub fn new() -> Solver {
139 let mut inst = Self {
140 cache: FxHashMap::default(),
141 array: Vec::new(),
142 };
143 let _ = inst.init(Solver::empty()); let _ = inst.init(Solver::full()); inst
146 }
147
148 fn init(&mut self, inst: TSet) -> TSetId {
149 let new_id = TSetId(self.cache.len() as u32);
150 self.cache.insert(inst, new_id);
151 self.array.push(inst);
152 new_id
153 }
154
155 pub fn get_set(&self, set_id: TSetId) -> TSet {
156 self.array[set_id.0 as usize]
157 }
158
159 pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
160 &self.array[set_id.0 as usize]
161 }
162
163 pub fn get_id(&mut self, inst: TSet) -> TSetId {
164 match self.cache.get(&inst) {
165 Some(&id) => id,
166 None => self.init(inst),
167 }
168 }
169
170 pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
171 self.array[set_id.0 as usize][idx] & bit != 0
172 }
173
174 pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
175 let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
176 let mut rangestart: Option<u8> = None;
177 let mut prevchar: Option<u8> = None;
178 for i in 0..4 {
179 for j in 0..64 {
180 let nthbit = 1u64 << j;
181 if tset[i] & nthbit != 0 {
182 let cc = (i * 64 + j) as u8;
183 if rangestart.is_none() {
184 rangestart = Some(cc);
185 prevchar = Some(cc);
186 continue;
187 }
188
189 if let (Some(currstart), Some(currprev)) = (rangestart, prevchar) {
190 if currprev == cc - 1 {
191 prevchar = Some(cc);
192 continue;
193 }
194 ranges.insert((currstart, currprev));
195 rangestart = Some(cc);
196 prevchar = Some(cc);
197 }
198 }
199 }
200 }
201 if let (Some(start), Some(end)) = (rangestart, prevchar) {
202 ranges.insert((start, end));
203 }
204 ranges
205 }
206
207 fn pp_byte(b: u8) -> String {
208 if cfg!(feature = "graphviz") {
209 match b as char {
210 '\n' => return r"\ṅ".to_owned(),
212 '"' => return r"\u{201c}".to_owned(),
213 '\r' => return r"\r".to_owned(),
214 '\t' => return r"\t".to_owned(),
215 _ => {}
216 }
217 }
218 match b as char {
219 '\n' => r"\n".to_owned(),
220 '\r' => r"\r".to_owned(),
221 '\t' => r"\t".to_owned(),
222 ' ' => r" ".to_owned(),
223 '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
224 | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
225 c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
226 _ => format!("\\x{:02X}", b),
227 }
228 }
229
230 fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
231 let display_range = |c, c2| {
232 if c == c2 {
233 Self::pp_byte(c)
234 } else if c.abs_diff(c2) == 1 {
235 format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
236 } else {
237 format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
238 }
239 };
240
241 if ranges.is_empty() {
242 return "\u{22a5}".to_owned();
243 }
244 if ranges.len() == 1 {
245 let (s, e) = ranges.iter().next().unwrap();
246 if s == e {
247 return Self::pp_byte(*s);
248 } else {
249 return ranges
250 .iter()
251 .map(|(s, e)| display_range(*s, *e))
252 .collect::<Vec<_>>()
253 .join("").to_string();
254 }
255 }
256 if ranges.len() > 20 {
257 return "\u{03c6}".to_owned();
258 }
259 ranges
260 .iter()
261 .map(|(s, e)| display_range(*s, *e))
262 .collect::<Vec<_>>()
263 .join("").to_string()
264 }
265
266 pub fn pp_first(&self, tset: &TSet) -> char {
267 let tryn1 = |i: usize| {
268 for j in 0..32 {
269 let nthbit = 1u64 << j;
270 if tset[i] & nthbit != 0 {
271 let cc = (i * 64 + j) as u8 as char;
272 return Some(cc);
273 }
274 }
275 None
276 };
277 let tryn2 = |i: usize| {
278 for j in 33..64 {
279 let nthbit = 1u64 << j;
280 if tset[i] & nthbit != 0 {
281 let cc = (i * 64 + j) as u8 as char;
282 return Some(cc);
283 }
284 }
285 None
286 };
287 tryn2(0)
289 .or_else(|| tryn2(1))
290 .or_else(|| tryn1(1))
291 .or_else(|| tryn1(2))
292 .or_else(|| tryn2(2))
293 .or_else(|| tryn1(3))
294 .or_else(|| tryn2(3))
295 .or_else(|| tryn1(0))
296 .unwrap_or('\u{22a5}')
297 }
298
299 pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
300 let tset = self.get_set(tset);
301 Self::pp_collect_ranges(&tset).into_iter().collect()
302 }
303
304 #[allow(unused)]
305 fn first_byte(tset: &TSet) -> u8 {
306 for i in 0..4 {
307 for j in 0..64 {
308 let nthbit = 1u64 << j;
309 if tset[i] & nthbit != 0 {
310 let cc = (i * 64 + j) as u8;
311 return cc;
312 }
313 }
314 }
315 0
316 }
317
318 pub fn pp(&self, tset: TSetId) -> String {
319 if tset == TSetId::FULL {
320 return "_".to_owned();
321 }
322 if tset == TSetId::EMPTY {
323 return "\u{22a5}".to_owned();
324 }
325 let tset = self.get_set(tset);
326 let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
327 let rstart = ranges.first().unwrap().0;
328 let rend = ranges.last().unwrap().1;
329 if ranges.len() >= 2 && rstart == 0 && rend == 255 {
330 let not_id = Self::not(&tset);
331 let not_ranges = Self::pp_collect_ranges(¬_id);
332 if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
333 return r".".to_owned();
334 }
335 let content = Self::pp_content(¬_ranges);
336 return format!("[^{}]", content);
337 }
338 if ranges.is_empty() {
339 return "\u{22a5}".to_owned();
340 }
341 if ranges.len() == 1 {
342 let (s, e) = ranges.iter().next().unwrap();
343 if s == e {
344 return Self::pp_byte(*s);
345 } else {
346 let content = Self::pp_content(&ranges);
347 return format!("[{}]", content);
348 }
349 }
350 let content = Self::pp_content(&ranges);
351 format!("[{}]", content)
352 }
353}
354
355impl Solver {
356 #[inline]
357 pub fn full() -> TSet {
358 FULL
359 }
360
361 #[inline]
362 pub fn empty() -> TSet {
363 EMPTY
364 }
365
366 #[inline]
367 pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
368 self.get_id(self.get_set(set1) | self.get_set(set2))
369 }
370
371 #[inline]
372 pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
373 self.get_id(self.get_set(set1) & self.get_set(set2))
374 }
375
376 #[inline]
377 pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
378 self.get_id(!self.get_set(set_id))
379 }
380
381 #[inline]
382 pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
383 self.and_id(set1, set2) != TSetId::EMPTY
384 }
385 #[inline]
386 pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
387 self.and_id(set1, set2) == TSetId::EMPTY
388 }
389
390 pub fn byte_count(&self, set_id: TSetId) -> u32 {
391 let tset = self.get_set(set_id);
392 (0..4).map(|i| tset[i].count_ones()).sum()
393 }
394
395 pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
396 let tset = self.get_set(set_id);
397 let mut bytes = Vec::new();
398 for i in 0..4 {
399 let mut bits = tset[i];
400 while bits != 0 {
401 let j = bits.trailing_zeros() as usize;
402 bytes.push((i * 64 + j) as u8);
403 bits &= bits - 1;
404 }
405 }
406 bytes
407 }
408
409 pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
410 let tset = self.get_set(set_id);
411 let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
412 if total != 1 {
413 return None;
414 }
415 for i in 0..4 {
416 if tset[i] != 0 {
417 return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
418 }
419 }
420 None
421 }
422
423 #[inline]
424 pub fn is_empty_id(&self, set1: TSetId) -> bool {
425 set1 == TSetId::EMPTY
426 }
427
428 #[inline]
429 pub fn is_full_id(&self, set1: TSetId) -> bool {
430 set1 == TSetId::FULL
431 }
432
433 #[inline]
434 pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
435 let not_large = self.not_id(large_id);
436 self.and_id(small_id, not_large) == TSetId::EMPTY
437 }
438
439 pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
440 let mut result = TSet::splat(u64::MIN);
441 let nthbit = 1u64 << (byte % 64);
442 match byte {
443 0..=63 => {
444 result[0] = nthbit;
445 }
446 64..=127 => {
447 result[1] = nthbit;
448 }
449 128..=191 => {
450 result[2] = nthbit;
451 }
452 192..=255 => {
453 result[3] = nthbit;
454 }
455 }
456 self.get_id(result)
457 }
458
459 pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
460 let mut result = TSet::splat(u64::MIN);
461 for byte in start..=end {
462 let nthbit = 1u64 << (byte % 64);
463 match byte {
464 0..=63 => {
465 result[0] |= nthbit;
466 }
467 64..=127 => {
468 result[1] |= nthbit;
469 }
470 128..=191 => {
471 result[2] |= nthbit;
472 }
473 192..=255 => {
474 result[3] |= nthbit;
475 }
476 }
477 }
478 self.get_id(result)
479 }
480
481 #[inline]
482 pub fn and(set1: &TSet, set2: &TSet) -> TSet {
483 *set1 & *set2
484 }
485
486 #[inline]
487 pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
488 *set1 & *set2 != Solver::empty()
489 }
490
491 #[inline]
492 pub fn or(set1: &TSet, set2: &TSet) -> TSet {
493 *set1 | *set2
494 }
495
496 #[inline]
497 pub fn not(set: &TSet) -> TSet {
498 !*set
499 }
500
501 #[inline]
502 pub fn is_full(set: &TSet) -> bool {
503 *set == Self::full()
504 }
505
506 #[inline]
507 pub fn is_empty(set: &TSet) -> bool {
508 *set == Solver::empty()
509 }
510
511 #[inline]
512 pub fn contains(large: &TSet, small: &TSet) -> bool {
513 Solver::empty() == (*small & !*large)
514 }
515
516 pub fn u8_to_set(byte: u8) -> TSet {
517 let mut result = TSet::splat(u64::MIN);
518 let nthbit = 1u64 << (byte % 64);
519 match byte {
520 0..=63 => {
521 result[0] = nthbit;
522 }
523 64..=127 => {
524 result[1] = nthbit;
525 }
526 128..=191 => {
527 result[2] = nthbit;
528 }
529 192..=255 => {
530 result[3] = nthbit;
531 }
532 }
533 result
534 }
535
536 pub fn range_to_set(start: u8, end: u8) -> TSet {
537 let mut result = TSet::splat(u64::MIN);
538 for byte in start..=end {
539 let nthbit = 1u64 << (byte % 64);
540 match byte {
541 0..=63 => {
542 result[0] |= nthbit;
543 }
544 64..=127 => {
545 result[1] |= nthbit;
546 }
547 128..=191 => {
548 result[2] |= nthbit;
549 }
550 192..=255 => {
551 result[3] |= nthbit;
552 }
553 }
554 }
555 result
556 }
557}