#[cfg(feature = "pyo3")]
mod py;
#[cfg(feature = "pyo3")]
pub use py::PyLongestStringMatching;
use crate::trie::Trie;
use flatbuffers;
use std::io::Write;
#[allow(dead_code, unused_imports, clippy::all)]
mod generated {
include!(concat!(env!("OUT_DIR"), "/lsm/model_generated.rs"));
}
use crate::persistence::ModelError;
pub trait BaseLongestStringMatching: Sized + Clone + Sync {
fn max_word_length(&self) -> usize;
fn trie(&self) -> &Trie<char, ()>;
fn trie_mut(&mut self) -> &mut Trie<char, ()>;
fn from_parts(max_word_length: usize, trie: Trie<char, ()>) -> Self;
fn fit(&mut self, sents: Vec<Vec<String>>) {
*self.trie_mut() = Trie::new();
for sent in sents {
for word in sent {
if word.chars().count() > 1 {
self.trie_mut().insert_seq(word.chars());
}
}
}
}
fn predict(&self, sent_strs: Vec<String>) -> Vec<Vec<String>> {
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
sent_strs
.into_par_iter()
.with_min_len(16)
.map(|sent_str| self.predict_sent(&sent_str))
.collect()
}
#[cfg(not(feature = "parallel"))]
{
sent_strs
.into_iter()
.map(|sent_str| self.predict_sent(&sent_str))
.collect()
}
}
fn save_to_path(&self, path: &str) -> Result<(), ModelError> {
let mut buf = Vec::new();
save_lsm_flatbuffers(self, &mut buf)?;
crate::persistence::save_zstd(path, &buf)
}
fn load_from_path(&mut self, path: &str) -> Result<(), ModelError> {
let bytes = crate::persistence::load_zstd(path, "longest string matching model")?;
load_lsm_flatbuffers(self, &bytes)
}
fn predict_sent(&self, sent_str: &str) -> Vec<String> {
let chars: Vec<char> = sent_str.chars().collect();
if chars.is_empty() {
return Vec::new();
}
let estimated_words = (chars.len() / 3).max(1);
let mut sent_predicted = Vec::with_capacity(estimated_words);
let mut i = 0;
while i < chars.len() {
let remaining = &chars[i..];
let max_len = std::cmp::min(remaining.len(), self.max_word_length());
let match_len = self.trie().longest_match(remaining, max_len);
if match_len > 0 {
let word: String = chars[i..i + match_len].iter().collect();
sent_predicted.push(word);
i += match_len;
} else {
sent_predicted.push(chars[i].to_string());
i += 1;
}
}
sent_predicted
}
fn predict_with_offsets(&self, sent_strs: Vec<String>) -> Vec<Vec<(String, (usize, usize))>> {
let words = self.predict(sent_strs);
crate::wordseg::attach_offsets(words)
}
}
fn save_lsm_flatbuffers<T: BaseLongestStringMatching, W: Write>(
model: &T,
writer: &mut W,
) -> Result<(), ModelError> {
use generated::rustling::lsm_fbs as fbs;
let mut builder = flatbuffers::FlatBufferBuilder::with_capacity(512 * 1024);
let mut words: Vec<String> = model
.trie()
.all_sequences()
.into_iter()
.map(|chars| chars.into_iter().collect())
.collect();
words.sort();
let word_strs: Vec<_> = words.iter().map(|w| builder.create_string(w)).collect();
let words_fb = builder.create_vector(&word_strs);
let lsm = fbs::LsmModel::create(
&mut builder,
&fbs::LsmModelArgs {
max_word_length: model.max_word_length() as u32,
words: Some(words_fb),
},
);
builder.finish(lsm, None);
writer
.write_all(builder.finished_data())
.map_err(|e| ModelError::Io(format!("Failed to write FlatBuffers data: {e}")))
}
fn load_lsm_flatbuffers<T: BaseLongestStringMatching>(
model: &mut T,
bytes: &[u8],
) -> Result<(), ModelError> {
use generated::rustling::lsm_fbs as fbs;
let opts = crate::persistence::flatbuffers_verifier_opts();
let lsm = flatbuffers::root_with_opts::<fbs::LsmModel>(&opts, bytes)
.map_err(|e| ModelError::ParseError(format!("Invalid FlatBuffers LSM data: {e}")))?;
let mut trie = Trie::new();
for word in lsm.words().iter() {
trie.insert_seq(word.chars());
}
*model.trie_mut() = trie;
Ok(())
}
#[derive(Clone, Debug)]
pub struct LongestStringMatching {
max_word_length: usize,
trie: Trie<char, ()>,
}
impl BaseLongestStringMatching for LongestStringMatching {
fn max_word_length(&self) -> usize {
self.max_word_length
}
fn trie(&self) -> &Trie<char, ()> {
&self.trie
}
fn trie_mut(&mut self) -> &mut Trie<char, ()> {
&mut self.trie
}
fn from_parts(max_word_length: usize, trie: Trie<char, ()>) -> Self {
Self {
max_word_length,
trie,
}
}
}
impl LongestStringMatching {
pub fn new(max_word_length: usize) -> Result<Self, ModelError> {
if max_word_length < 2 {
return Err(ModelError::ValidationError(format!(
"max_word_length must be >= 2 to be meaningful: {}",
max_word_length
)));
}
Ok(Self {
max_word_length,
trie: Trie::new(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_valid() {
let model = LongestStringMatching::new(4);
assert!(model.is_ok());
let model = model.unwrap();
assert_eq!(model.max_word_length, 4);
}
#[test]
fn test_new_invalid_max_word_length() {
let result = LongestStringMatching::new(1);
assert!(result.is_err());
}
#[test]
fn test_basic() {
let mut model = LongestStringMatching::new(4).unwrap();
model.fit(vec![
vec![
"this".to_string(),
"is".to_string(),
"a".to_string(),
"sentence".to_string(),
],
vec![
"that".to_string(),
"is".to_string(),
"not".to_string(),
"a".to_string(),
"sentence".to_string(),
],
]);
let result = model.predict(vec!["thatisadog".to_string(), "thisisnotacat".to_string()]);
assert_eq!(
result,
vec![
vec!["that", "is", "a", "d", "o", "g"],
vec!["this", "is", "not", "a", "c", "a", "t"],
]
);
}
#[test]
fn test_empty_input() {
let mut model = LongestStringMatching::new(4).unwrap();
model.fit(vec![vec!["hello".to_string(), "world".to_string()]]);
let result = model.predict(vec!["".to_string()]);
assert_eq!(result, vec![Vec::<String>::new()]);
}
#[test]
fn test_no_training_data() {
let mut model = LongestStringMatching::new(4).unwrap();
model.fit(vec![]);
let result = model.predict(vec!["hello".to_string()]);
assert_eq!(result, vec![vec!["h", "e", "l", "l", "o"]]);
}
#[test]
fn test_single_char_words_ignored_in_training() {
let mut model = LongestStringMatching::new(4).unwrap();
model.fit(vec![vec![
"a".to_string(),
"b".to_string(),
"ab".to_string(),
]]);
let result = model.predict(vec!["abab".to_string()]);
assert_eq!(result, vec![vec!["ab", "ab"]]);
}
#[test]
fn test_unicode_chars() {
let mut model = LongestStringMatching::new(4).unwrap();
model.fit(vec![vec!["你好".to_string(), "世界".to_string()]]);
let result = model.predict(vec!["你好世界".to_string()]);
assert_eq!(result, vec![vec!["你好", "世界"]]);
}
#[test]
fn test_max_word_length_constraint() {
let mut model = LongestStringMatching::new(3).unwrap();
model.fit(vec![vec!["hello".to_string()]]);
let result = model.predict(vec!["hello".to_string()]);
assert_eq!(result, vec![vec!["h", "e", "l", "l", "o"]]);
}
#[test]
fn test_save_and_load() {
let mut model = LongestStringMatching::new(4).unwrap();
model.fit(vec![
vec![
"this".to_string(),
"is".to_string(),
"a".to_string(),
"sentence".to_string(),
],
vec![
"that".to_string(),
"is".to_string(),
"not".to_string(),
"a".to_string(),
"sentence".to_string(),
],
]);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("model.bin");
let path_str = path.to_str().unwrap();
model.save_to_path(path_str).unwrap();
let mut loaded = LongestStringMatching::new(4).unwrap();
loaded.load_from_path(path_str).unwrap();
let input = vec!["thatisadog".to_string(), "thisisnotacat".to_string()];
assert_eq!(model.predict(input.clone()), loaded.predict(input));
}
}