Skip to main content

provenant/license_detection/
token_multiset.rs

1// SPDX-FileCopyrightText: Provenant contributors
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::ops::Deref;
6
7use rkyv::Archive;
8
9use crate::license_detection::index::dictionary::{TokenDictionary, TokenId, TokenKind};
10
11/// A multiset of token IDs stored as token -> occurrence count.
12#[derive(Clone, Debug, PartialEq, Eq, Default, Archive, rkyv::Serialize, rkyv::Deserialize)]
13pub struct TokenMultiset(HashMap<TokenId, usize>);
14
15impl TokenMultiset {
16    /// Create a TokenMultiset from a sequence of token IDs.
17    pub fn from_token_ids(token_ids: &[TokenId]) -> Self {
18        let mut counts = HashMap::new();
19
20        for &tid in token_ids {
21            *counts.entry(tid).or_insert(0) += 1;
22        }
23
24        Self(counts)
25    }
26
27    /// Total number of token occurrences in the multiset.
28    pub fn total_count(&self) -> usize {
29        self.0.values().sum()
30    }
31
32    /// Get a subset containing only high-value (legalese) tokens.
33    pub fn high_subset(&self, dictionary: &TokenDictionary) -> Self {
34        self.0
35            .iter()
36            .filter(|(tid, _)| dictionary.token_kind(**tid) == TokenKind::Legalese)
37            .map(|(&tid, &count)| (tid, count))
38            .collect()
39    }
40
41    /// Materialize the multiset intersection with another TokenMultiset.
42    pub fn intersection(&self, other: &Self) -> Self {
43        let (smaller, larger) = if self.0.len() < other.0.len() {
44            (&self.0, &other.0)
45        } else {
46            (&other.0, &self.0)
47        };
48
49        smaller
50            .iter()
51            .filter_map(|(&tid, &count)| {
52                larger
53                    .get(&tid)
54                    .map(|&other_count| (tid, count.min(other_count)))
55            })
56            .collect()
57    }
58}
59
60impl Deref for TokenMultiset {
61    type Target = HashMap<TokenId, usize>;
62
63    fn deref(&self) -> &Self::Target {
64        &self.0
65    }
66}
67
68impl FromIterator<(TokenId, usize)> for TokenMultiset {
69    fn from_iter<T: IntoIterator<Item = (TokenId, usize)>>(iter: T) -> Self {
70        Self(iter.into_iter().collect())
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use crate::license_detection::index::dictionary::{TokenDictionary, tid};
78
79    #[test]
80    fn test_from_token_ids() {
81        let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(4), tid(1), tid(1)];
82        let multiset = TokenMultiset::from_token_ids(&token_ids);
83
84        assert_eq!(multiset.get(&tid(1)), Some(&3));
85        assert_eq!(multiset.get(&tid(2)), Some(&2));
86        assert_eq!(multiset.get(&tid(3)), Some(&1));
87        assert_eq!(multiset.get(&tid(4)), Some(&1));
88    }
89
90    #[test]
91    fn test_total_count() {
92        let token_ids = vec![tid(1), tid(2), tid(3), tid(2), tid(1), tid(1)];
93        let multiset = TokenMultiset::from_token_ids(&token_ids);
94
95        assert_eq!(multiset.total_count(), 6);
96    }
97
98    #[test]
99    fn test_high_subset() {
100        let token_ids = vec![tid(1), tid(1), tid(2), tid(5), tid(10)];
101        let multiset = TokenMultiset::from_token_ids(&token_ids);
102        let dict = TokenDictionary::new_with_legalese_pairs(&[("one", 1), ("two", 2)]);
103
104        let high_multiset = multiset.high_subset(&dict);
105
106        assert_eq!(high_multiset.len(), 2);
107        assert_eq!(high_multiset.get(&tid(1)), Some(&2));
108        assert_eq!(high_multiset.get(&tid(2)), Some(&1));
109        assert!(!high_multiset.contains_key(&tid(5)));
110        assert!(!high_multiset.contains_key(&tid(10)));
111    }
112
113    #[test]
114    fn test_intersection() {
115        let left = TokenMultiset::from_token_ids(&[tid(1), tid(1), tid(2), tid(3)]);
116        let right = TokenMultiset::from_token_ids(&[tid(1), tid(2), tid(2), tid(4)]);
117
118        let intersection = left.intersection(&right);
119
120        assert_eq!(intersection.get(&tid(1)), Some(&1));
121        assert_eq!(intersection.get(&tid(2)), Some(&1));
122        assert!(!intersection.contains_key(&tid(3)));
123        assert!(!intersection.contains_key(&tid(4)));
124        assert_eq!(intersection.total_count(), 2);
125    }
126}