unicode_intervals/
query.rs

1use crate::{
2    categories,
3    categories::UnicodeCategorySet,
4    constants::{ALL_CATEGORIES, MAX_CODEPOINT},
5    intervals, Interval, UnicodeVersion,
6};
7use core::cmp::{max, min};
8use std::borrow::Cow;
9
10/// Non-generic query implementation to reduce the amount of generated code.
11#[must_use]
12pub fn query(
13    version: UnicodeVersion,
14    include_categories: Option<UnicodeCategorySet>,
15    exclude_categories: UnicodeCategorySet,
16    include_characters: &str,
17    exclude_characters: &str,
18    min_codepoint: u32,
19    max_codepoint: u32,
20) -> Vec<Interval> {
21    let categories = categories::merge(include_categories, exclude_categories);
22
23    let include_intervals = intervals::from_str(include_characters);
24    let exclude_intervals = intervals::from_str(exclude_characters);
25
26    let full = intervals_for_set(version, categories);
27    // Depending on the codepoint range, it could be less work to do
28    let mut intervals = match (min_codepoint, max_codepoint) {
29        // Full range, no need to filter
30        (0, MAX_CODEPOINT) => full.to_vec(),
31        // Only check for the left bound
32        (0, _) => {
33            let mut intervals = vec![];
34            for (left, right) in full.iter().copied() {
35                if left <= max_codepoint {
36                    intervals.push((max(left, min_codepoint), min(right, max_codepoint)));
37                }
38            }
39            intervals
40        }
41        // Only check for the right bound
42        (_, MAX_CODEPOINT) => {
43            let mut intervals = vec![];
44            for (left, right) in full.iter().copied() {
45                if right >= min_codepoint {
46                    intervals.push((max(left, min_codepoint), min(right, max_codepoint)));
47                }
48            }
49            intervals
50        }
51        // Check for both bounds
52        _ => {
53            let mut intervals = vec![];
54            for (left, right) in full.iter().copied() {
55                if left <= max_codepoint && right >= min_codepoint {
56                    intervals.push((max(left, min_codepoint), min(right, max_codepoint)));
57                }
58            }
59            intervals
60        }
61    };
62    // Include intervals
63    if intervals.is_empty() {
64        intervals = include_intervals;
65    } else if !include_intervals.is_empty() {
66        intervals.extend_from_slice(&include_intervals);
67        intervals::merge(&mut intervals);
68    } else {
69        intervals::merge(&mut intervals);
70    }
71    // Exclude intervals
72    intervals::subtract(intervals, exclude_intervals.as_slice())
73}
74
75/// Get intervals for the given `CategorySet`.
76/// The final intervals are merged and sorted.
77#[inline]
78#[must_use]
79pub fn intervals_for_set(
80    version: UnicodeVersion,
81    categories: UnicodeCategorySet,
82) -> Cow<'static, [Interval]> {
83    match categories.into_value() {
84        0 => Cow::Borrowed(&[]),
85        ALL_CATEGORIES => Cow::Borrowed(&[(0, MAX_CODEPOINT)]),
86        value => {
87            if categories.len() == 1 {
88                let category_idx = value.trailing_zeros() as usize;
89                Cow::Borrowed(version.table()[category_idx])
90            } else {
91                // Pre-allocate space for intervals from all categories
92                let size: usize = categories
93                    .iter()
94                    .map(|c| version.table()[c as usize].len())
95                    .sum();
96                let mut intervals = Vec::with_capacity(size);
97                for category in categories.iter() {
98                    intervals.extend_from_slice(version.table()[category as usize]);
99                }
100                Cow::Owned(intervals)
101            }
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::UnicodeCategory;
110    use test_case::test_case;
111
112    #[test_case(UnicodeCategorySet::new(), &[])]
113    #[test_case(UnicodeCategorySet::all(), &[(0, MAX_CODEPOINT)])]
114    #[test_case(UnicodeCategory::Zl.into(), &[(8232, 8232)])]
115    #[test_case(UnicodeCategory::Zl | UnicodeCategory::Cs, &[(8232, 8232), (55296, 57343)])]
116    fn test_intervals_for_set(categories: UnicodeCategorySet, expected: &[Interval]) {
117        let intervals = intervals_for_set(UnicodeVersion::V15_0_0, categories);
118        assert_eq!(intervals, expected);
119    }
120}