pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}
pub mod config;
pub mod evaluation;
pub mod search;
pub mod code;
pub mod text;
pub mod tokenizer;
pub mod transfer;
pub use config::*;
pub use evaluation::*;
pub use search::*;
pub use code::*;
pub use text::*;
pub use tokenizer::*;
pub use transfer::*;
pub mod model;
mod training;
mod export;
pub mod cli;
mod commands;
pub use model::*;
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array, Array1, Array2};
use std::collections::HashMap;
#[test]
fn test_build_vocab() {
let sentences = vec![
vec!["hello".to_string(), "world".to_string()],
vec!["hello".to_string(), "rust".to_string()],
];
let (vocab, reverse_vocab) = build_vocab(&sentences);
assert_eq!(vocab.len(), 3);
assert_eq!(reverse_vocab.len(), 3);
assert_eq!(vocab.get("hello"), Some(&0));
assert_eq!(vocab.get("world"), Some(&1));
assert_eq!(vocab.get("rust"), Some(&2));
}
#[test]
fn test_load_text_data() {
let text = "Hello world! This is a test.";
let sentences = load_text_data(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["hello", "world"]);
assert_eq!(sentences[1], vec!["this", "is", "a", "test"]);
}
fn make_test_data() -> TrainingData {
let text = "the cat sat on the mat. the dog sat on the log. the cat chased the dog.";
let sentences = load_text_data(text);
let (vocab, reverse_vocab) = build_vocab(&sentences);
TrainingData { sentences, vocab, reverse_vocab }
}
fn test_config(model_type: ModelType) -> TrainingConfig {
TrainingConfig {
embedding_dim: 8,
learning_rate: 0.1,
epochs: 2,
batch_size: 4,
context_window: 1,
negative_samples: 2,
model_type,
lr_schedule: LearningRateSchedule::Constant,
early_stopping: None,
l2_regularization: None,
gradient_clip: None,
}
}
#[test]
fn test_train_skipgram() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
assert!(model.get_embedding("dog", &data).is_some());
assert!(model.get_embedding("the", &data).is_some());
assert!(model.similarity("cat", "dog", &data).is_some());
}
#[test]
fn test_train_cbow() {
let data = make_test_data();
let config = test_config(ModelType::Cbow);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
assert!(model.get_embedding("dog", &data).is_some());
assert!(model.similarity("cat", "dog", &data).is_some());
}
#[test]
fn test_save_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_embeddings_save.txt");
let path_str = path.to_str().unwrap();
assert!(model.save_embeddings(path_str, &data).is_ok());
let contents = std::fs::read_to_string(path_str).unwrap();
assert!(contents.contains("cat"));
assert!(contents.contains("dog"));
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_similarity_unknown_word() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
assert!(model.similarity("cat", "nonexistent", &data).is_none());
assert!(model.similarity("nonexistent", "dog", &data).is_none());
}
#[test]
fn test_strip_html() {
let processor = TextProcessor {
remove_html: true,
remove_punctuation: false,
lowercase: false,
..TextProcessor::default()
};
let text = "<p>Hello world!</p> This is a <b>test</b>.";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["Hello", "world"]);
assert_eq!(sentences[1], vec!["This", "is", "a", "test"]);
}
#[test]
fn test_strip_urls() {
let processor = TextProcessor {
remove_urls: true,
remove_punctuation: true,
lowercase: true,
..TextProcessor::default()
};
let text = "Visit https://example.com for info. See www.test.org too.";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["visit", "for", "info"]);
assert_eq!(sentences[1], vec!["see", "too"]);
}
#[test]
fn test_expand_contractions() {
let processor = TextProcessor {
expand_contractions: true,
remove_punctuation: true,
lowercase: true,
..TextProcessor::default()
};
let text = "I can't do this. It's a test.";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], vec!["i", "cannot", "do", "this"]);
assert_eq!(sentences[1], vec!["it", "is", "a", "test"]);
}
#[test]
fn test_normalize_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
model.normalize_embeddings();
for row in model.embeddings.rows() {
let norm = row.iter().map(|&x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5 || norm == 0.0);
}
}
#[test]
fn test_analogy_unknown_word() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
assert!(model.analogy("unknown", "cat", "dog", &data, 1).is_empty());
}
#[test]
fn test_split_data() {
let sentences = vec![
vec!["a".to_string()],
vec!["b".to_string()],
vec!["c".to_string()],
vec!["d".to_string()],
vec!["e".to_string()],
vec!["f".to_string()],
vec!["g".to_string()],
vec!["h".to_string()],
vec!["i".to_string()],
vec!["j".to_string()],
];
let config = test_config(ModelType::SkipGram);
let model = EmbeddingModel::new(config, 1);
let (train, val) = model.split_data(&sentences, 0.7);
assert_eq!(train.len(), 7);
assert_eq!(val.len(), 3);
}
#[test]
fn test_gradient_clipping() {
let data = make_test_data();
let mut config = test_config(ModelType::SkipGram);
config.gradient_clip = Some(0.001);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("cat", &data).is_some());
}
#[test]
fn test_mini_batch_processing() {
let data = make_test_data();
let mut config1 = test_config(ModelType::SkipGram);
config1.batch_size = 1;
let mut model1 = EmbeddingModel::new(config1, data.vocab.len());
assert!(model1.train(&data).is_ok());
let mut config8 = test_config(ModelType::SkipGram);
config8.batch_size = 8;
let mut model8 = EmbeddingModel::new(config8, data.vocab.len());
assert!(model8.train(&data).is_ok());
assert!(model1.get_embedding("cat", &data).is_some());
assert!(model8.get_embedding("cat", &data).is_some());
}
#[test]
fn test_empty_text() {
let sentences = load_text_data("");
assert!(sentences.is_empty());
}
#[test]
fn test_single_word_text() {
let sentences = load_text_data("hello");
assert_eq!(sentences.len(), 1);
assert_eq!(sentences[0], vec!["hello"]);
}
#[test]
fn test_learning_rate_schedules() {
let data = make_test_data();
let mut config_exp = test_config(ModelType::SkipGram);
config_exp.lr_schedule = LearningRateSchedule::Exponential { decay_rate: 0.9 };
let mut model_exp = EmbeddingModel::new(config_exp, data.vocab.len());
assert!(model_exp.train(&data).is_ok());
let mut config_step = test_config(ModelType::SkipGram);
config_step.lr_schedule = LearningRateSchedule::Step { step_size: 1, gamma: 0.5 };
let mut model_step = EmbeddingModel::new(config_step, data.vocab.len());
assert!(model_step.train(&data).is_ok());
let mut config_cos = test_config(ModelType::SkipGram);
config_cos.lr_schedule = LearningRateSchedule::Cosine { t_max: 2 };
let mut model_cos = EmbeddingModel::new(config_cos, data.vocab.len());
assert!(model_cos.train(&data).is_ok());
}
#[test]
fn test_early_stopping() {
let data = make_test_data();
let mut config = test_config(ModelType::SkipGram);
config.early_stopping = Some(EarlyStoppingConfig { patience: 1, min_delta: 0.001 });
config.epochs = 10;
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
}
#[test]
fn test_word2vec_format_roundtrip() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_word2vec.txt");
let path_str = temp_path.to_str().unwrap();
assert!(model.save_word2vec_format(path_str, &data).is_ok());
let (loaded, dim) = EmbeddingModel::load_word2vec_format(path_str).unwrap();
assert_eq!(dim, 8);
assert!(loaded.contains_key("cat"));
assert!(loaded.contains_key("dog"));
assert_eq!(loaded.get("cat").unwrap().len(), 8);
assert_eq!(loaded.get("dog").unwrap().len(), 8);
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_bpe_tokenizer() {
let corpus = vec![
"low".to_string(),
"lower".to_string(),
"lowest".to_string(),
"newer".to_string(),
"new".to_string(),
"widest".to_string(),
"wide".to_string(),
];
let tokenizer = BPETokenizer::train(&corpus, 20);
assert!(tokenizer.vocab.len() >= 10);
let tokens = tokenizer.encode("lowest");
assert!(!tokens.is_empty());
let decoded = tokenizer.decode(&tokens);
assert_eq!(decoded, "lowest");
}
#[test]
fn test_pretrained_embeddings_loading() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let temp_path = std::env::temp_dir().join("test_pretrained.txt");
let path_str = temp_path.to_str().unwrap();
let mut file = std::fs::File::create(path_str).unwrap();
use std::io::Write;
writeln!(file, "{} {}", data.vocab.len(), config.embedding_dim).unwrap();
for (word_id, word) in data.reverse_vocab.iter().enumerate() {
let vals: Vec<String> = (0..config.embedding_dim)
.map(|i| format!("{:.6}", (word_id * 10 + i) as f32 * 0.1))
.collect();
writeln!(file, "{} {}", word, vals.join(" ")).unwrap();
}
drop(file);
let model = EmbeddingModel::new_with_pretrained(
config,
data.vocab.len(),
&data,
path_str,
);
assert!(model.is_ok());
let model = model.unwrap();
let cat_emb = model.get_embedding("cat", &data).unwrap();
let cat_id = data.vocab.get("cat").unwrap();
for (i, &val) in cat_emb.iter().enumerate() {
let expected = (*cat_id * 10 + i) as f32 * 0.1;
assert!((val - expected).abs() < 1e-5, "Mismatch at index {}: got {}, expected {}", i, val, expected);
}
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_semantic_search() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let results = model.semantic_search("cat", &data, 5);
assert!(!results.is_empty());
for (word, _) in &results {
assert_ne!(word, "cat");
}
}
#[test]
fn test_embedding_arithmetic() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let result = model.embedding_arithmetic("cat", "dog", &data);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), config.embedding_dim);
}
#[test]
fn test_interpolate_embeddings() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let result = model.interpolate_embeddings("cat", "dog", &data, 0.5);
assert!(result.is_some());
assert_eq!(result.unwrap().len(), config.embedding_dim);
}
#[test]
fn test_save_numpy_format() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_numpy.npy");
let path_str = temp_path.to_str().unwrap();
assert!(model.save_numpy_format(path_str, &data).is_ok());
let metadata = std::fs::metadata(path_str).unwrap();
assert!(metadata.len() > 0);
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_stream_sentences() {
use std::io::Write;
let temp_path = std::env::temp_dir().join("test_stream.txt");
let path_str = temp_path.to_str().unwrap();
let mut file = std::fs::File::create(path_str).unwrap();
writeln!(file, "the cat sat on the mat.").unwrap();
writeln!(file, "the dog sat on the log.").unwrap();
writeln!(file, "the cat chased the dog.").unwrap();
drop(file);
let loader = DataLoader::new(4, false);
let sentences: Vec<Vec<String>> = loader.stream_sentences(path_str).unwrap().collect();
assert!(!sentences.is_empty());
assert!(sentences.iter().all(|s| !s.is_empty()));
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_incremental_vocab_update() {
let mut data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let old_vocab_size = data.vocab.len();
let old_emb_rows = model.embeddings.nrows();
let new_words = vec!["elephant".to_string(), "giraffe".to_string()];
let added = model.incremental_vocab_update(&new_words, &mut data).unwrap();
assert_eq!(added.len(), 2);
assert_eq!(data.vocab.len(), old_vocab_size + 2);
assert_eq!(model.embeddings.nrows(), old_emb_rows + 2);
assert_eq!(model.embeddings.ncols(), config.embedding_dim);
assert!(model.get_embedding("elephant", &data).is_some());
assert!(model.get_embedding("giraffe", &data).is_some());
assert!(model.get_embedding("cat", &data).is_some());
}
#[test]
fn test_lsh_index() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let mut lsh = LSHIndex::new(4, config.embedding_dim);
lsh.build(&model, &data);
let results = lsh.query("cat", &model, &data, 5);
for (word, _) in &results {
assert_ne!(word, "cat");
}
}
#[test]
fn test_save_onnx_format() {
use prost::Message;
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let temp_path = std::env::temp_dir().join("test_model.onnx");
let path_str = temp_path.to_str().unwrap();
assert!(model.save_onnx_format(path_str, &data).is_ok());
let metadata = std::fs::metadata(path_str).unwrap();
assert!(metadata.len() > 50);
let bytes = std::fs::read(path_str).unwrap();
let decoded = onnx::ModelProto::decode(&bytes[..]);
assert!(decoded.is_ok());
let m = decoded.unwrap();
assert_eq!(m.ir_version, 9);
assert_eq!(m.producer_name, "embedding-trainer");
let graph = m.graph.unwrap();
assert_eq!(graph.name, "embedding_graph");
assert_eq!(graph.node.len(), 1);
assert_eq!(graph.node[0].op_type, "Gather");
assert_eq!(graph.initializer.len(), 1);
std::fs::remove_file(path_str).ok();
}
#[test]
fn test_sentence_embedding() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let sentence = vec!["the".to_string(), "cat".to_string(), "sat".to_string()];
let emb = model.sentence_embedding(&sentence, &data);
assert!(emb.is_some());
let emb = emb.unwrap();
assert_eq!(emb.len(), config.embedding_dim);
assert!(model.sentence_embedding(&[], &data).is_none());
}
#[test]
fn test_multimodal_fusion() {
let text = Array::from_vec(vec![1.0, 2.0, 3.0]);
let aux = Array::from_vec(vec![4.0, 5.0, 6.0]);
let fusion = MultimodalFusion { text_dim: 3, aux_dim: 3 };
let concat = fusion.concatenate(&text, &aux);
assert_eq!(concat.len(), 6);
assert_eq!(concat[0], 1.0);
assert_eq!(concat[5], 6.0);
let avg = fusion.weighted_average(&text, &aux, 0.5).unwrap();
assert_eq!(avg.len(), 3);
assert!((avg[0] - 2.5).abs() < 1e-6);
let short = Array::from_vec(vec![1.0, 2.0]);
assert!(fusion.weighted_average(&text, &short, 0.5).is_none());
}
#[test]
fn test_cross_lingual_aligner() {
let dim = 4;
let mut aligner = CrossLingualAligner::new(dim);
let v = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let aligned = aligner.align(&v);
assert_eq!(aligned, v);
let src = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
let tgt = Array::from_vec(vec![2.0, 0.0, 0.0, 0.0]);
aligner.train_from_dictionary(&[(src, tgt)], 100, 0.1);
let test = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0]);
let result = aligner.align(&test);
assert!((result[0] - 2.0).abs() < 0.1, "Expected ~2.0, got {}", result[0]);
}
#[test]
fn test_domain_adapter() {
let mut data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let domain_sentences = vec![
vec!["the".to_string(), "cat".to_string()],
vec!["a".to_string(), "dog".to_string()],
];
assert!(DomainAdapter::adapt(&mut model, &mut data, &domain_sentences, 1).is_ok());
assert!(data.vocab.contains_key("cat"));
}
#[test]
fn test_document_embedder() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config.clone(), data.vocab.len());
model.train(&data).unwrap();
let sentences = vec![
vec!["the".to_string(), "cat".to_string()],
vec!["a".to_string(), "dog".to_string()],
];
let doc = DocumentEmbedder::embed_document(&model, &data, &sentences);
assert!(doc.is_some());
assert_eq!(doc.unwrap().len(), config.embedding_dim);
assert!(DocumentEmbedder::embed_document(&model, &data, &[]).is_none());
}
#[test]
fn test_zero_shot_transfer() {
let proto_a = Array::from_vec(vec![1.0, 0.0, 0.0]);
let proto_b = Array::from_vec(vec![0.0, 1.0, 0.0]);
let mut prototypes = HashMap::new();
prototypes.insert("class_a".to_string(), proto_a);
prototypes.insert("class_b".to_string(), proto_b);
let query = Array::from_vec(vec![0.9, 0.1, 0.0]);
let result = ZeroShotTransfer::classify(&query, &prototypes);
assert!(result.is_some());
let (label, sim) = result.unwrap();
assert_eq!(label, "class_a");
assert!(sim > 0.9);
}
#[test]
fn test_query_expander() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let expanded = QueryExpander::expand(&model, &data, "cat", 3);
assert!(!expanded.is_empty());
assert_eq!(expanded[0], "cat");
}
#[test]
fn test_hierarchical_clustering() {
let data = make_test_data();
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
model.train(&data).unwrap();
let clusters = HierarchicalClustering::cluster(&model, &data, 2);
assert_eq!(clusters.len(), 2);
let mut all_words = std::collections::HashSet::new();
for c in &clusters {
for word in c {
assert!(all_words.insert(word.clone()));
}
}
assert_eq!(all_words.len(), data.vocab.len());
}
#[test]
fn test_unicode_normalization() {
let processor = TextProcessor {
lowercase: true,
remove_punctuation: false,
remove_numbers: false,
remove_stop_words: false,
remove_html: false,
remove_urls: false,
expand_contractions: false,
normalize_unicode: false,
language: "en".to_string(),
};
let text = "caf\u{0065}\u{0301}";
let sentences = processor.process_text(text);
assert_eq!(sentences.len(), 1);
assert_eq!(sentences[0].len(), 1);
assert_eq!(sentences[0][0], "caf\u{00e9}");
}
#[test]
fn test_code_embedding_pipeline() {
let code = r#"
fn computeEmbeddingVector(input: Vec<f32>) -> Vec<f32> {
let result = vec![];
for x in input {
result.push(x * 2.0);
}
result
}
"#;
let sentences = load_code_data(code);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab) = build_vocab(&sentences);
assert!(vocab.contains_key("compute"));
assert!(vocab.contains_key("embedding"));
assert!(vocab.contains_key("vector"));
assert!(vocab.contains_key("result"));
let data = TrainingData { sentences, vocab, reverse_vocab };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("embedding", &data).is_some());
assert!(model.get_embedding("vector", &data).is_some());
}
#[test]
fn test_western_language_embedding_pipeline() {
let text = "Le chat noir dort sur le tapis. Le chien brun joue dans le jardin.";
let sentences = load_text_data(text);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab) = build_vocab(&sentences);
assert!(vocab.contains_key("chat"));
assert!(vocab.contains_key("noir"));
assert!(vocab.contains_key("dort"));
assert!(vocab.contains_key("chien"));
assert!(vocab.contains_key("jardin"));
let data = TrainingData { sentences, vocab, reverse_vocab };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("chat", &data).is_some());
assert!(model.get_embedding("chien", &data).is_some());
assert!(model.get_embedding("jardin", &data).is_some());
}
#[test]
fn test_chinese_embedding_pipeline() {
let text = "猫坐在垫子上。狗在花园里玩。猫追狗。";
let sentences = load_text_data(text);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab) = build_vocab(&sentences);
assert!(vocab.contains_key("猫"));
assert!(vocab.contains_key("坐"));
assert!(vocab.contains_key("狗"));
assert!(vocab.contains_key("花"));
assert!(vocab.contains_key("追"));
let data = TrainingData { sentences, vocab, reverse_vocab };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("猫", &data).is_some());
assert!(model.get_embedding("狗", &data).is_some());
assert!(model.get_embedding("追", &data).is_some());
}
#[test]
fn test_japanese_embedding_pipeline() {
let text = "猫はマットの上に座っています。犬は庭で遊んでいます。";
let sentences = load_text_data(text);
assert!(!sentences.is_empty());
let (vocab, reverse_vocab) = build_vocab(&sentences);
assert!(vocab.contains_key("猫"));
assert!(vocab.contains_key("座"));
assert!(vocab.contains_key("犬"));
assert!(vocab.contains_key("遊"));
let data = TrainingData { sentences, vocab, reverse_vocab };
let config = test_config(ModelType::SkipGram);
let mut model = EmbeddingModel::new(config, data.vocab.len());
assert!(model.train(&data).is_ok());
assert!(model.get_embedding("猫", &data).is_some());
assert!(model.get_embedding("犬", &data).is_some());
}
#[test]
fn test_subword_embedder() {
let embedder = SubwordEmbedder::new(3, 5);
let ngrams = embedder.ngrams("apple");
assert!(!ngrams.is_empty());
assert!(ngrams.contains(&"<ap".to_string()));
assert!(ngrams.contains(&"ple>".to_string()));
let mut vectors = HashMap::new();
vectors.insert("<ap".to_string(), Array::from_vec(vec![1.0, 0.0]));
vectors.insert("app".to_string(), Array::from_vec(vec![0.0, 1.0]));
vectors.insert("ppl".to_string(), Array::from_vec(vec![1.0, 1.0]));
vectors.insert("ple>".to_string(), Array::from_vec(vec![0.5, 0.5]));
let emb = embedder.embed("apple", &vectors);
assert!(emb.is_some());
let emb = emb.unwrap();
assert_eq!(emb.len(), 2);
}
}