1use crate::{
11 embeddings::{EmbeddableContent, EmbeddingConfig, EmbeddingGenerator},
12 Vector,
13};
14use anyhow::{anyhow, Result};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufRead, BufReader};
20use std::path::Path;
21
22#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
24pub enum Word2VecFormat {
25 Text,
27 Binary,
29 GloVe,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Word2VecConfig {
36 pub model_path: String,
38 pub format: Word2VecFormat,
40 pub dimensions: usize,
42 pub aggregation: AggregationMethod,
44 pub use_subwords: bool,
46 pub min_subword_len: usize,
48 pub max_subword_len: usize,
50 pub oov_strategy: OovStrategy,
52 pub normalize: bool,
54}
55
56impl Default for Word2VecConfig {
57 fn default() -> Self {
58 Self {
59 model_path: String::new(),
60 format: Word2VecFormat::Text,
61 dimensions: 300,
62 aggregation: AggregationMethod::Mean,
63 use_subwords: true,
64 min_subword_len: 3,
65 max_subword_len: 6,
66 oov_strategy: OovStrategy::Subword,
67 normalize: true,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
74pub enum AggregationMethod {
75 Mean,
77 WeightedMean,
79 Max,
81 Min,
83 MeanMax,
85 TfIdfWeighted,
87}
88
89#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
91pub enum OovStrategy {
92 Zero,
94 Random,
96 Subword,
98 Skip,
100 LearnedOov,
102}
103
104pub struct Word2VecEmbeddingGenerator {
106 config: Word2VecConfig,
107 embedding_config: EmbeddingConfig,
108 embeddings: HashMap<String, Vec<f32>>,
110 subword_embeddings: HashMap<String, Vec<f32>>,
112 doc_frequencies: HashMap<String, f32>,
114 oov_embedding: Option<Vec<f32>>,
116}
117
118impl Word2VecEmbeddingGenerator {
119 pub fn new(word2vec_config: Word2VecConfig, embedding_config: EmbeddingConfig) -> Result<Self> {
121 let mut generator = Self {
122 config: word2vec_config,
123 embedding_config,
124 embeddings: HashMap::new(),
125 subword_embeddings: HashMap::new(),
126 doc_frequencies: HashMap::new(),
127 oov_embedding: None,
128 };
129
130 let model_path = generator.config.model_path.clone();
132 if !model_path.is_empty() {
133 generator.load_model(&model_path)?;
134 }
135
136 Ok(generator)
137 }
138
139 pub fn load_model(&mut self, path: &str) -> Result<()> {
141 let path = Path::new(path);
142
143 if !path.exists() {
144 return Err(anyhow!("Model file not found: {}", path.display()));
145 }
146
147 match self.config.format {
148 Word2VecFormat::Text => self.load_text_format(path),
149 Word2VecFormat::Binary => self.load_binary_format(path),
150 Word2VecFormat::GloVe => self.load_glove_format(path),
151 }
152 }
153
154 fn load_text_format(&mut self, path: &Path) -> Result<()> {
156 let file = File::open(path)?;
157 let reader = BufReader::new(file);
158 let mut lines = reader.lines();
159
160 if let Some(Ok(header)) = lines.next() {
162 let parts: Vec<&str> = header.split_whitespace().collect();
163 if parts.len() == 2 {
164 let _vocab_size: usize = parts[0].parse()?;
165 let dimensions: usize = parts[1].parse()?;
166
167 if dimensions != self.config.dimensions {
168 return Err(anyhow!(
169 "Model dimensions ({}) don't match config ({})",
170 dimensions,
171 self.config.dimensions
172 ));
173 }
174 }
175 }
176
177 for line in lines {
179 let line = line?;
180 let parts: Vec<&str> = line.split_whitespace().collect();
181
182 if parts.len() < self.config.dimensions + 1 {
183 continue;
184 }
185
186 let word = parts[0].to_string();
187 let embedding: Result<Vec<f32>> = parts[1..=self.config.dimensions]
188 .iter()
189 .map(|s| s.parse::<f32>().map_err(Into::into))
190 .collect();
191
192 if let Ok(embedding) = embedding {
193 self.embeddings.insert(word, embedding);
194 }
195 }
196
197 if self.config.use_subwords {
199 self.generate_subword_embeddings()?;
200 }
201
202 if self.config.oov_strategy == OovStrategy::LearnedOov {
204 self.initialize_oov_embedding();
205 }
206
207 Ok(())
208 }
209
210 fn load_binary_format(&mut self, path: &Path) -> Result<()> {
212 use std::io::Read;
213
214 let mut file = File::open(path)?;
215 let mut buffer = Vec::new();
216 file.read_to_end(&mut buffer)?;
217
218 #[allow(unused_assignments)]
220 let mut pos = 0;
221
222 let header_end = buffer
224 .iter()
225 .position(|&b| b == b'\n')
226 .ok_or_else(|| anyhow!("Invalid binary format"))?;
227 let header = std::str::from_utf8(&buffer[..header_end])?;
228 let parts: Vec<&str> = header.split_whitespace().collect();
229
230 if parts.len() != 2 {
231 return Err(anyhow!("Invalid header format"));
232 }
233
234 let vocab_size: usize = parts[0].parse()?;
235 let dimensions: usize = parts[1].parse()?;
236
237 if dimensions != self.config.dimensions {
238 return Err(anyhow!(
239 "Model dimensions ({}) don't match config ({})",
240 dimensions,
241 self.config.dimensions
242 ));
243 }
244
245 pos = header_end + 1;
246
247 for _ in 0..vocab_size {
249 let word_start = pos;
251 while pos < buffer.len() && buffer[pos] != b' ' {
252 pos += 1;
253 }
254
255 if pos >= buffer.len() {
256 break;
257 }
258
259 let word = std::str::from_utf8(&buffer[word_start..pos])?.to_string();
260 pos += 1; let mut embedding = Vec::with_capacity(dimensions);
264 for _ in 0..dimensions {
265 if pos + 4 > buffer.len() {
266 break;
267 }
268
269 let bytes = [
270 buffer[pos],
271 buffer[pos + 1],
272 buffer[pos + 2],
273 buffer[pos + 3],
274 ];
275 let value = f32::from_le_bytes(bytes);
276 embedding.push(value);
277 pos += 4;
278 }
279
280 if embedding.len() == dimensions {
281 self.embeddings.insert(word, embedding);
282 }
283
284 if pos < buffer.len() && buffer[pos] == b'\n' {
286 pos += 1;
287 }
288 }
289
290 if self.config.use_subwords {
292 self.generate_subword_embeddings()?;
293 }
294
295 Ok(())
296 }
297
298 fn load_glove_format(&mut self, path: &Path) -> Result<()> {
300 let file = File::open(path)?;
301 let reader = BufReader::new(file);
302
303 for line in reader.lines() {
304 let line = line?;
305 let parts: Vec<&str> = line.split_whitespace().collect();
306
307 if parts.len() < self.config.dimensions + 1 {
308 continue;
309 }
310
311 let word = parts[0].to_string();
312 let embedding: Result<Vec<f32>> = parts[1..=self.config.dimensions]
313 .iter()
314 .map(|s| s.parse::<f32>().map_err(Into::into))
315 .collect();
316
317 if let Ok(embedding) = embedding {
318 self.embeddings.insert(word, embedding);
319 }
320 }
321
322 if self.config.use_subwords {
324 self.generate_subword_embeddings()?;
325 }
326
327 Ok(())
328 }
329
330 fn generate_subword_embeddings(&mut self) -> Result<()> {
332 let mut subword_counts: HashMap<String, usize> = HashMap::new();
333 let mut subword_sums: HashMap<String, Vec<f32>> = HashMap::new();
334
335 for (word, embedding) in &self.embeddings {
337 let subwords = self.get_subwords(word);
338
339 for subword in subwords {
340 *subword_counts.entry(subword.clone()).or_insert(0) += 1;
341
342 let sum = subword_sums
343 .entry(subword)
344 .or_insert_with(|| vec![0.0; self.config.dimensions]);
345 for (i, val) in embedding.iter().enumerate() {
346 sum[i] += val;
347 }
348 }
349 }
350
351 for (subword, count) in subword_counts {
353 if let Some(sum) = subword_sums.get(&subword) {
354 let avg: Vec<f32> = sum.iter().map(|&s| s / count as f32).collect();
355 self.subword_embeddings.insert(subword, avg);
356 }
357 }
358
359 Ok(())
360 }
361
362 fn get_subwords(&self, word: &str) -> Vec<String> {
364 let mut subwords = Vec::new();
365 let chars: Vec<char> = word.chars().collect();
366
367 for len in self.config.min_subword_len..=self.config.max_subword_len.min(chars.len()) {
368 for start in 0..=chars.len().saturating_sub(len) {
369 let subword: String = chars[start..start + len].iter().collect();
370 subwords.push(format!("<{subword}>")); }
372 }
373
374 subwords
375 }
376
377 fn initialize_oov_embedding(&mut self) {
379 let mut sum = vec![0.0; self.config.dimensions];
381 let count = self.embeddings.len() as f32;
382
383 for embedding in self.embeddings.values() {
384 for (i, val) in embedding.iter().enumerate() {
385 sum[i] += val;
386 }
387 }
388
389 self.oov_embedding = Some(sum.iter().map(|&s| s / count).collect());
390 }
391
392 fn get_word_embedding(&self, word: &str) -> Option<Vec<f32>> {
394 if let Some(embedding) = self.embeddings.get(word) {
396 return Some(embedding.clone());
397 }
398
399 if let Some(embedding) = self.embeddings.get(&word.to_lowercase()) {
401 return Some(embedding.clone());
402 }
403
404 match self.config.oov_strategy {
406 OovStrategy::Zero => Some(vec![0.0; self.config.dimensions]),
407 OovStrategy::Random => {
408 let mut hasher = std::collections::hash_map::DefaultHasher::new();
410 std::hash::Hash::hash(&word, &mut hasher);
411 let hash = std::hash::Hasher::finish(&hasher);
412
413 let mut rng = Random::seed(hash);
414 use scirs2_core::random::Random as RngTrait;
415
416 Some(
417 (0..self.config.dimensions)
418 .map(|_| rng.gen_range(-0.1..0.1))
419 .collect(),
420 )
421 }
422 OovStrategy::Subword => {
423 if self.config.use_subwords {
424 self.get_subword_embedding(word)
425 } else {
426 None
427 }
428 }
429 OovStrategy::Skip => None,
430 OovStrategy::LearnedOov => self.oov_embedding.clone(),
431 }
432 }
433
434 fn get_subword_embedding(&self, word: &str) -> Option<Vec<f32>> {
436 let subwords = self.get_subwords(word);
437 let mut sum = vec![0.0; self.config.dimensions];
438 let mut count = 0;
439
440 for subword in subwords {
441 if let Some(embedding) = self.subword_embeddings.get(&subword) {
442 for (i, val) in embedding.iter().enumerate() {
443 sum[i] += val;
444 }
445 count += 1;
446 }
447 }
448
449 if count > 0 {
450 Some(sum.iter().map(|&s| s / count as f32).collect())
451 } else {
452 None
453 }
454 }
455
456 fn tokenize(&self, text: &str) -> Vec<String> {
458 text.to_lowercase()
459 .split_whitespace()
460 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
461 .filter(|s| !s.is_empty())
462 .map(String::from)
463 .collect()
464 }
465
466 fn aggregate_embeddings(&self, word_embeddings: &[(String, Vec<f32>)]) -> Vec<f32> {
468 if word_embeddings.is_empty() {
469 return vec![0.0; self.config.dimensions];
470 }
471
472 match self.config.aggregation {
473 AggregationMethod::Mean => {
474 let mut sum = vec![0.0; self.config.dimensions];
475
476 for (_, embedding) in word_embeddings {
477 for (i, val) in embedding.iter().enumerate() {
478 sum[i] += val;
479 }
480 }
481
482 let count = word_embeddings.len() as f32;
483 sum.iter().map(|&s| s / count).collect()
484 }
485 AggregationMethod::WeightedMean => {
486 let mut word_counts: HashMap<String, usize> = HashMap::new();
488 for (word, _) in word_embeddings {
489 *word_counts.entry(word.clone()).or_insert(0) += 1;
490 }
491
492 let total_words = word_embeddings.len() as f32;
493 let mut weighted_sum = vec![0.0; self.config.dimensions];
494
495 for (word, embedding) in word_embeddings {
496 let weight = word_counts[word] as f32 / total_words;
497 for (i, val) in embedding.iter().enumerate() {
498 weighted_sum[i] += val * weight;
499 }
500 }
501
502 weighted_sum
503 }
504 AggregationMethod::Max => {
505 let mut max_vals = vec![f32::NEG_INFINITY; self.config.dimensions];
506
507 for (_, embedding) in word_embeddings {
508 for (i, val) in embedding.iter().enumerate() {
509 max_vals[i] = max_vals[i].max(*val);
510 }
511 }
512
513 max_vals
514 }
515 AggregationMethod::Min => {
516 let mut min_vals = vec![f32::INFINITY; self.config.dimensions];
517
518 for (_, embedding) in word_embeddings {
519 for (i, val) in embedding.iter().enumerate() {
520 min_vals[i] = min_vals[i].min(*val);
521 }
522 }
523
524 min_vals
525 }
526 AggregationMethod::MeanMax => {
527 let mean =
529 self.aggregate_embeddings_with_method(word_embeddings, AggregationMethod::Mean);
530 let max =
531 self.aggregate_embeddings_with_method(word_embeddings, AggregationMethod::Max);
532
533 let mut result = Vec::with_capacity(self.config.dimensions * 2);
534 result.extend(mean);
535 result.extend(max);
536
537 result.resize(self.config.dimensions, 0.0);
539 result
540 }
541 AggregationMethod::TfIdfWeighted => {
542 if self.doc_frequencies.is_empty() {
544 return self.aggregate_embeddings_with_method(
546 word_embeddings,
547 AggregationMethod::Mean,
548 );
549 }
550
551 let mut weighted_sum = vec![0.0; self.config.dimensions];
552 let mut total_weight = 0.0;
553
554 for (word, embedding) in word_embeddings {
555 let tf = word_embeddings.iter().filter(|(w, _)| w == word).count() as f32
556 / word_embeddings.len() as f32;
557 let idf = self.doc_frequencies.get(word).unwrap_or(&1.0);
558 let weight = tf * idf;
559
560 for (i, val) in embedding.iter().enumerate() {
561 weighted_sum[i] += val * weight;
562 }
563 total_weight += weight;
564 }
565
566 if total_weight > 0.0 {
567 weighted_sum.iter().map(|&s| s / total_weight).collect()
568 } else {
569 weighted_sum
570 }
571 }
572 }
573 }
574
575 fn aggregate_embeddings_with_method(
577 &self,
578 word_embeddings: &[(String, Vec<f32>)],
579 method: AggregationMethod,
580 ) -> Vec<f32> {
581 let _original_method = self.config.aggregation;
582 let mut config_clone = self.config.clone();
583 config_clone.aggregation = method;
584
585 let temp_self = Self {
586 config: config_clone,
587 embedding_config: self.embedding_config.clone(),
588 embeddings: self.embeddings.clone(),
589 subword_embeddings: self.subword_embeddings.clone(),
590 doc_frequencies: self.doc_frequencies.clone(),
591 oov_embedding: self.oov_embedding.clone(),
592 };
593
594 temp_self.aggregate_embeddings(word_embeddings)
595 }
596
597 pub fn set_document_frequencies(&mut self, frequencies: HashMap<String, f32>) {
599 self.doc_frequencies = frequencies;
600 }
601
602 pub fn calculate_document_frequencies(&mut self, documents: &[String]) -> Result<()> {
604 let total_docs = documents.len() as f32;
605 let mut doc_counts: HashMap<String, usize> = HashMap::new();
606
607 for doc in documents {
608 let words = self.tokenize(doc);
609 let unique_words: std::collections::HashSet<_> = words.into_iter().collect();
610
611 for word in unique_words {
612 *doc_counts.entry(word).or_insert(0) += 1;
613 }
614 }
615
616 self.doc_frequencies = doc_counts
618 .into_iter()
619 .map(|(word, count)| {
620 let idf = (total_docs / (count as f32 + 1.0)).ln();
621 (word, idf)
622 })
623 .collect();
624
625 Ok(())
626 }
627}
628
629impl EmbeddingGenerator for Word2VecEmbeddingGenerator {
630 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
631 let text = content.to_text();
632 let words = self.tokenize(&text);
633
634 let mut word_embeddings = Vec::new();
636
637 for word in words {
638 if let Some(embedding) = self.get_word_embedding(&word) {
639 word_embeddings.push((word, embedding));
640 }
641 }
642
643 if word_embeddings.is_empty() {
644 return Ok(Vector::new(vec![0.0; self.config.dimensions]));
645 }
646
647 let mut document_embedding = self.aggregate_embeddings(&word_embeddings);
649
650 if self.config.normalize {
652 use oxirs_core::simd::SimdOps;
653 let norm = f32::norm(&document_embedding);
654 if norm > 0.0 {
655 for val in &mut document_embedding {
656 *val /= norm;
657 }
658 }
659 }
660
661 Ok(Vector::new(document_embedding))
662 }
663
664 fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
665 contents.iter().map(|c| self.generate(c)).collect()
668 }
669
670 fn dimensions(&self) -> usize {
671 self.config.dimensions
672 }
673
674 fn config(&self) -> &EmbeddingConfig {
675 &self.embedding_config
676 }
677}
678
679impl crate::embeddings::AsAny for Word2VecEmbeddingGenerator {
680 fn as_any(&self) -> &dyn std::any::Any {
681 self
682 }
683
684 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
685 self
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_word2vec_generator() {
695 let config = Word2VecConfig {
696 dimensions: 100,
697 ..Default::default()
698 };
699
700 let embedding_config = EmbeddingConfig {
701 model_name: "word2vec-test".to_string(),
702 dimensions: 100,
703 max_sequence_length: 512,
704 normalize: true,
705 };
706
707 let mut generator = Word2VecEmbeddingGenerator::new(config, embedding_config).unwrap();
708
709 generator
711 .embeddings
712 .insert("hello".to_string(), vec![0.1; 100]);
713 generator
714 .embeddings
715 .insert("world".to_string(), vec![0.2; 100]);
716
717 let content = EmbeddableContent::Text("hello world".to_string());
719 let embedding = generator.generate(&content).unwrap();
720
721 assert_eq!(embedding.dimensions, 100);
722 }
723
724 #[test]
725 fn test_subword_generation() {
726 let config = Word2VecConfig::default();
727 let generator =
728 Word2VecEmbeddingGenerator::new(config, EmbeddingConfig::default()).unwrap();
729
730 let subwords = generator.get_subwords("hello");
731 assert!(subwords.contains(&"<hel>".to_string()));
732 assert!(subwords.contains(&"<ell>".to_string()));
733 assert!(subwords.contains(&"<llo>".to_string()));
734 }
735
736 #[test]
737 fn test_aggregation_methods() {
738 let mut config = Word2VecConfig {
739 dimensions: 3,
740 normalize: false,
741 ..Default::default()
742 };
743
744 let embedding_config = EmbeddingConfig {
745 model_name: "test".to_string(),
746 dimensions: 3,
747 max_sequence_length: 512,
748 normalize: false,
749 };
750
751 for method in [
753 AggregationMethod::Mean,
754 AggregationMethod::Max,
755 AggregationMethod::Min,
756 ] {
757 config.aggregation = method;
758 let mut generator =
759 Word2VecEmbeddingGenerator::new(config.clone(), embedding_config.clone()).unwrap();
760
761 generator
762 .embeddings
763 .insert("a".to_string(), vec![1.0, 2.0, 3.0]);
764 generator
765 .embeddings
766 .insert("b".to_string(), vec![4.0, 5.0, 6.0]);
767
768 let content = EmbeddableContent::Text("a b".to_string());
769 let embedding = generator.generate(&content).unwrap();
770
771 match method {
772 AggregationMethod::Mean => {
773 assert_eq!(embedding.as_f32(), vec![2.5, 3.5, 4.5]);
774 }
775 AggregationMethod::Max => {
776 assert_eq!(embedding.as_f32(), vec![4.0, 5.0, 6.0]);
777 }
778 AggregationMethod::Min => {
779 assert_eq!(embedding.as_f32(), vec![1.0, 2.0, 3.0]);
780 }
781 _ => {}
782 }
783 }
784 }
785}