nnsplit/
lib.rs

1//! Fast, robust text splitting with bindings for Python, Rust and Javascript. This crate contains the core splitting logic which is shared between Javascript, Python and Rust. Each binding then implements a backend separately.
2//!
3//! See [`tract_backend::NNSplit`](tract_backend/struct.NNSplit.html) for information for using NNSplit from Rust.
4#![warn(missing_docs)]
5#[cfg(test)]
6#[macro_use]
7extern crate quickcheck_macros;
8
9use lazy_static::lazy_static;
10use ndarray::prelude::*;
11use serde_derive::{Deserialize, Serialize};
12use std::cmp;
13use std::collections::HashMap;
14use std::ops::Range;
15
16/// Backend to run models using tch-rs.
17#[cfg(feature = "tract-backend")]
18pub mod tract_backend;
19#[cfg(feature = "tract-backend")]
20pub use tract_backend::NNSplit;
21
22/// Caching and downloading of models.
23#[cfg(feature = "model-loader")]
24pub mod model_loader;
25
26/// A Split level, used to describe what this split corresponds to (e. g. a sentence).
27#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
28pub struct Level(pub String);
29
30/// A splitted text.
31#[derive(Debug)]
32pub enum Split<'a> {
33    /// The lowest level of split.
34    Text(&'a str),
35    /// A split which contains one or more smaller splits.
36    Split((&'a str, Vec<Split<'a>>)),
37}
38
39impl<'a> Split<'a> {
40    /// Returns the encompassed text.
41    pub fn text(&self) -> &'a str {
42        match self {
43            Split::Split((text, _)) => text,
44            Split::Text(text) => text,
45        }
46    }
47
48    /// Iterate over smaller splits.
49    /// # Panics
50    /// * If the Split is a `Split::Text` because the lowest level of split can not be iterated over.
51    pub fn iter(&self) -> impl Iterator<Item = &Split<'a>> {
52        match self {
53            Split::Split((_, splits)) => splits.iter(),
54            Split::Text(_) => panic!("Can not iterate over Split::Text."),
55        }
56    }
57
58    /// Recursively flatten the split. Returns a vector where each item is the text of the split at the lowest level.
59    pub fn flatten(&self, level: usize) -> Vec<&str> {
60        match self {
61            Split::Text(text) => vec![text],
62            Split::Split((_, parts)) => {
63                let mut out = Vec::new();
64
65                for part in parts {
66                    if level == 0 {
67                        out.push(part.text());
68                    } else {
69                        out.extend(part.flatten(level - 1));
70                    }
71                }
72
73                out
74            }
75        }
76    }
77}
78
79fn split_whitespace(input: &str) -> Vec<&str> {
80    let offset = input.trim_end().len();
81    vec![&input[..offset], &input[offset..]]
82}
83
84type SplitFunction = fn(&str) -> Vec<&str>;
85
86lazy_static! {
87    static ref SPLIT_FUNCTIONS: HashMap<&'static str, SplitFunction> = {
88        let mut map = HashMap::new();
89        map.insert("whitespace", split_whitespace as SplitFunction);
90        map
91    };
92}
93
94#[derive(Serialize, Deserialize)]
95/// Instruction to split text.
96pub enum SplitInstruction {
97    /// Instruction to split at the given index of the neural network predictions.
98    PredictionIndex(usize),
99    /// Instruction to split according to a function.
100    Function(String),
101}
102
103#[derive(Serialize, Deserialize)]
104/// Instructions for how to convert neural network outputs and a text to `Split` objects.
105pub struct SplitSequence {
106    instructions: Vec<(Level, SplitInstruction)>,
107}
108
109impl SplitSequence {
110    /// Creates a new split sequence. Contains instructions for how to use model predictions to split a text.
111    pub fn new(instructions: Vec<(Level, SplitInstruction)>) -> Self {
112        SplitSequence { instructions }
113    }
114
115    /// Gets the levels of this split sequence, from top (larger) to bottom (smaller).
116    pub fn get_levels(&self) -> Vec<&Level> {
117        self.instructions.iter().map(|(level, _)| level).collect()
118    }
119
120    fn inner_apply<'a>(
121        &self,
122        text: &'a str,
123        predictions: ArrayView2<f32>,
124        threshold: f32,
125        instruction_idx: usize,
126    ) -> Split<'a> {
127        assert_eq!(
128            predictions.shape()[0],
129            text.len(),
130            "length of predictions must be equal to the number of bytes in text"
131        );
132
133        if let Some((_, instruction)) = self.instructions.get(instruction_idx) {
134            match instruction {
135                SplitInstruction::PredictionIndex(idx) => {
136                    let mut indices: Vec<_> = predictions
137                        .slice(s![.., *idx])
138                        .indexed_iter()
139                        .filter_map(|(index, &item)| {
140                            if item > threshold {
141                                Some(index + 1)
142                            } else {
143                                None
144                            }
145                        })
146                        .collect();
147
148                    if indices.is_empty() || indices[indices.len() - 1] != text.len() {
149                        indices.push(text.len());
150                    }
151
152                    let mut parts = Vec::new();
153                    let mut prev = 0;
154
155                    for raw_idx in indices {
156                        if prev >= raw_idx {
157                            continue;
158                        }
159
160                        let mut idx = raw_idx;
161
162                        let part = loop {
163                            if let Some(part) = text.get(prev..idx) {
164                                break part;
165                            }
166                            idx += 1;
167                        };
168
169                        parts.push(self.inner_apply(
170                            part,
171                            predictions.slice(s![prev..idx, ..]),
172                            threshold,
173                            instruction_idx + 1,
174                        ));
175
176                        prev = idx;
177                    }
178
179                    Split::Split((text, parts))
180                }
181                SplitInstruction::Function(func_name) => Split::Split((
182                    text,
183                    (*SPLIT_FUNCTIONS.get(func_name.as_str()).unwrap())(text)
184                        .iter()
185                        .map(|part| {
186                            let start = part.as_ptr() as usize - text.as_ptr() as usize;
187                            let end = start + part.len();
188
189                            self.inner_apply(
190                                part,
191                                predictions.slice(s![start..end, ..]),
192                                threshold,
193                                instruction_idx + 1,
194                            )
195                        })
196                        .collect::<Vec<Split>>(),
197                )),
198            }
199        } else {
200            Split::Text(text)
201        }
202    }
203
204    fn apply<'a>(&self, text: &'a str, predictions: ArrayView2<f32>, threshold: f32) -> Split<'a> {
205        self.inner_apply(text, predictions, threshold, 0)
206    }
207}
208
209/// Options for splitting text.
210#[derive(Serialize, Deserialize)]
211#[serde(deny_unknown_fields)]
212pub struct NNSplitOptions {
213    /// Threshold from 0 to 1 above which predictions will be considered positive.
214    #[serde(default = "NNSplitOptions::default_threshold")]
215    pub threshold: f32,
216    /// How much to move the window after each prediction (comparable to stride of 1d convolution).
217    #[serde(default = "NNSplitOptions::default_stride")]
218    pub stride: usize,
219    /// The maximum length of each cut (comparable to kernel size of 1d convolution).
220    #[serde(alias = "maxLength", default = "NNSplitOptions::default_max_length")]
221    pub max_length: usize,
222    /// How much to zero pad the text on both sides.
223    #[serde(default = "NNSplitOptions::default_padding")]
224    pub padding: usize,
225    /// Total length will be padded until it is divisible by this number. Allows some additional optimizations.
226    #[serde(
227        alias = "paddingDivisor",
228        default = "NNSplitOptions::default_length_divisor"
229    )]
230    pub length_divisor: usize,
231    /// Batch size to use.
232    #[serde(alias = "batchSize", default = "NNSplitOptions::default_batch_size")]
233    pub batch_size: usize,
234}
235
236impl NNSplitOptions {
237    fn default_threshold() -> f32 {
238        0.8
239    }
240
241    fn default_stride() -> usize {
242        NNSplitOptions::default_max_length() / 2
243    }
244
245    fn default_max_length() -> usize {
246        500
247    }
248
249    fn default_padding() -> usize {
250        5
251    }
252
253    fn default_batch_size() -> usize {
254        256
255    }
256
257    fn default_length_divisor() -> usize {
258        2
259    }
260}
261
262impl Default for NNSplitOptions {
263    fn default() -> Self {
264        NNSplitOptions {
265            threshold: NNSplitOptions::default_threshold(),
266            stride: NNSplitOptions::default_stride(),
267            max_length: NNSplitOptions::default_max_length(),
268            padding: NNSplitOptions::default_padding(),
269            batch_size: NNSplitOptions::default_batch_size(),
270            length_divisor: NNSplitOptions::default_length_divisor(),
271        }
272    }
273}
274
275/// The logic by which texts are split.
276pub struct NNSplitLogic {
277    options: NNSplitOptions,
278    split_sequence: SplitSequence,
279}
280
281impl NNSplitLogic {
282    /// Create new logic from options and a split sequence.
283    ///
284    /// # Panics
285    /// - If the options are invalid, e. g. max_length % length_divisor != 0
286    pub fn new(options: NNSplitOptions, split_sequence: SplitSequence) -> Self {
287        if options.max_length % options.length_divisor != 0 {
288            panic!("max length must be divisible by length divisor.")
289        }
290
291        NNSplitLogic {
292            options,
293            split_sequence,
294        }
295    }
296
297    /// Get the underlying NNSplitOptions.
298    #[inline]
299    pub fn options(&self) -> &NNSplitOptions {
300        &self.options
301    }
302
303    /// Get the underlying SplitSequence.
304    #[inline]
305    pub fn split_sequence(&self) -> &SplitSequence {
306        &self.split_sequence
307    }
308
309    fn pad(&self, length: usize) -> usize {
310        let padded = length + self.options.padding * 2;
311        let remainder = padded % self.options.length_divisor;
312
313        if remainder == 0 {
314            padded
315        } else {
316            padded + (self.options.length_divisor - remainder)
317        }
318    }
319
320    /// Convert texts to neural network inputs. Returns:
321    /// * An `ndarray::Array2` which can be fed into the neural network as is.
322    /// * A vector of indices with information which positions in the text the array elements correspond to.
323    pub fn get_inputs_and_indices(
324        &self,
325        texts: &[&str],
326    ) -> (Array2<u8>, Vec<(usize, Range<usize>)>) {
327        let maxlen = cmp::min(
328            texts.iter().map(|x| self.pad(x.len())).max().unwrap_or(0),
329            self.options.max_length,
330        );
331
332        let (all_inputs, all_indices) = texts
333            .iter()
334            .enumerate()
335            .map(|(i, text)| {
336                let mut text_inputs: Vec<u8> = Vec::new();
337                let mut text_indices: Vec<(usize, Range<usize>)> = Vec::new();
338
339                let length = self.pad(text.len());
340                let mut inputs = vec![0; length];
341
342                for (j, byte) in text.bytes().enumerate() {
343                    inputs[j + self.options.padding] = byte;
344                }
345
346                let mut start = 0;
347                let mut end = 0;
348
349                while end != length {
350                    end = cmp::min(start + self.options.max_length, length);
351                    start = if self.options.max_length > end {
352                        0
353                    } else {
354                        end - self.options.max_length
355                    };
356
357                    let mut input_slice = vec![0u8; maxlen];
358                    input_slice[..end - start].copy_from_slice(&inputs[start..end]);
359
360                    text_inputs.extend(input_slice);
361                    text_indices.push((i, start..end));
362
363                    start += self.options.stride;
364                }
365
366                (text_inputs, text_indices)
367            })
368            .fold(
369                (Vec::<u8>::new(), Vec::<(usize, Range<usize>)>::new()),
370                |mut acc, (text_inputs, text_indices)| {
371                    acc.0.extend(text_inputs);
372                    acc.1.extend(text_indices);
373
374                    acc
375                },
376            );
377
378        let input_array = Array2::from_shape_vec((all_indices.len(), maxlen), all_inputs).unwrap();
379        (input_array, all_indices)
380    }
381
382    fn combine_predictions(
383        &self,
384        slice_predictions: ArrayView3<f32>,
385        indices: Vec<(usize, Range<usize>)>,
386        lengths: Vec<usize>,
387    ) -> Vec<Array2<f32>> {
388        let pred_dim = slice_predictions.shape()[2];
389        let mut preds_and_counts = lengths
390            .iter()
391            .map(|x| (Array2::zeros((*x, pred_dim)), Array2::zeros((*x, 1))))
392            .collect::<Vec<_>>();
393
394        for (slice_pred, (index, range)) in slice_predictions.outer_iter().zip(indices) {
395            let (pred, count) = preds_and_counts
396                .get_mut(index)
397                .expect("slice index must be in bounds");
398
399            let mut pred_slice = pred.slice_mut(s![range.start..range.end, ..]);
400            pred_slice += &slice_pred.slice(s![..range.end - range.start, ..]);
401
402            let mut count_slice = count.slice_mut(s![range.start..range.end, ..]);
403            count_slice += 1f32;
404        }
405
406        preds_and_counts
407            .into_iter()
408            .map(|(pred, count): (Array2<f32>, Array2<f32>)| (pred / count))
409            .collect()
410    }
411
412    /// Splits the text, given predictions by a neural network and indices
413    /// with information which positions in the text the predictions correspond to.
414    pub fn split<'a>(
415        &self,
416        texts: &[&'a str],
417        slice_preds: Array3<f32>,
418        indices: Vec<(usize, Range<usize>)>,
419    ) -> Vec<Split<'a>> {
420        let padded_preds = self.combine_predictions(
421            (&slice_preds).into(),
422            indices,
423            texts.iter().map(|x| self.pad(x.len())).collect(),
424        );
425
426        let preds = padded_preds
427            .iter()
428            .zip(texts)
429            .map(|(x, text)| {
430                x.slice(s![
431                    self.options.padding..self.options.padding + text.len(),
432                    ..
433                ])
434            })
435            .collect::<Vec<_>>();
436
437        texts
438            .iter()
439            .zip(preds)
440            .map(|(text, pred)| {
441                self.split_sequence
442                    .apply(text, pred, self.options.threshold)
443            })
444            .collect()
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use rand::{thread_rng, Rng};
452
453    struct DummyNNSplit {
454        logic: NNSplitLogic,
455    }
456
457    impl DummyNNSplit {
458        fn new(options: NNSplitOptions) -> Self {
459            DummyNNSplit {
460                logic: NNSplitLogic::new(
461                    options,
462                    SplitSequence::new(vec![
463                        (
464                            Level("Sentence".into()),
465                            SplitInstruction::PredictionIndex(0),
466                        ),
467                        (Level("Token".into()), SplitInstruction::PredictionIndex(1)),
468                        (
469                            Level("Whitespace".into()),
470                            SplitInstruction::Function("whitespace".into()),
471                        ),
472                    ]),
473                ),
474            }
475        }
476
477        fn predict(&self, input: Array2<u8>) -> Array3<f32> {
478            let n = input.shape()[0];
479            let length = input.shape()[1];
480            let dim = 2usize;
481
482            let mut rng = thread_rng();
483
484            let mut blob = Vec::new();
485            for _ in 0..n * length * dim {
486                blob.push(rng.gen_range(0.0..1.0));
487            }
488
489            Array3::from_shape_vec((n, length, dim), blob).unwrap()
490        }
491
492        pub fn split<'a>(&self, texts: &[&'a str]) -> Vec<Split<'a>> {
493            let (input, indices) = self.logic.get_inputs_and_indices(texts);
494            let slice_preds = self.predict(input);
495
496            self.logic.split(texts, slice_preds, indices)
497        }
498    }
499
500    #[test]
501    fn split_instructions_work() {
502        let instructions = SplitSequence::new(vec![
503            (Level("Token".into()), SplitInstruction::PredictionIndex(0)),
504            (
505                Level("Whitespace".into()),
506                SplitInstruction::Function("whitespace".into()),
507            ),
508        ]);
509
510        let input = "This is a test.";
511        let mut predictions = array![[0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1.]];
512        predictions.swap_axes(0, 1);
513        let predictions: ArrayView2<f32> = (&predictions).into();
514
515        let splits = instructions.apply(input, predictions, 0.5);
516        assert_eq!(splits.flatten(0), ["This ", "is ", "a ", "test", "."]);
517        assert_eq!(
518            splits.flatten(1),
519            ["This", " ", "is", " ", "a", " ", "test", "", ".", ""]
520        );
521    }
522
523    #[test]
524    fn splitter_works() {
525        let options = NNSplitOptions {
526            stride: 5,
527            max_length: 20,
528            ..NNSplitOptions::default()
529        };
530        let splitter = DummyNNSplit::new(options);
531
532        // sample text must only contain chars which are 1 byte long, so that `DummyNNSplit`
533        // can not generate splits which are not char boundaries
534        splitter.split(&["This is a short test.", "This is another short test."]);
535    }
536
537    #[test]
538    fn splitter_works_on_empty_input() {
539        let splitter = DummyNNSplit::new(NNSplitOptions::default());
540
541        let splits = splitter.split(&[]);
542        assert!(splits.is_empty());
543    }
544
545    #[quickcheck]
546    fn length_invariant(text: String) -> bool {
547        let splitter = DummyNNSplit::new(NNSplitOptions::default());
548
549        let split = &splitter.split(&[&text])[0];
550
551        let mut sums: Vec<usize> = Vec::new();
552        sums.push(split.iter().map(|x| x.text().len()).sum());
553
554        for i in 0..4 {
555            sums.push(split.flatten(i).iter().map(|x| x.len()).sum());
556        }
557
558        sums.into_iter().all(|sum| sum == text.len())
559    }
560}