Skip to main content

embedding/
config.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use rand::prelude::SliceRandom;
4use crate::text::{load_text_data, TextProcessor};
5
6/// Default embedding dimension.
7pub const DEFAULT_EMBEDDING_DIM: usize = 300;
8/// Default learning rate.
9pub const DEFAULT_LEARNING_RATE: f64 = 0.025;
10/// Default number of training epochs.
11pub const DEFAULT_EPOCHS: usize = 10;
12/// Default mini-batch size.
13pub const DEFAULT_BATCH_SIZE: usize = 32;
14/// Default context window size.
15pub const DEFAULT_CONTEXT_WINDOW: usize = 5;
16/// Default number of negative samples.
17pub const DEFAULT_NEGATIVE_SAMPLES: usize = 5;
18/// Default validation ratio (0.0 = no validation).
19pub const DEFAULT_VALIDATION_RATIO: f64 = 0.0;
20
21/// Training hyperparameters for embedding models.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TrainingConfig {
24    pub embedding_dim: usize,
25    pub learning_rate: f64,
26    pub epochs: usize,
27    pub batch_size: usize,
28    pub context_window: usize,
29    pub negative_samples: usize,
30    pub model_type: ModelType,
31    pub lr_schedule: LearningRateSchedule,
32    pub early_stopping: Option<EarlyStoppingConfig>,
33    pub l2_regularization: Option<f64>,
34    pub gradient_clip: Option<f32>,
35    pub validation_ratio: Option<f64>,
36    /// Sub-sampling threshold for frequent words (Mikolov et al.).
37    /// `None` disables sub-sampling. Typical value: `1e-5`.
38    pub subsample_threshold: Option<f64>,
39    /// If `true`, sample negative examples using the unigram distribution
40    /// raised to the 3/4 power instead of uniform random.
41    pub use_unigram_negative_sampling: bool,
42    /// Number of epochs for linear LR warm-up. `None` disables warm-up.
43    pub warmup_epochs: Option<usize>,
44    /// Save a checkpoint every N epochs. `None` disables checkpointing.
45    pub checkpoint_interval: Option<usize>,
46    /// Directory to write checkpoint files. Defaults to current directory.
47    pub checkpoint_path: Option<String>,
48    /// If `true`, process sentences in parallel during training.
49    pub use_parallel: bool,
50}
51
52impl TrainingConfig {
53    /// Creates a new [`TrainingConfig`] with sensible defaults.
54    ///
55    /// # Example
56    /// ```rust
57    /// use embedding::{TrainingConfig, ModelType};
58    /// let config = TrainingConfig::new(ModelType::SkipGram);
59    /// ```
60    pub fn new(model_type: ModelType) -> Self {
61        Self {
62            embedding_dim: DEFAULT_EMBEDDING_DIM,
63            learning_rate: DEFAULT_LEARNING_RATE,
64            epochs: DEFAULT_EPOCHS,
65            batch_size: DEFAULT_BATCH_SIZE,
66            context_window: DEFAULT_CONTEXT_WINDOW,
67            negative_samples: DEFAULT_NEGATIVE_SAMPLES,
68            model_type,
69            lr_schedule: LearningRateSchedule::Constant,
70            early_stopping: None,
71            l2_regularization: None,
72            gradient_clip: None,
73            validation_ratio: None,
74            subsample_threshold: None,
75            use_unigram_negative_sampling: true,
76            warmup_epochs: None,
77            checkpoint_interval: None,
78            checkpoint_path: None,
79            use_parallel: false,
80        }
81    }
82
83    /// Fluent setter for embedding dimension.
84    pub fn with_dim(mut self, dim: usize) -> Self {
85        self.embedding_dim = dim;
86        self
87    }
88
89    /// Fluent setter for learning rate.
90    pub fn with_learning_rate(mut self, lr: f64) -> Self {
91        self.learning_rate = lr;
92        self
93    }
94
95    /// Fluent setter for number of epochs.
96    pub fn with_epochs(mut self, epochs: usize) -> Self {
97        self.epochs = epochs;
98        self
99    }
100
101    /// Fluent setter for batch size.
102    pub fn with_batch_size(mut self, bs: usize) -> Self {
103        self.batch_size = bs;
104        self
105    }
106
107    /// Fluent setter for context window.
108    pub fn with_window(mut self, window: usize) -> Self {
109        self.context_window = window;
110        self
111    }
112
113    /// Fluent setter for negative samples.
114    pub fn with_negative_samples(mut self, ns: usize) -> Self {
115        self.negative_samples = ns;
116        self
117    }
118
119    /// Fluent setter for learning rate schedule.
120    pub fn with_lr_schedule(mut self, schedule: LearningRateSchedule) -> Self {
121        self.lr_schedule = schedule;
122        self
123    }
124
125    /// Fluent setter for early stopping.
126    pub fn with_early_stopping(mut self, patience: usize, min_delta: f64) -> Self {
127        self.early_stopping = Some(EarlyStoppingConfig { patience, min_delta });
128        self
129    }
130
131    /// Fluent setter for L2 regularization.
132    pub fn with_l2_regularization(mut self, lambda: f64) -> Self {
133        self.l2_regularization = Some(lambda);
134        self
135    }
136
137    /// Fluent setter for gradient clip.
138    pub fn with_gradient_clip(mut self, max_norm: f32) -> Self {
139        self.gradient_clip = Some(max_norm);
140        self
141    }
142
143    /// Fluent setter for validation ratio.
144    pub fn with_validation_ratio(mut self, ratio: f64) -> Self {
145        self.validation_ratio = Some(ratio);
146        self
147    }
148
149    /// Fluent setter for sub-sampling threshold (`None` disables sub-sampling).
150    pub fn with_subsample_threshold(mut self, threshold: Option<f64>) -> Self {
151        self.subsample_threshold = threshold;
152        self
153    }
154
155    /// Fluent setter for unigram negative sampling.
156    pub fn with_unigram_negative_sampling(mut self, enabled: bool) -> Self {
157        self.use_unigram_negative_sampling = enabled;
158        self
159    }
160
161    /// Fluent setter for LR warm-up epochs (`None` disables warm-up).
162    pub fn with_warmup_epochs(mut self, epochs: Option<usize>) -> Self {
163        self.warmup_epochs = epochs;
164        self
165    }
166
167    /// Fluent setter for checkpoint interval (`None` disables checkpointing).
168    pub fn with_checkpoint_interval(mut self, interval: Option<usize>) -> Self {
169        self.checkpoint_interval = interval;
170        self
171    }
172
173    /// Fluent setter for checkpoint directory.
174    pub fn with_checkpoint_path(mut self, path: Option<String>) -> Self {
175        self.checkpoint_path = path;
176        self
177    }
178
179    /// Fluent setter for parallel training.
180    pub fn with_parallel(mut self, enabled: bool) -> Self {
181        self.use_parallel = enabled;
182        self
183    }
184}
185
186/// Learning rate schedule variants.
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum LearningRateSchedule {
189    Constant,
190    Exponential { decay_rate: f64 },
191    Step { step_size: usize, gamma: f64 },
192    Cosine { t_max: usize },
193}
194
195/// Configuration for early stopping during training.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct EarlyStoppingConfig {
198    pub patience: usize,
199    pub min_delta: f64,
200}
201
202/// Supported embedding model architectures.
203#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204pub enum ModelType {
205    SkipGram,
206    Cbow,
207}
208
209/// Container for tokenized sentences and the vocabulary mapping.
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct TrainingData {
212    pub sentences: Vec<Vec<String>>,
213    pub vocab: HashMap<String, usize>,
214    pub reverse_vocab: Vec<String>,
215    /// Per-vocab-ID word frequencies used for negative-sampling and sub-sampling.
216    pub word_freq: Vec<usize>,
217}
218
219impl TrainingData {
220    /// Creates [`TrainingData`] from raw text by tokenizing and building the vocabulary.
221    ///
222    /// # Example
223    /// ```rust
224    /// use embedding::TrainingData;
225    /// let data = TrainingData::from_text("the cat sat on the mat");
226    /// ```
227    pub fn from_text(text: &str) -> Self {
228        let sentences = load_text_data(text);
229        let (vocab, reverse_vocab, word_freq) = crate::text::build_vocab_with_freq(&sentences);
230        Self { sentences, vocab, reverse_vocab, word_freq }
231    }
232
233    /// Creates [`TrainingData`] from a file by reading, tokenizing, and building the vocabulary.
234    ///
235    /// # Errors
236    /// Returns an error if the file cannot be read.
237    pub fn from_file(path: &str) -> Result<Self, String> {
238        let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
239        let sentences = load_text_data(&content);
240        let (vocab, reverse_vocab, word_freq) = crate::text::build_vocab_with_freq(&sentences);
241        Ok(Self { sentences, vocab, reverse_vocab, word_freq })
242    }
243
244    /// Total number of word occurrences across all sentences.
245    pub fn total_word_count(&self) -> usize {
246        self.word_freq.iter().sum()
247    }
248}
249
250/// Utility for batching and streaming sentence data.
251#[derive(Debug, Clone)]
252pub struct DataLoader {
253    pub batch_size: usize,
254    pub shuffle: bool,
255    pub file_path: Option<String>,
256}
257
258impl DataLoader {
259    /// Creates a new data loader with the given batch size and options.
260    pub fn new(batch_size: usize, shuffle: bool) -> Self {
261        Self {
262            batch_size,
263            shuffle,
264            file_path: None,
265        }
266    }
267    
268    /// Sets the file path for lazy loading.
269    pub fn set_file_path(&mut self, path: String) {
270        self.file_path = Some(path);
271    }
272    
273    /// Groups sentences into fixed-size batches.
274    pub fn load_batches(&self, sentences: &[Vec<String>]) -> Vec<Vec<Vec<String>>> {
275        let mut batches = Vec::new();
276        let mut current_batch = Vec::new();
277        
278        for sentence in sentences {
279            current_batch.push(sentence.clone());
280            
281            if current_batch.len() >= self.batch_size {
282                if self.shuffle {
283                    let mut rng = rand::thread_rng();
284                    current_batch.shuffle(&mut rng);
285                }
286                batches.push(current_batch.clone());
287                current_batch.clear();
288            }
289        }
290        
291        // Add remaining sentences as the last batch
292        if !current_batch.is_empty() {
293            if self.shuffle {
294                let mut rng = rand::thread_rng();
295                current_batch.shuffle(&mut rng);
296            }
297            batches.push(current_batch);
298        }
299        
300        batches
301    }
302    
303    /// Loads sentences from a file.
304    pub fn load_lazily(&self, file_path: &str) -> Result<Vec<Vec<String>>, String> {
305        use std::fs::File;
306        use std::io::Read;
307
308        let mut file = File::open(file_path).map_err(|e| e.to_string())?;
309        let mut content = String::new();
310        file.read_to_string(&mut content).map_err(|e| e.to_string())?;
311
312        Ok(load_text_data(&content))
313    }
314
315    /// Returns a lazy iterator over sentences from a file.
316    pub fn stream_sentences(&self, file_path: &str) -> Result<Box<dyn Iterator<Item = Vec<String>>>, String> {
317        use std::fs::File;
318        use std::io::{BufRead, BufReader};
319
320        let file = File::open(file_path).map_err(|e| e.to_string())?;
321        let reader = BufReader::new(file);
322        let processor = TextProcessor::default();
323
324        let iter = reader.lines().filter_map(move |line| {
325            let line = line.ok()?;
326            let sentences = processor.process_text(&line);
327            if sentences.is_empty() {
328                None
329            } else {
330                Some(sentences.into_iter().flatten().collect::<Vec<String>>())
331            }
332        });
333
334        Ok(Box::new(iter))
335    }
336}