use fst::Error;
use fst::raw::Fst;
use fst::raw::Output;
use std::vec::IntoIter;
use crate::SegmentedToken;
use crate::UseOrSubdivide;
use crate::segmentation::Segmenter;
#[derive(Clone)]
pub struct DecompositionFst<D: AsRef<[u8]>> {
dictionary: Fst<D>,
}
impl DecompositionFst<Vec<u8>> {
pub fn from_dictionary<I, P>(dict: I) -> Result<Self, Error>
where
I: IntoIterator<Item = P>,
P: AsRef<[u8]>,
{
let fst = Fst::from_iter_set(dict)?;
Ok(DecompositionFst::from_fst(fst))
}
}
impl<D> DecompositionFst<D>
where
D: AsRef<[u8]>,
{
pub fn from_fst(dictionary: Fst<D>) -> Self {
DecompositionFst { dictionary }
}
}
impl<D> Segmenter for DecompositionFst<D>
where
D: AsRef<[u8]>,
{
type SubdivisionIter<'a> = IntoIter<SegmentedToken<'a>>;
fn subdivide<'a>(
&self,
token: SegmentedToken<'a>,
) -> UseOrSubdivide<SegmentedToken<'a>, IntoIter<SegmentedToken<'a>>> {
if token.is_known_word {
return UseOrSubdivide::Use(token);
}
let mut cuts = Vec::new();
let mut pos: usize = 0;
while let Some((_, length)) =
find_longest_prefix(&self.dictionary, &token.text.as_bytes()[pos..])
{
cuts.push(length);
pos += length;
if !token.text.is_char_boundary(pos) && pos != token.text.len() {
eprintln!(
"Detected invalid utf8 in dictionary, not subdividing token: {:?}",
token.text
);
return UseOrSubdivide::Use(token);
}
}
if pos == token.text.len() {
let mut subsegments = Vec::<SegmentedToken>::with_capacity(cuts.len() + 1);
let mut text = token.text;
for pos in cuts {
let (word, rest) = text.split_at(pos);
text = rest;
subsegments
.push(SegmentedToken::new_derived_from(word, &token).with_is_kown_word(true));
}
return UseOrSubdivide::Subdivide(subsegments.into_iter());
} else {
return UseOrSubdivide::Use(token);
}
}
}
#[inline]
fn find_longest_prefix<D>(fst: &Fst<D>, value: &[u8]) -> Option<(u64, usize)>
where
D: AsRef<[u8]>,
{
let mut node = fst.root();
let mut out = Output::zero();
let mut last_match = None;
for (i, &b) in value.iter().enumerate() {
if let Some(trans_index) = node.find_input(b) {
let t = node.transition(trans_index);
node = fst.node(t.addr);
out = out.cat(t.out);
if node.is_final() {
last_match = Some((out.cat(node.final_output()).value(), i + 1));
}
} else {
return last_match;
}
}
last_match
}
#[cfg(test)]
mod test {
use super::*;
use crate::SubdivisionMap;
use crate::initial_paragraph_splitter::InitialParagraphSplitter;
#[test]
fn test_decomposition_fst() {
let decomposer = DecompositionFst::from_dictionary(vec!["bar", "baz", "foo"]).unwrap();
let splitter = InitialParagraphSplitter::new("foobarbaz fooquux foo bazbaz");
let subsplitter = SubdivisionMap::new(splitter, |s| decomposer.subdivide(s));
let result: Vec<&str> = subsplitter.map(|s| s.text).collect();
let expected_result = vec![
"foo", "bar", "baz", " ", "fooquux", " ", "foo", " ", "baz", "baz",
];
assert_eq!(result, expected_result);
}
}