Skip to main content

resharp_algebra/
solver.rs

1/// 256-bit vector for character sets, represented as 4 x u64.
2/// originally we used portable_simd, but **allegedly** this should be vectorized as well
3#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
4pub struct TSet(pub [u64; 4]);
5
6impl TSet {
7    #[inline]
8    pub const fn splat(v: u64) -> Self {
9        TSet([v, v, v, v])
10    }
11}
12
13impl std::ops::Index<usize> for TSet {
14    type Output = u64;
15    #[inline]
16    fn index(&self, i: usize) -> &u64 {
17        &self.0[i]
18    }
19}
20
21impl std::ops::IndexMut<usize> for TSet {
22    #[inline]
23    fn index_mut(&mut self, i: usize) -> &mut u64 {
24        &mut self.0[i]
25    }
26}
27
28impl std::ops::BitAnd for TSet {
29    type Output = TSet;
30    #[inline]
31    fn bitand(self, rhs: TSet) -> TSet {
32        TSet([
33            self.0[0] & rhs.0[0],
34            self.0[1] & rhs.0[1],
35            self.0[2] & rhs.0[2],
36            self.0[3] & rhs.0[3],
37        ])
38    }
39}
40
41impl std::ops::BitAnd for &TSet {
42    type Output = TSet;
43    #[inline]
44    fn bitand(self, rhs: &TSet) -> TSet {
45        TSet([
46            self.0[0] & rhs.0[0],
47            self.0[1] & rhs.0[1],
48            self.0[2] & rhs.0[2],
49            self.0[3] & rhs.0[3],
50        ])
51    }
52}
53
54impl std::ops::BitOr for TSet {
55    type Output = TSet;
56    #[inline]
57    fn bitor(self, rhs: TSet) -> TSet {
58        TSet([
59            self.0[0] | rhs.0[0],
60            self.0[1] | rhs.0[1],
61            self.0[2] | rhs.0[2],
62            self.0[3] | rhs.0[3],
63        ])
64    }
65}
66
67impl std::ops::Not for TSet {
68    type Output = TSet;
69    #[inline]
70    fn not(self) -> TSet {
71        TSet([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
72    }
73}
74
75// &TSet ops used by Solver helper methods
76impl std::ops::BitAnd<TSet> for &TSet {
77    type Output = TSet;
78    #[inline]
79    fn bitand(self, rhs: TSet) -> TSet {
80        TSet([
81            self.0[0] & rhs.0[0],
82            self.0[1] & rhs.0[1],
83            self.0[2] & rhs.0[2],
84            self.0[3] & rhs.0[3],
85        ])
86    }
87}
88
89impl std::ops::BitOr<TSet> for &TSet {
90    type Output = TSet;
91    #[inline]
92    fn bitor(self, rhs: TSet) -> TSet {
93        TSet([
94            self.0[0] | rhs.0[0],
95            self.0[1] | rhs.0[1],
96            self.0[2] | rhs.0[2],
97            self.0[3] | rhs.0[3],
98        ])
99    }
100}
101
102const EMPTY: TSet = TSet::splat(u64::MIN);
103const FULL: TSet = TSet::splat(u64::MAX);
104
105#[derive(Clone, Copy, PartialEq, Hash, Eq, Debug, PartialOrd, Ord)]
106pub struct TSetId(pub u32);
107impl TSetId {
108    pub const EMPTY: TSetId = TSetId(0);
109    pub const FULL: TSetId = TSetId(1);
110}
111
112use std::collections::{BTreeMap, BTreeSet};
113
114pub struct Solver {
115    cache: BTreeMap<TSet, TSetId>,
116    pub array: Vec<TSet>,
117}
118
119impl Solver {
120    pub fn new() -> Solver {
121        let mut inst = Self {
122            cache: BTreeMap::new(),
123            array: Vec::new(),
124        };
125        let _ = inst.init(Solver::empty()); // 0
126        let _ = inst.init(Solver::full()); // 1
127        inst
128    }
129
130    fn init(&mut self, inst: TSet) -> TSetId {
131        let new_id = TSetId(self.cache.len() as u32);
132        self.cache.insert(inst, new_id);
133        self.array.push(inst);
134        new_id
135    }
136
137    pub fn get_set(&self, set_id: TSetId) -> TSet {
138        self.array[set_id.0 as usize]
139    }
140
141    pub fn get_set_ref(&self, set_id: TSetId) -> &TSet {
142        &self.array[set_id.0 as usize]
143    }
144
145    pub fn get_id(&mut self, inst: TSet) -> TSetId {
146        match self.cache.get(&inst) {
147            Some(&id) => id,
148            None => self.init(inst),
149        }
150    }
151
152    pub fn has_bit_set(&mut self, set_id: TSetId, idx: usize, bit: u64) -> bool {
153        self.array[set_id.0 as usize][idx] & bit != 0
154    }
155
156    pub fn pp_collect_ranges(tset: &TSet) -> BTreeSet<(u8, u8)> {
157        let mut ranges: BTreeSet<(u8, u8)> = BTreeSet::new();
158        let mut rangestart: Option<u8> = None;
159        let mut prevchar: Option<u8> = None;
160        for i in 0..4 {
161            for j in 0..64 {
162                let nthbit = 1u64 << j;
163                if tset[i] & nthbit != 0 {
164                    let cc = (i * 64 + j) as u8;
165                    if rangestart.is_none() {
166                        rangestart = Some(cc);
167                        prevchar = Some(cc);
168                        continue;
169                    }
170
171                    if let Some(currstart) = rangestart {
172                        if let Some(currprev) = prevchar {
173                            if currprev as u8 == cc as u8 - 1 {
174                                prevchar = Some(cc);
175                                continue;
176                            } else {
177                                if currstart == currprev {
178                                    ranges.insert((currstart, currstart));
179                                } else {
180                                    ranges.insert((currstart, currprev));
181                                }
182                                rangestart = Some(cc);
183                                prevchar = Some(cc);
184                            }
185                        } else {
186                        }
187                    } else {
188                    }
189                }
190            }
191        }
192        if let Some(start) = rangestart {
193            if let Some(prevchar) = prevchar {
194                if prevchar as u8 == start as u8 {
195                    ranges.insert((start, start));
196                } else {
197                    ranges.insert((start, prevchar));
198                }
199            } else {
200                // single char
201                ranges.insert((start, start));
202            }
203        }
204        ranges
205    }
206
207    fn pp_byte(b: u8) -> String {
208        if cfg!(feature = "graphviz") {
209            match b as char {
210                // graphviz doesnt like \n so we use \ṅ
211                '\n' => return r"\ṅ".to_owned(),
212                '"' => return r"\u{201c}".to_owned(),
213                '\r' => return r"\r".to_owned(),
214                '\t' => return r"\t".to_owned(),
215                _ => {}
216            }
217        }
218        match b as char {
219            '\n' => r"\n".to_owned(),
220            '\r' => r"\r".to_owned(),
221            '\t' => r"\t".to_owned(),
222            ' ' => r" ".to_owned(),
223            '_' | '.' | '+' | '-' | '\\' | '&' | '|' | '~' | '{' | '}' | '[' | ']' | '(' | ')'
224            | '*' | '?' | '^' | '$' => r"\".to_owned() + &(b as char).to_string(),
225            c if c.is_ascii_punctuation() || c.is_ascii_alphanumeric() => c.to_string(),
226            _ => format!("\\x{:02X}", b),
227        }
228    }
229
230    fn pp_content(ranges: &BTreeSet<(u8, u8)>) -> String {
231        let display_range = |c, c2| {
232            if c == c2 {
233                Self::pp_byte(c)
234            } else if c.abs_diff(c2) == 1 {
235                format!("{}{}", Self::pp_byte(c), Self::pp_byte(c2))
236            } else {
237                format!("{}-{}", Self::pp_byte(c), Self::pp_byte(c2))
238            }
239        };
240
241        if ranges.len() == 0 {
242            return "\u{22a5}".to_owned();
243        }
244        if ranges.len() == 1 {
245            let (s, e) = ranges.iter().next().unwrap();
246            if s == e {
247                return Self::pp_byte(*s);
248            } else {
249                return format!(
250                    "{}",
251                    ranges
252                        .iter()
253                        .map(|(s, e)| display_range(*s, *e))
254                        .collect::<Vec<_>>()
255                        .join("")
256                );
257            }
258        }
259        if ranges.len() > 20 {
260            return "\u{03c6}".to_owned();
261        }
262        return format!(
263            "{}",
264            ranges
265                .iter()
266                .map(|(s, e)| display_range(*s, *e))
267                .collect::<Vec<_>>()
268                .join("")
269        );
270    }
271
272    /// pretty print a character from the set
273    pub fn pp_first(&self, tset: &TSet) -> char {
274        let tryn1 = |i: usize| {
275            for j in 0..32 {
276                let nthbit = 1u64 << j;
277                if tset[i] & nthbit != 0 {
278                    let cc = (i * 64 + j) as u8 as char;
279                    return Some(cc);
280                }
281            }
282            None
283        };
284        let tryn2 = |i: usize| {
285            for j in 33..64 {
286                let nthbit = 1u64 << j;
287                if tset[i] & nthbit != 0 {
288                    let cc = (i * 64 + j) as u8 as char;
289                    return Some(cc);
290                }
291            }
292            None
293        };
294        // readable ones first
295        tryn2(0)
296            .or_else(|| tryn2(1))
297            .or_else(|| tryn1(1))
298            .or_else(|| tryn1(2))
299            .or_else(|| tryn2(2))
300            .or_else(|| tryn1(3))
301            .or_else(|| tryn2(3))
302            .or_else(|| tryn1(0))
303            .unwrap_or('\u{22a5}')
304    }
305
306    pub fn byte_ranges(&self, tset: TSetId) -> Vec<(u8, u8)> {
307        let tset = self.get_set(tset);
308        Self::pp_collect_ranges(&tset).into_iter().collect()
309    }
310
311    #[allow(unused)]
312    fn first_byte(tset: &TSet) -> u8 {
313        for i in 0..4 {
314            for j in 0..64 {
315                let nthbit = 1u64 << j;
316                if tset[i] & nthbit != 0 {
317                    let cc = (i * 64 + j) as u8;
318                    return cc;
319                }
320            }
321        }
322        return 0;
323    }
324
325    pub fn pp(&self, tset: TSetId) -> String {
326        if tset == TSetId::FULL {
327            return "_".to_owned();
328        }
329        if tset == TSetId::EMPTY {
330            return "\u{22a5}".to_owned();
331        }
332        let tset = self.get_set(tset);
333        let ranges: BTreeSet<(u8, u8)> = Self::pp_collect_ranges(&tset);
334        let rstart = ranges.first().unwrap().0;
335        let rend = ranges.last().unwrap().1;
336        if ranges.len() >= 2 && rstart == 0 && rend == 255 {
337            let not_id = Self::not(&tset);
338            let not_ranges = Self::pp_collect_ranges(&not_id);
339            if not_ranges.len() == 1 && not_ranges.iter().next() == Some(&(10, 10)) {
340                return r".".to_owned();
341            }
342            let content = Self::pp_content(&not_ranges);
343            return format!("[^{}]", content);
344        }
345        if ranges.len() == 0 {
346            return "\u{22a5}".to_owned();
347        }
348        if ranges.len() == 1 {
349            let (s, e) = ranges.iter().next().unwrap();
350            if s == e {
351                return Self::pp_byte(*s);
352            } else {
353                let content = Self::pp_content(&ranges);
354                return format!("[{}]", content);
355            }
356        }
357        let content = Self::pp_content(&ranges);
358        return format!("[{}]", content);
359    }
360}
361
362impl Solver {
363    #[inline]
364    pub fn full() -> TSet {
365        FULL
366    }
367
368    #[inline]
369    pub fn empty() -> TSet {
370        EMPTY
371    }
372
373    #[inline]
374    pub fn or_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
375        self.get_id(self.get_set(set1) | self.get_set(set2))
376    }
377
378    #[inline]
379    pub fn and_id(&mut self, set1: TSetId, set2: TSetId) -> TSetId {
380        self.get_id(self.get_set(set1) & self.get_set(set2))
381    }
382
383    #[inline]
384    pub fn not_id(&mut self, set_id: TSetId) -> TSetId {
385        self.get_id(!self.get_set(set_id))
386    }
387
388    #[inline]
389    pub fn is_sat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
390        self.and_id(set1, set2) != TSetId::EMPTY
391    }
392    #[inline]
393    pub fn unsat_id(&mut self, set1: TSetId, set2: TSetId) -> bool {
394        self.and_id(set1, set2) == TSetId::EMPTY
395    }
396
397    /// if the set contains exactly one byte, return it
398    pub fn single_byte(&self, set_id: TSetId) -> Option<u8> {
399        let tset = self.get_set(set_id);
400        let total: u32 = (0..4).map(|i| tset[i].count_ones()).sum();
401        if total != 1 {
402            return None;
403        }
404        for i in 0..4 {
405            if tset[i] != 0 {
406                return Some((i * 64 + tset[i].trailing_zeros() as usize) as u8);
407            }
408        }
409        None
410    }
411
412    #[inline]
413    pub fn is_empty_id(&self, set1: TSetId) -> bool {
414        set1 == TSetId::EMPTY
415    }
416
417    #[inline]
418    pub fn is_full_id(&self, set1: TSetId) -> bool {
419        set1 == TSetId::FULL
420    }
421
422    #[inline]
423    pub fn contains_id(&mut self, large_id: TSetId, small_id: TSetId) -> bool {
424        let not_large = self.not_id(large_id);
425        self.and_id(small_id, not_large) == TSetId::EMPTY
426    }
427
428    pub fn u8_to_set_id(&mut self, byte: u8) -> TSetId {
429        let mut result = TSet::splat(u64::MIN);
430        let nthbit = 1u64 << byte % 64;
431        match byte {
432            0..=63 => {
433                result[0] = nthbit;
434            }
435            64..=127 => {
436                result[1] = nthbit;
437            }
438            128..=191 => {
439                result[2] = nthbit;
440            }
441            192..=255 => {
442                result[3] = nthbit;
443            }
444        }
445        self.get_id(result)
446    }
447
448    pub fn range_to_set_id(&mut self, start: u8, end: u8) -> TSetId {
449        let mut result = TSet::splat(u64::MIN);
450        for byte in start..=end {
451            let nthbit = 1u64 << byte % 64;
452            match byte {
453                0..=63 => {
454                    result[0] |= nthbit;
455                }
456                64..=127 => {
457                    result[1] |= nthbit;
458                }
459                128..=191 => {
460                    result[2] |= nthbit;
461                }
462                192..=255 => {
463                    result[3] |= nthbit;
464                }
465            }
466        }
467        self.get_id(result)
468    }
469
470    #[inline]
471    pub fn and(set1: &TSet, set2: &TSet) -> TSet {
472        *set1 & *set2
473    }
474
475    #[inline]
476    pub fn is_sat(set1: &TSet, set2: &TSet) -> bool {
477        *set1 & *set2 != Solver::empty()
478    }
479
480    #[inline]
481    pub fn or(set1: &TSet, set2: &TSet) -> TSet {
482        *set1 | *set2
483    }
484
485    #[inline]
486    pub fn not(set: &TSet) -> TSet {
487        !*set
488    }
489
490    #[inline]
491    pub fn is_full(set: &TSet) -> bool {
492        *set == Self::full()
493    }
494
495    #[inline]
496    pub fn is_empty(set: &TSet) -> bool {
497        *set == Solver::empty()
498    }
499
500    #[inline]
501    pub fn contains(large: &TSet, small: &TSet) -> bool {
502        Solver::empty() == (*small & !*large)
503    }
504
505    pub fn u8_to_set(byte: u8) -> TSet {
506        let mut result = TSet::splat(u64::MIN);
507        let nthbit = 1u64 << byte % 64;
508        match byte {
509            0..=63 => {
510                result[0] = nthbit;
511            }
512            64..=127 => {
513                result[1] = nthbit;
514            }
515            128..=191 => {
516                result[2] = nthbit;
517            }
518            192..=255 => {
519                result[3] = nthbit;
520            }
521        }
522        result
523    }
524
525    pub fn range_to_set(start: u8, end: u8) -> TSet {
526        let mut result = TSet::splat(u64::MIN);
527        for byte in start..=end {
528            let nthbit = 1u64 << byte % 64;
529            match byte {
530                0..=63 => {
531                    result[0] |= nthbit;
532                }
533                64..=127 => {
534                    result[1] |= nthbit;
535                }
536                128..=191 => {
537                    result[2] |= nthbit;
538                }
539                192..=255 => {
540                    result[3] |= nthbit;
541                }
542            }
543        }
544        result
545    }
546}