Skip to main content

provenant/license_detection/
token_multiset.rs

1use std::collections::HashMap;
2use std::ops::Deref;
3
4use rkyv::Archive;
5
6use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
7
8/// A multiset of token IDs stored as token -> occurrence count.
9#[derive(Clone, Debug, PartialEq, Eq, Default, Archive, rkyv::Serialize, rkyv::Deserialize)]
10pub struct TokenMultiset(HashMap<TokenId, usize>);
11
12impl TokenMultiset {
13    /// Create a TokenMultiset from a sequence of token IDs.
14    pub fn from_token_ids(token_ids: &[TokenId]) -> Self {
15        let mut counts = HashMap::new();
16
17        for &tid in token_ids {
18            *counts.entry(tid).or_insert(0) += 1;
19        }
20
21        Self(counts)
22    }
23
24    /// Total number of token occurrences in the multiset.
25    pub fn total_count(&self) -> usize {
26        self.0.values().sum()
27    }
28
29    /// Get a subset containing only high-value (legalese) tokens.
30    pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
31        self.0
32            .iter()
33            .filter(|(tid, _)| dictionary.token_kind(**tid) == TokenKind::Legalese)
34            .map(|(&tid, &count)| (tid, count))
35            .collect()
36    }
37
38    /// Materialize the multiset intersection with another TokenMultiset.
39    pub fn intersection(&self, other: &Self) -> Self {
40        let (smaller, larger) = if self.0.len() < other.0.len() {
41            (&self.0, &other.0)
42        } else {
43            (&other.0, &self.0)
44        };
45
46        smaller
47            .iter()
48            .filter_map(|(&tid, &count)| {
49                larger
50                    .get(&tid)
51                    .map(|&other_count| (tid, count.min(other_count)))
52            })
53            .collect()
54    }
55}
56
57impl Deref for TokenMultiset {
58    type Target = HashMap<TokenId, usize>;
59
60    fn deref(&self) -> &Self::Target {
61        &self.0
62    }
63}
64
65impl FromIterator<(TokenId, usize)> for TokenMultiset {
66    fn from_iter<T: IntoIterator<Item = (TokenId, usize)>>(iter: T) -> Self {
67        Self(iter.into_iter().collect())
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::license_detection::index::dictionary::{TokenDictionary, tid};
75
76    #[test]
77    fn test_from_token_ids() {
78        let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(4), tid(1), tid(1)];
79        let multiset = TokenMultiset::from_token_ids(&token_ids);
80
81        assert_eq!(multiset.get(&tid(1)), Some(&3));
82        assert_eq!(multiset.get(&tid(2)), Some(&2));
83        assert_eq!(multiset.get(&tid(3)), Some(&1));
84        assert_eq!(multiset.get(&tid(4)), Some(&1));
85    }
86
87    #[test]
88    fn test_total_count() {
89        let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(1), tid(1)];
90        let multiset = TokenMultiset::from_token_ids(&token_ids);
91
92        assert_eq!(multiset.total_count(), 6);
93    }
94
95    #[test]
96    fn test_high_subset() {
97        let token_ids = vec![tid(1), tid(1), tid(2), tid(5), tid(10)];
98        let multiset = TokenMultiset::from_token_ids(&token_ids);
99        let dict = TokenDictionary::new_with_legalese(&[("one", 1), ("two", 2)]);
100
101        let high_multiset = multiset.high_subset(&dict);
102
103        assert_eq!(high_multiset.len(), 2);
104        assert_eq!(high_multiset.get(&tid(1)), Some(&2));
105        assert_eq!(high_multiset.get(&tid(2)), Some(&1));
106        assert!(!high_multiset.contains_key(&tid(5)));
107        assert!(!high_multiset.contains_key(&tid(10)));
108    }
109
110    #[test]
111    fn test_intersection() {
112        let left = TokenMultiset::from_token_ids(&[tid(1), tid(1), tid(2), tid(3)]);
113        let right = TokenMultiset::from_token_ids(&[tid(1), tid(2), tid(2), tid(4)]);
114
115        let intersection = left.intersection(&right);
116
117        assert_eq!(intersection.get(&tid(1)), Some(&1));
118        assert_eq!(intersection.get(&tid(2)), Some(&1));
119        assert!(!intersection.contains_key(&tid(3)));
120        assert!(!intersection.contains_key(&tid(4)));
121        assert_eq!(intersection.total_count(), 2);
122    }
123}