Skip to main content

resharp_algebra/
solver.rs

1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
3pub struct TSet(pub [u64; 4]);
4
5impl TSet {
6    #[inline]
7    pub const fn splat(v: u64) -> Self {
8        TSet([v, v, v, v])
9    }
10
11    pub fn from_bytes(bytes: &[u8]) -> Self {
12        let mut bits = [0u64; 4];
13        for &b in bytes {
14            bits[b as usize / 64] |= 1u64 << (b as usize % 64);
15        }
16        Self(bits)
17    }
18
19    #[inline(always)]
20    pub fn contains_byte(&self, b: u8) -> bool {
21        self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
22    }
23}
24
25impl std::ops::Index<usize> for TSet {
26    type Output = u64;
27    #[inline]
28    fn index(&self, i: usize) -> &u64 {
29        &self.0[i]
30    }
31}
32
33impl std::ops::IndexMut<usize> for TSet {
34    #[inline]
35    fn index_mut(&mut self, i: usize) -> &mut u64 {
36        &mut self.0[i]
37    }
38}
39
40impl std::ops::BitAnd for TSet {
41    type Output = TSet;
42    #[inline]
43    fn bitand(self, rhs: TSet) -> TSet {
44        TSet([
45            self.0[0] & rhs.0[0],
46            self.0[1] & rhs.0[1],
47            self.0[2] & rhs.0[2],
48            self.0[3] & rhs.0[3],
49        ])
50    }
51}
52
53impl std::ops::BitAnd for &TSet {
54    type Output = TSet;
55    #[inline]
56    fn bitand(self, rhs: &TSet) -> TSet {
57        TSet([
58            self.0[0] & rhs.0[0],
59            self.0[1] & rhs.0[1],
60            self.0[2] & rhs.0[2],
61            self.0[3] & rhs.0[3],
62        ])
63    }
64}
65
66impl std::ops::BitOr for TSet {
67    type Output = TSet;
68    #[inline]
69    fn bitor(self, rhs: TSet) -> TSet {
70        TSet([
71            self.0[0] | rhs.0[0],
72            self.0[1] | rhs.0[1],
73            self.0[2] | rhs.0[2],
74            self.0[3] | rhs.0[3],
75        ])
76    }
77}
78
79impl std::ops::Not for TSet {
80    type Output = TSet;
81    #[inline]
82    fn not(self) -> TSet {
83        TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
84    }
85}
86
87// &TSet ops used by Solver helper methods
88impl std::ops::BitAnd<TSet> for &TSet {
89    type Output = TSet;
90    #[inline]
91    fn bitand(self, rhs: TSet) -> TSet {
92        TSet([
93            self.0[0] & rhs.0[0],
94            self.0[1] & rhs.0[1],
95            self.0[2] & rhs.0[2],
96            self.0[3] & rhs.0[3],
97        ])
98    }
99}
100
101impl std::ops::BitOr<TSet> for &TSet {
102    type Output = TSet;
103    #[inline]
104    fn bitor(self, rhs: TSet) -> TSet {
105        TSet([
106            self.0[0] | rhs.0[0],
107            self.0[1] | rhs.0[1],
108            self.0[2] | rhs.0[2],
109            self.0[3] | rhs.0[3],
110        ])
111    }
112}
113
114const EMPTY: TSet = TSet::splat(u64::MIN);
115const FULL: TSet = TSet::splat(u64::MAX);
116
117#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
118pub struct TSetId(pub u32);
119impl TSetId {
120    pub const EMPTY: TSetId = TSetId(0);
121    pub const FULL: TSetId = TSetId(1);
122}
123
124use rustc_hash::FxHashMap;
125use std::collections::BTreeSet;
126
127pub struct Solver {
128    cache: FxHashMap<TSet, TSetId>,
129    pub array: Vec<TSet>,
130}
131
132impl Default for Solver {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl Solver {
139    pub fn new() -> Solver {
140        let mut inst = Self {
141            cache: FxHashMap::default(),
142            array: Vec::new(),
143        };
144        let _ = inst.init(Solver::empty()); // 0
145        let _ = inst.init(Solver::full()); // 1
146        inst
147    }
148
149    fn init(&mut self, inst: TSet) -> TSetId {
150        let new_id = TSetId(self.cache.len() as u32);
151        self.cache.insert(inst, new_id);
152        self.array.push(inst);
153        new_id
154    }
155
156    pub fn get_set(&self, set_id: TSetId) -> TSet {
157        self.array[set_id.0 as usize]
158    }
159
160    pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
161        &self.array[set_id.0 as usize]
162    }
163
164    pub fn get_id(&mut self, inst: TSet) -> TSetId {
165        match self.cache.get(&inst) {
166            Some(&id) => id,
167            None => self.init(inst),
168        }
169    }
170
171    pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
172        self.array[set_id.0 as usize][idx] & bit != 0
173    }
174
175    pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
176        let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
177        let mut rangestart: Option<u8> = None;
178        let mut prevchar: Option<u8> = None;
179        for i in 0..4 {
180            for j in 0..64 {
181                let nthbit = 1u64 << j;
182                if tset[i] & nthbit != 0 {
183                    let cc = (i * 64 + j) as u8;
184                    if rangestart.is_none() {
185                        rangestart = Some(cc);
186                        prevchar = Some(cc);
187                        continue;
188                    }
189
190                    if let (Some(currstart), Some(currprev)) = (rangestart, prevchar) {
191                        if currprev == cc - 1 {
192                            prevchar = Some(cc);
193                            continue;
194                        }
195                        ranges.insert((currstart, currprev));
196                        rangestart = Some(cc);
197                        prevchar = Some(cc);
198                    }
199                }
200            }
201        }
202        if let (Some(start), Some(end)) = (rangestart, prevchar) {
203            ranges.insert((start, end));
204        }
205        ranges
206    }
207
208    fn pp_byte(b: u8) -> String {
209        if cfg!(feature = "graphviz") {
210            match b as char {
211                // graphviz doesnt like \n so we use \ṅ
212                '\n' => return r"\ṅ".to_owned(),
213                '"' => return r"\u{201c}".to_owned(),
214                '\r' => return r"\r".to_owned(),
215                '\t' => return r"\t".to_owned(),
216                _ => {}
217            }
218        }
219        match b as char {
220            '\n' => r"\n".to_owned(),
221            '\r' => r"\r".to_owned(),
222            '\t' => r"\t".to_owned(),
223            ' ' => r" ".to_owned(),
224            '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
225            | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
226            c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
227            _ => format!("\\x{:02X}", b),
228        }
229    }
230
231    fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
232        let display_range = |c, c2| {
233            if c == c2 {
234                Self::pp_byte(c)
235            } else if c.abs_diff(c2) == 1 {
236                format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
237            } else {
238                format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
239            }
240        };
241
242        if ranges.is_empty() {
243            return "\u{22a5}".to_owned();
244        }
245        if ranges.len() == 1 {
246            let (s, e) = ranges.iter().next().unwrap();
247            if s == e {
248                return Self::pp_byte(*s);
249            } else {
250                return ranges
251                    .iter()
252                    .map(|(s, e)| display_range(*s, *e))
253                    .collect::<Vec<_>>()
254                    .join("")
255                    .to_string();
256            }
257        }
258        if ranges.len() > 20 {
259            return "\u{03c6}".to_owned();
260        }
261        ranges
262            .iter()
263            .map(|(s, e)| display_range(*s, *e))
264            .collect::<Vec<_>>()
265            .join("")
266            .to_string()
267    }
268
269    pub fn pp_first(&self, tset: &TSet) -> char {
270        let tryn1 = |i: usize| {
271            for j in 0..32 {
272                let nthbit = 1u64 << j;
273                if tset[i] & nthbit != 0 {
274                    let cc = (i * 64 + j) as u8 as char;
275                    return Some(cc);
276                }
277            }
278            None
279        };
280        let tryn2 = |i: usize| {
281            for j in 33..64 {
282                let nthbit = 1u64 << j;
283                if tset[i] & nthbit != 0 {
284                    let cc = (i * 64 + j) as u8 as char;
285                    return Some(cc);
286                }
287            }
288            None
289        };
290        // readable ones first
291        tryn2(0)
292            .or_else(|| tryn2(1))
293            .or_else(|| tryn1(1))
294            .or_else(|| tryn1(2))
295            .or_else(|| tryn2(2))
296            .or_else(|| tryn1(3))
297            .or_else(|| tryn2(3))
298            .or_else(|| tryn1(0))
299            .unwrap_or('\u{22a5}')
300    }
301
302    pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
303        let tset = self.get_set(tset);
304        Self::pp_collect_ranges(&tset).into_iter().collect()
305    }
306
307    #[allow(unused)]
308    fn first_byte(tset: &TSet) -> u8 {
309        for i in 0..4 {
310            for j in 0..64 {
311                let nthbit = 1u64 << j;
312                if tset[i] & nthbit != 0 {
313                    let cc = (i * 64 + j) as u8;
314                    return cc;
315                }
316            }
317        }
318        0
319    }
320
321    pub fn pp(&self, tset: TSetId) -> String {
322        if tset == TSetId::FULL {
323            return "_".to_owned();
324        }
325        if tset == TSetId::EMPTY {
326            return "\u{22a5}".to_owned();
327        }
328        let tset = self.get_set(tset);
329        let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
330        let rstart = ranges.first().unwrap().0;
331        let rend = ranges.last().unwrap().1;
332        if ranges.len() >= 2 && rstart == 0 && rend == 255 {
333            let not_id = Self::not(&tset);
334            let not_ranges = Self::pp_collect_ranges(&not_id);
335            if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
336                return r".".to_owned();
337            }
338            let content = Self::pp_content(&not_ranges);
339            return format!("[^{}]", content);
340        }
341        if ranges.is_empty() {
342            return "\u{22a5}".to_owned();
343        }
344        if ranges.len() == 1 {
345            let (s, e) = ranges.iter().next().unwrap();
346            if s == e {
347                return Self::pp_byte(*s);
348            } else {
349                let content = Self::pp_content(&ranges);
350                return format!("[{}]", content);
351            }
352        }
353        let content = Self::pp_content(&ranges);
354        format!("[{}]", content)
355    }
356}
357
358impl Solver {
359    #[inline]
360    pub fn full() -> TSet {
361        FULL
362    }
363
364    #[inline]
365    pub fn empty() -> TSet {
366        EMPTY
367    }
368
369    #[inline]
370    pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
371        self.get_id(self.get_set(set1) | self.get_set(set2))
372    }
373
374    #[inline]
375    pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
376        self.get_id(self.get_set(set1) & self.get_set(set2))
377    }
378
379    #[inline]
380    pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
381        self.get_id(!self.get_set(set_id))
382    }
383
384    #[inline]
385    pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
386        self.and_id(set1, set2) != TSetId::EMPTY
387    }
388    #[inline]
389    pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
390        self.and_id(set1, set2) == TSetId::EMPTY
391    }
392
393    pub fn byte_count(&self, set_id: TSetId) -> u32 {
394        let tset = self.get_set(set_id);
395        (0..4).map(|i| tset[i].count_ones()).sum()
396    }
397
398    pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
399        let tset = self.get_set(set_id);
400        let mut bytes = Vec::new();
401        for i in 0..4 {
402            let mut bits = tset[i];
403            while bits != 0 {
404                let j = bits.trailing_zeros() as usize;
405                bytes.push((i * 64 + j) as u8);
406                bits &= bits - 1;
407            }
408        }
409        bytes
410    }
411
412    pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
413        let tset = self.get_set(set_id);
414        let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
415        if total != 1 {
416            return None;
417        }
418        for i in 0..4 {
419            if tset[i] != 0 {
420                return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
421            }
422        }
423        None
424    }
425
426    #[inline]
427    pub fn is_empty_id(&self, set1: TSetId) -> bool {
428        set1 == TSetId::EMPTY
429    }
430
431    #[inline]
432    pub fn is_full_id(&self, set1: TSetId) -> bool {
433        set1 == TSetId::FULL
434    }
435
436    #[inline]
437    pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
438        let not_large = self.not_id(large_id);
439        self.and_id(small_id, not_large) == TSetId::EMPTY
440    }
441
442    pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
443        let mut result = TSet::splat(u64::MIN);
444        let nthbit = 1u64 << (byte % 64);
445        match byte {
446            0..=63 => {
447                result[0] = nthbit;
448            }
449            64..=127 => {
450                result[1] = nthbit;
451            }
452            128..=191 => {
453                result[2] = nthbit;
454            }
455            192..=255 => {
456                result[3] = nthbit;
457            }
458        }
459        self.get_id(result)
460    }
461
462    pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
463        let mut result = TSet::splat(u64::MIN);
464        for byte in start..=end {
465            let nthbit = 1u64 << (byte % 64);
466            match byte {
467                0..=63 => {
468                    result[0] |= nthbit;
469                }
470                64..=127 => {
471                    result[1] |= nthbit;
472                }
473                128..=191 => {
474                    result[2] |= nthbit;
475                }
476                192..=255 => {
477                    result[3] |= nthbit;
478                }
479            }
480        }
481        self.get_id(result)
482    }
483
484    #[inline]
485    pub fn and(set1: &TSet, set2: &TSet) -> TSet {
486        *set1 & *set2
487    }
488
489    #[inline]
490    pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
491        *set1 & *set2 != Solver::empty()
492    }
493
494    #[inline]
495    pub fn or(set1: &TSet, set2: &TSet) -> TSet {
496        *set1 | *set2
497    }
498
499    #[inline]
500    pub fn not(set: &TSet) -> TSet {
501        !*set
502    }
503
504    #[inline]
505    pub fn is_full(set: &TSet) -> bool {
506        *set == Self::full()
507    }
508
509    #[inline]
510    pub fn is_empty(set: &TSet) -> bool {
511        *set == Solver::empty()
512    }
513
514    #[inline]
515    pub fn contains(large: &TSet, small: &TSet) -> bool {
516        Solver::empty() == (*small & !*large)
517    }
518
519    pub fn u8_to_set(byte: u8) -> TSet {
520        let mut result = TSet::splat(u64::MIN);
521        let nthbit = 1u64 << (byte % 64);
522        match byte {
523            0..=63 => {
524                result[0] = nthbit;
525            }
526            64..=127 => {
527                result[1] = nthbit;
528            }
529            128..=191 => {
530                result[2] = nthbit;
531            }
532            192..=255 => {
533                result[3] = nthbit;
534            }
535        }
536        result
537    }
538
539    pub fn range_to_set(start: u8, end: u8) -> TSet {
540        let mut result = TSet::splat(u64::MIN);
541        for byte in start..=end {
542            let nthbit = 1u64 << (byte % 64);
543            match byte {
544                0..=63 => {
545                    result[0] |= nthbit;
546                }
547                64..=127 => {
548                    result[1] |= nthbit;
549                }
550                128..=191 => {
551                    result[2] |= nthbit;
552                }
553                192..=255 => {
554                    result[3] |= nthbit;
555                }
556            }
557        }
558        result
559    }
560}