mecrab_word2vec/
vocab.rs

1//! Vocabulary management for Word2Vec training
2
3use crate::Result;
4use std::collections::HashMap;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::Path;
8
9/// Word frequency and metadata
10#[derive(Debug, Clone)]
11pub struct WordInfo {
12    pub word_id: u32,
13    pub remapped_id: u32, // Dense 0-based index for training
14    pub count: u64,
15    /// Subsampling probability (probability to keep the word)
16    pub sample_prob: f32,
17}
18
19/// Vocabulary with frequency counts and subsampling
20#[derive(Debug)]
21pub struct Vocabulary {
22    /// word_id → WordInfo
23    words: HashMap<u32, WordInfo>,
24    /// remapped_id → word_id (for saving output)
25    remapped_to_word_id: Vec<u32>,
26    /// Fast lookup: word_id → remapped_id (O(1) access for training)
27    word_id_to_remapped: Vec<Option<u32>>,
28    /// Total word count in corpus
29    total_words: u64,
30    /// Maximum word_id in vocabulary (for MCV1 format compatibility)
31    max_word_id: u32,
32    /// Minimum frequency threshold
33    min_count: u64,
34    /// Subsampling threshold (e.g., 1e-4)
35    sample: f64,
36}
37
38impl Vocabulary {
39    /// Create new vocabulary builder
40    pub fn new(min_count: u64, sample: f64) -> Self {
41        Self {
42            words: HashMap::new(),
43            remapped_to_word_id: Vec::new(),
44            word_id_to_remapped: Vec::new(),
45            total_words: 0,
46            max_word_id: 0,
47            min_count,
48            sample,
49        }
50    }
51
52    /// Build vocabulary from corpus file (space-separated word_ids per line)
53    pub fn build_from_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
54        let file = File::open(path)?;
55        let reader = BufReader::new(file);
56
57        // First pass: count frequencies
58        let mut counts: HashMap<u32, u64> = HashMap::new();
59        let mut total = 0u64;
60
61        for line in reader.lines() {
62            let line = line?;
63            for token in line.split_whitespace() {
64                if let Ok(word_id) = token.parse::<u32>() {
65                    *counts.entry(word_id).or_insert(0) += 1;
66                    total += 1;
67                }
68            }
69        }
70
71        self.total_words = total;
72
73        let unique_words = counts.len();
74
75        // Filter by min_count and compute subsampling probabilities
76        // First collect into a Vec for sorting
77        let mut word_list: Vec<(u32, u64)> = counts
78            .into_iter()
79            .filter(|(_, count)| *count >= self.min_count)
80            .collect();
81
82        // Sort by frequency (descending) for better cache locality
83        word_list.sort_by(|a, b| b.1.cmp(&a.1));
84
85        // Assign remapped_ids 0, 1, 2, ... (dense indexing)
86        for (remapped_id, (word_id, count)) in word_list.into_iter().enumerate() {
87            let freq = count as f64 / total as f64;
88            let sample_prob = if self.sample > 0.0 {
89                // Mikolov's subsampling formula
90                ((self.sample / freq).sqrt() + (self.sample / freq)).min(1.0) as f32
91            } else {
92                1.0
93            };
94
95            self.words.insert(
96                word_id,
97                WordInfo {
98                    word_id,
99                    remapped_id: remapped_id as u32,
100                    count,
101                    sample_prob,
102                },
103            );
104
105            // Build reverse mapping
106            self.remapped_to_word_id.push(word_id);
107
108            // Track maximum word_id (for MCV1 format)
109            if word_id > self.max_word_id {
110                self.max_word_id = word_id;
111            }
112        }
113
114        // Build fast lookup table: word_id → remapped_id (O(1) access)
115        self.word_id_to_remapped = vec![None; (self.max_word_id + 1) as usize];
116        for info in self.words.values() {
117            self.word_id_to_remapped[info.word_id as usize] = Some(info.remapped_id);
118        }
119
120        eprintln!("Vocabulary built:");
121        eprintln!("  Total words: {}", self.total_words);
122        eprintln!("  Unique words (before filtering): {}", unique_words);
123        eprintln!(
124            "  Vocab size (after min_count={}): {}",
125            self.min_count,
126            self.words.len()
127        );
128        eprintln!(
129            "  Remapped IDs: 0-{} (dense indexing)",
130            self.words.len() - 1
131        );
132        eprintln!(
133            "  Fast lookup table: {} entries",
134            self.word_id_to_remapped.len()
135        );
136
137        Ok(())
138    }
139
140    /// Check if word_id is in vocabulary
141    pub fn contains(&self, word_id: u32) -> bool {
142        self.words.contains_key(&word_id)
143    }
144
145    /// Get word info
146    pub fn get(&self, word_id: u32) -> Option<&WordInfo> {
147        self.words.get(&word_id)
148    }
149
150    /// Get vocabulary size
151    pub fn len(&self) -> usize {
152        self.words.len()
153    }
154
155    /// Check if vocabulary is empty
156    pub fn is_empty(&self) -> bool {
157        self.words.is_empty()
158    }
159
160    /// Get total word count
161    pub fn total_words(&self) -> u64 {
162        self.total_words
163    }
164
165    /// Get maximum word_id in vocabulary
166    pub fn max_word_id(&self) -> u32 {
167        self.max_word_id
168    }
169
170    /// Iterate over all words
171    pub fn iter(&self) -> impl Iterator<Item = &WordInfo> {
172        self.words.values()
173    }
174
175    /// Get word_id from remapped_id
176    pub fn get_word_id(&self, remapped_id: u32) -> Option<u32> {
177        self.remapped_to_word_id.get(remapped_id as usize).copied()
178    }
179
180    /// Fast lookup: word_id → remapped_id (O(1), for training hot path)
181    #[inline]
182    pub fn get_remapped_id(&self, word_id: u32) -> Option<u32> {
183        self.word_id_to_remapped
184            .get(word_id as usize)
185            .and_then(|&opt| opt)
186    }
187}