1use std::collections::{BTreeMap, HashMap};
7
8use rkyv::Archive;
9use rkyv::Archived;
10use serde::{Deserialize, Serialize};
11
12#[derive(
13 Debug,
14 Clone,
15 Copy,
16 PartialEq,
17 Eq,
18 Hash,
19 PartialOrd,
20 Ord,
21 Serialize,
22 Deserialize,
23 Archive,
24 rkyv::Serialize,
25 rkyv::Deserialize,
26)]
27#[rkyv(derive(Hash, Eq, PartialEq, PartialOrd, Ord))]
28pub struct TokenId(u16);
29
30impl TokenId {
31 pub const fn new(raw: u16) -> Self {
32 Self(raw)
33 }
34
35 pub const fn raw(self) -> u16 {
36 self.0
37 }
38
39 pub const fn as_usize(self) -> usize {
40 self.0 as usize
41 }
42
43 pub const fn to_le_bytes(self) -> [u8; 2] {
44 self.0.to_le_bytes()
45 }
46}
47
48#[cfg(test)]
49pub const fn tid(raw: u16) -> TokenId {
50 TokenId::new(raw)
51}
52
53impl From<u16> for TokenId {
54 fn from(value: u16) -> Self {
55 Self(value)
56 }
57}
58
59impl From<TokenId> for u16 {
60 fn from(value: TokenId) -> Self {
61 value.0
62 }
63}
64
65impl PartialEq<u16> for TokenId {
66 fn eq(&self, other: &u16) -> bool {
67 self.0 == *other
68 }
69}
70
71impl PartialOrd<u16> for TokenId {
72 fn partial_cmp(&self, other: &u16) -> Option<std::cmp::Ordering> {
73 self.0.partial_cmp(other)
74 }
75}
76
77impl PartialEq<TokenId> for u16 {
78 fn eq(&self, other: &TokenId) -> bool {
79 *self == other.0
80 }
81}
82
83impl PartialOrd<TokenId> for u16 {
84 fn partial_cmp(&self, other: &TokenId) -> Option<std::cmp::Ordering> {
85 self.partial_cmp(&other.0)
86 }
87}
88
89#[derive(
90 Debug,
91 Clone,
92 Copy,
93 PartialEq,
94 Eq,
95 Hash,
96 Serialize,
97 Deserialize,
98 Archive,
99 rkyv::Serialize,
100 rkyv::Deserialize,
101)]
102pub enum TokenKind {
103 Legalese,
104 Regular,
105}
106
107#[derive(
108 Debug,
109 Clone,
110 Copy,
111 PartialEq,
112 Eq,
113 Hash,
114 Serialize,
115 Deserialize,
116 Archive,
117 rkyv::Serialize,
118 rkyv::Deserialize,
119)]
120pub struct KnownToken {
121 pub id: TokenId,
122 pub kind: TokenKind,
123 pub is_digit_only: bool,
124 pub is_short_or_digit: bool,
125}
126
127#[derive(
128 Debug,
129 Clone,
130 Copy,
131 PartialEq,
132 Eq,
133 Hash,
134 Serialize,
135 Deserialize,
136 Archive,
137 rkyv::Serialize,
138 rkyv::Deserialize,
139)]
140pub enum QueryToken {
141 Known(KnownToken),
142 Unknown,
143 Stopword,
144}
145
146#[derive(
147 Debug, Clone, Copy, Serialize, Deserialize, Archive, rkyv::Serialize, rkyv::Deserialize,
148)]
149pub struct TokenMetadata {
150 pub kind: TokenKind,
151 pub is_digit_only: bool,
152 pub is_short_or_digit: bool,
153}
154
155#[derive(Debug, Clone, Archive, rkyv::Serialize, rkyv::Deserialize)]
167pub struct TokenDictionary {
168 tokens_to_ids: HashMap<String, TokenId>,
170
171 token_metadata: Vec<Option<TokenMetadata>>,
172
173 len_legalese: usize,
175
176 next_id: TokenId,
178}
179
180impl TokenDictionary {
181 const DEFAULT_METADATA: TokenMetadata = TokenMetadata {
182 kind: TokenKind::Regular,
183 is_digit_only: false,
184 is_short_or_digit: false,
185 };
186
187 pub fn new_with_legalese(legalese: &Archived<BTreeMap<String, u16>>) -> Self {
200 let mut tokens_to_ids = HashMap::new();
201 let max_existing_id = legalese
202 .iter()
203 .map(|(_, id)| id.to_native() as usize)
204 .max()
205 .unwrap_or(0);
206 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
207
208 for (word, id) in legalese.iter() {
209 let native_id = TokenId::new(id.to_native());
210 tokens_to_ids.insert(word.to_string(), native_id);
211 token_metadata[native_id.as_usize()] = Some(TokenMetadata {
212 kind: TokenKind::Legalese,
213 is_digit_only: word.chars().all(|c: char| c.is_ascii_digit()),
214 is_short_or_digit: word.len() == 1
215 || word.chars().all(|c: char| c.is_ascii_digit()),
216 });
217 }
218
219 let len_legalese = legalese.len();
220 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
221
222 Self {
223 tokens_to_ids,
224 token_metadata,
225 len_legalese,
226 next_id,
227 }
228 }
229
230 pub fn new_with_legalese_pairs(legalese_entries: &[(&str, u16)]) -> Self {
234 let mut tokens_to_ids = HashMap::new();
235 let max_existing_id = legalese_entries
236 .iter()
237 .map(|(_, token_id)| *token_id as usize)
238 .max()
239 .unwrap_or(0);
240 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
241
242 for (word, token_id) in legalese_entries {
243 let id = TokenId::from(*token_id);
244 tokens_to_ids.insert(word.to_string(), id);
245 token_metadata[id.as_usize()] = Some(TokenMetadata {
246 kind: TokenKind::Legalese,
247 is_digit_only: word.chars().all(|c| c.is_ascii_digit()),
248 is_short_or_digit: word.len() == 1 || word.chars().all(|c| c.is_ascii_digit()),
249 });
250 }
251
252 let len_legalese = legalese_entries.len();
253 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
254
255 Self {
256 tokens_to_ids,
257 token_metadata,
258 len_legalese,
259 next_id,
260 }
261 }
262
263 pub fn new(legalese_count: usize) -> Self {
271 Self {
272 tokens_to_ids: HashMap::new(),
273 token_metadata: Vec::new(),
274 len_legalese: legalese_count,
275 next_id: TokenId::new(legalese_count as u16),
276 }
277 }
278
279 fn metadata_for(&self, id: TokenId) -> TokenMetadata {
280 self.token_metadata
281 .get(id.as_usize())
282 .and_then(|meta| *meta)
283 .unwrap_or(Self::DEFAULT_METADATA)
284 }
285
286 fn build_known_token(&self, id: TokenId) -> KnownToken {
287 let metadata = self.metadata_for(id);
288 KnownToken {
289 id,
290 kind: metadata.kind,
291 is_digit_only: metadata.is_digit_only,
292 is_short_or_digit: metadata.is_short_or_digit,
293 }
294 }
295
296 fn insert_metadata(&mut self, id: TokenId, kind: TokenKind, token: &str) {
297 let raw = id.as_usize();
298 if self.token_metadata.len() <= raw {
299 self.token_metadata.resize(raw + 1, None);
300 }
301 self.token_metadata[raw] = Some(TokenMetadata {
302 kind,
303 is_digit_only: token.chars().all(|c| c.is_ascii_digit()),
304 is_short_or_digit: token.len() == 1 || token.chars().all(|c| c.is_ascii_digit()),
305 });
306 }
307
308 pub fn intern(&mut self, token: &str) -> KnownToken {
309 if let Some(&id) = self.tokens_to_ids.get(token) {
310 return self.build_known_token(id);
311 }
312
313 let id = self.next_id;
314 self.next_id = TokenId::new(self.next_id.raw() + 1);
315 self.tokens_to_ids.insert(token.to_string(), id);
316 self.insert_metadata(id, TokenKind::Regular, token);
317 self.build_known_token(id)
318 }
319
320 pub fn lookup(&self, token: &str) -> Option<KnownToken> {
321 self.tokens_to_ids
322 .get(token)
323 .copied()
324 .map(|id| self.build_known_token(id))
325 }
326
327 pub fn classify_query_token(&self, token: &str) -> QueryToken {
328 self.lookup(token)
329 .map_or(QueryToken::Unknown, QueryToken::Known)
330 }
331
332 pub fn token_kind(&self, token_id: TokenId) -> TokenKind {
333 self.metadata_for(token_id).kind
334 }
335
336 pub fn is_digit_only_token(&self, token_id: TokenId) -> bool {
337 self.metadata_for(token_id).is_digit_only
338 }
339
340 #[cfg(test)]
341 pub fn get_or_assign(&mut self, token: &str) -> TokenId {
342 self.intern(token).id
343 }
344
345 pub fn get_token_id(&self, token: &str) -> Option<TokenId> {
353 self.lookup(token).map(|token| token.id)
354 }
355
356 #[inline]
358 pub fn get(&self, token: &str) -> Option<TokenId> {
359 self.get_token_id(token)
360 }
361
362 pub const fn legalese_count(&self) -> usize {
364 self.len_legalese
365 }
366
367 #[cfg(test)]
369 pub fn tokens_to_ids(&self) -> impl Iterator<Item = (&String, &TokenId)> {
370 self.tokens_to_ids.iter()
371 }
372
373 #[allow(dead_code)]
376 pub fn tokens_to_ids_len(&self) -> usize {
377 self.tokens_to_ids.len()
378 }
379}
380
381impl Default for TokenDictionary {
382 fn default() -> Self {
383 Self::new(0)
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_token_dictionary_new() {
393 let dict = TokenDictionary::new(10);
394 assert_eq!(dict.legalese_count(), 10);
395 assert_eq!(dict.tokens_to_ids.len(), 0);
396 assert!(dict.tokens_to_ids.is_empty());
397 }
398
399 #[test]
400 fn test_new_with_legalese() {
401 let legalese = [
402 ("license".to_string(), 0u16),
403 ("copyright".to_string(), 1u16),
404 ("permission".to_string(), 2u16),
405 ];
406
407 let mut dict = TokenDictionary::new_with_legalese_pairs(
408 &legalese
409 .iter()
410 .map(|(s, i)| (s.as_str(), *i))
411 .collect::<Vec<_>>(),
412 );
413
414 assert_eq!(dict.legalese_count(), 3);
415 assert_eq!(dict.tokens_to_ids.len(), 3);
416 assert!(!dict.tokens_to_ids.is_empty());
417
418 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
420 assert_eq!(dict.get_token_id("copyright"), Some(tid(1)));
421 assert_eq!(dict.get_token_id("permission"), Some(tid(2)));
422
423 let test_id = dict.get_or_assign("test");
425 assert_eq!(test_id, 3);
426 }
427
428 #[test]
429 fn test_new_with_legalese_sorted() {
430 let legalese = [
431 ("copyright".to_string(), 5u16),
432 ("license".to_string(), 0u16),
433 ("permission".to_string(), 10u16),
434 ];
435
436 let mut dict = TokenDictionary::new_with_legalese_pairs(
437 &legalese
438 .iter()
439 .map(|(s, i)| (s.as_str(), *i))
440 .collect::<Vec<_>>(),
441 );
442
443 assert_eq!(dict.legalese_count(), 3);
444 assert_eq!(dict.tokens_to_ids.len(), 3);
445
446 assert_eq!(dict.get_token_id("copyright"), Some(tid(5)));
448 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
449 assert_eq!(dict.get_token_id("permission"), Some(tid(10)));
450
451 let test_id = dict.get_or_assign("test");
453 assert_eq!(test_id, tid(11));
454 }
455
456 #[test]
457 fn test_get_or_assign_new_token() {
458 let mut dict = TokenDictionary::new(5);
459
460 let id1 = dict.get_or_assign("hello");
461 let id2 = dict.get_or_assign("world");
462
463 assert_eq!(id1, 5);
465 assert_eq!(id2, 6);
466 assert_eq!(dict.tokens_to_ids.len(), 2);
467 }
468
469 #[test]
470 fn test_get_or_assign_existing_token() {
471 let mut dict = TokenDictionary::new(5);
472
473 let id1 = dict.get_or_assign("hello");
474 let id2 = dict.get_or_assign("hello");
475
476 assert_eq!(id1, id2);
478 assert_eq!(dict.tokens_to_ids.len(), 1);
479 }
480
481 #[test]
482 fn test_get_or_assign_with_preexisting_legalese() {
483 let legalese = [("license".to_string(), 0u16)];
484 let mut dict = TokenDictionary::new_with_legalese_pairs(
485 &legalese
486 .iter()
487 .map(|(s, i)| (s.as_str(), *i))
488 .collect::<Vec<_>>(),
489 );
490
491 let id = dict.get_or_assign("license");
493 assert_eq!(id, 0);
494 assert_eq!(dict.tokens_to_ids.len(), 1);
495
496 let new_id = dict.get_or_assign("new");
498 assert_eq!(new_id, 1);
499 assert_eq!(dict.tokens_to_ids.len(), 2);
500 }
501
502 #[test]
503 fn test_get_existing_token() {
504 let mut dict = TokenDictionary::new(5);
505
506 dict.get_or_assign("hello");
507 assert_eq!(dict.get_token_id("hello"), Some(tid(5)));
508 }
509
510 #[test]
511 fn test_get_nonexistent_token() {
512 let dict = TokenDictionary::new(5);
513 assert_eq!(dict.get_token_id("hello"), None);
514 }
515
516 #[test]
517 fn test_legalese_range() {
518 let dict = TokenDictionary::new(10);
519
520 assert!(0 < dict.legalese_count() as u16);
522 assert!(5 < dict.legalese_count() as u16);
523 assert!(9 < dict.legalese_count() as u16);
524
525 assert!(10 >= dict.legalese_count() as u16);
527 assert!(100 >= dict.legalese_count() as u16);
528 }
529
530 #[test]
531 fn test_legalese_range_with_actual_legalese() {
532 let legalese = [
533 ("license".to_string(), 0u16),
534 ("copyright".to_string(), 1u16),
535 ];
536
537 let mut dict = TokenDictionary::new_with_legalese_pairs(
538 &legalese
539 .iter()
540 .map(|(s, i)| (s.as_str(), *i))
541 .collect::<Vec<_>>(),
542 );
543
544 assert!(dict.get_token_id("license").unwrap() < dict.legalese_count() as u16);
546 assert!(dict.get_token_id("copyright").unwrap() < dict.legalese_count() as u16);
547
548 let regular_id = dict.get_or_assign("regular");
550 assert!(regular_id >= dict.legalese_count() as u16);
551 }
552
553 #[test]
554 fn test_token_dictionary_default() {
555 let dict = TokenDictionary::default();
556 assert_eq!(dict.legalese_count(), 0);
557 assert!(dict.tokens_to_ids.is_empty());
558 }
559
560 #[test]
561 fn test_get_alias() {
562 let mut dict = TokenDictionary::new(5);
563 dict.get_or_assign("hello");
564
565 assert_eq!(dict.get("hello"), dict.get_token_id("hello"));
567 }
568
569 #[test]
570 fn test_with_actual_legalese_module() {
571 use crate::license_detection::rules::legalese;
572
573 let legalese = legalese::archived_legalese();
574 assert!(!legalese.is_empty(), "Should have legalese words");
575
576 let mut dict = TokenDictionary::new_with_legalese(legalese);
577
578 assert_eq!(dict.legalese_count(), legalese.len());
580 assert_eq!(dict.tokens_to_ids.len(), legalese.len());
581
582 let license_id = dict.get_token_id("license");
584 assert!(license_id.is_some(), "License should be in dictionary");
585 assert!(
586 license_id.unwrap() < dict.legalese_count() as u16,
587 "License should be a legalese token"
588 );
589
590 let copyrighted_id = dict.get_token_id("copyrighted");
593 assert!(
594 copyrighted_id.is_some(),
595 "Copyrighted should be in dictionary"
596 );
597 assert!(
598 copyrighted_id.unwrap() < dict.legalese_count() as u16,
599 "Copyrighted should be a legalese token"
600 );
601
602 let hello_id = dict.get_or_assign("hello");
604 assert!(hello_id >= dict.legalese_count() as u16);
605 assert!(hello_id >= dict.legalese_count() as u16);
606 }
607}