Skip to main content

scirs2_text/
language_model.rs

1//! N-gram Language Models
2//!
3//! This module provides statistical language models based on n-grams.
4//! N-gram language models estimate the probability of word sequences
5//! and can be used for text generation, auto-completion, and more.
6//!
7//! ## Overview
8//!
9//! An n-gram is a contiguous sequence of n items from a given text.
10//! The n-gram model estimates:
11//!
12//! P(w_n | w_1, w_2, ..., w_{n-1})
13//!
14//! ## Supported Models
15//!
16//! - **Unigram**: P(word)
17//! - **Bigram**: P(word | previous_word)
18//! - **Trigram**: P(word | previous_two_words)
19//! - **N-gram**: P(word | previous_n-1_words)
20//!
21//! ## Smoothing Techniques
22//!
23//! - **Laplace (Add-1) Smoothing**: Adds 1 to all counts
24//! - **Add-k Smoothing**: Adds k to all counts
25//! - **Kneser-Ney Smoothing**: Advanced smoothing based on continuation probability
26//!
27//! ## Quick Start
28//!
29//! ```rust
30//! use scirs2_text::language_model::{NgramModel, SmoothingMethod};
31//!
32//! // Create a bigram model
33//! let texts = vec![
34//!     "the quick brown fox jumps over the lazy dog",
35//!     "the dog was lazy but the fox was quick"
36//! ];
37//!
38//! let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
39//! model.train(&texts).expect("Training failed");
40//!
41//! // Calculate probability
42//! let prob = model.probability(&["the"], "quick").expect("Failed to get probability");
43//! println!("P(quick | the) = {}", prob);
44//!
45//! // Generate text
46//! let text = model.generate(10, Some("the")).expect("Generation failed");
47//! println!("Generated: {}", text);
48//! ```
49
50use crate::error::{Result, TextError};
51use crate::tokenize::{Tokenizer, WordTokenizer};
52use scirs2_core::random::prelude::*;
53use std::collections::HashMap;
54use std::fmt::Debug;
55
56/// Smoothing methods for n-gram models
57#[derive(Debug, Clone, Copy, PartialEq)]
58pub enum SmoothingMethod {
59    /// No smoothing (maximum likelihood estimation)
60    None,
61    /// Laplace (add-1) smoothing
62    Laplace,
63    /// Add-k smoothing with custom k value
64    AddK(f64),
65    /// Kneser-Ney smoothing with discount parameter
66    KneserNey(f64),
67}
68
69/// N-gram language model
70pub struct NgramModel {
71    /// Order of the n-gram model (n)
72    n: usize,
73    /// Smoothing method
74    smoothing: SmoothingMethod,
75    /// N-gram counts: (context, word) -> count
76    ngram_counts: HashMap<Vec<String>, HashMap<String, usize>>,
77    /// Context counts for normalization
78    context_counts: HashMap<Vec<String>, usize>,
79    /// Vocabulary
80    vocabulary: Vec<String>,
81    /// Total word count
82    total_words: usize,
83    /// Tokenizer
84    tokenizer: Box<dyn Tokenizer + Send + Sync>,
85}
86
87impl Debug for NgramModel {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("NgramModel")
90            .field("n", &self.n)
91            .field("smoothing", &self.smoothing)
92            .field("vocabulary_size", &self.vocabulary.len())
93            .field("total_words", &self.total_words)
94            .finish()
95    }
96}
97
98impl Clone for NgramModel {
99    fn clone(&self) -> Self {
100        Self {
101            n: self.n,
102            smoothing: self.smoothing,
103            ngram_counts: self.ngram_counts.clone(),
104            context_counts: self.context_counts.clone(),
105            vocabulary: self.vocabulary.clone(),
106            total_words: self.total_words,
107            tokenizer: Box::new(WordTokenizer::default()),
108        }
109    }
110}
111
112impl NgramModel {
113    /// Create a new n-gram model
114    ///
115    /// # Arguments
116    ///
117    /// * `n` - Order of the model (1 for unigram, 2 for bigram, etc.)
118    /// * `smoothing` - Smoothing method to use
119    pub fn new(n: usize, smoothing: SmoothingMethod) -> Self {
120        if n == 0 {
121            panic!("N-gram order must be at least 1");
122        }
123
124        Self {
125            n,
126            smoothing,
127            ngram_counts: HashMap::new(),
128            context_counts: HashMap::new(),
129            vocabulary: Vec::new(),
130            total_words: 0,
131            tokenizer: Box::new(WordTokenizer::default()),
132        }
133    }
134
135    /// Set a custom tokenizer
136    pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer + Send + Sync>) -> Self {
137        self.tokenizer = tokenizer;
138        self
139    }
140
141    /// Train the model on a corpus
142    pub fn train(&mut self, texts: &[&str]) -> Result<()> {
143        if texts.is_empty() {
144            return Err(TextError::InvalidInput(
145                "No texts provided for training".into(),
146            ));
147        }
148
149        // Clear existing data
150        self.ngram_counts.clear();
151        self.context_counts.clear();
152        self.vocabulary.clear();
153        self.total_words = 0;
154
155        // Collect vocabulary
156        let mut vocab_set = std::collections::HashSet::new();
157
158        for &text in texts {
159            let tokens = self.tokenizer.tokenize(text)?;
160
161            // Add start and end markers
162            let mut augmented_tokens = vec!["<START>".to_string(); self.n - 1];
163            augmented_tokens.extend(tokens);
164            augmented_tokens.push("<END>".to_string());
165
166            // Build vocabulary
167            for token in &augmented_tokens {
168                vocab_set.insert(token.clone());
169            }
170
171            // Count n-grams
172            for i in (self.n - 1)..augmented_tokens.len() {
173                let context = augmented_tokens[i - (self.n - 1)..i].to_vec();
174                let word = &augmented_tokens[i];
175
176                // Update n-gram counts
177                *self
178                    .ngram_counts
179                    .entry(context.clone())
180                    .or_default()
181                    .entry(word.clone())
182                    .or_insert(0) += 1;
183
184                // Update context counts
185                *self.context_counts.entry(context).or_insert(0) += 1;
186
187                self.total_words += 1;
188            }
189        }
190
191        self.vocabulary = vocab_set.into_iter().collect();
192        self.vocabulary.sort();
193
194        Ok(())
195    }
196
197    /// Calculate the probability of a word given its context
198    ///
199    /// # Arguments
200    ///
201    /// * `context` - The previous n-1 words
202    /// * `word` - The word to predict
203    ///
204    /// # Returns
205    ///
206    /// The probability P(word | context)
207    pub fn probability(&self, context: &[&str], word: &str) -> Result<f64> {
208        if context.len() != self.n - 1 {
209            return Err(TextError::InvalidInput(format!(
210                "Context must have exactly {} words for {}-gram model",
211                self.n - 1,
212                self.n
213            )));
214        }
215
216        let context_vec: Vec<String> = context.iter().map(|s| s.to_string()).collect();
217        let vocab_size = self.vocabulary.len();
218
219        match self.smoothing {
220            SmoothingMethod::None => {
221                // Maximum likelihood estimation
222                let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
223
224                if context_count == 0 {
225                    return Ok(0.0);
226                }
227
228                let ngram_count = self
229                    .ngram_counts
230                    .get(&context_vec)
231                    .and_then(|words| words.get(word))
232                    .copied()
233                    .unwrap_or(0);
234
235                Ok(ngram_count as f64 / context_count as f64)
236            }
237            SmoothingMethod::Laplace => {
238                // Add-1 smoothing
239                let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
240
241                let ngram_count = self
242                    .ngram_counts
243                    .get(&context_vec)
244                    .and_then(|words| words.get(word))
245                    .copied()
246                    .unwrap_or(0);
247
248                Ok((ngram_count + 1) as f64 / (context_count + vocab_size) as f64)
249            }
250            SmoothingMethod::AddK(k) => {
251                // Add-k smoothing
252                let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
253
254                let ngram_count = self
255                    .ngram_counts
256                    .get(&context_vec)
257                    .and_then(|words| words.get(word))
258                    .copied()
259                    .unwrap_or(0);
260
261                Ok((ngram_count as f64 + k) / (context_count as f64 + k * vocab_size as f64))
262            }
263            SmoothingMethod::KneserNey(discount) => {
264                // Simplified Kneser-Ney smoothing
265                let context_count = self.context_counts.get(&context_vec).copied().unwrap_or(0);
266
267                if context_count == 0 {
268                    return Ok(1.0 / vocab_size as f64);
269                }
270
271                let ngram_count = self
272                    .ngram_counts
273                    .get(&context_vec)
274                    .and_then(|words| words.get(word))
275                    .copied()
276                    .unwrap_or(0);
277
278                let adjusted_count = (ngram_count as f64 - discount).max(0.0);
279                let lambda = discount
280                    * self
281                        .ngram_counts
282                        .get(&context_vec)
283                        .map(|m| m.len())
284                        .unwrap_or(0) as f64
285                    / context_count as f64;
286
287                let continuation_prob = 1.0 / vocab_size as f64;
288
289                Ok(adjusted_count / context_count as f64 + lambda * continuation_prob)
290            }
291        }
292    }
293
294    /// Calculate perplexity on a test corpus
295    ///
296    /// Perplexity is a measure of how well the model predicts the test data.
297    /// Lower perplexity indicates better performance.
298    pub fn perplexity(&self, texts: &[&str]) -> Result<f64> {
299        if texts.is_empty() {
300            return Err(TextError::InvalidInput("No test texts provided".into()));
301        }
302
303        let mut log_prob_sum = 0.0;
304        let mut word_count = 0;
305
306        for &text in texts {
307            let tokens = self.tokenizer.tokenize(text)?;
308
309            let mut augmented_tokens = vec!["<START>".to_string(); self.n - 1];
310            augmented_tokens.extend(tokens);
311            augmented_tokens.push("<END>".to_string());
312
313            for i in (self.n - 1)..augmented_tokens.len() {
314                let context: Vec<&str> = augmented_tokens[i - (self.n - 1)..i]
315                    .iter()
316                    .map(|s| s.as_str())
317                    .collect();
318                let word = &augmented_tokens[i];
319
320                let prob = self.probability(&context, word)?;
321
322                if prob > 0.0 {
323                    log_prob_sum += prob.ln();
324                    word_count += 1;
325                } else {
326                    // Avoid log(0) by using a small probability
327                    log_prob_sum += f64::ln(1e-10);
328                    word_count += 1;
329                }
330            }
331        }
332
333        if word_count == 0 {
334            return Ok(f64::INFINITY);
335        }
336
337        Ok((-log_prob_sum / word_count as f64).exp())
338    }
339
340    /// Generate text using the language model
341    ///
342    /// # Arguments
343    ///
344    /// * `max_length` - Maximum number of words to generate
345    /// * `start_context` - Optional starting context (must have n-1 words)
346    ///
347    /// # Returns
348    ///
349    /// Generated text as a string
350    pub fn generate(&self, max_length: usize, start_context: Option<&str>) -> Result<String> {
351        let mut rng = scirs2_core::random::rng();
352        let mut generated = Vec::new();
353
354        // Initialize context
355        let mut context: Vec<String> = if let Some(start) = start_context {
356            let tokens = self.tokenizer.tokenize(start)?;
357            if tokens.len() < self.n - 1 {
358                let mut ctx = vec!["<START>".to_string(); self.n - 1 - tokens.len()];
359                ctx.extend(tokens);
360                ctx
361            } else {
362                tokens.into_iter().rev().take(self.n - 1).rev().collect()
363            }
364        } else {
365            vec!["<START>".to_string(); self.n - 1]
366        };
367
368        // Generate words
369        for _ in 0..max_length {
370            let context_refs: Vec<&str> = context.iter().map(|s| s.as_str()).collect();
371
372            // Get possible next words and their probabilities
373            let candidates = match self.ngram_counts.get(&context) {
374                Some(words) => words,
375                None => {
376                    // If context not found, sample from vocabulary
377                    break;
378                }
379            };
380
381            if candidates.is_empty() {
382                break;
383            }
384
385            // Sample next word based on probabilities
386            let total: usize = candidates.values().sum();
387            let mut threshold = rng.random_range(0..total);
388            let mut next_word = String::new();
389
390            for (word, &count) in candidates {
391                if threshold < count {
392                    next_word = word.clone();
393                    break;
394                }
395                threshold -= count;
396            }
397
398            if next_word == "<END>" {
399                break;
400            }
401
402            if next_word != "<START>" {
403                generated.push(next_word.clone());
404            }
405
406            // Update context
407            context.remove(0);
408            context.push(next_word);
409        }
410
411        Ok(generated.join(" "))
412    }
413
414    /// Get the most likely next words given a context
415    ///
416    /// # Arguments
417    ///
418    /// * `context` - The previous n-1 words
419    /// * `top_n` - Number of suggestions to return
420    ///
421    /// # Returns
422    ///
423    /// Vector of (word, probability) pairs, sorted by probability (descending)
424    pub fn suggest_next(&self, context: &[&str], top_n: usize) -> Result<Vec<(String, f64)>> {
425        if context.len() != self.n - 1 {
426            return Err(TextError::InvalidInput(format!(
427                "Context must have exactly {} words",
428                self.n - 1
429            )));
430        }
431
432        let context_vec: Vec<String> = context.iter().map(|s| s.to_string()).collect();
433
434        let candidates = match self.ngram_counts.get(&context_vec) {
435            Some(words) => words,
436            None => {
437                return Ok(Vec::new());
438            }
439        };
440
441        let mut suggestions: Vec<(String, f64)> = candidates
442            .keys()
443            .map(|word| {
444                let prob = self.probability(context, word).unwrap_or(0.0);
445                (word.clone(), prob)
446            })
447            .collect();
448
449        suggestions.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
450
451        Ok(suggestions.into_iter().take(top_n).collect())
452    }
453
454    /// Get the n-gram order
455    pub fn order(&self) -> usize {
456        self.n
457    }
458
459    /// Get the vocabulary size
460    pub fn vocabulary_size(&self) -> usize {
461        self.vocabulary.len()
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_unigram_model() {
471        let texts = vec!["the cat sat on the mat", "the dog sat on the log"];
472
473        let mut model = NgramModel::new(1, SmoothingMethod::Laplace);
474        model.train(&texts).expect("Training failed");
475
476        // "the" appears 4 times out of ~14 total words
477        let prob = model
478            .probability(&[], "the")
479            .expect("Failed to get probability");
480        assert!(prob > 0.0);
481    }
482
483    #[test]
484    fn test_bigram_model() {
485        let texts = vec!["the cat sat", "the dog sat"];
486
487        let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
488        model.train(&texts).expect("Training failed");
489
490        // P(cat | the) should be non-zero
491        let prob = model
492            .probability(&["the"], "cat")
493            .expect("Failed to get probability");
494        assert!(prob > 0.0);
495
496        // P(dog | the) should be non-zero
497        let prob = model
498            .probability(&["the"], "dog")
499            .expect("Failed to get probability");
500        assert!(prob > 0.0);
501    }
502
503    #[test]
504    fn test_trigram_model() {
505        let texts = vec!["the quick brown fox", "the quick red fox"];
506
507        let mut model = NgramModel::new(3, SmoothingMethod::Laplace);
508        model.train(&texts).expect("Training failed");
509
510        // P(brown | the quick)
511        let prob = model
512            .probability(&["the", "quick"], "brown")
513            .expect("Failed to get probability");
514        assert!(prob > 0.0);
515    }
516
517    #[test]
518    fn test_smoothing_methods() {
519        let texts = vec!["the cat sat"];
520
521        // Test Laplace smoothing
522        let mut model_laplace = NgramModel::new(2, SmoothingMethod::Laplace);
523        model_laplace.train(&texts).expect("Training failed");
524
525        let prob_laplace = model_laplace
526            .probability(&["the"], "dog")
527            .expect("Failed to get probability");
528        assert!(
529            prob_laplace > 0.0,
530            "Laplace smoothing should give non-zero probability to unseen n-grams"
531        );
532
533        // Test Add-k smoothing
534        let mut model_addk = NgramModel::new(2, SmoothingMethod::AddK(0.5));
535        model_addk.train(&texts).expect("Training failed");
536
537        let prob_addk = model_addk
538            .probability(&["the"], "dog")
539            .expect("Failed to get probability");
540        assert!(prob_addk > 0.0);
541    }
542
543    #[test]
544    fn test_text_generation() {
545        let texts = vec![
546            "the quick brown fox jumps over the lazy dog",
547            "the quick brown dog runs fast",
548        ];
549
550        let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
551        model.train(&texts).expect("Training failed");
552
553        let generated = model.generate(10, Some("the")).expect("Generation failed");
554        assert!(!generated.is_empty());
555    }
556
557    #[test]
558    fn test_perplexity() {
559        let train_texts = vec!["the cat sat on the mat"];
560        let test_texts = vec!["the cat sat"];
561
562        let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
563        model.train(&train_texts).expect("Training failed");
564
565        let perplexity = model
566            .perplexity(&test_texts)
567            .expect("Failed to calculate perplexity");
568        assert!(perplexity > 0.0);
569        assert!(perplexity.is_finite());
570    }
571
572    #[test]
573    fn test_suggest_next() {
574        let texts = vec!["the cat sat", "the cat ran", "the dog sat"];
575
576        let mut model = NgramModel::new(2, SmoothingMethod::Laplace);
577        model.train(&texts).expect("Training failed");
578
579        let suggestions = model
580            .suggest_next(&["the"], 3)
581            .expect("Failed to get suggestions");
582
583        assert!(!suggestions.is_empty());
584        // "cat" and "dog" should be among the suggestions
585        assert!(suggestions.iter().any(|(word, _)| word == "cat"));
586    }
587}