1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use rand::prelude::SliceRandom;
4use crate::text::{load_text_data, TextProcessor};
5
6pub const DEFAULT_EMBEDDING_DIM: usize = 300;
8pub const DEFAULT_LEARNING_RATE: f64 = 0.025;
10pub const DEFAULT_EPOCHS: usize = 10;
12pub const DEFAULT_BATCH_SIZE: usize = 32;
14pub const DEFAULT_CONTEXT_WINDOW: usize = 5;
16pub const DEFAULT_NEGATIVE_SAMPLES: usize = 5;
18pub const DEFAULT_VALIDATION_RATIO: f64 = 0.0;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TrainingConfig {
24 pub embedding_dim: usize,
25 pub learning_rate: f64,
26 pub epochs: usize,
27 pub batch_size: usize,
28 pub context_window: usize,
29 pub negative_samples: usize,
30 pub model_type: ModelType,
31 pub lr_schedule: LearningRateSchedule,
32 pub early_stopping: Option<EarlyStoppingConfig>,
33 pub l2_regularization: Option<f64>,
34 pub gradient_clip: Option<f32>,
35 pub validation_ratio: Option<f64>,
36 pub subsample_threshold: Option<f64>,
39 pub use_unigram_negative_sampling: bool,
42 pub warmup_epochs: Option<usize>,
44 pub checkpoint_interval: Option<usize>,
46 pub checkpoint_path: Option<String>,
48 pub use_parallel: bool,
50}
51
52impl TrainingConfig {
53 pub fn new(model_type: ModelType) -> Self {
61 Self {
62 embedding_dim: DEFAULT_EMBEDDING_DIM,
63 learning_rate: DEFAULT_LEARNING_RATE,
64 epochs: DEFAULT_EPOCHS,
65 batch_size: DEFAULT_BATCH_SIZE,
66 context_window: DEFAULT_CONTEXT_WINDOW,
67 negative_samples: DEFAULT_NEGATIVE_SAMPLES,
68 model_type,
69 lr_schedule: LearningRateSchedule::Constant,
70 early_stopping: None,
71 l2_regularization: None,
72 gradient_clip: None,
73 validation_ratio: None,
74 subsample_threshold: None,
75 use_unigram_negative_sampling: true,
76 warmup_epochs: None,
77 checkpoint_interval: None,
78 checkpoint_path: None,
79 use_parallel: false,
80 }
81 }
82
83 pub fn with_dim(mut self, dim: usize) -> Self {
85 self.embedding_dim = dim;
86 self
87 }
88
89 pub fn with_learning_rate(mut self, lr: f64) -> Self {
91 self.learning_rate = lr;
92 self
93 }
94
95 pub fn with_epochs(mut self, epochs: usize) -> Self {
97 self.epochs = epochs;
98 self
99 }
100
101 pub fn with_batch_size(mut self, bs: usize) -> Self {
103 self.batch_size = bs;
104 self
105 }
106
107 pub fn with_window(mut self, window: usize) -> Self {
109 self.context_window = window;
110 self
111 }
112
113 pub fn with_negative_samples(mut self, ns: usize) -> Self {
115 self.negative_samples = ns;
116 self
117 }
118
119 pub fn with_lr_schedule(mut self, schedule: LearningRateSchedule) -> Self {
121 self.lr_schedule = schedule;
122 self
123 }
124
125 pub fn with_early_stopping(mut self, patience: usize, min_delta: f64) -> Self {
127 self.early_stopping = Some(EarlyStoppingConfig { patience, min_delta });
128 self
129 }
130
131 pub fn with_l2_regularization(mut self, lambda: f64) -> Self {
133 self.l2_regularization = Some(lambda);
134 self
135 }
136
137 pub fn with_gradient_clip(mut self, max_norm: f32) -> Self {
139 self.gradient_clip = Some(max_norm);
140 self
141 }
142
143 pub fn with_validation_ratio(mut self, ratio: f64) -> Self {
145 self.validation_ratio = Some(ratio);
146 self
147 }
148
149 pub fn with_subsample_threshold(mut self, threshold: Option<f64>) -> Self {
151 self.subsample_threshold = threshold;
152 self
153 }
154
155 pub fn with_unigram_negative_sampling(mut self, enabled: bool) -> Self {
157 self.use_unigram_negative_sampling = enabled;
158 self
159 }
160
161 pub fn with_warmup_epochs(mut self, epochs: Option<usize>) -> Self {
163 self.warmup_epochs = epochs;
164 self
165 }
166
167 pub fn with_checkpoint_interval(mut self, interval: Option<usize>) -> Self {
169 self.checkpoint_interval = interval;
170 self
171 }
172
173 pub fn with_checkpoint_path(mut self, path: Option<String>) -> Self {
175 self.checkpoint_path = path;
176 self
177 }
178
179 pub fn with_parallel(mut self, enabled: bool) -> Self {
181 self.use_parallel = enabled;
182 self
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum LearningRateSchedule {
189 Constant,
190 Exponential { decay_rate: f64 },
191 Step { step_size: usize, gamma: f64 },
192 Cosine { t_max: usize },
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct EarlyStoppingConfig {
198 pub patience: usize,
199 pub min_delta: f64,
200}
201
202#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
204pub enum ModelType {
205 SkipGram,
206 Cbow,
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct TrainingData {
212 pub sentences: Vec<Vec<String>>,
213 pub vocab: HashMap<String, usize>,
214 pub reverse_vocab: Vec<String>,
215 pub word_freq: Vec<usize>,
217}
218
219impl TrainingData {
220 pub fn from_text(text: &str) -> Self {
228 let sentences = load_text_data(text);
229 let (vocab, reverse_vocab, word_freq) = crate::text::build_vocab_with_freq(&sentences);
230 Self { sentences, vocab, reverse_vocab, word_freq }
231 }
232
233 pub fn from_file(path: &str) -> Result<Self, String> {
238 let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
239 let sentences = load_text_data(&content);
240 let (vocab, reverse_vocab, word_freq) = crate::text::build_vocab_with_freq(&sentences);
241 Ok(Self { sentences, vocab, reverse_vocab, word_freq })
242 }
243
244 pub fn total_word_count(&self) -> usize {
246 self.word_freq.iter().sum()
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct DataLoader {
253 pub batch_size: usize,
254 pub shuffle: bool,
255 pub file_path: Option<String>,
256}
257
258impl DataLoader {
259 pub fn new(batch_size: usize, shuffle: bool) -> Self {
261 Self {
262 batch_size,
263 shuffle,
264 file_path: None,
265 }
266 }
267
268 pub fn set_file_path(&mut self, path: String) {
270 self.file_path = Some(path);
271 }
272
273 pub fn load_batches(&self, sentences: &[Vec<String>]) -> Vec<Vec<Vec<String>>> {
275 let mut batches = Vec::new();
276 let mut current_batch = Vec::new();
277
278 for sentence in sentences {
279 current_batch.push(sentence.clone());
280
281 if current_batch.len() >= self.batch_size {
282 if self.shuffle {
283 let mut rng = rand::thread_rng();
284 current_batch.shuffle(&mut rng);
285 }
286 batches.push(current_batch.clone());
287 current_batch.clear();
288 }
289 }
290
291 if !current_batch.is_empty() {
293 if self.shuffle {
294 let mut rng = rand::thread_rng();
295 current_batch.shuffle(&mut rng);
296 }
297 batches.push(current_batch);
298 }
299
300 batches
301 }
302
303 pub fn load_lazily(&self, file_path: &str) -> Result<Vec<Vec<String>>, String> {
305 use std::fs::File;
306 use std::io::Read;
307
308 let mut file = File::open(file_path).map_err(|e| e.to_string())?;
309 let mut content = String::new();
310 file.read_to_string(&mut content).map_err(|e| e.to_string())?;
311
312 Ok(load_text_data(&content))
313 }
314
315 pub fn stream_sentences(&self, file_path: &str) -> Result<Box<dyn Iterator<Item = Vec<String>>>, String> {
317 use std::fs::File;
318 use std::io::{BufRead, BufReader};
319
320 let file = File::open(file_path).map_err(|e| e.to_string())?;
321 let reader = BufReader::new(file);
322 let processor = TextProcessor::default();
323
324 let iter = reader.lines().filter_map(move |line| {
325 let line = line.ok()?;
326 let sentences = processor.process_text(&line);
327 if sentences.is_empty() {
328 None
329 } else {
330 Some(sentences.into_iter().flatten().collect::<Vec<String>>())
331 }
332 });
333
334 Ok(Box::new(iter))
335 }
336}