Skip to main content

provenant/license_detection/
token_set.rs

1// SPDX-FileCopyrightText: Provenant contributors
2// SPDX-License-Identifier: Apache-2.0
3
4use smallvec::SmallVec;
5use std::cmp::Ordering;
6use std::ops::Deref;
7
8use rkyv::Archive;
9
10use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
11
12/// A set of token IDs stored as a sorted SmallVec.
13///
14/// Invariant: elements are always sorted and deduplicated.
15/// Construct via `TokenSet::from_token_ids()`, `TokenSet::from_u16_iter()`,
16/// or `.collect()` from an iterator of u16.
17#[derive(Clone, Debug, PartialEq, Eq, Archive, rkyv::Serialize, rkyv::Deserialize)]
18pub struct TokenSet(SmallVec<[u16; 64]>);
19
20impl TokenSet {
21    /// Create a TokenSet from an iterator of u16 token IDs.
22    /// Sorts and deduplicates the input.
23    pub fn from_u16_iter<I: IntoIterator<Item = u16>>(iter: I) -> Self {
24        let mut inner: SmallVec<[u16; 64]> = iter.into_iter().collect();
25        inner.sort_unstable();
26        inner.dedup();
27        Self(inner)
28    }
29
30    /// Create a TokenSet from an iterator of TokenId values.
31    pub fn from_token_ids<I: IntoIterator<Item = TokenId>>(iter: I) -> Self {
32        Self::from_u16_iter(iter.into_iter().map(|tid| tid.raw()))
33    }
34
35    /// Create an empty TokenSet.
36    pub fn new() -> Self {
37        Self(SmallVec::new())
38    }
39
40    /// Number of tokens in the set.
41    pub fn len(&self) -> usize {
42        self.0.len()
43    }
44
45    /// Is the set empty?
46    pub fn is_empty(&self) -> bool {
47        self.0.is_empty()
48    }
49
50    /// Return true if the set contains the given token ID.
51    pub fn contains_token_id(&self, token_id: TokenId) -> bool {
52        self.0.contains(&token_id.raw())
53    }
54
55    /// Get the subset containing only high-value (legalese) tokens.
56    pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
57        Self::from_u16_iter(
58            self.iter()
59                .filter(|&tid| dictionary.token_kind(TokenId::new(tid)) == TokenKind::Legalese),
60        )
61    }
62
63    /// Count intersection with another TokenSet (no allocation).
64    pub fn intersection_count(&self, other: &TokenSet) -> usize {
65        let (mut i, mut j, mut count) = (0, 0, 0);
66        while i < self.0.len() && j < other.0.len() {
67            match self.0[i].cmp(&other.0[j]) {
68                Ordering::Less => i += 1,
69                Ordering::Greater => j += 1,
70                Ordering::Equal => {
71                    count += 1;
72                    i += 1;
73                    j += 1;
74                }
75            }
76        }
77        count
78    }
79
80    /// Materialize intersection with another TokenSet.
81    pub fn intersection(&self, other: &TokenSet) -> TokenSet {
82        let mut result = SmallVec::new();
83        let (mut i, mut j) = (0, 0);
84        while i < self.0.len() && j < other.0.len() {
85            match self.0[i].cmp(&other.0[j]) {
86                Ordering::Less => i += 1,
87                Ordering::Greater => j += 1,
88                Ordering::Equal => {
89                    result.push(self.0[i]);
90                    i += 1;
91                    j += 1;
92                }
93            }
94        }
95        Self(result)
96    }
97
98    /// Iterate over the sorted token IDs.
99    pub fn iter(&self) -> impl Iterator<Item = u16> + '_ {
100        self.0.iter().copied()
101    }
102}
103
104impl Deref for TokenSet {
105    type Target = [u16];
106    fn deref(&self) -> &Self::Target {
107        &self.0
108    }
109}
110
111impl Default for TokenSet {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl std::iter::FromIterator<u16> for TokenSet {
118    fn from_iter<T: IntoIterator<Item = u16>>(iter: T) -> Self {
119        Self::from_u16_iter(iter)
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::license_detection::index::dictionary::tid;
127
128    #[test]
129    fn test_from_token_ids() {
130        let set = TokenSet::from_token_ids([tid(4), tid(2), tid(4), tid(1)]);
131
132        assert_eq!(set.iter().collect::<Vec<_>>(), vec![1, 2, 4]);
133        assert!(set.contains_token_id(tid(1)));
134        assert!(set.contains_token_id(tid(2)));
135        assert!(set.contains_token_id(tid(4)));
136    }
137
138    #[test]
139    fn test_high_subset() {
140        let set = TokenSet::from_u16_iter([1, 2, 5, 10]);
141        let dict = TokenDictionary::new_with_legalese_pairs(&[("one", 1), ("two", 2)]);
142
143        let high_set = set.high_subset(&dict);
144
145        assert_eq!(high_set.iter().collect::<Vec<_>>(), vec![1, 2]);
146    }
147}