1use itertools::Itertools;
50use partial_min_max::max;
51use std::default::Default;
52use std::fmt;
53use strum::EnumIter;
54use writeable::{LengthHint, Writeable};
55
56#[derive(Clone, Debug, PartialEq, Default)]
57pub struct Breakpoints {
58 pub breakpoints: Vec<usize>,
60 pub length: usize,
62}
63
64#[derive(Clone, Copy, Debug, PartialEq)]
65pub struct BiesVector<F: fmt::Debug> {
66 pub b: F,
67 pub i: F,
68 pub e: F,
69 pub s: F,
70}
71
72#[derive(Clone, Debug, PartialEq)]
74pub struct BiesMatrix(pub Vec<BiesVector<f32>>);
75
76#[derive(Clone, PartialEq)]
77pub struct BiesString<'a>(&'a Breakpoints);
78
79#[derive(Clone, Copy, Debug, PartialEq, EnumIter)]
80pub enum Algorithm {
81 Alg1a,
83
84 Alg1b,
86
87 Alg2a,
89
90 Alg3a,
92}
93
94impl Breakpoints {
95 pub fn from_bies_matrix(
96 algorithm: Algorithm,
97 matrix: &BiesMatrix,
98 valid_breakpoints: impl Iterator<Item = usize>,
99 ) -> Self {
100 match algorithm {
101 Algorithm::Alg1a => Self::from_bies_matrix_1a(matrix, valid_breakpoints),
102 Algorithm::Alg1b => Self::from_bies_matrix_1b(matrix, valid_breakpoints),
103 Algorithm::Alg2a => Self::from_bies_matrix_2a(matrix, valid_breakpoints),
104 Algorithm::Alg3a => Self::from_bies_matrix_3a(matrix, valid_breakpoints),
105 }
106 }
107
108 #[allow(clippy::suspicious_operation_groupings)]
109 fn from_bies_matrix_1a(
110 matrix: &BiesMatrix,
111 valid_breakpoints: impl Iterator<Item = usize>,
112 ) -> Self {
113 let mut breakpoints = vec![];
114 for i in valid_breakpoints {
115 if i == 0 || i >= matrix.0.len() {
116 panic!("Invalid i value");
118 }
119 let bies1 = &matrix.0[i - 1];
120 let bies2 = &matrix.0[i];
121 let break_score =
122 bies1.e * bies2.b + bies1.e * bies2.s + bies1.s * bies2.b + bies1.s * bies2.s;
123 let nobrk_score =
124 bies1.i * bies2.i + bies1.i * bies2.e + bies1.b * bies2.i + bies1.b * bies2.e;
125 if break_score > nobrk_score {
126 breakpoints.push(i);
127 }
128 }
129 Self {
130 breakpoints,
131 length: matrix.0.len(),
132 }
133 }
134
135 fn from_bies_matrix_1b(
136 matrix: &BiesMatrix,
137 valid_breakpoints: impl Iterator<Item = usize>,
138 ) -> Self {
139 let mut breakpoints = vec![];
140 for i in valid_breakpoints {
141 if i == 0 || i >= matrix.0.len() {
142 panic!("Invalid i value");
144 }
145 let bies1 = &matrix.0[i - 1];
146 let bies2 = &matrix.0[i];
147 let mut candidate = (f32::NEG_INFINITY, false);
148 candidate = max(candidate, (bies1.e * bies2.b, true));
149 candidate = max(candidate, (bies1.e * bies2.s, true));
150 candidate = max(candidate, (bies1.s * bies2.b, true));
151 candidate = max(candidate, (bies1.s * bies2.s, true));
152 candidate = max(candidate, (bies1.i * bies2.i, false));
153 candidate = max(candidate, (bies1.i * bies2.e, false));
154 candidate = max(candidate, (bies1.b * bies2.i, false));
155 candidate = max(candidate, (bies1.b * bies2.e, false));
156 if candidate.1 {
157 breakpoints.push(i);
158 }
159 }
160 Self {
161 breakpoints,
162 length: matrix.0.len(),
163 }
164 }
165
166 fn from_bies_matrix_2a(
167 matrix: &BiesMatrix,
168 mut valid_breakpoints: impl Iterator<Item = usize>,
169 ) -> Self {
170 if matrix.0.len() <= 1 {
171 return Self::default();
172 }
173 let mut breakpoints = vec![];
174 let mut inside_word = false;
175 let mut next_valid_brkpt = valid_breakpoints.next();
176 for i in 0..(matrix.0.len() - 1) {
177 let bies1 = &matrix.0[i];
178 let bies2 = &matrix.0[i + 1];
179 let is_valid_brkpt = next_valid_brkpt == Some(i + 1);
180 let mut candidate = (f32::NEG_INFINITY, false);
181 if inside_word {
182 candidate = max(candidate, (bies1.i * bies2.e, false));
184 candidate = max(candidate, (bies1.i * bies2.i, false));
185 if is_valid_brkpt {
186 candidate = max(candidate, (bies1.e * bies2.b, true));
188 candidate = max(candidate, (bies1.e * bies2.s, true));
189 }
190 } else {
191 candidate = max(candidate, (bies1.b * bies2.i, false));
193 candidate = max(candidate, (bies1.b * bies2.e, false));
194 if is_valid_brkpt {
195 candidate = max(candidate, (bies1.s * bies2.b, true));
197 candidate = max(candidate, (bies1.s * bies2.s, true));
198 }
199 }
200 if candidate.1 {
201 breakpoints.push(i + 1);
202 }
203 inside_word = !candidate.1;
204 if is_valid_brkpt {
205 next_valid_brkpt = valid_breakpoints.next();
206 }
207 }
208 Self {
209 breakpoints,
210 length: matrix.0.len(),
211 }
212 }
213
214 fn from_bies_matrix_3a(
215 matrix: &BiesMatrix,
216 valid_breakpoints: impl Iterator<Item = usize>,
217 ) -> Self {
218 let valid_breakpoints: Vec<usize> = valid_breakpoints.collect();
219 let mut best_log_probability = f32::NEG_INFINITY;
220 let mut breakpoints: Vec<usize> = vec![];
221 for i in 0..=valid_breakpoints.len() {
222 for combo in valid_breakpoints.iter().combinations(i) {
223 let mut log_probability = 0.0;
224 let mut add_word = |i: usize, j: usize| {
225 if i == j - 1 {
226 log_probability += matrix.0[i].s.ln();
227 } else {
228 log_probability += matrix.0[i].b.ln();
229 for k in (i + 1)..(j - 1) {
230 log_probability += matrix.0[k].i.ln();
231 }
232 log_probability += matrix.0[j - 1].e.ln();
233 }
234 };
235 let mut i = 0;
236 for j in combo.iter().copied().copied() {
237 add_word(i, j);
238 i = j;
239 }
240 add_word(i, matrix.0.len());
241 if log_probability > best_log_probability {
242 best_log_probability = log_probability;
243 breakpoints = combo.iter().copied().copied().collect();
244 }
245 }
246 }
247 Self {
248 breakpoints,
249 length: matrix.0.len(),
250 }
251 }
252}
253
254impl<'a> From<&'a Breakpoints> for BiesString<'a> {
255 fn from(other: &'a Breakpoints) -> Self {
256 Self(other)
257 }
258}
259
260impl Writeable for BiesString<'_> {
261 fn write_to<W: std::fmt::Write + ?Sized>(&self, sink: &mut W) -> std::fmt::Result {
262 let mut write_bies_word = |i: usize, j: usize| -> fmt::Result {
263 if i == j - 1 {
264 sink.write_char('s')?;
265 } else {
266 sink.write_char('b')?;
267 for _ in (i + 1)..(j - 1) {
268 sink.write_char('i')?;
269 }
270 sink.write_char('e')?;
271 }
272 Ok(())
273 };
274 let mut i = 0;
275 for j in self.0.breakpoints.iter().copied() {
276 write_bies_word(i, j)?;
277 i = j;
278 }
279 write_bies_word(i, self.0.length)?;
280 Ok(())
281 }
282
283 fn writeable_length_hint(&self) -> writeable::LengthHint {
284 LengthHint::exact(self.0.length)
285 }
286}
287
288impl fmt::Debug for BiesString<'_> {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
290 self.write_to(f)
291 }
292}
293
294writeable::impl_display_with_writeable!(BiesString<'_>);