1#![warn(missing_docs)]
5#[cfg(test)]
6#[macro_use]
7extern crate quickcheck_macros;
8
9use lazy_static::lazy_static;
10use ndarray::prelude::*;
11use serde_derive::{Deserialize, Serialize};
12use std::cmp;
13use std::collections::HashMap;
14use std::ops::Range;
15
16#[cfg(feature = "tract-backend")]
18pub mod tract_backend;
19#[cfg(feature = "tract-backend")]
20pub use tract_backend::NNSplit;
21
22#[cfg(feature = "model-loader")]
24pub mod model_loader;
25
26#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
28pub struct Level(pub String);
29
30#[derive(Debug)]
32pub enum Split<'a> {
33 Text(&'a str),
35 Split((&'a str, Vec<Split<'a>>)),
37}
38
39impl<'a> Split<'a> {
40 pub fn text(&self) -> &'a str {
42 match self {
43 Split::Split((text, _)) => text,
44 Split::Text(text) => text,
45 }
46 }
47
48 pub fn iter(&self) -> impl Iterator<Item = &Split<'a>> {
52 match self {
53 Split::Split((_, splits)) => splits.iter(),
54 Split::Text(_) => panic!("Can not iterate over Split::Text."),
55 }
56 }
57
58 pub fn flatten(&self, level: usize) -> Vec<&str> {
60 match self {
61 Split::Text(text) => vec![text],
62 Split::Split((_, parts)) => {
63 let mut out = Vec::new();
64
65 for part in parts {
66 if level == 0 {
67 out.push(part.text());
68 } else {
69 out.extend(part.flatten(level - 1));
70 }
71 }
72
73 out
74 }
75 }
76 }
77}
78
79fn split_whitespace(input: &str) -> Vec<&str> {
80 let offset = input.trim_end().len();
81 vec![&input[..offset], &input[offset..]]
82}
83
84type SplitFunction = fn(&str) -> Vec<&str>;
85
86lazy_static! {
87 static ref SPLIT_FUNCTIONS: HashMap<&'static str, SplitFunction> = {
88 let mut map = HashMap::new();
89 map.insert("whitespace", split_whitespace as SplitFunction);
90 map
91 };
92}
93
94#[derive(Serialize, Deserialize)]
95pub enum SplitInstruction {
97 PredictionIndex(usize),
99 Function(String),
101}
102
103#[derive(Serialize, Deserialize)]
104pub struct SplitSequence {
106 instructions: Vec<(Level, SplitInstruction)>,
107}
108
109impl SplitSequence {
110 pub fn new(instructions: Vec<(Level, SplitInstruction)>) -> Self {
112 SplitSequence { instructions }
113 }
114
115 pub fn get_levels(&self) -> Vec<&Level> {
117 self.instructions.iter().map(|(level, _)| level).collect()
118 }
119
120 fn inner_apply<'a>(
121 &self,
122 text: &'a str,
123 predictions: ArrayView2<f32>,
124 threshold: f32,
125 instruction_idx: usize,
126 ) -> Split<'a> {
127 assert_eq!(
128 predictions.shape()[0],
129 text.len(),
130 "length of predictions must be equal to the number of bytes in text"
131 );
132
133 if let Some((_, instruction)) = self.instructions.get(instruction_idx) {
134 match instruction {
135 SplitInstruction::PredictionIndex(idx) => {
136 let mut indices: Vec<_> = predictions
137 .slice(s![.., *idx])
138 .indexed_iter()
139 .filter_map(|(index, &item)| {
140 if item > threshold {
141 Some(index + 1)
142 } else {
143 None
144 }
145 })
146 .collect();
147
148 if indices.is_empty() || indices[indices.len() - 1] != text.len() {
149 indices.push(text.len());
150 }
151
152 let mut parts = Vec::new();
153 let mut prev = 0;
154
155 for raw_idx in indices {
156 if prev >= raw_idx {
157 continue;
158 }
159
160 let mut idx = raw_idx;
161
162 let part = loop {
163 if let Some(part) = text.get(prev..idx) {
164 break part;
165 }
166 idx += 1;
167 };
168
169 parts.push(self.inner_apply(
170 part,
171 predictions.slice(s![prev..idx, ..]),
172 threshold,
173 instruction_idx + 1,
174 ));
175
176 prev = idx;
177 }
178
179 Split::Split((text, parts))
180 }
181 SplitInstruction::Function(func_name) => Split::Split((
182 text,
183 (*SPLIT_FUNCTIONS.get(func_name.as_str()).unwrap())(text)
184 .iter()
185 .map(|part| {
186 let start = part.as_ptr() as usize - text.as_ptr() as usize;
187 let end = start + part.len();
188
189 self.inner_apply(
190 part,
191 predictions.slice(s![start..end, ..]),
192 threshold,
193 instruction_idx + 1,
194 )
195 })
196 .collect::<Vec<Split>>(),
197 )),
198 }
199 } else {
200 Split::Text(text)
201 }
202 }
203
204 fn apply<'a>(&self, text: &'a str, predictions: ArrayView2<f32>, threshold: f32) -> Split<'a> {
205 self.inner_apply(text, predictions, threshold, 0)
206 }
207}
208
209#[derive(Serialize, Deserialize)]
211#[serde(deny_unknown_fields)]
212pub struct NNSplitOptions {
213 #[serde(default = "NNSplitOptions::default_threshold")]
215 pub threshold: f32,
216 #[serde(default = "NNSplitOptions::default_stride")]
218 pub stride: usize,
219 #[serde(alias = "maxLength", default = "NNSplitOptions::default_max_length")]
221 pub max_length: usize,
222 #[serde(default = "NNSplitOptions::default_padding")]
224 pub padding: usize,
225 #[serde(
227 alias = "paddingDivisor",
228 default = "NNSplitOptions::default_length_divisor"
229 )]
230 pub length_divisor: usize,
231 #[serde(alias = "batchSize", default = "NNSplitOptions::default_batch_size")]
233 pub batch_size: usize,
234}
235
236impl NNSplitOptions {
237 fn default_threshold() -> f32 {
238 0.8
239 }
240
241 fn default_stride() -> usize {
242 NNSplitOptions::default_max_length() / 2
243 }
244
245 fn default_max_length() -> usize {
246 500
247 }
248
249 fn default_padding() -> usize {
250 5
251 }
252
253 fn default_batch_size() -> usize {
254 256
255 }
256
257 fn default_length_divisor() -> usize {
258 2
259 }
260}
261
262impl Default for NNSplitOptions {
263 fn default() -> Self {
264 NNSplitOptions {
265 threshold: NNSplitOptions::default_threshold(),
266 stride: NNSplitOptions::default_stride(),
267 max_length: NNSplitOptions::default_max_length(),
268 padding: NNSplitOptions::default_padding(),
269 batch_size: NNSplitOptions::default_batch_size(),
270 length_divisor: NNSplitOptions::default_length_divisor(),
271 }
272 }
273}
274
275pub struct NNSplitLogic {
277 options: NNSplitOptions,
278 split_sequence: SplitSequence,
279}
280
281impl NNSplitLogic {
282 pub fn new(options: NNSplitOptions, split_sequence: SplitSequence) -> Self {
287 if options.max_length % options.length_divisor != 0 {
288 panic!("max length must be divisible by length divisor.")
289 }
290
291 NNSplitLogic {
292 options,
293 split_sequence,
294 }
295 }
296
297 #[inline]
299 pub fn options(&self) -> &NNSplitOptions {
300 &self.options
301 }
302
303 #[inline]
305 pub fn split_sequence(&self) -> &SplitSequence {
306 &self.split_sequence
307 }
308
309 fn pad(&self, length: usize) -> usize {
310 let padded = length + self.options.padding * 2;
311 let remainder = padded % self.options.length_divisor;
312
313 if remainder == 0 {
314 padded
315 } else {
316 padded + (self.options.length_divisor - remainder)
317 }
318 }
319
320 pub fn get_inputs_and_indices(
324 &self,
325 texts: &[&str],
326 ) -> (Array2<u8>, Vec<(usize, Range<usize>)>) {
327 let maxlen = cmp::min(
328 texts.iter().map(|x| self.pad(x.len())).max().unwrap_or(0),
329 self.options.max_length,
330 );
331
332 let (all_inputs, all_indices) = texts
333 .iter()
334 .enumerate()
335 .map(|(i, text)| {
336 let mut text_inputs: Vec<u8> = Vec::new();
337 let mut text_indices: Vec<(usize, Range<usize>)> = Vec::new();
338
339 let length = self.pad(text.len());
340 let mut inputs = vec![0; length];
341
342 for (j, byte) in text.bytes().enumerate() {
343 inputs[j + self.options.padding] = byte;
344 }
345
346 let mut start = 0;
347 let mut end = 0;
348
349 while end != length {
350 end = cmp::min(start + self.options.max_length, length);
351 start = if self.options.max_length > end {
352 0
353 } else {
354 end - self.options.max_length
355 };
356
357 let mut input_slice = vec![0u8; maxlen];
358 input_slice[..end - start].copy_from_slice(&inputs[start..end]);
359
360 text_inputs.extend(input_slice);
361 text_indices.push((i, start..end));
362
363 start += self.options.stride;
364 }
365
366 (text_inputs, text_indices)
367 })
368 .fold(
369 (Vec::<u8>::new(), Vec::<(usize, Range<usize>)>::new()),
370 |mut acc, (text_inputs, text_indices)| {
371 acc.0.extend(text_inputs);
372 acc.1.extend(text_indices);
373
374 acc
375 },
376 );
377
378 let input_array = Array2::from_shape_vec((all_indices.len(), maxlen), all_inputs).unwrap();
379 (input_array, all_indices)
380 }
381
382 fn combine_predictions(
383 &self,
384 slice_predictions: ArrayView3<f32>,
385 indices: Vec<(usize, Range<usize>)>,
386 lengths: Vec<usize>,
387 ) -> Vec<Array2<f32>> {
388 let pred_dim = slice_predictions.shape()[2];
389 let mut preds_and_counts = lengths
390 .iter()
391 .map(|x| (Array2::zeros((*x, pred_dim)), Array2::zeros((*x, 1))))
392 .collect::<Vec<_>>();
393
394 for (slice_pred, (index, range)) in slice_predictions.outer_iter().zip(indices) {
395 let (pred, count) = preds_and_counts
396 .get_mut(index)
397 .expect("slice index must be in bounds");
398
399 let mut pred_slice = pred.slice_mut(s![range.start..range.end, ..]);
400 pred_slice += &slice_pred.slice(s![..range.end - range.start, ..]);
401
402 let mut count_slice = count.slice_mut(s![range.start..range.end, ..]);
403 count_slice += 1f32;
404 }
405
406 preds_and_counts
407 .into_iter()
408 .map(|(pred, count): (Array2<f32>, Array2<f32>)| (pred / count))
409 .collect()
410 }
411
412 pub fn split<'a>(
415 &self,
416 texts: &[&'a str],
417 slice_preds: Array3<f32>,
418 indices: Vec<(usize, Range<usize>)>,
419 ) -> Vec<Split<'a>> {
420 let padded_preds = self.combine_predictions(
421 (&slice_preds).into(),
422 indices,
423 texts.iter().map(|x| self.pad(x.len())).collect(),
424 );
425
426 let preds = padded_preds
427 .iter()
428 .zip(texts)
429 .map(|(x, text)| {
430 x.slice(s![
431 self.options.padding..self.options.padding + text.len(),
432 ..
433 ])
434 })
435 .collect::<Vec<_>>();
436
437 texts
438 .iter()
439 .zip(preds)
440 .map(|(text, pred)| {
441 self.split_sequence
442 .apply(text, pred, self.options.threshold)
443 })
444 .collect()
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use rand::{thread_rng, Rng};
452
453 struct DummyNNSplit {
454 logic: NNSplitLogic,
455 }
456
457 impl DummyNNSplit {
458 fn new(options: NNSplitOptions) -> Self {
459 DummyNNSplit {
460 logic: NNSplitLogic::new(
461 options,
462 SplitSequence::new(vec![
463 (
464 Level("Sentence".into()),
465 SplitInstruction::PredictionIndex(0),
466 ),
467 (Level("Token".into()), SplitInstruction::PredictionIndex(1)),
468 (
469 Level("Whitespace".into()),
470 SplitInstruction::Function("whitespace".into()),
471 ),
472 ]),
473 ),
474 }
475 }
476
477 fn predict(&self, input: Array2<u8>) -> Array3<f32> {
478 let n = input.shape()[0];
479 let length = input.shape()[1];
480 let dim = 2usize;
481
482 let mut rng = thread_rng();
483
484 let mut blob = Vec::new();
485 for _ in 0..n * length * dim {
486 blob.push(rng.gen_range(0.0..1.0));
487 }
488
489 Array3::from_shape_vec((n, length, dim), blob).unwrap()
490 }
491
492 pub fn split<'a>(&self, texts: &[&'a str]) -> Vec<Split<'a>> {
493 let (input, indices) = self.logic.get_inputs_and_indices(texts);
494 let slice_preds = self.predict(input);
495
496 self.logic.split(texts, slice_preds, indices)
497 }
498 }
499
500 #[test]
501 fn split_instructions_work() {
502 let instructions = SplitSequence::new(vec![
503 (Level("Token".into()), SplitInstruction::PredictionIndex(0)),
504 (
505 Level("Whitespace".into()),
506 SplitInstruction::Function("whitespace".into()),
507 ),
508 ]);
509
510 let input = "This is a test.";
511 let mut predictions = array![[0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1.]];
512 predictions.swap_axes(0, 1);
513 let predictions: ArrayView2<f32> = (&predictions).into();
514
515 let splits = instructions.apply(input, predictions, 0.5);
516 assert_eq!(splits.flatten(0), ["This ", "is ", "a ", "test", "."]);
517 assert_eq!(
518 splits.flatten(1),
519 ["This", " ", "is", " ", "a", " ", "test", "", ".", ""]
520 );
521 }
522
523 #[test]
524 fn splitter_works() {
525 let options = NNSplitOptions {
526 stride: 5,
527 max_length: 20,
528 ..NNSplitOptions::default()
529 };
530 let splitter = DummyNNSplit::new(options);
531
532 splitter.split(&["This is a short test.", "This is another short test."]);
535 }
536
537 #[test]
538 fn splitter_works_on_empty_input() {
539 let splitter = DummyNNSplit::new(NNSplitOptions::default());
540
541 let splits = splitter.split(&[]);
542 assert!(splits.is_empty());
543 }
544
545 #[quickcheck]
546 fn length_invariant(text: String) -> bool {
547 let splitter = DummyNNSplit::new(NNSplitOptions::default());
548
549 let split = &splitter.split(&[&text])[0];
550
551 let mut sums: Vec<usize> = Vec::new();
552 sums.push(split.iter().map(|x| x.text().len()).sum());
553
554 for i in 0..4 {
555 sums.push(split.flatten(i).iter().map(|x| x.len()).sum());
556 }
557
558 sums.into_iter().all(|sum| sum == text.len())
559 }
560}