1use std::ops::{Index, Range};
2use std::str;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6use smartstring::alias::String;
7
8#[cfg(feature = "test-cases")]
9pub mod test_cases;
10#[cfg(feature = "__test_data")]
11pub mod test_data;
12
13#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
15pub struct Segmenter {
16 scores: HashMap<String, (f64, HashMap<String, f64>)>,
20 uni_total_log10: f64,
22 limit: usize,
23}
24
25impl Segmenter {
26 pub fn new<U, B>(unigrams: U, bigrams: B) -> Self
31 where
32 U: IntoIterator<Item = (String, f64)>,
33 B: IntoIterator<Item = ((String, String), f64)>,
34 {
35 let mut scores = HashMap::default();
37 let mut uni_total = 0.0;
38 for (word, uni) in unigrams {
39 scores.insert(word, (uni, HashMap::default()));
40 uni_total += uni;
41 }
42 let mut bi_total = 0.0;
43 for ((word1, word2), bi) in bigrams {
44 let Some((_, bi_scores)) = scores.get_mut(&word2) else {
45 continue;
50 };
51 bi_scores.insert(word1, bi);
52 bi_total += bi;
53 }
54
55 for (uni, bi_scores) in scores.values_mut() {
58 *uni = (*uni / uni_total).log10();
59 for bi in bi_scores.values_mut() {
60 *bi = (*bi / bi_total).log10();
61 }
62 }
63
64 Self {
65 uni_total_log10: uni_total.log10(),
66 scores,
67 limit: DEFAULT_LIMIT,
68 }
69 }
70
71 pub fn segment<'a>(
77 &self,
78 input: &str,
79 search: &'a mut Search,
80 ) -> Result<Segments<'a>, InvalidCharacter> {
81 let state = SegmentState::new(Ascii::new(input)?, self, search);
82 let score = match input {
83 "" => 0.0,
84 _ => state.run(),
85 };
86
87 Ok(Segments {
88 iter: search.result.iter(),
89 score,
90 })
91 }
92
93 pub fn score_sentence<'a>(&self, mut words: impl Iterator<Item = &'a str>) -> Option<f64> {
98 let mut prev = words.next()?;
99 let mut score = self.score(prev, None);
100 for word in words {
101 score += self.score(word, Some(prev));
102 prev = word;
103 }
104 Some(score)
105 }
106
107 fn score(&self, word: &str, previous: Option<&str>) -> f64 {
108 let (uni, bi_scores) = match self.scores.get(word) {
109 Some((uni, bi_scores)) => (uni, bi_scores),
110 None => {
123 let word_len = word.len() as f64;
124 let word_count = word_len / 5.0;
125 return (1.0 - self.uni_total_log10 - word_len) * word_count;
126 }
127 };
128
129 if let Some(prev) = previous {
130 if let Some(bi) = bi_scores.get(prev) {
131 if let Some((uni_prev, _)) = self.scores.get(prev) {
132 return bi - uni_prev;
136 }
137 }
138 }
139
140 *uni
141 }
142
143 pub fn set_limit(&mut self, limit: usize) {
145 self.limit = limit;
146 }
147}
148
149pub struct Segments<'a> {
150 iter: std::slice::Iter<'a, String>,
151 score: f64,
152}
153
154impl Segments<'_> {
155 pub fn score(&self) -> f64 {
157 self.score
158 }
159}
160
161impl<'a> Iterator for Segments<'a> {
162 type Item = &'a str;
163
164 fn next(&mut self) -> Option<Self::Item> {
165 self.iter.next().map(|v| v.as_str())
166 }
167}
168
169impl ExactSizeIterator for Segments<'_> {
170 fn len(&self) -> usize {
171 self.iter.len()
172 }
173}
174
175struct SegmentState<'a> {
176 data: &'a Segmenter,
177 text: Ascii<'a>,
178 search: &'a mut Search,
179}
180
181impl<'a> SegmentState<'a> {
182 fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
183 search.clear();
184 Self { data, text, search }
185 }
186
187 fn run(self) -> f64 {
188 for end in 1..=self.text.len() {
189 let start = end.saturating_sub(self.data.limit);
190 for split in start..end {
191 let (prev, prev_score) = match split {
192 0 => (None, 0.0),
193 _ => {
194 let prefix = self.search.candidates[split - 1];
195 let word = &self.text[split - prefix.len..split];
196 (Some(word), prefix.score)
197 }
198 };
199
200 let word = &self.text[split..end];
201 let score = self.data.score(word, prev) + prev_score;
202 match self.search.candidates.get_mut(end - 1) {
203 Some(cur) if cur.score < score => {
204 cur.len = end - split;
205 cur.score = score;
206 }
207 None => self.search.candidates.push(Candidate {
208 len: end - split,
209 score,
210 }),
211 _ => {}
212 }
213 }
214 }
215
216 let mut end = self.text.len();
217 let mut best = self.search.candidates[end - 1];
218 let score = best.score;
219 loop {
220 let word = &self.text[end - best.len..end];
221 self.search.result.push(word.into());
222
223 end -= best.len;
224 if end == 0 {
225 break;
226 }
227
228 best = self.search.candidates[end - 1];
229 }
230
231 self.search.result.reverse();
232 score
233 }
234}
235
236#[derive(Clone, Default)]
238pub struct Search {
239 candidates: Vec<Candidate>,
240 result: Vec<String>,
241}
242
243impl Search {
244 fn clear(&mut self) {
245 self.candidates.clear();
246 self.result.clear();
247 }
248
249 #[doc(hidden)]
250 pub fn get(&self, idx: usize) -> Option<&str> {
251 self.result.get(idx).map(|v| v.as_str())
252 }
253}
254
255#[derive(Clone, Copy, Debug, Default)]
256struct Candidate {
257 len: usize,
258 score: f64,
259}
260
261#[derive(Debug)]
262struct Ascii<'a>(&'a [u8]);
263
264impl<'a> Ascii<'a> {
265 fn new(s: &'a str) -> Result<Self, InvalidCharacter> {
266 let bytes = s.as_bytes();
267
268 let valid = bytes
269 .iter()
270 .all(|b| b.is_ascii_lowercase() || b.is_ascii_digit());
271
272 match valid {
273 true => Ok(Self(bytes)),
274 false => Err(InvalidCharacter),
275 }
276 }
277
278 fn len(&self) -> usize {
279 self.0.len()
280 }
281}
282
283impl Index<Range<usize>> for Ascii<'_> {
284 type Output = str;
285
286 fn index(&self, index: Range<usize>) -> &Self::Output {
287 let bytes = self.0.index(index);
288 unsafe { str::from_utf8_unchecked(bytes) }
290 }
291}
292
293#[derive(Debug)]
295pub struct InvalidCharacter;
296
297impl std::error::Error for InvalidCharacter {}
298
299impl std::fmt::Display for InvalidCharacter {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.write_str("invalid character")
302 }
303}
304
305type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
306
307const DEFAULT_LIMIT: usize = 24;
308
309#[cfg(test)]
310pub mod tests {
311 use super::*;
312
313 #[test]
314 fn test_clean() {
315 Ascii::new("Can't buy me love!").unwrap_err();
316 let text = Ascii::new("cantbuymelove").unwrap();
317 assert_eq!(&text[0..text.len()], "cantbuymelove");
318 let text_with_numbers = Ascii::new("c4ntbuym3l0v3").unwrap();
319 assert_eq!(
320 &text_with_numbers[0..text_with_numbers.len()],
321 "c4ntbuym3l0v3"
322 );
323 }
324}