Skip to main content

provenant/license_detection/
token_multiset.rs

1use std::collections::HashMap;
2use std::ops::Deref;
3
4use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
5
6/// A multiset of token IDs stored as token -> occurrence count.
7#[derive(Clone, Debug, PartialEq, Eq, Default)]
8pub struct TokenMultiset(HashMap<TokenId, usize>);
9
10impl TokenMultiset {
11    /// Create a TokenMultiset from a sequence of token IDs.
12    pub fn from_token_ids(token_ids: &[TokenId]) -> Self {
13        let mut counts = HashMap::new();
14
15        for &tid in token_ids {
16            *counts.entry(tid).or_insert(0) += 1;
17        }
18
19        Self(counts)
20    }
21
22    /// Total number of token occurrences in the multiset.
23    pub fn total_count(&self) -> usize {
24        self.0.values().sum()
25    }
26
27    /// Get a subset containing only high-value (legalese) tokens.
28    pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
29        self.0
30            .iter()
31            .filter(|(tid, _)| dictionary.token_kind(**tid) == TokenKind::Legalese)
32            .map(|(&tid, &count)| (tid, count))
33            .collect()
34    }
35
36    /// Materialize the multiset intersection with another TokenMultiset.
37    pub fn intersection(&self, other: &Self) -> Self {
38        let (smaller, larger) = if self.0.len() < other.0.len() {
39            (&self.0, &other.0)
40        } else {
41            (&other.0, &self.0)
42        };
43
44        smaller
45            .iter()
46            .filter_map(|(&tid, &count)| {
47                larger
48                    .get(&tid)
49                    .map(|&other_count| (tid, count.min(other_count)))
50            })
51            .collect()
52    }
53}
54
55impl Deref for TokenMultiset {
56    type Target = HashMap<TokenId, usize>;
57
58    fn deref(&self) -> &Self::Target {
59        &self.0
60    }
61}
62
63impl FromIterator<(TokenId, usize)> for TokenMultiset {
64    fn from_iter<T: IntoIterator<Item = (TokenId, usize)>>(iter: T) -> Self {
65        Self(iter.into_iter().collect())
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::license_detection::index::dictionary::{TokenDictionary, tid};
73
74    #[test]
75    fn test_from_token_ids() {
76        let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(4), tid(1), tid(1)];
77        let multiset = TokenMultiset::from_token_ids(&token_ids);
78
79        assert_eq!(multiset.get(&tid(1)), Some(&3));
80        assert_eq!(multiset.get(&tid(2)), Some(&2));
81        assert_eq!(multiset.get(&tid(3)), Some(&1));
82        assert_eq!(multiset.get(&tid(4)), Some(&1));
83    }
84
85    #[test]
86    fn test_total_count() {
87        let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(1), tid(1)];
88        let multiset = TokenMultiset::from_token_ids(&token_ids);
89
90        assert_eq!(multiset.total_count(), 6);
91    }
92
93    #[test]
94    fn test_high_subset() {
95        let token_ids = vec![tid(1), tid(1), tid(2), tid(5), tid(10)];
96        let multiset = TokenMultiset::from_token_ids(&token_ids);
97        let dict = TokenDictionary::new_with_legalese(&[("one", 1), ("two", 2)]);
98
99        let high_multiset = multiset.high_subset(&dict);
100
101        assert_eq!(high_multiset.len(), 2);
102        assert_eq!(high_multiset.get(&tid(1)), Some(&2));
103        assert_eq!(high_multiset.get(&tid(2)), Some(&1));
104        assert!(!high_multiset.contains_key(&tid(5)));
105        assert!(!high_multiset.contains_key(&tid(10)));
106    }
107
108    #[test]
109    fn test_intersection() {
110        let left = TokenMultiset::from_token_ids(&[tid(1), tid(1), tid(2), tid(3)]);
111        let right = TokenMultiset::from_token_ids(&[tid(1), tid(2), tid(2), tid(4)]);
112
113        let intersection = left.intersection(&right);
114
115        assert_eq!(intersection.get(&tid(1)), Some(&1));
116        assert_eq!(intersection.get(&tid(2)), Some(&1));
117        assert!(!intersection.contains_key(&tid(3)));
118        assert!(!intersection.contains_key(&tid(4)));
119        assert_eq!(intersection.total_count(), 2);
120    }
121}