use std::{
cmp::{max, min, Ordering},
fmt::Display,
};
use crate::{
errors::Error,
smt_strings::{char_to_smt, MAX_CHAR},
};
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub struct CharSet {
start: u32,
end: u32,
}
impl Display for CharSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let a = self.start;
let b = self.end;
if a == b {
write!(f, "{}", char_to_smt(a))
} else if a == 0 && b == MAX_CHAR {
write!(f, "\u{03a3}") } else {
write!(f, "[{}..{}]", char_to_smt(a), char_to_smt(b))
}
}
}
impl PartialOrd for CharSet {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self == other {
Some(Ordering::Equal)
} else if self.end < other.start {
Some(Ordering::Less)
} else if self.start > other.end {
Some(Ordering::Greater)
} else {
None
}
}
}
impl CharSet {
pub fn singleton(x: u32) -> CharSet {
debug_assert!(x <= MAX_CHAR);
CharSet { start: x, end: x }
}
pub fn range(x: u32, y: u32) -> CharSet {
debug_assert!(x <= y && y <= MAX_CHAR);
CharSet { start: x, end: y }
}
pub fn all_chars() -> CharSet {
CharSet {
start: 0,
end: MAX_CHAR,
}
}
pub fn contains(&self, x: u32) -> bool {
self.start <= x && x <= self.end
}
pub fn covers(&self, other: &CharSet) -> bool {
debug_assert!(other.start <= other.end);
self.start <= other.start && other.end <= self.end
}
pub fn is_before(&self, x: u32) -> bool {
self.end < x
}
pub fn is_after(&self, x: u32) -> bool {
x < self.start
}
pub fn size(&self) -> u32 {
self.end - self.start + 1
}
pub fn is_singleton(&self) -> bool {
self.start == self.end
}
pub fn is_alphabet(&self) -> bool {
self.start == 0 && self.end == MAX_CHAR
}
pub fn pick(&self) -> u32 {
self.start
}
pub fn inter(&self, other: &CharSet) -> Option<CharSet> {
let max_start = max(self.start, other.start);
let min_end = min(self.end, other.end);
if max_start <= min_end {
Some(Self::range(max_start, min_end))
} else {
None
}
}
pub fn inter_list(a: &[CharSet]) -> Option<CharSet> {
if a.is_empty() {
Some(Self::all_chars())
} else {
let mut result = a[0];
for s in &a[1..] {
match result.inter(s) {
None => return None,
Some(x) => result = x,
}
}
Some(result)
}
}
pub fn union(&self, other: &CharSet) -> Option<CharSet> {
let max_end = max(self.end, other.end);
if self.start == other.start || (self.start < other.start && self.end >= other.start - 1) {
Some(Self::range(self.start, max_end))
} else if other.start < self.start && other.end >= self.start - 1 {
Some(Self::range(other.start, max_end))
} else {
None
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum CoverResult {
CoveredBy(usize),
DisjointFromAll,
Overlaps,
}
impl Display for CoverResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CoverResult::CoveredBy(i) => write!(f, "CoveredBy({})", i),
CoverResult::DisjointFromAll => write!(f, "DisjointFromAll"),
CoverResult::Overlaps => write!(f, "Overlaps"),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum ClassId {
Interval(usize),
Complement,
}
impl Display for ClassId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClassId::Interval(i) => write!(f, "Interval({})", i),
ClassId::Complement => write!(f, "Complement"),
}
}
}
#[derive(Debug, PartialEq, Eq, Default, Clone)]
pub struct CharPartition {
list: Vec<CharSet>,
comp_witness: u32,
}
impl CharPartition {
pub fn len(&self) -> usize {
self.list.len()
}
pub fn is_empty(&self) -> bool {
self.list.is_empty()
}
pub fn new() -> Self {
CharPartition::default()
}
pub fn from_set(c: &CharSet) -> Self {
let witness = if c.start > 0 { 0 } else { c.end + 1 };
CharPartition {
list: vec![*c],
comp_witness: witness,
}
}
pub fn try_from_iter(iter: impl Iterator<Item = CharSet>) -> Result<Self, Error> {
let mut v: Vec<CharSet> = iter.collect();
let mut comp_witness = 0;
if !v.is_empty() {
v.sort_by_key(|c| c.start);
let mut prev = &v[0];
if prev.start <= comp_witness {
comp_witness = prev.end + 1;
}
for c in &v[1..] {
if c.start <= prev.end {
return Err(Error::NonDisjointCharSets);
}
if c.start <= comp_witness {
comp_witness = c.end + 1;
}
prev = c;
}
}
Ok(CharPartition {
list: v,
comp_witness,
})
}
pub fn try_from_list(a: &[CharSet]) -> Result<Self, Error> {
Self::try_from_iter(a.iter().copied())
}
pub fn push(&mut self, start: u32, end: u32) {
debug_assert!(start <= end && end <= MAX_CHAR);
debug_assert!(self.list.is_empty() || start > self.list.last().unwrap().end);
self.list.push(CharSet { start, end });
if start <= self.comp_witness {
self.comp_witness = end + 1;
}
}
pub fn get(&self, i: usize) -> (u32, u32) {
if i < self.list.len() {
let r = &self.list[i];
(r.start, r.end)
} else {
(MAX_CHAR + 1, MAX_CHAR + 1)
}
}
pub fn interval(&self, i: usize) -> CharSet {
self.list[i]
}
pub fn start(&self, i: usize) -> u32 {
if i < self.len() {
self.list[i].start
} else {
MAX_CHAR + 1
}
}
pub fn end(&self, index: usize) -> u32 {
if index < self.len() {
self.list[index].end
} else {
MAX_CHAR + 1
}
}
pub fn pick(&self, i: usize) -> u32 {
self.list[i].start
}
pub fn empty_complement(&self) -> bool {
self.comp_witness > MAX_CHAR
}
pub fn pick_complement(&self) -> u32 {
self.comp_witness
}
pub fn valid_class_id(&self, cid: ClassId) -> bool {
use ClassId::*;
match cid {
Interval(i) => i < self.len(),
Complement => !self.empty_complement(),
}
}
pub fn num_classes(&self) -> usize {
let n = self.len();
if self.empty_complement() {
n
} else {
n + 1
}
}
pub fn pick_in_class(&self, cid: ClassId) -> u32 {
use ClassId::*;
match cid {
Interval(i) => self.pick(i),
Complement => {
assert!(!self.empty_complement());
self.pick_complement()
}
}
}
pub fn ranges(&self) -> impl Iterator<Item = &CharSet> {
self.list.iter()
}
pub fn class_ids(&self) -> ClassIdIterator<'_> {
ClassIdIterator {
partition: self,
counter: 0,
}
}
pub fn picks(&self) -> PickIterator<'_> {
PickIterator {
partition: self,
counter: 0,
}
}
pub fn class_of_char(&self, x: u32) -> ClassId {
#[allow(clippy::many_single_char_names)]
fn binary_search(p: &[CharSet], x: u32) -> ClassId {
let mut i = 0;
let mut j = p.len();
while i < j {
let h = i + (j - i) / 2;
if p[h].contains(x) {
return ClassId::Interval(h);
}
if p[h].is_before(x) {
i = h + 1;
} else {
j = h;
}
}
ClassId::Complement
}
binary_search(&self.list, x)
}
pub fn interval_cover(&self, set: &CharSet) -> CoverResult {
#[allow(clippy::many_single_char_names)]
fn binary_search(p: &[CharSet], x: u32) -> usize {
let mut i = 0;
let mut j = p.len();
while i + 1 < j {
let h = i + (j - i) / 2;
if p[h].start <= x {
i = h
} else {
j = h
}
}
i
}
let a = set.start;
let b = set.end;
debug_assert!(a <= b && b <= MAX_CHAR);
let i = binary_search(&self.list, a);
let (a_i, b_i) = self.get(i);
if a < a_i {
debug_assert!(i == 0);
if b < a_i {
CoverResult::DisjointFromAll
} else {
CoverResult::Overlaps
}
} else if a <= b_i {
if b <= b_i {
CoverResult::CoveredBy(i)
} else {
CoverResult::Overlaps
}
} else {
let next_ai = self.end(i + 1);
if b < next_ai {
CoverResult::DisjointFromAll
} else {
CoverResult::Overlaps
}
}
}
pub fn class_of_set(&self, s: &CharSet) -> Result<ClassId, Error> {
use ClassId::*;
use CoverResult::*;
match self.interval_cover(s) {
CoveredBy(i) => Ok(Interval(i)),
DisjointFromAll => Ok(Complement),
Overlaps => Err(Error::AmbiguousCharSet),
}
}
pub fn good_char_set(&self, c: &CharSet) -> bool {
!matches!(self.interval_cover(c), CoverResult::Overlaps)
}
}
impl Display for CharPartition {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{{ ")?;
for r in self.ranges() {
write!(f, "{} ", r)?;
}
write!(f, "}}")
}
}
#[derive(Debug)]
pub struct ClassIdIterator<'a> {
partition: &'a CharPartition,
counter: usize,
}
impl<'a> Iterator for ClassIdIterator<'a> {
type Item = ClassId;
fn next(&mut self) -> Option<Self::Item> {
let i = self.counter;
self.counter += 1;
if i < self.partition.len() {
Some(ClassId::Interval(i))
} else if i == self.partition.len() && !self.partition.empty_complement() {
Some(ClassId::Complement)
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let mut size = self.partition.len();
if !self.partition.empty_complement() {
size += 1;
}
(size, Some(size))
}
}
#[derive(Debug)]
pub struct PickIterator<'a> {
partition: &'a CharPartition,
counter: usize,
}
impl<'a> Iterator for PickIterator<'a> {
type Item = u32;
fn next(&mut self) -> Option<u32> {
let i = self.counter;
self.counter += 1;
if i < self.partition.len() {
Some(self.partition.pick(i))
} else if i == self.partition.len() && !self.partition.empty_complement() {
Some(self.partition.pick_complement())
} else {
None
}
}
}
#[allow(clippy::many_single_char_names)]
pub fn merge_partitions(p1: &CharPartition, p2: &CharPartition) -> CharPartition {
fn next_interval(p: &CharPartition, i: usize) -> (usize, u32, u32) {
let (x, y) = p.get(i);
(i + 1, x, y)
}
let mut triple1 = next_interval(p1, 0);
let mut triple2 = next_interval(p2, 0);
let mut result = CharPartition::new();
while triple1.2 <= MAX_CHAR || triple2.2 <= MAX_CHAR {
let (i, a, b) = triple1;
let (j, c, d) = triple2;
if b < c {
result.push(a, b);
triple1 = next_interval(p1, i);
} else if d < a {
result.push(c, d);
triple2 = next_interval(p2, j);
} else if c < a {
result.push(c, a - 1);
triple2.1 = a; } else if a < c {
result.push(a, c - 1);
triple1.1 = c; } else if b < d {
result.push(a, b);
triple1 = next_interval(p1, i);
triple2.1 = b + 1;
} else if d < b {
result.push(c, d);
triple1.1 = d + 1;
triple2 = next_interval(p2, j);
} else {
result.push(a, b);
triple1 = next_interval(p1, i);
triple2 = next_interval(p2, j);
}
}
result
}
pub fn merge_partition_list<'a>(list: impl Iterator<Item = &'a CharPartition>) -> CharPartition {
let mut result = CharPartition::new();
for p in list {
result = merge_partitions(&result, p)
}
result
}
#[cfg(test)]
mod test {
use super::*;
fn good_partition(p: &CharPartition) -> bool {
let mut prev_end = MAX_CHAR + 1;
for s in p.ranges() {
if s.start > s.end {
return false;
}
if prev_end <= MAX_CHAR && s.start <= prev_end {
return false;
}
if s.start <= p.comp_witness && p.comp_witness <= s.end {
return false;
}
prev_end = s.end;
}
true
}
fn example1() -> CharPartition {
let mut p = CharPartition::new();
p.push('0' as u32, '9' as u32);
p.push('Z' as u32, 'Z' as u32);
p.push('f' as u32, 'q' as u32);
p
}
fn example2() -> CharPartition {
let mut p = CharPartition::new();
p.push('0' as u32, '0' as u32);
p.push('A' as u32, 'G' as u32);
p.push('H' as u32, 'M' as u32);
p.push('W' as u32, 'Z' as u32);
p.push('a' as u32, 'n' as u32);
p.push('q' as u32, 'r' as u32);
p
}
#[test]
fn test_simple() {
let p1 = CharPartition::new();
let p2 = example1();
let p3 = example1();
let p4 = example2();
let p5 = CharPartition::from_set(&CharSet::all_chars());
assert!(good_partition(&p1));
assert!(good_partition(&p2));
assert!(good_partition(&p4));
assert!(good_partition(&p5));
assert!(!p1.empty_complement());
assert!(p1.pick_complement() == 0);
assert!(!p2.empty_complement());
assert!(p2.pick_complement() == 0);
assert!(!p4.empty_complement());
assert!(p4.pick_complement() == 0);
assert!(p5.empty_complement());
assert_eq!(&p2, &p3);
assert_ne!(&p2, &p4);
assert_ne!(&p1, &p2);
assert_ne!(&p1, &p4);
assert_ne!(&p1, &p5);
println!("Empty partition: {}", &p1);
println!("Example1: {}", &p2);
println!("Example2: {}", &p4);
println!("All chars: {}", &p5);
}
#[test]
fn test_from_list() {
let v = [
CharSet::range(120, 400),
CharSet::range(0, 10),
CharSet::range(1000, 2000),
];
match CharPartition::try_from_list(&v) {
Ok(p) => {
println!("From list succeeded: {}", &p);
assert_eq!(p.len(), 3);
assert_eq!(p.get(0), (0, 10));
assert_eq!(p.get(1), (120, 400));
assert_eq!(p.get(2), (1000, 2000));
assert!(good_partition(&p));
}
Err(e) => panic!("Partition::try_from_list failed with error {}", e),
}
let w = [
CharSet::range(120, 400),
CharSet::range(1000, 2000),
CharSet::range(0, 10),
CharSet::range(100, 200),
];
match CharPartition::try_from_list(&w) {
Ok(_) => panic!("Partition::try_from_list should have failed"),
Err(e) => println!(
"Partition::try_from_list failed with error {} as expected",
e
),
}
}
#[test]
fn test_search() {
use super::ClassId::*;
let p = CharPartition::new();
assert_eq!(p.class_of_char('a' as u32), Complement);
assert_eq!(p.class_of_char(0), Complement);
assert_eq!(p.class_of_char(MAX_CHAR), Complement);
let p2 = example1();
assert_eq!(p2.class_of_char(10), Complement);
assert_eq!(p2.class_of_char('0' as u32), Interval(0));
assert_eq!(p2.class_of_char('5' as u32), Interval(0));
assert_eq!(p2.class_of_char('9' as u32), Interval(0));
assert_eq!(p2.class_of_char('A' as u32), Complement);
assert_eq!(p2.class_of_char('Z' as u32), Interval(1));
assert_eq!(p2.class_of_char('e' as u32), Complement);
assert_eq!(p2.class_of_char('g' as u32), Interval(2));
assert_eq!(p2.class_of_char('z' as u32), Complement);
let p3 = example2();
assert_eq!(p3.class_of_char(10), Complement);
assert_eq!(p3.class_of_char('0' as u32), Interval(0));
assert_eq!(p3.class_of_char('5' as u32), Complement);
assert_eq!(p3.class_of_char('9' as u32), Complement);
assert_eq!(p3.class_of_char('A' as u32), Interval(1));
assert_eq!(p3.class_of_char('F' as u32), Interval(1));
assert_eq!(p3.class_of_char('G' as u32), Interval(1));
assert_eq!(p3.class_of_char('H' as u32), Interval(2));
assert_eq!(p3.class_of_char('L' as u32), Interval(2));
assert_eq!(p3.class_of_char('O' as u32), Complement);
assert_eq!(p3.class_of_char('W' as u32), Interval(3));
assert_eq!(p3.class_of_char('Z' as u32), Interval(3));
assert_eq!(p3.class_of_char('^' as u32), Complement);
assert_eq!(p3.class_of_char('e' as u32), Interval(4));
assert_eq!(p3.class_of_char('g' as u32), Interval(4));
assert_eq!(p3.class_of_char('p' as u32), Complement);
assert_eq!(p3.class_of_char('q' as u32), Interval(5));
assert_eq!(p3.class_of_char('r' as u32), Interval(5));
assert_eq!(p3.class_of_char('s' as u32), Complement);
assert_eq!(p3.class_of_char('z' as u32), Complement);
let p4 = CharPartition::from_set(&CharSet::all_chars());
assert_eq!(p4.class_of_char(0), Interval(0));
assert_eq!(p4.class_of_char(MAX_CHAR), Interval(0));
}
#[test]
fn test_merge() {
let v = vec![CharPartition::new(), example1(), example2()];
for p in &v {
for q in &v {
let m = merge_partitions(p, q);
println!("Merge({}, {}) = {}", p, q, &m);
assert!(good_partition(&m));
if p.is_empty() {
assert_eq!(&m, q);
}
if q.is_empty() {
assert_eq!(&m, p);
}
if p == q {
assert_eq!(&m, p);
}
}
}
}
#[test]
fn test_inter() {
let a = CharSet::singleton(0);
let b = CharSet::range(1, 20);
let c = CharSet::range(30, 60);
let d = CharSet::range(0, 30);
assert_eq!(a.inter(&a), Some(a));
assert_eq!(a.inter(&b), None);
assert_eq!(a.inter(&c), None);
assert_eq!(a.inter(&d), Some(a));
assert_eq!(b.inter(&a), None);
assert_eq!(b.inter(&b), Some(b));
assert_eq!(b.inter(&c), None);
assert_eq!(b.inter(&d), Some(b));
assert_eq!(c.inter(&d), Some(CharSet::singleton(30)));
}
#[test]
fn test_union() {
let a = CharSet::singleton(0);
let b = CharSet::range(1, 20);
let c = CharSet::range(30, 60);
let d = CharSet::range(0, 30);
assert_eq!(a.union(&a), Some(a));
assert_eq!(a.union(&b), Some(CharSet::range(0, 20)));
assert_eq!(a.union(&c), None);
assert_eq!(a.union(&d), Some(d));
assert_eq!(b.union(&a), Some(CharSet::range(0, 20)));
assert_eq!(b.union(&b), Some(b));
assert_eq!(b.union(&c), None);
assert_eq!(b.union(&d), Some(d));
assert_eq!(c.union(&d), Some(CharSet::range(0, 60)));
}
#[test]
fn test_cover() {
let v = vec![CharPartition::new(), example1(), example2()];
let i = vec![
CharSet::singleton('a' as u32),
CharSet::range('0' as u32, '9' as u32),
CharSet::range('a' as u32, 'z' as u32),
CharSet::range('h' as u32, 'q' as u32),
CharSet::singleton('f' as u32),
CharSet::singleton('q' as u32),
CharSet::all_chars(),
CharSet::singleton(0),
CharSet::singleton(MAX_CHAR),
CharSet::range(0, 'z' as u32),
CharSet::range('z' as u32, MAX_CHAR),
];
fn check_covered(p: &CharPartition, test: &CharSet, i: usize) -> bool {
i < p.len() && p.start(i) <= test.start && test.end <= p.end(i)
}
fn check_disjoint(p: &[CharSet], test: &CharSet) -> bool {
p.iter()
.all(|set| test.end < set.start || set.end < test.start)
}
fn check_overlap(p: &[CharSet], test: &CharSet) -> bool {
p.iter().any(|set| {
(test.start < set.start && set.start <= test.end)
|| (test.start <= set.end && set.end < test.end)
})
}
for p in &v {
println!("Partition: {}", p);
for set in &i {
let c = p.interval_cover(set);
println!("Cover for {} = {}", set, c);
match c {
CoverResult::CoveredBy(i) => {
assert!(check_covered(p, set, i))
}
CoverResult::DisjointFromAll => {
assert!(check_disjoint(&p.list, set))
}
CoverResult::Overlaps => {
assert!(check_overlap(&p.list, set))
}
}
}
println!();
}
}
}