1use crate::{
11 embeddings::{EmbeddableContent, EmbeddingConfig, EmbeddingGenerator},
12 Vector,
13};
14use anyhow::{anyhow, Result};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufRead, BufReader};
20use std::path::Path;
21
22#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
24pub enum Word2VecFormat {
25 Text,
27 Binary,
29 GloVe,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Word2VecConfig {
36 pub model_path: String,
38 pub format: Word2VecFormat,
40 pub dimensions: usize,
42 pub aggregation: AggregationMethod,
44 pub use_subwords: bool,
46 pub min_subword_len: usize,
48 pub max_subword_len: usize,
50 pub oov_strategy: OovStrategy,
52 pub normalize: bool,
54}
55
56impl Default for Word2VecConfig {
57 fn default() -> Self {
58 Self {
59 model_path: String::new(),
60 format: Word2VecFormat::Text,
61 dimensions: 300,
62 aggregation: AggregationMethod::Mean,
63 use_subwords: true,
64 min_subword_len: 3,
65 max_subword_len: 6,
66 oov_strategy: OovStrategy::Subword,
67 normalize: true,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
74pub enum AggregationMethod {
75 Mean,
77 WeightedMean,
79 Max,
81 Min,
83 MeanMax,
85 TfIdfWeighted,
87}
88
89#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
91pub enum OovStrategy {
92 Zero,
94 Random,
96 Subword,
98 Skip,
100 LearnedOov,
102}
103
104pub struct Word2VecEmbeddingGenerator {
106 config: Word2VecConfig,
107 embedding_config: EmbeddingConfig,
108 embeddings: HashMap<String, Vec<f32>>,
110 subword_embeddings: HashMap<String, Vec<f32>>,
112 doc_frequencies: HashMap<String, f32>,
114 oov_embedding: Option<Vec<f32>>,
116}
117
118impl Word2VecEmbeddingGenerator {
119 pub fn new(word2vec_config: Word2VecConfig, embedding_config: EmbeddingConfig) -> Result<Self> {
121 let mut generator = Self {
122 config: word2vec_config,
123 embedding_config,
124 embeddings: HashMap::new(),
125 subword_embeddings: HashMap::new(),
126 doc_frequencies: HashMap::new(),
127 oov_embedding: None,
128 };
129
130 let model_path = generator.config.model_path.clone();
132 if !model_path.is_empty() {
133 generator.load_model(&model_path)?;
134 }
135
136 Ok(generator)
137 }
138
139 pub fn load_model(&mut self, path: &str) -> Result<()> {
141 let path = Path::new(path);
142
143 if !path.exists() {
144 return Err(anyhow!("Model file not found: {}", path.display()));
145 }
146
147 match self.config.format {
148 Word2VecFormat::Text => self.load_text_format(path),
149 Word2VecFormat::Binary => self.load_binary_format(path),
150 Word2VecFormat::GloVe => self.load_glove_format(path),
151 }
152 }
153
154 fn load_text_format(&mut self, path: &Path) -> Result<()> {
156 let file = File::open(path)?;
157 let reader = BufReader::new(file);
158 let mut lines = reader.lines();
159
160 if let Some(Ok(header)) = lines.next() {
162 let parts: Vec<&str> = header.split_whitespace().collect();
163 if parts.len() == 2 {
164 let _vocab_size: usize = parts[0].parse()?;
165 let dimensions: usize = parts[1].parse()?;
166
167 if dimensions != self.config.dimensions {
168 return Err(anyhow!(
169 "Model dimensions ({}) don't match config ({})",
170 dimensions,
171 self.config.dimensions
172 ));
173 }
174 }
175 }
176
177 for line in lines {
179 let line = line?;
180 let parts: Vec<&str> = line.split_whitespace().collect();
181
182 if parts.len() < self.config.dimensions + 1 {
183 continue;
184 }
185
186 let word = parts[0].to_string();
187 let embedding: Result<Vec<f32>> = parts[1..=self.config.dimensions]
188 .iter()
189 .map(|s| s.parse::<f32>().map_err(Into::into))
190 .collect();
191
192 if let Ok(embedding) = embedding {
193 self.embeddings.insert(word, embedding);
194 }
195 }
196
197 if self.config.use_subwords {
199 self.generate_subword_embeddings()?;
200 }
201
202 if self.config.oov_strategy == OovStrategy::LearnedOov {
204 self.initialize_oov_embedding();
205 }
206
207 Ok(())
208 }
209
210 fn load_binary_format(&mut self, path: &Path) -> Result<()> {
212 use std::io::Read;
213
214 let mut file = File::open(path)?;
215 let mut buffer = Vec::new();
216 file.read_to_end(&mut buffer)?;
217
218 #[allow(unused_assignments)]
220 let mut pos = 0;
221
222 let header_end = buffer
224 .iter()
225 .position(|&b| b == b'\n')
226 .ok_or_else(|| anyhow!("Invalid binary format"))?;
227 let header = std::str::from_utf8(&buffer[..header_end])?;
228 let parts: Vec<&str> = header.split_whitespace().collect();
229
230 if parts.len() != 2 {
231 return Err(anyhow!("Invalid header format"));
232 }
233
234 let vocab_size: usize = parts[0].parse()?;
235 let dimensions: usize = parts[1].parse()?;
236
237 if dimensions != self.config.dimensions {
238 return Err(anyhow!(
239 "Model dimensions ({}) don't match config ({})",
240 dimensions,
241 self.config.dimensions
242 ));
243 }
244
245 pos = header_end + 1;
246
247 for _ in 0..vocab_size {
249 let word_start = pos;
251 while pos < buffer.len() && buffer[pos] != b' ' {
252 pos += 1;
253 }
254
255 if pos >= buffer.len() {
256 break;
257 }
258
259 let word = std::str::from_utf8(&buffer[word_start..pos])?.to_string();
260 pos += 1; let mut embedding = Vec::with_capacity(dimensions);
264 for _ in 0..dimensions {
265 if pos + 4 > buffer.len() {
266 break;
267 }
268
269 let bytes = [
270 buffer[pos],
271 buffer[pos + 1],
272 buffer[pos + 2],
273 buffer[pos + 3],
274 ];
275 let value = f32::from_le_bytes(bytes);
276 embedding.push(value);
277 pos += 4;
278 }
279
280 if embedding.len() == dimensions {
281 self.embeddings.insert(word, embedding);
282 }
283
284 if pos < buffer.len() && buffer[pos] == b'\n' {
286 pos += 1;
287 }
288 }
289
290 if self.config.use_subwords {
292 self.generate_subword_embeddings()?;
293 }
294
295 Ok(())
296 }
297
298 fn load_glove_format(&mut self, path: &Path) -> Result<()> {
300 let file = File::open(path)?;
301 let reader = BufReader::new(file);
302
303 for line in reader.lines() {
304 let line = line?;
305 let parts: Vec<&str> = line.split_whitespace().collect();
306
307 if parts.len() < self.config.dimensions + 1 {
308 continue;
309 }
310
311 let word = parts[0].to_string();
312 let embedding: Result<Vec<f32>> = parts[1..=self.config.dimensions]
313 .iter()
314 .map(|s| s.parse::<f32>().map_err(Into::into))
315 .collect();
316
317 if let Ok(embedding) = embedding {
318 self.embeddings.insert(word, embedding);
319 }
320 }
321
322 if self.config.use_subwords {
324 self.generate_subword_embeddings()?;
325 }
326
327 Ok(())
328 }
329
330 fn generate_subword_embeddings(&mut self) -> Result<()> {
332 let mut subword_counts: HashMap<String, usize> = HashMap::new();
333 let mut subword_sums: HashMap<String, Vec<f32>> = HashMap::new();
334
335 for (word, embedding) in &self.embeddings {
337 let subwords = self.get_subwords(word);
338
339 for subword in subwords {
340 *subword_counts.entry(subword.clone()).or_insert(0) += 1;
341
342 let sum = subword_sums
343 .entry(subword)
344 .or_insert_with(|| vec![0.0; self.config.dimensions]);
345 for (i, val) in embedding.iter().enumerate() {
346 sum[i] += val;
347 }
348 }
349 }
350
351 for (subword, count) in subword_counts {
353 if let Some(sum) = subword_sums.get(&subword) {
354 let avg: Vec<f32> = sum.iter().map(|&s| s / count as f32).collect();
355 self.subword_embeddings.insert(subword, avg);
356 }
357 }
358
359 Ok(())
360 }
361
362 fn get_subwords(&self, word: &str) -> Vec<String> {
364 let mut subwords = Vec::new();
365 let chars: Vec<char> = word.chars().collect();
366
367 for len in self.config.min_subword_len..=self.config.max_subword_len.min(chars.len()) {
368 for start in 0..=chars.len().saturating_sub(len) {
369 let subword: String = chars[start..start + len].iter().collect();
370 subwords.push(format!("<{subword}>")); }
372 }
373
374 subwords
375 }
376
377 fn initialize_oov_embedding(&mut self) {
379 let mut sum = vec![0.0; self.config.dimensions];
381 let count = self.embeddings.len() as f32;
382
383 for embedding in self.embeddings.values() {
384 for (i, val) in embedding.iter().enumerate() {
385 sum[i] += val;
386 }
387 }
388
389 self.oov_embedding = Some(sum.iter().map(|&s| s / count).collect());
390 }
391
392 fn get_word_embedding(&self, word: &str) -> Option<Vec<f32>> {
394 if let Some(embedding) = self.embeddings.get(word) {
396 return Some(embedding.clone());
397 }
398
399 if let Some(embedding) = self.embeddings.get(&word.to_lowercase()) {
401 return Some(embedding.clone());
402 }
403
404 match self.config.oov_strategy {
406 OovStrategy::Zero => Some(vec![0.0; self.config.dimensions]),
407 OovStrategy::Random => {
408 let mut hasher = std::collections::hash_map::DefaultHasher::new();
410 std::hash::Hash::hash(&word, &mut hasher);
411 let hash = std::hash::Hasher::finish(&hasher);
412
413 let mut rng = Random::seed(hash);
414
415 Some(
416 (0..self.config.dimensions)
417 .map(|_| rng.gen_range(-0.1..0.1))
418 .collect(),
419 )
420 }
421 OovStrategy::Subword => {
422 if self.config.use_subwords {
423 self.get_subword_embedding(word)
424 } else {
425 None
426 }
427 }
428 OovStrategy::Skip => None,
429 OovStrategy::LearnedOov => self.oov_embedding.clone(),
430 }
431 }
432
433 fn get_subword_embedding(&self, word: &str) -> Option<Vec<f32>> {
435 let subwords = self.get_subwords(word);
436 let mut sum = vec![0.0; self.config.dimensions];
437 let mut count = 0;
438
439 for subword in subwords {
440 if let Some(embedding) = self.subword_embeddings.get(&subword) {
441 for (i, val) in embedding.iter().enumerate() {
442 sum[i] += val;
443 }
444 count += 1;
445 }
446 }
447
448 if count > 0 {
449 Some(sum.iter().map(|&s| s / count as f32).collect())
450 } else {
451 None
452 }
453 }
454
455 fn tokenize(&self, text: &str) -> Vec<String> {
457 text.to_lowercase()
458 .split_whitespace()
459 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
460 .filter(|s| !s.is_empty())
461 .map(String::from)
462 .collect()
463 }
464
465 fn aggregate_embeddings(&self, word_embeddings: &[(String, Vec<f32>)]) -> Vec<f32> {
467 if word_embeddings.is_empty() {
468 return vec![0.0; self.config.dimensions];
469 }
470
471 match self.config.aggregation {
472 AggregationMethod::Mean => {
473 let mut sum = vec![0.0; self.config.dimensions];
474
475 for (_, embedding) in word_embeddings {
476 for (i, val) in embedding.iter().enumerate() {
477 sum[i] += val;
478 }
479 }
480
481 let count = word_embeddings.len() as f32;
482 sum.iter().map(|&s| s / count).collect()
483 }
484 AggregationMethod::WeightedMean => {
485 let mut word_counts: HashMap<String, usize> = HashMap::new();
487 for (word, _) in word_embeddings {
488 *word_counts.entry(word.clone()).or_insert(0) += 1;
489 }
490
491 let total_words = word_embeddings.len() as f32;
492 let mut weighted_sum = vec![0.0; self.config.dimensions];
493
494 for (word, embedding) in word_embeddings {
495 let weight = word_counts[word] as f32 / total_words;
496 for (i, val) in embedding.iter().enumerate() {
497 weighted_sum[i] += val * weight;
498 }
499 }
500
501 weighted_sum
502 }
503 AggregationMethod::Max => {
504 let mut max_vals = vec![f32::NEG_INFINITY; self.config.dimensions];
505
506 for (_, embedding) in word_embeddings {
507 for (i, val) in embedding.iter().enumerate() {
508 max_vals[i] = max_vals[i].max(*val);
509 }
510 }
511
512 max_vals
513 }
514 AggregationMethod::Min => {
515 let mut min_vals = vec![f32::INFINITY; self.config.dimensions];
516
517 for (_, embedding) in word_embeddings {
518 for (i, val) in embedding.iter().enumerate() {
519 min_vals[i] = min_vals[i].min(*val);
520 }
521 }
522
523 min_vals
524 }
525 AggregationMethod::MeanMax => {
526 let mean =
528 self.aggregate_embeddings_with_method(word_embeddings, AggregationMethod::Mean);
529 let max =
530 self.aggregate_embeddings_with_method(word_embeddings, AggregationMethod::Max);
531
532 let mut result = Vec::with_capacity(self.config.dimensions * 2);
533 result.extend(mean);
534 result.extend(max);
535
536 result.resize(self.config.dimensions, 0.0);
538 result
539 }
540 AggregationMethod::TfIdfWeighted => {
541 if self.doc_frequencies.is_empty() {
543 return self.aggregate_embeddings_with_method(
545 word_embeddings,
546 AggregationMethod::Mean,
547 );
548 }
549
550 let mut weighted_sum = vec![0.0; self.config.dimensions];
551 let mut total_weight = 0.0;
552
553 for (word, embedding) in word_embeddings {
554 let tf = word_embeddings.iter().filter(|(w, _)| w == word).count() as f32
555 / word_embeddings.len() as f32;
556 let idf = self.doc_frequencies.get(word).unwrap_or(&1.0);
557 let weight = tf * idf;
558
559 for (i, val) in embedding.iter().enumerate() {
560 weighted_sum[i] += val * weight;
561 }
562 total_weight += weight;
563 }
564
565 if total_weight > 0.0 {
566 weighted_sum.iter().map(|&s| s / total_weight).collect()
567 } else {
568 weighted_sum
569 }
570 }
571 }
572 }
573
574 fn aggregate_embeddings_with_method(
576 &self,
577 word_embeddings: &[(String, Vec<f32>)],
578 method: AggregationMethod,
579 ) -> Vec<f32> {
580 let _original_method = self.config.aggregation;
581 let mut config_clone = self.config.clone();
582 config_clone.aggregation = method;
583
584 let temp_self = Self {
585 config: config_clone,
586 embedding_config: self.embedding_config.clone(),
587 embeddings: self.embeddings.clone(),
588 subword_embeddings: self.subword_embeddings.clone(),
589 doc_frequencies: self.doc_frequencies.clone(),
590 oov_embedding: self.oov_embedding.clone(),
591 };
592
593 temp_self.aggregate_embeddings(word_embeddings)
594 }
595
596 pub fn set_document_frequencies(&mut self, frequencies: HashMap<String, f32>) {
598 self.doc_frequencies = frequencies;
599 }
600
601 pub fn calculate_document_frequencies(&mut self, documents: &[String]) -> Result<()> {
603 let total_docs = documents.len() as f32;
604 let mut doc_counts: HashMap<String, usize> = HashMap::new();
605
606 for doc in documents {
607 let words = self.tokenize(doc);
608 let unique_words: std::collections::HashSet<_> = words.into_iter().collect();
609
610 for word in unique_words {
611 *doc_counts.entry(word).or_insert(0) += 1;
612 }
613 }
614
615 self.doc_frequencies = doc_counts
617 .into_iter()
618 .map(|(word, count)| {
619 let idf = (total_docs / (count as f32 + 1.0)).ln();
620 (word, idf)
621 })
622 .collect();
623
624 Ok(())
625 }
626}
627
628impl EmbeddingGenerator for Word2VecEmbeddingGenerator {
629 fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
630 let text = content.to_text();
631 let words = self.tokenize(&text);
632
633 let mut word_embeddings = Vec::new();
635
636 for word in words {
637 if let Some(embedding) = self.get_word_embedding(&word) {
638 word_embeddings.push((word, embedding));
639 }
640 }
641
642 if word_embeddings.is_empty() {
643 return Ok(Vector::new(vec![0.0; self.config.dimensions]));
644 }
645
646 let mut document_embedding = self.aggregate_embeddings(&word_embeddings);
648
649 if self.config.normalize {
651 use oxirs_core::simd::SimdOps;
652 let norm = f32::norm(&document_embedding);
653 if norm > 0.0 {
654 for val in &mut document_embedding {
655 *val /= norm;
656 }
657 }
658 }
659
660 Ok(Vector::new(document_embedding))
661 }
662
663 fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
664 contents.iter().map(|c| self.generate(c)).collect()
667 }
668
669 fn dimensions(&self) -> usize {
670 self.config.dimensions
671 }
672
673 fn config(&self) -> &EmbeddingConfig {
674 &self.embedding_config
675 }
676}
677
678impl crate::embeddings::AsAny for Word2VecEmbeddingGenerator {
679 fn as_any(&self) -> &dyn std::any::Any {
680 self
681 }
682
683 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
684 self
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn test_word2vec_generator() {
694 let config = Word2VecConfig {
695 dimensions: 100,
696 ..Default::default()
697 };
698
699 let embedding_config = EmbeddingConfig {
700 model_name: "word2vec-test".to_string(),
701 dimensions: 100,
702 max_sequence_length: 512,
703 normalize: true,
704 };
705
706 let mut generator = Word2VecEmbeddingGenerator::new(config, embedding_config).unwrap();
707
708 generator
710 .embeddings
711 .insert("hello".to_string(), vec![0.1; 100]);
712 generator
713 .embeddings
714 .insert("world".to_string(), vec![0.2; 100]);
715
716 let content = EmbeddableContent::Text("hello world".to_string());
718 let embedding = generator.generate(&content).unwrap();
719
720 assert_eq!(embedding.dimensions, 100);
721 }
722
723 #[test]
724 fn test_subword_generation() {
725 let config = Word2VecConfig::default();
726 let generator =
727 Word2VecEmbeddingGenerator::new(config, EmbeddingConfig::default()).unwrap();
728
729 let subwords = generator.get_subwords("hello");
730 assert!(subwords.contains(&"<hel>".to_string()));
731 assert!(subwords.contains(&"<ell>".to_string()));
732 assert!(subwords.contains(&"<llo>".to_string()));
733 }
734
735 #[test]
736 fn test_aggregation_methods() {
737 let mut config = Word2VecConfig {
738 dimensions: 3,
739 normalize: false,
740 ..Default::default()
741 };
742
743 let embedding_config = EmbeddingConfig {
744 model_name: "test".to_string(),
745 dimensions: 3,
746 max_sequence_length: 512,
747 normalize: false,
748 };
749
750 for method in [
752 AggregationMethod::Mean,
753 AggregationMethod::Max,
754 AggregationMethod::Min,
755 ] {
756 config.aggregation = method;
757 let mut generator =
758 Word2VecEmbeddingGenerator::new(config.clone(), embedding_config.clone()).unwrap();
759
760 generator
761 .embeddings
762 .insert("a".to_string(), vec![1.0, 2.0, 3.0]);
763 generator
764 .embeddings
765 .insert("b".to_string(), vec![4.0, 5.0, 6.0]);
766
767 let content = EmbeddableContent::Text("a b".to_string());
768 let embedding = generator.generate(&content).unwrap();
769
770 match method {
771 AggregationMethod::Mean => {
772 assert_eq!(embedding.as_f32(), vec![2.5, 3.5, 4.5]);
773 }
774 AggregationMethod::Max => {
775 assert_eq!(embedding.as_f32(), vec![4.0, 5.0, 6.0]);
776 }
777 AggregationMethod::Min => {
778 assert_eq!(embedding.as_f32(), vec![1.0, 2.0, 3.0]);
779 }
780 _ => {}
781 }
782 }
783 }
784}