use crate::Result;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct WordInfo {
pub word_id: u32,
pub remapped_id: u32, pub count: u64,
pub sample_prob: f32,
}
#[derive(Debug)]
pub struct Vocabulary {
words: HashMap<u32, WordInfo>,
remapped_to_word_id: Vec<u32>,
word_id_to_remapped: Vec<Option<u32>>,
total_words: u64,
max_word_id: u32,
min_count: u64,
sample: f64,
}
impl Vocabulary {
pub fn new(min_count: u64, sample: f64) -> Self {
Self {
words: HashMap::new(),
remapped_to_word_id: Vec::new(),
word_id_to_remapped: Vec::new(),
total_words: 0,
max_word_id: 0,
min_count,
sample,
}
}
pub fn build_from_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut counts: HashMap<u32, u64> = HashMap::new();
let mut total = 0u64;
for line in reader.lines() {
let line = line?;
for token in line.split_whitespace() {
if let Ok(word_id) = token.parse::<u32>() {
*counts.entry(word_id).or_insert(0) += 1;
total += 1;
}
}
}
self.total_words = total;
let unique_words = counts.len();
let mut word_list: Vec<(u32, u64)> = counts
.into_iter()
.filter(|(_, count)| *count >= self.min_count)
.collect();
word_list.sort_by(|a, b| b.1.cmp(&a.1));
for (remapped_id, (word_id, count)) in word_list.into_iter().enumerate() {
let freq = count as f64 / total as f64;
let sample_prob = if self.sample > 0.0 {
((self.sample / freq).sqrt() + (self.sample / freq)).min(1.0) as f32
} else {
1.0
};
self.words.insert(
word_id,
WordInfo {
word_id,
remapped_id: remapped_id as u32,
count,
sample_prob,
},
);
self.remapped_to_word_id.push(word_id);
if word_id > self.max_word_id {
self.max_word_id = word_id;
}
}
self.word_id_to_remapped = vec![None; (self.max_word_id + 1) as usize];
for info in self.words.values() {
self.word_id_to_remapped[info.word_id as usize] = Some(info.remapped_id);
}
eprintln!("Vocabulary built:");
eprintln!(" Total words: {}", self.total_words);
eprintln!(" Unique words (before filtering): {}", unique_words);
eprintln!(
" Vocab size (after min_count={}): {}",
self.min_count,
self.words.len()
);
eprintln!(
" Remapped IDs: 0-{} (dense indexing)",
self.words.len() - 1
);
eprintln!(
" Fast lookup table: {} entries",
self.word_id_to_remapped.len()
);
Ok(())
}
pub fn contains(&self, word_id: u32) -> bool {
self.words.contains_key(&word_id)
}
pub fn get(&self, word_id: u32) -> Option<&WordInfo> {
self.words.get(&word_id)
}
pub fn len(&self) -> usize {
self.words.len()
}
pub fn is_empty(&self) -> bool {
self.words.is_empty()
}
pub fn total_words(&self) -> u64 {
self.total_words
}
pub fn max_word_id(&self) -> u32 {
self.max_word_id
}
pub fn iter(&self) -> impl Iterator<Item = &WordInfo> {
self.words.values()
}
pub fn get_word_id(&self, remapped_id: u32) -> Option<u32> {
self.remapped_to_word_id.get(remapped_id as usize).copied()
}
#[inline]
pub fn get_remapped_id(&self, word_id: u32) -> Option<u32> {
self.word_id_to_remapped
.get(word_id as usize)
.and_then(|&opt| opt)
}
}