use std::ops::{Index, Range};
use std::str;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use smartstring::alias::String;
#[cfg(feature = "test-cases")]
pub mod test_cases;
#[cfg(feature = "__test_data")]
pub mod test_data;
#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
pub struct Segmenter {
scores: HashMap<String, (f64, HashMap<String, f64>)>,
uni_total_log10: f64,
limit: usize,
}
impl Segmenter {
pub fn new<U, B>(unigrams: U, bigrams: B) -> Self
where
U: IntoIterator<Item = (String, f64)>,
B: IntoIterator<Item = ((String, String), f64)>,
{
let mut scores = HashMap::default();
let mut uni_total = 0.0;
for (word, uni) in unigrams {
scores.insert(word, (uni, HashMap::default()));
uni_total += uni;
}
let mut bi_total = 0.0;
for ((word1, word2), bi) in bigrams {
let Some((_, bi_scores)) = scores.get_mut(&word2) else {
continue;
};
bi_scores.insert(word1, bi);
bi_total += bi;
}
for (uni, bi_scores) in scores.values_mut() {
*uni = (*uni / uni_total).log10();
for bi in bi_scores.values_mut() {
*bi = (*bi / bi_total).log10();
}
}
Self {
uni_total_log10: uni_total.log10(),
scores,
limit: DEFAULT_LIMIT,
}
}
pub fn segment<'a>(
&self,
input: &str,
search: &'a mut Search,
) -> Result<impl Iterator<Item = &'a str> + ExactSizeIterator, InvalidCharacter> {
let state = SegmentState::new(Ascii::new(input)?, self, search);
if !input.is_empty() {
state.run();
}
Ok(search.result.iter().map(|v| v.as_str()))
}
pub fn score_sentence<'a>(&self, mut words: impl Iterator<Item = &'a str>) -> Option<f64> {
let mut prev = words.next()?;
let mut score = self.score(prev, None);
for word in words {
score += self.score(word, Some(prev));
prev = word;
}
Some(score)
}
fn score(&self, word: &str, previous: Option<&str>) -> f64 {
let (uni, bi_scores) = match self.scores.get(word) {
Some((uni, bi_scores)) => (uni, bi_scores),
None => return 1.0 - self.uni_total_log10 - word.len() as f64,
};
if let Some(prev) = previous {
if let Some(bi) = bi_scores.get(prev) {
if let Some((uni_prev, _)) = self.scores.get(prev) {
return bi - uni_prev;
}
}
}
*uni
}
pub fn set_limit(&mut self, limit: usize) {
self.limit = limit;
}
}
struct SegmentState<'a> {
data: &'a Segmenter,
text: Ascii<'a>,
search: &'a mut Search,
}
impl<'a> SegmentState<'a> {
fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
search.clear();
Self { data, text, search }
}
fn run(self) {
for end in 1..=self.text.len() {
let start = end.saturating_sub(self.data.limit);
for split in start..end {
let (prev, prev_score) = match split {
0 => (None, 0.0),
_ => {
let prefix = self.search.candidates[split - 1];
let word = &self.text[split - prefix.len..split];
(Some(word), prefix.score)
}
};
let word = &self.text[split..end];
let score = self.data.score(word, prev) + prev_score;
match self.search.candidates.get_mut(end - 1) {
Some(cur) if cur.score < score => {
cur.len = end - split;
cur.score = score;
}
None => self.search.candidates.push(Candidate {
len: end - split,
score,
}),
_ => {}
}
}
}
let mut end = self.text.len();
let mut best = self.search.candidates[end - 1];
loop {
let word = &self.text[end - best.len..end];
self.search.result.push(word.into());
end -= best.len;
if end == 0 {
break;
}
best = self.search.candidates[end - 1];
}
self.search.result.reverse();
}
}
#[derive(Clone, Default)]
pub struct Search {
candidates: Vec<Candidate>,
result: Vec<String>,
}
impl Search {
fn clear(&mut self) {
self.candidates.clear();
self.result.clear();
}
#[doc(hidden)]
pub fn get(&self, idx: usize) -> Option<&str> {
self.result.get(idx).map(|v| v.as_str())
}
}
#[derive(Clone, Copy, Debug, Default)]
struct Candidate {
len: usize,
score: f64,
}
#[derive(Debug)]
struct Ascii<'a>(&'a [u8]);
impl<'a> Ascii<'a> {
fn new(s: &'a str) -> Result<Self, InvalidCharacter> {
let bytes = s.as_bytes();
let valid = bytes
.iter()
.all(|b| b.is_ascii_lowercase() || b.is_ascii_digit());
match valid {
true => Ok(Self(bytes)),
false => Err(InvalidCharacter),
}
}
fn len(&self) -> usize {
self.0.len()
}
}
impl<'a> Index<Range<usize>> for Ascii<'a> {
type Output = str;
fn index(&self, index: Range<usize>) -> &Self::Output {
let bytes = self.0.index(index);
unsafe { str::from_utf8_unchecked(bytes) }
}
}
#[derive(Debug)]
pub struct InvalidCharacter;
impl std::error::Error for InvalidCharacter {}
impl std::fmt::Display for InvalidCharacter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("invalid character")
}
}
type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
const DEFAULT_LIMIT: usize = 24;
#[cfg(test)]
pub mod tests {
use super::*;
#[test]
fn test_clean() {
Ascii::new("Can't buy me love!").unwrap_err();
let text = Ascii::new("cantbuymelove").unwrap();
assert_eq!(&text[0..text.len()], "cantbuymelove");
let text_with_numbers = Ascii::new("c4ntbuym3l0v3").unwrap();
assert_eq!(
&text_with_numbers[0..text_with_numbers.len()],
"c4ntbuym3l0v3"
);
}
}