provenant/license_detection/
token_multiset.rs1use std::collections::HashMap;
5use std::ops::Deref;
6
7use rkyv::Archive;
8
9use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
10
11#[derive(Clone, Debug, PartialEq, Eq, Default, Archive, rkyv::Serialize, rkyv::Deserialize)]
13pub struct TokenMultiset(HashMap<TokenId, usize>);
14
15impl TokenMultiset {
16 pub fn from_token_ids(token_ids: &[TokenId]) -> Self {
18 let mut counts = HashMap::new();
19
20 for &tid in token_ids {
21 *counts.entry(tid).or_insert(0) += 1;
22 }
23
24 Self(counts)
25 }
26
27 pub fn total_count(&self) -> usize {
29 self.0.values().sum()
30 }
31
32 pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
34 self.0
35 .iter()
36 .filter(|(tid, _)| dictionary.token_kind(**tid) == TokenKind::Legalese)
37 .map(|(&tid, &count)| (tid, count))
38 .collect()
39 }
40
41 pub fn intersection(&self, other: &Self) -> Self {
43 let (smaller, larger) = if self.0.len() < other.0.len() {
44 (&self.0, &other.0)
45 } else {
46 (&other.0, &self.0)
47 };
48
49 smaller
50 .iter()
51 .filter_map(|(&tid, &count)| {
52 larger
53 .get(&tid)
54 .map(|&other_count| (tid, count.min(other_count)))
55 })
56 .collect()
57 }
58}
59
60impl Deref for TokenMultiset {
61 type Target = HashMap<TokenId, usize>;
62
63 fn deref(&self) -> &Self::Target {
64 &self.0
65 }
66}
67
68impl FromIterator<(TokenId, usize)> for TokenMultiset {
69 fn from_iter<T: IntoIterator<Item = (TokenId, usize)>>(iter: T) -> Self {
70 Self(iter.into_iter().collect())
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use crate::license_detection::index::dictionary::{TokenDictionary, tid};
78
79 #[test]
80 fn test_from_token_ids() {
81 let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(4), tid(1), tid(1)];
82 let multiset = TokenMultiset::from_token_ids(&token_ids);
83
84 assert_eq!(multiset.get(&tid(1)), Some(&3));
85 assert_eq!(multiset.get(&tid(2)), Some(&2));
86 assert_eq!(multiset.get(&tid(3)), Some(&1));
87 assert_eq!(multiset.get(&tid(4)), Some(&1));
88 }
89
90 #[test]
91 fn test_total_count() {
92 let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(1), tid(1)];
93 let multiset = TokenMultiset::from_token_ids(&token_ids);
94
95 assert_eq!(multiset.total_count(), 6);
96 }
97
98 #[test]
99 fn test_high_subset() {
100 let token_ids = vec![tid(1), tid(1), tid(2), tid(5), tid(10)];
101 let multiset = TokenMultiset::from_token_ids(&token_ids);
102 let dict = TokenDictionary::new_with_legalese_pairs(&[("one", 1), ("two", 2)]);
103
104 let high_multiset = multiset.high_subset(&dict);
105
106 assert_eq!(high_multiset.len(), 2);
107 assert_eq!(high_multiset.get(&tid(1)), Some(&2));
108 assert_eq!(high_multiset.get(&tid(2)), Some(&1));
109 assert!(!high_multiset.contains_key(&tid(5)));
110 assert!(!high_multiset.contains_key(&tid(10)));
111 }
112
113 #[test]
114 fn test_intersection() {
115 let left = TokenMultiset::from_token_ids(&[tid(1), tid(1), tid(2), tid(3)]);
116 let right = TokenMultiset::from_token_ids(&[tid(1), tid(2), tid(2), tid(4)]);
117
118 let intersection = left.intersection(&right);
119
120 assert_eq!(intersection.get(&tid(1)), Some(&1));
121 assert_eq!(intersection.get(&tid(2)), Some(&1));
122 assert!(!intersection.contains_key(&tid(3)));
123 assert!(!intersection.contains_key(&tid(4)));
124 assert_eq!(intersection.total_count(), 2);
125 }
126}