instant_segment/
lib.rs

1use std::ops::{Index, Range};
2use std::str;
3
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6use smartstring::alias::String;
7
8#[cfg(feature = "test-cases")]
9pub mod test_cases;
10#[cfg(feature = "__test_data")]
11pub mod test_data;
12
13/// Central data structure used to calculate word probabilities
14#[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
15pub struct Segmenter {
16    // Maps a word to both its unigram score, as well has a nested HashMap in
17    // which the bigram score can be looked up using the previous word. Scores
18    // are base-10 logarithms of relative word frequencies
19    scores: HashMap<String, (f64, HashMap<String, f64>)>,
20    // Base-10 logarithm of the total count of unigrams
21    uni_total_log10: f64,
22    limit: usize,
23}
24
25impl Segmenter {
26    /// Create `Segmenter` from the given unigram and bigram counts.
27    ///
28    /// Note: the `String` types used in this API are defined in the `smartstring` crate. Any
29    /// `&str` or `String` can be converted into the `String` used here by calling `into()` on it.
30    pub fn new<U, B>(unigrams: U, bigrams: B) -> Self
31    where
32        U: IntoIterator<Item = (String, f64)>,
33        B: IntoIterator<Item = ((String, String), f64)>,
34    {
35        // Initially, `scores` contains the original unigram and bigram counts
36        let mut scores = HashMap::default();
37        let mut uni_total = 0.0;
38        for (word, uni) in unigrams {
39            scores.insert(word, (uni, HashMap::default()));
40            uni_total += uni;
41        }
42        let mut bi_total = 0.0;
43        for ((word1, word2), bi) in bigrams {
44            let Some((_, bi_scores)) = scores.get_mut(&word2) else {
45                // We throw away bigrams for which we do not have a unigram for
46                // the second word. This case shouldn't ever happen on
47                // real-world data, and in fact, it never happens on the word
48                // count lists shipped with this crate.
49                continue;
50            };
51            bi_scores.insert(word1, bi);
52            bi_total += bi;
53        }
54
55        // Now convert the counts in `scores` to the values we actually want,
56        // namely logarithms of relative frequencies
57        for (uni, bi_scores) in scores.values_mut() {
58            *uni = (*uni / uni_total).log10();
59            for bi in bi_scores.values_mut() {
60                *bi = (*bi / bi_total).log10();
61            }
62        }
63
64        Self {
65            uni_total_log10: uni_total.log10(),
66            scores,
67            limit: DEFAULT_LIMIT,
68        }
69    }
70
71    /// Segment the text in `input`
72    ///
73    /// Requires that the input `text` consists of lowercase ASCII characters only. Otherwise,
74    /// returns `Err(InvalidCharacter)`. The `search` parameter contains caches that are used
75    /// segmentation; passing it in allows the callers to reuse the cache allocations.
76    pub fn segment<'a>(
77        &self,
78        input: &str,
79        search: &'a mut Search,
80    ) -> Result<Segments<'a>, InvalidCharacter> {
81        let state = SegmentState::new(Ascii::new(input)?, self, search);
82        let score = match input {
83            "" => 0.0,
84            _ => state.run(),
85        };
86
87        Ok(Segments {
88            iter: search.result.iter(),
89            score,
90        })
91    }
92
93    /// Returns the sentence's score
94    ///
95    /// Returns the relative probability for the given sentence in the the corpus represented by
96    /// this `Segmenter`. Will return `None` iff given an empty iterator argument.
97    pub fn score_sentence<'a>(&self, mut words: impl Iterator<Item = &'a str>) -> Option<f64> {
98        let mut prev = words.next()?;
99        let mut score = self.score(prev, None);
100        for word in words {
101            score += self.score(word, Some(prev));
102            prev = word;
103        }
104        Some(score)
105    }
106
107    fn score(&self, word: &str, previous: Option<&str>) -> f64 {
108        let (uni, bi_scores) = match self.scores.get(word) {
109            Some((uni, bi_scores)) => (uni, bi_scores),
110            // Penalize words not found in the unigrams according
111            // to their length, a crucial heuristic.
112            //
113            // In the original presentation non-words are scored as
114            //
115            //    (1.0 - self.uni_total_log10 - word_len)
116            //
117            // However in practice this seems to under-penalize long non-words.  The intuition
118            // behind the variation used here is that it applies this penalty once for each word
119            // there "should" have been in the non-word's place.
120            //
121            // See <https://github.com/instant-labs/instant-segment/issues/53>.
122            None => {
123                let word_len = word.len() as f64;
124                let word_count = word_len / 5.0;
125                return (1.0 - self.uni_total_log10 - word_len) * word_count;
126            }
127        };
128
129        if let Some(prev) = previous {
130            if let Some(bi) = bi_scores.get(prev) {
131                if let Some((uni_prev, _)) = self.scores.get(prev) {
132                    // Conditional probability of the word given the previous
133                    // word. The technical name is "stupid backoff" and it's
134                    // not a probability distribution but it works well in practice.
135                    return bi - uni_prev;
136                }
137            }
138        }
139
140        *uni
141    }
142
143    /// Customize the word length `limit`
144    pub fn set_limit(&mut self, limit: usize) {
145        self.limit = limit;
146    }
147}
148
149pub struct Segments<'a> {
150    iter: std::slice::Iter<'a, String>,
151    score: f64,
152}
153
154impl Segments<'_> {
155    /// Returns the score of the segmented text
156    pub fn score(&self) -> f64 {
157        self.score
158    }
159}
160
161impl<'a> Iterator for Segments<'a> {
162    type Item = &'a str;
163
164    fn next(&mut self) -> Option<Self::Item> {
165        self.iter.next().map(|v| v.as_str())
166    }
167}
168
169impl ExactSizeIterator for Segments<'_> {
170    fn len(&self) -> usize {
171        self.iter.len()
172    }
173}
174
175struct SegmentState<'a> {
176    data: &'a Segmenter,
177    text: Ascii<'a>,
178    search: &'a mut Search,
179}
180
181impl<'a> SegmentState<'a> {
182    fn new(text: Ascii<'a>, data: &'a Segmenter, search: &'a mut Search) -> Self {
183        search.clear();
184        Self { data, text, search }
185    }
186
187    fn run(self) -> f64 {
188        for end in 1..=self.text.len() {
189            let start = end.saturating_sub(self.data.limit);
190            for split in start..end {
191                let (prev, prev_score) = match split {
192                    0 => (None, 0.0),
193                    _ => {
194                        let prefix = self.search.candidates[split - 1];
195                        let word = &self.text[split - prefix.len..split];
196                        (Some(word), prefix.score)
197                    }
198                };
199
200                let word = &self.text[split..end];
201                let score = self.data.score(word, prev) + prev_score;
202                match self.search.candidates.get_mut(end - 1) {
203                    Some(cur) if cur.score < score => {
204                        cur.len = end - split;
205                        cur.score = score;
206                    }
207                    None => self.search.candidates.push(Candidate {
208                        len: end - split,
209                        score,
210                    }),
211                    _ => {}
212                }
213            }
214        }
215
216        let mut end = self.text.len();
217        let mut best = self.search.candidates[end - 1];
218        let score = best.score;
219        loop {
220            let word = &self.text[end - best.len..end];
221            self.search.result.push(word.into());
222
223            end -= best.len;
224            if end == 0 {
225                break;
226            }
227
228            best = self.search.candidates[end - 1];
229        }
230
231        self.search.result.reverse();
232        score
233    }
234}
235
236/// Search state for a [`Segmenter`]
237#[derive(Clone, Default)]
238pub struct Search {
239    candidates: Vec<Candidate>,
240    result: Vec<String>,
241}
242
243impl Search {
244    fn clear(&mut self) {
245        self.candidates.clear();
246        self.result.clear();
247    }
248
249    #[doc(hidden)]
250    pub fn get(&self, idx: usize) -> Option<&str> {
251        self.result.get(idx).map(|v| v.as_str())
252    }
253}
254
255#[derive(Clone, Copy, Debug, Default)]
256struct Candidate {
257    len: usize,
258    score: f64,
259}
260
261#[derive(Debug)]
262struct Ascii<'a>(&'a [u8]);
263
264impl<'a> Ascii<'a> {
265    fn new(s: &'a str) -> Result<Self, InvalidCharacter> {
266        let bytes = s.as_bytes();
267
268        let valid = bytes
269            .iter()
270            .all(|b| b.is_ascii_lowercase() || b.is_ascii_digit());
271
272        match valid {
273            true => Ok(Self(bytes)),
274            false => Err(InvalidCharacter),
275        }
276    }
277
278    fn len(&self) -> usize {
279        self.0.len()
280    }
281}
282
283impl Index<Range<usize>> for Ascii<'_> {
284    type Output = str;
285
286    fn index(&self, index: Range<usize>) -> &Self::Output {
287        let bytes = self.0.index(index);
288        // Since `Ascii` can only be instantiated with ASCII characters, this should be safe
289        unsafe { str::from_utf8_unchecked(bytes) }
290    }
291}
292
293/// Error returned by [`Segmenter::segment`] when given an invalid character
294#[derive(Debug)]
295pub struct InvalidCharacter;
296
297impl std::error::Error for InvalidCharacter {}
298
299impl std::fmt::Display for InvalidCharacter {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.write_str("invalid character")
302    }
303}
304
305type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
306
307const DEFAULT_LIMIT: usize = 24;
308
309#[cfg(test)]
310pub mod tests {
311    use super::*;
312
313    #[test]
314    fn test_clean() {
315        Ascii::new("Can't buy me love!").unwrap_err();
316        let text = Ascii::new("cantbuymelove").unwrap();
317        assert_eq!(&text[0..text.len()], "cantbuymelove");
318        let text_with_numbers = Ascii::new("c4ntbuym3l0v3").unwrap();
319        assert_eq!(
320            &text_with_numbers[0..text_with_numbers.len()],
321            "c4ntbuym3l0v3"
322        );
323    }
324}