use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use rand::prelude::SliceRandom;
use crate::text::{load_text_data, TextProcessor};
pub const DEFAULT_EMBEDDING_DIM: usize = 300;
pub const DEFAULT_LEARNING_RATE: f64 = 0.025;
pub const DEFAULT_EPOCHS: usize = 10;
pub const DEFAULT_BATCH_SIZE: usize = 32;
pub const DEFAULT_CONTEXT_WINDOW: usize = 5;
pub const DEFAULT_NEGATIVE_SAMPLES: usize = 5;
pub const DEFAULT_VALIDATION_RATIO: f64 = 0.0;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub embedding_dim: usize,
pub learning_rate: f64,
pub epochs: usize,
pub batch_size: usize,
pub context_window: usize,
pub negative_samples: usize,
pub model_type: ModelType,
pub lr_schedule: LearningRateSchedule,
pub early_stopping: Option<EarlyStoppingConfig>,
pub l2_regularization: Option<f64>,
pub gradient_clip: Option<f32>,
pub validation_ratio: Option<f64>,
pub subsample_threshold: Option<f64>,
pub use_unigram_negative_sampling: bool,
pub warmup_epochs: Option<usize>,
pub checkpoint_interval: Option<usize>,
pub checkpoint_path: Option<String>,
pub use_parallel: bool,
}
impl TrainingConfig {
pub fn new(model_type: ModelType) -> Self {
Self {
embedding_dim: DEFAULT_EMBEDDING_DIM,
learning_rate: DEFAULT_LEARNING_RATE,
epochs: DEFAULT_EPOCHS,
batch_size: DEFAULT_BATCH_SIZE,
context_window: DEFAULT_CONTEXT_WINDOW,
negative_samples: DEFAULT_NEGATIVE_SAMPLES,
model_type,
lr_schedule: LearningRateSchedule::Constant,
early_stopping: None,
l2_regularization: None,
gradient_clip: None,
validation_ratio: None,
subsample_threshold: None,
use_unigram_negative_sampling: true,
warmup_epochs: None,
checkpoint_interval: None,
checkpoint_path: None,
use_parallel: false,
}
}
pub fn with_dim(mut self, dim: usize) -> Self {
self.embedding_dim = dim;
self
}
pub fn with_learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
pub fn with_epochs(mut self, epochs: usize) -> Self {
self.epochs = epochs;
self
}
pub fn with_batch_size(mut self, bs: usize) -> Self {
self.batch_size = bs;
self
}
pub fn with_window(mut self, window: usize) -> Self {
self.context_window = window;
self
}
pub fn with_negative_samples(mut self, ns: usize) -> Self {
self.negative_samples = ns;
self
}
pub fn with_lr_schedule(mut self, schedule: LearningRateSchedule) -> Self {
self.lr_schedule = schedule;
self
}
pub fn with_early_stopping(mut self, patience: usize, min_delta: f64) -> Self {
self.early_stopping = Some(EarlyStoppingConfig { patience, min_delta });
self
}
pub fn with_l2_regularization(mut self, lambda: f64) -> Self {
self.l2_regularization = Some(lambda);
self
}
pub fn with_gradient_clip(mut self, max_norm: f32) -> Self {
self.gradient_clip = Some(max_norm);
self
}
pub fn with_validation_ratio(mut self, ratio: f64) -> Self {
self.validation_ratio = Some(ratio);
self
}
pub fn with_subsample_threshold(mut self, threshold: Option<f64>) -> Self {
self.subsample_threshold = threshold;
self
}
pub fn with_unigram_negative_sampling(mut self, enabled: bool) -> Self {
self.use_unigram_negative_sampling = enabled;
self
}
pub fn with_warmup_epochs(mut self, epochs: Option<usize>) -> Self {
self.warmup_epochs = epochs;
self
}
pub fn with_checkpoint_interval(mut self, interval: Option<usize>) -> Self {
self.checkpoint_interval = interval;
self
}
pub fn with_checkpoint_path(mut self, path: Option<String>) -> Self {
self.checkpoint_path = path;
self
}
pub fn with_parallel(mut self, enabled: bool) -> Self {
self.use_parallel = enabled;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LearningRateSchedule {
Constant,
Exponential { decay_rate: f64 },
Step { step_size: usize, gamma: f64 },
Cosine { t_max: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyStoppingConfig {
pub patience: usize,
pub min_delta: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum ModelType {
SkipGram,
Cbow,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingData {
pub sentences: Vec<Vec<String>>,
pub vocab: HashMap<String, usize>,
pub reverse_vocab: Vec<String>,
pub word_freq: Vec<usize>,
}
impl TrainingData {
pub fn from_text(text: &str) -> Self {
let sentences = load_text_data(text);
let (vocab, reverse_vocab, word_freq) = crate::text::build_vocab_with_freq(&sentences);
Self { sentences, vocab, reverse_vocab, word_freq }
}
pub fn from_file(path: &str) -> Result<Self, String> {
let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
let sentences = load_text_data(&content);
let (vocab, reverse_vocab, word_freq) = crate::text::build_vocab_with_freq(&sentences);
Ok(Self { sentences, vocab, reverse_vocab, word_freq })
}
pub fn total_word_count(&self) -> usize {
self.word_freq.iter().sum()
}
}
#[derive(Debug, Clone)]
pub struct DataLoader {
pub batch_size: usize,
pub shuffle: bool,
pub file_path: Option<String>,
}
impl DataLoader {
pub fn new(batch_size: usize, shuffle: bool) -> Self {
Self {
batch_size,
shuffle,
file_path: None,
}
}
pub fn set_file_path(&mut self, path: String) {
self.file_path = Some(path);
}
pub fn load_batches(&self, sentences: &[Vec<String>]) -> Vec<Vec<Vec<String>>> {
let mut batches = Vec::new();
let mut current_batch = Vec::new();
for sentence in sentences {
current_batch.push(sentence.clone());
if current_batch.len() >= self.batch_size {
if self.shuffle {
let mut rng = rand::thread_rng();
current_batch.shuffle(&mut rng);
}
batches.push(current_batch.clone());
current_batch.clear();
}
}
if !current_batch.is_empty() {
if self.shuffle {
let mut rng = rand::thread_rng();
current_batch.shuffle(&mut rng);
}
batches.push(current_batch);
}
batches
}
pub fn load_lazily(&self, file_path: &str) -> Result<Vec<Vec<String>>, String> {
use std::fs::File;
use std::io::Read;
let mut file = File::open(file_path).map_err(|e| e.to_string())?;
let mut content = String::new();
file.read_to_string(&mut content).map_err(|e| e.to_string())?;
Ok(load_text_data(&content))
}
pub fn stream_sentences(&self, file_path: &str) -> Result<Box<dyn Iterator<Item = Vec<String>>>, String> {
use std::fs::File;
use std::io::{BufRead, BufReader};
let file = File::open(file_path).map_err(|e| e.to_string())?;
let reader = BufReader::new(file);
let processor = TextProcessor::default();
let iter = reader.lines().filter_map(move |line| {
let line = line.ok()?;
let sentences = processor.process_text(&line);
if sentences.is_empty() {
None
} else {
Some(sentences.into_iter().flatten().collect::<Vec<String>>())
}
});
Ok(Box::new(iter))
}
}