Skip to main content

kenlm/
lib.rs

1//! Safe Rust bindings for KenLM language model inference.
2//!
3//! The crate compiles KenLM's query-only C++ sources and exposes a small Rust
4//! API for loading ARPA or binary models, scoring sentences, inspecting
5//! vocabulary membership, and doing explicit stateful scoring.
6//!
7//! # Safety model
8//!
9//! The public API is safe Rust. C++ exceptions are caught in the C++ shim and
10//! converted into [`KenlmError`] values. Opaque [`State`] values carry an
11//! internal model identity token, so stateful scoring rejects states created by
12//! a different [`Model`] before calling into KenLM. The raw KenLM model handle
13//! is owned by [`Model`] and released exactly once in `Drop`.
14
15use std::error::Error;
16use std::ffi::{CStr, CString, NulError};
17use std::fmt;
18use std::os::raw::{c_char, c_float, c_int, c_uint, c_void};
19use std::path::Path;
20use std::ptr::NonNull;
21use std::sync::Arc;
22
23#[cfg(any(
24    feature = "tools",
25    feature = "estimation",
26    feature = "filter",
27    feature = "interpolate"
28))]
29pub mod commands;
30
31/// Crate-local result type.
32pub type Result<T> = std::result::Result<T, KenlmError>;
33
34/// KenLM vocabulary index.
35pub type WordIndex = u32;
36
37#[repr(C)]
38struct RawModel {
39    _private: [u8; 0],
40}
41
42#[repr(C)]
43#[derive(Clone, Copy)]
44struct RawConfig {
45    load_method: c_int,
46    arpa_complain: c_int,
47    probing_multiplier: c_float,
48    unknown_missing_logprob: c_float,
49    show_progress: u8,
50}
51
52#[repr(C)]
53#[derive(Clone, Copy)]
54struct RawFullScore {
55    prob: c_float,
56    ngram_length: u8,
57    independent_left: u8,
58    extend_left: u64,
59    rest: c_float,
60}
61
62extern "C" {
63    fn kenlm_config_default(config: *mut RawConfig);
64    fn kenlm_model_load(path: *const c_char, config: *const RawConfig) -> *mut RawModel;
65    fn kenlm_model_free(model: *mut RawModel);
66    fn kenlm_last_error() -> *const c_char;
67    fn kenlm_model_state_size(model: *const RawModel) -> usize;
68    fn kenlm_model_order(model: *const RawModel) -> u8;
69    fn kenlm_model_begin_sentence_write(model: *const RawModel, state: *mut c_void);
70    fn kenlm_model_null_context_write(model: *const RawModel, state: *mut c_void);
71    fn kenlm_model_try_index(
72        model: *const RawModel,
73        word: *const c_char,
74        out: *mut c_uint,
75    ) -> c_int;
76    fn kenlm_model_begin_sentence_index(model: *const RawModel) -> c_uint;
77    fn kenlm_model_end_sentence_index(model: *const RawModel) -> c_uint;
78    fn kenlm_model_not_found_index(model: *const RawModel) -> c_uint;
79    fn kenlm_model_try_base_score(
80        model: *const RawModel,
81        in_state: *const c_void,
82        word: c_uint,
83        out_state: *mut c_void,
84        out: *mut c_float,
85    ) -> c_int;
86    fn kenlm_model_try_base_full_score(
87        model: *const RawModel,
88        in_state: *const c_void,
89        word: c_uint,
90        out_state: *mut c_void,
91        out: *mut RawFullScore,
92    ) -> c_int;
93}
94
95/// Errors returned by the KenLM bindings.
96#[derive(Debug)]
97pub enum KenlmError {
98    /// A path or word contained an interior NUL byte and cannot cross the C ABI.
99    InteriorNul(NulError),
100    /// KenLM could not load the requested model.
101    Load(String),
102    /// A state created by one model was used with another model.
103    StateModelMismatch,
104}
105
106impl fmt::Display for KenlmError {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        match self {
109            KenlmError::InteriorNul(error) => {
110                write!(f, "string contains an interior NUL byte: {error}")
111            }
112            KenlmError::Load(error) => f.write_str(error),
113            KenlmError::StateModelMismatch => {
114                f.write_str("KenLM state was created by a different model")
115            }
116        }
117    }
118}
119
120impl Error for KenlmError {}
121
122impl From<NulError> for KenlmError {
123    fn from(value: NulError) -> Self {
124        KenlmError::InteriorNul(value)
125    }
126}
127
128/// How KenLM should bring binary model data into memory.
129#[derive(Clone, Copy, Debug, Eq, PartialEq)]
130#[repr(i32)]
131pub enum LoadMethod {
132    Lazy = 0,
133    PopulateOrLazy = 1,
134    PopulateOrRead = 2,
135    Read = 3,
136    ParallelRead = 4,
137}
138
139/// How loudly KenLM should complain about loading ARPA instead of binary data.
140#[derive(Clone, Copy, Debug, Eq, PartialEq)]
141#[repr(i32)]
142pub enum ArpaLoadComplain {
143    All = 0,
144    Expensive = 1,
145    None = 2,
146}
147
148/// Runtime loading options for KenLM models.
149#[derive(Clone, Copy, Debug)]
150pub struct Config {
151    pub load_method: LoadMethod,
152    pub arpa_complain: ArpaLoadComplain,
153    pub probing_multiplier: f32,
154    pub unknown_missing_logprob: f32,
155    pub show_progress: bool,
156}
157
158impl Config {
159    fn as_raw(self) -> RawConfig {
160        RawConfig {
161            load_method: self.load_method as c_int,
162            arpa_complain: self.arpa_complain as c_int,
163            probing_multiplier: self.probing_multiplier,
164            unknown_missing_logprob: self.unknown_missing_logprob,
165            show_progress: u8::from(self.show_progress),
166        }
167    }
168}
169
170impl Default for Config {
171    fn default() -> Self {
172        let mut raw = RawConfig {
173            load_method: LoadMethod::Lazy as c_int,
174            arpa_complain: ArpaLoadComplain::All as c_int,
175            probing_multiplier: 1.5,
176            unknown_missing_logprob: -100.0,
177            show_progress: 1,
178        };
179        // SAFETY: `raw` points to a valid, writable `RawConfig` with the same C
180        // layout as `KenlmConfig`. The C++ wrapper only writes scalar fields.
181        unsafe {
182            kenlm_config_default(&mut raw);
183        }
184        Self {
185            load_method: match raw.load_method {
186                1 => LoadMethod::PopulateOrLazy,
187                2 => LoadMethod::PopulateOrRead,
188                3 => LoadMethod::Read,
189                4 => LoadMethod::ParallelRead,
190                _ => LoadMethod::Lazy,
191            },
192            arpa_complain: match raw.arpa_complain {
193                1 => ArpaLoadComplain::Expensive,
194                2 => ArpaLoadComplain::None,
195                _ => ArpaLoadComplain::All,
196            },
197            probing_multiplier: raw.probing_multiplier,
198            unknown_missing_logprob: raw.unknown_missing_logprob,
199            show_progress: raw.show_progress != 0,
200        }
201    }
202}
203
204/// A KenLM model loaded from an ARPA or KenLM binary file.
205pub struct Model {
206    raw: NonNull<RawModel>,
207    state_size: usize,
208    token: Arc<ModelToken>,
209}
210
211#[derive(Debug)]
212struct ModelToken;
213
214// KenLM model scoring is read-only after construction. Callers provide separate
215// state buffers for each transition, so sharing a loaded model is safe.
216unsafe impl Send for Model {}
217unsafe impl Sync for Model {}
218
219impl Model {
220    /// Load a language model with default configuration.
221    pub fn new(path: impl AsRef<Path>) -> Result<Self> {
222        Self::with_config(path, Config::default())
223    }
224
225    /// Load a language model with explicit configuration.
226    pub fn with_config(path: impl AsRef<Path>, config: Config) -> Result<Self> {
227        let path = path.as_ref().as_os_str().to_string_lossy();
228        let path = CString::new(path.as_bytes())?;
229        let raw_config = config.as_raw();
230        // SAFETY: `path` and `raw_config` are valid for the duration of the call.
231        // The C++ wrapper catches exceptions and returns null on failure.
232        let raw = unsafe { kenlm_model_load(path.as_ptr(), &raw_config) };
233        let raw = NonNull::new(raw).ok_or_else(last_error)?;
234        // SAFETY: `raw` is a non-null KenLM handle returned by
235        // `kenlm_model_load` and remains owned by `Self` until `Drop`.
236        let state_size = unsafe { kenlm_model_state_size(raw.as_ptr()) };
237        Ok(Self {
238            raw,
239            state_size,
240            token: Arc::new(ModelToken),
241        })
242    }
243
244    /// Return the n-gram order of the model.
245    pub fn order(&self) -> u8 {
246        // SAFETY: `self.raw` is a live KenLM handle for the lifetime of `self`.
247        unsafe { kenlm_model_order(self.raw.as_ptr()) }
248    }
249
250    /// Return true when `word` exists in the model vocabulary.
251    pub fn contains(&self, word: &str) -> Result<bool> {
252        Ok(self.index(word)? != self.not_found_index())
253    }
254
255    /// Return KenLM's vocabulary index for `word`, or the not-found index for OOV words.
256    pub fn index(&self, word: &str) -> Result<WordIndex> {
257        let word = CString::new(word)?;
258        // SAFETY: `self.raw` is live and `word` is a valid NUL-terminated C
259        // string for the duration of the call.
260        let mut index = 0;
261        let status = unsafe { kenlm_model_try_index(self.raw.as_ptr(), word.as_ptr(), &mut index) };
262        if status == 0 {
263            Ok(index as WordIndex)
264        } else {
265            Err(last_error())
266        }
267    }
268
269    /// Return the index for `<s>`.
270    pub fn begin_sentence_index(&self) -> WordIndex {
271        // SAFETY: `self.raw` is a live KenLM handle for the lifetime of `self`.
272        unsafe { kenlm_model_begin_sentence_index(self.raw.as_ptr()) as WordIndex }
273    }
274
275    /// Return the index for `</s>`.
276    pub fn end_sentence_index(&self) -> WordIndex {
277        // SAFETY: `self.raw` is a live KenLM handle for the lifetime of `self`.
278        unsafe { kenlm_model_end_sentence_index(self.raw.as_ptr()) as WordIndex }
279    }
280
281    /// Return the vocabulary index used for out-of-vocabulary words.
282    pub fn not_found_index(&self) -> WordIndex {
283        // SAFETY: `self.raw` is a live KenLM handle for the lifetime of `self`.
284        unsafe { kenlm_model_not_found_index(self.raw.as_ptr()) as WordIndex }
285    }
286
287    /// Score a whitespace-tokenized sentence, returning log10 probability.
288    ///
289    /// With `bos = true` and `eos = true`, this returns
290    /// `log10 p(sentence </s> | <s>)`.
291    pub fn score(&self, sentence: &str, bos: bool, eos: bool) -> Result<f32> {
292        self.score_words(sentence.split_whitespace(), bos, eos)
293    }
294
295    /// Score pre-tokenized words, returning log10 probability.
296    pub fn score_words<'a>(
297        &self,
298        words: impl IntoIterator<Item = &'a str>,
299        bos: bool,
300        eos: bool,
301    ) -> Result<f32> {
302        let mut state = self.initial_state(bos);
303        let mut next = self.empty_state();
304        let mut total = 0.0;
305
306        for word in words {
307            let index = self.index(word)?;
308            total += self.base_score(&state, index, &mut next)?;
309            std::mem::swap(&mut state, &mut next);
310        }
311
312        if eos {
313            total += self.base_score(&state, self.end_sentence_index(), &mut next)?;
314        }
315
316        Ok(total)
317    }
318
319    /// Return perplexity for a complete whitespace-tokenized sentence.
320    pub fn perplexity(&self, sentence: &str) -> Result<f32> {
321        let words = sentence.split_whitespace().count() + 1;
322        Ok(10.0_f32.powf(-self.score(sentence, true, true)? / words as f32))
323    }
324
325    /// Return per-token full scores for a whitespace-tokenized sentence.
326    pub fn full_scores(&self, sentence: &str, bos: bool, eos: bool) -> Result<Vec<TokenScore>> {
327        self.full_scores_words(sentence.split_whitespace(), bos, eos)
328    }
329
330    /// Return per-token full scores for pre-tokenized words.
331    pub fn full_scores_words<'a>(
332        &self,
333        words: impl IntoIterator<Item = &'a str>,
334        bos: bool,
335        eos: bool,
336    ) -> Result<Vec<TokenScore>> {
337        let mut state = self.initial_state(bos);
338        let mut next = self.empty_state();
339        let mut scores = Vec::new();
340
341        for word in words {
342            let index = self.index(word)?;
343            let full_score = self.base_full_score(&state, index, &mut next)?;
344            scores.push(TokenScore {
345                log_prob: full_score.log_prob,
346                ngram_length: full_score.ngram_length,
347                oov: index == self.not_found_index(),
348            });
349            std::mem::swap(&mut state, &mut next);
350        }
351
352        if eos {
353            let full_score = self.base_full_score(&state, self.end_sentence_index(), &mut next)?;
354            scores.push(TokenScore {
355                log_prob: full_score.log_prob,
356                ngram_length: full_score.ngram_length,
357                oov: false,
358            });
359        }
360
361        Ok(scores)
362    }
363
364    /// Create a state initialized to beginning-of-sentence context.
365    pub fn begin_sentence_state(&self) -> State {
366        let mut state = self.empty_state();
367        // SAFETY: `state` is exactly `self.state_size` bytes and belongs to
368        // this model. KenLM writes a POD state into the provided buffer.
369        unsafe {
370            kenlm_model_begin_sentence_write(self.raw.as_ptr(), state.as_mut_ptr());
371        }
372        state
373    }
374
375    /// Create a state initialized to null context.
376    pub fn null_context_state(&self) -> State {
377        let mut state = self.empty_state();
378        // SAFETY: `state` is exactly `self.state_size` bytes and belongs to
379        // this model. KenLM writes a POD state into the provided buffer.
380        unsafe {
381            kenlm_model_null_context_write(self.raw.as_ptr(), state.as_mut_ptr());
382        }
383        state
384    }
385
386    /// Score `word_index` from `in_state`, writing the next state into `out_state`.
387    pub fn base_score(
388        &self,
389        in_state: &State,
390        word_index: WordIndex,
391        out_state: &mut State,
392    ) -> Result<f32> {
393        self.validate_state(in_state)?;
394        self.validate_state(out_state)?;
395        // Safe Rust prevents passing the exact same `State` as both `&State`
396        // and `&mut State`; KenLM additionally requires distinct buffers.
397        debug_assert!(!std::ptr::eq(in_state.as_ptr(), out_state.as_ptr()));
398        let mut score = 0.0;
399        // SAFETY: states were created by this model and have the exact byte
400        // size KenLM reported. Input and output buffers are distinct. The C++
401        // wrapper catches exceptions and reports them through its status code.
402        let status = unsafe {
403            kenlm_model_try_base_score(
404                self.raw.as_ptr(),
405                in_state.as_ptr(),
406                word_index as c_uint,
407                out_state.as_mut_ptr(),
408                &mut score,
409            )
410        };
411        if status == 0 {
412            Ok(score)
413        } else {
414            Err(last_error())
415        }
416    }
417
418    /// Return KenLM's full score metadata for a state transition.
419    pub fn base_full_score(
420        &self,
421        in_state: &State,
422        word_index: WordIndex,
423        out_state: &mut State,
424    ) -> Result<FullScore> {
425        self.validate_state(in_state)?;
426        self.validate_state(out_state)?;
427        // Safe Rust prevents passing the exact same `State` as both `&State`
428        // and `&mut State`; KenLM additionally requires distinct buffers.
429        debug_assert!(!std::ptr::eq(in_state.as_ptr(), out_state.as_ptr()));
430        let mut raw = RawFullScore {
431            prob: 0.0,
432            ngram_length: 0,
433            independent_left: 0,
434            extend_left: 0,
435            rest: 0.0,
436        };
437        // SAFETY: states were created by this model and have the exact byte
438        // size KenLM reported. Input and output buffers are distinct. The C++
439        // wrapper catches exceptions and reports them through its status code.
440        let status = unsafe {
441            kenlm_model_try_base_full_score(
442                self.raw.as_ptr(),
443                in_state.as_ptr(),
444                word_index as c_uint,
445                out_state.as_mut_ptr(),
446                &mut raw,
447            )
448        };
449        if status != 0 {
450            return Err(last_error());
451        }
452        Ok(FullScore {
453            log_prob: raw.prob,
454            ngram_length: raw.ngram_length,
455            independent_left: raw.independent_left != 0,
456            extend_left: raw.extend_left,
457            rest: raw.rest,
458        })
459    }
460
461    fn initial_state(&self, bos: bool) -> State {
462        if bos {
463            self.begin_sentence_state()
464        } else {
465            self.null_context_state()
466        }
467    }
468
469    fn empty_state(&self) -> State {
470        State {
471            bytes: vec![0; self.state_size],
472            owner: Arc::clone(&self.token),
473        }
474    }
475
476    fn validate_state(&self, state: &State) -> Result<()> {
477        if state.bytes.len() != self.state_size || !Arc::ptr_eq(&state.owner, &self.token) {
478            return Err(KenlmError::StateModelMismatch);
479        }
480        Ok(())
481    }
482}
483
484impl Drop for Model {
485    fn drop(&mut self) {
486        // SAFETY: `self.raw` was returned by `kenlm_model_load` and has not
487        // been freed yet. `Drop` runs exactly once for `Model`.
488        unsafe {
489            kenlm_model_free(self.raw.as_ptr());
490        }
491    }
492}
493
494/// Opaque KenLM state memory used for incremental scoring.
495#[derive(Clone, Debug)]
496pub struct State {
497    bytes: Vec<u8>,
498    owner: Arc<ModelToken>,
499}
500
501impl State {
502    fn as_ptr(&self) -> *const c_void {
503        self.bytes.as_ptr().cast()
504    }
505
506    fn as_mut_ptr(&mut self) -> *mut c_void {
507        self.bytes.as_mut_ptr().cast()
508    }
509}
510
511impl PartialEq for State {
512    fn eq(&self, other: &Self) -> bool {
513        Arc::ptr_eq(&self.owner, &other.owner) && self.bytes == other.bytes
514    }
515}
516
517impl Eq for State {}
518
519/// Detailed score metadata for one state transition.
520#[derive(Clone, Copy, Debug, PartialEq)]
521pub struct FullScore {
522    pub log_prob: f32,
523    pub ngram_length: u8,
524    pub independent_left: bool,
525    pub extend_left: u64,
526    pub rest: f32,
527}
528
529/// Sentence-level per-token score, including OOV metadata.
530#[derive(Clone, Copy, Debug, PartialEq)]
531pub struct TokenScore {
532    pub log_prob: f32,
533    pub ngram_length: u8,
534    pub oov: bool,
535}
536
537fn last_error() -> KenlmError {
538    // SAFETY: `kenlm_last_error` returns a pointer to thread-local storage in
539    // the C++ wrapper. It is valid until the next wrapper call on this thread.
540    let message = unsafe {
541        let ptr = kenlm_last_error();
542        if ptr.is_null() {
543            String::new()
544        } else {
545            CStr::from_ptr(ptr).to_string_lossy().into_owned()
546        }
547    };
548    if message.is_empty() {
549        KenlmError::Load("unknown KenLM error".to_string())
550    } else {
551        KenlmError::Load(message)
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[test]
560    fn loads_and_scores_test_model() {
561        let config = Config {
562            show_progress: false,
563            ..Config::default()
564        };
565        let model = Model::with_config("lm/test.arpa", config).unwrap();
566
567        assert!(model.order() > 0);
568        assert!(model.contains("looking").unwrap());
569        assert!(!model.contains("definitely-not-in-this-model").unwrap());
570
571        let score = model.score("looking on a little", true, true).unwrap();
572        assert!(score.is_finite());
573
574        let full_scores = model
575            .full_scores("looking on a little", true, true)
576            .unwrap();
577        assert_eq!(full_scores.len(), 5);
578        assert!(full_scores.iter().all(|score| score.log_prob.is_finite()));
579    }
580
581    #[test]
582    fn supports_stateful_scoring() {
583        let config = Config {
584            show_progress: false,
585            ..Config::default()
586        };
587        let model = Model::with_config("lm/test.arpa", config).unwrap();
588
589        let mut state = model.begin_sentence_state();
590        let mut out = model.null_context_state();
591        let looking = model.index("looking").unwrap();
592
593        let score = model.base_score(&state, looking, &mut out).unwrap();
594        assert!(score.is_finite());
595
596        std::mem::swap(&mut state, &mut out);
597        let full = model
598            .base_full_score(&state, model.end_sentence_index(), &mut out)
599            .unwrap();
600        assert!(full.log_prob.is_finite());
601    }
602
603    #[test]
604    fn rejects_states_from_other_models() {
605        let config = Config {
606            show_progress: false,
607            ..Config::default()
608        };
609        let first = Model::with_config("lm/test.arpa", config).unwrap();
610        let second = Model::with_config("lm/test.arpa", config).unwrap();
611
612        let state = first.begin_sentence_state();
613        let mut out = second.null_context_state();
614        let word = second.index("looking").unwrap();
615
616        let error = second.base_score(&state, word, &mut out).unwrap_err();
617        assert!(matches!(error, KenlmError::StateModelMismatch));
618    }
619}