use std::num::NonZero;
use fst::raw::{Fst, Output};
pub(crate) struct FstSegmenter<'fst> {
words_fst: &'fst Fst<&'fst [u8]>,
buffering_strategy: BufferingStrategy, }
impl<'fst> FstSegmenter<'fst> {
pub(crate) fn new(
words_fst: &'fst Fst<&'fst [u8]>,
buffering_strategy: BufferingStrategy,
) -> Self {
Self { words_fst, buffering_strategy }
}
pub fn segment_str<'o>(
&'fst self,
to_segment: &'o str,
) -> Box<dyn Iterator<Item = &'o str> + 'o>
where
'fst: 'o,
{
let mut cursor = SegmentationCursor::new(to_segment);
let iter = std::iter::from_fn(move || {
loop {
let Some(next_to_segment) = cursor.tail() else {
return cursor.take_buffered_segment();
};
let next_match = find_longest_prefix(self.words_fst, next_to_segment.as_bytes());
if let Some((_, length)) = next_match {
return cursor.compute_next_segment(length);
} else {
match self.buffering_strategy {
BufferingStrategy::UntilNextMatch { max_char_count } => {
if cursor.buffer_next_character(max_char_count).is_full() {
return cursor.take_buffered_segment();
}
}
}
}
}
});
Box::new(iter)
}
}
fn floor_char_boundary(s: &str, length: usize) -> usize {
s.char_indices().find(|(idx, _)| *idx >= length).map(|(idx, _)| idx).unwrap_or(s.len())
}
fn is_max_char_count_reached(s: &str, max_char_count: Option<NonZero<usize>>) -> bool {
if let Some(max_char_count) = max_char_count {
s.chars().count() >= max_char_count.get()
} else {
false
}
}
#[inline]
fn find_longest_prefix(fst: &Fst<&[u8]>, value: &[u8]) -> Option<(u64, usize)> {
let mut node = fst.root();
let mut out = Output::zero();
let mut last_match = None;
for (i, &b) in value.iter().enumerate() {
if let Some(trans_index) = node.find_input(b) {
let t = node.transition(trans_index);
node = fst.node(t.addr);
out = out.cat(t.out);
if node.is_final() {
last_match = Some((out.cat(node.final_output()).value(), i + 1));
}
} else {
return last_match;
}
}
last_match
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BufferingStrategy {
UntilNextMatch { max_char_count: Option<NonZero<usize>> },
}
enum BufferState {
Full,
Buffering,
}
impl BufferState {
fn is_full(&self) -> bool {
matches!(self, BufferState::Full)
}
}
struct SegmentationCursor<'o> {
to_segment: &'o str,
buffer_head_offset: Option<usize>,
offset: usize,
}
impl<'o> SegmentationCursor<'o> {
fn new(to_segment: &'o str) -> Self {
Self { to_segment, buffer_head_offset: None, offset: 0 }
}
fn tail(&self) -> Option<&'o str> {
self.to_segment.get(self.offset..).filter(|s| !s.is_empty())
}
fn take_buffered_segment(&mut self) -> Option<&'o str> {
self.buffer_head_offset
.take()
.and_then(|head| self.to_segment.get(head..self.offset))
.filter(|s| !s.is_empty())
}
fn buffer_next_character(&mut self, max_char_count: Option<NonZero<usize>>) -> BufferState {
let head = *self.buffer_head_offset.get_or_insert(self.offset);
let tail = {
self.offset += self.next_character_length();
self.offset
};
let segment = &self.to_segment[head..tail];
if is_max_char_count_reached(segment, max_char_count) {
BufferState::Full
} else {
BufferState::Buffering
}
}
fn next_character_length(&self) -> usize {
self.to_segment[self.offset..].chars().next().unwrap().len_utf8()
}
fn compute_next_segment(&mut self, next_segment_length: usize) -> Option<&'o str> {
if let Some(buffered_segment) = self.take_buffered_segment() {
return Some(buffered_segment);
}
if let Some(tail) = self.tail() {
let length = floor_char_boundary(tail, next_segment_length);
self.offset += length;
return Some(&tail[..length]);
}
None
}
}