pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
pub mod config;
pub mod evaluation;
pub mod search;
pub mod code;
pub mod text;
pub mod tokenizer;
pub mod transfer;
pub use config::*;
pub use evaluation::*;
pub use search::*;
pub use code::*;
pub use text::*;
pub use tokenizer::*;
pub use transfer::*;
pub mod model;
pub mod backend;
pub mod benchmark;
pub mod transformer;
pub mod mmap;
pub mod pretrained;
mod training;
mod export;
pub mod cli;
mod commands;
pub use model::*;
pub use backend::*;
pub use benchmark::*;
pub use transformer::*;
pub use mmap::MmapEmbeddings;
pub use pretrained::{PretrainedEmbeddings, PretrainedLoader};
pub use training::IncrementalTrainer;
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array, Array1, Array2};
use std::collections::HashMap;
#[test]
fn test_build_vocab() {
let sentences = vec![
vec!["hello".to_string(), "world".to_string()],
vec!["hello".to_string(), "rust".to_string()],
];
let (vocab, reverse_vocab) = build_vocab(&sentences);
assert_eq!(vocab.len(), 3);
assert_eq!(reverse_vocab.len(), 3);
assert_eq!(vocab.get("hello"), Some(&0));
assert_eq!(vocab.get("world"), Some(&1));
assert_eq!(vocab.get("rust"), Some(&2));
}
#[test]
fn test_load_text_data() {
let text = "Hello world! This is a test.";
let sentences = load_text_data(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["hello", "world"]);
assert_eq!(sentences[1], vec!["this", "is", "a", "test"]);
}
fn make_test_data() -> TrainingData {
TrainingData::from_text("the cat sat on the mat. the dog sat on the log. the cat chased the dog.")
}
fn test_config(model_type: ModelType) -> TrainingConfig {
TrainingConfig::new(model_type)
.with_dim(8)
.with_learning_rate(0.1)
.with_epochs(2)
.with_batch_size(4)
.with_window(1)
.with_negative_samples(2)
}
#[test]
fn test_train_skipgram() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
assert!(model.get_embedding("dog", &data).is_some());
assert!(model.get_embedding("the", &data).is_some());
assert!(model.similarity("cat", "dog", &data).is_some());
}
#[test]
fn test_train_cbow() {
let data = make_test_data();
let config = test_config(ModelType::Cbow);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
assert!(model.get_embedding("dog", &data).is_some());
assert!(model.similarity("cat", "dog", &data).is_some());
}
#[test]
fn test_build_vocab_with_freq_counts_correctly() {
let sentences = vec![
vec!["the".to_string(), "cat".to_string(), "sat".to_string()],
vec!["the".to_string(), "dog".to_string(), "sat".to_string()],
];
let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
assert_eq!(vocab.len(), 4);
assert_eq!(reverse_vocab.len(), 4);
assert_eq!(word_freq.len(), 4);
let the_id = vocab["the"];
let sat_id = vocab["sat"];
assert_eq!(word_freq[the_id], 2);
assert_eq!(word_freq[sat_id], 2);
assert_eq!(word_freq[vocab["cat"]], 1);
assert_eq!(word_freq[vocab["dog"]], 1);
}
#[test]
fn test_unigram_negative_sampling_runs() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram)
.with_unigram_negative_sampling(true)
.with_epochs(2);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
}
#[test]
fn test_subsampling_runs() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram)
.with_subsample_threshold(Some(1e-5))
.with_epochs(2);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
}
#[test]
fn test_subsampling_drops_frequent_words() {
let data = make_test_data();
let total = data.total_word_count();
assert!(total > 0);
let config = TrainingConfig::new(ModelType::SkipGram)
.with_dim(4)
.with_epochs(1)
.with_batch_size(2)
.with_subsample_threshold(Some(1e-3));
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
}
#[test]
fn test_lr_warmup() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram)
.with_warmup_epochs(Some(3))
.with_epochs(5)
.with_learning_rate(0.1);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
let lr0 = model.get_learning_rate(0, 5);
let lr1 = model.get_learning_rate(1, 5);
let lr2 = model.get_learning_rate(2, 5);
let lr3 = model.get_learning_rate(3, 5);
assert!(lr0 < lr1 && lr1 < lr2, "LR should increase during warm-up");
assert!(lr2 < lr3, "LR should reach base rate after warm-up");
assert!(lr3 > 0.0, "LR after warm-up should be positive");
}
#[test]
fn test_checkpoint_save_and_load() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram).with_epochs(2);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_checkpoint.json");
let path_str = path.to_str().unwrap();
model.save_checkpoint(path_str, 2, 0.5).unwrap();
let loaded = EmbeddingModel::load_checkpoint(path_str).unwrap();
assert_eq!(loaded.config.model_type, config.model_type);
assert_eq!(loaded.vocab_size, model.vocab_size);
assert_eq!(loaded.embeddings.shape(), model.embeddings.shape());
}
#[test]
fn test_parallel_training_skipgram() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram)
.with_parallel(true)
.with_epochs(2);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
}
#[test]
fn test_parallel_training_cbow() {
let data = make_test_data();
let config = test_config(ModelType::Cbow)
.with_parallel(true)
.with_epochs(2);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("dog", &data).is_some());
}
#[test]
fn test_save_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_embeddings_save.txt");
let path_str = path.to_str().unwrap();
assert!(model.save_embeddings(path_str, &data).is_ok());
let contents = std::fs::read_to_string(path_str).unwrap();
assert!(contents.contains("cat"));
assert!(contents.contains("dog"));
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_similarity_unknown_word() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
assert!(model.similarity("cat", "nonexistent", &data).is_none());
assert!(model.similarity("nonexistent", "dog", &data).is_none());
}
#[test]
fn test_strip_html() {
let processor = TextProcessor {
remove_html: true,
remove_punctuation: false,
lowercase: false,
..TextProcessor::default()
};
let text = "<p>Hello world!</p> This is a <b>test</b>.";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["Hello", "world"]);
assert_eq!(sentences[1], vec!["This", "is", "a", "test"]);
}
#[test]
fn test_strip_urls() {
let processor = TextProcessor {
remove_urls: true,
remove_punctuation: true,
lowercase: true,
..TextProcessor::default()
};
let text = "Visit https://example.com for info. See www.test.org too.";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["visit", "for", "info"]);
assert_eq!(sentences[1], vec!["see", "too"]);
}
#[test]
fn test_expand_contractions() {
let processor = TextProcessor {
expand_contractions: true,
remove_punctuation: true,
lowercase: true,
..TextProcessor::default()
};
let text = "I can't do this. It's a test.";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["i", "cannot", "do", "this"]);
assert_eq!(sentences[1], vec!["it", "is", "a", "test"]);
}
#[test]
fn test_normalize_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
model.normalize_embeddings();
for row in model.embeddings.rows() {
let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5 || norm == 0.0);
}
}
#[test]
fn test_analogy_unknown_word() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
assert!(model.analogy("unknown", "cat", "dog", &data, 1).is_empty());
}
#[test]
fn test_split_data() {
let sentences = vec![
vec!["a".to_string()],
vec!["b".to_string()],
vec!["c".to_string()],
vec!["d".to_string()],
vec!["e".to_string()],
vec!["f".to_string()],
vec!["g".to_string()],
vec!["h".to_string()],
vec!["i".to_string()],
vec!["j".to_string()],
];
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new(config, 1);
let (train, val) = model.split_data(&sentences, 0.7);
assert_eq!(train.len(), 7);
assert_eq!(val.len(), 3);
}
#[test]
fn test_gradient_clipping() {
let data = make_test_data();
let mut config = test_config(ModelType::SkipGram);
config.gradient_clip = Some(0.001);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
}
#[test]
fn test_mini_batch_processing() {
let data = make_test_data();
let mut config1 = test_config(ModelType::SkipGram);
config1.batch_size = 1;
let mut model1 = EmbeddingModel::new(config1, data.vocab.len());
assert!(model1.train(&data).is_ok());
let mut config8 = test_config(ModelType::SkipGram);
config8.batch_size = 8;
let mut model8 = EmbeddingModel::new(config8, data.vocab.len());
assert!(model8.train(&data).is_ok());
assert!(model1.get_embedding("cat", &data).is_some());
assert!(model8.get_embedding("cat", &data).is_some());
}
#[test]
fn test_empty_text() {
let sentences = load_text_data("");
assert!(sentences.is_empty());
}
#[test]
fn test_single_word_text() {
let sentences = load_text_data("hello");
assert_eq!(sentences.len(), 1);
assert_eq!(sentences[0], vec!["hello"]);
}
#[test]
fn test_learning_rate_schedules() {
let data = make_test_data();
let mut config_exp = test_config(ModelType::SkipGram);
config_exp.lr_schedule = LearningRateSchedule::Exponential { decay_rate: 0.9 };
let mut model_exp = EmbeddingModel::new(config_exp, data.vocab.len());
assert!(model_exp.train(&data).is_ok());
let mut config_step = test_config(ModelType::SkipGram);
config_step.lr_schedule = LearningRateSchedule::Step { step_size: 1, gamma: 0.5 };
let mut model_step = EmbeddingModel::new(config_step, data.vocab.len());
assert!(model_step.train(&data).is_ok());
let mut config_cos = test_config(ModelType::SkipGram);
config_cos.lr_schedule = LearningRateSchedule::Cosine { t_max: 2 };
let mut model_cos = EmbeddingModel::new(config_cos, data.vocab.len());
assert!(model_cos.train(&data).is_ok());
}
#[test]
fn test_early_stopping() {
let data = make_test_data();
let mut config = test_config(ModelType::SkipGram);
config.early_stopping = Some(EarlyStoppingConfig { patience: 1, min_delta: 0.001 });
config.epochs = 10;
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
}
#[test]
fn test_word2vec_format_roundtrip() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_word2vec.txt");
let path_str = temp_path.to_str().unwrap();
assert!(model.save_word2vec_format(path_str, &data).is_ok());
let (loaded, dim) = EmbeddingModel::load_word2vec_format(path_str).unwrap();
assert_eq!(dim, 8);
assert!(loaded.contains_key("cat"));
assert!(loaded.contains_key("dog"));
assert_eq!(loaded.get("cat").unwrap().len(), 8);
assert_eq!(loaded.get("dog").unwrap().len(), 8);
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_bpe_tokenizer() {
let corpus = vec![
"low".to_string(),
"lower".to_string(),
"lowest".to_string(),
"newer".to_string(),
"new".to_string(),
"widest".to_string(),
"wide".to_string(),
];
let tokenizer = BPETokenizer::train(&corpus, 20);
assert!(tokenizer.vocab.len() >= 10);
let tokens = tokenizer.encode("lowest");
assert!(!tokens.is_empty());
let decoded = tokenizer.decode(&tokens);
assert_eq!(decoded, "lowest");
}
#[test]
fn test_pretrained_embeddings_loading() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let temp_path = std::env::temp_dir().join("test_pretrained.txt");
let path_str = temp_path.to_str().unwrap();
let mut file = std::fs::File::create(path_str).unwrap();
use std::io::Write;
writeln!(file, "{} {}", data.vocab.len(), config.embedding_dim).unwrap();
for (word_id, word) in data.reverse_vocab.iter().enumerate() {
let vals: Vec<String> = (0..config.embedding_dim)
.map(|i| format!("{:.6}", (word_id * 10 + i) as f32 * 0.1))
.collect();
writeln!(file, "{} {}", word, vals.join(" ")).unwrap();
}
drop(file);
let model = EmbeddingModel::new_with_pretrained(
config,
data.vocab.len(),
&data,
path_str,
);
assert!(model.is_ok());
let model = model.unwrap();
let cat_emb = model.get_embedding("cat", &data).unwrap();
let cat_id = data.vocab.get("cat").unwrap();
for (i, &val) in cat_emb.iter().enumerate() {
let expected = (*cat_id * 10 + i) as f32 * 0.1;
assert!((val - expected).abs() < 1e-5, "Mismatch at index {}: got {}, expected {}", i, val, expected);
}
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_semantic_search() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let results = model.semantic_search("cat", &data, 5);
assert!(!results.is_empty());
for (word, _) in &results {
assert_ne!(word, "cat");
}
}
#[test]
fn test_embedding_arithmetic() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let result = model.embedding_arithmetic("cat", "dog", &data);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), config.embedding_dim);
}
#[test]
fn test_interpolate_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let result = model.interpolate_embeddings("cat", "dog", &data, 0.5);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), config.embedding_dim);
}
#[test]
fn test_save_numpy_format() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_numpy.npy");
let path_str = temp_path.to_str().unwrap();
assert!(model.save_numpy_format(path_str, &data).is_ok());
let metadata = std::fs::metadata(path_str).unwrap();
assert!(metadata.len() > 0);
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_stream_sentences() {
use std::io::Write;
let temp_path = std::env::temp_dir().join("test_stream.txt");
let path_str = temp_path.to_str().unwrap();
let mut file = std::fs::File::create(path_str).unwrap();
writeln!(file, "the cat sat on the mat.").unwrap();
writeln!(file, "the dog sat on the log.").unwrap();
writeln!(file, "the cat chased the dog.").unwrap();
drop(file);
let loader = DataLoader::new(4, false);
let sentences: Vec<Vec<String>> = loader.stream_sentences(path_str).unwrap().collect();
assert!(!sentences.is_empty());
assert!(sentences.iter().all(|s| !s.is_empty()));
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_incremental_vocab_update() {
let mut data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let old_vocab_size = data.vocab.len();
let old_emb_rows = model.embeddings.nrows();
let new_words = vec!["elephant".to_string(), "giraffe".to_string()];
let added = model.incremental_vocab_update(&new_words, &mut data).unwrap();
assert_eq!(added.len(), 2);
assert_eq!(data.vocab.len(), old_vocab_size + 2);
assert_eq!(model.embeddings.nrows(), old_emb_rows + 2);
assert_eq!(model.embeddings.ncols(), config.embedding_dim);
assert!(model.get_embedding("elephant", &data).is_some());
assert!(model.get_embedding("giraffe", &data).is_some());
assert!(model.get_embedding("cat", &data).is_some());
}
#[test]
fn test_lsh_index() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let mut lsh = LSHIndex::new(4, config.embedding_dim);
lsh.build(&model, &data);
let results = lsh.query("cat", &model, &data, 5);
for (word, _) in &results {
assert_ne!(word, "cat");
}
}
#[test]
fn test_save_onnx_format() {
use prost::Message;
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_model.onnx");
let path_str = temp_path.to_str().unwrap();
assert!(model.save_onnx_format(path_str, &data).is_ok());
let metadata = std::fs::metadata(path_str).unwrap();
assert!(metadata.len() > 50);
let bytes = std::fs::read(path_str).unwrap();
let decoded = onnx::ModelProto::decode(&bytes[..]);
assert!(decoded.is_ok());
let m = decoded.unwrap();
assert_eq!(m.ir_version, 9);
assert_eq!(m.producer_name, "embedding-trainer");
let graph = m.graph.unwrap();
assert_eq!(graph.name, "embedding_graph");
assert_eq!(graph.node.len(), 1);
assert_eq!(graph.node[0].op_type, "Gather");
assert_eq!(graph.initializer.len(), 1);
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_sentence_embedding() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let sentence = vec!["the".to_string(), "cat".to_string(), "sat".to_string()];
let emb = model.sentence_embedding(&sentence, &data);
assert!(emb.is_some());
let emb = emb.unwrap();
assert_eq!(emb.len(), config.embedding_dim);
assert!(model.sentence_embedding(&[], &data).is_none());
}
#[test]
fn test_multimodal_fusion() {
let text = Array::from_vec(vec![1.0, 2.0, 3.0]);
let aux = Array::from_vec(vec![4.0, 5.0, 6.0]);
let fusion = MultimodalFusion::new(3, 3, 3);
let concat = fusion.concatenate(&text, &aux);
assert_eq!(concat.len(), 6);
assert_eq!(concat[0], 1.0);
assert_eq!(concat[5], 6.0);
let avg = fusion.weighted_average(&text, &aux, 0.5).unwrap();
assert_eq!(avg.len(), 3);
assert!((avg[0] - 2.5).abs() < 1e-6);
let short = Array::from_vec(vec![1.0, 2.0]);
assert!(fusion.weighted_average(&text, &short, 0.5).is_none());
let attn = fusion.attention_fusion(&text, &aux).unwrap();
assert_eq!(attn.len(), 3);
let sim = MultimodalFusion::cross_modal_similarity(&text, &aux);
assert!(sim >= -1.0 && sim <= 1.0);
}
#[test]
fn test_cross_lingual_aligner() {
let dim = 4;
let mut aligner = CrossLingualAligner::new(dim);
let v = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let aligned = aligner.align(&v);
assert_eq!(aligned, v);
let src = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
let tgt = Array::from_vec(vec![2.0, 0.0, 0.0, 0.0]);
aligner.train_from_dictionary(&[(src, tgt)], 100, 0.1);
let test = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
let result = aligner.align(&test);
assert!((result[0] - 2.0).abs() < 0.1, "Expected ~2.0, got {}", result[0]);
}
#[test]
fn test_domain_adapter() {
let mut data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let domain_sentences = vec![
vec!["the".to_string(), "cat".to_string()],
vec!["a".to_string(), "dog".to_string()],
];
assert!(DomainAdapter::adapt(&mut model, &mut data, &domain_sentences, 1).is_ok());
assert!(data.vocab.contains_key("cat"));
}
#[test]
fn test_document_embedder() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let sentences = vec![
vec!["the".to_string(), "cat".to_string()],
vec!["a".to_string(), "dog".to_string()],
];
let doc = DocumentEmbedder::embed_document(&model, &data, &sentences);
assert!(doc.is_some());
assert_eq!(doc.unwrap().len(), config.embedding_dim);
assert!(DocumentEmbedder::embed_document(&model, &data, &[]).is_none());
}
#[test]
fn test_zero_shot_transfer() {
let proto_a = Array::from_vec(vec![1.0, 0.0, 0.0]);
let proto_b = Array::from_vec(vec![0.0, 1.0, 0.0]);
let mut prototypes = HashMap::new();
prototypes.insert("class_a".to_string(), proto_a);
prototypes.insert("class_b".to_string(), proto_b);
let query = Array::from_vec(vec![0.9, 0.1, 0.0]);
let result = ZeroShotTransfer::classify(&query, &prototypes);
assert!(result.is_some());
let (label, sim) = result.unwrap();
assert_eq!(label, "class_a");
assert!(sim > 0.9);
}
#[test]
fn test_query_expander() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let expanded = QueryExpander::expand(&model, &data, "cat", 3);
assert!(!expanded.is_empty());
assert_eq!(expanded[0], "cat");
}
#[test]
fn test_hierarchical_clustering() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let clusters = HierarchicalClustering::cluster(&model, &data, 2);
assert_eq!(clusters.len(), 2);
let mut all_words = std::collections::HashSet::new();
for c in &clusters {
for word in c {
assert!(all_words.insert(word.clone()));
}
}
assert_eq!(all_words.len(), data.vocab.len());
}
#[test]
fn test_unicode_normalization() {
let processor = TextProcessor {
lowercase: true,
remove_punctuation: false,
remove_numbers: false,
remove_stop_words: false,
remove_html: false,
remove_urls: false,
expand_contractions: false,
normalize_unicode: false,
language: "en".to_string(),
};
let text = "caf\u{0065}\u{0301}";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 1);
assert_eq!(sentences[0].len(), 1);
assert_eq!(sentences[0][0], "caf\u{00e9}");
}
#[test]
fn test_code_embedding_pipeline() {
let code = r#"
fn computeEmbeddingVector(input: Vec<f32>) -> Vec<f32> {
let result = vec![];
for x in input {
result.push(x * 2.0);
}
result
}
"#;
let sentences = load_code_data(code);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
assert!(vocab.contains_key("compute"));
assert!(vocab.contains_key("embedding"));
assert!(vocab.contains_key("vector"));
assert!(vocab.contains_key("result"));
let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("embedding", &data).is_some());
assert!(model.get_embedding("vector", &data).is_some());
}
#[test]
fn test_western_language_embedding_pipeline() {
let text = "Le chat noir dort sur le tapis. Le chien brun joue dans le jardin.";
let sentences = load_text_data(text);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
assert!(vocab.contains_key("chat"));
assert!(vocab.contains_key("noir"));
assert!(vocab.contains_key("dort"));
assert!(vocab.contains_key("chien"));
assert!(vocab.contains_key("jardin"));
let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("chat", &data).is_some());
assert!(model.get_embedding("chien", &data).is_some());
assert!(model.get_embedding("jardin", &data).is_some());
}
#[test]
fn test_chinese_embedding_pipeline() {
let text = "猫坐在垫子上。狗在花园里玩。猫追狗。";
let sentences = load_text_data(text);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
assert!(vocab.contains_key("猫"));
assert!(vocab.contains_key("坐"));
assert!(vocab.contains_key("狗"));
assert!(vocab.contains_key("花"));
assert!(vocab.contains_key("追"));
let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("猫", &data).is_some());
assert!(model.get_embedding("狗", &data).is_some());
assert!(model.get_embedding("追", &data).is_some());
}
#[test]
fn test_japanese_embedding_pipeline() {
let text = "猫はマットの上に座っています。犬は庭で遊んでいます。";
let sentences = load_text_data(text);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
assert!(vocab.contains_key("猫"));
assert!(vocab.contains_key("座"));
assert!(vocab.contains_key("犬"));
assert!(vocab.contains_key("遊"));
let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("猫", &data).is_some());
assert!(model.get_embedding("犬", &data).is_some());
}
#[test]
fn test_subword_embedder() {
let embedder = SubwordEmbedder::new(3, 5);
let ngrams = embedder.ngrams("apple");
assert!(!ngrams.is_empty());
assert!(ngrams.contains(&"<ap".to_string()));
assert!(ngrams.contains(&"ple>".to_string()));
let mut vectors = HashMap::new();
vectors.insert("<ap".to_string(), Array::from_vec(vec![1.0, 0.0]));
vectors.insert("app".to_string(), Array::from_vec(vec![0.0, 1.0]));
vectors.insert("ppl".to_string(), Array::from_vec(vec![1.0, 1.0]));
vectors.insert("ple>".to_string(), Array::from_vec(vec![0.5, 0.5]));
let emb = embedder.embed("apple", &vectors);
assert!(emb.is_some());
let emb = emb.unwrap();
assert_eq!(emb.len(), 2);
}
#[test]
fn test_create_validation_data() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new(config, data.vocab.len());
let val_data = model.create_validation_data(&data.sentences);
assert!(!val_data.positive_pairs.is_empty());
assert!(!val_data.negative_pairs.is_empty());
}
#[test]
fn test_evaluate_produces_metrics() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let val_data = model.create_validation_data(&data.sentences);
let metrics = model.evaluate(&data, &val_data);
assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
assert!(metrics.precision >= 0.0 && metrics.precision <= 1.0);
assert!(metrics.recall >= 0.0 && metrics.recall <= 1.0);
assert!(metrics.f1_score >= 0.0 && metrics.f1_score <= 1.0);
assert!(metrics.mean_similarity >= -1.0 && metrics.mean_similarity <= 1.0);
assert!(metrics.embedding_quality_score >= 0.0 && metrics.embedding_quality_score <= 1.0);
}
#[test]
fn test_train_with_validation_split() {
let data = make_test_data();
let mut config = test_config(ModelType::SkipGram);
config.validation_ratio = Some(0.3);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
}
#[test]
fn test_cross_validation_basic() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new(config, data.vocab.len());
let result = model.cross_validate(&data, 3).unwrap();
assert_eq!(result.folds, 3);
assert_eq!(result.per_fold_metrics.len(), 3);
assert!(result.averaged_metrics.accuracy >= 0.0 && result.averaged_metrics.accuracy <= 1.0);
assert!(result.averaged_metrics.f1_score >= 0.0 && result.averaged_metrics.f1_score <= 1.0);
}
#[test]
fn test_cross_validation_invalid_k() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.cross_validate(&data, 0).is_err());
assert!(model.cross_validate(&data, 1).is_err());
assert!(model.cross_validate(&data, 100).is_err());
}
#[test]
fn test_cross_validation_k_equals_2() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new(config, data.vocab.len());
let result = model.cross_validate(&data, 2).unwrap();
assert_eq!(result.folds, 2);
assert_eq!(result.per_fold_metrics.len(), 2);
}
#[test]
fn test_l2_normalize_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
model.normalize_embeddings();
for row in model.embeddings.rows() {
let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {}", norm);
}
}
#[test]
fn test_cross_validation_cbow() {
let data = make_test_data();
let config = test_config(ModelType::Cbow);
let model = EmbeddingModel::new(config, data.vocab.len());
let result = model.cross_validate(&data, 2).unwrap();
assert_eq!(result.folds, 2);
assert!(result.averaged_metrics.accuracy >= 0.0);
}
#[test]
fn test_training_history_records_epochs() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
assert!(!model.training_history.epochs.is_empty());
let first = &model.training_history.epochs[0];
assert!(first.loss >= 0.0);
assert!(first.learning_rate > 0.0);
let json = model.training_history.to_json().unwrap();
assert!(json.contains("loss"));
assert!(json.contains("learning_rate"));
}
#[test]
fn test_wordpiece_tokenizer_train_encode_decode() {
let corpus = vec![
"hello".to_string(),
"world".to_string(),
"hello".to_string(),
"world".to_string(),
];
let tokenizer = tokenizer::WordPieceTokenizer::train(&corpus, 50);
assert!(tokenizer.vocab_size > 0);
let tokens = tokenizer.encode("hello");
assert!(!tokens.is_empty());
let decoded = tokenizer.decode(&tokens);
assert_eq!(decoded, "hello");
}
#[test]
fn test_cpu_backend() {
let backend = backend::CpuBackend::new();
assert_eq!(backend.name(), "cpu");
let emb = backend.init_embeddings(10, 8);
assert_eq!(emb.nrows(), 10);
assert_eq!(emb.ncols(), 8);
let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
assert!((backend.dot(&a, &b) - 32.0).abs() < 1e-5);
let mut c = a.clone();
backend.add_scaled(&mut c, &b, 2.0);
assert_eq!(c.to_vec(), vec![9.0, 12.0, 15.0]);
let m1 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let m2 = Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
let result = backend.matmul(&m1, &m2);
assert_eq!(result.nrows(), 2);
assert_eq!(result.ncols(), 2);
}
#[test]
fn test_best_backend_returns_cpu() {
let backend = backend::best_backend();
assert!(!backend.name().is_empty());
}
#[test]
fn test_benchmark_load_and_evaluate() {
let tsv = "cat\tdog\t0.8\ncat\tmat\t0.2\ndog\tmat\t0.1\n";
let pairs = benchmark::BenchmarkEvaluator::load_from_tsv(tsv);
assert_eq!(pairs.len(), 3);
assert_eq!(pairs[0].word1, "cat");
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let result = benchmark::BenchmarkEvaluator::evaluate(&model, &data, &pairs);
assert_eq!(result.num_pairs, 3);
assert!(result.num_evaluated <= 3);
assert!(result.correlation >= -1.0 && result.correlation <= 1.0);
}
#[test]
fn test_kmeans_clustering() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let clusters = search::KMeansClustering::cluster(&model, &data, 3, 20);
assert!(!clusters.is_empty());
assert!(clusters.len() <= 3);
let total_words: usize = clusters.iter().map(|c| c.len()).sum();
assert_eq!(total_words, data.vocab.len());
}
#[test]
fn test_kmeans_clustering_k_greater_than_vocab() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let clusters = search::KMeansClustering::cluster(&model, &data, 100, 10);
assert_eq!(clusters.len(), data.vocab.len());
}
#[test]
fn test_transformer_encoder() {
let encoder = TransformerEncoder::new(2, 2, 8, 16, 10);
let tokens = ndarray::Array2::zeros((3, 8));
let encoded = encoder.encode_sequence(&tokens);
assert_eq!(encoded.nrows(), 3);
assert_eq!(encoded.ncols(), 8);
}
#[test]
fn test_incremental_trainer() {
let mut data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let original_vocab = data.vocab.len();
let new_sentences = vec![vec!["newword".to_string(), "cat".to_string()]];
IncrementalTrainer::update(&mut model, &mut data, &new_sentences, 1).unwrap();
assert!(data.vocab.len() >= original_vocab);
assert!(data.vocab.contains_key("newword"));
}
#[test]
fn test_incremental_stream_train() {
let mut data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let sentences = vec![
vec!["stream".to_string(), "word".to_string()],
vec!["another".to_string(), "stream".to_string()],
];
IncrementalTrainer::stream_train(
&mut model,
&mut data,
sentences.into_iter(),
1,
1,
)
.unwrap();
assert!(data.vocab.contains_key("stream"));
}
#[test]
fn test_mmap_embeddings_roundtrip() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_mmap.bin");
let path_str = temp_path.to_str().unwrap();
model.save_mmapable_format(path_str, &data).unwrap();
let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
assert_eq!(mmap.vocab_size(), data.vocab.len());
assert_eq!(mmap.dim(), model.config.embedding_dim);
let cat_emb = mmap.get("cat").unwrap();
assert_eq!(cat_emb.len(), model.config.embedding_dim);
let cat_id = data.vocab["cat"];
let model_cat: Vec<f32> = model.embeddings.row(cat_id).to_vec();
assert_eq!(cat_emb, model_cat.as_slice());
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_mmap_embeddings_iter() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_mmap_iter.bin");
let path_str = temp_path.to_str().unwrap();
model.save_mmapable_format(path_str, &data).unwrap();
let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
let mut count = 0;
for (word, emb) in mmap.iter() {
assert!(!word.is_empty());
assert_eq!(emb.len(), model.config.embedding_dim);
count += 1;
}
assert_eq!(count, data.vocab.len());
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_mmap_embeddings_missing_word() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_mmap_missing.bin");
let path_str = temp_path.to_str().unwrap();
model.save_mmapable_format(path_str, &data).unwrap();
let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
assert!(mmap.get("nonexistent_word").is_none());
assert!(mmap.get("cat").is_some());
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_pretrained_loader_word2vec_text() {
let temp = std::env::temp_dir().join("test_pretrained_w2v.txt");
let path = temp.to_str().unwrap();
let content = "3 4\ncat 0.1 0.2 0.3 0.4\ndog 0.5 0.6 0.7 0.8\nfish 0.9 0.0 0.1 0.2\n";
std::fs::write(path, content).unwrap();
let emb = PretrainedLoader::auto(path).unwrap();
assert_eq!(emb.dim(), 4);
assert_eq!(emb.vocab_size(), 3);
assert!(emb.contains("cat"));
assert!(emb.contains("dog"));
assert!(!emb.contains("elephant"));
let cat = emb.get("cat").unwrap();
assert_eq!(cat.len(), 4);
assert!((cat[0] - 0.1).abs() < 1e-6);
std::fs::remove_file(path).ok();
}
#[test]
fn test_pretrained_embeddings_similarity() {
let mut emb = PretrainedEmbeddings::new(3);
emb.insert("a".to_string(), vec![1.0, 0.0, 0.0]);
emb.insert("b".to_string(), vec![0.0, 1.0, 0.0]);
emb.insert("c".to_string(), vec![1.0, 0.0, 0.0]);
let sim_ab = emb.similarity("a", "b").unwrap();
assert!(sim_ab.abs() < 1e-5, "Orthogonal vectors should have ~0 similarity");
let sim_ac = emb.similarity("a", "c").unwrap();
assert!((sim_ac - 1.0).abs() < 1e-5, "Identical vectors should have similarity ~1");
assert!(emb.similarity("a", "missing").is_none());
}
#[test]
fn test_pretrained_embeddings_most_similar() {
let mut emb = PretrainedEmbeddings::new(2);
emb.insert("king".to_string(), vec![1.0, 0.0]);
emb.insert("queen".to_string(), vec![0.9, 0.1]);
emb.insert("man".to_string(), vec![0.1, 1.0]);
emb.insert("woman".to_string(), vec![0.2, 0.9]);
let similar = emb.most_similar("king", 2);
assert_eq!(similar.len(), 2);
assert_eq!(similar[0].0, "queen"); }
#[test]
fn test_pretrained_loader_glove_format() {
let temp = std::env::temp_dir().join("test_pretrained_glove.txt");
let path = temp.to_str().unwrap();
let content = "2 3\nhello 0.1 0.2 0.3\nworld 0.4 0.5 0.6\n";
std::fs::write(path, content).unwrap();
let emb = PretrainedLoader::with_format(path, pretrained::PretrainedFormat::GloVe).unwrap();
assert_eq!(emb.dim(), 3);
assert_eq!(emb.vocab_size(), 2);
let hello = emb.get("hello").unwrap();
assert_eq!(hello, &[0.1, 0.2, 0.3]);
std::fs::remove_file(path).ok();
}
#[test]
fn test_pretrained_init_model_from_pretrained() {
let temp = std::env::temp_dir().join("test_pretrained_init.txt");
let path = temp.to_str().unwrap();
let content = "3 8\ncat 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1\ndog 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2\nthe 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3\n";
std::fs::write(path, content).unwrap();
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new_with_pretrained(config, data.vocab.len(), &data, path).unwrap();
let cat_id = data.vocab["cat"];
let cat_emb = model.embeddings.row(cat_id);
for &v in cat_emb.iter() {
assert!((v - 0.1).abs() < 1e-5);
}
std::fs::remove_file(path).ok();
}
}