tokenizers/processors/
sequence.rs

1use crate::processors::PostProcessorWrapper;
2use crate::tokenizer::{Encoding, PostProcessor, Result};
3use crate::utils::macro_rules_attribute;
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, PartialEq, Eq)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct Sequence {
9    processors: Vec<PostProcessorWrapper>,
10}
11
12impl Sequence {
13    pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
14        Self { processors }
15    }
16}
17
18impl PostProcessor for Sequence {
19    fn added_tokens(&self, is_pair: bool) -> usize {
20        self.processors
21            .iter()
22            .map(|p| p.added_tokens(is_pair))
23            .sum::<usize>()
24    }
25
26    fn process_encodings(
27        &self,
28        mut encodings: Vec<Encoding>,
29        add_special_tokens: bool,
30    ) -> Result<Vec<Encoding>> {
31        for processor in &self.processors {
32            encodings = processor.process_encodings(encodings, add_special_tokens)?;
33        }
34        Ok(encodings)
35    }
36}
37
38#[cfg(test)]
39mod tests {
40    use super::*;
41    use crate::processors::{ByteLevel, PostProcessorWrapper};
42    use crate::tokenizer::{Encoding, PostProcessor};
43    use std::collections::HashMap;
44    use std::iter::FromIterator;
45
46    #[test]
47    fn process_chain() {
48        let start = Encoding::new(
49            vec![0; 5],
50            vec![0; 5],
51            vec![
52                "Ġ".into(),
53                "ĠĠĠĠHelloĠĠ".into(),
54                "ĠĠHello".into(),
55                "HelloĠĠ".into(),
56                "ĠĠĠĠ".into(),
57            ],
58            vec![],
59            vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
60            vec![],
61            vec![],
62            vec![],
63            HashMap::new(),
64        );
65
66        let bytelevel = ByteLevel::default().trim_offsets(true);
67        let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]);
68        let expected = Encoding::new(
69            vec![0; 5],
70            vec![0; 5],
71            vec![
72                "Ġ".into(),
73                "ĠĠĠĠHelloĠĠ".into(),
74                "ĠĠHello".into(),
75                "HelloĠĠ".into(),
76                "ĠĠĠĠ".into(),
77            ],
78            vec![],
79            vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
80            vec![],
81            vec![],
82            vec![],
83            HashMap::from_iter(vec![(0, 0..5)]),
84        );
85
86        assert_eq!(
87            expected,
88            bytelevel.process(start.clone(), None, false).unwrap()
89        );
90        assert_eq!(
91            expected,
92            sequence.process(start.clone(), None, false).unwrap()
93        );
94
95        let pair_expected = Encoding::new(
96            vec![0; 10],
97            vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
98            vec![
99                "Ġ".into(),
100                "ĠĠĠĠHelloĠĠ".into(),
101                "ĠĠHello".into(),
102                "HelloĠĠ".into(),
103                "ĠĠĠĠ".into(),
104                "Ġ".into(),
105                "ĠĠĠĠHelloĠĠ".into(),
106                "ĠĠHello".into(),
107                "HelloĠĠ".into(),
108                "ĠĠĠĠ".into(),
109            ],
110            vec![],
111            vec![
112                (0, 0),
113                (4, 9),
114                (13, 18),
115                (18, 23),
116                (29, 29),
117                (0, 0),
118                (4, 9),
119                (13, 18),
120                (18, 23),
121                (29, 29),
122            ],
123            vec![],
124            vec![],
125            vec![],
126            HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
127        );
128        assert_eq!(
129            pair_expected,
130            bytelevel
131                .process(start.clone(), Some(start.clone()), false)
132                .unwrap()
133        );
134        assert_eq!(
135            pair_expected,
136            sequence.process(start.clone(), Some(start), false).unwrap()
137        );
138    }
139}