Skip to main content

resharp_algebra/
solver.rs

1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2pub struct TSet(pub [u64; 4]);
3
4impl TSet {
5    #[inline]
6    pub const fn splat(v: u64) -> Self {
7        TSet([v, v, v, v])
8    }
9
10    pub fn from_bytes(bytes: &[u8]) -> Self {
11        let mut bits = [0u64; 4];
12        for &b in bytes {
13            bits[b as usize / 64] |= 1u64 << (b as usize % 64);
14        }
15        Self(bits)
16    }
17
18    #[inline(always)]
19    pub fn contains_byte(&self, b: u8) -> bool {
20        self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
21    }
22}
23
24impl std::ops::Index<usize> for TSet {
25    type Output = u64;
26    #[inline]
27    fn index(&self, i: usize) -> &u64 {
28        &self.0[i]
29    }
30}
31
32impl std::ops::IndexMut<usize> for TSet {
33    #[inline]
34    fn index_mut(&mut self, i: usize) -> &mut u64 {
35        &mut self.0[i]
36    }
37}
38
39impl std::ops::BitAnd for TSet {
40    type Output = TSet;
41    #[inline]
42    fn bitand(self, rhs: TSet) -> TSet {
43        TSet([
44            self.0[0] & rhs.0[0],
45            self.0[1] & rhs.0[1],
46            self.0[2] & rhs.0[2],
47            self.0[3] & rhs.0[3],
48        ])
49    }
50}
51
52impl std::ops::BitAnd for &TSet {
53    type Output = TSet;
54    #[inline]
55    fn bitand(self, rhs: &TSet) -> TSet {
56        TSet([
57            self.0[0] & rhs.0[0],
58            self.0[1] & rhs.0[1],
59            self.0[2] & rhs.0[2],
60            self.0[3] & rhs.0[3],
61        ])
62    }
63}
64
65impl std::ops::BitOr for TSet {
66    type Output = TSet;
67    #[inline]
68    fn bitor(self, rhs: TSet) -> TSet {
69        TSet([
70            self.0[0] | rhs.0[0],
71            self.0[1] | rhs.0[1],
72            self.0[2] | rhs.0[2],
73            self.0[3] | rhs.0[3],
74        ])
75    }
76}
77
78impl std::ops::Not for TSet {
79    type Output = TSet;
80    #[inline]
81    fn not(self) -> TSet {
82        TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
83    }
84}
85
86// &TSet ops used by Solver helper methods
87impl std::ops::BitAnd<TSet> for &TSet {
88    type Output = TSet;
89    #[inline]
90    fn bitand(self, rhs: TSet) -> TSet {
91        TSet([
92            self.0[0] & rhs.0[0],
93            self.0[1] & rhs.0[1],
94            self.0[2] & rhs.0[2],
95            self.0[3] & rhs.0[3],
96        ])
97    }
98}
99
100impl std::ops::BitOr<TSet> for &TSet {
101    type Output = TSet;
102    #[inline]
103    fn bitor(self, rhs: TSet) -> TSet {
104        TSet([
105            self.0[0] | rhs.0[0],
106            self.0[1] | rhs.0[1],
107            self.0[2] | rhs.0[2],
108            self.0[3] | rhs.0[3],
109        ])
110    }
111}
112
113const EMPTY: TSet = TSet::splat(u64::MIN);
114const FULL: TSet = TSet::splat(u64::MAX);
115
116#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
117pub struct TSetId(pub u32);
118impl TSetId {
119    pub const EMPTY: TSetId = TSetId(0);
120    pub const FULL: TSetId = TSetId(1);
121}
122
123use rustc_hash::FxHashMap;
124use std::collections::BTreeSet;
125
126pub struct Solver {
127    cache: FxHashMap<TSet, TSetId>,
128    pub array: Vec<TSet>,
129}
130
131impl Default for Solver {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl Solver {
138    pub fn new() -> Solver {
139        let mut inst = Self {
140            cache: FxHashMap::default(),
141            array: Vec::new(),
142        };
143        let _ = inst.init(Solver::empty()); // 0
144        let _ = inst.init(Solver::full()); // 1
145        inst
146    }
147
148    fn init(&mut self, inst: TSet) -> TSetId {
149        let new_id = TSetId(self.cache.len() as u32);
150        self.cache.insert(inst, new_id);
151        self.array.push(inst);
152        new_id
153    }
154
155    pub fn get_set(&self, set_id: TSetId) -> TSet {
156        self.array[set_id.0 as usize]
157    }
158
159    pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
160        &self.array[set_id.0 as usize]
161    }
162
163    pub fn get_id(&mut self, inst: TSet) -> TSetId {
164        match self.cache.get(&inst) {
165            Some(&id) => id,
166            None => self.init(inst),
167        }
168    }
169
170    pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
171        self.array[set_id.0 as usize][idx] & bit != 0
172    }
173
174    pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
175        let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
176        let mut rangestart: Option<u8> = None;
177        let mut prevchar: Option<u8> = None;
178        for i in 0..4 {
179            for j in 0..64 {
180                let nthbit = 1u64 << j;
181                if tset[i] & nthbit != 0 {
182                    let cc = (i * 64 + j) as u8;
183                    if rangestart.is_none() {
184                        rangestart = Some(cc);
185                        prevchar = Some(cc);
186                        continue;
187                    }
188
189                    if let (Some(currstart), Some(currprev)) = (rangestart, prevchar) {
190                        if currprev == cc - 1 {
191                            prevchar = Some(cc);
192                            continue;
193                        }
194                        ranges.insert((currstart, currprev));
195                        rangestart = Some(cc);
196                        prevchar = Some(cc);
197                    }
198                }
199            }
200        }
201        if let (Some(start), Some(end)) = (rangestart, prevchar) {
202            ranges.insert((start, end));
203        }
204        ranges
205    }
206
207    fn pp_byte(b: u8) -> String {
208        if cfg!(feature = "graphviz") {
209            match b as char {
210                // graphviz doesnt like \n so we use \ṅ
211                '\n' => return r"\ṅ".to_owned(),
212                '"' => return r"\u{201c}".to_owned(),
213                '\r' => return r"\r".to_owned(),
214                '\t' => return r"\t".to_owned(),
215                _ => {}
216            }
217        }
218        match b as char {
219            '\n' => r"\n".to_owned(),
220            '\r' => r"\r".to_owned(),
221            '\t' => r"\t".to_owned(),
222            ' ' => r" ".to_owned(),
223            '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
224            | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
225            c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
226            _ => format!("\\x{:02X}", b),
227        }
228    }
229
230    fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
231        let display_range = |c, c2| {
232            if c == c2 {
233                Self::pp_byte(c)
234            } else if c.abs_diff(c2) == 1 {
235                format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
236            } else {
237                format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
238            }
239        };
240
241        if ranges.is_empty() {
242            return "\u{22a5}".to_owned();
243        }
244        if ranges.len() == 1 {
245            let (s, e) = ranges.iter().next().unwrap();
246            if s == e {
247                return Self::pp_byte(*s);
248            } else {
249                return ranges
250                    .iter()
251                    .map(|(s, e)| display_range(*s, *e))
252                    .collect::<Vec<_>>()
253                    .join("")
254                    .to_string();
255            }
256        }
257        if ranges.len() > 20 {
258            return "\u{03c6}".to_owned();
259        }
260        ranges
261            .iter()
262            .map(|(s, e)| display_range(*s, *e))
263            .collect::<Vec<_>>()
264            .join("")
265            .to_string()
266    }
267
268    pub fn pp_first(&self, tset: &TSet) -> char {
269        let tryn1 = |i: usize| {
270            for j in 0..32 {
271                let nthbit = 1u64 << j;
272                if tset[i] & nthbit != 0 {
273                    let cc = (i * 64 + j) as u8 as char;
274                    return Some(cc);
275                }
276            }
277            None
278        };
279        let tryn2 = |i: usize| {
280            for j in 33..64 {
281                let nthbit = 1u64 << j;
282                if tset[i] & nthbit != 0 {
283                    let cc = (i * 64 + j) as u8 as char;
284                    return Some(cc);
285                }
286            }
287            None
288        };
289        // readable ones first
290        tryn2(0)
291            .or_else(|| tryn2(1))
292            .or_else(|| tryn1(1))
293            .or_else(|| tryn1(2))
294            .or_else(|| tryn2(2))
295            .or_else(|| tryn1(3))
296            .or_else(|| tryn2(3))
297            .or_else(|| tryn1(0))
298            .unwrap_or('\u{22a5}')
299    }
300
301    pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
302        let tset = self.get_set(tset);
303        Self::pp_collect_ranges(&tset).into_iter().collect()
304    }
305
306    #[allow(unused)]
307    fn first_byte(tset: &TSet) -> u8 {
308        for i in 0..4 {
309            for j in 0..64 {
310                let nthbit = 1u64 << j;
311                if tset[i] & nthbit != 0 {
312                    let cc = (i * 64 + j) as u8;
313                    return cc;
314                }
315            }
316        }
317        0
318    }
319
320    pub fn pp(&self, tset: TSetId) -> String {
321        if tset == TSetId::FULL {
322            return "_".to_owned();
323        }
324        if tset == TSetId::EMPTY {
325            return "\u{22a5}".to_owned();
326        }
327        let tset = self.get_set(tset);
328        let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
329        let rstart = ranges.first().unwrap().0;
330        let rend = ranges.last().unwrap().1;
331        if ranges.len() >= 2 && rstart == 0 && rend == 255 {
332            let not_id = Self::not(&tset);
333            let not_ranges = Self::pp_collect_ranges(&not_id);
334            if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
335                return r".".to_owned();
336            }
337            let content = Self::pp_content(&not_ranges);
338            return format!("[^{}]", content);
339        }
340        if ranges.is_empty() {
341            return "\u{22a5}".to_owned();
342        }
343        if ranges.len() == 1 {
344            let (s, e) = ranges.iter().next().unwrap();
345            if s == e {
346                return Self::pp_byte(*s);
347            } else {
348                let content = Self::pp_content(&ranges);
349                return format!("[{}]", content);
350            }
351        }
352        let content = Self::pp_content(&ranges);
353        format!("[{}]", content)
354    }
355}
356
357impl Solver {
358    #[inline]
359    pub fn full() -> TSet {
360        FULL
361    }
362
363    #[inline]
364    pub fn empty() -> TSet {
365        EMPTY
366    }
367
368    #[inline]
369    pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
370        self.get_id(self.get_set(set1) | self.get_set(set2))
371    }
372
373    #[inline]
374    pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
375        self.get_id(self.get_set(set1) & self.get_set(set2))
376    }
377
378    #[inline]
379    pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
380        self.get_id(!self.get_set(set_id))
381    }
382
383    #[inline]
384    pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
385        self.and_id(set1, set2) != TSetId::EMPTY
386    }
387    #[inline]
388    pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
389        self.and_id(set1, set2) == TSetId::EMPTY
390    }
391
392    pub fn byte_count(&self, set_id: TSetId) -> u32 {
393        let tset = self.get_set(set_id);
394        (0..4).map(|i| tset[i].count_ones()).sum()
395    }
396
397    pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
398        let tset = self.get_set(set_id);
399        let mut bytes = Vec::new();
400        for i in 0..4 {
401            let mut bits = tset[i];
402            while bits != 0 {
403                let j = bits.trailing_zeros() as usize;
404                bytes.push((i * 64 + j) as u8);
405                bits &= bits - 1;
406            }
407        }
408        bytes
409    }
410
411    pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
412        let tset = self.get_set(set_id);
413        let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
414        if total != 1 {
415            return None;
416        }
417        for i in 0..4 {
418            if tset[i] != 0 {
419                return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
420            }
421        }
422        None
423    }
424
425    #[inline]
426    pub fn is_empty_id(&self, set1: TSetId) -> bool {
427        set1 == TSetId::EMPTY
428    }
429
430    #[inline]
431    pub fn is_full_id(&self, set1: TSetId) -> bool {
432        set1 == TSetId::FULL
433    }
434
435    #[inline]
436    pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
437        let not_large = self.not_id(large_id);
438        self.and_id(small_id, not_large) == TSetId::EMPTY
439    }
440
441    pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
442        let mut result = TSet::splat(u64::MIN);
443        let nthbit = 1u64 << (byte % 64);
444        match byte {
445            0..=63 => {
446                result[0] = nthbit;
447            }
448            64..=127 => {
449                result[1] = nthbit;
450            }
451            128..=191 => {
452                result[2] = nthbit;
453            }
454            192..=255 => {
455                result[3] = nthbit;
456            }
457        }
458        self.get_id(result)
459    }
460
461    pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
462        let mut result = TSet::splat(u64::MIN);
463        for byte in start..=end {
464            let nthbit = 1u64 << (byte % 64);
465            match byte {
466                0..=63 => {
467                    result[0] |= nthbit;
468                }
469                64..=127 => {
470                    result[1] |= nthbit;
471                }
472                128..=191 => {
473                    result[2] |= nthbit;
474                }
475                192..=255 => {
476                    result[3] |= nthbit;
477                }
478            }
479        }
480        self.get_id(result)
481    }
482
483    #[inline]
484    pub fn and(set1: &TSet, set2: &TSet) -> TSet {
485        *set1 & *set2
486    }
487
488    #[inline]
489    pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
490        *set1 & *set2 != Solver::empty()
491    }
492
493    #[inline]
494    pub fn or(set1: &TSet, set2: &TSet) -> TSet {
495        *set1 | *set2
496    }
497
498    #[inline]
499    pub fn not(set: &TSet) -> TSet {
500        !*set
501    }
502
503    #[inline]
504    pub fn is_full(set: &TSet) -> bool {
505        *set == Self::full()
506    }
507
508    #[inline]
509    pub fn is_empty(set: &TSet) -> bool {
510        *set == Solver::empty()
511    }
512
513    #[inline]
514    pub fn contains(large: &TSet, small: &TSet) -> bool {
515        Solver::empty() == (*small & !*large)
516    }
517
518    pub fn u8_to_set(byte: u8) -> TSet {
519        let mut result = TSet::splat(u64::MIN);
520        let nthbit = 1u64 << (byte % 64);
521        match byte {
522            0..=63 => {
523                result[0] = nthbit;
524            }
525            64..=127 => {
526                result[1] = nthbit;
527            }
528            128..=191 => {
529                result[2] = nthbit;
530            }
531            192..=255 => {
532                result[3] = nthbit;
533            }
534        }
535        result
536    }
537
538    pub fn range_to_set(start: u8, end: u8) -> TSet {
539        let mut result = TSet::splat(u64::MIN);
540        for byte in start..=end {
541            let nthbit = 1u64 << (byte % 64);
542            match byte {
543                0..=63 => {
544                    result[0] |= nthbit;
545                }
546                64..=127 => {
547                    result[1] |= nthbit;
548                }
549                128..=191 => {
550                    result[2] |= nthbit;
551                }
552                192..=255 => {
553                    result[3] |= nthbit;
554                }
555            }
556        }
557        result
558    }
559}