Skip to main content

embedding/
lib.rs

1//! Word embedding training library.
2//!
3//! This crate provides tools for training word embeddings from scratch
4//! using SkipGram, CBOW, and other models. It supports:
5//!
6//! - Mini-batch training with gradient clipping and L2 regularization
7//! - Learning rate scheduling (constant, exponential, step, cosine)
8//! - Early stopping and evaluation metrics
9//! - Text preprocessing (HTML stripping, URL removal, contraction expansion)
10//! - Source code preprocessing (comment stripping, camelCase splitting)
11//! - BPE subword tokenization
12//! - Export to Word2Vec, NumPy, ONNX, and binary formats
13//! - Semantic search, analogy solving, and embedding arithmetic
14//! - Incremental vocabulary updates and LSH-based approximate nearest neighbors
15//!
16//! # Example
17//!
18//! ```rust
19//! use embedding::*;
20//!
21//! let data = TrainingData::from_text("the cat sat on the mat");
22//! let config = TrainingConfig::new(ModelType::SkipGram)
23//!     .with_dim(8)
24//!     .with_epochs(2);
25//!
26//! let mut model = EmbeddingModel::new(config, data.vocab.len());
27//! // model.train(&data).unwrap();
28//! ```
29
30/// Low-level ONNX protobuf definitions generated by prost.
31pub mod onnx {
32    include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
33}
34
35pub mod config;
36pub mod evaluation;
37pub mod search;
38pub mod code;
39pub mod text;
40pub mod tokenizer;
41pub mod transfer;
42pub use config::*;
43pub use evaluation::*;
44pub use search::*;
45pub use code::*;
46pub use text::*;
47pub use tokenizer::*;
48pub use transfer::*;
49
50pub mod model;
51pub mod backend;
52pub mod benchmark;
53pub mod transformer;
54pub mod mmap;
55pub mod pretrained;
56mod training;
57mod export;
58pub mod cli;
59mod commands;
60pub use model::*;
61pub use backend::*;
62pub use benchmark::*;
63pub use transformer::*;
64pub use mmap::MmapEmbeddings;
65pub use pretrained::{PretrainedEmbeddings, PretrainedLoader};
66pub use training::IncrementalTrainer;
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use ndarray::{Array, Array1, Array2};
72    use std::collections::HashMap;
73
74    #[test]
75    fn test_build_vocab() {
76        let sentences = vec![
77            vec!["hello".to_string(), "world".to_string()],
78            vec!["hello".to_string(), "rust".to_string()],
79        ];
80        
81        let (vocab, reverse_vocab) = build_vocab(&sentences);
82        
83        assert_eq!(vocab.len(), 3);
84        assert_eq!(reverse_vocab.len(), 3);
85        assert_eq!(vocab.get("hello"), Some(&0));
86        assert_eq!(vocab.get("world"), Some(&1));
87        assert_eq!(vocab.get("rust"), Some(&2));
88    }
89
90    #[test]
91    fn test_load_text_data() {
92        let text = "Hello world! This is a test.";
93        let sentences = load_text_data(text);
94
95        assert_eq!(sentences.len(), 2);
96        assert_eq!(sentences[0], vec!["hello", "world"]);
97        assert_eq!(sentences[1], vec!["this", "is", "a", "test"]);
98    }
99
100    fn make_test_data() -> TrainingData {
101        TrainingData::from_text("the cat sat on the mat. the dog sat on the log. the cat chased the dog.")
102    }
103
104    fn test_config(model_type: ModelType) -> TrainingConfig {
105        TrainingConfig::new(model_type)
106            .with_dim(8)
107            .with_learning_rate(0.1)
108            .with_epochs(2)
109            .with_batch_size(4)
110            .with_window(1)
111            .with_negative_samples(2)
112    }
113
114    #[test]
115    fn test_train_skipgram() {
116        let data = make_test_data();
117        let config = test_config(ModelType::SkipGram);
118        let mut model = EmbeddingModel::new(config, data.vocab.len());
119
120        assert!(model.train(&data).is_ok());
121
122        // Embeddings should exist for known words
123        assert!(model.get_embedding("cat", &data).is_some());
124        assert!(model.get_embedding("dog", &data).is_some());
125        assert!(model.get_embedding("the", &data).is_some());
126
127        // Similarity should return a value for known pairs
128        assert!(model.similarity("cat", "dog", &data).is_some());
129    }
130
131    #[test]
132    fn test_train_cbow() {
133        let data = make_test_data();
134        let config = test_config(ModelType::Cbow);
135        let mut model = EmbeddingModel::new(config, data.vocab.len());
136
137        assert!(model.train(&data).is_ok());
138
139        assert!(model.get_embedding("cat", &data).is_some());
140        assert!(model.get_embedding("dog", &data).is_some());
141        assert!(model.similarity("cat", "dog", &data).is_some());
142    }
143
144    #[test]
145    fn test_build_vocab_with_freq_counts_correctly() {
146        let sentences = vec![
147            vec!["the".to_string(), "cat".to_string(), "sat".to_string()],
148            vec!["the".to_string(), "dog".to_string(), "sat".to_string()],
149        ];
150        let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
151        assert_eq!(vocab.len(), 4);
152        assert_eq!(reverse_vocab.len(), 4);
153        assert_eq!(word_freq.len(), 4);
154
155        let the_id = vocab["the"];
156        let sat_id = vocab["sat"];
157        assert_eq!(word_freq[the_id], 2);
158        assert_eq!(word_freq[sat_id], 2);
159        assert_eq!(word_freq[vocab["cat"]], 1);
160        assert_eq!(word_freq[vocab["dog"]], 1);
161    }
162
163    #[test]
164    fn test_unigram_negative_sampling_runs() {
165        let data = make_test_data();
166        let config = test_config(ModelType::SkipGram)
167            .with_unigram_negative_sampling(true)
168            .with_epochs(2);
169        let mut model = EmbeddingModel::new(config, data.vocab.len());
170        assert!(model.train(&data).is_ok());
171    }
172
173    #[test]
174    fn test_subsampling_runs() {
175        let data = make_test_data();
176        let config = test_config(ModelType::SkipGram)
177            .with_subsample_threshold(Some(1e-5))
178            .with_epochs(2);
179        let mut model = EmbeddingModel::new(config, data.vocab.len());
180        assert!(model.train(&data).is_ok());
181    }
182
183    #[test]
184    fn test_subsampling_drops_frequent_words() {
185        let data = make_test_data();
186        let total = data.total_word_count();
187        assert!(total > 0);
188
189        // With a moderate threshold, frequent words like "the" and "sat" may be dropped
190        let config = TrainingConfig::new(ModelType::SkipGram)
191            .with_dim(4)
192            .with_epochs(1)
193            .with_batch_size(2)
194            .with_subsample_threshold(Some(1e-3));
195        let mut model = EmbeddingModel::new(config, data.vocab.len());
196        assert!(model.train(&data).is_ok());
197    }
198
199    #[test]
200    fn test_lr_warmup() {
201        let data = make_test_data();
202        let config = test_config(ModelType::SkipGram)
203            .with_warmup_epochs(Some(3))
204            .with_epochs(5)
205            .with_learning_rate(0.1);
206        let mut model = EmbeddingModel::new(config, data.vocab.len());
207        assert!(model.train(&data).is_ok());
208
209        // Verify LR is scaled down during warm-up epochs
210        let lr0 = model.get_learning_rate(0, 5);
211        let lr1 = model.get_learning_rate(1, 5);
212        let lr2 = model.get_learning_rate(2, 5);
213        let lr3 = model.get_learning_rate(3, 5);
214        assert!(lr0 < lr1 && lr1 < lr2, "LR should increase during warm-up");
215        assert!(lr2 < lr3, "LR should reach base rate after warm-up");
216        assert!(lr3 > 0.0, "LR after warm-up should be positive");
217    }
218
219    #[test]
220    fn test_checkpoint_save_and_load() {
221        let data = make_test_data();
222        let config = test_config(ModelType::SkipGram).with_epochs(2);
223        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
224        model.train(&data).unwrap();
225
226        let temp_dir = std::env::temp_dir();
227        let path = temp_dir.join("test_checkpoint.json");
228        let path_str = path.to_str().unwrap();
229
230        model.save_checkpoint(path_str, 2, 0.5).unwrap();
231        let loaded = EmbeddingModel::load_checkpoint(path_str).unwrap();
232
233        assert_eq!(loaded.config.model_type, config.model_type);
234        assert_eq!(loaded.vocab_size, model.vocab_size);
235        assert_eq!(loaded.embeddings.shape(), model.embeddings.shape());
236    }
237
238    #[test]
239    fn test_parallel_training_skipgram() {
240        let data = make_test_data();
241        let config = test_config(ModelType::SkipGram)
242            .with_parallel(true)
243            .with_epochs(2);
244        let mut model = EmbeddingModel::new(config, data.vocab.len());
245        assert!(model.train(&data).is_ok());
246        assert!(model.get_embedding("cat", &data).is_some());
247    }
248
249    #[test]
250    fn test_parallel_training_cbow() {
251        let data = make_test_data();
252        let config = test_config(ModelType::Cbow)
253            .with_parallel(true)
254            .with_epochs(2);
255        let mut model = EmbeddingModel::new(config, data.vocab.len());
256        assert!(model.train(&data).is_ok());
257        assert!(model.get_embedding("dog", &data).is_some());
258    }
259
260    #[test]
261    fn test_save_embeddings() {
262        let data = make_test_data();
263        let config = test_config(ModelType::SkipGram);
264        let mut model = EmbeddingModel::new(config, data.vocab.len());
265        model.train(&data).unwrap();
266
267        let temp_dir = std::env::temp_dir();
268        let path = temp_dir.join("test_embeddings_save.txt");
269        let path_str = path.to_str().unwrap();
270
271        assert!(model.save_embeddings(path_str, &data).is_ok());
272        let contents = std::fs::read_to_string(path_str).unwrap();
273        assert!(contents.contains("cat"));
274        assert!(contents.contains("dog"));
275
276        std::fs::remove_file(path_str).ok();
277    }
278
279    #[test]
280    fn test_similarity_unknown_word() {
281        let data = make_test_data();
282        let config = test_config(ModelType::SkipGram);
283        let mut model = EmbeddingModel::new(config, data.vocab.len());
284        model.train(&data).unwrap();
285
286        assert!(model.similarity("cat", "nonexistent", &data).is_none());
287        assert!(model.similarity("nonexistent", "dog", &data).is_none());
288    }
289
290    #[test]
291    fn test_strip_html() {
292        let processor = TextProcessor {
293            remove_html: true,
294            remove_punctuation: false,
295            lowercase: false,
296            ..TextProcessor::default()
297        };
298        let text = "<p>Hello world!</p> This is a <b>test</b>.";
299        let sentences = processor.process_text(text);
300        assert_eq!(sentences.len(), 2);
301        assert_eq!(sentences[0], vec!["Hello", "world"]);
302        assert_eq!(sentences[1], vec!["This", "is", "a", "test"]);
303    }
304
305    #[test]
306    fn test_strip_urls() {
307        let processor = TextProcessor {
308            remove_urls: true,
309            remove_punctuation: true,
310            lowercase: true,
311            ..TextProcessor::default()
312        };
313        let text = "Visit https://example.com for info. See www.test.org too.";
314        let sentences = processor.process_text(text);
315        assert_eq!(sentences.len(), 2);
316        assert_eq!(sentences[0], vec!["visit", "for", "info"]);
317        assert_eq!(sentences[1], vec!["see", "too"]);
318    }
319
320    #[test]
321    fn test_expand_contractions() {
322        let processor = TextProcessor {
323            expand_contractions: true,
324            remove_punctuation: true,
325            lowercase: true,
326            ..TextProcessor::default()
327        };
328        let text = "I can't do this. It's a test.";
329        let sentences = processor.process_text(text);
330        assert_eq!(sentences.len(), 2);
331        // "can't" -> "cannot", then punctuation stripped
332        assert_eq!(sentences[0], vec!["i", "cannot", "do", "this"]);
333        assert_eq!(sentences[1], vec!["it", "is", "a", "test"]);
334    }
335
336    #[test]
337    fn test_normalize_embeddings() {
338        let data = make_test_data();
339        let config = test_config(ModelType::SkipGram);
340        let mut model = EmbeddingModel::new(config, data.vocab.len());
341        model.train(&data).unwrap();
342        model.normalize_embeddings();
343
344        for row in model.embeddings.rows() {
345            let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
346            assert!((norm - 1.0).abs() < 1e-5 || norm == 0.0);
347        }
348    }
349
350    #[test]
351    fn test_analogy_unknown_word() {
352        let data = make_test_data();
353        let config = test_config(ModelType::SkipGram);
354        let mut model = EmbeddingModel::new(config, data.vocab.len());
355        model.train(&data).unwrap();
356
357        assert!(model.analogy("unknown", "cat", "dog", &data, 1).is_empty());
358    }
359
360    #[test]
361    fn test_split_data() {
362        let sentences = vec![
363            vec!["a".to_string()],
364            vec!["b".to_string()],
365            vec!["c".to_string()],
366            vec!["d".to_string()],
367            vec!["e".to_string()],
368            vec!["f".to_string()],
369            vec!["g".to_string()],
370            vec!["h".to_string()],
371            vec!["i".to_string()],
372            vec!["j".to_string()],
373        ];
374        let config = test_config(ModelType::SkipGram);
375        let model = EmbeddingModel::new(config, 1);
376        let (train, val) = model.split_data(&sentences, 0.7);
377        assert_eq!(train.len(), 7);
378        assert_eq!(val.len(), 3);
379    }
380
381    #[test]
382    fn test_gradient_clipping() {
383        let data = make_test_data();
384        let mut config = test_config(ModelType::SkipGram);
385        config.gradient_clip = Some(0.001);
386        let mut model = EmbeddingModel::new(config, data.vocab.len());
387
388        // Training should still succeed with aggressive clipping
389        assert!(model.train(&data).is_ok());
390        assert!(model.get_embedding("cat", &data).is_some());
391    }
392
393    #[test]
394    fn test_mini_batch_processing() {
395        let data = make_test_data();
396        // Test with batch_size = 1 (equivalent to old behavior)
397        let mut config1 = test_config(ModelType::SkipGram);
398        config1.batch_size = 1;
399        let mut model1 = EmbeddingModel::new(config1, data.vocab.len());
400        assert!(model1.train(&data).is_ok());
401
402        // Test with batch_size = 8 (actual mini-batch)
403        let mut config8 = test_config(ModelType::SkipGram);
404        config8.batch_size = 8;
405        let mut model8 = EmbeddingModel::new(config8, data.vocab.len());
406        assert!(model8.train(&data).is_ok());
407
408        // Both should produce embeddings for known words
409        assert!(model1.get_embedding("cat", &data).is_some());
410        assert!(model8.get_embedding("cat", &data).is_some());
411    }
412
413    #[test]
414    fn test_empty_text() {
415        let sentences = load_text_data("");
416        assert!(sentences.is_empty());
417    }
418
419    #[test]
420    fn test_single_word_text() {
421        let sentences = load_text_data("hello");
422        assert_eq!(sentences.len(), 1);
423        assert_eq!(sentences[0], vec!["hello"]);
424    }
425
426    #[test]
427    fn test_learning_rate_schedules() {
428        let data = make_test_data();
429
430        let mut config_exp = test_config(ModelType::SkipGram);
431        config_exp.lr_schedule = LearningRateSchedule::Exponential { decay_rate: 0.9 };
432        let mut model_exp = EmbeddingModel::new(config_exp, data.vocab.len());
433        assert!(model_exp.train(&data).is_ok());
434
435        let mut config_step = test_config(ModelType::SkipGram);
436        config_step.lr_schedule = LearningRateSchedule::Step { step_size: 1, gamma: 0.5 };
437        let mut model_step = EmbeddingModel::new(config_step, data.vocab.len());
438        assert!(model_step.train(&data).is_ok());
439
440        let mut config_cos = test_config(ModelType::SkipGram);
441        config_cos.lr_schedule = LearningRateSchedule::Cosine { t_max: 2 };
442        let mut model_cos = EmbeddingModel::new(config_cos, data.vocab.len());
443        assert!(model_cos.train(&data).is_ok());
444    }
445
446    #[test]
447    fn test_early_stopping() {
448        let data = make_test_data();
449        let mut config = test_config(ModelType::SkipGram);
450        config.early_stopping = Some(EarlyStoppingConfig { patience: 1, min_delta: 0.001 });
451        config.epochs = 10;
452        let mut model = EmbeddingModel::new(config, data.vocab.len());
453        assert!(model.train(&data).is_ok());
454    }
455
456    #[test]
457    fn test_word2vec_format_roundtrip() {
458        let data = make_test_data();
459        let config = test_config(ModelType::SkipGram);
460        let mut model = EmbeddingModel::new(config, data.vocab.len());
461        model.train(&data).unwrap();
462
463        let temp_path = std::env::temp_dir().join("test_word2vec.txt");
464        let path_str = temp_path.to_str().unwrap();
465
466        // Save in Word2Vec format
467        assert!(model.save_word2vec_format(path_str, &data).is_ok());
468
469        // Load and verify
470        let (loaded, dim) = EmbeddingModel::load_word2vec_format(path_str).unwrap();
471        assert_eq!(dim, 8);
472        assert!(loaded.contains_key("cat"));
473        assert!(loaded.contains_key("dog"));
474        assert_eq!(loaded.get("cat").unwrap().len(), 8);
475        assert_eq!(loaded.get("dog").unwrap().len(), 8);
476
477        std::fs::remove_file(path_str).ok();
478    }
479
480    #[test]
481    fn test_bpe_tokenizer() {
482        let corpus = vec![
483            "low".to_string(),
484            "lower".to_string(),
485            "lowest".to_string(),
486            "newer".to_string(),
487            "new".to_string(),
488            "widest".to_string(),
489            "wide".to_string(),
490        ];
491
492        let tokenizer = BPETokenizer::train(&corpus, 20);
493
494        // Vocab should have grown beyond initial character count
495        assert!(tokenizer.vocab.len() >= 10);
496
497        // Encode a word
498        let tokens = tokenizer.encode("lowest");
499        assert!(!tokens.is_empty());
500
501        // Decode should reconstruct the original (with end-of-word marker removed)
502        let decoded = tokenizer.decode(&tokens);
503        assert_eq!(decoded, "lowest");
504    }
505
506    #[test]
507    fn test_pretrained_embeddings_loading() {
508        let data = make_test_data();
509        let config = test_config(ModelType::SkipGram);
510
511        // Create a fake pre-trained embeddings file
512        let temp_path = std::env::temp_dir().join("test_pretrained.txt");
513        let path_str = temp_path.to_str().unwrap();
514
515        let mut file = std::fs::File::create(path_str).unwrap();
516        use std::io::Write;
517        writeln!(file, "{} {}", data.vocab.len(), config.embedding_dim).unwrap();
518        for (word_id, word) in data.reverse_vocab.iter().enumerate() {
519            let vals: Vec<String> = (0..config.embedding_dim)
520                .map(|i| format!("{:.6}", (word_id * 10 + i) as f32 * 0.1))
521                .collect();
522            writeln!(file, "{} {}", word, vals.join(" ")).unwrap();
523        }
524        drop(file);
525
526        // Load pre-trained embeddings
527        let model = EmbeddingModel::new_with_pretrained(
528            config,
529            data.vocab.len(),
530            &data,
531            path_str,
532        );
533        assert!(model.is_ok());
534
535        let model = model.unwrap();
536        // Verify "cat" embedding matches pre-trained values
537        let cat_emb = model.get_embedding("cat", &data).unwrap();
538        let cat_id = data.vocab.get("cat").unwrap();
539        for (i, &val) in cat_emb.iter().enumerate() {
540            let expected = (*cat_id * 10 + i) as f32 * 0.1;
541            assert!((val - expected).abs() < 1e-5, "Mismatch at index {}: got {}, expected {}", i, val, expected);
542        }
543
544        std::fs::remove_file(path_str).ok();
545    }
546
547    #[test]
548    fn test_semantic_search() {
549        let data = make_test_data();
550        let config = test_config(ModelType::SkipGram);
551        let mut model = EmbeddingModel::new(config, data.vocab.len());
552        model.train(&data).unwrap();
553
554        let results = model.semantic_search("cat", &data, 5);
555        assert!(!results.is_empty());
556        // Results should not include the query word itself
557        for (word, _) in &results {
558            assert_ne!(word, "cat");
559        }
560    }
561
562    #[test]
563    fn test_embedding_arithmetic() {
564        let data = make_test_data();
565        let config = test_config(ModelType::SkipGram);
566        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
567        model.train(&data).unwrap();
568
569        let result = model.embedding_arithmetic("cat", "dog", &data);
570        assert!(result.is_some());
571        assert_eq!(result.unwrap().len(), config.embedding_dim);
572    }
573
574    #[test]
575    fn test_interpolate_embeddings() {
576        let data = make_test_data();
577        let config = test_config(ModelType::SkipGram);
578        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
579        model.train(&data).unwrap();
580
581        let result = model.interpolate_embeddings("cat", "dog", &data, 0.5);
582        assert!(result.is_some());
583        assert_eq!(result.unwrap().len(), config.embedding_dim);
584    }
585
586    #[test]
587    fn test_save_numpy_format() {
588        let data = make_test_data();
589        let config = test_config(ModelType::SkipGram);
590        let mut model = EmbeddingModel::new(config, data.vocab.len());
591        model.train(&data).unwrap();
592
593        let temp_path = std::env::temp_dir().join("test_numpy.npy");
594        let path_str = temp_path.to_str().unwrap();
595
596        assert!(model.save_numpy_format(path_str, &data).is_ok());
597
598        // Verify file exists and has non-zero size
599        let metadata = std::fs::metadata(path_str).unwrap();
600        assert!(metadata.len() > 0);
601
602        std::fs::remove_file(path_str).ok();
603    }
604
605    #[test]
606    fn test_stream_sentences() {
607        use std::io::Write;
608
609        let temp_path = std::env::temp_dir().join("test_stream.txt");
610        let path_str = temp_path.to_str().unwrap();
611
612        let mut file = std::fs::File::create(path_str).unwrap();
613        writeln!(file, "the cat sat on the mat.").unwrap();
614        writeln!(file, "the dog sat on the log.").unwrap();
615        writeln!(file, "the cat chased the dog.").unwrap();
616        drop(file);
617
618        let loader = DataLoader::new(4, false);
619        let sentences: Vec<Vec<String>> = loader.stream_sentences(path_str).unwrap().collect();
620        assert!(!sentences.is_empty());
621        // Each line should produce tokens
622        assert!(sentences.iter().all(|s| !s.is_empty()));
623
624        std::fs::remove_file(path_str).ok();
625    }
626
627    #[test]
628    fn test_incremental_vocab_update() {
629        let mut data = make_test_data();
630        let config = test_config(ModelType::SkipGram);
631        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
632        model.train(&data).unwrap();
633
634        let old_vocab_size = data.vocab.len();
635        let old_emb_rows = model.embeddings.nrows();
636
637        let new_words = vec!["elephant".to_string(), "giraffe".to_string()];
638        let added = model.incremental_vocab_update(&new_words, &mut data).unwrap();
639
640        assert_eq!(added.len(), 2);
641        assert_eq!(data.vocab.len(), old_vocab_size + 2);
642        assert_eq!(model.embeddings.nrows(), old_emb_rows + 2);
643        assert_eq!(model.embeddings.ncols(), config.embedding_dim);
644
645        // New words should be retrievable
646        assert!(model.get_embedding("elephant", &data).is_some());
647        assert!(model.get_embedding("giraffe", &data).is_some());
648        // Existing words should still work
649        assert!(model.get_embedding("cat", &data).is_some());
650    }
651
652    #[test]
653    fn test_lsh_index() {
654        let data = make_test_data();
655        let config = test_config(ModelType::SkipGram);
656        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
657        model.train(&data).unwrap();
658
659        let mut lsh = LSHIndex::new(4, config.embedding_dim);
660        lsh.build(&model, &data);
661
662        let results = lsh.query("cat", &model, &data, 5);
663        // LSH may return empty if all hashes collide poorly on tiny vocab, but
664        // with 4 tables and 32 bits it should usually find candidates.
665        // At minimum it should not panic.
666        for (word, _) in &results {
667            assert_ne!(word, "cat");
668        }
669    }
670
671    #[test]
672    fn test_save_onnx_format() {
673        use prost::Message;
674
675        let data = make_test_data();
676        let config = test_config(ModelType::SkipGram);
677        let mut model = EmbeddingModel::new(config, data.vocab.len());
678        model.train(&data).unwrap();
679
680        let temp_path = std::env::temp_dir().join("test_model.onnx");
681        let path_str = temp_path.to_str().unwrap();
682
683        assert!(model.save_onnx_format(path_str, &data).is_ok());
684
685        // Verify file exists and has reasonable size (protobuf header + data)
686        let metadata = std::fs::metadata(path_str).unwrap();
687        assert!(metadata.len() > 50);
688
689        // Verify it's valid protobuf by decoding
690        let bytes = std::fs::read(path_str).unwrap();
691        let decoded = onnx::ModelProto::decode(&bytes[..]);
692        assert!(decoded.is_ok());
693
694        let m = decoded.unwrap();
695        assert_eq!(m.ir_version, 9);
696        assert_eq!(m.producer_name, "embedding-trainer");
697        let graph = m.graph.unwrap();
698        assert_eq!(graph.name, "embedding_graph");
699        assert_eq!(graph.node.len(), 1);
700        assert_eq!(graph.node[0].op_type, "Gather");
701        assert_eq!(graph.initializer.len(), 1);
702
703        std::fs::remove_file(path_str).ok();
704    }
705
706    #[test]
707    fn test_sentence_embedding() {
708        let data = make_test_data();
709        let config = test_config(ModelType::SkipGram);
710        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
711        model.train(&data).unwrap();
712
713        let sentence = vec!["the".to_string(), "cat".to_string(), "sat".to_string()];
714        let emb = model.sentence_embedding(&sentence, &data);
715        assert!(emb.is_some());
716        let emb = emb.unwrap();
717        assert_eq!(emb.len(), config.embedding_dim);
718
719        // Empty sentence should return None
720        assert!(model.sentence_embedding(&[], &data).is_none());
721    }
722
723    #[test]
724    fn test_multimodal_fusion() {
725        let text = Array::from_vec(vec![1.0, 2.0, 3.0]);
726        let aux = Array::from_vec(vec![4.0, 5.0, 6.0]);
727        let fusion = MultimodalFusion::new(3, 3, 3);
728
729        // Concatenation
730        let concat = fusion.concatenate(&text, &aux);
731        assert_eq!(concat.len(), 6);
732        assert_eq!(concat[0], 1.0);
733        assert_eq!(concat[5], 6.0);
734
735        // Weighted average (same dims)
736        let avg = fusion.weighted_average(&text, &aux, 0.5).unwrap();
737        assert_eq!(avg.len(), 3);
738        assert!((avg[0] - 2.5).abs() < 1e-6);
739
740        // Mismatched dims should return None
741        let short = Array::from_vec(vec![1.0, 2.0]);
742        assert!(fusion.weighted_average(&text, &short, 0.5).is_none());
743
744        // Attention fusion
745        let attn = fusion.attention_fusion(&text, &aux).unwrap();
746        assert_eq!(attn.len(), 3);
747
748        // Cross-modal similarity
749        let sim = MultimodalFusion::cross_modal_similarity(&text, &aux);
750        assert!(sim >= -1.0 && sim <= 1.0);
751    }
752
753    #[test]
754    fn test_cross_lingual_aligner() {
755        let dim = 4;
756        let mut aligner = CrossLingualAligner::new(dim);
757
758        // Identity projection should preserve vectors
759        let v = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
760        let aligned = aligner.align(&v);
761        assert_eq!(aligned, v);
762
763        // Train on a single synthetic pair: src -> tgt = src * 2
764        let src = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
765        let tgt = Array::from_vec(vec![2.0, 0.0, 0.0, 0.0]);
766        aligner.train_from_dictionary(&[(src, tgt)], 100, 0.1);
767
768        // After training, projecting [1,0,0,0] should be close to [2,0,0,0]
769        let test = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
770        let result = aligner.align(&test);
771        assert!((result[0] - 2.0).abs() < 0.1, "Expected ~2.0, got {}", result[0]);
772    }
773
774    #[test]
775    fn test_domain_adapter() {
776        let mut data = make_test_data();
777        let config = test_config(ModelType::SkipGram);
778        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
779        model.train(&data).unwrap();
780
781        let domain_sentences = vec![
782            vec!["the".to_string(), "cat".to_string()],
783            vec!["a".to_string(), "dog".to_string()],
784        ];
785        assert!(DomainAdapter::adapt(&mut model, &mut data, &domain_sentences, 1).is_ok());
786        // Domain words should now be in vocab
787        assert!(data.vocab.contains_key("cat"));
788    }
789
790    #[test]
791    fn test_document_embedder() {
792        let data = make_test_data();
793        let config = test_config(ModelType::SkipGram);
794        let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
795        model.train(&data).unwrap();
796
797        let sentences = vec![
798            vec!["the".to_string(), "cat".to_string()],
799            vec!["a".to_string(), "dog".to_string()],
800        ];
801        let doc = DocumentEmbedder::embed_document(&model, &data, &sentences);
802        assert!(doc.is_some());
803        assert_eq!(doc.unwrap().len(), config.embedding_dim);
804
805        assert!(DocumentEmbedder::embed_document(&model, &data, &[]).is_none());
806    }
807
808    #[test]
809    fn test_zero_shot_transfer() {
810        let proto_a = Array::from_vec(vec![1.0, 0.0, 0.0]);
811        let proto_b = Array::from_vec(vec![0.0, 1.0, 0.0]);
812        let mut prototypes = HashMap::new();
813        prototypes.insert("class_a".to_string(), proto_a);
814        prototypes.insert("class_b".to_string(), proto_b);
815
816        let query = Array::from_vec(vec![0.9, 0.1, 0.0]);
817        let result = ZeroShotTransfer::classify(&query, &prototypes);
818        assert!(result.is_some());
819        let (label, sim) = result.unwrap();
820        assert_eq!(label, "class_a");
821        assert!(sim > 0.9);
822    }
823
824    #[test]
825    fn test_query_expander() {
826        let data = make_test_data();
827        let config = test_config(ModelType::SkipGram);
828        let mut model = EmbeddingModel::new(config, data.vocab.len());
829        model.train(&data).unwrap();
830
831        let expanded = QueryExpander::expand(&model, &data, "cat", 3);
832        assert!(!expanded.is_empty());
833        assert_eq!(expanded[0], "cat");
834    }
835
836    #[test]
837    fn test_hierarchical_clustering() {
838        let data = make_test_data();
839        let config = test_config(ModelType::SkipGram);
840        let mut model = EmbeddingModel::new(config, data.vocab.len());
841        model.train(&data).unwrap();
842
843        let clusters = HierarchicalClustering::cluster(&model, &data, 2);
844        assert_eq!(clusters.len(), 2);
845        // Every vocab word should belong to exactly one cluster
846        let mut all_words = std::collections::HashSet::new();
847        for c in &clusters {
848            for word in c {
849                assert!(all_words.insert(word.clone()));
850            }
851        }
852        assert_eq!(all_words.len(), data.vocab.len());
853    }
854
855    #[test]
856    fn test_unicode_normalization() {
857        let processor = TextProcessor {
858            lowercase: true,
859            remove_punctuation: false,
860            remove_numbers: false,
861            remove_stop_words: false,
862            remove_html: false,
863            remove_urls: false,
864            expand_contractions: false,
865            normalize_unicode: false,
866            language: "en".to_string(),
867        };
868        // e with combining acute (U+0065 U+0301) should match precomposed e-acute (U+00E9)
869        let text = "caf\u{0065}\u{0301}";
870        let sentences = processor.process_text(text);
871        assert_eq!(sentences.len(), 1);
872        assert_eq!(sentences[0].len(), 1);
873        // After NFC normalization it should be "café"
874        assert_eq!(sentences[0][0], "caf\u{00e9}");
875    }
876
877    #[test]
878    fn test_code_embedding_pipeline() {
879        let code = r#"
880            fn computeEmbeddingVector(input: Vec<f32>) -> Vec<f32> {
881                let result = vec![];
882                for x in input {
883                    result.push(x * 2.0);
884                }
885                result
886            }
887        "#;
888        let sentences = load_code_data(code);
889        assert!(!sentences.is_empty());
890
891        let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
892        assert!(vocab.contains_key("compute"));
893        assert!(vocab.contains_key("embedding"));
894        assert!(vocab.contains_key("vector"));
895        assert!(vocab.contains_key("result"));
896
897        let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
898        let config = test_config(ModelType::SkipGram);
899        let mut model = EmbeddingModel::new(config, data.vocab.len());
900        assert!(model.train(&data).is_ok());
901
902        // Check that code tokens have embeddings
903        assert!(model.get_embedding("embedding", &data).is_some());
904        assert!(model.get_embedding("vector", &data).is_some());
905    }
906
907    #[test]
908    fn test_western_language_embedding_pipeline() {
909        // French text with diacritics
910        let text = "Le chat noir dort sur le tapis. Le chien brun joue dans le jardin.";
911        let sentences = load_text_data(text);
912        assert!(!sentences.is_empty());
913
914        let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
915        // Verify French words are preserved including accented characters
916        assert!(vocab.contains_key("chat"));
917        assert!(vocab.contains_key("noir"));
918        assert!(vocab.contains_key("dort"));
919        assert!(vocab.contains_key("chien"));
920        assert!(vocab.contains_key("jardin"));
921
922        let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
923        let config = test_config(ModelType::SkipGram);
924        let mut model = EmbeddingModel::new(config, data.vocab.len());
925        assert!(model.train(&data).is_ok());
926
927        // French words should have embeddings
928        assert!(model.get_embedding("chat", &data).is_some());
929        assert!(model.get_embedding("chien", &data).is_some());
930        assert!(model.get_embedding("jardin", &data).is_some());
931    }
932
933    #[test]
934    fn test_chinese_embedding_pipeline() {
935        // Chinese text
936        let text = "猫坐在垫子上。狗在花园里玩。猫追狗。";
937        let sentences = load_text_data(text);
938        assert!(!sentences.is_empty());
939
940        let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
941        // Verify Chinese characters are tokenized individually
942        assert!(vocab.contains_key("猫"));
943        assert!(vocab.contains_key("坐"));
944        assert!(vocab.contains_key("狗"));
945        assert!(vocab.contains_key("花"));
946        assert!(vocab.contains_key("追"));
947
948        let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
949        let config = test_config(ModelType::SkipGram);
950        let mut model = EmbeddingModel::new(config, data.vocab.len());
951        assert!(model.train(&data).is_ok());
952
953        // Chinese characters should have embeddings
954        assert!(model.get_embedding("猫", &data).is_some());
955        assert!(model.get_embedding("狗", &data).is_some());
956        assert!(model.get_embedding("追", &data).is_some());
957    }
958
959    #[test]
960    fn test_japanese_embedding_pipeline() {
961        // Japanese text with hiragana and kanji
962        let text = "猫はマットの上に座っています。犬は庭で遊んでいます。";
963        let sentences = load_text_data(text);
964        assert!(!sentences.is_empty());
965
966        let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
967        // Verify Japanese characters are tokenized
968        assert!(vocab.contains_key("猫"));
969        assert!(vocab.contains_key("座"));
970        assert!(vocab.contains_key("犬"));
971        assert!(vocab.contains_key("遊"));
972
973        let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
974        let config = test_config(ModelType::SkipGram);
975        let mut model = EmbeddingModel::new(config, data.vocab.len());
976        assert!(model.train(&data).is_ok());
977
978        assert!(model.get_embedding("猫", &data).is_some());
979        assert!(model.get_embedding("犬", &data).is_some());
980    }
981
982    #[test]
983    fn test_subword_embedder() {
984        let embedder = SubwordEmbedder::new(3, 5);
985        let ngrams = embedder.ngrams("apple");
986        assert!(!ngrams.is_empty());
987        assert!(ngrams.contains(&"<ap".to_string()));
988        assert!(ngrams.contains(&"ple>".to_string()));
989
990        let mut vectors = HashMap::new();
991        vectors.insert("<ap".to_string(), Array::from_vec(vec![1.0, 0.0]));
992        vectors.insert("app".to_string(), Array::from_vec(vec![0.0, 1.0]));
993        vectors.insert("ppl".to_string(), Array::from_vec(vec![1.0, 1.0]));
994        vectors.insert("ple>".to_string(), Array::from_vec(vec![0.5, 0.5]));
995
996        let emb = embedder.embed("apple", &vectors);
997        assert!(emb.is_some());
998        let emb = emb.unwrap();
999        assert_eq!(emb.len(), 2);
1000    }
1001
1002    #[test]
1003    fn test_create_validation_data() {
1004        let data = make_test_data();
1005        let config = test_config(ModelType::SkipGram);
1006        let model = EmbeddingModel::new(config, data.vocab.len());
1007        let val_data = model.create_validation_data(&data.sentences);
1008        assert!(!val_data.positive_pairs.is_empty());
1009        assert!(!val_data.negative_pairs.is_empty());
1010    }
1011
1012    #[test]
1013    fn test_evaluate_produces_metrics() {
1014        let data = make_test_data();
1015        let config = test_config(ModelType::SkipGram);
1016        let mut model = EmbeddingModel::new(config, data.vocab.len());
1017        model.train(&data).unwrap();
1018
1019        let val_data = model.create_validation_data(&data.sentences);
1020        let metrics = model.evaluate(&data, &val_data);
1021        assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
1022        assert!(metrics.precision >= 0.0 && metrics.precision <= 1.0);
1023        assert!(metrics.recall >= 0.0 && metrics.recall <= 1.0);
1024        assert!(metrics.f1_score >= 0.0 && metrics.f1_score <= 1.0);
1025        assert!(metrics.mean_similarity >= -1.0 && metrics.mean_similarity <= 1.0);
1026        assert!(metrics.embedding_quality_score >= 0.0 && metrics.embedding_quality_score <= 1.0);
1027    }
1028
1029    #[test]
1030    fn test_train_with_validation_split() {
1031        let data = make_test_data();
1032        let mut config = test_config(ModelType::SkipGram);
1033        config.validation_ratio = Some(0.3);
1034        let mut model = EmbeddingModel::new(config, data.vocab.len());
1035        assert!(model.train(&data).is_ok());
1036        assert!(model.get_embedding("cat", &data).is_some());
1037    }
1038
1039    #[test]
1040    fn test_cross_validation_basic() {
1041        let data = make_test_data();
1042        let config = test_config(ModelType::SkipGram);
1043        let model = EmbeddingModel::new(config, data.vocab.len());
1044
1045        let result = model.cross_validate(&data, 3).unwrap();
1046        assert_eq!(result.folds, 3);
1047        assert_eq!(result.per_fold_metrics.len(), 3);
1048
1049        // Averaged metrics should be within valid ranges
1050        assert!(result.averaged_metrics.accuracy >= 0.0 && result.averaged_metrics.accuracy <= 1.0);
1051        assert!(result.averaged_metrics.f1_score >= 0.0 && result.averaged_metrics.f1_score <= 1.0);
1052    }
1053
1054    #[test]
1055    fn test_cross_validation_invalid_k() {
1056        let data = make_test_data();
1057        let config = test_config(ModelType::SkipGram);
1058        let model = EmbeddingModel::new(config, data.vocab.len());
1059
1060        assert!(model.cross_validate(&data, 0).is_err());
1061        assert!(model.cross_validate(&data, 1).is_err());
1062        assert!(model.cross_validate(&data, 100).is_err());
1063    }
1064
1065    #[test]
1066    fn test_cross_validation_k_equals_2() {
1067        let data = make_test_data();
1068        let config = test_config(ModelType::SkipGram);
1069        let model = EmbeddingModel::new(config, data.vocab.len());
1070
1071        let result = model.cross_validate(&data, 2).unwrap();
1072        assert_eq!(result.folds, 2);
1073        assert_eq!(result.per_fold_metrics.len(), 2);
1074    }
1075
1076    #[test]
1077    fn test_l2_normalize_embeddings() {
1078        let data = make_test_data();
1079        let config = test_config(ModelType::SkipGram);
1080        let mut model = EmbeddingModel::new(config, data.vocab.len());
1081        model.train(&data).unwrap();
1082
1083        model.normalize_embeddings();
1084
1085        // Verify all embeddings are unit length
1086        for row in model.embeddings.rows() {
1087            let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
1088            assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {}", norm);
1089        }
1090    }
1091
1092    #[test]
1093    fn test_cross_validation_cbow() {
1094        let data = make_test_data();
1095        let config = test_config(ModelType::Cbow);
1096        let model = EmbeddingModel::new(config, data.vocab.len());
1097
1098        let result = model.cross_validate(&data, 2).unwrap();
1099        assert_eq!(result.folds, 2);
1100        assert!(result.averaged_metrics.accuracy >= 0.0);
1101    }
1102
1103    #[test]
1104    fn test_training_history_records_epochs() {
1105        let data = make_test_data();
1106        let config = test_config(ModelType::SkipGram);
1107        let mut model = EmbeddingModel::new(config, data.vocab.len());
1108        model.train(&data).unwrap();
1109
1110        assert!(!model.training_history.epochs.is_empty());
1111        let first = &model.training_history.epochs[0];
1112        assert!(first.loss >= 0.0);
1113        assert!(first.learning_rate > 0.0);
1114
1115        let json = model.training_history.to_json().unwrap();
1116        assert!(json.contains("loss"));
1117        assert!(json.contains("learning_rate"));
1118    }
1119
1120    #[test]
1121    fn test_wordpiece_tokenizer_train_encode_decode() {
1122        let corpus = vec![
1123            "hello".to_string(),
1124            "world".to_string(),
1125            "hello".to_string(),
1126            "world".to_string(),
1127        ];
1128        let tokenizer = tokenizer::WordPieceTokenizer::train(&corpus, 50);
1129        assert!(tokenizer.vocab_size > 0);
1130
1131        let tokens = tokenizer.encode("hello");
1132        assert!(!tokens.is_empty());
1133
1134        let decoded = tokenizer.decode(&tokens);
1135        assert_eq!(decoded, "hello");
1136    }
1137
1138    #[test]
1139    fn test_cpu_backend() {
1140        let backend = backend::CpuBackend::new();
1141        assert_eq!(backend.name(), "cpu");
1142
1143        let emb = backend.init_embeddings(10, 8);
1144        assert_eq!(emb.nrows(), 10);
1145        assert_eq!(emb.ncols(), 8);
1146
1147        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1148        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1149        assert!((backend.dot(&a, &b) - 32.0).abs() < 1e-5);
1150
1151        let mut c = a.clone();
1152        backend.add_scaled(&mut c, &b, 2.0);
1153        assert_eq!(c.to_vec(), vec![9.0, 12.0, 15.0]);
1154
1155        let m1 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1156        let m2 = Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
1157        let result = backend.matmul(&m1, &m2);
1158        assert_eq!(result.nrows(), 2);
1159        assert_eq!(result.ncols(), 2);
1160    }
1161
1162    #[test]
1163    fn test_best_backend_returns_cpu() {
1164        let backend = backend::best_backend();
1165        assert!(!backend.name().is_empty());
1166    }
1167
1168    #[test]
1169    fn test_benchmark_load_and_evaluate() {
1170        let tsv = "cat\tdog\t0.8\ncat\tmat\t0.2\ndog\tmat\t0.1\n";
1171        let pairs = benchmark::BenchmarkEvaluator::load_from_tsv(tsv);
1172        assert_eq!(pairs.len(), 3);
1173        assert_eq!(pairs[0].word1, "cat");
1174
1175        let data = make_test_data();
1176        let config = test_config(ModelType::SkipGram);
1177        let mut model = EmbeddingModel::new(config, data.vocab.len());
1178        model.train(&data).unwrap();
1179
1180        let result = benchmark::BenchmarkEvaluator::evaluate(&model, &data, &pairs);
1181        assert_eq!(result.num_pairs, 3);
1182        // Some words may be OOV so num_evaluated <= 3
1183        assert!(result.num_evaluated <= 3);
1184        // Correlation is between -1 and 1
1185        assert!(result.correlation >= -1.0 && result.correlation <= 1.0);
1186    }
1187
1188    #[test]
1189    fn test_kmeans_clustering() {
1190        let data = make_test_data();
1191        let config = test_config(ModelType::SkipGram);
1192        let mut model = EmbeddingModel::new(config, data.vocab.len());
1193        model.train(&data).unwrap();
1194
1195        let clusters = search::KMeansClustering::cluster(&model, &data, 3, 20);
1196        assert!(!clusters.is_empty());
1197        assert!(clusters.len() <= 3);
1198
1199        let total_words: usize = clusters.iter().map(|c| c.len()).sum();
1200        assert_eq!(total_words, data.vocab.len());
1201    }
1202
1203    #[test]
1204    fn test_kmeans_clustering_k_greater_than_vocab() {
1205        let data = make_test_data();
1206        let config = test_config(ModelType::SkipGram);
1207        let mut model = EmbeddingModel::new(config, data.vocab.len());
1208        model.train(&data).unwrap();
1209
1210        // k larger than vocab should clamp to vocab size
1211        let clusters = search::KMeansClustering::cluster(&model, &data, 100, 10);
1212        assert_eq!(clusters.len(), data.vocab.len());
1213    }
1214
1215    #[test]
1216    fn test_transformer_encoder() {
1217        let encoder = TransformerEncoder::new(2, 2, 8, 16, 10);
1218        let tokens = ndarray::Array2::zeros((3, 8));
1219        let encoded = encoder.encode_sequence(&tokens);
1220        assert_eq!(encoded.nrows(), 3);
1221        assert_eq!(encoded.ncols(), 8);
1222    }
1223
1224    #[test]
1225    fn test_incremental_trainer() {
1226        let mut data = make_test_data();
1227        let config = test_config(ModelType::SkipGram);
1228        let mut model = EmbeddingModel::new(config, data.vocab.len());
1229        model.train(&data).unwrap();
1230
1231        let original_vocab = data.vocab.len();
1232        let new_sentences = vec![vec!["newword".to_string(), "cat".to_string()]];
1233
1234        IncrementalTrainer::update(&mut model, &mut data, &new_sentences, 1).unwrap();
1235
1236        // Vocabulary should have grown
1237        assert!(data.vocab.len() >= original_vocab);
1238        // Model should now know the new word
1239        assert!(data.vocab.contains_key("newword"));
1240    }
1241
1242    #[test]
1243    fn test_incremental_stream_train() {
1244        let mut data = make_test_data();
1245        let config = test_config(ModelType::SkipGram);
1246        let mut model = EmbeddingModel::new(config, data.vocab.len());
1247        model.train(&data).unwrap();
1248
1249        let sentences = vec![
1250            vec!["stream".to_string(), "word".to_string()],
1251            vec!["another".to_string(), "stream".to_string()],
1252        ];
1253
1254        IncrementalTrainer::stream_train(
1255            &mut model,
1256            &mut data,
1257            sentences.into_iter(),
1258            1,
1259            1,
1260        )
1261        .unwrap();
1262
1263        assert!(data.vocab.contains_key("stream"));
1264    }
1265
1266    #[test]
1267    fn test_mmap_embeddings_roundtrip() {
1268        let data = make_test_data();
1269        let config = test_config(ModelType::SkipGram);
1270        let mut model = EmbeddingModel::new(config, data.vocab.len());
1271        model.train(&data).unwrap();
1272
1273        let temp_path = std::env::temp_dir().join("test_mmap.bin");
1274        let path_str = temp_path.to_str().unwrap();
1275
1276        // Save in mmapable format
1277        model.save_mmapable_format(path_str, &data).unwrap();
1278
1279        // Load via mmap
1280        let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
1281        assert_eq!(mmap.vocab_size(), data.vocab.len());
1282        assert_eq!(mmap.dim(), model.config.embedding_dim);
1283
1284        // Verify a known word
1285        let cat_emb = mmap.get("cat").unwrap();
1286        assert_eq!(cat_emb.len(), model.config.embedding_dim);
1287
1288        // Verify it matches the in-memory embedding
1289        let cat_id = data.vocab["cat"];
1290        let model_cat: Vec<f32> = model.embeddings.row(cat_id).to_vec();
1291        assert_eq!(cat_emb, model_cat.as_slice());
1292
1293        std::fs::remove_file(path_str).ok();
1294    }
1295
1296    #[test]
1297    fn test_mmap_embeddings_iter() {
1298        let data = make_test_data();
1299        let config = test_config(ModelType::SkipGram);
1300        let mut model = EmbeddingModel::new(config, data.vocab.len());
1301        model.train(&data).unwrap();
1302
1303        let temp_path = std::env::temp_dir().join("test_mmap_iter.bin");
1304        let path_str = temp_path.to_str().unwrap();
1305
1306        model.save_mmapable_format(path_str, &data).unwrap();
1307        let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
1308
1309        let mut count = 0;
1310        for (word, emb) in mmap.iter() {
1311            assert!(!word.is_empty());
1312            assert_eq!(emb.len(), model.config.embedding_dim);
1313            count += 1;
1314        }
1315        assert_eq!(count, data.vocab.len());
1316
1317        std::fs::remove_file(path_str).ok();
1318    }
1319
1320    #[test]
1321    fn test_mmap_embeddings_missing_word() {
1322        let data = make_test_data();
1323        let config = test_config(ModelType::SkipGram);
1324        let mut model = EmbeddingModel::new(config, data.vocab.len());
1325        model.train(&data).unwrap();
1326
1327        let temp_path = std::env::temp_dir().join("test_mmap_missing.bin");
1328        let path_str = temp_path.to_str().unwrap();
1329
1330        model.save_mmapable_format(path_str, &data).unwrap();
1331        let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
1332
1333        assert!(mmap.get("nonexistent_word").is_none());
1334        assert!(mmap.get("cat").is_some());
1335
1336        std::fs::remove_file(path_str).ok();
1337    }
1338
1339    #[test]
1340    fn test_pretrained_loader_word2vec_text() {
1341        let temp = std::env::temp_dir().join("test_pretrained_w2v.txt");
1342        let path = temp.to_str().unwrap();
1343
1344        let content = "3 4\ncat 0.1 0.2 0.3 0.4\ndog 0.5 0.6 0.7 0.8\nfish 0.9 0.0 0.1 0.2\n";
1345        std::fs::write(path, content).unwrap();
1346
1347        let emb = PretrainedLoader::auto(path).unwrap();
1348        assert_eq!(emb.dim(), 4);
1349        assert_eq!(emb.vocab_size(), 3);
1350        assert!(emb.contains("cat"));
1351        assert!(emb.contains("dog"));
1352        assert!(!emb.contains("elephant"));
1353
1354        let cat = emb.get("cat").unwrap();
1355        assert_eq!(cat.len(), 4);
1356        assert!((cat[0] - 0.1).abs() < 1e-6);
1357
1358        std::fs::remove_file(path).ok();
1359    }
1360
1361    #[test]
1362    fn test_pretrained_embeddings_similarity() {
1363        let mut emb = PretrainedEmbeddings::new(3);
1364        emb.insert("a".to_string(), vec![1.0, 0.0, 0.0]);
1365        emb.insert("b".to_string(), vec![0.0, 1.0, 0.0]);
1366        emb.insert("c".to_string(), vec![1.0, 0.0, 0.0]);
1367
1368        // a and b are orthogonal
1369        let sim_ab = emb.similarity("a", "b").unwrap();
1370        assert!(sim_ab.abs() < 1e-5, "Orthogonal vectors should have ~0 similarity");
1371
1372        // a and c are identical
1373        let sim_ac = emb.similarity("a", "c").unwrap();
1374        assert!((sim_ac - 1.0).abs() < 1e-5, "Identical vectors should have similarity ~1");
1375
1376        // Missing word
1377        assert!(emb.similarity("a", "missing").is_none());
1378    }
1379
1380    #[test]
1381    fn test_pretrained_embeddings_most_similar() {
1382        let mut emb = PretrainedEmbeddings::new(2);
1383        emb.insert("king".to_string(),   vec![1.0, 0.0]);
1384        emb.insert("queen".to_string(),  vec![0.9, 0.1]);
1385        emb.insert("man".to_string(),     vec![0.1, 1.0]);
1386        emb.insert("woman".to_string(),  vec![0.2, 0.9]);
1387
1388        let similar = emb.most_similar("king", 2);
1389        assert_eq!(similar.len(), 2);
1390        assert_eq!(similar[0].0, "queen"); // closest to king
1391    }
1392
1393    #[test]
1394    fn test_pretrained_loader_glove_format() {
1395        let temp = std::env::temp_dir().join("test_pretrained_glove.txt");
1396        let path = temp.to_str().unwrap();
1397
1398        let content = "2 3\nhello 0.1 0.2 0.3\nworld 0.4 0.5 0.6\n";
1399        std::fs::write(path, content).unwrap();
1400
1401        let emb = PretrainedLoader::with_format(path, pretrained::PretrainedFormat::GloVe).unwrap();
1402        assert_eq!(emb.dim(), 3);
1403        assert_eq!(emb.vocab_size(), 2);
1404
1405        let hello = emb.get("hello").unwrap();
1406        assert_eq!(hello, &[0.1, 0.2, 0.3]);
1407
1408        std::fs::remove_file(path).ok();
1409    }
1410
1411    #[test]
1412    fn test_pretrained_init_model_from_pretrained() {
1413        let temp = std::env::temp_dir().join("test_pretrained_init.txt");
1414        let path = temp.to_str().unwrap();
1415
1416        // Pre-trained file with some of our test vocab
1417        let content = "3 8\ncat 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1\ndog 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2\nthe 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3\n";
1418        std::fs::write(path, content).unwrap();
1419
1420        let data = make_test_data();
1421        let config = test_config(ModelType::SkipGram);
1422        let model = EmbeddingModel::new_with_pretrained(config, data.vocab.len(), &data, path).unwrap();
1423
1424        // Words in the pretrained file should have those exact values
1425        let cat_id = data.vocab["cat"];
1426        let cat_emb = model.embeddings.row(cat_id);
1427        for &v in cat_emb.iter() {
1428            assert!((v - 0.1).abs() < 1e-5);
1429        }
1430
1431        std::fs::remove_file(path).ok();
1432    }
1433}