use crate::embedding::SubwordEmbedding;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OovStrategy {
SubwordEmbedding,
FixedProbability {
log_prob: f64,
},
Uniform,
SimilarWords {
k: usize,
},
}
impl Default for OovStrategy {
fn default() -> Self {
Self::SubwordEmbedding
}
}
pub struct OovHandler<'a> {
embedding: &'a SubwordEmbedding,
strategy: OovStrategy,
vocab_size: usize,
}
impl<'a> OovHandler<'a> {
pub fn new(embedding: &'a SubwordEmbedding, strategy: OovStrategy) -> Self {
let vocab_size = embedding.vocab_size();
Self {
embedding,
strategy,
vocab_size,
}
}
pub fn estimate_log_prob(&self, word: &str, context: &[&str]) -> f64 {
match self.strategy {
OovStrategy::SubwordEmbedding => self.estimate_from_subwords(word, context),
OovStrategy::FixedProbability { log_prob } => log_prob,
OovStrategy::Uniform => -(self.vocab_size as f64).ln(),
OovStrategy::SimilarWords { k } => self.estimate_from_similar(word, context, k),
}
}
fn estimate_from_subwords(&self, word: &str, context: &[&str]) -> f64 {
let word_vec = self.embedding.word_vector(word);
if context.is_empty() {
return -(self.vocab_size as f64).ln();
}
let context_vec = self.embedding.sentence_vector(context);
let similarity = Self::cosine_similarity(&word_vec, &context_vec);
let log_prob = (similarity as f64) - 1.0;
log_prob.max(-20.0).min(-1e-6)
}
fn estimate_from_similar(&self, word: &str, context: &[&str], k: usize) -> f64 {
let similar = self.embedding.most_similar(word, k);
if similar.is_empty() {
return -(self.vocab_size as f64).ln();
}
let mut weighted_sum = 0.0;
let mut weight_total = 0.0;
for (sim_word, similarity) in &similar {
if *similarity > 0.0 {
let sim_vec = self.embedding.word_vector(sim_word);
let context_vec = if context.is_empty() {
ndarray::Array1::zeros(self.embedding.dim())
} else {
self.embedding.sentence_vector(context)
};
let context_sim = Self::cosine_similarity(&sim_vec, &context_vec);
weighted_sum += *similarity as f64 * context_sim as f64;
weight_total += *similarity as f64;
}
}
if weight_total > 0.0 {
let avg_sim = weighted_sum / weight_total;
(avg_sim - 1.0).max(-20.0).min(-1e-6)
} else {
-(self.vocab_size as f64).ln()
}
}
fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
let dot = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn is_oov(&self, word: &str) -> bool {
!self.embedding.contains(word)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::PlaintextReader;
use crate::embedding::EmbeddingTrainerBuilder;
use std::io::Write;
use tempfile::TempDir;
fn create_test_embedding() -> SubwordEmbedding {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox";
let path = dir.path().join("test.txt");
let mut file = std::fs::File::create(&path).expect("Failed to create test file");
write!(file, "{}", content).expect("Failed to write test file");
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
EmbeddingTrainerBuilder::new()
.dim(10)
.window_size(2)
.min_count(1)
.epochs(2)
.train(reader)
.expect("Training failed")
}
#[test]
fn test_oov_subword_strategy() {
let embedding = create_test_embedding();
let handler = OovHandler::new(&embedding, OovStrategy::SubwordEmbedding);
let log_prob = handler.estimate_log_prob("unknown", &["the", "quick"]);
assert!(log_prob.is_finite());
assert!(log_prob < 0.0);
}
#[test]
fn test_oov_fixed_probability() {
let embedding = create_test_embedding();
let handler = OovHandler::new(
&embedding,
OovStrategy::FixedProbability { log_prob: -10.0 },
);
let log_prob = handler.estimate_log_prob("unknown", &["the"]);
assert_eq!(log_prob, -10.0);
}
#[test]
fn test_oov_uniform() {
let embedding = create_test_embedding();
let handler = OovHandler::new(&embedding, OovStrategy::Uniform);
let log_prob = handler.estimate_log_prob("unknown", &[]);
assert!(log_prob.is_finite());
assert!(log_prob < 0.0);
}
#[test]
fn test_is_oov() {
let embedding = create_test_embedding();
let handler = OovHandler::new(&embedding, OovStrategy::default());
assert!(!handler.is_oov("the"));
assert!(handler.is_oov("xyz123"));
}
}