1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2pub struct TSet(pub [u64; 4]);
3
4impl TSet {
5 #[inline]
6 pub const fn splat(v: u64) -> Self {
7 TSet([v, v, v, v])
8 }
9
10 pub fn from_bytes(bytes: &[u8]) -> Self {
11 let mut bits = [0u64; 4];
12 for &b in bytes {
13 bits[b as usize / 64] |= 1u64 << (b as usize % 64);
14 }
15 Self(bits)
16 }
17
18 #[inline(always)]
19 pub fn contains_byte(&self, b: u8) -> bool {
20 self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
21 }
22}
23
24impl std::ops::Index<usize> for TSet {
25 type Output = u64;
26 #[inline]
27 fn index(&self, i: usize) -> &u64 {
28 &self.0[i]
29 }
30}
31
32impl std::ops::IndexMut<usize> for TSet {
33 #[inline]
34 fn index_mut(&mut self, i: usize) -> &mut u64 {
35 &mut self.0[i]
36 }
37}
38
39impl std::ops::BitAnd for TSet {
40 type Output = TSet;
41 #[inline]
42 fn bitand(self, rhs: TSet) -> TSet {
43 TSet([
44 self.0[0] & rhs.0[0],
45 self.0[1] & rhs.0[1],
46 self.0[2] & rhs.0[2],
47 self.0[3] & rhs.0[3],
48 ])
49 }
50}
51
52impl std::ops::BitAnd for &TSet {
53 type Output = TSet;
54 #[inline]
55 fn bitand(self, rhs: &TSet) -> TSet {
56 TSet([
57 self.0[0] & rhs.0[0],
58 self.0[1] & rhs.0[1],
59 self.0[2] & rhs.0[2],
60 self.0[3] & rhs.0[3],
61 ])
62 }
63}
64
65impl std::ops::BitOr for TSet {
66 type Output = TSet;
67 #[inline]
68 fn bitor(self, rhs: TSet) -> TSet {
69 TSet([
70 self.0[0] | rhs.0[0],
71 self.0[1] | rhs.0[1],
72 self.0[2] | rhs.0[2],
73 self.0[3] | rhs.0[3],
74 ])
75 }
76}
77
78impl std::ops::Not for TSet {
79 type Output = TSet;
80 #[inline]
81 fn not(self) -> TSet {
82 TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
83 }
84}
85
86impl std::ops::BitAnd<TSet> for &TSet {
88 type Output = TSet;
89 #[inline]
90 fn bitand(self, rhs: TSet) -> TSet {
91 TSet([
92 self.0[0] & rhs.0[0],
93 self.0[1] & rhs.0[1],
94 self.0[2] & rhs.0[2],
95 self.0[3] & rhs.0[3],
96 ])
97 }
98}
99
100impl std::ops::BitOr<TSet> for &TSet {
101 type Output = TSet;
102 #[inline]
103 fn bitor(self, rhs: TSet) -> TSet {
104 TSet([
105 self.0[0] | rhs.0[0],
106 self.0[1] | rhs.0[1],
107 self.0[2] | rhs.0[2],
108 self.0[3] | rhs.0[3],
109 ])
110 }
111}
112
113const EMPTY: TSet = TSet::splat(u64::MIN);
114const FULL: TSet = TSet::splat(u64::MAX);
115
116#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
117pub struct TSetId(pub u32);
118impl TSetId {
119 pub const EMPTY: TSetId = TSetId(0);
120 pub const FULL: TSetId = TSetId(1);
121}
122
123use rustc_hash::FxHashMap;
124use std::collections::BTreeSet;
125
126pub struct Solver {
127 cache: FxHashMap<TSet, TSetId>,
128 pub array: Vec<TSet>,
129}
130
131impl Default for Solver {
132 fn default() -> Self {
133 Self::new()
134 }
135}
136
137impl Solver {
138 pub fn new() -> Solver {
139 let mut inst = Self {
140 cache: FxHashMap::default(),
141 array: Vec::new(),
142 };
143 let _ = inst.init(Solver::empty()); let _ = inst.init(Solver::full()); inst
146 }
147
148 fn init(&mut self, inst: TSet) -> TSetId {
149 let new_id = TSetId(self.cache.len() as u32);
150 self.cache.insert(inst, new_id);
151 self.array.push(inst);
152 new_id
153 }
154
155 pub fn get_set(&self, set_id: TSetId) -> TSet {
156 self.array[set_id.0 as usize]
157 }
158
159 pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
160 &self.array[set_id.0 as usize]
161 }
162
163 pub fn get_id(&mut self, inst: TSet) -> TSetId {
164 match self.cache.get(&inst) {
165 Some(&id) => id,
166 None => self.init(inst),
167 }
168 }
169
170 pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
171 self.array[set_id.0 as usize][idx] & bit != 0
172 }
173
174 pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
175 let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
176 let mut rangestart: Option<u8> = None;
177 let mut prevchar: Option<u8> = None;
178 for i in 0..4 {
179 for j in 0..64 {
180 let nthbit = 1u64 << j;
181 if tset[i] & nthbit != 0 {
182 let cc = (i * 64 + j) as u8;
183 if rangestart.is_none() {
184 rangestart = Some(cc);
185 prevchar = Some(cc);
186 continue;
187 }
188
189 if let (Some(currstart), Some(currprev)) = (rangestart, prevchar) {
190 if currprev == cc - 1 {
191 prevchar = Some(cc);
192 continue;
193 }
194 ranges.insert((currstart, currprev));
195 rangestart = Some(cc);
196 prevchar = Some(cc);
197 }
198 }
199 }
200 }
201 if let (Some(start), Some(end)) = (rangestart, prevchar) {
202 ranges.insert((start, end));
203 }
204 ranges
205 }
206
207 fn pp_byte(b: u8) -> String {
208 if cfg!(feature = "graphviz") {
209 match b as char {
210 '\n' => return r"\ṅ".to_owned(),
212 '"' => return r"\u{201c}".to_owned(),
213 '\r' => return r"\r".to_owned(),
214 '\t' => return r"\t".to_owned(),
215 _ => {}
216 }
217 }
218 match b as char {
219 '\n' => r"\n".to_owned(),
220 '\r' => r"\r".to_owned(),
221 '\t' => r"\t".to_owned(),
222 ' ' => r" ".to_owned(),
223 '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
224 | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
225 c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
226 _ => format!("\\x{:02X}", b),
227 }
228 }
229
230 fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
231 let display_range = |c, c2| {
232 if c == c2 {
233 Self::pp_byte(c)
234 } else if c.abs_diff(c2) == 1 {
235 format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
236 } else {
237 format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
238 }
239 };
240
241 if ranges.is_empty() {
242 return "\u{22a5}".to_owned();
243 }
244 if ranges.len() == 1 {
245 let (s, e) = ranges.iter().next().unwrap();
246 if s == e {
247 return Self::pp_byte(*s);
248 } else {
249 return ranges
250 .iter()
251 .map(|(s, e)| display_range(*s, *e))
252 .collect::<Vec<_>>()
253 .join("")
254 .to_string();
255 }
256 }
257 if ranges.len() > 20 {
258 return "\u{03c6}".to_owned();
259 }
260 ranges
261 .iter()
262 .map(|(s, e)| display_range(*s, *e))
263 .collect::<Vec<_>>()
264 .join("")
265 .to_string()
266 }
267
268 pub fn pp_first(&self, tset: &TSet) -> char {
269 let tryn1 = |i: usize| {
270 for j in 0..32 {
271 let nthbit = 1u64 << j;
272 if tset[i] & nthbit != 0 {
273 let cc = (i * 64 + j) as u8 as char;
274 return Some(cc);
275 }
276 }
277 None
278 };
279 let tryn2 = |i: usize| {
280 for j in 33..64 {
281 let nthbit = 1u64 << j;
282 if tset[i] & nthbit != 0 {
283 let cc = (i * 64 + j) as u8 as char;
284 return Some(cc);
285 }
286 }
287 None
288 };
289 tryn2(0)
291 .or_else(|| tryn2(1))
292 .or_else(|| tryn1(1))
293 .or_else(|| tryn1(2))
294 .or_else(|| tryn2(2))
295 .or_else(|| tryn1(3))
296 .or_else(|| tryn2(3))
297 .or_else(|| tryn1(0))
298 .unwrap_or('\u{22a5}')
299 }
300
301 pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
302 let tset = self.get_set(tset);
303 Self::pp_collect_ranges(&tset).into_iter().collect()
304 }
305
306 #[allow(unused)]
307 fn first_byte(tset: &TSet) -> u8 {
308 for i in 0..4 {
309 for j in 0..64 {
310 let nthbit = 1u64 << j;
311 if tset[i] & nthbit != 0 {
312 let cc = (i * 64 + j) as u8;
313 return cc;
314 }
315 }
316 }
317 0
318 }
319
320 pub fn pp(&self, tset: TSetId) -> String {
321 if tset == TSetId::FULL {
322 return "_".to_owned();
323 }
324 if tset == TSetId::EMPTY {
325 return "\u{22a5}".to_owned();
326 }
327 let tset = self.get_set(tset);
328 let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
329 let rstart = ranges.first().unwrap().0;
330 let rend = ranges.last().unwrap().1;
331 if ranges.len() >= 2 && rstart == 0 && rend == 255 {
332 let not_id = Self::not(&tset);
333 let not_ranges = Self::pp_collect_ranges(¬_id);
334 if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
335 return r".".to_owned();
336 }
337 let content = Self::pp_content(¬_ranges);
338 return format!("[^{}]", content);
339 }
340 if ranges.is_empty() {
341 return "\u{22a5}".to_owned();
342 }
343 if ranges.len() == 1 {
344 let (s, e) = ranges.iter().next().unwrap();
345 if s == e {
346 return Self::pp_byte(*s);
347 } else {
348 let content = Self::pp_content(&ranges);
349 return format!("[{}]", content);
350 }
351 }
352 let content = Self::pp_content(&ranges);
353 format!("[{}]", content)
354 }
355}
356
357impl Solver {
358 #[inline]
359 pub fn full() -> TSet {
360 FULL
361 }
362
363 #[inline]
364 pub fn empty() -> TSet {
365 EMPTY
366 }
367
368 #[inline]
369 pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
370 self.get_id(self.get_set(set1) | self.get_set(set2))
371 }
372
373 #[inline]
374 pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
375 self.get_id(self.get_set(set1) & self.get_set(set2))
376 }
377
378 #[inline]
379 pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
380 self.get_id(!self.get_set(set_id))
381 }
382
383 #[inline]
384 pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
385 self.and_id(set1, set2) != TSetId::EMPTY
386 }
387 #[inline]
388 pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
389 self.and_id(set1, set2) == TSetId::EMPTY
390 }
391
392 pub fn byte_count(&self, set_id: TSetId) -> u32 {
393 let tset = self.get_set(set_id);
394 (0..4).map(|i| tset[i].count_ones()).sum()
395 }
396
397 pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
398 let tset = self.get_set(set_id);
399 let mut bytes = Vec::new();
400 for i in 0..4 {
401 let mut bits = tset[i];
402 while bits != 0 {
403 let j = bits.trailing_zeros() as usize;
404 bytes.push((i * 64 + j) as u8);
405 bits &= bits - 1;
406 }
407 }
408 bytes
409 }
410
411 pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
412 let tset = self.get_set(set_id);
413 let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
414 if total != 1 {
415 return None;
416 }
417 for i in 0..4 {
418 if tset[i] != 0 {
419 return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
420 }
421 }
422 None
423 }
424
425 #[inline]
426 pub fn is_empty_id(&self, set1: TSetId) -> bool {
427 set1 == TSetId::EMPTY
428 }
429
430 #[inline]
431 pub fn is_full_id(&self, set1: TSetId) -> bool {
432 set1 == TSetId::FULL
433 }
434
435 #[inline]
436 pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
437 let not_large = self.not_id(large_id);
438 self.and_id(small_id, not_large) == TSetId::EMPTY
439 }
440
441 pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
442 let mut result = TSet::splat(u64::MIN);
443 let nthbit = 1u64 << (byte % 64);
444 match byte {
445 0..=63 => {
446 result[0] = nthbit;
447 }
448 64..=127 => {
449 result[1] = nthbit;
450 }
451 128..=191 => {
452 result[2] = nthbit;
453 }
454 192..=255 => {
455 result[3] = nthbit;
456 }
457 }
458 self.get_id(result)
459 }
460
461 pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
462 let mut result = TSet::splat(u64::MIN);
463 for byte in start..=end {
464 let nthbit = 1u64 << (byte % 64);
465 match byte {
466 0..=63 => {
467 result[0] |= nthbit;
468 }
469 64..=127 => {
470 result[1] |= nthbit;
471 }
472 128..=191 => {
473 result[2] |= nthbit;
474 }
475 192..=255 => {
476 result[3] |= nthbit;
477 }
478 }
479 }
480 self.get_id(result)
481 }
482
483 #[inline]
484 pub fn and(set1: &TSet, set2: &TSet) -> TSet {
485 *set1 & *set2
486 }
487
488 #[inline]
489 pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
490 *set1 & *set2 != Solver::empty()
491 }
492
493 #[inline]
494 pub fn or(set1: &TSet, set2: &TSet) -> TSet {
495 *set1 | *set2
496 }
497
498 #[inline]
499 pub fn not(set: &TSet) -> TSet {
500 !*set
501 }
502
503 #[inline]
504 pub fn is_full(set: &TSet) -> bool {
505 *set == Self::full()
506 }
507
508 #[inline]
509 pub fn is_empty(set: &TSet) -> bool {
510 *set == Solver::empty()
511 }
512
513 #[inline]
514 pub fn contains(large: &TSet, small: &TSet) -> bool {
515 Solver::empty() == (*small & !*large)
516 }
517
518 pub fn u8_to_set(byte: u8) -> TSet {
519 let mut result = TSet::splat(u64::MIN);
520 let nthbit = 1u64 << (byte % 64);
521 match byte {
522 0..=63 => {
523 result[0] = nthbit;
524 }
525 64..=127 => {
526 result[1] = nthbit;
527 }
528 128..=191 => {
529 result[2] = nthbit;
530 }
531 192..=255 => {
532 result[3] = nthbit;
533 }
534 }
535 result
536 }
537
538 pub fn range_to_set(start: u8, end: u8) -> TSet {
539 let mut result = TSet::splat(u64::MIN);
540 for byte in start..=end {
541 let nthbit = 1u64 << (byte % 64);
542 match byte {
543 0..=63 => {
544 result[0] |= nthbit;
545 }
546 64..=127 => {
547 result[1] |= nthbit;
548 }
549 128..=191 => {
550 result[2] |= nthbit;
551 }
552 192..=255 => {
553 result[3] |= nthbit;
554 }
555 }
556 }
557 result
558 }
559}