1use std::error::Error;
16use std::ffi::{CStr, CString, NulError};
17use std::fmt;
18use std::os::raw::{c_char, c_float, c_int, c_uint, c_void};
19use std::path::Path;
20use std::ptr::NonNull;
21use std::sync::Arc;
22
23#[cfg(any(
24 feature = "tools",
25 feature = "estimation",
26 feature = "filter",
27 feature = "interpolate"
28))]
29pub mod commands;
30
31pub type Result<T> = std::result::Result<T, KenlmError>;
33
34pub type WordIndex = u32;
36
37#[repr(C)]
38struct RawModel {
39 _private: [u8; 0],
40}
41
42#[repr(C)]
43#[derive(Clone, Copy)]
44struct RawConfig {
45 load_method: c_int,
46 arpa_complain: c_int,
47 probing_multiplier: c_float,
48 unknown_missing_logprob: c_float,
49 show_progress: u8,
50}
51
52#[repr(C)]
53#[derive(Clone, Copy)]
54struct RawFullScore {
55 prob: c_float,
56 ngram_length: u8,
57 independent_left: u8,
58 extend_left: u64,
59 rest: c_float,
60}
61
62extern "C" {
63 fn kenlm_config_default(config: *mut RawConfig);
64 fn kenlm_model_load(path: *const c_char, config: *const RawConfig) -> *mut RawModel;
65 fn kenlm_model_free(model: *mut RawModel);
66 fn kenlm_last_error() -> *const c_char;
67 fn kenlm_model_state_size(model: *const RawModel) -> usize;
68 fn kenlm_model_order(model: *const RawModel) -> u8;
69 fn kenlm_model_begin_sentence_write(model: *const RawModel, state: *mut c_void);
70 fn kenlm_model_null_context_write(model: *const RawModel, state: *mut c_void);
71 fn kenlm_model_try_index(
72 model: *const RawModel,
73 word: *const c_char,
74 out: *mut c_uint,
75 ) -> c_int;
76 fn kenlm_model_begin_sentence_index(model: *const RawModel) -> c_uint;
77 fn kenlm_model_end_sentence_index(model: *const RawModel) -> c_uint;
78 fn kenlm_model_not_found_index(model: *const RawModel) -> c_uint;
79 fn kenlm_model_try_base_score(
80 model: *const RawModel,
81 in_state: *const c_void,
82 word: c_uint,
83 out_state: *mut c_void,
84 out: *mut c_float,
85 ) -> c_int;
86 fn kenlm_model_try_base_full_score(
87 model: *const RawModel,
88 in_state: *const c_void,
89 word: c_uint,
90 out_state: *mut c_void,
91 out: *mut RawFullScore,
92 ) -> c_int;
93}
94
95#[derive(Debug)]
97pub enum KenlmError {
98 InteriorNul(NulError),
100 Load(String),
102 StateModelMismatch,
104}
105
106impl fmt::Display for KenlmError {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 match self {
109 KenlmError::InteriorNul(error) => {
110 write!(f, "string contains an interior NUL byte: {error}")
111 }
112 KenlmError::Load(error) => f.write_str(error),
113 KenlmError::StateModelMismatch => {
114 f.write_str("KenLM state was created by a different model")
115 }
116 }
117 }
118}
119
120impl Error for KenlmError {}
121
122impl From<NulError> for KenlmError {
123 fn from(value: NulError) -> Self {
124 KenlmError::InteriorNul(value)
125 }
126}
127
128#[derive(Clone, Copy, Debug, Eq, PartialEq)]
130#[repr(i32)]
131pub enum LoadMethod {
132 Lazy = 0,
133 PopulateOrLazy = 1,
134 PopulateOrRead = 2,
135 Read = 3,
136 ParallelRead = 4,
137}
138
139#[derive(Clone, Copy, Debug, Eq, PartialEq)]
141#[repr(i32)]
142pub enum ArpaLoadComplain {
143 All = 0,
144 Expensive = 1,
145 None = 2,
146}
147
148#[derive(Clone, Copy, Debug)]
150pub struct Config {
151 pub load_method: LoadMethod,
152 pub arpa_complain: ArpaLoadComplain,
153 pub probing_multiplier: f32,
154 pub unknown_missing_logprob: f32,
155 pub show_progress: bool,
156}
157
158impl Config {
159 fn as_raw(self) -> RawConfig {
160 RawConfig {
161 load_method: self.load_method as c_int,
162 arpa_complain: self.arpa_complain as c_int,
163 probing_multiplier: self.probing_multiplier,
164 unknown_missing_logprob: self.unknown_missing_logprob,
165 show_progress: u8::from(self.show_progress),
166 }
167 }
168}
169
170impl Default for Config {
171 fn default() -> Self {
172 let mut raw = RawConfig {
173 load_method: LoadMethod::Lazy as c_int,
174 arpa_complain: ArpaLoadComplain::All as c_int,
175 probing_multiplier: 1.5,
176 unknown_missing_logprob: -100.0,
177 show_progress: 1,
178 };
179 unsafe {
182 kenlm_config_default(&mut raw);
183 }
184 Self {
185 load_method: match raw.load_method {
186 1 => LoadMethod::PopulateOrLazy,
187 2 => LoadMethod::PopulateOrRead,
188 3 => LoadMethod::Read,
189 4 => LoadMethod::ParallelRead,
190 _ => LoadMethod::Lazy,
191 },
192 arpa_complain: match raw.arpa_complain {
193 1 => ArpaLoadComplain::Expensive,
194 2 => ArpaLoadComplain::None,
195 _ => ArpaLoadComplain::All,
196 },
197 probing_multiplier: raw.probing_multiplier,
198 unknown_missing_logprob: raw.unknown_missing_logprob,
199 show_progress: raw.show_progress != 0,
200 }
201 }
202}
203
204pub struct Model {
206 raw: NonNull<RawModel>,
207 state_size: usize,
208 token: Arc<ModelToken>,
209}
210
211#[derive(Debug)]
212struct ModelToken;
213
214unsafe impl Send for Model {}
217unsafe impl Sync for Model {}
218
219impl Model {
220 pub fn new(path: impl AsRef<Path>) -> Result<Self> {
222 Self::with_config(path, Config::default())
223 }
224
225 pub fn with_config(path: impl AsRef<Path>, config: Config) -> Result<Self> {
227 let path = path.as_ref().as_os_str().to_string_lossy();
228 let path = CString::new(path.as_bytes())?;
229 let raw_config = config.as_raw();
230 let raw = unsafe { kenlm_model_load(path.as_ptr(), &raw_config) };
233 let raw = NonNull::new(raw).ok_or_else(last_error)?;
234 let state_size = unsafe { kenlm_model_state_size(raw.as_ptr()) };
237 Ok(Self {
238 raw,
239 state_size,
240 token: Arc::new(ModelToken),
241 })
242 }
243
244 pub fn order(&self) -> u8 {
246 unsafe { kenlm_model_order(self.raw.as_ptr()) }
248 }
249
250 pub fn contains(&self, word: &str) -> Result<bool> {
252 Ok(self.index(word)? != self.not_found_index())
253 }
254
255 pub fn index(&self, word: &str) -> Result<WordIndex> {
257 let word = CString::new(word)?;
258 let mut index = 0;
261 let status = unsafe { kenlm_model_try_index(self.raw.as_ptr(), word.as_ptr(), &mut index) };
262 if status == 0 {
263 Ok(index as WordIndex)
264 } else {
265 Err(last_error())
266 }
267 }
268
269 pub fn begin_sentence_index(&self) -> WordIndex {
271 unsafe { kenlm_model_begin_sentence_index(self.raw.as_ptr()) as WordIndex }
273 }
274
275 pub fn end_sentence_index(&self) -> WordIndex {
277 unsafe { kenlm_model_end_sentence_index(self.raw.as_ptr()) as WordIndex }
279 }
280
281 pub fn not_found_index(&self) -> WordIndex {
283 unsafe { kenlm_model_not_found_index(self.raw.as_ptr()) as WordIndex }
285 }
286
287 pub fn score(&self, sentence: &str, bos: bool, eos: bool) -> Result<f32> {
292 self.score_words(sentence.split_whitespace(), bos, eos)
293 }
294
295 pub fn score_words<'a>(
297 &self,
298 words: impl IntoIterator<Item = &'a str>,
299 bos: bool,
300 eos: bool,
301 ) -> Result<f32> {
302 let mut state = self.initial_state(bos);
303 let mut next = self.empty_state();
304 let mut total = 0.0;
305
306 for word in words {
307 let index = self.index(word)?;
308 total += self.base_score(&state, index, &mut next)?;
309 std::mem::swap(&mut state, &mut next);
310 }
311
312 if eos {
313 total += self.base_score(&state, self.end_sentence_index(), &mut next)?;
314 }
315
316 Ok(total)
317 }
318
319 pub fn perplexity(&self, sentence: &str) -> Result<f32> {
321 let words = sentence.split_whitespace().count() + 1;
322 Ok(10.0_f32.powf(-self.score(sentence, true, true)? / words as f32))
323 }
324
325 pub fn full_scores(&self, sentence: &str, bos: bool, eos: bool) -> Result<Vec<TokenScore>> {
327 self.full_scores_words(sentence.split_whitespace(), bos, eos)
328 }
329
330 pub fn full_scores_words<'a>(
332 &self,
333 words: impl IntoIterator<Item = &'a str>,
334 bos: bool,
335 eos: bool,
336 ) -> Result<Vec<TokenScore>> {
337 let mut state = self.initial_state(bos);
338 let mut next = self.empty_state();
339 let mut scores = Vec::new();
340
341 for word in words {
342 let index = self.index(word)?;
343 let full_score = self.base_full_score(&state, index, &mut next)?;
344 scores.push(TokenScore {
345 log_prob: full_score.log_prob,
346 ngram_length: full_score.ngram_length,
347 oov: index == self.not_found_index(),
348 });
349 std::mem::swap(&mut state, &mut next);
350 }
351
352 if eos {
353 let full_score = self.base_full_score(&state, self.end_sentence_index(), &mut next)?;
354 scores.push(TokenScore {
355 log_prob: full_score.log_prob,
356 ngram_length: full_score.ngram_length,
357 oov: false,
358 });
359 }
360
361 Ok(scores)
362 }
363
364 pub fn begin_sentence_state(&self) -> State {
366 let mut state = self.empty_state();
367 unsafe {
370 kenlm_model_begin_sentence_write(self.raw.as_ptr(), state.as_mut_ptr());
371 }
372 state
373 }
374
375 pub fn null_context_state(&self) -> State {
377 let mut state = self.empty_state();
378 unsafe {
381 kenlm_model_null_context_write(self.raw.as_ptr(), state.as_mut_ptr());
382 }
383 state
384 }
385
386 pub fn base_score(
388 &self,
389 in_state: &State,
390 word_index: WordIndex,
391 out_state: &mut State,
392 ) -> Result<f32> {
393 self.validate_state(in_state)?;
394 self.validate_state(out_state)?;
395 debug_assert!(!std::ptr::eq(in_state.as_ptr(), out_state.as_ptr()));
398 let mut score = 0.0;
399 let status = unsafe {
403 kenlm_model_try_base_score(
404 self.raw.as_ptr(),
405 in_state.as_ptr(),
406 word_index as c_uint,
407 out_state.as_mut_ptr(),
408 &mut score,
409 )
410 };
411 if status == 0 {
412 Ok(score)
413 } else {
414 Err(last_error())
415 }
416 }
417
418 pub fn base_full_score(
420 &self,
421 in_state: &State,
422 word_index: WordIndex,
423 out_state: &mut State,
424 ) -> Result<FullScore> {
425 self.validate_state(in_state)?;
426 self.validate_state(out_state)?;
427 debug_assert!(!std::ptr::eq(in_state.as_ptr(), out_state.as_ptr()));
430 let mut raw = RawFullScore {
431 prob: 0.0,
432 ngram_length: 0,
433 independent_left: 0,
434 extend_left: 0,
435 rest: 0.0,
436 };
437 let status = unsafe {
441 kenlm_model_try_base_full_score(
442 self.raw.as_ptr(),
443 in_state.as_ptr(),
444 word_index as c_uint,
445 out_state.as_mut_ptr(),
446 &mut raw,
447 )
448 };
449 if status != 0 {
450 return Err(last_error());
451 }
452 Ok(FullScore {
453 log_prob: raw.prob,
454 ngram_length: raw.ngram_length,
455 independent_left: raw.independent_left != 0,
456 extend_left: raw.extend_left,
457 rest: raw.rest,
458 })
459 }
460
461 fn initial_state(&self, bos: bool) -> State {
462 if bos {
463 self.begin_sentence_state()
464 } else {
465 self.null_context_state()
466 }
467 }
468
469 fn empty_state(&self) -> State {
470 State {
471 bytes: vec![0; self.state_size],
472 owner: Arc::clone(&self.token),
473 }
474 }
475
476 fn validate_state(&self, state: &State) -> Result<()> {
477 if state.bytes.len() != self.state_size || !Arc::ptr_eq(&state.owner, &self.token) {
478 return Err(KenlmError::StateModelMismatch);
479 }
480 Ok(())
481 }
482}
483
484impl Drop for Model {
485 fn drop(&mut self) {
486 unsafe {
489 kenlm_model_free(self.raw.as_ptr());
490 }
491 }
492}
493
494#[derive(Clone, Debug)]
496pub struct State {
497 bytes: Vec<u8>,
498 owner: Arc<ModelToken>,
499}
500
501impl State {
502 fn as_ptr(&self) -> *const c_void {
503 self.bytes.as_ptr().cast()
504 }
505
506 fn as_mut_ptr(&mut self) -> *mut c_void {
507 self.bytes.as_mut_ptr().cast()
508 }
509}
510
511impl PartialEq for State {
512 fn eq(&self, other: &Self) -> bool {
513 Arc::ptr_eq(&self.owner, &other.owner) && self.bytes == other.bytes
514 }
515}
516
517impl Eq for State {}
518
519#[derive(Clone, Copy, Debug, PartialEq)]
521pub struct FullScore {
522 pub log_prob: f32,
523 pub ngram_length: u8,
524 pub independent_left: bool,
525 pub extend_left: u64,
526 pub rest: f32,
527}
528
529#[derive(Clone, Copy, Debug, PartialEq)]
531pub struct TokenScore {
532 pub log_prob: f32,
533 pub ngram_length: u8,
534 pub oov: bool,
535}
536
537fn last_error() -> KenlmError {
538 let message = unsafe {
541 let ptr = kenlm_last_error();
542 if ptr.is_null() {
543 String::new()
544 } else {
545 CStr::from_ptr(ptr).to_string_lossy().into_owned()
546 }
547 };
548 if message.is_empty() {
549 KenlmError::Load("unknown KenLM error".to_string())
550 } else {
551 KenlmError::Load(message)
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn loads_and_scores_test_model() {
561 let config = Config {
562 show_progress: false,
563 ..Config::default()
564 };
565 let model = Model::with_config("lm/test.arpa", config).unwrap();
566
567 assert!(model.order() > 0);
568 assert!(model.contains("looking").unwrap());
569 assert!(!model.contains("definitely-not-in-this-model").unwrap());
570
571 let score = model.score("looking on a little", true, true).unwrap();
572 assert!(score.is_finite());
573
574 let full_scores = model
575 .full_scores("looking on a little", true, true)
576 .unwrap();
577 assert_eq!(full_scores.len(), 5);
578 assert!(full_scores.iter().all(|score| score.log_prob.is_finite()));
579 }
580
581 #[test]
582 fn supports_stateful_scoring() {
583 let config = Config {
584 show_progress: false,
585 ..Config::default()
586 };
587 let model = Model::with_config("lm/test.arpa", config).unwrap();
588
589 let mut state = model.begin_sentence_state();
590 let mut out = model.null_context_state();
591 let looking = model.index("looking").unwrap();
592
593 let score = model.base_score(&state, looking, &mut out).unwrap();
594 assert!(score.is_finite());
595
596 std::mem::swap(&mut state, &mut out);
597 let full = model
598 .base_full_score(&state, model.end_sentence_index(), &mut out)
599 .unwrap();
600 assert!(full.log_prob.is_finite());
601 }
602
603 #[test]
604 fn rejects_states_from_other_models() {
605 let config = Config {
606 show_progress: false,
607 ..Config::default()
608 };
609 let first = Model::with_config("lm/test.arpa", config).unwrap();
610 let second = Model::with_config("lm/test.arpa", config).unwrap();
611
612 let state = first.begin_sentence_state();
613 let mut out = second.null_context_state();
614 let word = second.index("looking").unwrap();
615
616 let error = second.base_score(&state, word, &mut out).unwrap_err();
617 assert!(matches!(error, KenlmError::StateModelMismatch));
618 }
619}