1use std::collections::{BTreeMap, HashMap};
10
11use rkyv::Archive;
12use rkyv::Archived;
13use serde::{Deserialize, Serialize};
14
15#[derive(
16 Debug,
17 Clone,
18 Copy,
19 PartialEq,
20 Eq,
21 Hash,
22 PartialOrd,
23 Ord,
24 Serialize,
25 Deserialize,
26 Archive,
27 rkyv::Serialize,
28 rkyv::Deserialize,
29)]
30#[rkyv(derive(Hash, Eq, PartialEq, PartialOrd, Ord))]
31pub struct TokenId(u16);
32
33impl TokenId {
34 pub const fn new(raw: u16) -> Self {
35 Self(raw)
36 }
37
38 pub const fn raw(self) -> u16 {
39 self.0
40 }
41
42 pub const fn as_usize(self) -> usize {
43 self.0 as usize
44 }
45
46 pub const fn to_le_bytes(self) -> [u8; 2] {
47 self.0.to_le_bytes()
48 }
49}
50
51#[cfg(test)]
52pub const fn tid(raw: u16) -> TokenId {
53 TokenId::new(raw)
54}
55
56impl From<u16> for TokenId {
57 fn from(value: u16) -> Self {
58 Self(value)
59 }
60}
61
62impl From<TokenId> for u16 {
63 fn from(value: TokenId) -> Self {
64 value.0
65 }
66}
67
68impl PartialEq<u16> for TokenId {
69 fn eq(&self, other: &u16) -> bool {
70 self.0 == *other
71 }
72}
73
74impl PartialOrd<u16> for TokenId {
75 fn partial_cmp(&self, other: &u16) -> Option<std::cmp::Ordering> {
76 self.0.partial_cmp(other)
77 }
78}
79
80impl PartialEq<TokenId> for u16 {
81 fn eq(&self, other: &TokenId) -> bool {
82 *self == other.0
83 }
84}
85
86impl PartialOrd<TokenId> for u16 {
87 fn partial_cmp(&self, other: &TokenId) -> Option<std::cmp::Ordering> {
88 self.partial_cmp(&other.0)
89 }
90}
91
92#[derive(
93 Debug,
94 Clone,
95 Copy,
96 PartialEq,
97 Eq,
98 Hash,
99 Serialize,
100 Deserialize,
101 Archive,
102 rkyv::Serialize,
103 rkyv::Deserialize,
104)]
105pub enum TokenKind {
106 Legalese,
107 Regular,
108}
109
110#[derive(
111 Debug,
112 Clone,
113 Copy,
114 PartialEq,
115 Eq,
116 Hash,
117 Serialize,
118 Deserialize,
119 Archive,
120 rkyv::Serialize,
121 rkyv::Deserialize,
122)]
123pub struct KnownToken {
124 pub id: TokenId,
125 pub kind: TokenKind,
126 pub is_digit_only: bool,
127 pub is_short_or_digit: bool,
128}
129
130#[derive(
131 Debug,
132 Clone,
133 Copy,
134 PartialEq,
135 Eq,
136 Hash,
137 Serialize,
138 Deserialize,
139 Archive,
140 rkyv::Serialize,
141 rkyv::Deserialize,
142)]
143pub enum QueryToken {
144 Known(KnownToken),
145 Unknown,
146 Stopword,
147}
148
149#[derive(
150 Debug, Clone, Copy, Serialize, Deserialize, Archive, rkyv::Serialize, rkyv::Deserialize,
151)]
152pub struct TokenMetadata {
153 pub kind: TokenKind,
154 pub is_digit_only: bool,
155 pub is_short_or_digit: bool,
156}
157
158#[derive(Debug, Clone, Archive, rkyv::Serialize, rkyv::Deserialize)]
170pub struct TokenDictionary {
171 tokens_to_ids: HashMap<String, TokenId>,
173
174 token_metadata: Vec<Option<TokenMetadata>>,
175
176 len_legalese: usize,
178
179 next_id: TokenId,
181}
182
183impl TokenDictionary {
184 const DEFAULT_METADATA: TokenMetadata = TokenMetadata {
185 kind: TokenKind::Regular,
186 is_digit_only: false,
187 is_short_or_digit: false,
188 };
189
190 pub fn new_with_legalese(legalese: &Archived<BTreeMap<String, u16>>) -> Self {
203 let mut tokens_to_ids = HashMap::new();
204 let max_existing_id = legalese
205 .iter()
206 .map(|(_, id)| id.to_native() as usize)
207 .max()
208 .unwrap_or(0);
209 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
210
211 for (word, id) in legalese.iter() {
212 let native_id = TokenId::new(id.to_native());
213 tokens_to_ids.insert(word.to_string(), native_id);
214 token_metadata[native_id.as_usize()] = Some(TokenMetadata {
215 kind: TokenKind::Legalese,
216 is_digit_only: word.chars().all(|c: char| c.is_ascii_digit()),
217 is_short_or_digit: word.len() == 1
218 || word.chars().all(|c: char| c.is_ascii_digit()),
219 });
220 }
221
222 let len_legalese = legalese.len();
223 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
224
225 Self {
226 tokens_to_ids,
227 token_metadata,
228 len_legalese,
229 next_id,
230 }
231 }
232
233 pub fn new_with_legalese_pairs(legalese_entries: &[(&str, u16)]) -> Self {
237 let mut tokens_to_ids = HashMap::new();
238 let max_existing_id = legalese_entries
239 .iter()
240 .map(|(_, token_id)| *token_id as usize)
241 .max()
242 .unwrap_or(0);
243 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
244
245 for (word, token_id) in legalese_entries {
246 let id = TokenId::from(*token_id);
247 tokens_to_ids.insert(word.to_string(), id);
248 token_metadata[id.as_usize()] = Some(TokenMetadata {
249 kind: TokenKind::Legalese,
250 is_digit_only: word.chars().all(|c| c.is_ascii_digit()),
251 is_short_or_digit: word.len() == 1 || word.chars().all(|c| c.is_ascii_digit()),
252 });
253 }
254
255 let len_legalese = legalese_entries.len();
256 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
257
258 Self {
259 tokens_to_ids,
260 token_metadata,
261 len_legalese,
262 next_id,
263 }
264 }
265
266 pub fn new(legalese_count: usize) -> Self {
274 Self {
275 tokens_to_ids: HashMap::new(),
276 token_metadata: Vec::new(),
277 len_legalese: legalese_count,
278 next_id: TokenId::new(legalese_count as u16),
279 }
280 }
281
282 fn metadata_for(&self, id: TokenId) -> TokenMetadata {
283 self.token_metadata
284 .get(id.as_usize())
285 .and_then(|meta| *meta)
286 .unwrap_or(Self::DEFAULT_METADATA)
287 }
288
289 fn build_known_token(&self, id: TokenId) -> KnownToken {
290 let metadata = self.metadata_for(id);
291 KnownToken {
292 id,
293 kind: metadata.kind,
294 is_digit_only: metadata.is_digit_only,
295 is_short_or_digit: metadata.is_short_or_digit,
296 }
297 }
298
299 fn insert_metadata(&mut self, id: TokenId, kind: TokenKind, token: &str) {
300 let raw = id.as_usize();
301 if self.token_metadata.len() <= raw {
302 self.token_metadata.resize(raw + 1, None);
303 }
304 self.token_metadata[raw] = Some(TokenMetadata {
305 kind,
306 is_digit_only: token.chars().all(|c| c.is_ascii_digit()),
307 is_short_or_digit: token.len() == 1 || token.chars().all(|c| c.is_ascii_digit()),
308 });
309 }
310
311 pub fn intern(&mut self, token: &str) -> KnownToken {
312 if let Some(&id) = self.tokens_to_ids.get(token) {
313 return self.build_known_token(id);
314 }
315
316 let id = self.next_id;
317 self.next_id = TokenId::new(self.next_id.raw() + 1);
318 self.tokens_to_ids.insert(token.to_string(), id);
319 self.insert_metadata(id, TokenKind::Regular, token);
320 self.build_known_token(id)
321 }
322
323 pub fn lookup(&self, token: &str) -> Option<KnownToken> {
324 self.tokens_to_ids
325 .get(token)
326 .copied()
327 .map(|id| self.build_known_token(id))
328 }
329
330 pub fn classify_query_token(&self, token: &str) -> QueryToken {
331 self.lookup(token)
332 .map_or(QueryToken::Unknown, QueryToken::Known)
333 }
334
335 pub fn token_kind(&self, token_id: TokenId) -> TokenKind {
336 self.metadata_for(token_id).kind
337 }
338
339 pub fn is_digit_only_token(&self, token_id: TokenId) -> bool {
340 self.metadata_for(token_id).is_digit_only
341 }
342
343 #[cfg(test)]
344 pub fn get_or_assign(&mut self, token: &str) -> TokenId {
345 self.intern(token).id
346 }
347
348 pub fn get_token_id(&self, token: &str) -> Option<TokenId> {
356 self.lookup(token).map(|token| token.id)
357 }
358
359 #[inline]
361 pub fn get(&self, token: &str) -> Option<TokenId> {
362 self.get_token_id(token)
363 }
364
365 pub const fn legalese_count(&self) -> usize {
367 self.len_legalese
368 }
369
370 #[cfg(test)]
372 pub fn tokens_to_ids(&self) -> impl Iterator<Item = (&String, &TokenId)> {
373 self.tokens_to_ids.iter()
374 }
375
376 #[allow(dead_code)]
379 pub fn tokens_to_ids_len(&self) -> usize {
380 self.tokens_to_ids.len()
381 }
382}
383
384impl Default for TokenDictionary {
385 fn default() -> Self {
386 Self::new(0)
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_token_dictionary_new() {
396 let dict = TokenDictionary::new(10);
397 assert_eq!(dict.legalese_count(), 10);
398 assert_eq!(dict.tokens_to_ids.len(), 0);
399 assert!(dict.tokens_to_ids.is_empty());
400 }
401
402 #[test]
403 fn test_new_with_legalese() {
404 let legalese = [
405 ("license".to_string(), 0u16),
406 ("copyright".to_string(), 1u16),
407 ("permission".to_string(), 2u16),
408 ];
409
410 let mut dict = TokenDictionary::new_with_legalese_pairs(
411 &legalese
412 .iter()
413 .map(|(s, i)| (s.as_str(), *i))
414 .collect::<Vec<_>>(),
415 );
416
417 assert_eq!(dict.legalese_count(), 3);
418 assert_eq!(dict.tokens_to_ids.len(), 3);
419 assert!(!dict.tokens_to_ids.is_empty());
420
421 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
423 assert_eq!(dict.get_token_id("copyright"), Some(tid(1)));
424 assert_eq!(dict.get_token_id("permission"), Some(tid(2)));
425
426 let test_id = dict.get_or_assign("test");
428 assert_eq!(test_id, 3);
429 }
430
431 #[test]
432 fn test_new_with_legalese_sorted() {
433 let legalese = [
434 ("copyright".to_string(), 5u16),
435 ("license".to_string(), 0u16),
436 ("permission".to_string(), 10u16),
437 ];
438
439 let mut dict = TokenDictionary::new_with_legalese_pairs(
440 &legalese
441 .iter()
442 .map(|(s, i)| (s.as_str(), *i))
443 .collect::<Vec<_>>(),
444 );
445
446 assert_eq!(dict.legalese_count(), 3);
447 assert_eq!(dict.tokens_to_ids.len(), 3);
448
449 assert_eq!(dict.get_token_id("copyright"), Some(tid(5)));
451 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
452 assert_eq!(dict.get_token_id("permission"), Some(tid(10)));
453
454 let test_id = dict.get_or_assign("test");
456 assert_eq!(test_id, tid(11));
457 }
458
459 #[test]
460 fn test_get_or_assign_new_token() {
461 let mut dict = TokenDictionary::new(5);
462
463 let id1 = dict.get_or_assign("hello");
464 let id2 = dict.get_or_assign("world");
465
466 assert_eq!(id1, 5);
468 assert_eq!(id2, 6);
469 assert_eq!(dict.tokens_to_ids.len(), 2);
470 }
471
472 #[test]
473 fn test_get_or_assign_existing_token() {
474 let mut dict = TokenDictionary::new(5);
475
476 let id1 = dict.get_or_assign("hello");
477 let id2 = dict.get_or_assign("hello");
478
479 assert_eq!(id1, id2);
481 assert_eq!(dict.tokens_to_ids.len(), 1);
482 }
483
484 #[test]
485 fn test_get_or_assign_with_preexisting_legalese() {
486 let legalese = [("license".to_string(), 0u16)];
487 let mut dict = TokenDictionary::new_with_legalese_pairs(
488 &legalese
489 .iter()
490 .map(|(s, i)| (s.as_str(), *i))
491 .collect::<Vec<_>>(),
492 );
493
494 let id = dict.get_or_assign("license");
496 assert_eq!(id, 0);
497 assert_eq!(dict.tokens_to_ids.len(), 1);
498
499 let new_id = dict.get_or_assign("new");
501 assert_eq!(new_id, 1);
502 assert_eq!(dict.tokens_to_ids.len(), 2);
503 }
504
505 #[test]
506 fn test_get_existing_token() {
507 let mut dict = TokenDictionary::new(5);
508
509 dict.get_or_assign("hello");
510 assert_eq!(dict.get_token_id("hello"), Some(tid(5)));
511 }
512
513 #[test]
514 fn test_get_nonexistent_token() {
515 let dict = TokenDictionary::new(5);
516 assert_eq!(dict.get_token_id("hello"), None);
517 }
518
519 #[test]
520 fn test_legalese_range() {
521 let dict = TokenDictionary::new(10);
522
523 assert!(0 < dict.legalese_count() as u16);
525 assert!(5 < dict.legalese_count() as u16);
526 assert!(9 < dict.legalese_count() as u16);
527
528 assert!(10 >= dict.legalese_count() as u16);
530 assert!(100 >= dict.legalese_count() as u16);
531 }
532
533 #[test]
534 fn test_legalese_range_with_actual_legalese() {
535 let legalese = [
536 ("license".to_string(), 0u16),
537 ("copyright".to_string(), 1u16),
538 ];
539
540 let mut dict = TokenDictionary::new_with_legalese_pairs(
541 &legalese
542 .iter()
543 .map(|(s, i)| (s.as_str(), *i))
544 .collect::<Vec<_>>(),
545 );
546
547 assert!(dict.get_token_id("license").unwrap() < dict.legalese_count() as u16);
549 assert!(dict.get_token_id("copyright").unwrap() < dict.legalese_count() as u16);
550
551 let regular_id = dict.get_or_assign("regular");
553 assert!(regular_id >= dict.legalese_count() as u16);
554 }
555
556 #[test]
557 fn test_token_dictionary_default() {
558 let dict = TokenDictionary::default();
559 assert_eq!(dict.legalese_count(), 0);
560 assert!(dict.tokens_to_ids.is_empty());
561 }
562
563 #[test]
564 fn test_get_alias() {
565 let mut dict = TokenDictionary::new(5);
566 dict.get_or_assign("hello");
567
568 assert_eq!(dict.get("hello"), dict.get_token_id("hello"));
570 }
571
572 #[test]
573 fn test_with_actual_legalese_module() {
574 use crate::license_detection::rules::legalese;
575
576 let legalese = legalese::archived_legalese();
577 assert!(!legalese.is_empty(), "Should have legalese words");
578
579 let mut dict = TokenDictionary::new_with_legalese(legalese);
580
581 assert_eq!(dict.legalese_count(), legalese.len());
583 assert_eq!(dict.tokens_to_ids.len(), legalese.len());
584
585 let license_id = dict.get_token_id("license");
587 assert!(license_id.is_some(), "License should be in dictionary");
588 assert!(
589 license_id.unwrap() < dict.legalese_count() as u16,
590 "License should be a legalese token"
591 );
592
593 let copyrighted_id = dict.get_token_id("copyrighted");
596 assert!(
597 copyrighted_id.is_some(),
598 "Copyrighted should be in dictionary"
599 );
600 assert!(
601 copyrighted_id.unwrap() < dict.legalese_count() as u16,
602 "Copyrighted should be a legalese token"
603 );
604
605 let hello_id = dict.get_or_assign("hello");
607 assert!(hello_id >= dict.legalese_count() as u16);
608 assert!(hello_id >= dict.legalese_count() as u16);
609 }
610}