oxirs_vec/
word2vec.rs

1//! Word2Vec embedding integration for text content
2//!
3//! This module provides Word2Vec-based embeddings with support for:
4//! - Pre-trained model loading
5//! - Document embedding aggregation
6//! - Subword handling
7//! - Out-of-vocabulary management
8//! - Hierarchical softmax support
9
10use crate::{
11    embeddings::{EmbeddableContent, EmbeddingConfig, EmbeddingGenerator},
12    Vector,
13};
14use anyhow::{anyhow, Result};
15use scirs2_core::random::Random;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::File;
19use std::io::{BufRead, BufReader};
20use std::path::Path;
21
22/// Word2Vec model format
23#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
24pub enum Word2VecFormat {
25    /// Text format (word2vec text format)
26    Text,
27    /// Binary format (word2vec binary format)
28    Binary,
29    /// GloVe format (space-separated text)
30    GloVe,
31}
32
33/// Word2Vec configuration
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Word2VecConfig {
36    /// Path to pre-trained model file
37    pub model_path: String,
38    /// Model format
39    pub format: Word2VecFormat,
40    /// Embedding dimensions
41    pub dimensions: usize,
42    /// Aggregation method for document embeddings
43    pub aggregation: AggregationMethod,
44    /// Enable subword handling
45    pub use_subwords: bool,
46    /// Minimum subword length
47    pub min_subword_len: usize,
48    /// Maximum subword length
49    pub max_subword_len: usize,
50    /// Out-of-vocabulary strategy
51    pub oov_strategy: OovStrategy,
52    /// Whether to normalize embeddings
53    pub normalize: bool,
54}
55
56impl Default for Word2VecConfig {
57    fn default() -> Self {
58        Self {
59            model_path: String::new(),
60            format: Word2VecFormat::Text,
61            dimensions: 300,
62            aggregation: AggregationMethod::Mean,
63            use_subwords: true,
64            min_subword_len: 3,
65            max_subword_len: 6,
66            oov_strategy: OovStrategy::Subword,
67            normalize: true,
68        }
69    }
70}
71
72/// Aggregation method for combining word embeddings into document embeddings
73#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
74pub enum AggregationMethod {
75    /// Simple average of word embeddings
76    Mean,
77    /// Weighted average by term frequency
78    WeightedMean,
79    /// Max pooling across dimensions
80    Max,
81    /// Min pooling across dimensions
82    Min,
83    /// Concatenation of mean and max
84    MeanMax,
85    /// TF-IDF weighted average
86    TfIdfWeighted,
87}
88
89/// Out-of-vocabulary word handling strategy
90#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
91pub enum OovStrategy {
92    /// Use zero vector
93    Zero,
94    /// Use random vector
95    Random,
96    /// Use subword embeddings
97    Subword,
98    /// Skip OOV words
99    Skip,
100    /// Use a learned OOV embedding
101    LearnedOov,
102}
103
104/// Word2Vec embedding generator
105pub struct Word2VecEmbeddingGenerator {
106    config: Word2VecConfig,
107    embedding_config: EmbeddingConfig,
108    /// Word embeddings lookup table
109    embeddings: HashMap<String, Vec<f32>>,
110    /// Subword embeddings for OOV handling
111    subword_embeddings: HashMap<String, Vec<f32>>,
112    /// Document frequency for TF-IDF weighting
113    doc_frequencies: HashMap<String, f32>,
114    /// Learned OOV embedding
115    oov_embedding: Option<Vec<f32>>,
116}
117
118impl Word2VecEmbeddingGenerator {
119    /// Create a new Word2Vec embedding generator
120    pub fn new(word2vec_config: Word2VecConfig, embedding_config: EmbeddingConfig) -> Result<Self> {
121        let mut generator = Self {
122            config: word2vec_config,
123            embedding_config,
124            embeddings: HashMap::new(),
125            subword_embeddings: HashMap::new(),
126            doc_frequencies: HashMap::new(),
127            oov_embedding: None,
128        };
129
130        // Load pre-trained embeddings if path is provided
131        let model_path = generator.config.model_path.clone();
132        if !model_path.is_empty() {
133            generator.load_model(&model_path)?;
134        }
135
136        Ok(generator)
137    }
138
139    /// Load pre-trained Word2Vec model
140    pub fn load_model(&mut self, path: &str) -> Result<()> {
141        let path = Path::new(path);
142
143        if !path.exists() {
144            return Err(anyhow!("Model file not found: {}", path.display()));
145        }
146
147        match self.config.format {
148            Word2VecFormat::Text => self.load_text_format(path),
149            Word2VecFormat::Binary => self.load_binary_format(path),
150            Word2VecFormat::GloVe => self.load_glove_format(path),
151        }
152    }
153
154    /// Load Word2Vec text format
155    fn load_text_format(&mut self, path: &Path) -> Result<()> {
156        let file = File::open(path)?;
157        let reader = BufReader::new(file);
158        let mut lines = reader.lines();
159
160        // First line contains vocab size and dimensions
161        if let Some(Ok(header)) = lines.next() {
162            let parts: Vec<&str> = header.split_whitespace().collect();
163            if parts.len() == 2 {
164                let _vocab_size: usize = parts[0].parse()?;
165                let dimensions: usize = parts[1].parse()?;
166
167                if dimensions != self.config.dimensions {
168                    return Err(anyhow!(
169                        "Model dimensions ({}) don't match config ({})",
170                        dimensions,
171                        self.config.dimensions
172                    ));
173                }
174            }
175        }
176
177        // Read embeddings
178        for line in lines {
179            let line = line?;
180            let parts: Vec<&str> = line.split_whitespace().collect();
181
182            if parts.len() < self.config.dimensions + 1 {
183                continue;
184            }
185
186            let word = parts[0].to_string();
187            let embedding: Result<Vec<f32>> = parts[1..=self.config.dimensions]
188                .iter()
189                .map(|s| s.parse::<f32>().map_err(Into::into))
190                .collect();
191
192            if let Ok(embedding) = embedding {
193                self.embeddings.insert(word, embedding);
194            }
195        }
196
197        // Generate subword embeddings if enabled
198        if self.config.use_subwords {
199            self.generate_subword_embeddings()?;
200        }
201
202        // Initialize OOV embedding if using learned strategy
203        if self.config.oov_strategy == OovStrategy::LearnedOov {
204            self.initialize_oov_embedding();
205        }
206
207        Ok(())
208    }
209
210    /// Load Word2Vec binary format
211    fn load_binary_format(&mut self, path: &Path) -> Result<()> {
212        use std::io::Read;
213
214        let mut file = File::open(path)?;
215        let mut buffer = Vec::new();
216        file.read_to_end(&mut buffer)?;
217
218        // Parse binary format
219        #[allow(unused_assignments)]
220        let mut pos = 0;
221
222        // Read header
223        let header_end = buffer
224            .iter()
225            .position(|&b| b == b'\n')
226            .ok_or_else(|| anyhow!("Invalid binary format"))?;
227        let header = std::str::from_utf8(&buffer[..header_end])?;
228        let parts: Vec<&str> = header.split_whitespace().collect();
229
230        if parts.len() != 2 {
231            return Err(anyhow!("Invalid header format"));
232        }
233
234        let vocab_size: usize = parts[0].parse()?;
235        let dimensions: usize = parts[1].parse()?;
236
237        if dimensions != self.config.dimensions {
238            return Err(anyhow!(
239                "Model dimensions ({}) don't match config ({})",
240                dimensions,
241                self.config.dimensions
242            ));
243        }
244
245        pos = header_end + 1;
246
247        // Read embeddings
248        for _ in 0..vocab_size {
249            // Read word until space
250            let word_start = pos;
251            while pos < buffer.len() && buffer[pos] != b' ' {
252                pos += 1;
253            }
254
255            if pos >= buffer.len() {
256                break;
257            }
258
259            let word = std::str::from_utf8(&buffer[word_start..pos])?.to_string();
260            pos += 1; // Skip space
261
262            // Read embedding values
263            let mut embedding = Vec::with_capacity(dimensions);
264            for _ in 0..dimensions {
265                if pos + 4 > buffer.len() {
266                    break;
267                }
268
269                let bytes = [
270                    buffer[pos],
271                    buffer[pos + 1],
272                    buffer[pos + 2],
273                    buffer[pos + 3],
274                ];
275                let value = f32::from_le_bytes(bytes);
276                embedding.push(value);
277                pos += 4;
278            }
279
280            if embedding.len() == dimensions {
281                self.embeddings.insert(word, embedding);
282            }
283
284            // Skip newline if present
285            if pos < buffer.len() && buffer[pos] == b'\n' {
286                pos += 1;
287            }
288        }
289
290        // Generate subword embeddings if enabled
291        if self.config.use_subwords {
292            self.generate_subword_embeddings()?;
293        }
294
295        Ok(())
296    }
297
298    /// Load GloVe format
299    fn load_glove_format(&mut self, path: &Path) -> Result<()> {
300        let file = File::open(path)?;
301        let reader = BufReader::new(file);
302
303        for line in reader.lines() {
304            let line = line?;
305            let parts: Vec<&str> = line.split_whitespace().collect();
306
307            if parts.len() < self.config.dimensions + 1 {
308                continue;
309            }
310
311            let word = parts[0].to_string();
312            let embedding: Result<Vec<f32>> = parts[1..=self.config.dimensions]
313                .iter()
314                .map(|s| s.parse::<f32>().map_err(Into::into))
315                .collect();
316
317            if let Ok(embedding) = embedding {
318                self.embeddings.insert(word, embedding);
319            }
320        }
321
322        // Generate subword embeddings if enabled
323        if self.config.use_subwords {
324            self.generate_subword_embeddings()?;
325        }
326
327        Ok(())
328    }
329
330    /// Generate subword embeddings from word embeddings
331    fn generate_subword_embeddings(&mut self) -> Result<()> {
332        let mut subword_counts: HashMap<String, usize> = HashMap::new();
333        let mut subword_sums: HashMap<String, Vec<f32>> = HashMap::new();
334
335        // Collect subwords from vocabulary
336        for (word, embedding) in &self.embeddings {
337            let subwords = self.get_subwords(word);
338
339            for subword in subwords {
340                *subword_counts.entry(subword.clone()).or_insert(0) += 1;
341
342                let sum = subword_sums
343                    .entry(subword)
344                    .or_insert_with(|| vec![0.0; self.config.dimensions]);
345                for (i, val) in embedding.iter().enumerate() {
346                    sum[i] += val;
347                }
348            }
349        }
350
351        // Average subword embeddings
352        for (subword, count) in subword_counts {
353            if let Some(sum) = subword_sums.get(&subword) {
354                let avg: Vec<f32> = sum.iter().map(|&s| s / count as f32).collect();
355                self.subword_embeddings.insert(subword, avg);
356            }
357        }
358
359        Ok(())
360    }
361
362    /// Get subwords for a given word
363    fn get_subwords(&self, word: &str) -> Vec<String> {
364        let mut subwords = Vec::new();
365        let chars: Vec<char> = word.chars().collect();
366
367        for len in self.config.min_subword_len..=self.config.max_subword_len.min(chars.len()) {
368            for start in 0..=chars.len().saturating_sub(len) {
369                let subword: String = chars[start..start + len].iter().collect();
370                subwords.push(format!("<{subword}>")); // Mark as subword
371            }
372        }
373
374        subwords
375    }
376
377    /// Initialize learned OOV embedding
378    fn initialize_oov_embedding(&mut self) {
379        // Average all embeddings to create OOV embedding
380        let mut sum = vec![0.0; self.config.dimensions];
381        let count = self.embeddings.len() as f32;
382
383        for embedding in self.embeddings.values() {
384            for (i, val) in embedding.iter().enumerate() {
385                sum[i] += val;
386            }
387        }
388
389        self.oov_embedding = Some(sum.iter().map(|&s| s / count).collect());
390    }
391
392    /// Get embedding for a word
393    fn get_word_embedding(&self, word: &str) -> Option<Vec<f32>> {
394        // Try exact match first
395        if let Some(embedding) = self.embeddings.get(word) {
396            return Some(embedding.clone());
397        }
398
399        // Try lowercase
400        if let Some(embedding) = self.embeddings.get(&word.to_lowercase()) {
401            return Some(embedding.clone());
402        }
403
404        // Handle OOV
405        match self.config.oov_strategy {
406            OovStrategy::Zero => Some(vec![0.0; self.config.dimensions]),
407            OovStrategy::Random => {
408                // Generate deterministic "random" vector based on word hash
409                let mut hasher = std::collections::hash_map::DefaultHasher::new();
410                std::hash::Hash::hash(&word, &mut hasher);
411                let hash = std::hash::Hasher::finish(&hasher);
412
413                let mut rng = Random::seed(hash);
414
415                Some(
416                    (0..self.config.dimensions)
417                        .map(|_| rng.gen_range(-0.1..0.1))
418                        .collect(),
419                )
420            }
421            OovStrategy::Subword => {
422                if self.config.use_subwords {
423                    self.get_subword_embedding(word)
424                } else {
425                    None
426                }
427            }
428            OovStrategy::Skip => None,
429            OovStrategy::LearnedOov => self.oov_embedding.clone(),
430        }
431    }
432
433    /// Get subword-based embedding for OOV word
434    fn get_subword_embedding(&self, word: &str) -> Option<Vec<f32>> {
435        let subwords = self.get_subwords(word);
436        let mut sum = vec![0.0; self.config.dimensions];
437        let mut count = 0;
438
439        for subword in subwords {
440            if let Some(embedding) = self.subword_embeddings.get(&subword) {
441                for (i, val) in embedding.iter().enumerate() {
442                    sum[i] += val;
443                }
444                count += 1;
445            }
446        }
447
448        if count > 0 {
449            Some(sum.iter().map(|&s| s / count as f32).collect())
450        } else {
451            None
452        }
453    }
454
455    /// Tokenize text into words
456    fn tokenize(&self, text: &str) -> Vec<String> {
457        text.to_lowercase()
458            .split_whitespace()
459            .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric()))
460            .filter(|s| !s.is_empty())
461            .map(String::from)
462            .collect()
463    }
464
465    /// Aggregate word embeddings into document embedding
466    fn aggregate_embeddings(&self, word_embeddings: &[(String, Vec<f32>)]) -> Vec<f32> {
467        if word_embeddings.is_empty() {
468            return vec![0.0; self.config.dimensions];
469        }
470
471        match self.config.aggregation {
472            AggregationMethod::Mean => {
473                let mut sum = vec![0.0; self.config.dimensions];
474
475                for (_, embedding) in word_embeddings {
476                    for (i, val) in embedding.iter().enumerate() {
477                        sum[i] += val;
478                    }
479                }
480
481                let count = word_embeddings.len() as f32;
482                sum.iter().map(|&s| s / count).collect()
483            }
484            AggregationMethod::WeightedMean => {
485                // Weight by word frequency in document
486                let mut word_counts: HashMap<String, usize> = HashMap::new();
487                for (word, _) in word_embeddings {
488                    *word_counts.entry(word.clone()).or_insert(0) += 1;
489                }
490
491                let total_words = word_embeddings.len() as f32;
492                let mut weighted_sum = vec![0.0; self.config.dimensions];
493
494                for (word, embedding) in word_embeddings {
495                    let weight = word_counts[word] as f32 / total_words;
496                    for (i, val) in embedding.iter().enumerate() {
497                        weighted_sum[i] += val * weight;
498                    }
499                }
500
501                weighted_sum
502            }
503            AggregationMethod::Max => {
504                let mut max_vals = vec![f32::NEG_INFINITY; self.config.dimensions];
505
506                for (_, embedding) in word_embeddings {
507                    for (i, val) in embedding.iter().enumerate() {
508                        max_vals[i] = max_vals[i].max(*val);
509                    }
510                }
511
512                max_vals
513            }
514            AggregationMethod::Min => {
515                let mut min_vals = vec![f32::INFINITY; self.config.dimensions];
516
517                for (_, embedding) in word_embeddings {
518                    for (i, val) in embedding.iter().enumerate() {
519                        min_vals[i] = min_vals[i].min(*val);
520                    }
521                }
522
523                min_vals
524            }
525            AggregationMethod::MeanMax => {
526                // Concatenate mean and max embeddings
527                let mean =
528                    self.aggregate_embeddings_with_method(word_embeddings, AggregationMethod::Mean);
529                let max =
530                    self.aggregate_embeddings_with_method(word_embeddings, AggregationMethod::Max);
531
532                let mut result = Vec::with_capacity(self.config.dimensions * 2);
533                result.extend(mean);
534                result.extend(max);
535
536                // Truncate or pad to match configured dimensions
537                result.resize(self.config.dimensions, 0.0);
538                result
539            }
540            AggregationMethod::TfIdfWeighted => {
541                // Weight by TF-IDF if document frequencies are available
542                if self.doc_frequencies.is_empty() {
543                    // Fall back to mean if no doc frequencies
544                    return self.aggregate_embeddings_with_method(
545                        word_embeddings,
546                        AggregationMethod::Mean,
547                    );
548                }
549
550                let mut weighted_sum = vec![0.0; self.config.dimensions];
551                let mut total_weight = 0.0;
552
553                for (word, embedding) in word_embeddings {
554                    let tf = word_embeddings.iter().filter(|(w, _)| w == word).count() as f32
555                        / word_embeddings.len() as f32;
556                    let idf = self.doc_frequencies.get(word).unwrap_or(&1.0);
557                    let weight = tf * idf;
558
559                    for (i, val) in embedding.iter().enumerate() {
560                        weighted_sum[i] += val * weight;
561                    }
562                    total_weight += weight;
563                }
564
565                if total_weight > 0.0 {
566                    weighted_sum.iter().map(|&s| s / total_weight).collect()
567                } else {
568                    weighted_sum
569                }
570            }
571        }
572    }
573
574    /// Helper method for recursive aggregation
575    fn aggregate_embeddings_with_method(
576        &self,
577        word_embeddings: &[(String, Vec<f32>)],
578        method: AggregationMethod,
579    ) -> Vec<f32> {
580        let _original_method = self.config.aggregation;
581        let mut config_clone = self.config.clone();
582        config_clone.aggregation = method;
583
584        let temp_self = Self {
585            config: config_clone,
586            embedding_config: self.embedding_config.clone(),
587            embeddings: self.embeddings.clone(),
588            subword_embeddings: self.subword_embeddings.clone(),
589            doc_frequencies: self.doc_frequencies.clone(),
590            oov_embedding: self.oov_embedding.clone(),
591        };
592
593        temp_self.aggregate_embeddings(word_embeddings)
594    }
595
596    /// Set document frequencies for TF-IDF weighting
597    pub fn set_document_frequencies(&mut self, frequencies: HashMap<String, f32>) {
598        self.doc_frequencies = frequencies;
599    }
600
601    /// Calculate document frequencies from a corpus
602    pub fn calculate_document_frequencies(&mut self, documents: &[String]) -> Result<()> {
603        let total_docs = documents.len() as f32;
604        let mut doc_counts: HashMap<String, usize> = HashMap::new();
605
606        for doc in documents {
607            let words = self.tokenize(doc);
608            let unique_words: std::collections::HashSet<_> = words.into_iter().collect();
609
610            for word in unique_words {
611                *doc_counts.entry(word).or_insert(0) += 1;
612            }
613        }
614
615        // Calculate IDF scores
616        self.doc_frequencies = doc_counts
617            .into_iter()
618            .map(|(word, count)| {
619                let idf = (total_docs / (count as f32 + 1.0)).ln();
620                (word, idf)
621            })
622            .collect();
623
624        Ok(())
625    }
626}
627
628impl EmbeddingGenerator for Word2VecEmbeddingGenerator {
629    fn generate(&self, content: &EmbeddableContent) -> Result<Vector> {
630        let text = content.to_text();
631        let words = self.tokenize(&text);
632
633        // Get embeddings for each word
634        let mut word_embeddings = Vec::new();
635
636        for word in words {
637            if let Some(embedding) = self.get_word_embedding(&word) {
638                word_embeddings.push((word, embedding));
639            }
640        }
641
642        if word_embeddings.is_empty() {
643            return Ok(Vector::new(vec![0.0; self.config.dimensions]));
644        }
645
646        // Aggregate embeddings
647        let mut document_embedding = self.aggregate_embeddings(&word_embeddings);
648
649        // Normalize if configured
650        if self.config.normalize {
651            use oxirs_core::simd::SimdOps;
652            let norm = f32::norm(&document_embedding);
653            if norm > 0.0 {
654                for val in &mut document_embedding {
655                    *val /= norm;
656                }
657            }
658        }
659
660        Ok(Vector::new(document_embedding))
661    }
662
663    fn generate_batch(&self, contents: &[EmbeddableContent]) -> Result<Vec<Vector>> {
664        // For Word2Vec, batch processing doesn't provide significant benefits
665        // since we're just looking up pre-computed embeddings
666        contents.iter().map(|c| self.generate(c)).collect()
667    }
668
669    fn dimensions(&self) -> usize {
670        self.config.dimensions
671    }
672
673    fn config(&self) -> &EmbeddingConfig {
674        &self.embedding_config
675    }
676}
677
678impl crate::embeddings::AsAny for Word2VecEmbeddingGenerator {
679    fn as_any(&self) -> &dyn std::any::Any {
680        self
681    }
682
683    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
684        self
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691
692    #[test]
693    fn test_word2vec_generator() {
694        let config = Word2VecConfig {
695            dimensions: 100,
696            ..Default::default()
697        };
698
699        let embedding_config = EmbeddingConfig {
700            model_name: "word2vec-test".to_string(),
701            dimensions: 100,
702            max_sequence_length: 512,
703            normalize: true,
704        };
705
706        let mut generator = Word2VecEmbeddingGenerator::new(config, embedding_config).unwrap();
707
708        // Add some test embeddings
709        generator
710            .embeddings
711            .insert("hello".to_string(), vec![0.1; 100]);
712        generator
713            .embeddings
714            .insert("world".to_string(), vec![0.2; 100]);
715
716        // Test embedding generation
717        let content = EmbeddableContent::Text("hello world".to_string());
718        let embedding = generator.generate(&content).unwrap();
719
720        assert_eq!(embedding.dimensions, 100);
721    }
722
723    #[test]
724    fn test_subword_generation() {
725        let config = Word2VecConfig::default();
726        let generator =
727            Word2VecEmbeddingGenerator::new(config, EmbeddingConfig::default()).unwrap();
728
729        let subwords = generator.get_subwords("hello");
730        assert!(subwords.contains(&"<hel>".to_string()));
731        assert!(subwords.contains(&"<ell>".to_string()));
732        assert!(subwords.contains(&"<llo>".to_string()));
733    }
734
735    #[test]
736    fn test_aggregation_methods() {
737        let mut config = Word2VecConfig {
738            dimensions: 3,
739            normalize: false,
740            ..Default::default()
741        };
742
743        let embedding_config = EmbeddingConfig {
744            model_name: "test".to_string(),
745            dimensions: 3,
746            max_sequence_length: 512,
747            normalize: false,
748        };
749
750        // Test different aggregation methods
751        for method in [
752            AggregationMethod::Mean,
753            AggregationMethod::Max,
754            AggregationMethod::Min,
755        ] {
756            config.aggregation = method;
757            let mut generator =
758                Word2VecEmbeddingGenerator::new(config.clone(), embedding_config.clone()).unwrap();
759
760            generator
761                .embeddings
762                .insert("a".to_string(), vec![1.0, 2.0, 3.0]);
763            generator
764                .embeddings
765                .insert("b".to_string(), vec![4.0, 5.0, 6.0]);
766
767            let content = EmbeddableContent::Text("a b".to_string());
768            let embedding = generator.generate(&content).unwrap();
769
770            match method {
771                AggregationMethod::Mean => {
772                    assert_eq!(embedding.as_f32(), vec![2.5, 3.5, 4.5]);
773                }
774                AggregationMethod::Max => {
775                    assert_eq!(embedding.as_f32(), vec![4.0, 5.0, 6.0]);
776                }
777                AggregationMethod::Min => {
778                    assert_eq!(embedding.as_f32(), vec![1.0, 2.0, 3.0]);
779                }
780                _ => {}
781            }
782        }
783    }
784}