1pub mod onnx {
32 include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
33}
34
35pub mod config;
36pub mod evaluation;
37pub mod search;
38pub mod code;
39pub mod text;
40pub mod tokenizer;
41pub mod transfer;
42pub use config::*;
43pub use evaluation::*;
44pub use search::*;
45pub use code::*;
46pub use text::*;
47pub use tokenizer::*;
48pub use transfer::*;
49
50pub mod model;
51pub mod backend;
52pub mod benchmark;
53pub mod transformer;
54pub mod mmap;
55pub mod pretrained;
56mod training;
57mod export;
58pub mod cli;
59mod commands;
60pub use model::*;
61pub use backend::*;
62pub use benchmark::*;
63pub use transformer::*;
64pub use mmap::MmapEmbeddings;
65pub use pretrained::{PretrainedEmbeddings, PretrainedLoader};
66pub use training::IncrementalTrainer;
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use ndarray::{Array, Array1, Array2};
72 use std::collections::HashMap;
73
74 #[test]
75 fn test_build_vocab() {
76 let sentences = vec![
77 vec!["hello".to_string(), "world".to_string()],
78 vec!["hello".to_string(), "rust".to_string()],
79 ];
80
81 let (vocab, reverse_vocab) = build_vocab(&sentences);
82
83 assert_eq!(vocab.len(), 3);
84 assert_eq!(reverse_vocab.len(), 3);
85 assert_eq!(vocab.get("hello"), Some(&0));
86 assert_eq!(vocab.get("world"), Some(&1));
87 assert_eq!(vocab.get("rust"), Some(&2));
88 }
89
90 #[test]
91 fn test_load_text_data() {
92 let text = "Hello world! This is a test.";
93 let sentences = load_text_data(text);
94
95 assert_eq!(sentences.len(), 2);
96 assert_eq!(sentences[0], vec!["hello", "world"]);
97 assert_eq!(sentences[1], vec!["this", "is", "a", "test"]);
98 }
99
100 fn make_test_data() -> TrainingData {
101 TrainingData::from_text("the cat sat on the mat. the dog sat on the log. the cat chased the dog.")
102 }
103
104 fn test_config(model_type: ModelType) -> TrainingConfig {
105 TrainingConfig::new(model_type)
106 .with_dim(8)
107 .with_learning_rate(0.1)
108 .with_epochs(2)
109 .with_batch_size(4)
110 .with_window(1)
111 .with_negative_samples(2)
112 }
113
114 #[test]
115 fn test_train_skipgram() {
116 let data = make_test_data();
117 let config = test_config(ModelType::SkipGram);
118 let mut model = EmbeddingModel::new(config, data.vocab.len());
119
120 assert!(model.train(&data).is_ok());
121
122 assert!(model.get_embedding("cat", &data).is_some());
124 assert!(model.get_embedding("dog", &data).is_some());
125 assert!(model.get_embedding("the", &data).is_some());
126
127 assert!(model.similarity("cat", "dog", &data).is_some());
129 }
130
131 #[test]
132 fn test_train_cbow() {
133 let data = make_test_data();
134 let config = test_config(ModelType::Cbow);
135 let mut model = EmbeddingModel::new(config, data.vocab.len());
136
137 assert!(model.train(&data).is_ok());
138
139 assert!(model.get_embedding("cat", &data).is_some());
140 assert!(model.get_embedding("dog", &data).is_some());
141 assert!(model.similarity("cat", "dog", &data).is_some());
142 }
143
144 #[test]
145 fn test_build_vocab_with_freq_counts_correctly() {
146 let sentences = vec![
147 vec!["the".to_string(), "cat".to_string(), "sat".to_string()],
148 vec!["the".to_string(), "dog".to_string(), "sat".to_string()],
149 ];
150 let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
151 assert_eq!(vocab.len(), 4);
152 assert_eq!(reverse_vocab.len(), 4);
153 assert_eq!(word_freq.len(), 4);
154
155 let the_id = vocab["the"];
156 let sat_id = vocab["sat"];
157 assert_eq!(word_freq[the_id], 2);
158 assert_eq!(word_freq[sat_id], 2);
159 assert_eq!(word_freq[vocab["cat"]], 1);
160 assert_eq!(word_freq[vocab["dog"]], 1);
161 }
162
163 #[test]
164 fn test_unigram_negative_sampling_runs() {
165 let data = make_test_data();
166 let config = test_config(ModelType::SkipGram)
167 .with_unigram_negative_sampling(true)
168 .with_epochs(2);
169 let mut model = EmbeddingModel::new(config, data.vocab.len());
170 assert!(model.train(&data).is_ok());
171 }
172
173 #[test]
174 fn test_subsampling_runs() {
175 let data = make_test_data();
176 let config = test_config(ModelType::SkipGram)
177 .with_subsample_threshold(Some(1e-5))
178 .with_epochs(2);
179 let mut model = EmbeddingModel::new(config, data.vocab.len());
180 assert!(model.train(&data).is_ok());
181 }
182
183 #[test]
184 fn test_subsampling_drops_frequent_words() {
185 let data = make_test_data();
186 let total = data.total_word_count();
187 assert!(total > 0);
188
189 let config = TrainingConfig::new(ModelType::SkipGram)
191 .with_dim(4)
192 .with_epochs(1)
193 .with_batch_size(2)
194 .with_subsample_threshold(Some(1e-3));
195 let mut model = EmbeddingModel::new(config, data.vocab.len());
196 assert!(model.train(&data).is_ok());
197 }
198
199 #[test]
200 fn test_lr_warmup() {
201 let data = make_test_data();
202 let config = test_config(ModelType::SkipGram)
203 .with_warmup_epochs(Some(3))
204 .with_epochs(5)
205 .with_learning_rate(0.1);
206 let mut model = EmbeddingModel::new(config, data.vocab.len());
207 assert!(model.train(&data).is_ok());
208
209 let lr0 = model.get_learning_rate(0, 5);
211 let lr1 = model.get_learning_rate(1, 5);
212 let lr2 = model.get_learning_rate(2, 5);
213 let lr3 = model.get_learning_rate(3, 5);
214 assert!(lr0 < lr1 && lr1 < lr2, "LR should increase during warm-up");
215 assert!(lr2 < lr3, "LR should reach base rate after warm-up");
216 assert!(lr3 > 0.0, "LR after warm-up should be positive");
217 }
218
219 #[test]
220 fn test_checkpoint_save_and_load() {
221 let data = make_test_data();
222 let config = test_config(ModelType::SkipGram).with_epochs(2);
223 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
224 model.train(&data).unwrap();
225
226 let temp_dir = std::env::temp_dir();
227 let path = temp_dir.join("test_checkpoint.json");
228 let path_str = path.to_str().unwrap();
229
230 model.save_checkpoint(path_str, 2, 0.5).unwrap();
231 let loaded = EmbeddingModel::load_checkpoint(path_str).unwrap();
232
233 assert_eq!(loaded.config.model_type, config.model_type);
234 assert_eq!(loaded.vocab_size, model.vocab_size);
235 assert_eq!(loaded.embeddings.shape(), model.embeddings.shape());
236 }
237
238 #[test]
239 fn test_parallel_training_skipgram() {
240 let data = make_test_data();
241 let config = test_config(ModelType::SkipGram)
242 .with_parallel(true)
243 .with_epochs(2);
244 let mut model = EmbeddingModel::new(config, data.vocab.len());
245 assert!(model.train(&data).is_ok());
246 assert!(model.get_embedding("cat", &data).is_some());
247 }
248
249 #[test]
250 fn test_parallel_training_cbow() {
251 let data = make_test_data();
252 let config = test_config(ModelType::Cbow)
253 .with_parallel(true)
254 .with_epochs(2);
255 let mut model = EmbeddingModel::new(config, data.vocab.len());
256 assert!(model.train(&data).is_ok());
257 assert!(model.get_embedding("dog", &data).is_some());
258 }
259
260 #[test]
261 fn test_save_embeddings() {
262 let data = make_test_data();
263 let config = test_config(ModelType::SkipGram);
264 let mut model = EmbeddingModel::new(config, data.vocab.len());
265 model.train(&data).unwrap();
266
267 let temp_dir = std::env::temp_dir();
268 let path = temp_dir.join("test_embeddings_save.txt");
269 let path_str = path.to_str().unwrap();
270
271 assert!(model.save_embeddings(path_str, &data).is_ok());
272 let contents = std::fs::read_to_string(path_str).unwrap();
273 assert!(contents.contains("cat"));
274 assert!(contents.contains("dog"));
275
276 std::fs::remove_file(path_str).ok();
277 }
278
279 #[test]
280 fn test_similarity_unknown_word() {
281 let data = make_test_data();
282 let config = test_config(ModelType::SkipGram);
283 let mut model = EmbeddingModel::new(config, data.vocab.len());
284 model.train(&data).unwrap();
285
286 assert!(model.similarity("cat", "nonexistent", &data).is_none());
287 assert!(model.similarity("nonexistent", "dog", &data).is_none());
288 }
289
290 #[test]
291 fn test_strip_html() {
292 let processor = TextProcessor {
293 remove_html: true,
294 remove_punctuation: false,
295 lowercase: false,
296 ..TextProcessor::default()
297 };
298 let text = "<p>Hello world!</p> This is a <b>test</b>.";
299 let sentences = processor.process_text(text);
300 assert_eq!(sentences.len(), 2);
301 assert_eq!(sentences[0], vec!["Hello", "world"]);
302 assert_eq!(sentences[1], vec!["This", "is", "a", "test"]);
303 }
304
305 #[test]
306 fn test_strip_urls() {
307 let processor = TextProcessor {
308 remove_urls: true,
309 remove_punctuation: true,
310 lowercase: true,
311 ..TextProcessor::default()
312 };
313 let text = "Visit https://example.com for info. See www.test.org too.";
314 let sentences = processor.process_text(text);
315 assert_eq!(sentences.len(), 2);
316 assert_eq!(sentences[0], vec!["visit", "for", "info"]);
317 assert_eq!(sentences[1], vec!["see", "too"]);
318 }
319
320 #[test]
321 fn test_expand_contractions() {
322 let processor = TextProcessor {
323 expand_contractions: true,
324 remove_punctuation: true,
325 lowercase: true,
326 ..TextProcessor::default()
327 };
328 let text = "I can't do this. It's a test.";
329 let sentences = processor.process_text(text);
330 assert_eq!(sentences.len(), 2);
331 assert_eq!(sentences[0], vec!["i", "cannot", "do", "this"]);
333 assert_eq!(sentences[1], vec!["it", "is", "a", "test"]);
334 }
335
336 #[test]
337 fn test_normalize_embeddings() {
338 let data = make_test_data();
339 let config = test_config(ModelType::SkipGram);
340 let mut model = EmbeddingModel::new(config, data.vocab.len());
341 model.train(&data).unwrap();
342 model.normalize_embeddings();
343
344 for row in model.embeddings.rows() {
345 let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
346 assert!((norm - 1.0).abs() < 1e-5 || norm == 0.0);
347 }
348 }
349
350 #[test]
351 fn test_analogy_unknown_word() {
352 let data = make_test_data();
353 let config = test_config(ModelType::SkipGram);
354 let mut model = EmbeddingModel::new(config, data.vocab.len());
355 model.train(&data).unwrap();
356
357 assert!(model.analogy("unknown", "cat", "dog", &data, 1).is_empty());
358 }
359
360 #[test]
361 fn test_split_data() {
362 let sentences = vec![
363 vec!["a".to_string()],
364 vec!["b".to_string()],
365 vec!["c".to_string()],
366 vec!["d".to_string()],
367 vec!["e".to_string()],
368 vec!["f".to_string()],
369 vec!["g".to_string()],
370 vec!["h".to_string()],
371 vec!["i".to_string()],
372 vec!["j".to_string()],
373 ];
374 let config = test_config(ModelType::SkipGram);
375 let model = EmbeddingModel::new(config, 1);
376 let (train, val) = model.split_data(&sentences, 0.7);
377 assert_eq!(train.len(), 7);
378 assert_eq!(val.len(), 3);
379 }
380
381 #[test]
382 fn test_gradient_clipping() {
383 let data = make_test_data();
384 let mut config = test_config(ModelType::SkipGram);
385 config.gradient_clip = Some(0.001);
386 let mut model = EmbeddingModel::new(config, data.vocab.len());
387
388 assert!(model.train(&data).is_ok());
390 assert!(model.get_embedding("cat", &data).is_some());
391 }
392
393 #[test]
394 fn test_mini_batch_processing() {
395 let data = make_test_data();
396 let mut config1 = test_config(ModelType::SkipGram);
398 config1.batch_size = 1;
399 let mut model1 = EmbeddingModel::new(config1, data.vocab.len());
400 assert!(model1.train(&data).is_ok());
401
402 let mut config8 = test_config(ModelType::SkipGram);
404 config8.batch_size = 8;
405 let mut model8 = EmbeddingModel::new(config8, data.vocab.len());
406 assert!(model8.train(&data).is_ok());
407
408 assert!(model1.get_embedding("cat", &data).is_some());
410 assert!(model8.get_embedding("cat", &data).is_some());
411 }
412
413 #[test]
414 fn test_empty_text() {
415 let sentences = load_text_data("");
416 assert!(sentences.is_empty());
417 }
418
419 #[test]
420 fn test_single_word_text() {
421 let sentences = load_text_data("hello");
422 assert_eq!(sentences.len(), 1);
423 assert_eq!(sentences[0], vec!["hello"]);
424 }
425
426 #[test]
427 fn test_learning_rate_schedules() {
428 let data = make_test_data();
429
430 let mut config_exp = test_config(ModelType::SkipGram);
431 config_exp.lr_schedule = LearningRateSchedule::Exponential { decay_rate: 0.9 };
432 let mut model_exp = EmbeddingModel::new(config_exp, data.vocab.len());
433 assert!(model_exp.train(&data).is_ok());
434
435 let mut config_step = test_config(ModelType::SkipGram);
436 config_step.lr_schedule = LearningRateSchedule::Step { step_size: 1, gamma: 0.5 };
437 let mut model_step = EmbeddingModel::new(config_step, data.vocab.len());
438 assert!(model_step.train(&data).is_ok());
439
440 let mut config_cos = test_config(ModelType::SkipGram);
441 config_cos.lr_schedule = LearningRateSchedule::Cosine { t_max: 2 };
442 let mut model_cos = EmbeddingModel::new(config_cos, data.vocab.len());
443 assert!(model_cos.train(&data).is_ok());
444 }
445
446 #[test]
447 fn test_early_stopping() {
448 let data = make_test_data();
449 let mut config = test_config(ModelType::SkipGram);
450 config.early_stopping = Some(EarlyStoppingConfig { patience: 1, min_delta: 0.001 });
451 config.epochs = 10;
452 let mut model = EmbeddingModel::new(config, data.vocab.len());
453 assert!(model.train(&data).is_ok());
454 }
455
456 #[test]
457 fn test_word2vec_format_roundtrip() {
458 let data = make_test_data();
459 let config = test_config(ModelType::SkipGram);
460 let mut model = EmbeddingModel::new(config, data.vocab.len());
461 model.train(&data).unwrap();
462
463 let temp_path = std::env::temp_dir().join("test_word2vec.txt");
464 let path_str = temp_path.to_str().unwrap();
465
466 assert!(model.save_word2vec_format(path_str, &data).is_ok());
468
469 let (loaded, dim) = EmbeddingModel::load_word2vec_format(path_str).unwrap();
471 assert_eq!(dim, 8);
472 assert!(loaded.contains_key("cat"));
473 assert!(loaded.contains_key("dog"));
474 assert_eq!(loaded.get("cat").unwrap().len(), 8);
475 assert_eq!(loaded.get("dog").unwrap().len(), 8);
476
477 std::fs::remove_file(path_str).ok();
478 }
479
480 #[test]
481 fn test_bpe_tokenizer() {
482 let corpus = vec![
483 "low".to_string(),
484 "lower".to_string(),
485 "lowest".to_string(),
486 "newer".to_string(),
487 "new".to_string(),
488 "widest".to_string(),
489 "wide".to_string(),
490 ];
491
492 let tokenizer = BPETokenizer::train(&corpus, 20);
493
494 assert!(tokenizer.vocab.len() >= 10);
496
497 let tokens = tokenizer.encode("lowest");
499 assert!(!tokens.is_empty());
500
501 let decoded = tokenizer.decode(&tokens);
503 assert_eq!(decoded, "lowest");
504 }
505
506 #[test]
507 fn test_pretrained_embeddings_loading() {
508 let data = make_test_data();
509 let config = test_config(ModelType::SkipGram);
510
511 let temp_path = std::env::temp_dir().join("test_pretrained.txt");
513 let path_str = temp_path.to_str().unwrap();
514
515 let mut file = std::fs::File::create(path_str).unwrap();
516 use std::io::Write;
517 writeln!(file, "{} {}", data.vocab.len(), config.embedding_dim).unwrap();
518 for (word_id, word) in data.reverse_vocab.iter().enumerate() {
519 let vals: Vec<String> = (0..config.embedding_dim)
520 .map(|i| format!("{:.6}", (word_id * 10 + i) as f32 * 0.1))
521 .collect();
522 writeln!(file, "{} {}", word, vals.join(" ")).unwrap();
523 }
524 drop(file);
525
526 let model = EmbeddingModel::new_with_pretrained(
528 config,
529 data.vocab.len(),
530 &data,
531 path_str,
532 );
533 assert!(model.is_ok());
534
535 let model = model.unwrap();
536 let cat_emb = model.get_embedding("cat", &data).unwrap();
538 let cat_id = data.vocab.get("cat").unwrap();
539 for (i, &val) in cat_emb.iter().enumerate() {
540 let expected = (*cat_id * 10 + i) as f32 * 0.1;
541 assert!((val - expected).abs() < 1e-5, "Mismatch at index {}: got {}, expected {}", i, val, expected);
542 }
543
544 std::fs::remove_file(path_str).ok();
545 }
546
547 #[test]
548 fn test_semantic_search() {
549 let data = make_test_data();
550 let config = test_config(ModelType::SkipGram);
551 let mut model = EmbeddingModel::new(config, data.vocab.len());
552 model.train(&data).unwrap();
553
554 let results = model.semantic_search("cat", &data, 5);
555 assert!(!results.is_empty());
556 for (word, _) in &results {
558 assert_ne!(word, "cat");
559 }
560 }
561
562 #[test]
563 fn test_embedding_arithmetic() {
564 let data = make_test_data();
565 let config = test_config(ModelType::SkipGram);
566 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
567 model.train(&data).unwrap();
568
569 let result = model.embedding_arithmetic("cat", "dog", &data);
570 assert!(result.is_some());
571 assert_eq!(result.unwrap().len(), config.embedding_dim);
572 }
573
574 #[test]
575 fn test_interpolate_embeddings() {
576 let data = make_test_data();
577 let config = test_config(ModelType::SkipGram);
578 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
579 model.train(&data).unwrap();
580
581 let result = model.interpolate_embeddings("cat", "dog", &data, 0.5);
582 assert!(result.is_some());
583 assert_eq!(result.unwrap().len(), config.embedding_dim);
584 }
585
586 #[test]
587 fn test_save_numpy_format() {
588 let data = make_test_data();
589 let config = test_config(ModelType::SkipGram);
590 let mut model = EmbeddingModel::new(config, data.vocab.len());
591 model.train(&data).unwrap();
592
593 let temp_path = std::env::temp_dir().join("test_numpy.npy");
594 let path_str = temp_path.to_str().unwrap();
595
596 assert!(model.save_numpy_format(path_str, &data).is_ok());
597
598 let metadata = std::fs::metadata(path_str).unwrap();
600 assert!(metadata.len() > 0);
601
602 std::fs::remove_file(path_str).ok();
603 }
604
605 #[test]
606 fn test_stream_sentences() {
607 use std::io::Write;
608
609 let temp_path = std::env::temp_dir().join("test_stream.txt");
610 let path_str = temp_path.to_str().unwrap();
611
612 let mut file = std::fs::File::create(path_str).unwrap();
613 writeln!(file, "the cat sat on the mat.").unwrap();
614 writeln!(file, "the dog sat on the log.").unwrap();
615 writeln!(file, "the cat chased the dog.").unwrap();
616 drop(file);
617
618 let loader = DataLoader::new(4, false);
619 let sentences: Vec<Vec<String>> = loader.stream_sentences(path_str).unwrap().collect();
620 assert!(!sentences.is_empty());
621 assert!(sentences.iter().all(|s| !s.is_empty()));
623
624 std::fs::remove_file(path_str).ok();
625 }
626
627 #[test]
628 fn test_incremental_vocab_update() {
629 let mut data = make_test_data();
630 let config = test_config(ModelType::SkipGram);
631 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
632 model.train(&data).unwrap();
633
634 let old_vocab_size = data.vocab.len();
635 let old_emb_rows = model.embeddings.nrows();
636
637 let new_words = vec!["elephant".to_string(), "giraffe".to_string()];
638 let added = model.incremental_vocab_update(&new_words, &mut data).unwrap();
639
640 assert_eq!(added.len(), 2);
641 assert_eq!(data.vocab.len(), old_vocab_size + 2);
642 assert_eq!(model.embeddings.nrows(), old_emb_rows + 2);
643 assert_eq!(model.embeddings.ncols(), config.embedding_dim);
644
645 assert!(model.get_embedding("elephant", &data).is_some());
647 assert!(model.get_embedding("giraffe", &data).is_some());
648 assert!(model.get_embedding("cat", &data).is_some());
650 }
651
652 #[test]
653 fn test_lsh_index() {
654 let data = make_test_data();
655 let config = test_config(ModelType::SkipGram);
656 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
657 model.train(&data).unwrap();
658
659 let mut lsh = LSHIndex::new(4, config.embedding_dim);
660 lsh.build(&model, &data);
661
662 let results = lsh.query("cat", &model, &data, 5);
663 for (word, _) in &results {
667 assert_ne!(word, "cat");
668 }
669 }
670
671 #[test]
672 fn test_save_onnx_format() {
673 use prost::Message;
674
675 let data = make_test_data();
676 let config = test_config(ModelType::SkipGram);
677 let mut model = EmbeddingModel::new(config, data.vocab.len());
678 model.train(&data).unwrap();
679
680 let temp_path = std::env::temp_dir().join("test_model.onnx");
681 let path_str = temp_path.to_str().unwrap();
682
683 assert!(model.save_onnx_format(path_str, &data).is_ok());
684
685 let metadata = std::fs::metadata(path_str).unwrap();
687 assert!(metadata.len() > 50);
688
689 let bytes = std::fs::read(path_str).unwrap();
691 let decoded = onnx::ModelProto::decode(&bytes[..]);
692 assert!(decoded.is_ok());
693
694 let m = decoded.unwrap();
695 assert_eq!(m.ir_version, 9);
696 assert_eq!(m.producer_name, "embedding-trainer");
697 let graph = m.graph.unwrap();
698 assert_eq!(graph.name, "embedding_graph");
699 assert_eq!(graph.node.len(), 1);
700 assert_eq!(graph.node[0].op_type, "Gather");
701 assert_eq!(graph.initializer.len(), 1);
702
703 std::fs::remove_file(path_str).ok();
704 }
705
706 #[test]
707 fn test_sentence_embedding() {
708 let data = make_test_data();
709 let config = test_config(ModelType::SkipGram);
710 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
711 model.train(&data).unwrap();
712
713 let sentence = vec!["the".to_string(), "cat".to_string(), "sat".to_string()];
714 let emb = model.sentence_embedding(&sentence, &data);
715 assert!(emb.is_some());
716 let emb = emb.unwrap();
717 assert_eq!(emb.len(), config.embedding_dim);
718
719 assert!(model.sentence_embedding(&[], &data).is_none());
721 }
722
723 #[test]
724 fn test_multimodal_fusion() {
725 let text = Array::from_vec(vec![1.0, 2.0, 3.0]);
726 let aux = Array::from_vec(vec![4.0, 5.0, 6.0]);
727 let fusion = MultimodalFusion::new(3, 3, 3);
728
729 let concat = fusion.concatenate(&text, &aux);
731 assert_eq!(concat.len(), 6);
732 assert_eq!(concat[0], 1.0);
733 assert_eq!(concat[5], 6.0);
734
735 let avg = fusion.weighted_average(&text, &aux, 0.5).unwrap();
737 assert_eq!(avg.len(), 3);
738 assert!((avg[0] - 2.5).abs() < 1e-6);
739
740 let short = Array::from_vec(vec![1.0, 2.0]);
742 assert!(fusion.weighted_average(&text, &short, 0.5).is_none());
743
744 let attn = fusion.attention_fusion(&text, &aux).unwrap();
746 assert_eq!(attn.len(), 3);
747
748 let sim = MultimodalFusion::cross_modal_similarity(&text, &aux);
750 assert!(sim >= -1.0 && sim <= 1.0);
751 }
752
753 #[test]
754 fn test_cross_lingual_aligner() {
755 let dim = 4;
756 let mut aligner = CrossLingualAligner::new(dim);
757
758 let v = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
760 let aligned = aligner.align(&v);
761 assert_eq!(aligned, v);
762
763 let src = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
765 let tgt = Array::from_vec(vec![2.0, 0.0, 0.0, 0.0]);
766 aligner.train_from_dictionary(&[(src, tgt)], 100, 0.1);
767
768 let test = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
770 let result = aligner.align(&test);
771 assert!((result[0] - 2.0).abs() < 0.1, "Expected ~2.0, got {}", result[0]);
772 }
773
774 #[test]
775 fn test_domain_adapter() {
776 let mut data = make_test_data();
777 let config = test_config(ModelType::SkipGram);
778 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
779 model.train(&data).unwrap();
780
781 let domain_sentences = vec![
782 vec!["the".to_string(), "cat".to_string()],
783 vec!["a".to_string(), "dog".to_string()],
784 ];
785 assert!(DomainAdapter::adapt(&mut model, &mut data, &domain_sentences, 1).is_ok());
786 assert!(data.vocab.contains_key("cat"));
788 }
789
790 #[test]
791 fn test_document_embedder() {
792 let data = make_test_data();
793 let config = test_config(ModelType::SkipGram);
794 let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
795 model.train(&data).unwrap();
796
797 let sentences = vec![
798 vec!["the".to_string(), "cat".to_string()],
799 vec!["a".to_string(), "dog".to_string()],
800 ];
801 let doc = DocumentEmbedder::embed_document(&model, &data, &sentences);
802 assert!(doc.is_some());
803 assert_eq!(doc.unwrap().len(), config.embedding_dim);
804
805 assert!(DocumentEmbedder::embed_document(&model, &data, &[]).is_none());
806 }
807
808 #[test]
809 fn test_zero_shot_transfer() {
810 let proto_a = Array::from_vec(vec![1.0, 0.0, 0.0]);
811 let proto_b = Array::from_vec(vec![0.0, 1.0, 0.0]);
812 let mut prototypes = HashMap::new();
813 prototypes.insert("class_a".to_string(), proto_a);
814 prototypes.insert("class_b".to_string(), proto_b);
815
816 let query = Array::from_vec(vec![0.9, 0.1, 0.0]);
817 let result = ZeroShotTransfer::classify(&query, &prototypes);
818 assert!(result.is_some());
819 let (label, sim) = result.unwrap();
820 assert_eq!(label, "class_a");
821 assert!(sim > 0.9);
822 }
823
824 #[test]
825 fn test_query_expander() {
826 let data = make_test_data();
827 let config = test_config(ModelType::SkipGram);
828 let mut model = EmbeddingModel::new(config, data.vocab.len());
829 model.train(&data).unwrap();
830
831 let expanded = QueryExpander::expand(&model, &data, "cat", 3);
832 assert!(!expanded.is_empty());
833 assert_eq!(expanded[0], "cat");
834 }
835
836 #[test]
837 fn test_hierarchical_clustering() {
838 let data = make_test_data();
839 let config = test_config(ModelType::SkipGram);
840 let mut model = EmbeddingModel::new(config, data.vocab.len());
841 model.train(&data).unwrap();
842
843 let clusters = HierarchicalClustering::cluster(&model, &data, 2);
844 assert_eq!(clusters.len(), 2);
845 let mut all_words = std::collections::HashSet::new();
847 for c in &clusters {
848 for word in c {
849 assert!(all_words.insert(word.clone()));
850 }
851 }
852 assert_eq!(all_words.len(), data.vocab.len());
853 }
854
855 #[test]
856 fn test_unicode_normalization() {
857 let processor = TextProcessor {
858 lowercase: true,
859 remove_punctuation: false,
860 remove_numbers: false,
861 remove_stop_words: false,
862 remove_html: false,
863 remove_urls: false,
864 expand_contractions: false,
865 normalize_unicode: false,
866 language: "en".to_string(),
867 };
868 let text = "caf\u{0065}\u{0301}";
870 let sentences = processor.process_text(text);
871 assert_eq!(sentences.len(), 1);
872 assert_eq!(sentences[0].len(), 1);
873 assert_eq!(sentences[0][0], "caf\u{00e9}");
875 }
876
877 #[test]
878 fn test_code_embedding_pipeline() {
879 let code = r#"
880 fn computeEmbeddingVector(input: Vec<f32>) -> Vec<f32> {
881 let result = vec![];
882 for x in input {
883 result.push(x * 2.0);
884 }
885 result
886 }
887 "#;
888 let sentences = load_code_data(code);
889 assert!(!sentences.is_empty());
890
891 let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
892 assert!(vocab.contains_key("compute"));
893 assert!(vocab.contains_key("embedding"));
894 assert!(vocab.contains_key("vector"));
895 assert!(vocab.contains_key("result"));
896
897 let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
898 let config = test_config(ModelType::SkipGram);
899 let mut model = EmbeddingModel::new(config, data.vocab.len());
900 assert!(model.train(&data).is_ok());
901
902 assert!(model.get_embedding("embedding", &data).is_some());
904 assert!(model.get_embedding("vector", &data).is_some());
905 }
906
907 #[test]
908 fn test_western_language_embedding_pipeline() {
909 let text = "Le chat noir dort sur le tapis. Le chien brun joue dans le jardin.";
911 let sentences = load_text_data(text);
912 assert!(!sentences.is_empty());
913
914 let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
915 assert!(vocab.contains_key("chat"));
917 assert!(vocab.contains_key("noir"));
918 assert!(vocab.contains_key("dort"));
919 assert!(vocab.contains_key("chien"));
920 assert!(vocab.contains_key("jardin"));
921
922 let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
923 let config = test_config(ModelType::SkipGram);
924 let mut model = EmbeddingModel::new(config, data.vocab.len());
925 assert!(model.train(&data).is_ok());
926
927 assert!(model.get_embedding("chat", &data).is_some());
929 assert!(model.get_embedding("chien", &data).is_some());
930 assert!(model.get_embedding("jardin", &data).is_some());
931 }
932
933 #[test]
934 fn test_chinese_embedding_pipeline() {
935 let text = "猫坐在垫子上。狗在花园里玩。猫追狗。";
937 let sentences = load_text_data(text);
938 assert!(!sentences.is_empty());
939
940 let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
941 assert!(vocab.contains_key("猫"));
943 assert!(vocab.contains_key("坐"));
944 assert!(vocab.contains_key("狗"));
945 assert!(vocab.contains_key("花"));
946 assert!(vocab.contains_key("追"));
947
948 let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
949 let config = test_config(ModelType::SkipGram);
950 let mut model = EmbeddingModel::new(config, data.vocab.len());
951 assert!(model.train(&data).is_ok());
952
953 assert!(model.get_embedding("猫", &data).is_some());
955 assert!(model.get_embedding("狗", &data).is_some());
956 assert!(model.get_embedding("追", &data).is_some());
957 }
958
959 #[test]
960 fn test_japanese_embedding_pipeline() {
961 let text = "猫はマットの上に座っています。犬は庭で遊んでいます。";
963 let sentences = load_text_data(text);
964 assert!(!sentences.is_empty());
965
966 let (vocab, reverse_vocab, word_freq) = build_vocab_with_freq(&sentences);
967 assert!(vocab.contains_key("猫"));
969 assert!(vocab.contains_key("座"));
970 assert!(vocab.contains_key("犬"));
971 assert!(vocab.contains_key("遊"));
972
973 let data = TrainingData { sentences, vocab, reverse_vocab, word_freq };
974 let config = test_config(ModelType::SkipGram);
975 let mut model = EmbeddingModel::new(config, data.vocab.len());
976 assert!(model.train(&data).is_ok());
977
978 assert!(model.get_embedding("猫", &data).is_some());
979 assert!(model.get_embedding("犬", &data).is_some());
980 }
981
982 #[test]
983 fn test_subword_embedder() {
984 let embedder = SubwordEmbedder::new(3, 5);
985 let ngrams = embedder.ngrams("apple");
986 assert!(!ngrams.is_empty());
987 assert!(ngrams.contains(&"<ap".to_string()));
988 assert!(ngrams.contains(&"ple>".to_string()));
989
990 let mut vectors = HashMap::new();
991 vectors.insert("<ap".to_string(), Array::from_vec(vec![1.0, 0.0]));
992 vectors.insert("app".to_string(), Array::from_vec(vec![0.0, 1.0]));
993 vectors.insert("ppl".to_string(), Array::from_vec(vec![1.0, 1.0]));
994 vectors.insert("ple>".to_string(), Array::from_vec(vec![0.5, 0.5]));
995
996 let emb = embedder.embed("apple", &vectors);
997 assert!(emb.is_some());
998 let emb = emb.unwrap();
999 assert_eq!(emb.len(), 2);
1000 }
1001
1002 #[test]
1003 fn test_create_validation_data() {
1004 let data = make_test_data();
1005 let config = test_config(ModelType::SkipGram);
1006 let model = EmbeddingModel::new(config, data.vocab.len());
1007 let val_data = model.create_validation_data(&data.sentences);
1008 assert!(!val_data.positive_pairs.is_empty());
1009 assert!(!val_data.negative_pairs.is_empty());
1010 }
1011
1012 #[test]
1013 fn test_evaluate_produces_metrics() {
1014 let data = make_test_data();
1015 let config = test_config(ModelType::SkipGram);
1016 let mut model = EmbeddingModel::new(config, data.vocab.len());
1017 model.train(&data).unwrap();
1018
1019 let val_data = model.create_validation_data(&data.sentences);
1020 let metrics = model.evaluate(&data, &val_data);
1021 assert!(metrics.accuracy >= 0.0 && metrics.accuracy <= 1.0);
1022 assert!(metrics.precision >= 0.0 && metrics.precision <= 1.0);
1023 assert!(metrics.recall >= 0.0 && metrics.recall <= 1.0);
1024 assert!(metrics.f1_score >= 0.0 && metrics.f1_score <= 1.0);
1025 assert!(metrics.mean_similarity >= -1.0 && metrics.mean_similarity <= 1.0);
1026 assert!(metrics.embedding_quality_score >= 0.0 && metrics.embedding_quality_score <= 1.0);
1027 }
1028
1029 #[test]
1030 fn test_train_with_validation_split() {
1031 let data = make_test_data();
1032 let mut config = test_config(ModelType::SkipGram);
1033 config.validation_ratio = Some(0.3);
1034 let mut model = EmbeddingModel::new(config, data.vocab.len());
1035 assert!(model.train(&data).is_ok());
1036 assert!(model.get_embedding("cat", &data).is_some());
1037 }
1038
1039 #[test]
1040 fn test_cross_validation_basic() {
1041 let data = make_test_data();
1042 let config = test_config(ModelType::SkipGram);
1043 let model = EmbeddingModel::new(config, data.vocab.len());
1044
1045 let result = model.cross_validate(&data, 3).unwrap();
1046 assert_eq!(result.folds, 3);
1047 assert_eq!(result.per_fold_metrics.len(), 3);
1048
1049 assert!(result.averaged_metrics.accuracy >= 0.0 && result.averaged_metrics.accuracy <= 1.0);
1051 assert!(result.averaged_metrics.f1_score >= 0.0 && result.averaged_metrics.f1_score <= 1.0);
1052 }
1053
1054 #[test]
1055 fn test_cross_validation_invalid_k() {
1056 let data = make_test_data();
1057 let config = test_config(ModelType::SkipGram);
1058 let model = EmbeddingModel::new(config, data.vocab.len());
1059
1060 assert!(model.cross_validate(&data, 0).is_err());
1061 assert!(model.cross_validate(&data, 1).is_err());
1062 assert!(model.cross_validate(&data, 100).is_err());
1063 }
1064
1065 #[test]
1066 fn test_cross_validation_k_equals_2() {
1067 let data = make_test_data();
1068 let config = test_config(ModelType::SkipGram);
1069 let model = EmbeddingModel::new(config, data.vocab.len());
1070
1071 let result = model.cross_validate(&data, 2).unwrap();
1072 assert_eq!(result.folds, 2);
1073 assert_eq!(result.per_fold_metrics.len(), 2);
1074 }
1075
1076 #[test]
1077 fn test_l2_normalize_embeddings() {
1078 let data = make_test_data();
1079 let config = test_config(ModelType::SkipGram);
1080 let mut model = EmbeddingModel::new(config, data.vocab.len());
1081 model.train(&data).unwrap();
1082
1083 model.normalize_embeddings();
1084
1085 for row in model.embeddings.rows() {
1087 let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
1088 assert!((norm - 1.0).abs() < 1e-5, "Expected unit norm, got {}", norm);
1089 }
1090 }
1091
1092 #[test]
1093 fn test_cross_validation_cbow() {
1094 let data = make_test_data();
1095 let config = test_config(ModelType::Cbow);
1096 let model = EmbeddingModel::new(config, data.vocab.len());
1097
1098 let result = model.cross_validate(&data, 2).unwrap();
1099 assert_eq!(result.folds, 2);
1100 assert!(result.averaged_metrics.accuracy >= 0.0);
1101 }
1102
1103 #[test]
1104 fn test_training_history_records_epochs() {
1105 let data = make_test_data();
1106 let config = test_config(ModelType::SkipGram);
1107 let mut model = EmbeddingModel::new(config, data.vocab.len());
1108 model.train(&data).unwrap();
1109
1110 assert!(!model.training_history.epochs.is_empty());
1111 let first = &model.training_history.epochs[0];
1112 assert!(first.loss >= 0.0);
1113 assert!(first.learning_rate > 0.0);
1114
1115 let json = model.training_history.to_json().unwrap();
1116 assert!(json.contains("loss"));
1117 assert!(json.contains("learning_rate"));
1118 }
1119
1120 #[test]
1121 fn test_wordpiece_tokenizer_train_encode_decode() {
1122 let corpus = vec![
1123 "hello".to_string(),
1124 "world".to_string(),
1125 "hello".to_string(),
1126 "world".to_string(),
1127 ];
1128 let tokenizer = tokenizer::WordPieceTokenizer::train(&corpus, 50);
1129 assert!(tokenizer.vocab_size > 0);
1130
1131 let tokens = tokenizer.encode("hello");
1132 assert!(!tokens.is_empty());
1133
1134 let decoded = tokenizer.decode(&tokens);
1135 assert_eq!(decoded, "hello");
1136 }
1137
1138 #[test]
1139 fn test_cpu_backend() {
1140 let backend = backend::CpuBackend::new();
1141 assert_eq!(backend.name(), "cpu");
1142
1143 let emb = backend.init_embeddings(10, 8);
1144 assert_eq!(emb.nrows(), 10);
1145 assert_eq!(emb.ncols(), 8);
1146
1147 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1148 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
1149 assert!((backend.dot(&a, &b) - 32.0).abs() < 1e-5);
1150
1151 let mut c = a.clone();
1152 backend.add_scaled(&mut c, &b, 2.0);
1153 assert_eq!(c.to_vec(), vec![9.0, 12.0, 15.0]);
1154
1155 let m1 = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1156 let m2 = Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
1157 let result = backend.matmul(&m1, &m2);
1158 assert_eq!(result.nrows(), 2);
1159 assert_eq!(result.ncols(), 2);
1160 }
1161
1162 #[test]
1163 fn test_best_backend_returns_cpu() {
1164 let backend = backend::best_backend();
1165 assert!(!backend.name().is_empty());
1166 }
1167
1168 #[test]
1169 fn test_benchmark_load_and_evaluate() {
1170 let tsv = "cat\tdog\t0.8\ncat\tmat\t0.2\ndog\tmat\t0.1\n";
1171 let pairs = benchmark::BenchmarkEvaluator::load_from_tsv(tsv);
1172 assert_eq!(pairs.len(), 3);
1173 assert_eq!(pairs[0].word1, "cat");
1174
1175 let data = make_test_data();
1176 let config = test_config(ModelType::SkipGram);
1177 let mut model = EmbeddingModel::new(config, data.vocab.len());
1178 model.train(&data).unwrap();
1179
1180 let result = benchmark::BenchmarkEvaluator::evaluate(&model, &data, &pairs);
1181 assert_eq!(result.num_pairs, 3);
1182 assert!(result.num_evaluated <= 3);
1184 assert!(result.correlation >= -1.0 && result.correlation <= 1.0);
1186 }
1187
1188 #[test]
1189 fn test_kmeans_clustering() {
1190 let data = make_test_data();
1191 let config = test_config(ModelType::SkipGram);
1192 let mut model = EmbeddingModel::new(config, data.vocab.len());
1193 model.train(&data).unwrap();
1194
1195 let clusters = search::KMeansClustering::cluster(&model, &data, 3, 20);
1196 assert!(!clusters.is_empty());
1197 assert!(clusters.len() <= 3);
1198
1199 let total_words: usize = clusters.iter().map(|c| c.len()).sum();
1200 assert_eq!(total_words, data.vocab.len());
1201 }
1202
1203 #[test]
1204 fn test_kmeans_clustering_k_greater_than_vocab() {
1205 let data = make_test_data();
1206 let config = test_config(ModelType::SkipGram);
1207 let mut model = EmbeddingModel::new(config, data.vocab.len());
1208 model.train(&data).unwrap();
1209
1210 let clusters = search::KMeansClustering::cluster(&model, &data, 100, 10);
1212 assert_eq!(clusters.len(), data.vocab.len());
1213 }
1214
1215 #[test]
1216 fn test_transformer_encoder() {
1217 let encoder = TransformerEncoder::new(2, 2, 8, 16, 10);
1218 let tokens = ndarray::Array2::zeros((3, 8));
1219 let encoded = encoder.encode_sequence(&tokens);
1220 assert_eq!(encoded.nrows(), 3);
1221 assert_eq!(encoded.ncols(), 8);
1222 }
1223
1224 #[test]
1225 fn test_incremental_trainer() {
1226 let mut data = make_test_data();
1227 let config = test_config(ModelType::SkipGram);
1228 let mut model = EmbeddingModel::new(config, data.vocab.len());
1229 model.train(&data).unwrap();
1230
1231 let original_vocab = data.vocab.len();
1232 let new_sentences = vec![vec!["newword".to_string(), "cat".to_string()]];
1233
1234 IncrementalTrainer::update(&mut model, &mut data, &new_sentences, 1).unwrap();
1235
1236 assert!(data.vocab.len() >= original_vocab);
1238 assert!(data.vocab.contains_key("newword"));
1240 }
1241
1242 #[test]
1243 fn test_incremental_stream_train() {
1244 let mut data = make_test_data();
1245 let config = test_config(ModelType::SkipGram);
1246 let mut model = EmbeddingModel::new(config, data.vocab.len());
1247 model.train(&data).unwrap();
1248
1249 let sentences = vec![
1250 vec!["stream".to_string(), "word".to_string()],
1251 vec!["another".to_string(), "stream".to_string()],
1252 ];
1253
1254 IncrementalTrainer::stream_train(
1255 &mut model,
1256 &mut data,
1257 sentences.into_iter(),
1258 1,
1259 1,
1260 )
1261 .unwrap();
1262
1263 assert!(data.vocab.contains_key("stream"));
1264 }
1265
1266 #[test]
1267 fn test_mmap_embeddings_roundtrip() {
1268 let data = make_test_data();
1269 let config = test_config(ModelType::SkipGram);
1270 let mut model = EmbeddingModel::new(config, data.vocab.len());
1271 model.train(&data).unwrap();
1272
1273 let temp_path = std::env::temp_dir().join("test_mmap.bin");
1274 let path_str = temp_path.to_str().unwrap();
1275
1276 model.save_mmapable_format(path_str, &data).unwrap();
1278
1279 let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
1281 assert_eq!(mmap.vocab_size(), data.vocab.len());
1282 assert_eq!(mmap.dim(), model.config.embedding_dim);
1283
1284 let cat_emb = mmap.get("cat").unwrap();
1286 assert_eq!(cat_emb.len(), model.config.embedding_dim);
1287
1288 let cat_id = data.vocab["cat"];
1290 let model_cat: Vec<f32> = model.embeddings.row(cat_id).to_vec();
1291 assert_eq!(cat_emb, model_cat.as_slice());
1292
1293 std::fs::remove_file(path_str).ok();
1294 }
1295
1296 #[test]
1297 fn test_mmap_embeddings_iter() {
1298 let data = make_test_data();
1299 let config = test_config(ModelType::SkipGram);
1300 let mut model = EmbeddingModel::new(config, data.vocab.len());
1301 model.train(&data).unwrap();
1302
1303 let temp_path = std::env::temp_dir().join("test_mmap_iter.bin");
1304 let path_str = temp_path.to_str().unwrap();
1305
1306 model.save_mmapable_format(path_str, &data).unwrap();
1307 let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
1308
1309 let mut count = 0;
1310 for (word, emb) in mmap.iter() {
1311 assert!(!word.is_empty());
1312 assert_eq!(emb.len(), model.config.embedding_dim);
1313 count += 1;
1314 }
1315 assert_eq!(count, data.vocab.len());
1316
1317 std::fs::remove_file(path_str).ok();
1318 }
1319
1320 #[test]
1321 fn test_mmap_embeddings_missing_word() {
1322 let data = make_test_data();
1323 let config = test_config(ModelType::SkipGram);
1324 let mut model = EmbeddingModel::new(config, data.vocab.len());
1325 model.train(&data).unwrap();
1326
1327 let temp_path = std::env::temp_dir().join("test_mmap_missing.bin");
1328 let path_str = temp_path.to_str().unwrap();
1329
1330 model.save_mmapable_format(path_str, &data).unwrap();
1331 let mmap = EmbeddingModel::load_mmap(path_str).unwrap();
1332
1333 assert!(mmap.get("nonexistent_word").is_none());
1334 assert!(mmap.get("cat").is_some());
1335
1336 std::fs::remove_file(path_str).ok();
1337 }
1338
1339 #[test]
1340 fn test_pretrained_loader_word2vec_text() {
1341 let temp = std::env::temp_dir().join("test_pretrained_w2v.txt");
1342 let path = temp.to_str().unwrap();
1343
1344 let content = "3 4\ncat 0.1 0.2 0.3 0.4\ndog 0.5 0.6 0.7 0.8\nfish 0.9 0.0 0.1 0.2\n";
1345 std::fs::write(path, content).unwrap();
1346
1347 let emb = PretrainedLoader::auto(path).unwrap();
1348 assert_eq!(emb.dim(), 4);
1349 assert_eq!(emb.vocab_size(), 3);
1350 assert!(emb.contains("cat"));
1351 assert!(emb.contains("dog"));
1352 assert!(!emb.contains("elephant"));
1353
1354 let cat = emb.get("cat").unwrap();
1355 assert_eq!(cat.len(), 4);
1356 assert!((cat[0] - 0.1).abs() < 1e-6);
1357
1358 std::fs::remove_file(path).ok();
1359 }
1360
1361 #[test]
1362 fn test_pretrained_embeddings_similarity() {
1363 let mut emb = PretrainedEmbeddings::new(3);
1364 emb.insert("a".to_string(), vec![1.0, 0.0, 0.0]);
1365 emb.insert("b".to_string(), vec![0.0, 1.0, 0.0]);
1366 emb.insert("c".to_string(), vec![1.0, 0.0, 0.0]);
1367
1368 let sim_ab = emb.similarity("a", "b").unwrap();
1370 assert!(sim_ab.abs() < 1e-5, "Orthogonal vectors should have ~0 similarity");
1371
1372 let sim_ac = emb.similarity("a", "c").unwrap();
1374 assert!((sim_ac - 1.0).abs() < 1e-5, "Identical vectors should have similarity ~1");
1375
1376 assert!(emb.similarity("a", "missing").is_none());
1378 }
1379
1380 #[test]
1381 fn test_pretrained_embeddings_most_similar() {
1382 let mut emb = PretrainedEmbeddings::new(2);
1383 emb.insert("king".to_string(), vec![1.0, 0.0]);
1384 emb.insert("queen".to_string(), vec![0.9, 0.1]);
1385 emb.insert("man".to_string(), vec![0.1, 1.0]);
1386 emb.insert("woman".to_string(), vec![0.2, 0.9]);
1387
1388 let similar = emb.most_similar("king", 2);
1389 assert_eq!(similar.len(), 2);
1390 assert_eq!(similar[0].0, "queen"); }
1392
1393 #[test]
1394 fn test_pretrained_loader_glove_format() {
1395 let temp = std::env::temp_dir().join("test_pretrained_glove.txt");
1396 let path = temp.to_str().unwrap();
1397
1398 let content = "2 3\nhello 0.1 0.2 0.3\nworld 0.4 0.5 0.6\n";
1399 std::fs::write(path, content).unwrap();
1400
1401 let emb = PretrainedLoader::with_format(path, pretrained::PretrainedFormat::GloVe).unwrap();
1402 assert_eq!(emb.dim(), 3);
1403 assert_eq!(emb.vocab_size(), 2);
1404
1405 let hello = emb.get("hello").unwrap();
1406 assert_eq!(hello, &[0.1, 0.2, 0.3]);
1407
1408 std::fs::remove_file(path).ok();
1409 }
1410
1411 #[test]
1412 fn test_pretrained_init_model_from_pretrained() {
1413 let temp = std::env::temp_dir().join("test_pretrained_init.txt");
1414 let path = temp.to_str().unwrap();
1415
1416 let content = "3 8\ncat 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1\ndog 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2\nthe 0.3 0.3 0.3 0.3 0.3 0.3 0.3 0.3\n";
1418 std::fs::write(path, content).unwrap();
1419
1420 let data = make_test_data();
1421 let config = test_config(ModelType::SkipGram);
1422 let model = EmbeddingModel::new_with_pretrained(config, data.vocab.len(), &data, path).unwrap();
1423
1424 let cat_id = data.vocab["cat"];
1426 let cat_emb = model.embeddings.row(cat_id);
1427 for &v in cat_emb.iter() {
1428 assert!((v - 0.1).abs() < 1e-5);
1429 }
1430
1431 std::fs::remove_file(path).ok();
1432 }
1433}