Skip to main content

resharp_algebra/
solver.rs

1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2pub struct TSet(pub [u64; 4]);
3
4impl TSet {
5    #[inline]
6    pub const fn splat(v: u64) -> Self {
7        TSet([v, v, v, v])
8    }
9
10    pub fn from_bytes(bytes: &[u8]) -> Self {
11        let mut bits = [0u64; 4];
12        for &b in bytes {
13            bits[b as usize / 64] |= 1u64 << (b as usize % 64);
14        }
15        Self(bits)
16    }
17
18    #[inline(always)]
19    pub fn contains_byte(&self, b: u8) -> bool {
20        self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
21    }
22}
23
24impl std::ops::Index<usize> for TSet {
25    type Output = u64;
26    #[inline]
27    fn index(&self, i: usize) -> &u64 {
28        &self.0[i]
29    }
30}
31
32impl std::ops::IndexMut<usize> for TSet {
33    #[inline]
34    fn index_mut(&mut self, i: usize) -> &mut u64 {
35        &mut self.0[i]
36    }
37}
38
39impl std::ops::BitAnd for TSet {
40    type Output = TSet;
41    #[inline]
42    fn bitand(self, rhs: TSet) -> TSet {
43        TSet([
44            self.0[0] & rhs.0[0],
45            self.0[1] & rhs.0[1],
46            self.0[2] & rhs.0[2],
47            self.0[3] & rhs.0[3],
48        ])
49    }
50}
51
52impl std::ops::BitAnd for &TSet {
53    type Output = TSet;
54    #[inline]
55    fn bitand(self, rhs: &TSet) -> TSet {
56        TSet([
57            self.0[0] & rhs.0[0],
58            self.0[1] & rhs.0[1],
59            self.0[2] & rhs.0[2],
60            self.0[3] & rhs.0[3],
61        ])
62    }
63}
64
65impl std::ops::BitOr for TSet {
66    type Output = TSet;
67    #[inline]
68    fn bitor(self, rhs: TSet) -> TSet {
69        TSet([
70            self.0[0] | rhs.0[0],
71            self.0[1] | rhs.0[1],
72            self.0[2] | rhs.0[2],
73            self.0[3] | rhs.0[3],
74        ])
75    }
76}
77
78impl std::ops::Not for TSet {
79    type Output = TSet;
80    #[inline]
81    fn not(self) -> TSet {
82        TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
83    }
84}
85
86// &TSet ops used by Solver helper methods
87impl std::ops::BitAnd<TSet> for &TSet {
88    type Output = TSet;
89    #[inline]
90    fn bitand(self, rhs: TSet) -> TSet {
91        TSet([
92            self.0[0] & rhs.0[0],
93            self.0[1] & rhs.0[1],
94            self.0[2] & rhs.0[2],
95            self.0[3] & rhs.0[3],
96        ])
97    }
98}
99
100impl std::ops::BitOr<TSet> for &TSet {
101    type Output = TSet;
102    #[inline]
103    fn bitor(self, rhs: TSet) -> TSet {
104        TSet([
105            self.0[0] | rhs.0[0],
106            self.0[1] | rhs.0[1],
107            self.0[2] | rhs.0[2],
108            self.0[3] | rhs.0[3],
109        ])
110    }
111}
112
113const EMPTY: TSet = TSet::splat(u64::MIN);
114const FULL: TSet = TSet::splat(u64::MAX);
115
116#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
117pub struct TSetId(pub u32);
118impl TSetId {
119    pub const EMPTY: TSetId = TSetId(0);
120    pub const FULL: TSetId = TSetId(1);
121}
122
123use rustc_hash::FxHashMap;
124use std::collections::BTreeSet;
125
126pub struct Solver {
127    cache: FxHashMap<TSet, TSetId>,
128    pub array: Vec<TSet>,
129}
130
131impl Solver {
132    pub fn new() -> Solver {
133        let mut inst = Self {
134            cache: FxHashMap::default(),
135            array: Vec::new(),
136        };
137        let _ = inst.init(Solver::empty()); // 0
138        let _ = inst.init(Solver::full()); // 1
139        inst
140    }
141
142    fn init(&mut self, inst: TSet) -> TSetId {
143        let new_id = TSetId(self.cache.len() as u32);
144        self.cache.insert(inst, new_id);
145        self.array.push(inst);
146        new_id
147    }
148
149    pub fn get_set(&self, set_id: TSetId) -> TSet {
150        self.array[set_id.0 as usize]
151    }
152
153    pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
154        &self.array[set_id.0 as usize]
155    }
156
157    pub fn get_id(&mut self, inst: TSet) -> TSetId {
158        match self.cache.get(&inst) {
159            Some(&id) => id,
160            None => self.init(inst),
161        }
162    }
163
164    pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
165        self.array[set_id.0 as usize][idx] & bit != 0
166    }
167
168    pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
169        let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
170        let mut rangestart: Option<u8> = None;
171        let mut prevchar: Option<u8> = None;
172        for i in 0..4 {
173            for j in 0..64 {
174                let nthbit = 1u64 << j;
175                if tset[i] & nthbit != 0 {
176                    let cc = (i * 64 + j) as u8;
177                    if rangestart.is_none() {
178                        rangestart = Some(cc);
179                        prevchar = Some(cc);
180                        continue;
181                    }
182
183                    if let Some(currstart) = rangestart {
184                        if let Some(currprev) = prevchar {
185                            if currprev as u8 == cc as u8 - 1 {
186                                prevchar = Some(cc);
187                                continue;
188                            } else {
189                                if currstart == currprev {
190                                    ranges.insert((currstart, currstart));
191                                } else {
192                                    ranges.insert((currstart, currprev));
193                                }
194                                rangestart = Some(cc);
195                                prevchar = Some(cc);
196                            }
197                        } else {
198                        }
199                    } else {
200                    }
201                }
202            }
203        }
204        if let Some(start) = rangestart {
205            if let Some(prevchar) = prevchar {
206                if prevchar as u8 == start as u8 {
207                    ranges.insert((start, start));
208                } else {
209                    ranges.insert((start, prevchar));
210                }
211            } else {
212                // single char
213                ranges.insert((start, start));
214            }
215        }
216        ranges
217    }
218
219    fn pp_byte(b: u8) -> String {
220        if cfg!(feature = "graphviz") {
221            match b as char {
222                // graphviz doesnt like \n so we use \ṅ
223                '\n' => return r"\ṅ".to_owned(),
224                '"' => return r"\u{201c}".to_owned(),
225                '\r' => return r"\r".to_owned(),
226                '\t' => return r"\t".to_owned(),
227                _ => {}
228            }
229        }
230        match b as char {
231            '\n' => r"\n".to_owned(),
232            '\r' => r"\r".to_owned(),
233            '\t' => r"\t".to_owned(),
234            ' ' => r" ".to_owned(),
235            '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
236            | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
237            c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
238            _ => format!("\\x{:02X}", b),
239        }
240    }
241
242    fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
243        let display_range = |c, c2| {
244            if c == c2 {
245                Self::pp_byte(c)
246            } else if c.abs_diff(c2) == 1 {
247                format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
248            } else {
249                format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
250            }
251        };
252
253        if ranges.len() == 0 {
254            return "\u{22a5}".to_owned();
255        }
256        if ranges.len() == 1 {
257            let (s, e) = ranges.iter().next().unwrap();
258            if s == e {
259                return Self::pp_byte(*s);
260            } else {
261                return format!(
262                    "{}",
263                    ranges
264                        .iter()
265                        .map(|(s, e)| display_range(*s, *e))
266                        .collect::<Vec<_>>()
267                        .join("")
268                );
269            }
270        }
271        if ranges.len() > 20 {
272            return "\u{03c6}".to_owned();
273        }
274        return format!(
275            "{}",
276            ranges
277                .iter()
278                .map(|(s, e)| display_range(*s, *e))
279                .collect::<Vec<_>>()
280                .join("")
281        );
282    }
283
284    pub fn pp_first(&self, tset: &TSet) -> char {
285        let tryn1 = |i: usize| {
286            for j in 0..32 {
287                let nthbit = 1u64 << j;
288                if tset[i] & nthbit != 0 {
289                    let cc = (i * 64 + j) as u8 as char;
290                    return Some(cc);
291                }
292            }
293            None
294        };
295        let tryn2 = |i: usize| {
296            for j in 33..64 {
297                let nthbit = 1u64 << j;
298                if tset[i] & nthbit != 0 {
299                    let cc = (i * 64 + j) as u8 as char;
300                    return Some(cc);
301                }
302            }
303            None
304        };
305        // readable ones first
306        tryn2(0)
307            .or_else(|| tryn2(1))
308            .or_else(|| tryn1(1))
309            .or_else(|| tryn1(2))
310            .or_else(|| tryn2(2))
311            .or_else(|| tryn1(3))
312            .or_else(|| tryn2(3))
313            .or_else(|| tryn1(0))
314            .unwrap_or('\u{22a5}')
315    }
316
317    pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
318        let tset = self.get_set(tset);
319        Self::pp_collect_ranges(&tset).into_iter().collect()
320    }
321
322    #[allow(unused)]
323    fn first_byte(tset: &TSet) -> u8 {
324        for i in 0..4 {
325            for j in 0..64 {
326                let nthbit = 1u64 << j;
327                if tset[i] & nthbit != 0 {
328                    let cc = (i * 64 + j) as u8;
329                    return cc;
330                }
331            }
332        }
333        return 0;
334    }
335
336    pub fn pp(&self, tset: TSetId) -> String {
337        if tset == TSetId::FULL {
338            return "_".to_owned();
339        }
340        if tset == TSetId::EMPTY {
341            return "\u{22a5}".to_owned();
342        }
343        let tset = self.get_set(tset);
344        let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
345        let rstart = ranges.first().unwrap().0;
346        let rend = ranges.last().unwrap().1;
347        if ranges.len() >= 2 && rstart == 0 && rend == 255 {
348            let not_id = Self::not(&tset);
349            let not_ranges = Self::pp_collect_ranges(&not_id);
350            if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
351                return r".".to_owned();
352            }
353            let content = Self::pp_content(&not_ranges);
354            return format!("[^{}]", content);
355        }
356        if ranges.len() == 0 {
357            return "\u{22a5}".to_owned();
358        }
359        if ranges.len() == 1 {
360            let (s, e) = ranges.iter().next().unwrap();
361            if s == e {
362                return Self::pp_byte(*s);
363            } else {
364                let content = Self::pp_content(&ranges);
365                return format!("[{}]", content);
366            }
367        }
368        let content = Self::pp_content(&ranges);
369        return format!("[{}]", content);
370    }
371}
372
373impl Solver {
374    #[inline]
375    pub fn full() -> TSet {
376        FULL
377    }
378
379    #[inline]
380    pub fn empty() -> TSet {
381        EMPTY
382    }
383
384    #[inline]
385    pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
386        self.get_id(self.get_set(set1) | self.get_set(set2))
387    }
388
389    #[inline]
390    pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
391        self.get_id(self.get_set(set1) & self.get_set(set2))
392    }
393
394    #[inline]
395    pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
396        self.get_id(!self.get_set(set_id))
397    }
398
399    #[inline]
400    pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
401        self.and_id(set1, set2) != TSetId::EMPTY
402    }
403    #[inline]
404    pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
405        self.and_id(set1, set2) == TSetId::EMPTY
406    }
407
408    pub fn byte_count(&self, set_id: TSetId) -> u32 {
409        let tset = self.get_set(set_id);
410        (0..4).map(|i| tset[i].count_ones()).sum()
411    }
412
413    pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
414        let tset = self.get_set(set_id);
415        let mut bytes = Vec::new();
416        for i in 0..4 {
417            let mut bits = tset[i];
418            while bits != 0 {
419                let j = bits.trailing_zeros() as usize;
420                bytes.push((i * 64 + j) as u8);
421                bits &= bits - 1;
422            }
423        }
424        bytes
425    }
426
427    pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
428        let tset = self.get_set(set_id);
429        let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
430        if total != 1 {
431            return None;
432        }
433        for i in 0..4 {
434            if tset[i] != 0 {
435                return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
436            }
437        }
438        None
439    }
440
441    #[inline]
442    pub fn is_empty_id(&self, set1: TSetId) -> bool {
443        set1 == TSetId::EMPTY
444    }
445
446    #[inline]
447    pub fn is_full_id(&self, set1: TSetId) -> bool {
448        set1 == TSetId::FULL
449    }
450
451    #[inline]
452    pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
453        let not_large = self.not_id(large_id);
454        self.and_id(small_id, not_large) == TSetId::EMPTY
455    }
456
457    pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
458        let mut result = TSet::splat(u64::MIN);
459        let nthbit = 1u64 << byte % 64;
460        match byte {
461            0..=63 => {
462                result[0] = nthbit;
463            }
464            64..=127 => {
465                result[1] = nthbit;
466            }
467            128..=191 => {
468                result[2] = nthbit;
469            }
470            192..=255 => {
471                result[3] = nthbit;
472            }
473        }
474        self.get_id(result)
475    }
476
477    pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
478        let mut result = TSet::splat(u64::MIN);
479        for byte in start..=end {
480            let nthbit = 1u64 << byte % 64;
481            match byte {
482                0..=63 => {
483                    result[0] |= nthbit;
484                }
485                64..=127 => {
486                    result[1] |= nthbit;
487                }
488                128..=191 => {
489                    result[2] |= nthbit;
490                }
491                192..=255 => {
492                    result[3] |= nthbit;
493                }
494            }
495        }
496        self.get_id(result)
497    }
498
499    #[inline]
500    pub fn and(set1: &TSet, set2: &TSet) -> TSet {
501        *set1 & *set2
502    }
503
504    #[inline]
505    pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
506        *set1 & *set2 != Solver::empty()
507    }
508
509    #[inline]
510    pub fn or(set1: &TSet, set2: &TSet) -> TSet {
511        *set1 | *set2
512    }
513
514    #[inline]
515    pub fn not(set: &TSet) -> TSet {
516        !*set
517    }
518
519    #[inline]
520    pub fn is_full(set: &TSet) -> bool {
521        *set == Self::full()
522    }
523
524    #[inline]
525    pub fn is_empty(set: &TSet) -> bool {
526        *set == Solver::empty()
527    }
528
529    #[inline]
530    pub fn contains(large: &TSet, small: &TSet) -> bool {
531        Solver::empty() == (*small & !*large)
532    }
533
534    pub fn u8_to_set(byte: u8) -> TSet {
535        let mut result = TSet::splat(u64::MIN);
536        let nthbit = 1u64 << byte % 64;
537        match byte {
538            0..=63 => {
539                result[0] = nthbit;
540            }
541            64..=127 => {
542                result[1] = nthbit;
543            }
544            128..=191 => {
545                result[2] = nthbit;
546            }
547            192..=255 => {
548                result[3] = nthbit;
549            }
550        }
551        result
552    }
553
554    pub fn range_to_set(start: u8, end: u8) -> TSet {
555        let mut result = TSet::splat(u64::MIN);
556        for byte in start..=end {
557            let nthbit = 1u64 << byte % 64;
558            match byte {
559                0..=63 => {
560                    result[0] |= nthbit;
561                }
562                64..=127 => {
563                    result[1] |= nthbit;
564                }
565                128..=191 => {
566                    result[2] |= nthbit;
567                }
568                192..=255 => {
569                    result[3] |= nthbit;
570                }
571            }
572        }
573        result
574    }
575}