use crate::embedding::SubwordEmbedding;
use crate::ngram::{NgramEntry, NgramModel};
use dashmap::DashMap;
use liblevenshtein::dictionary::MutableMappedDictionary;
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(feature = "serde-extras")]
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum InterpolationStrategy {
Linear {
alpha: f64,
},
LogLinear {
alpha: f64,
},
NgramWithEmbeddingFallback,
Dynamic {
base_alpha: f64,
alpha_per_context: f64,
max_alpha: f64,
},
}
impl Default for InterpolationStrategy {
fn default() -> Self {
Self::Linear { alpha: 0.8 }
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct HybridConfig {
pub strategy: InterpolationStrategy,
pub cache_size: usize,
pub embedding_smoothing: f64,
pub temperature: f64,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
strategy: InterpolationStrategy::default(),
cache_size: 50_000,
embedding_smoothing: 1e-8,
temperature: 1.0,
}
}
}
struct ScoreCache {
entries: DashMap<u64, f64>,
access_order: Mutex<VecDeque<u64>>,
max_entries: usize,
num_entries: AtomicUsize,
}
impl ScoreCache {
fn new(max_entries: usize) -> Self {
Self {
entries: DashMap::with_capacity(max_entries.min(10_000)),
access_order: Mutex::new(VecDeque::with_capacity(max_entries.min(10_000))),
max_entries,
num_entries: AtomicUsize::new(0),
}
}
fn compute_hash(word: &str, context: &[&str]) -> u64 {
let mut hasher = DefaultHasher::new();
word.hash(&mut hasher);
context.len().hash(&mut hasher);
for ctx in context {
ctx.hash(&mut hasher);
}
hasher.finish()
}
fn get(&self, word: &str, context: &[&str]) -> Option<f64> {
let hash = Self::compute_hash(word, context);
self.entries.get(&hash).map(|entry| *entry)
}
fn insert(&self, word: &str, context: &[&str], score: f64) {
let hash = Self::compute_hash(word, context);
if self.entries.contains_key(&hash) {
self.entries.insert(hash, score);
return;
}
self.entries.insert(hash, score);
let count = self.num_entries.fetch_add(1, Ordering::Relaxed) + 1;
{
let mut order = self.access_order.lock();
order.push_back(hash);
}
if count > self.max_entries {
self.evict_oldest();
}
}
fn evict_oldest(&self) {
let hash_to_remove = {
let mut order = self.access_order.lock();
order.pop_front()
};
if let Some(hash) = hash_to_remove {
if self.entries.remove(&hash).is_some() {
self.num_entries.fetch_sub(1, Ordering::Relaxed);
}
}
}
fn clear(&self) {
self.entries.clear();
self.access_order.lock().clear();
self.num_entries.store(0, Ordering::Relaxed);
}
}
impl Default for ScoreCache {
fn default() -> Self {
Self::new(HybridConfig::default().cache_size)
}
}
fn default_cache() -> ScoreCache {
ScoreCache::default()
}
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(bound = "D: serde::Serialize + serde::de::DeserializeOwned")]
pub struct HybridLanguageModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + Send + Sync,
{
ngram: NgramModel<D>,
embedding: SubwordEmbedding,
config: HybridConfig,
#[serde(skip, default = "default_cache")]
cache: ScoreCache,
}
impl<D> HybridLanguageModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + Send + Sync,
{
pub fn new(ngram: NgramModel<D>, embedding: SubwordEmbedding, config: HybridConfig) -> Self {
let cache = ScoreCache::new(config.cache_size.max(1));
Self {
ngram,
embedding,
config,
cache,
}
}
pub fn with_defaults(ngram: NgramModel<D>, embedding: SubwordEmbedding) -> Self {
Self::new(ngram, embedding, HybridConfig::default())
}
pub fn ngram_model(&self) -> &NgramModel<D> {
&self.ngram
}
pub fn embedding_model(&self) -> &SubwordEmbedding {
&self.embedding
}
pub fn config(&self) -> &HybridConfig {
&self.config
}
pub fn score(&self, word: &str, context: &[&str]) -> f64 {
if let Some(cached_score) = self.cache.get(word, context) {
return cached_score;
}
let score = match self.config.strategy {
InterpolationStrategy::Linear { alpha } => self.score_linear(word, context, alpha),
InterpolationStrategy::LogLinear { alpha } => {
self.score_log_linear(word, context, alpha)
}
InterpolationStrategy::NgramWithEmbeddingFallback => {
self.score_with_fallback(word, context)
}
InterpolationStrategy::Dynamic {
base_alpha,
alpha_per_context,
max_alpha,
} => {
let alpha = (base_alpha + alpha_per_context * context.len() as f64).min(max_alpha);
self.score_linear(word, context, alpha)
}
};
self.cache.insert(word, context, score);
score
}
fn score_linear(&self, word: &str, context: &[&str], alpha: f64) -> f64 {
let ngram_log_prob = self.ngram.log_prob(word, context);
let embedding_log_prob = self.embedding_log_prob(word, context);
let min_log_prob = -50.0;
let ngram_log_prob = ngram_log_prob.max(min_log_prob);
let embedding_log_prob = embedding_log_prob.max(min_log_prob);
let ngram_prob = ngram_log_prob.exp();
let embedding_prob = embedding_log_prob.exp();
let combined_prob = alpha * ngram_prob + (1.0 - alpha) * embedding_prob;
combined_prob.max(f64::MIN_POSITIVE).ln()
}
fn score_log_linear(&self, word: &str, context: &[&str], alpha: f64) -> f64 {
let ngram_log_prob = self.ngram.log_prob(word, context);
let embedding_log_prob = self.embedding_log_prob(word, context);
let min_log_prob = -50.0;
let ngram_log_prob = ngram_log_prob.max(min_log_prob);
let embedding_log_prob = embedding_log_prob.max(min_log_prob);
alpha * ngram_log_prob + (1.0 - alpha) * embedding_log_prob
}
fn score_with_fallback(&self, word: &str, context: &[&str]) -> f64 {
let min_log_prob = -50.0;
if self.ngram.count(&[word]) > 0 {
self.ngram.log_prob(word, context).max(min_log_prob)
} else {
self.embedding_log_prob(word, context).max(min_log_prob)
}
}
fn embedding_log_prob(&self, word: &str, context: &[&str]) -> f64 {
if context.is_empty() {
return -(self.embedding.vocab_size() as f64).ln();
}
let word_vec = self.embedding.word_vector(word);
let context_vec = self.embedding.sentence_vector(context);
let similarity = Self::cosine_similarity(&word_vec, &context_vec);
let scaled_sim = (similarity as f64) / self.config.temperature;
let log_prob = scaled_sim - 1.0;
log_prob.max((self.config.embedding_smoothing).ln())
}
fn cosine_similarity(a: &ndarray::Array1<f32>, b: &ndarray::Array1<f32>) -> f32 {
let dot = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn sentence_log_prob(&self, words: &[&str]) -> f64 {
if words.is_empty() {
return 0.0;
}
let order = self.ngram.order();
let mut total_log_prob = 0.0;
for (i, word) in words.iter().enumerate() {
let context_start = i.saturating_sub(order - 1);
let context: Vec<&str> = words[context_start..i].iter().copied().collect();
total_log_prob += self.score(word, &context);
}
total_log_prob
}
pub fn perplexity(&self, words: &[&str]) -> f64 {
if words.is_empty() {
return f64::INFINITY;
}
let log_prob = self.sentence_log_prob(words);
let avg_log_prob = log_prob / words.len() as f64;
(-avg_log_prob).exp()
}
pub fn predict_next(&self, context: &[&str], candidates: &[&str]) -> Option<(String, f64)> {
if candidates.is_empty() {
return None;
}
let mut best_word = String::new();
let mut best_score = f64::NEG_INFINITY;
for &candidate in candidates {
let score = self.score(candidate, context);
if score > best_score {
best_score = score;
best_word = candidate.to_string();
}
}
Some((best_word, best_score))
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
unsafe impl<D> Send for HybridLanguageModel<D> where
D: MutableMappedDictionary<Value = NgramEntry> + Send + Sync
{
}
unsafe impl<D> Sync for HybridLanguageModel<D> where
D: MutableMappedDictionary<Value = NgramEntry> + Send + Sync
{
}
#[cfg(feature = "serde-extras")]
impl<D> HybridLanguageModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry>
+ Send
+ Sync
+ serde::Serialize
+ serde::de::DeserializeOwned,
{
pub fn save<P: AsRef<Path>>(&self, path: P) -> crate::Result<()> {
let file = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(file);
bincode::serialize_into(writer, self)?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> crate::Result<Self> {
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
let model = bincode::deserialize_from(reader)?;
Ok(model)
}
}
#[cfg(feature = "serde-extras")]
#[derive(serde::Serialize, serde::Deserialize)]
pub struct PortableHybridModel {
pub ngram: crate::ngram::PortableNgramModel,
pub embedding: SubwordEmbedding,
pub config: HybridConfig,
}
#[cfg(feature = "serde-extras")]
impl<D> HybridLanguageModel<D>
where
D: MutableMappedDictionary<Value = NgramEntry> + Send + Sync,
{
pub fn to_portable(&self) -> PortableHybridModel
where
D: crate::ngram::IterableDictionary,
{
PortableHybridModel {
ngram: self.ngram.to_portable(),
embedding: self.embedding.clone(),
config: self.config.clone(),
}
}
pub fn save_portable<P: AsRef<Path>>(&self, path: P) -> crate::Result<()>
where
D: crate::ngram::IterableDictionary,
{
let portable = self.to_portable();
let file = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(file);
bincode::serialize_into(writer, &portable)?;
Ok(())
}
pub fn load_portable<P, F>(path: P, dictionary_factory: F) -> crate::Result<Self>
where
P: AsRef<Path>,
F: FnOnce() -> D,
{
let file = std::fs::File::open(path)?;
let reader = std::io::BufReader::new(file);
let portable: PortableHybridModel = bincode::deserialize_from(reader)?;
let ngram = crate::ngram::NgramModel::load_portable_from_portable(
portable.ngram,
dictionary_factory,
)?;
let cache = ScoreCache::new(portable.config.cache_size.max(1));
Ok(Self {
ngram,
embedding: portable.embedding,
config: portable.config,
cache,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::corpus::PlaintextReader;
use crate::embedding::EmbeddingTrainerBuilder;
use crate::ngram::TrainerBuilder;
use liblevenshtein::dictionary::pathmap::PathMapDictionary;
use std::io::Write;
use tempfile::TempDir;
fn create_test_corpus(dir: &std::path::Path, content: &str) -> std::path::PathBuf {
let path = dir.join("test.txt");
let mut file = std::fs::File::create(&path).expect("Failed to create test file");
write!(file, "{}", content).expect("Failed to write test file");
path
}
fn create_test_models() -> (NgramModel<PathMapDictionary<NgramEntry>>, SubwordEmbedding) {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox";
let path = create_test_corpus(dir.path(), content);
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
let dictionary = PathMapDictionary::<NgramEntry>::new();
let ngram_model = TrainerBuilder::new(dictionary)
.order(3)
.train(reader)
.expect("N-gram training failed");
let reader2 = PlaintextReader::from_file(&path).expect("Failed to create reader");
let embedding_model = EmbeddingTrainerBuilder::new()
.dim(10)
.window_size(2)
.min_count(1)
.epochs(2)
.train(reader2)
.expect("Embedding training failed");
(ngram_model, embedding_model)
}
#[test]
fn test_hybrid_creation() {
let (ngram, embedding) = create_test_models();
let config = HybridConfig::default();
let _hybrid = HybridLanguageModel::new(ngram, embedding, config);
}
#[test]
fn test_hybrid_score() {
let (ngram, embedding) = create_test_models();
let hybrid = HybridLanguageModel::with_defaults(ngram, embedding);
let score = hybrid.score("fox", &["the", "quick"]);
assert!(score.is_finite());
assert!(score < 0.0); }
#[test]
fn test_sentence_log_prob() {
let (ngram, embedding) = create_test_models();
let hybrid = HybridLanguageModel::with_defaults(ngram, embedding);
let log_prob = hybrid.sentence_log_prob(&["the", "quick", "brown", "fox"]);
assert!(log_prob.is_finite());
assert!(log_prob < 0.0);
}
#[test]
fn test_perplexity() {
let (ngram, embedding) = create_test_models();
let hybrid = HybridLanguageModel::with_defaults(ngram, embedding);
let ppl = hybrid.perplexity(&["the", "quick", "brown", "fox"]);
assert!(ppl.is_finite());
assert!(ppl > 0.0);
}
#[test]
fn test_interpolation_strategies() {
let (ngram, embedding) = create_test_models();
let config1 = HybridConfig {
strategy: InterpolationStrategy::Linear { alpha: 0.5 },
..Default::default()
};
let hybrid1 = HybridLanguageModel::new(ngram.clone(), embedding.clone(), config1);
let score1 = hybrid1.score("fox", &["the"]);
assert!(score1.is_finite());
let config2 = HybridConfig {
strategy: InterpolationStrategy::LogLinear { alpha: 0.5 },
..Default::default()
};
let hybrid2 = HybridLanguageModel::new(ngram.clone(), embedding.clone(), config2);
let score2 = hybrid2.score("fox", &["the"]);
assert!(score2.is_finite());
let config3 = HybridConfig {
strategy: InterpolationStrategy::NgramWithEmbeddingFallback,
..Default::default()
};
let hybrid3 = HybridLanguageModel::new(ngram.clone(), embedding.clone(), config3);
let score3 = hybrid3.score("fox", &["the"]);
assert!(score3.is_finite());
}
#[test]
fn test_cache() {
let (ngram, embedding) = create_test_models();
let hybrid = HybridLanguageModel::with_defaults(ngram, embedding);
let score1 = hybrid.score("fox", &["the", "quick"]);
let score2 = hybrid.score("fox", &["the", "quick"]);
assert_eq!(score1, score2);
hybrid.clear_cache();
}
#[test]
fn test_predict_next() {
let (ngram, embedding) = create_test_models();
let hybrid = HybridLanguageModel::with_defaults(ngram, embedding);
let candidates = ["fox", "dog", "cat"];
let result = hybrid.predict_next(&["the", "quick", "brown"], &candidates);
assert!(result.is_some());
let (word, score) = result.unwrap();
assert!(candidates.contains(&word.as_str()));
assert!(score.is_finite());
}
#[cfg(feature = "serde-extras")]
mod serde_tests {
use super::*;
use liblevenshtein::dictionary::dynamic_dawg_char::DynamicDawgChar;
fn create_serializable_test_models(
) -> (NgramModel<DynamicDawgChar<NgramEntry>>, SubwordEmbedding) {
let dir = TempDir::new().expect("Failed to create temp dir");
let content = "the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox \
the quick brown fox the quick brown dog the lazy fox";
let path = create_test_corpus(dir.path(), content);
let reader = PlaintextReader::from_file(&path).expect("Failed to create reader");
let dictionary = DynamicDawgChar::<NgramEntry>::new();
let ngram_model = TrainerBuilder::new(dictionary)
.order(3)
.train(reader)
.expect("N-gram training failed");
let reader2 = PlaintextReader::from_file(&path).expect("Failed to create reader");
let embedding_model = EmbeddingTrainerBuilder::new()
.dim(10)
.window_size(2)
.min_count(1)
.epochs(2)
.train(reader2)
.expect("Embedding training failed");
(ngram_model, embedding_model)
}
#[test]
fn test_hybrid_save_load_roundtrip() {
let (ngram, embedding) = create_serializable_test_models();
let config = HybridConfig {
strategy: InterpolationStrategy::Linear { alpha: 0.7 },
cache_size: 1000,
..Default::default()
};
let hybrid = HybridLanguageModel::new(ngram, embedding, config);
let temp_file = tempfile::NamedTempFile::new().expect("Failed to create temp file");
hybrid
.save(temp_file.path())
.expect("Failed to save hybrid model");
let metadata =
std::fs::metadata(temp_file.path()).expect("Failed to get file metadata");
assert!(metadata.len() > 0, "Saved model file should not be empty");
let loaded: HybridLanguageModel<DynamicDawgChar<NgramEntry>> =
HybridLanguageModel::load(temp_file.path()).expect("Failed to load hybrid model");
assert_eq!(hybrid.config().cache_size, loaded.config().cache_size);
match (hybrid.config().strategy, loaded.config().strategy) {
(
InterpolationStrategy::Linear { alpha: a1 },
InterpolationStrategy::Linear { alpha: a2 },
) => {
assert!((a1 - a2).abs() < 1e-10, "Alpha should match");
}
_ => panic!("Strategy should match"),
}
let orig_score = hybrid.score("fox", &["the", "quick"]);
let loaded_score = loaded.score("fox", &["the", "quick"]);
assert!(
(orig_score - loaded_score).abs() < 1e-10,
"Scores should match after roundtrip: {} vs {}",
orig_score,
loaded_score
);
let orig_sentence = hybrid.sentence_log_prob(&["the", "quick", "brown", "fox"]);
let loaded_sentence = loaded.sentence_log_prob(&["the", "quick", "brown", "fox"]);
assert!(
(orig_sentence - loaded_sentence).abs() < 1e-10,
"Sentence scores should match: {} vs {}",
orig_sentence,
loaded_sentence
);
let orig_ppl = hybrid.perplexity(&["the", "quick", "brown", "fox"]);
let loaded_ppl = loaded.perplexity(&["the", "quick", "brown", "fox"]);
assert!(
(orig_ppl - loaded_ppl).abs() < 1e-8,
"Perplexity should match: {} vs {}",
orig_ppl,
loaded_ppl
);
}
}
}