1use crate::Result;
4use std::collections::HashMap;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::Path;
8
9#[derive(Debug, Clone)]
11pub struct WordInfo {
12 pub word_id: u32,
13 pub remapped_id: u32, pub count: u64,
15 pub sample_prob: f32,
17}
18
19#[derive(Debug)]
21pub struct Vocabulary {
22 words: HashMap<u32, WordInfo>,
24 remapped_to_word_id: Vec<u32>,
26 word_id_to_remapped: Vec<Option<u32>>,
28 total_words: u64,
30 max_word_id: u32,
32 min_count: u64,
34 sample: f64,
36}
37
38impl Vocabulary {
39 pub fn new(min_count: u64, sample: f64) -> Self {
41 Self {
42 words: HashMap::new(),
43 remapped_to_word_id: Vec::new(),
44 word_id_to_remapped: Vec::new(),
45 total_words: 0,
46 max_word_id: 0,
47 min_count,
48 sample,
49 }
50 }
51
52 pub fn build_from_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
54 let file = File::open(path)?;
55 let reader = BufReader::new(file);
56
57 let mut counts: HashMap<u32, u64> = HashMap::new();
59 let mut total = 0u64;
60
61 for line in reader.lines() {
62 let line = line?;
63 for token in line.split_whitespace() {
64 if let Ok(word_id) = token.parse::<u32>() {
65 *counts.entry(word_id).or_insert(0) += 1;
66 total += 1;
67 }
68 }
69 }
70
71 self.total_words = total;
72
73 let unique_words = counts.len();
74
75 let mut word_list: Vec<(u32, u64)> = counts
78 .into_iter()
79 .filter(|(_, count)| *count >= self.min_count)
80 .collect();
81
82 word_list.sort_by(|a, b| b.1.cmp(&a.1));
84
85 for (remapped_id, (word_id, count)) in word_list.into_iter().enumerate() {
87 let freq = count as f64 / total as f64;
88 let sample_prob = if self.sample > 0.0 {
89 ((self.sample / freq).sqrt() + (self.sample / freq)).min(1.0) as f32
91 } else {
92 1.0
93 };
94
95 self.words.insert(
96 word_id,
97 WordInfo {
98 word_id,
99 remapped_id: remapped_id as u32,
100 count,
101 sample_prob,
102 },
103 );
104
105 self.remapped_to_word_id.push(word_id);
107
108 if word_id > self.max_word_id {
110 self.max_word_id = word_id;
111 }
112 }
113
114 self.word_id_to_remapped = vec![None; (self.max_word_id + 1) as usize];
116 for info in self.words.values() {
117 self.word_id_to_remapped[info.word_id as usize] = Some(info.remapped_id);
118 }
119
120 eprintln!("Vocabulary built:");
121 eprintln!(" Total words: {}", self.total_words);
122 eprintln!(" Unique words (before filtering): {}", unique_words);
123 eprintln!(
124 " Vocab size (after min_count={}): {}",
125 self.min_count,
126 self.words.len()
127 );
128 eprintln!(
129 " Remapped IDs: 0-{} (dense indexing)",
130 self.words.len() - 1
131 );
132 eprintln!(
133 " Fast lookup table: {} entries",
134 self.word_id_to_remapped.len()
135 );
136
137 Ok(())
138 }
139
140 pub fn contains(&self, word_id: u32) -> bool {
142 self.words.contains_key(&word_id)
143 }
144
145 pub fn get(&self, word_id: u32) -> Option<&WordInfo> {
147 self.words.get(&word_id)
148 }
149
150 pub fn len(&self) -> usize {
152 self.words.len()
153 }
154
155 pub fn is_empty(&self) -> bool {
157 self.words.is_empty()
158 }
159
160 pub fn total_words(&self) -> u64 {
162 self.total_words
163 }
164
165 pub fn max_word_id(&self) -> u32 {
167 self.max_word_id
168 }
169
170 pub fn iter(&self) -> impl Iterator<Item = &WordInfo> {
172 self.words.values()
173 }
174
175 pub fn get_word_id(&self, remapped_id: u32) -> Option<u32> {
177 self.remapped_to_word_id.get(remapped_id as usize).copied()
178 }
179
180 #[inline]
182 pub fn get_remapped_id(&self, word_id: u32) -> Option<u32> {
183 self.word_id_to_remapped
184 .get(word_id as usize)
185 .and_then(|&opt| opt)
186 }
187}