1use std::collections::HashMap;
7
8use rkyv::Archive;
9use serde::{Deserialize, Serialize};
10
11#[derive(
12 Debug,
13 Clone,
14 Copy,
15 PartialEq,
16 Eq,
17 Hash,
18 PartialOrd,
19 Ord,
20 Serialize,
21 Deserialize,
22 Archive,
23 rkyv::Serialize,
24 rkyv::Deserialize,
25)]
26#[rkyv(derive(Hash, Eq, PartialEq, PartialOrd, Ord))]
27pub struct TokenId(u16);
28
29impl TokenId {
30 pub const fn new(raw: u16) -> Self {
31 Self(raw)
32 }
33
34 pub const fn raw(self) -> u16 {
35 self.0
36 }
37
38 pub const fn as_usize(self) -> usize {
39 self.0 as usize
40 }
41
42 pub const fn to_le_bytes(self) -> [u8; 2] {
43 self.0.to_le_bytes()
44 }
45}
46
47#[cfg(test)]
48pub const fn tid(raw: u16) -> TokenId {
49 TokenId::new(raw)
50}
51
52impl From<u16> for TokenId {
53 fn from(value: u16) -> Self {
54 Self(value)
55 }
56}
57
58impl From<TokenId> for u16 {
59 fn from(value: TokenId) -> Self {
60 value.0
61 }
62}
63
64impl PartialEq<u16> for TokenId {
65 fn eq(&self, other: &u16) -> bool {
66 self.0 == *other
67 }
68}
69
70impl PartialOrd<u16> for TokenId {
71 fn partial_cmp(&self, other: &u16) -> Option<std::cmp::Ordering> {
72 self.0.partial_cmp(other)
73 }
74}
75
76impl PartialEq<TokenId> for u16 {
77 fn eq(&self, other: &TokenId) -> bool {
78 *self == other.0
79 }
80}
81
82impl PartialOrd<TokenId> for u16 {
83 fn partial_cmp(&self, other: &TokenId) -> Option<std::cmp::Ordering> {
84 self.partial_cmp(&other.0)
85 }
86}
87
88#[derive(
89 Debug,
90 Clone,
91 Copy,
92 PartialEq,
93 Eq,
94 Hash,
95 Serialize,
96 Deserialize,
97 Archive,
98 rkyv::Serialize,
99 rkyv::Deserialize,
100)]
101pub enum TokenKind {
102 Legalese,
103 Regular,
104}
105
106#[derive(
107 Debug,
108 Clone,
109 Copy,
110 PartialEq,
111 Eq,
112 Hash,
113 Serialize,
114 Deserialize,
115 Archive,
116 rkyv::Serialize,
117 rkyv::Deserialize,
118)]
119pub struct KnownToken {
120 pub id: TokenId,
121 pub kind: TokenKind,
122 pub is_digit_only: bool,
123 pub is_short_or_digit: bool,
124}
125
126#[derive(
127 Debug,
128 Clone,
129 Copy,
130 PartialEq,
131 Eq,
132 Hash,
133 Serialize,
134 Deserialize,
135 Archive,
136 rkyv::Serialize,
137 rkyv::Deserialize,
138)]
139pub enum QueryToken {
140 Known(KnownToken),
141 Unknown,
142 Stopword,
143}
144
145#[derive(
146 Debug, Clone, Copy, Serialize, Deserialize, Archive, rkyv::Serialize, rkyv::Deserialize,
147)]
148pub struct TokenMetadata {
149 pub kind: TokenKind,
150 pub is_digit_only: bool,
151 pub is_short_or_digit: bool,
152}
153
154#[derive(Debug, Clone, Archive, rkyv::Serialize, rkyv::Deserialize)]
166pub struct TokenDictionary {
167 tokens_to_ids: HashMap<String, TokenId>,
169
170 token_metadata: Vec<Option<TokenMetadata>>,
171
172 len_legalese: usize,
174
175 next_id: TokenId,
177}
178
179impl TokenDictionary {
180 const DEFAULT_METADATA: TokenMetadata = TokenMetadata {
181 kind: TokenKind::Regular,
182 is_digit_only: false,
183 is_short_or_digit: false,
184 };
185
186 pub fn new_with_legalese(legalese_entries: &[(&str, u16)]) -> Self {
197 let mut tokens_to_ids = HashMap::new();
198 let max_existing_id = legalese_entries
199 .iter()
200 .map(|(_, token_id)| *token_id as usize)
201 .max()
202 .unwrap_or(0);
203 let mut token_metadata = vec![None; max_existing_id.saturating_add(1)];
204
205 for (word, token_id) in legalese_entries {
206 let id = TokenId::from(*token_id);
207 tokens_to_ids.insert(word.to_string(), id);
208 token_metadata[id.as_usize()] = Some(TokenMetadata {
209 kind: TokenKind::Legalese,
210 is_digit_only: word.chars().all(|c| c.is_ascii_digit()),
211 is_short_or_digit: word.len() == 1 || word.chars().all(|c| c.is_ascii_digit()),
212 });
213 }
214
215 let len_legalese = legalese_entries.len();
216 let next_id = TokenId::new((max_existing_id + 1).max(len_legalese) as u16);
217
218 Self {
219 tokens_to_ids,
220 token_metadata,
221 len_legalese,
222 next_id,
223 }
224 }
225
226 pub fn new(legalese_count: usize) -> Self {
234 Self {
235 tokens_to_ids: HashMap::new(),
236 token_metadata: Vec::new(),
237 len_legalese: legalese_count,
238 next_id: TokenId::new(legalese_count as u16),
239 }
240 }
241
242 fn metadata_for(&self, id: TokenId) -> TokenMetadata {
243 self.token_metadata
244 .get(id.as_usize())
245 .and_then(|meta| *meta)
246 .unwrap_or(Self::DEFAULT_METADATA)
247 }
248
249 fn build_known_token(&self, id: TokenId) -> KnownToken {
250 let metadata = self.metadata_for(id);
251 KnownToken {
252 id,
253 kind: metadata.kind,
254 is_digit_only: metadata.is_digit_only,
255 is_short_or_digit: metadata.is_short_or_digit,
256 }
257 }
258
259 fn insert_metadata(&mut self, id: TokenId, kind: TokenKind, token: &str) {
260 let raw = id.as_usize();
261 if self.token_metadata.len() <= raw {
262 self.token_metadata.resize(raw + 1, None);
263 }
264 self.token_metadata[raw] = Some(TokenMetadata {
265 kind,
266 is_digit_only: token.chars().all(|c| c.is_ascii_digit()),
267 is_short_or_digit: token.len() == 1 || token.chars().all(|c| c.is_ascii_digit()),
268 });
269 }
270
271 pub fn intern(&mut self, token: &str) -> KnownToken {
272 if let Some(&id) = self.tokens_to_ids.get(token) {
273 return self.build_known_token(id);
274 }
275
276 let id = self.next_id;
277 self.next_id = TokenId::new(self.next_id.raw() + 1);
278 self.tokens_to_ids.insert(token.to_string(), id);
279 self.insert_metadata(id, TokenKind::Regular, token);
280 self.build_known_token(id)
281 }
282
283 pub fn lookup(&self, token: &str) -> Option<KnownToken> {
284 self.tokens_to_ids
285 .get(token)
286 .copied()
287 .map(|id| self.build_known_token(id))
288 }
289
290 pub fn classify_query_token(&self, token: &str) -> QueryToken {
291 self.lookup(token)
292 .map_or(QueryToken::Unknown, QueryToken::Known)
293 }
294
295 pub fn token_kind(&self, token_id: TokenId) -> TokenKind {
296 self.metadata_for(token_id).kind
297 }
298
299 pub fn is_digit_only_token(&self, token_id: TokenId) -> bool {
300 self.metadata_for(token_id).is_digit_only
301 }
302
303 #[cfg(test)]
304 pub fn get_or_assign(&mut self, token: &str) -> TokenId {
305 self.intern(token).id
306 }
307
308 pub fn get_token_id(&self, token: &str) -> Option<TokenId> {
316 self.lookup(token).map(|token| token.id)
317 }
318
319 #[inline]
321 pub fn get(&self, token: &str) -> Option<TokenId> {
322 self.get_token_id(token)
323 }
324
325 pub const fn legalese_count(&self) -> usize {
327 self.len_legalese
328 }
329
330 #[cfg(test)]
332 pub fn tokens_to_ids(&self) -> impl Iterator<Item = (&String, &TokenId)> {
333 self.tokens_to_ids.iter()
334 }
335
336 #[allow(dead_code)]
339 pub fn tokens_to_ids_len(&self) -> usize {
340 self.tokens_to_ids.len()
341 }
342}
343
344impl Default for TokenDictionary {
345 fn default() -> Self {
346 Self::new(0)
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_token_dictionary_new() {
356 let dict = TokenDictionary::new(10);
357 assert_eq!(dict.legalese_count(), 10);
358 assert_eq!(dict.tokens_to_ids.len(), 0);
359 assert!(dict.tokens_to_ids.is_empty());
360 }
361
362 #[test]
363 fn test_new_with_legalese() {
364 let legalese = [
365 ("license".to_string(), 0u16),
366 ("copyright".to_string(), 1u16),
367 ("permission".to_string(), 2u16),
368 ];
369
370 let mut dict = TokenDictionary::new_with_legalese(
371 &legalese
372 .iter()
373 .map(|(s, i)| (s.as_str(), *i))
374 .collect::<Vec<_>>(),
375 );
376
377 assert_eq!(dict.legalese_count(), 3);
378 assert_eq!(dict.tokens_to_ids.len(), 3);
379 assert!(!dict.tokens_to_ids.is_empty());
380
381 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
383 assert_eq!(dict.get_token_id("copyright"), Some(tid(1)));
384 assert_eq!(dict.get_token_id("permission"), Some(tid(2)));
385
386 let test_id = dict.get_or_assign("test");
388 assert_eq!(test_id, 3);
389 }
390
391 #[test]
392 fn test_new_with_legalese_sorted() {
393 let legalese = [
394 ("copyright".to_string(), 5u16),
395 ("license".to_string(), 0u16),
396 ("permission".to_string(), 10u16),
397 ];
398
399 let mut dict = TokenDictionary::new_with_legalese(
400 &legalese
401 .iter()
402 .map(|(s, i)| (s.as_str(), *i))
403 .collect::<Vec<_>>(),
404 );
405
406 assert_eq!(dict.legalese_count(), 3);
407 assert_eq!(dict.tokens_to_ids.len(), 3);
408
409 assert_eq!(dict.get_token_id("copyright"), Some(tid(5)));
411 assert_eq!(dict.get_token_id("license"), Some(tid(0)));
412 assert_eq!(dict.get_token_id("permission"), Some(tid(10)));
413
414 let test_id = dict.get_or_assign("test");
416 assert_eq!(test_id, tid(11));
417 }
418
419 #[test]
420 fn test_get_or_assign_new_token() {
421 let mut dict = TokenDictionary::new(5);
422
423 let id1 = dict.get_or_assign("hello");
424 let id2 = dict.get_or_assign("world");
425
426 assert_eq!(id1, 5);
428 assert_eq!(id2, 6);
429 assert_eq!(dict.tokens_to_ids.len(), 2);
430 }
431
432 #[test]
433 fn test_get_or_assign_existing_token() {
434 let mut dict = TokenDictionary::new(5);
435
436 let id1 = dict.get_or_assign("hello");
437 let id2 = dict.get_or_assign("hello");
438
439 assert_eq!(id1, id2);
441 assert_eq!(dict.tokens_to_ids.len(), 1);
442 }
443
444 #[test]
445 fn test_get_or_assign_with_preexisting_legalese() {
446 let legalese = [("license".to_string(), 0u16)];
447 let mut dict = TokenDictionary::new_with_legalese(
448 &legalese
449 .iter()
450 .map(|(s, i)| (s.as_str(), *i))
451 .collect::<Vec<_>>(),
452 );
453
454 let id = dict.get_or_assign("license");
456 assert_eq!(id, 0);
457 assert_eq!(dict.tokens_to_ids.len(), 1);
458
459 let new_id = dict.get_or_assign("new");
461 assert_eq!(new_id, 1);
462 assert_eq!(dict.tokens_to_ids.len(), 2);
463 }
464
465 #[test]
466 fn test_get_existing_token() {
467 let mut dict = TokenDictionary::new(5);
468
469 dict.get_or_assign("hello");
470 assert_eq!(dict.get_token_id("hello"), Some(tid(5)));
471 }
472
473 #[test]
474 fn test_get_nonexistent_token() {
475 let dict = TokenDictionary::new(5);
476 assert_eq!(dict.get_token_id("hello"), None);
477 }
478
479 #[test]
480 fn test_legalese_range() {
481 let dict = TokenDictionary::new(10);
482
483 assert!(0 < dict.legalese_count() as u16);
485 assert!(5 < dict.legalese_count() as u16);
486 assert!(9 < dict.legalese_count() as u16);
487
488 assert!(10 >= dict.legalese_count() as u16);
490 assert!(100 >= dict.legalese_count() as u16);
491 }
492
493 #[test]
494 fn test_legalese_range_with_actual_legalese() {
495 let legalese = [
496 ("license".to_string(), 0u16),
497 ("copyright".to_string(), 1u16),
498 ];
499
500 let mut dict = TokenDictionary::new_with_legalese(
501 &legalese
502 .iter()
503 .map(|(s, i)| (s.as_str(), *i))
504 .collect::<Vec<_>>(),
505 );
506
507 assert!(dict.get_token_id("license").unwrap() < dict.legalese_count() as u16);
509 assert!(dict.get_token_id("copyright").unwrap() < dict.legalese_count() as u16);
510
511 let regular_id = dict.get_or_assign("regular");
513 assert!(regular_id >= dict.legalese_count() as u16);
514 }
515
516 #[test]
517 fn test_token_dictionary_default() {
518 let dict = TokenDictionary::default();
519 assert_eq!(dict.legalese_count(), 0);
520 assert!(dict.tokens_to_ids.is_empty());
521 }
522
523 #[test]
524 fn test_get_alias() {
525 let mut dict = TokenDictionary::new(5);
526 dict.get_or_assign("hello");
527
528 assert_eq!(dict.get("hello"), dict.get_token_id("hello"));
530 }
531
532 #[test]
533 fn test_with_actual_legalese_module() {
534 use crate::license_detection::rules::legalese;
535
536 let legalese_words = legalese::get_legalese_words();
537 assert!(!legalese_words.is_empty(), "Should have legalese words");
538
539 let mut dict = TokenDictionary::new_with_legalese(&legalese_words);
540
541 assert_eq!(dict.legalese_count(), legalese_words.len());
543 assert_eq!(dict.tokens_to_ids.len(), legalese_words.len());
544
545 let license_id = dict.get_token_id("license");
547 assert!(license_id.is_some(), "License should be in dictionary");
548 assert!(
549 license_id.unwrap() < dict.legalese_count() as u16,
550 "License should be a legalese token"
551 );
552
553 let copyrighted_id = dict.get_token_id("copyrighted");
556 assert!(
557 copyrighted_id.is_some(),
558 "Copyrighted should be in dictionary"
559 );
560 assert!(
561 copyrighted_id.unwrap() < dict.legalese_count() as u16,
562 "Copyrighted should be a legalese token"
563 );
564
565 let hello_id = dict.get_or_assign("hello");
567 assert!(hello_id >= dict.legalese_count() as u16);
568 assert!(hello_id >= dict.legalese_count() as u16);
569 }
570}