bies/
lib.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5//! The algorithms in this project convert from a BIES matrix (the output of the LSTM segmentation neural network) to concrete segment boundaries.  In BIES, B = beginning of segment; I = inside segment; E = end of segment; and S = single segment (both beginning and end).
6//!
7//! These algorithms always produce valid breakpoint positions (at grapheme cluster boundaries); they don't assume that the neural network always predicts valid positions.
8//!
9//! # Example
10//!
11//! For example, suppose you had the following BIES matrix:
12//!
13//! <pre>
14//! |   B   |   I   |   E   |   S   |
15//! |-------|-------|-------|-------|
16//! | 0.01  | 0.01  | 0.01  | 0.97  |
17//! | 0.97  | 0.01  | 0.01  | 0.01  |
18//! | 0.01  | 0.97  | 0.01  | 0.01  |
19//! | 0.01  | 0.97  | 0.01  | 0.01  |
20//! | 0.01  | 0.01  | 0.97  | 0.01  |
21//! | 0.01  | 0.01  | 0.01  | 0.97  |
22//! | 0.97  | 0.01  | 0.01  | 0.01  |
23//! | 0.01  | 0.01  | 0.97  | 0.01  |
24//! </pre>
25//!
26//! This matrix resolves to:
27//!
28//! <pre>
29//! 01234567
30//! SBIIESBE
31//! </pre>
32//!
33//! The breakpoints are then: 0, 1, 5, and 8 (four segments).
34//!
35//! However, it could be the case that the algorithm's BIES are invalid.  For example, "BEE" is invalid, because the second "E" does not terminate any word.  The purpose of the algorithms in this project is to guarantee that valid breakpoints and BIES are always outputted.
36//!
37//! # Algorithms
38//!
39//! The following algorithms are implemented:
40//!
41//! **1a:** Step through each grapheme cluster boundary in the string. Look at the BIES vectors for the code points surrounding the boundary. The only valid results at that boundary are {EB, ES, SB, SS} (breakpoint) or {II, BI, IE, BE} (no breakpoint). Take the sum of the valid breakpoint and no-breakpoint probabilities, and decide whether to insert a breakpoint based on which sum is higher. Repeat for all grapheme cluster boundaries in the string. The output is a list of word boundaries, which can be converted back into BIES if desired.
42//!
43//! **1b:** Same as 1a, but instead of taking the sum, take the individual maximum.
44//!
45//! **2a:** Step through each element in the BIES sequence. For each element, look at the triplet containing the element and both of its neighbors. By induction, assume the first element in the triplet is correct. Now, depending on whether there is a code point boundary following the element, calculate the probabilities of all valid BIES for the triplet, and based on those results, pick the most likely value for the current element.
46//!
47//! **3a:** Exhaustively check the probabilities of all possible BIES for the string. This algorithm has exponential runtime.
48
49use itertools::Itertools;
50use partial_min_max::max;
51use std::default::Default;
52use std::fmt;
53use strum::EnumIter;
54use writeable::{LengthHint, Writeable};
55
56#[derive(Clone, Debug, PartialEq, Default)]
57pub struct Breakpoints {
58    /// An ascending list of breakpoints. All elements must be between 0 and length exclusive.
59    pub breakpoints: Vec<usize>,
60    /// The total length; i.e., the limit of the final word.
61    pub length: usize,
62}
63
64#[derive(Clone, Copy, Debug, PartialEq)]
65pub struct BiesVector<F: fmt::Debug> {
66    pub b: F,
67    pub i: F,
68    pub e: F,
69    pub s: F,
70}
71
72// TODO: Consider parameterizing the f32 to a trait
73#[derive(Clone, Debug, PartialEq)]
74pub struct BiesMatrix(pub Vec<BiesVector<f32>>);
75
76#[derive(Clone, PartialEq)]
77pub struct BiesString<'a>(&'a Breakpoints);
78
79#[derive(Clone, Copy, Debug, PartialEq, EnumIter)]
80pub enum Algorithm {
81    /// Algorithm 1a: check probabilities surrounding each valid breakpoint. Switch based on the sum.
82    Alg1a,
83
84    /// Algorithm 1b: check probabilities surrounding each valid breakpoint. Switch based on the individual max.
85    Alg1b,
86
87    /// Algorithm 2: step forward through the matrix and pick the highest probability at each step
88    Alg2a,
89
90    /// Algorithm 3: exhaustively check all combinations of breakpoints to find the highest true probability
91    Alg3a,
92}
93
94impl Breakpoints {
95    pub fn from_bies_matrix(
96        algorithm: Algorithm,
97        matrix: &BiesMatrix,
98        valid_breakpoints: impl Iterator<Item = usize>,
99    ) -> Self {
100        match algorithm {
101            Algorithm::Alg1a => Self::from_bies_matrix_1a(matrix, valid_breakpoints),
102            Algorithm::Alg1b => Self::from_bies_matrix_1b(matrix, valid_breakpoints),
103            Algorithm::Alg2a => Self::from_bies_matrix_2a(matrix, valid_breakpoints),
104            Algorithm::Alg3a => Self::from_bies_matrix_3a(matrix, valid_breakpoints),
105        }
106    }
107
108    #[allow(clippy::suspicious_operation_groupings)]
109    fn from_bies_matrix_1a(
110        matrix: &BiesMatrix,
111        valid_breakpoints: impl Iterator<Item = usize>,
112    ) -> Self {
113        let mut breakpoints = vec![];
114        for i in valid_breakpoints {
115            if i == 0 || i >= matrix.0.len() {
116                // TODO: Make fail-safe
117                panic!("Invalid i value");
118            }
119            let bies1 = &matrix.0[i - 1];
120            let bies2 = &matrix.0[i];
121            let break_score =
122                bies1.e * bies2.b + bies1.e * bies2.s + bies1.s * bies2.b + bies1.s * bies2.s;
123            let nobrk_score =
124                bies1.i * bies2.i + bies1.i * bies2.e + bies1.b * bies2.i + bies1.b * bies2.e;
125            if break_score > nobrk_score {
126                breakpoints.push(i);
127            }
128        }
129        Self {
130            breakpoints,
131            length: matrix.0.len(),
132        }
133    }
134
135    fn from_bies_matrix_1b(
136        matrix: &BiesMatrix,
137        valid_breakpoints: impl Iterator<Item = usize>,
138    ) -> Self {
139        let mut breakpoints = vec![];
140        for i in valid_breakpoints {
141            if i == 0 || i >= matrix.0.len() {
142                // TODO: Make fail-safe
143                panic!("Invalid i value");
144            }
145            let bies1 = &matrix.0[i - 1];
146            let bies2 = &matrix.0[i];
147            let mut candidate = (f32::NEG_INFINITY, false);
148            candidate = max(candidate, (bies1.e * bies2.b, true));
149            candidate = max(candidate, (bies1.e * bies2.s, true));
150            candidate = max(candidate, (bies1.s * bies2.b, true));
151            candidate = max(candidate, (bies1.s * bies2.s, true));
152            candidate = max(candidate, (bies1.i * bies2.i, false));
153            candidate = max(candidate, (bies1.i * bies2.e, false));
154            candidate = max(candidate, (bies1.b * bies2.i, false));
155            candidate = max(candidate, (bies1.b * bies2.e, false));
156            if candidate.1 {
157                breakpoints.push(i);
158            }
159        }
160        Self {
161            breakpoints,
162            length: matrix.0.len(),
163        }
164    }
165
166    fn from_bies_matrix_2a(
167        matrix: &BiesMatrix,
168        mut valid_breakpoints: impl Iterator<Item = usize>,
169    ) -> Self {
170        if matrix.0.len() <= 1 {
171            return Self::default();
172        }
173        let mut breakpoints = vec![];
174        let mut inside_word = false;
175        let mut next_valid_brkpt = valid_breakpoints.next();
176        for i in 0..(matrix.0.len() - 1) {
177            let bies1 = &matrix.0[i];
178            let bies2 = &matrix.0[i + 1];
179            let is_valid_brkpt = next_valid_brkpt == Some(i + 1);
180            let mut candidate = (f32::NEG_INFINITY, false);
181            if inside_word {
182                // IE, II
183                candidate = max(candidate, (bies1.i * bies2.e, false));
184                candidate = max(candidate, (bies1.i * bies2.i, false));
185                if is_valid_brkpt {
186                    // EB, ES
187                    candidate = max(candidate, (bies1.e * bies2.b, true));
188                    candidate = max(candidate, (bies1.e * bies2.s, true));
189                }
190            } else {
191                // BI, BE
192                candidate = max(candidate, (bies1.b * bies2.i, false));
193                candidate = max(candidate, (bies1.b * bies2.e, false));
194                if is_valid_brkpt {
195                    // SB, SS
196                    candidate = max(candidate, (bies1.s * bies2.b, true));
197                    candidate = max(candidate, (bies1.s * bies2.s, true));
198                }
199            }
200            if candidate.1 {
201                breakpoints.push(i + 1);
202            }
203            inside_word = !candidate.1;
204            if is_valid_brkpt {
205                next_valid_brkpt = valid_breakpoints.next();
206            }
207        }
208        Self {
209            breakpoints,
210            length: matrix.0.len(),
211        }
212    }
213
214    fn from_bies_matrix_3a(
215        matrix: &BiesMatrix,
216        valid_breakpoints: impl Iterator<Item = usize>,
217    ) -> Self {
218        let valid_breakpoints: Vec<usize> = valid_breakpoints.collect();
219        let mut best_log_probability = f32::NEG_INFINITY;
220        let mut breakpoints: Vec<usize> = vec![];
221        for i in 0..=valid_breakpoints.len() {
222            for combo in valid_breakpoints.iter().combinations(i) {
223                let mut log_probability = 0.0;
224                let mut add_word = |i: usize, j: usize| {
225                    if i == j - 1 {
226                        log_probability += matrix.0[i].s.ln();
227                    } else {
228                        log_probability += matrix.0[i].b.ln();
229                        for k in (i + 1)..(j - 1) {
230                            log_probability += matrix.0[k].i.ln();
231                        }
232                        log_probability += matrix.0[j - 1].e.ln();
233                    }
234                };
235                let mut i = 0;
236                for j in combo.iter().copied().copied() {
237                    add_word(i, j);
238                    i = j;
239                }
240                add_word(i, matrix.0.len());
241                if log_probability > best_log_probability {
242                    best_log_probability = log_probability;
243                    breakpoints = combo.iter().copied().copied().collect();
244                }
245            }
246        }
247        Self {
248            breakpoints,
249            length: matrix.0.len(),
250        }
251    }
252}
253
254impl<'a> From<&'a Breakpoints> for BiesString<'a> {
255    fn from(other: &'a Breakpoints) -> Self {
256        Self(other)
257    }
258}
259
260impl Writeable for BiesString<'_> {
261    fn write_to<W: std::fmt::Write + ?Sized>(&self, sink: &mut W) -> std::fmt::Result {
262        let mut write_bies_word = |i: usize, j: usize| -> fmt::Result {
263            if i == j - 1 {
264                sink.write_char('s')?;
265            } else {
266                sink.write_char('b')?;
267                for _ in (i + 1)..(j - 1) {
268                    sink.write_char('i')?;
269                }
270                sink.write_char('e')?;
271            }
272            Ok(())
273        };
274        let mut i = 0;
275        for j in self.0.breakpoints.iter().copied() {
276            write_bies_word(i, j)?;
277            i = j;
278        }
279        write_bies_word(i, self.0.length)?;
280        Ok(())
281    }
282
283    fn writeable_length_hint(&self) -> writeable::LengthHint {
284        LengthHint::exact(self.0.length)
285    }
286}
287
288impl fmt::Debug for BiesString<'_> {
289    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
290        self.write_to(f)
291    }
292}
293
294writeable::impl_display_with_writeable!(BiesString<'_>);