Skip to main content

resharp_algebra/
solver.rs

1#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
2pub struct TSet(pub [u64; 4]);
3
4impl TSet {
5    #[inline]
6    pub const fn splat(v: u64) -> Self {
7        TSet([v, v, v, v])
8    }
9
10    pub fn from_bytes(bytes: &[u8]) -> Self {
11        let mut bits = [0u64; 4];
12        for &b in bytes {
13            bits[b as usize / 64] |= 1u64 << (b as usize % 64);
14        }
15        Self(bits)
16    }
17
18    #[inline(always)]
19    pub fn contains_byte(&self, b: u8) -> bool {
20        self.0[b as usize / 64] & (1u64 << (b as usize % 64)) != 0
21    }
22}
23
24impl std::ops::Index<usize> for TSet {
25    type Output = u64;
26    #[inline]
27    fn index(&self, i: usize) -> &u64 {
28        &self.0[i]
29    }
30}
31
32impl std::ops::IndexMut<usize> for TSet {
33    #[inline]
34    fn index_mut(&mut self, i: usize) -> &mut u64 {
35        &mut self.0[i]
36    }
37}
38
39impl std::ops::BitAnd for TSet {
40    type Output = TSet;
41    #[inline]
42    fn bitand(self, rhs: TSet) -> TSet {
43        TSet([
44            self.0[0] & rhs.0[0],
45            self.0[1] & rhs.0[1],
46            self.0[2] & rhs.0[2],
47            self.0[3] & rhs.0[3],
48        ])
49    }
50}
51
52impl std::ops::BitAnd for &TSet {
53    type Output = TSet;
54    #[inline]
55    fn bitand(self, rhs: &TSet) -> TSet {
56        TSet([
57            self.0[0] & rhs.0[0],
58            self.0[1] & rhs.0[1],
59            self.0[2] & rhs.0[2],
60            self.0[3] & rhs.0[3],
61        ])
62    }
63}
64
65impl std::ops::BitOr for TSet {
66    type Output = TSet;
67    #[inline]
68    fn bitor(self, rhs: TSet) -> TSet {
69        TSet([
70            self.0[0] | rhs.0[0],
71            self.0[1] | rhs.0[1],
72            self.0[2] | rhs.0[2],
73            self.0[3] | rhs.0[3],
74        ])
75    }
76}
77
78impl std::ops::Not for TSet {
79    type Output = TSet;
80    #[inline]
81    fn not(self) -> TSet {
82        TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
83    }
84}
85
86// &TSet ops used by Solver helper methods
87impl std::ops::BitAnd<TSet> for &TSet {
88    type Output = TSet;
89    #[inline]
90    fn bitand(self, rhs: TSet) -> TSet {
91        TSet([
92            self.0[0] & rhs.0[0],
93            self.0[1] & rhs.0[1],
94            self.0[2] & rhs.0[2],
95            self.0[3] & rhs.0[3],
96        ])
97    }
98}
99
100impl std::ops::BitOr<TSet> for &TSet {
101    type Output = TSet;
102    #[inline]
103    fn bitor(self, rhs: TSet) -> TSet {
104        TSet([
105            self.0[0] | rhs.0[0],
106            self.0[1] | rhs.0[1],
107            self.0[2] | rhs.0[2],
108            self.0[3] | rhs.0[3],
109        ])
110    }
111}
112
113const EMPTY: TSet = TSet::splat(u64::MIN);
114const FULL: TSet = TSet::splat(u64::MAX);
115
116#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
117pub struct TSetId(pub u32);
118impl TSetId {
119    pub const EMPTY: TSetId = TSetId(0);
120    pub const FULL: TSetId = TSetId(1);
121}
122
123use std::collections::{BTreeMap, BTreeSet};
124
125pub struct Solver {
126    cache: BTreeMap<TSet, TSetId>,
127    pub array: Vec<TSet>,
128}
129
130impl Solver {
131    pub fn new() -> Solver {
132        let mut inst = Self {
133            cache: BTreeMap::new(),
134            array: Vec::new(),
135        };
136        let _ = inst.init(Solver::empty()); // 0
137        let _ = inst.init(Solver::full()); // 1
138        inst
139    }
140
141    fn init(&mut self, inst: TSet) -> TSetId {
142        let new_id = TSetId(self.cache.len() as u32);
143        self.cache.insert(inst, new_id);
144        self.array.push(inst);
145        new_id
146    }
147
148    pub fn get_set(&self, set_id: TSetId) -> TSet {
149        self.array[set_id.0 as usize]
150    }
151
152    pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
153        &self.array[set_id.0 as usize]
154    }
155
156    pub fn get_id(&mut self, inst: TSet) -> TSetId {
157        match self.cache.get(&inst) {
158            Some(&id) => id,
159            None => self.init(inst),
160        }
161    }
162
163    pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
164        self.array[set_id.0 as usize][idx] & bit != 0
165    }
166
167    pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
168        let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
169        let mut rangestart: Option<u8> = None;
170        let mut prevchar: Option<u8> = None;
171        for i in 0..4 {
172            for j in 0..64 {
173                let nthbit = 1u64 << j;
174                if tset[i] & nthbit != 0 {
175                    let cc = (i * 64 + j) as u8;
176                    if rangestart.is_none() {
177                        rangestart = Some(cc);
178                        prevchar = Some(cc);
179                        continue;
180                    }
181
182                    if let Some(currstart) = rangestart {
183                        if let Some(currprev) = prevchar {
184                            if currprev as u8 == cc as u8 - 1 {
185                                prevchar = Some(cc);
186                                continue;
187                            } else {
188                                if currstart == currprev {
189                                    ranges.insert((currstart, currstart));
190                                } else {
191                                    ranges.insert((currstart, currprev));
192                                }
193                                rangestart = Some(cc);
194                                prevchar = Some(cc);
195                            }
196                        } else {
197                        }
198                    } else {
199                    }
200                }
201            }
202        }
203        if let Some(start) = rangestart {
204            if let Some(prevchar) = prevchar {
205                if prevchar as u8 == start as u8 {
206                    ranges.insert((start, start));
207                } else {
208                    ranges.insert((start, prevchar));
209                }
210            } else {
211                // single char
212                ranges.insert((start, start));
213            }
214        }
215        ranges
216    }
217
218    fn pp_byte(b: u8) -> String {
219        if cfg!(feature = "graphviz") {
220            match b as char {
221                // graphviz doesnt like \n so we use \ṅ
222                '\n' => return r"\ṅ".to_owned(),
223                '"' => return r"\u{201c}".to_owned(),
224                '\r' => return r"\r".to_owned(),
225                '\t' => return r"\t".to_owned(),
226                _ => {}
227            }
228        }
229        match b as char {
230            '\n' => r"\n".to_owned(),
231            '\r' => r"\r".to_owned(),
232            '\t' => r"\t".to_owned(),
233            ' ' => r" ".to_owned(),
234            '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
235            | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
236            c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
237            _ => format!("\\x{:02X}", b),
238        }
239    }
240
241    fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
242        let display_range = |c, c2| {
243            if c == c2 {
244                Self::pp_byte(c)
245            } else if c.abs_diff(c2) == 1 {
246                format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
247            } else {
248                format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
249            }
250        };
251
252        if ranges.len() == 0 {
253            return "\u{22a5}".to_owned();
254        }
255        if ranges.len() == 1 {
256            let (s, e) = ranges.iter().next().unwrap();
257            if s == e {
258                return Self::pp_byte(*s);
259            } else {
260                return format!(
261                    "{}",
262                    ranges
263                        .iter()
264                        .map(|(s, e)| display_range(*s, *e))
265                        .collect::<Vec<_>>()
266                        .join("")
267                );
268            }
269        }
270        if ranges.len() > 20 {
271            return "\u{03c6}".to_owned();
272        }
273        return format!(
274            "{}",
275            ranges
276                .iter()
277                .map(|(s, e)| display_range(*s, *e))
278                .collect::<Vec<_>>()
279                .join("")
280        );
281    }
282
283    pub fn pp_first(&self, tset: &TSet) -> char {
284        let tryn1 = |i: usize| {
285            for j in 0..32 {
286                let nthbit = 1u64 << j;
287                if tset[i] & nthbit != 0 {
288                    let cc = (i * 64 + j) as u8 as char;
289                    return Some(cc);
290                }
291            }
292            None
293        };
294        let tryn2 = |i: usize| {
295            for j in 33..64 {
296                let nthbit = 1u64 << j;
297                if tset[i] & nthbit != 0 {
298                    let cc = (i * 64 + j) as u8 as char;
299                    return Some(cc);
300                }
301            }
302            None
303        };
304        // readable ones first
305        tryn2(0)
306            .or_else(|| tryn2(1))
307            .or_else(|| tryn1(1))
308            .or_else(|| tryn1(2))
309            .or_else(|| tryn2(2))
310            .or_else(|| tryn1(3))
311            .or_else(|| tryn2(3))
312            .or_else(|| tryn1(0))
313            .unwrap_or('\u{22a5}')
314    }
315
316    pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
317        let tset = self.get_set(tset);
318        Self::pp_collect_ranges(&tset).into_iter().collect()
319    }
320
321    #[allow(unused)]
322    fn first_byte(tset: &TSet) -> u8 {
323        for i in 0..4 {
324            for j in 0..64 {
325                let nthbit = 1u64 << j;
326                if tset[i] & nthbit != 0 {
327                    let cc = (i * 64 + j) as u8;
328                    return cc;
329                }
330            }
331        }
332        return 0;
333    }
334
335    pub fn pp(&self, tset: TSetId) -> String {
336        if tset == TSetId::FULL {
337            return "_".to_owned();
338        }
339        if tset == TSetId::EMPTY {
340            return "\u{22a5}".to_owned();
341        }
342        let tset = self.get_set(tset);
343        let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
344        let rstart = ranges.first().unwrap().0;
345        let rend = ranges.last().unwrap().1;
346        if ranges.len() >= 2 && rstart == 0 && rend == 255 {
347            let not_id = Self::not(&tset);
348            let not_ranges = Self::pp_collect_ranges(&not_id);
349            if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
350                return r".".to_owned();
351            }
352            let content = Self::pp_content(&not_ranges);
353            return format!("[^{}]", content);
354        }
355        if ranges.len() == 0 {
356            return "\u{22a5}".to_owned();
357        }
358        if ranges.len() == 1 {
359            let (s, e) = ranges.iter().next().unwrap();
360            if s == e {
361                return Self::pp_byte(*s);
362            } else {
363                let content = Self::pp_content(&ranges);
364                return format!("[{}]", content);
365            }
366        }
367        let content = Self::pp_content(&ranges);
368        return format!("[{}]", content);
369    }
370}
371
372impl Solver {
373    #[inline]
374    pub fn full() -> TSet {
375        FULL
376    }
377
378    #[inline]
379    pub fn empty() -> TSet {
380        EMPTY
381    }
382
383    #[inline]
384    pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
385        self.get_id(self.get_set(set1) | self.get_set(set2))
386    }
387
388    #[inline]
389    pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
390        self.get_id(self.get_set(set1) & self.get_set(set2))
391    }
392
393    #[inline]
394    pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
395        self.get_id(!self.get_set(set_id))
396    }
397
398    #[inline]
399    pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
400        self.and_id(set1, set2) != TSetId::EMPTY
401    }
402    #[inline]
403    pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
404        self.and_id(set1, set2) == TSetId::EMPTY
405    }
406
407    pub fn byte_count(&self, set_id: TSetId) -> u32 {
408        let tset = self.get_set(set_id);
409        (0..4).map(|i| tset[i].count_ones()).sum()
410    }
411
412    pub fn collect_bytes(&self, set_id: TSetId) -> Vec<u8> {
413        let tset = self.get_set(set_id);
414        let mut bytes = Vec::new();
415        for i in 0..4 {
416            let mut bits = tset[i];
417            while bits != 0 {
418                let j = bits.trailing_zeros() as usize;
419                bytes.push((i * 64 + j) as u8);
420                bits &= bits - 1;
421            }
422        }
423        bytes
424    }
425
426    pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
427        let tset = self.get_set(set_id);
428        let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
429        if total != 1 {
430            return None;
431        }
432        for i in 0..4 {
433            if tset[i] != 0 {
434                return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
435            }
436        }
437        None
438    }
439
440    #[inline]
441    pub fn is_empty_id(&self, set1: TSetId) -> bool {
442        set1 == TSetId::EMPTY
443    }
444
445    #[inline]
446    pub fn is_full_id(&self, set1: TSetId) -> bool {
447        set1 == TSetId::FULL
448    }
449
450    #[inline]
451    pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
452        let not_large = self.not_id(large_id);
453        self.and_id(small_id, not_large) == TSetId::EMPTY
454    }
455
456    pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
457        let mut result = TSet::splat(u64::MIN);
458        let nthbit = 1u64 << byte % 64;
459        match byte {
460            0..=63 => {
461                result[0] = nthbit;
462            }
463            64..=127 => {
464                result[1] = nthbit;
465            }
466            128..=191 => {
467                result[2] = nthbit;
468            }
469            192..=255 => {
470                result[3] = nthbit;
471            }
472        }
473        self.get_id(result)
474    }
475
476    pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
477        let mut result = TSet::splat(u64::MIN);
478        for byte in start..=end {
479            let nthbit = 1u64 << byte % 64;
480            match byte {
481                0..=63 => {
482                    result[0] |= nthbit;
483                }
484                64..=127 => {
485                    result[1] |= nthbit;
486                }
487                128..=191 => {
488                    result[2] |= nthbit;
489                }
490                192..=255 => {
491                    result[3] |= nthbit;
492                }
493            }
494        }
495        self.get_id(result)
496    }
497
498    #[inline]
499    pub fn and(set1: &TSet, set2: &TSet) -> TSet {
500        *set1 & *set2
501    }
502
503    #[inline]
504    pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
505        *set1 & *set2 != Solver::empty()
506    }
507
508    #[inline]
509    pub fn or(set1: &TSet, set2: &TSet) -> TSet {
510        *set1 | *set2
511    }
512
513    #[inline]
514    pub fn not(set: &TSet) -> TSet {
515        !*set
516    }
517
518    #[inline]
519    pub fn is_full(set: &TSet) -> bool {
520        *set == Self::full()
521    }
522
523    #[inline]
524    pub fn is_empty(set: &TSet) -> bool {
525        *set == Solver::empty()
526    }
527
528    #[inline]
529    pub fn contains(large: &TSet, small: &TSet) -> bool {
530        Solver::empty() == (*small & !*large)
531    }
532
533    pub fn u8_to_set(byte: u8) -> TSet {
534        let mut result = TSet::splat(u64::MIN);
535        let nthbit = 1u64 << byte % 64;
536        match byte {
537            0..=63 => {
538                result[0] = nthbit;
539            }
540            64..=127 => {
541                result[1] = nthbit;
542            }
543            128..=191 => {
544                result[2] = nthbit;
545            }
546            192..=255 => {
547                result[3] = nthbit;
548            }
549        }
550        result
551    }
552
553    pub fn range_to_set(start: u8, end: u8) -> TSet {
554        let mut result = TSet::splat(u64::MIN);
555        for byte in start..=end {
556            let nthbit = 1u64 << byte % 64;
557            match byte {
558                0..=63 => {
559                    result[0] |= nthbit;
560                }
561                64..=127 => {
562                    result[1] |= nthbit;
563                }
564                128..=191 => {
565                    result[2] |= nthbit;
566                }
567                192..=255 => {
568                    result[3] |= nthbit;
569                }
570            }
571        }
572        result
573    }
574}