tokenizers/utils/
truncation.rs

1use crate::tokenizer::{Encoding, Result};
2use serde::{Deserialize, Serialize};
3use std::cmp;
4use std::mem;
5
6#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)]
7pub enum TruncationDirection {
8    Left,
9    #[default]
10    Right,
11}
12
13impl std::convert::AsRef<str> for TruncationDirection {
14    fn as_ref(&self) -> &str {
15        match self {
16            TruncationDirection::Left => "left",
17            TruncationDirection::Right => "right",
18        }
19    }
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TruncationParams {
24    #[serde(default)]
25    pub direction: TruncationDirection,
26    pub max_length: usize,
27    pub strategy: TruncationStrategy,
28    pub stride: usize,
29}
30
31impl Default for TruncationParams {
32    fn default() -> Self {
33        Self {
34            max_length: 512,
35            strategy: TruncationStrategy::default(),
36            stride: 0,
37            direction: TruncationDirection::default(),
38        }
39    }
40}
41
42#[derive(thiserror::Error, Debug)]
43pub enum TruncationError {
44    /// We are supposed to truncate the pair sequence, but it has not been provided.
45    #[error("Truncation error: Second sequence not provided")]
46    SecondSequenceNotProvided,
47    /// We cannot truncate the target sequence enough to respect the provided max length.
48    #[error("Truncation error: Sequence to truncate too short to respect the provided max_length")]
49    SequenceTooShort,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
53pub enum TruncationStrategy {
54    LongestFirst,
55    OnlyFirst,
56    OnlySecond,
57}
58
59impl Default for TruncationStrategy {
60    fn default() -> Self {
61        Self::LongestFirst
62    }
63}
64
65impl std::convert::AsRef<str> for TruncationStrategy {
66    fn as_ref(&self) -> &str {
67        match self {
68            Self::LongestFirst => "longest_first",
69            Self::OnlyFirst => "only_first",
70            Self::OnlySecond => "only_second",
71        }
72    }
73}
74
75pub fn truncate_encodings(
76    mut encoding: Encoding,
77    mut pair_encoding: Option<Encoding>,
78    params: &TruncationParams,
79) -> Result<(Encoding, Option<Encoding>)> {
80    if params.max_length == 0 {
81        encoding.truncate(0, params.stride, params.direction);
82        if let Some(other_encoding) = pair_encoding.as_mut() {
83            other_encoding.truncate(0, params.stride, params.direction);
84        }
85        return Ok((encoding, pair_encoding));
86    }
87
88    let total_length = encoding.get_ids().len()
89        + pair_encoding
90            .as_ref()
91            .map(|e| e.get_ids().len())
92            .unwrap_or(0);
93    let to_remove = if total_length > params.max_length {
94        total_length - params.max_length
95    } else {
96        return Ok((encoding, pair_encoding));
97    };
98
99    match params.strategy {
100        TruncationStrategy::LongestFirst => {
101            if let Some(other_encoding) = pair_encoding.as_mut() {
102                // Assuming n1 <= n2, there are 3 cases
103                // Case 1:
104                //   No truncation needs to be performed.
105                //   This scenario is handled before the match.
106                // Case 2:
107                //   Only the longer input needs to be truncated.
108                //   n1 = n1
109                //   n2 = max_length - n1
110                // Case 3:
111                //   Both inputs must be truncated.
112                //   n1 = max_length / 2
113                //   n2 = n1 + max_length % 2
114
115                let mut n1 = encoding.get_ids().len();
116                let mut n2 = other_encoding.get_ids().len();
117                let mut swap = false;
118
119                // Ensure n1 is the length of the shortest input
120                if n1 > n2 {
121                    swap = true;
122                    mem::swap(&mut n1, &mut n2);
123                }
124
125                if n1 > params.max_length {
126                    // This needs to be a special case
127                    // to avoid max_length - n1 < 0
128                    // since n1 and n2 are unsigned
129                    n2 = n1;
130                } else {
131                    n2 = cmp::max(n1, params.max_length - n1);
132                }
133
134                if n1 + n2 > params.max_length {
135                    n1 = params.max_length / 2;
136                    n2 = n1 + params.max_length % 2;
137                }
138
139                // Swap lengths if we swapped previosuly
140                if swap {
141                    mem::swap(&mut n1, &mut n2);
142                }
143                encoding.truncate(n1, params.stride, params.direction);
144                other_encoding.truncate(n2, params.stride, params.direction);
145            } else {
146                encoding.truncate(total_length - to_remove, params.stride, params.direction);
147            }
148        }
149        TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
150            let target = if params.strategy == TruncationStrategy::OnlyFirst {
151                Ok(&mut encoding)
152            } else if let Some(encoding) = pair_encoding.as_mut() {
153                Ok(encoding)
154            } else {
155                Err(Box::new(TruncationError::SecondSequenceNotProvided))
156            }?;
157
158            let target_len = target.get_ids().len();
159            if target_len > to_remove {
160                target.truncate(target_len - to_remove, params.stride, params.direction);
161            } else {
162                return Err(Box::new(TruncationError::SequenceTooShort));
163            }
164        }
165    }
166    Ok((encoding, pair_encoding))
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::tokenizer::Encoding;
173    use std::collections::HashMap;
174
175    fn get_empty() -> Encoding {
176        Encoding::new(
177            vec![],
178            vec![],
179            vec![],
180            vec![],
181            vec![],
182            vec![],
183            vec![],
184            vec![],
185            HashMap::new(),
186        )
187    }
188
189    fn get_short() -> Encoding {
190        Encoding::new(
191            vec![1, 2],
192            vec![0, 0],
193            vec![String::from("a"), String::from("b")],
194            vec![Some(0), Some(1)],
195            vec![(0, 1), (1, 2)],
196            vec![0, 0],
197            vec![1, 1],
198            vec![],
199            HashMap::new(),
200        )
201    }
202
203    fn get_medium() -> Encoding {
204        Encoding::new(
205            vec![3, 4, 5, 6],
206            vec![0, 0, 0, 0],
207            vec![
208                String::from("d"),
209                String::from("e"),
210                String::from("f"),
211                String::from("g"),
212            ],
213            vec![Some(0), Some(1), Some(2), Some(3)],
214            vec![(0, 1), (1, 2), (2, 3), (3, 4)],
215            vec![0, 0, 0, 0],
216            vec![1, 1, 1, 1],
217            vec![],
218            HashMap::new(),
219        )
220    }
221
222    fn get_long() -> Encoding {
223        Encoding::new(
224            vec![7, 8, 9, 10, 11, 12, 13, 14],
225            vec![0, 0, 0, 0, 0, 0, 0, 0],
226            vec![
227                String::from("h"),
228                String::from("i"),
229                String::from("j"),
230                String::from("k"),
231                String::from("l"),
232                String::from("m"),
233                String::from("n"),
234                String::from("o"),
235            ],
236            vec![
237                Some(0),
238                Some(1),
239                Some(2),
240                Some(3),
241                Some(4),
242                Some(5),
243                Some(6),
244                Some(7),
245            ],
246            vec![
247                (0, 1),
248                (1, 2),
249                (2, 3),
250                (3, 4),
251                (4, 5),
252                (5, 6),
253                (6, 7),
254                (6, 8),
255            ],
256            vec![0, 0, 0, 0, 0, 0, 0, 0],
257            vec![1, 1, 1, 1, 1, 1, 1, 1],
258            vec![],
259            HashMap::new(),
260        )
261    }
262
263    fn truncate_and_assert(
264        encoding1: Encoding,
265        encoding2: Encoding,
266        params: &TruncationParams,
267        n1: usize,
268        n2: usize,
269    ) {
270        match truncate_encodings(encoding1, Some(encoding2), params) {
271            Ok((e1, Some(e2))) => {
272                assert!(e1.get_ids().len() == n1);
273                assert!(e2.get_ids().len() == n2);
274            }
275            _ => panic!(),
276        };
277    }
278
279    #[test]
280    fn truncate_encodings_longest_first() {
281        let params = TruncationParams {
282            max_length: 7,
283            strategy: TruncationStrategy::LongestFirst,
284            stride: 0,
285            direction: TruncationDirection::Right,
286        };
287
288        truncate_and_assert(get_empty(), get_empty(), &params, 0, 0);
289        truncate_and_assert(get_empty(), get_short(), &params, 0, 2);
290        truncate_and_assert(get_empty(), get_medium(), &params, 0, 4);
291        truncate_and_assert(get_empty(), get_long(), &params, 0, 7);
292
293        truncate_and_assert(get_short(), get_empty(), &params, 2, 0);
294        truncate_and_assert(get_short(), get_short(), &params, 2, 2);
295        truncate_and_assert(get_short(), get_medium(), &params, 2, 4);
296        truncate_and_assert(get_short(), get_long(), &params, 2, 5);
297
298        truncate_and_assert(get_medium(), get_empty(), &params, 4, 0);
299        truncate_and_assert(get_medium(), get_short(), &params, 4, 2);
300        truncate_and_assert(get_medium(), get_medium(), &params, 3, 4);
301        truncate_and_assert(get_medium(), get_long(), &params, 3, 4);
302
303        truncate_and_assert(get_long(), get_empty(), &params, 7, 0);
304        truncate_and_assert(get_long(), get_short(), &params, 5, 2);
305        truncate_and_assert(get_long(), get_medium(), &params, 4, 3);
306        truncate_and_assert(get_long(), get_long(), &params, 3, 4);
307    }
308
309    #[test]
310    fn truncate_encodings_empty() {
311        let params = TruncationParams {
312            max_length: 0,
313            strategy: TruncationStrategy::LongestFirst,
314            stride: 0,
315            direction: TruncationDirection::Right,
316        };
317
318        truncate_and_assert(get_empty(), get_short(), &params, 0, 0);
319        truncate_and_assert(get_medium(), get_medium(), &params, 0, 0);
320        truncate_and_assert(get_long(), get_long(), &params, 0, 0);
321    }
322
323    #[test]
324    fn test_deserialize_defaults() {
325        let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
326
327        let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
328
329        assert_eq!(params.direction, TruncationDirection::Right);
330    }
331}