1use crate::tokenizer::{Encoding, Result};
2use serde::{Deserialize, Serialize};
3use std::cmp;
4use std::mem;
5
6#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)]
7pub enum TruncationDirection {
8 Left,
9 #[default]
10 Right,
11}
12
13impl std::convert::AsRef<str> for TruncationDirection {
14 fn as_ref(&self) -> &str {
15 match self {
16 TruncationDirection::Left => "left",
17 TruncationDirection::Right => "right",
18 }
19 }
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TruncationParams {
24 #[serde(default)]
25 pub direction: TruncationDirection,
26 pub max_length: usize,
27 pub strategy: TruncationStrategy,
28 pub stride: usize,
29}
30
31impl Default for TruncationParams {
32 fn default() -> Self {
33 Self {
34 max_length: 512,
35 strategy: TruncationStrategy::default(),
36 stride: 0,
37 direction: TruncationDirection::default(),
38 }
39 }
40}
41
42#[derive(thiserror::Error, Debug)]
43pub enum TruncationError {
44 #[error("Truncation error: Second sequence not provided")]
46 SecondSequenceNotProvided,
47 #[error("Truncation error: Sequence to truncate too short to respect the provided max_length")]
49 SequenceTooShort,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
53pub enum TruncationStrategy {
54 LongestFirst,
55 OnlyFirst,
56 OnlySecond,
57}
58
59impl Default for TruncationStrategy {
60 fn default() -> Self {
61 Self::LongestFirst
62 }
63}
64
65impl std::convert::AsRef<str> for TruncationStrategy {
66 fn as_ref(&self) -> &str {
67 match self {
68 Self::LongestFirst => "longest_first",
69 Self::OnlyFirst => "only_first",
70 Self::OnlySecond => "only_second",
71 }
72 }
73}
74
75pub fn truncate_encodings(
76 mut encoding: Encoding,
77 mut pair_encoding: Option<Encoding>,
78 params: &TruncationParams,
79) -> Result<(Encoding, Option<Encoding>)> {
80 if params.max_length == 0 {
81 encoding.truncate(0, params.stride, params.direction);
82 if let Some(other_encoding) = pair_encoding.as_mut() {
83 other_encoding.truncate(0, params.stride, params.direction);
84 }
85 return Ok((encoding, pair_encoding));
86 }
87
88 let total_length = encoding.get_ids().len()
89 + pair_encoding
90 .as_ref()
91 .map(|e| e.get_ids().len())
92 .unwrap_or(0);
93 let to_remove = if total_length > params.max_length {
94 total_length - params.max_length
95 } else {
96 return Ok((encoding, pair_encoding));
97 };
98
99 match params.strategy {
100 TruncationStrategy::LongestFirst => {
101 if let Some(other_encoding) = pair_encoding.as_mut() {
102 let mut n1 = encoding.get_ids().len();
116 let mut n2 = other_encoding.get_ids().len();
117 let mut swap = false;
118
119 if n1 > n2 {
121 swap = true;
122 mem::swap(&mut n1, &mut n2);
123 }
124
125 if n1 > params.max_length {
126 n2 = n1;
130 } else {
131 n2 = cmp::max(n1, params.max_length - n1);
132 }
133
134 if n1 + n2 > params.max_length {
135 n1 = params.max_length / 2;
136 n2 = n1 + params.max_length % 2;
137 }
138
139 if swap {
141 mem::swap(&mut n1, &mut n2);
142 }
143 encoding.truncate(n1, params.stride, params.direction);
144 other_encoding.truncate(n2, params.stride, params.direction);
145 } else {
146 encoding.truncate(total_length - to_remove, params.stride, params.direction);
147 }
148 }
149 TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
150 let target = if params.strategy == TruncationStrategy::OnlyFirst {
151 Ok(&mut encoding)
152 } else if let Some(encoding) = pair_encoding.as_mut() {
153 Ok(encoding)
154 } else {
155 Err(Box::new(TruncationError::SecondSequenceNotProvided))
156 }?;
157
158 let target_len = target.get_ids().len();
159 if target_len > to_remove {
160 target.truncate(target_len - to_remove, params.stride, params.direction);
161 } else {
162 return Err(Box::new(TruncationError::SequenceTooShort));
163 }
164 }
165 }
166 Ok((encoding, pair_encoding))
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::tokenizer::Encoding;
173 use std::collections::HashMap;
174
175 fn get_empty() -> Encoding {
176 Encoding::new(
177 vec![],
178 vec![],
179 vec![],
180 vec![],
181 vec![],
182 vec![],
183 vec![],
184 vec![],
185 HashMap::new(),
186 )
187 }
188
189 fn get_short() -> Encoding {
190 Encoding::new(
191 vec![1, 2],
192 vec![0, 0],
193 vec![String::from("a"), String::from("b")],
194 vec![Some(0), Some(1)],
195 vec![(0, 1), (1, 2)],
196 vec![0, 0],
197 vec![1, 1],
198 vec![],
199 HashMap::new(),
200 )
201 }
202
203 fn get_medium() -> Encoding {
204 Encoding::new(
205 vec![3, 4, 5, 6],
206 vec![0, 0, 0, 0],
207 vec![
208 String::from("d"),
209 String::from("e"),
210 String::from("f"),
211 String::from("g"),
212 ],
213 vec![Some(0), Some(1), Some(2), Some(3)],
214 vec![(0, 1), (1, 2), (2, 3), (3, 4)],
215 vec![0, 0, 0, 0],
216 vec![1, 1, 1, 1],
217 vec![],
218 HashMap::new(),
219 )
220 }
221
222 fn get_long() -> Encoding {
223 Encoding::new(
224 vec![7, 8, 9, 10, 11, 12, 13, 14],
225 vec![0, 0, 0, 0, 0, 0, 0, 0],
226 vec![
227 String::from("h"),
228 String::from("i"),
229 String::from("j"),
230 String::from("k"),
231 String::from("l"),
232 String::from("m"),
233 String::from("n"),
234 String::from("o"),
235 ],
236 vec![
237 Some(0),
238 Some(1),
239 Some(2),
240 Some(3),
241 Some(4),
242 Some(5),
243 Some(6),
244 Some(7),
245 ],
246 vec![
247 (0, 1),
248 (1, 2),
249 (2, 3),
250 (3, 4),
251 (4, 5),
252 (5, 6),
253 (6, 7),
254 (6, 8),
255 ],
256 vec![0, 0, 0, 0, 0, 0, 0, 0],
257 vec![1, 1, 1, 1, 1, 1, 1, 1],
258 vec![],
259 HashMap::new(),
260 )
261 }
262
263 fn truncate_and_assert(
264 encoding1: Encoding,
265 encoding2: Encoding,
266 params: &TruncationParams,
267 n1: usize,
268 n2: usize,
269 ) {
270 match truncate_encodings(encoding1, Some(encoding2), params) {
271 Ok((e1, Some(e2))) => {
272 assert!(e1.get_ids().len() == n1);
273 assert!(e2.get_ids().len() == n2);
274 }
275 _ => panic!(),
276 };
277 }
278
279 #[test]
280 fn truncate_encodings_longest_first() {
281 let params = TruncationParams {
282 max_length: 7,
283 strategy: TruncationStrategy::LongestFirst,
284 stride: 0,
285 direction: TruncationDirection::Right,
286 };
287
288 truncate_and_assert(get_empty(), get_empty(), ¶ms, 0, 0);
289 truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 2);
290 truncate_and_assert(get_empty(), get_medium(), ¶ms, 0, 4);
291 truncate_and_assert(get_empty(), get_long(), ¶ms, 0, 7);
292
293 truncate_and_assert(get_short(), get_empty(), ¶ms, 2, 0);
294 truncate_and_assert(get_short(), get_short(), ¶ms, 2, 2);
295 truncate_and_assert(get_short(), get_medium(), ¶ms, 2, 4);
296 truncate_and_assert(get_short(), get_long(), ¶ms, 2, 5);
297
298 truncate_and_assert(get_medium(), get_empty(), ¶ms, 4, 0);
299 truncate_and_assert(get_medium(), get_short(), ¶ms, 4, 2);
300 truncate_and_assert(get_medium(), get_medium(), ¶ms, 3, 4);
301 truncate_and_assert(get_medium(), get_long(), ¶ms, 3, 4);
302
303 truncate_and_assert(get_long(), get_empty(), ¶ms, 7, 0);
304 truncate_and_assert(get_long(), get_short(), ¶ms, 5, 2);
305 truncate_and_assert(get_long(), get_medium(), ¶ms, 4, 3);
306 truncate_and_assert(get_long(), get_long(), ¶ms, 3, 4);
307 }
308
309 #[test]
310 fn truncate_encodings_empty() {
311 let params = TruncationParams {
312 max_length: 0,
313 strategy: TruncationStrategy::LongestFirst,
314 stride: 0,
315 direction: TruncationDirection::Right,
316 };
317
318 truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 0);
319 truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0);
320 truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0);
321 }
322
323 #[test]
324 fn test_deserialize_defaults() {
325 let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
326
327 let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
328
329 assert_eq!(params.direction, TruncationDirection::Right);
330 }
331}