Skip to main content

resharp_algebra/
solver.rs

1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2pub struct TSet(pub [u64; 4]);
3
4impl TSet {
5    #[inline]
6    pub const fn splat(v: u64) -> Self {
7        TSet([v, v, v, v])
8    }
9
10    pub fn from_bytes(bytes: &[u8]) -> Self {
11        let mut bits = [0u64; 4];
12        for &b in bytes {
13            bits[b as usize / 64] |= 1u64 << (b as usize % 64);
14        }
15        Self(bits)
16    }
17
18    #[inline(always)]
19    pub fn contains_byte(&self, b: u8) -> bool {
20        self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
21    }
22}
23
24impl std::ops::Index<usize> for TSet {
25    type Output = u64;
26    #[inline]
27    fn index(&self, i: usize) -> &u64 {
28        &self.0[i]
29    }
30}
31
32impl std::ops::IndexMut<usize> for TSet {
33    #[inline]
34    fn index_mut(&mut self, i: usize) -> &mut u64 {
35        &mut self.0[i]
36    }
37}
38
39impl std::ops::BitAnd for TSet {
40    type Output = TSet;
41    #[inline]
42    fn bitand(self, rhs: TSet) -> TSet {
43        TSet([
44            self.0[0] & rhs.0[0],
45            self.0[1] & rhs.0[1],
46            self.0[2] & rhs.0[2],
47            self.0[3] & rhs.0[3],
48        ])
49    }
50}
51
52impl std::ops::BitAnd for &TSet {
53    type Output = TSet;
54    #[inline]
55    fn bitand(self, rhs: &TSet) -> TSet {
56        TSet([
57            self.0[0] & rhs.0[0],
58            self.0[1] & rhs.0[1],
59            self.0[2] & rhs.0[2],
60            self.0[3] & rhs.0[3],
61        ])
62    }
63}
64
65impl std::ops::BitOr for TSet {
66    type Output = TSet;
67    #[inline]
68    fn bitor(self, rhs: TSet) -> TSet {
69        TSet([
70            self.0[0] | rhs.0[0],
71            self.0[1] | rhs.0[1],
72            self.0[2] | rhs.0[2],
73            self.0[3] | rhs.0[3],
74        ])
75    }
76}
77
78impl std::ops::Not for TSet {
79    type Output = TSet;
80    #[inline]
81    fn not(self) -> TSet {
82        TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
83    }
84}
85
86// &TSet ops used by Solver helper methods
87impl std::ops::BitAnd<TSet> for &TSet {
88    type Output = TSet;
89    #[inline]
90    fn bitand(self, rhs: TSet) -> TSet {
91        TSet([
92            self.0[0] & rhs.0[0],
93            self.0[1] & rhs.0[1],
94            self.0[2] & rhs.0[2],
95            self.0[3] & rhs.0[3],
96        ])
97    }
98}
99
100impl std::ops::BitOr<TSet> for &TSet {
101    type Output = TSet;
102    #[inline]
103    fn bitor(self, rhs: TSet) -> TSet {
104        TSet([
105            self.0[0] | rhs.0[0],
106            self.0[1] | rhs.0[1],
107            self.0[2] | rhs.0[2],
108            self.0[3] | rhs.0[3],
109        ])
110    }
111}
112
113const EMPTY: TSet = TSet::splat(u64::MIN);
114const FULL: TSet = TSet::splat(u64::MAX);
115
116#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
117pub struct TSetId(pub u32);
118impl TSetId {
119    pub const EMPTY: TSetId = TSetId(0);
120    pub const FULL: TSetId = TSetId(1);
121}
122
123use rustc_hash::FxHashMap;
124use std::collections::BTreeSet;
125
126pub struct Solver {
127    cache: FxHashMap<TSet, TSetId>,
128    pub array: Vec<TSet>,
129}
130
131impl Default for Solver {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl Solver {
138    pub fn new() -> Solver {
139        let mut inst = Self {
140            cache: FxHashMap::default(),
141            array: Vec::new(),
142        };
143        let _ = inst.init(Solver::empty()); // 0
144        let _ = inst.init(Solver::full()); // 1
145        inst
146    }
147
148    fn init(&mut self, inst: TSet) -> TSetId {
149        let new_id = TSetId(self.cache.len() as u32);
150        self.cache.insert(inst, new_id);
151        self.array.push(inst);
152        new_id
153    }
154
155    pub fn get_set(&self, set_id: TSetId) -> TSet {
156        self.array[set_id.0 as usize]
157    }
158
159    pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
160        &self.array[set_id.0 as usize]
161    }
162
163    pub fn get_id(&mut self, inst: TSet) -> TSetId {
164        match self.cache.get(&inst) {
165            Some(&id) => id,
166            None => self.init(inst),
167        }
168    }
169
170    pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
171        self.array[set_id.0 as usize][idx] & bit != 0
172    }
173
174    pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
175        let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
176        let mut rangestart: Option<u8> = None;
177        let mut prevchar: Option<u8> = None;
178        for i in 0..4 {
179            for j in 0..64 {
180                let nthbit = 1u64 << j;
181                if tset[i] & nthbit != 0 {
182                    let cc = (i * 64 + j) as u8;
183                    if rangestart.is_none() {
184                        rangestart = Some(cc);
185                        prevchar = Some(cc);
186                        continue;
187                    }
188
189                    if let (Some(currstart), Some(currprev)) = (rangestart, prevchar) {
190                        if currprev == cc - 1 {
191                            prevchar = Some(cc);
192                            continue;
193                        }
194                        ranges.insert((currstart, currprev));
195                        rangestart = Some(cc);
196                        prevchar = Some(cc);
197                    }
198                }
199            }
200        }
201        if let (Some(start), Some(end)) = (rangestart, prevchar) {
202            ranges.insert((start, end));
203        }
204        ranges
205    }
206
207    fn pp_byte(b: u8) -> String {
208        if cfg!(feature = "graphviz") {
209            match b as char {
210                // graphviz doesnt like \n so we use \ṅ
211                '\n' => return r"\ṅ".to_owned(),
212                '"' => return r"\u{201c}".to_owned(),
213                '\r' => return r"\r".to_owned(),
214                '\t' => return r"\t".to_owned(),
215                _ => {}
216            }
217        }
218        match b as char {
219            '\n' => r"\n".to_owned(),
220            '\r' => r"\r".to_owned(),
221            '\t' => r"\t".to_owned(),
222            ' ' => r" ".to_owned(),
223            '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
224            | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
225            c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
226            _ => format!("\\x{:02X}", b),
227        }
228    }
229
230    fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
231        let display_range = |c, c2| {
232            if c == c2 {
233                Self::pp_byte(c)
234            } else if c.abs_diff(c2) == 1 {
235                format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
236            } else {
237                format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
238            }
239        };
240
241        if ranges.is_empty() {
242            return "\u{22a5}".to_owned();
243        }
244        if ranges.len() == 1 {
245            let (s, e) = ranges.iter().next().unwrap();
246            if s == e {
247                return Self::pp_byte(*s);
248            } else {
249                return ranges
250                        .iter()
251                        .map(|(s, e)| display_range(*s, *e))
252                        .collect::<Vec<_>>()
253                        .join("").to_string();
254            }
255        }
256        if ranges.len() > 20 {
257            return "\u{03c6}".to_owned();
258        }
259        ranges
260                .iter()
261                .map(|(s, e)| display_range(*s, *e))
262                .collect::<Vec<_>>()
263                .join("").to_string()
264    }
265
266    pub fn pp_first(&self, tset: &TSet) -> char {
267        let tryn1 = |i: usize| {
268            for j in 0..32 {
269                let nthbit = 1u64 << j;
270                if tset[i] & nthbit != 0 {
271                    let cc = (i * 64 + j) as u8 as char;
272                    return Some(cc);
273                }
274            }
275            None
276        };
277        let tryn2 = |i: usize| {
278            for j in 33..64 {
279                let nthbit = 1u64 << j;
280                if tset[i] & nthbit != 0 {
281                    let cc = (i * 64 + j) as u8 as char;
282                    return Some(cc);
283                }
284            }
285            None
286        };
287        // readable ones first
288        tryn2(0)
289            .or_else(|| tryn2(1))
290            .or_else(|| tryn1(1))
291            .or_else(|| tryn1(2))
292            .or_else(|| tryn2(2))
293            .or_else(|| tryn1(3))
294            .or_else(|| tryn2(3))
295            .or_else(|| tryn1(0))
296            .unwrap_or('\u{22a5}')
297    }
298
299    pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
300        let tset = self.get_set(tset);
301        Self::pp_collect_ranges(&tset).into_iter().collect()
302    }
303
304    #[allow(unused)]
305    fn first_byte(tset: &TSet) -> u8 {
306        for i in 0..4 {
307            for j in 0..64 {
308                let nthbit = 1u64 << j;
309                if tset[i] & nthbit != 0 {
310                    let cc = (i * 64 + j) as u8;
311                    return cc;
312                }
313            }
314        }
315        0
316    }
317
318    pub fn pp(&self, tset: TSetId) -> String {
319        if tset == TSetId::FULL {
320            return "_".to_owned();
321        }
322        if tset == TSetId::EMPTY {
323            return "\u{22a5}".to_owned();
324        }
325        let tset = self.get_set(tset);
326        let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
327        let rstart = ranges.first().unwrap().0;
328        let rend = ranges.last().unwrap().1;
329        if ranges.len() >= 2 && rstart == 0 && rend == 255 {
330            let not_id = Self::not(&tset);
331            let not_ranges = Self::pp_collect_ranges(&not_id);
332            if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
333                return r".".to_owned();
334            }
335            let content = Self::pp_content(&not_ranges);
336            return format!("[^{}]", content);
337        }
338        if ranges.is_empty() {
339            return "\u{22a5}".to_owned();
340        }
341        if ranges.len() == 1 {
342            let (s, e) = ranges.iter().next().unwrap();
343            if s == e {
344                return Self::pp_byte(*s);
345            } else {
346                let content = Self::pp_content(&ranges);
347                return format!("[{}]", content);
348            }
349        }
350        let content = Self::pp_content(&ranges);
351        format!("[{}]", content)
352    }
353}
354
355impl Solver {
356    #[inline]
357    pub fn full() -> TSet {
358        FULL
359    }
360
361    #[inline]
362    pub fn empty() -> TSet {
363        EMPTY
364    }
365
366    #[inline]
367    pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
368        self.get_id(self.get_set(set1) | self.get_set(set2))
369    }
370
371    #[inline]
372    pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
373        self.get_id(self.get_set(set1) & self.get_set(set2))
374    }
375
376    #[inline]
377    pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
378        self.get_id(!self.get_set(set_id))
379    }
380
381    #[inline]
382    pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
383        self.and_id(set1, set2) != TSetId::EMPTY
384    }
385    #[inline]
386    pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
387        self.and_id(set1, set2) == TSetId::EMPTY
388    }
389
390    pub fn byte_count(&self, set_id: TSetId) -> u32 {
391        let tset = self.get_set(set_id);
392        (0..4).map(|i| tset[i].count_ones()).sum()
393    }
394
395    pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
396        let tset = self.get_set(set_id);
397        let mut bytes = Vec::new();
398        for i in 0..4 {
399            let mut bits = tset[i];
400            while bits != 0 {
401                let j = bits.trailing_zeros() as usize;
402                bytes.push((i * 64 + j) as u8);
403                bits &= bits - 1;
404            }
405        }
406        bytes
407    }
408
409    pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
410        let tset = self.get_set(set_id);
411        let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
412        if total != 1 {
413            return None;
414        }
415        for i in 0..4 {
416            if tset[i] != 0 {
417                return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
418            }
419        }
420        None
421    }
422
423    #[inline]
424    pub fn is_empty_id(&self, set1: TSetId) -> bool {
425        set1 == TSetId::EMPTY
426    }
427
428    #[inline]
429    pub fn is_full_id(&self, set1: TSetId) -> bool {
430        set1 == TSetId::FULL
431    }
432
433    #[inline]
434    pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
435        let not_large = self.not_id(large_id);
436        self.and_id(small_id, not_large) == TSetId::EMPTY
437    }
438
439    pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
440        let mut result = TSet::splat(u64::MIN);
441        let nthbit = 1u64 << (byte % 64);
442        match byte {
443            0..=63 => {
444                result[0] = nthbit;
445            }
446            64..=127 => {
447                result[1] = nthbit;
448            }
449            128..=191 => {
450                result[2] = nthbit;
451            }
452            192..=255 => {
453                result[3] = nthbit;
454            }
455        }
456        self.get_id(result)
457    }
458
459    pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
460        let mut result = TSet::splat(u64::MIN);
461        for byte in start..=end {
462            let nthbit = 1u64 << (byte % 64);
463            match byte {
464                0..=63 => {
465                    result[0] |= nthbit;
466                }
467                64..=127 => {
468                    result[1] |= nthbit;
469                }
470                128..=191 => {
471                    result[2] |= nthbit;
472                }
473                192..=255 => {
474                    result[3] |= nthbit;
475                }
476            }
477        }
478        self.get_id(result)
479    }
480
481    #[inline]
482    pub fn and(set1: &TSet, set2: &TSet) -> TSet {
483        *set1 & *set2
484    }
485
486    #[inline]
487    pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
488        *set1 & *set2 != Solver::empty()
489    }
490
491    #[inline]
492    pub fn or(set1: &TSet, set2: &TSet) -> TSet {
493        *set1 | *set2
494    }
495
496    #[inline]
497    pub fn not(set: &TSet) -> TSet {
498        !*set
499    }
500
501    #[inline]
502    pub fn is_full(set: &TSet) -> bool {
503        *set == Self::full()
504    }
505
506    #[inline]
507    pub fn is_empty(set: &TSet) -> bool {
508        *set == Solver::empty()
509    }
510
511    #[inline]
512    pub fn contains(large: &TSet, small: &TSet) -> bool {
513        Solver::empty() == (*small & !*large)
514    }
515
516    pub fn u8_to_set(byte: u8) -> TSet {
517        let mut result = TSet::splat(u64::MIN);
518        let nthbit = 1u64 << (byte % 64);
519        match byte {
520            0..=63 => {
521                result[0] = nthbit;
522            }
523            64..=127 => {
524                result[1] = nthbit;
525            }
526            128..=191 => {
527                result[2] = nthbit;
528            }
529            192..=255 => {
530                result[3] = nthbit;
531            }
532        }
533        result
534    }
535
536    pub fn range_to_set(start: u8, end: u8) -> TSet {
537        let mut result = TSet::splat(u64::MIN);
538        for byte in start..=end {
539            let nthbit = 1u64 << (byte % 64);
540            match byte {
541                0..=63 => {
542                    result[0] |= nthbit;
543                }
544                64..=127 => {
545                    result[1] |= nthbit;
546                }
547                128..=191 => {
548                    result[2] |= nthbit;
549                }
550                192..=255 => {
551                    result[3] |= nthbit;
552                }
553            }
554        }
555        result
556    }
557}