provenant/license_detection/index/
dictionary.rs1use std::collections::HashMap;
7
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
11pub struct TokenId(u16);
12
13impl TokenId {
14 pub const fn new(raw: u16) -> Self {
15 Self(raw)
16 }
17
18 pub const fn raw(self) -> u16 {
19 self.0
20 }
21
22 pub const fn as_usize(self) -> usize {
23 self.0 as usize
24 }
25
26 pub const fn to_le_bytes(self) -> [u8; 2] {
27 self.0.to_le_bytes()
28 }
29}
30
31#[cfg(test)]
32pub const fn tid(raw: u16) -> TokenId {
33 TokenId::new(raw)
34}
35
36impl From<u16> for TokenId {
37 fn from(value: u16) -> Self {
38 Self(value)
39 }
40}
41
42impl From<TokenId> for u16 {
43 fn from(value: TokenId) -> Self {
44 value.0
45 }
46}
47
48impl PartialEq<u16> for TokenId {
49 fn eq(&self, other: &u16) -> bool {
50 self.0 == *other
51 }
52}
53
54impl PartialOrd<u16> for TokenId {
55 fn partial_cmp(&self, other: &u16) -> Option<std::cmp::Ordering> {
56 self.0.partial_cmp(other)
57 }
58}
59
60impl PartialEq<TokenId> for u16 {
61 fn eq(&self, other: &TokenId) -> bool {
62 *self == other.0
63 }
64}
65
66impl PartialOrd<TokenId> for u16 {
67 fn partial_cmp(&self, other: &TokenId) -> Option<std::cmp::Ordering> {
68 self.partial_cmp(&other.0)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73pub enum TokenKind {
74 Legalese,
75 Regular,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
79pub struct KnownToken {
80 pub id: TokenId,
81 pub kind: TokenKind,
82 pub is_digit_only: bool,
83 pub is_short_or_digit: bool,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
87pub enum QueryToken {
88 Known(KnownToken),
89 Unknown,
90 Stopword,
91}
92
93#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
94pub struct TokenMetadata {
95 pub kind: TokenKind,
96 pub is_digit_only: bool,
97 pub is_short_or_digit: bool,
98}
99
100#[derive(Debug, Clone)]
112pub struct TokenDictionary {
113 tokens_to_ids: HashMap<String, TokenId>,
115
116 token_metadata: Vec<Option<TokenMetadata>>,
117
118 len_legalese: usize,
120
121 next_id: TokenId,
123}
124
125impl TokenDictionary {
126 const DEFAULT_METADATA: TokenMetadata = TokenMetadata {
127 kind: TokenKind::Regular,
128 is_digit_only: false,
129 is_short_or_digit: false,
130 };
131
132 pub fn new_with_legalese(legalese_entries: &[(&str, u16)]) -> Self {
143 let mut tokens_to_ids = HashMap::new();
144 let max_existing_id = legalese_entries
145 .iter()
146 .map(|(_, token_id)| *token_id as usize)
147 .max()
148 .unwrap_or(0);
149 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
150
151 for (word, token_id) in legalese_entries {
152 let id = TokenId::from(*token_id);
153 tokens_to_ids.insert(word.to_string(), id);
154 token_metadata[id.as_usize()] = Some(TokenMetadata {
155 kind: TokenKind::Legalese,
156 is_digit_only: word.chars().all(|c| c.is_ascii_digit()),
157 is_short_or_digit: word.len() == 1 || word.chars().all(|c| c.is_ascii_digit()),
158 });
159 }
160
161 let len_legalese = legalese_entries.len();
162 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
163
164 Self {
165 tokens_to_ids,
166 token_metadata,
167 len_legalese,
168 next_id,
169 }
170 }
171
172 pub fn new(legalese_count: usize) -> Self {
180 Self {
181 tokens_to_ids: HashMap::new(),
182 token_metadata: Vec::new(),
183 len_legalese: legalese_count,
184 next_id: TokenId::new(legalese_count as u16),
185 }
186 }
187
188 fn metadata_for(&self, id: TokenId) -> TokenMetadata {
189 self.token_metadata
190 .get(id.as_usize())
191 .and_then(|meta| *meta)
192 .unwrap_or(Self::DEFAULT_METADATA)
193 }
194
195 fn build_known_token(&self, id: TokenId) -> KnownToken {
196 let metadata = self.metadata_for(id);
197 KnownToken {
198 id,
199 kind: metadata.kind,
200 is_digit_only: metadata.is_digit_only,
201 is_short_or_digit: metadata.is_short_or_digit,
202 }
203 }
204
205 fn insert_metadata(&mut self, id: TokenId, kind: TokenKind, token: &str) {
206 let raw = id.as_usize();
207 if self.token_metadata.len() <= raw {
208 self.token_metadata.resize(raw + 1, None);
209 }
210 self.token_metadata[raw] = Some(TokenMetadata {
211 kind,
212 is_digit_only: token.chars().all(|c| c.is_ascii_digit()),
213 is_short_or_digit: token.len() == 1 || token.chars().all(|c| c.is_ascii_digit()),
214 });
215 }
216
217 pub fn intern(&mut self, token: &str) -> KnownToken {
218 if let Some(&id) = self.tokens_to_ids.get(token) {
219 return self.build_known_token(id);
220 }
221
222 let id = self.next_id;
223 self.next_id = TokenId::new(self.next_id.raw() + 1);
224 self.tokens_to_ids.insert(token.to_string(), id);
225 self.insert_metadata(id, TokenKind::Regular, token);
226 self.build_known_token(id)
227 }
228
229 pub fn lookup(&self, token: &str) -> Option<KnownToken> {
230 self.tokens_to_ids
231 .get(token)
232 .copied()
233 .map(|id| self.build_known_token(id))
234 }
235
236 pub fn classify_query_token(&self, token: &str) -> QueryToken {
237 self.lookup(token)
238 .map_or(QueryToken::Unknown, QueryToken::Known)
239 }
240
241 pub fn token_kind(&self, token_id: TokenId) -> TokenKind {
242 self.metadata_for(token_id).kind
243 }
244
245 pub fn is_digit_only_token(&self, token_id: TokenId) -> bool {
246 self.metadata_for(token_id).is_digit_only
247 }
248
249 #[cfg(test)]
250 pub fn get_or_assign(&mut self, token: &str) -> TokenId {
251 self.intern(token).id
252 }
253
254 pub fn get_token_id(&self, token: &str) -> Option<TokenId> {
262 self.lookup(token).map(|token| token.id)
263 }
264
265 #[inline]
267 pub fn get(&self, token: &str) -> Option<TokenId> {
268 self.get_token_id(token)
269 }
270
271 pub const fn legalese_count(&self) -> usize {
273 self.len_legalese
274 }
275
276 #[cfg(test)]
278 pub fn tokens_to_ids(&self) -> impl Iterator<Item = (&String, &TokenId)> {
279 self.tokens_to_ids.iter()
280 }
281
282 #[allow(dead_code)]
285 pub fn tokens_to_ids_len(&self) -> usize {
286 self.tokens_to_ids.len()
287 }
288}
289
290impl Default for TokenDictionary {
291 fn default() -> Self {
292 Self::new(0)
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_token_dictionary_new() {
302 let dict = TokenDictionary::new(10);
303 assert_eq!(dict.legalese_count(), 10);
304 assert_eq!(dict.tokens_to_ids.len(), 0);
305 assert!(dict.tokens_to_ids.is_empty());
306 }
307
308 #[test]
309 fn test_new_with_legalese() {
310 let legalese = [
311 ("license".to_string(), 0u16),
312 ("copyright".to_string(), 1u16),
313 ("permission".to_string(), 2u16),
314 ];
315
316 let mut dict = TokenDictionary::new_with_legalese(
317 &legalese
318 .iter()
319 .map(|(s, i)| (s.as_str(), *i))
320 .collect::<Vec<_>>(),
321 );
322
323 assert_eq!(dict.legalese_count(), 3);
324 assert_eq!(dict.tokens_to_ids.len(), 3);
325 assert!(!dict.tokens_to_ids.is_empty());
326
327 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
329 assert_eq!(dict.get_token_id("copyright"), Some(tid(1)));
330 assert_eq!(dict.get_token_id("permission"), Some(tid(2)));
331
332 let test_id = dict.get_or_assign("test");
334 assert_eq!(test_id, 3);
335 }
336
337 #[test]
338 fn test_new_with_legalese_sorted() {
339 let legalese = [
340 ("copyright".to_string(), 5u16),
341 ("license".to_string(), 0u16),
342 ("permission".to_string(), 10u16),
343 ];
344
345 let mut dict = TokenDictionary::new_with_legalese(
346 &legalese
347 .iter()
348 .map(|(s, i)| (s.as_str(), *i))
349 .collect::<Vec<_>>(),
350 );
351
352 assert_eq!(dict.legalese_count(), 3);
353 assert_eq!(dict.tokens_to_ids.len(), 3);
354
355 assert_eq!(dict.get_token_id("copyright"), Some(tid(5)));
357 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
358 assert_eq!(dict.get_token_id("permission"), Some(tid(10)));
359
360 let test_id = dict.get_or_assign("test");
362 assert_eq!(test_id, tid(11));
363 }
364
365 #[test]
366 fn test_get_or_assign_new_token() {
367 let mut dict = TokenDictionary::new(5);
368
369 let id1 = dict.get_or_assign("hello");
370 let id2 = dict.get_or_assign("world");
371
372 assert_eq!(id1, 5);
374 assert_eq!(id2, 6);
375 assert_eq!(dict.tokens_to_ids.len(), 2);
376 }
377
378 #[test]
379 fn test_get_or_assign_existing_token() {
380 let mut dict = TokenDictionary::new(5);
381
382 let id1 = dict.get_or_assign("hello");
383 let id2 = dict.get_or_assign("hello");
384
385 assert_eq!(id1, id2);
387 assert_eq!(dict.tokens_to_ids.len(), 1);
388 }
389
390 #[test]
391 fn test_get_or_assign_with_preexisting_legalese() {
392 let legalese = [("license".to_string(), 0u16)];
393 let mut dict = TokenDictionary::new_with_legalese(
394 &legalese
395 .iter()
396 .map(|(s, i)| (s.as_str(), *i))
397 .collect::<Vec<_>>(),
398 );
399
400 let id = dict.get_or_assign("license");
402 assert_eq!(id, 0);
403 assert_eq!(dict.tokens_to_ids.len(), 1);
404
405 let new_id = dict.get_or_assign("new");
407 assert_eq!(new_id, 1);
408 assert_eq!(dict.tokens_to_ids.len(), 2);
409 }
410
411 #[test]
412 fn test_get_existing_token() {
413 let mut dict = TokenDictionary::new(5);
414
415 dict.get_or_assign("hello");
416 assert_eq!(dict.get_token_id("hello"), Some(tid(5)));
417 }
418
419 #[test]
420 fn test_get_nonexistent_token() {
421 let dict = TokenDictionary::new(5);
422 assert_eq!(dict.get_token_id("hello"), None);
423 }
424
425 #[test]
426 fn test_legalese_range() {
427 let dict = TokenDictionary::new(10);
428
429 assert!(0 < dict.legalese_count() as u16);
431 assert!(5 < dict.legalese_count() as u16);
432 assert!(9 < dict.legalese_count() as u16);
433
434 assert!(10 >= dict.legalese_count() as u16);
436 assert!(100 >= dict.legalese_count() as u16);
437 }
438
439 #[test]
440 fn test_legalese_range_with_actual_legalese() {
441 let legalese = [
442 ("license".to_string(), 0u16),
443 ("copyright".to_string(), 1u16),
444 ];
445
446 let mut dict = TokenDictionary::new_with_legalese(
447 &legalese
448 .iter()
449 .map(|(s, i)| (s.as_str(), *i))
450 .collect::<Vec<_>>(),
451 );
452
453 assert!(dict.get_token_id("license").unwrap() < dict.legalese_count() as u16);
455 assert!(dict.get_token_id("copyright").unwrap() < dict.legalese_count() as u16);
456
457 let regular_id = dict.get_or_assign("regular");
459 assert!(regular_id >= dict.legalese_count() as u16);
460 }
461
462 #[test]
463 fn test_token_dictionary_default() {
464 let dict = TokenDictionary::default();
465 assert_eq!(dict.legalese_count(), 0);
466 assert!(dict.tokens_to_ids.is_empty());
467 }
468
469 #[test]
470 fn test_get_alias() {
471 let mut dict = TokenDictionary::new(5);
472 dict.get_or_assign("hello");
473
474 assert_eq!(dict.get("hello"), dict.get_token_id("hello"));
476 }
477
478 #[test]
479 fn test_with_actual_legalese_module() {
480 use crate::license_detection::rules::legalese;
481
482 let legalese_words = legalese::get_legalese_words();
483 assert!(!legalese_words.is_empty(), "Should have legalese words");
484
485 let mut dict = TokenDictionary::new_with_legalese(&legalese_words);
486
487 assert_eq!(dict.legalese_count(), legalese_words.len());
489 assert_eq!(dict.tokens_to_ids.len(), legalese_words.len());
490
491 let license_id = dict.get_token_id("license");
493 assert!(license_id.is_some(), "License should be in dictionary");
494 assert!(
495 license_id.unwrap() < dict.legalese_count() as u16,
496 "License should be a legalese token"
497 );
498
499 let copyrighted_id = dict.get_token_id("copyrighted");
502 assert!(
503 copyrighted_id.is_some(),
504 "Copyrighted should be in dictionary"
505 );
506 assert!(
507 copyrighted_id.unwrap() < dict.legalese_count() as u16,
508 "Copyrighted should be a legalese token"
509 );
510
511 let hello_id = dict.get_or_assign("hello");
513 assert!(hello_id >= dict.legalese_count() as u16);
514 assert!(hello_id >= dict.legalese_count() as u16);
515 }
516}