Skip to main content

provenant/license_detection/index/
dictionary.rs

1//! Token string to integer ID mapping.
2//!
3//! TokenDictionary maps token strings to unique integer IDs. This enables
4//! efficient token-based matching and indexing.
5
6use std::collections::HashMap;
7
8use rkyv::Archive;
9use serde::{Deserialize, Serialize};
10
11#[derive(
12    Debug,
13    Clone,
14    Copy,
15    PartialEq,
16    Eq,
17    Hash,
18    PartialOrd,
19    Ord,
20    Serialize,
21    Deserialize,
22    Archive,
23    rkyv::Serialize,
24    rkyv::Deserialize,
25)]
26#[rkyv(derive(Hash, Eq, PartialEq, PartialOrd, Ord))]
27pub struct TokenId(u16);
28
29impl TokenId {
30    pub const fn new(raw: u16) -> Self {
31        Self(raw)
32    }
33
34    pub const fn raw(self) -> u16 {
35        self.0
36    }
37
38    pub const fn as_usize(self) -> usize {
39        self.0 as usize
40    }
41
42    pub const fn to_le_bytes(self) -> [u8; 2] {
43        self.0.to_le_bytes()
44    }
45}
46
47#[cfg(test)]
48pub const fn tid(raw: u16) -> TokenId {
49    TokenId::new(raw)
50}
51
52impl From<u16> for TokenId {
53    fn from(value: u16) -> Self {
54        Self(value)
55    }
56}
57
58impl From<TokenId> for u16 {
59    fn from(value: TokenId) -> Self {
60        value.0
61    }
62}
63
64impl PartialEq<u16> for TokenId {
65    fn eq(&self, other: &u16) -> bool {
66        self.0 == *other
67    }
68}
69
70impl PartialOrd<u16> for TokenId {
71    fn partial_cmp(&self, other: &u16) -> Option<std::cmp::Ordering> {
72        self.0.partial_cmp(other)
73    }
74}
75
76impl PartialEq<TokenId> for u16 {
77    fn eq(&self, other: &TokenId) -> bool {
78        *self == other.0
79    }
80}
81
82impl PartialOrd<TokenId> for u16 {
83    fn partial_cmp(&self, other: &TokenId) -> Option<std::cmp::Ordering> {
84        self.partial_cmp(&other.0)
85    }
86}
87
88#[derive(
89    Debug,
90    Clone,
91    Copy,
92    PartialEq,
93    Eq,
94    Hash,
95    Serialize,
96    Deserialize,
97    Archive,
98    rkyv::Serialize,
99    rkyv::Deserialize,
100)]
101pub enum TokenKind {
102    Legalese,
103    Regular,
104}
105
106#[derive(
107    Debug,
108    Clone,
109    Copy,
110    PartialEq,
111    Eq,
112    Hash,
113    Serialize,
114    Deserialize,
115    Archive,
116    rkyv::Serialize,
117    rkyv::Deserialize,
118)]
119pub struct KnownToken {
120    pub id: TokenId,
121    pub kind: TokenKind,
122    pub is_digit_only: bool,
123    pub is_short_or_digit: bool,
124}
125
126#[derive(
127    Debug,
128    Clone,
129    Copy,
130    PartialEq,
131    Eq,
132    Hash,
133    Serialize,
134    Deserialize,
135    Archive,
136    rkyv::Serialize,
137    rkyv::Deserialize,
138)]
139pub enum QueryToken {
140    Known(KnownToken),
141    Unknown,
142    Stopword,
143}
144
145#[derive(
146    Debug, Clone, Copy, Serialize, Deserialize, Archive, rkyv::Serialize, rkyv::Deserialize,
147)]
148pub struct TokenMetadata {
149    pub kind: TokenKind,
150    pub is_digit_only: bool,
151    pub is_short_or_digit: bool,
152}
153
154/// Token dictionary mapping token strings to unique integer IDs.
155///
156/// Token IDs are assigned as follows:
157/// - IDs 0 to len_legalese-1: Reserved for legalese tokens (high-value words)
158/// - IDs len_legalese and above: Assigned to other tokens as encountered
159///
160/// The `len_legalese` delimiter allows the matching engine to distinguish
161/// between high-value (legalese) tokens and regular tokens.
162///
163/// Based on the Python ScanCode Toolkit implementation at:
164/// reference/scancode-toolkit/src/licensedcode/index.py
165#[derive(Debug, Clone, Archive, rkyv::Serialize, rkyv::Deserialize)]
166pub struct TokenDictionary {
167    /// Mapping from token string to token ID
168    tokens_to_ids: HashMap<String, TokenId>,
169
170    token_metadata: Vec<Option<TokenMetadata>>,
171
172    /// Number of legalese tokens (lower IDs = higher value)
173    len_legalese: usize,
174
175    /// Next token ID to assign (for non-legalese tokens)
176    next_id: TokenId,
177}
178
179impl TokenDictionary {
180    const DEFAULT_METADATA: TokenMetadata = TokenMetadata {
181        kind: TokenKind::Regular,
182        is_digit_only: false,
183        is_short_or_digit: false,
184    };
185
186    /// Create a new token dictionary initialized with legalese tokens.
187    ///
188    /// This follows the Python ScanCode TorchToolkit pattern where the dictionary
189    /// starts with pre-defined legalese words that get low IDs (high value).
190    ///
191    /// # Arguments
192    /// * `legalese_entries` - Slice of (word, token_id) pairs for legalese words
193    ///
194    /// # Returns
195    /// A new TokenDictionary instance with legalese tokens pre-populated
196    pub fn new_with_legalese(legalese_entries: &[(&str, u16)]) -> Self {
197        let mut tokens_to_ids = HashMap::new();
198        let max_existing_id = legalese_entries
199            .iter()
200            .map(|(_, token_id)| *token_id as usize)
201            .max()
202            .unwrap_or(0);
203        let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
204
205        for (word, token_id) in legalese_entries {
206            let id = TokenId::from(*token_id);
207            tokens_to_ids.insert(word.to_string(), id);
208            token_metadata[id.as_usize()] = Some(TokenMetadata {
209                kind: TokenKind::Legalese,
210                is_digit_only: word.chars().all(|c| c.is_ascii_digit()),
211                is_short_or_digit: word.len() == 1 || word.chars().all(|c| c.is_ascii_digit()),
212            });
213        }
214
215        let len_legalese = legalese_entries.len();
216        let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
217
218        Self {
219            tokens_to_ids,
220            token_metadata,
221            len_legalese,
222            next_id,
223        }
224    }
225
226    /// Create a new empty token dictionary (for testing).
227    ///
228    /// # Arguments
229    /// * `legalese_count` - Number of reserved legalese token IDs
230    ///
231    /// # Returns
232    /// A new TokenDictionary instance
233    pub fn new(legalese_count: usize) -> Self {
234        Self {
235            tokens_to_ids: HashMap::new(),
236            token_metadata: Vec::new(),
237            len_legalese: legalese_count,
238            next_id: TokenId::new(legalese_count as u16),
239        }
240    }
241
242    fn metadata_for(&self, id: TokenId) -> TokenMetadata {
243        self.token_metadata
244            .get(id.as_usize())
245            .and_then(|meta| *meta)
246            .unwrap_or(Self::DEFAULT_METADATA)
247    }
248
249    fn build_known_token(&self, id: TokenId) -> KnownToken {
250        let metadata = self.metadata_for(id);
251        KnownToken {
252            id,
253            kind: metadata.kind,
254            is_digit_only: metadata.is_digit_only,
255            is_short_or_digit: metadata.is_short_or_digit,
256        }
257    }
258
259    fn insert_metadata(&mut self, id: TokenId, kind: TokenKind, token: &str) {
260        let raw = id.as_usize();
261        if self.token_metadata.len() <= raw {
262            self.token_metadata.resize(raw + 1, None);
263        }
264        self.token_metadata[raw] = Some(TokenMetadata {
265            kind,
266            is_digit_only: token.chars().all(|c| c.is_ascii_digit()),
267            is_short_or_digit: token.len() == 1 || token.chars().all(|c| c.is_ascii_digit()),
268        });
269    }
270
271    pub fn intern(&mut self, token: &str) -> KnownToken {
272        if let Some(&id) = self.tokens_to_ids.get(token) {
273            return self.build_known_token(id);
274        }
275
276        let id = self.next_id;
277        self.next_id = TokenId::new(self.next_id.raw() + 1);
278        self.tokens_to_ids.insert(token.to_string(), id);
279        self.insert_metadata(id, TokenKind::Regular, token);
280        self.build_known_token(id)
281    }
282
283    pub fn lookup(&self, token: &str) -> Option<KnownToken> {
284        self.tokens_to_ids
285            .get(token)
286            .copied()
287            .map(|id| self.build_known_token(id))
288    }
289
290    pub fn classify_query_token(&self, token: &str) -> QueryToken {
291        self.lookup(token)
292            .map_or(QueryToken::Unknown, QueryToken::Known)
293    }
294
295    pub fn token_kind(&self, token_id: TokenId) -> TokenKind {
296        self.metadata_for(token_id).kind
297    }
298
299    pub fn is_digit_only_token(&self, token_id: TokenId) -> bool {
300        self.metadata_for(token_id).is_digit_only
301    }
302
303    #[cfg(test)]
304    pub fn get_or_assign(&mut self, token: &str) -> TokenId {
305        self.intern(token).id
306    }
307
308    /// Get the token ID for a token string if it exists.
309    ///
310    /// # Arguments
311    /// * `token` - The token string
312    ///
313    /// # Returns
314    /// Some(token_id) if the token exists, None otherwise
315    pub fn get_token_id(&self, token: &str) -> Option<TokenId> {
316        self.lookup(token).map(|token| token.id)
317    }
318
319    /// Get the token ID (alias for backward compatibility).
320    #[inline]
321    pub fn get(&self, token: &str) -> Option<TokenId> {
322        self.get_token_id(token)
323    }
324
325    /// Get the number of legalese tokens.
326    pub const fn legalese_count(&self) -> usize {
327        self.len_legalese
328    }
329
330    /// Get an iterator over all token string and ID pairs.
331    #[cfg(test)]
332    pub fn tokens_to_ids(&self) -> impl Iterator<Item = (&String, &TokenId)> {
333        self.tokens_to_ids.iter()
334    }
335
336    /// Get the number of tokens in the dictionary.
337    // This method will be used by the embedded index roundtrip tests in upcoming phases.
338    #[allow(dead_code)]
339    pub fn tokens_to_ids_len(&self) -> usize {
340        self.tokens_to_ids.len()
341    }
342}
343
344impl Default for TokenDictionary {
345    fn default() -> Self {
346        Self::new(0)
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_token_dictionary_new() {
356        let dict = TokenDictionary::new(10);
357        assert_eq!(dict.legalese_count(), 10);
358        assert_eq!(dict.tokens_to_ids.len(), 0);
359        assert!(dict.tokens_to_ids.is_empty());
360    }
361
362    #[test]
363    fn test_new_with_legalese() {
364        let legalese = [
365            ("license".to_string(), 0u16),
366            ("copyright".to_string(), 1u16),
367            ("permission".to_string(), 2u16),
368        ];
369
370        let mut dict = TokenDictionary::new_with_legalese(
371            &legalese
372                .iter()
373                .map(|(s, i)| (s.as_str(), *i))
374                .collect::<Vec<_>>(),
375        );
376
377        assert_eq!(dict.legalese_count(), 3);
378        assert_eq!(dict.tokens_to_ids.len(), 3);
379        assert!(!dict.tokens_to_ids.is_empty());
380
381        // Check that legalese tokens are registered
382        assert_eq!(dict.get_token_id("license"), Some(tid(0)));
383        assert_eq!(dict.get_token_id("copyright"), Some(tid(1)));
384        assert_eq!(dict.get_token_id("permission"), Some(tid(2)));
385
386        // Check that new tokens get IDs starting after legalese
387        let test_id = dict.get_or_assign("test");
388        assert_eq!(test_id, 3);
389    }
390
391    #[test]
392    fn test_new_with_legalese_sorted() {
393        let legalese = [
394            ("copyright".to_string(), 5u16),
395            ("license".to_string(), 0u16),
396            ("permission".to_string(), 10u16),
397        ];
398
399        let mut dict = TokenDictionary::new_with_legalese(
400            &legalese
401                .iter()
402                .map(|(s, i)| (s.as_str(), *i))
403                .collect::<Vec<_>>(),
404        );
405
406        assert_eq!(dict.legalese_count(), 3);
407        assert_eq!(dict.tokens_to_ids.len(), 3);
408
409        // Check legalese IDs are correct regardless of input order
410        assert_eq!(dict.get_token_id("copyright"), Some(tid(5)));
411        assert_eq!(dict.get_token_id("license"), Some(tid(0)));
412        assert_eq!(dict.get_token_id("permission"), Some(tid(10)));
413
414        // Next ID should advance past the highest explicit legalese token ID.
415        let test_id = dict.get_or_assign("test");
416        assert_eq!(test_id, tid(11));
417    }
418
419    #[test]
420    fn test_get_or_assign_new_token() {
421        let mut dict = TokenDictionary::new(5);
422
423        let id1 = dict.get_or_assign("hello");
424        let id2 = dict.get_or_assign("world");
425
426        // Should assign IDs starting at legalese_count (5)
427        assert_eq!(id1, 5);
428        assert_eq!(id2, 6);
429        assert_eq!(dict.tokens_to_ids.len(), 2);
430    }
431
432    #[test]
433    fn test_get_or_assign_existing_token() {
434        let mut dict = TokenDictionary::new(5);
435
436        let id1 = dict.get_or_assign("hello");
437        let id2 = dict.get_or_assign("hello");
438
439        // Should return the same ID for the same token
440        assert_eq!(id1, id2);
441        assert_eq!(dict.tokens_to_ids.len(), 1);
442    }
443
444    #[test]
445    fn test_get_or_assign_with_preexisting_legalese() {
446        let legalese = [("license".to_string(), 0u16)];
447        let mut dict = TokenDictionary::new_with_legalese(
448            &legalese
449                .iter()
450                .map(|(s, i)| (s.as_str(), *i))
451                .collect::<Vec<_>>(),
452        );
453
454        // Legalese tokens should already exist
455        let id = dict.get_or_assign("license");
456        assert_eq!(id, 0);
457        assert_eq!(dict.tokens_to_ids.len(), 1);
458
459        // New tokens should get IDs after legalese
460        let new_id = dict.get_or_assign("new");
461        assert_eq!(new_id, 1);
462        assert_eq!(dict.tokens_to_ids.len(), 2);
463    }
464
465    #[test]
466    fn test_get_existing_token() {
467        let mut dict = TokenDictionary::new(5);
468
469        dict.get_or_assign("hello");
470        assert_eq!(dict.get_token_id("hello"), Some(tid(5)));
471    }
472
473    #[test]
474    fn test_get_nonexistent_token() {
475        let dict = TokenDictionary::new(5);
476        assert_eq!(dict.get_token_id("hello"), None);
477    }
478
479    #[test]
480    fn test_legalese_range() {
481        let dict = TokenDictionary::new(10);
482
483        // IDs 0-9 are legalese
484        assert!(0 < dict.legalese_count() as u16);
485        assert!(5 < dict.legalese_count() as u16);
486        assert!(9 < dict.legalese_count() as u16);
487
488        // ID 10+ are not legalese
489        assert!(10 >= dict.legalese_count() as u16);
490        assert!(100 >= dict.legalese_count() as u16);
491    }
492
493    #[test]
494    fn test_legalese_range_with_actual_legalese() {
495        let legalese = [
496            ("license".to_string(), 0u16),
497            ("copyright".to_string(), 1u16),
498        ];
499
500        let mut dict = TokenDictionary::new_with_legalese(
501            &legalese
502                .iter()
503                .map(|(s, i)| (s.as_str(), *i))
504                .collect::<Vec<_>>(),
505        );
506
507        // Legalese tokens should have IDs in the legalese range
508        assert!(dict.get_token_id("license").unwrap() < dict.legalese_count() as u16);
509        assert!(dict.get_token_id("copyright").unwrap() < dict.legalese_count() as u16);
510
511        // Regular tokens should not be legalese
512        let regular_id = dict.get_or_assign("regular");
513        assert!(regular_id >= dict.legalese_count() as u16);
514    }
515
516    #[test]
517    fn test_token_dictionary_default() {
518        let dict = TokenDictionary::default();
519        assert_eq!(dict.legalese_count(), 0);
520        assert!(dict.tokens_to_ids.is_empty());
521    }
522
523    #[test]
524    fn test_get_alias() {
525        let mut dict = TokenDictionary::new(5);
526        dict.get_or_assign("hello");
527
528        // get() should be an alias for get_token_id()
529        assert_eq!(dict.get("hello"), dict.get_token_id("hello"));
530    }
531
532    #[test]
533    fn test_with_actual_legalese_module() {
534        use crate::license_detection::rules::legalese;
535
536        let legalese_words = legalese::get_legalese_words();
537        assert!(!legalese_words.is_empty(), "Should have legalese words");
538
539        let mut dict = TokenDictionary::new_with_legalese(&legalese_words);
540
541        // Verify dictionary has the right structure
542        assert_eq!(dict.legalese_count(), legalese_words.len());
543        assert_eq!(dict.tokens_to_ids.len(), legalese_words.len());
544
545        // Verify some legalese words are correctly registered
546        let license_id = dict.get_token_id("license");
547        assert!(license_id.is_some(), "License should be in dictionary");
548        assert!(
549            license_id.unwrap() < dict.legalese_count() as u16,
550            "License should be a legalese token"
551        );
552
553        // Note: Standalone "copyright" is NOT in the Python reference dictionary
554        // Only compound words like "copyrighted", "copyrights" are present
555        let copyrighted_id = dict.get_token_id("copyrighted");
556        assert!(
557            copyrighted_id.is_some(),
558            "Copyrighted should be in dictionary"
559        );
560        assert!(
561            copyrighted_id.unwrap() < dict.legalese_count() as u16,
562            "Copyrighted should be a legalese token"
563        );
564
565        // New tokens should get IDs after legalese
566        let hello_id = dict.get_or_assign("hello");
567        assert!(hello_id >= dict.legalese_count() as u16);
568        assert!(hello_id >= dict.legalese_count() as u16);
569    }
570}