1use burn::prelude::*;
2use burn::tensor::{Int, TensorData};
3use serde::{Deserialize, Serialize};
4
5use crate::error::DataValidationError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum TailStrategy {
9 Pad,
10 Drop,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct TokenBatchingConfig {
15 pub batch_size: usize,
16 pub seq_len: usize,
17 pub stride: usize,
18 pub pad_token: usize,
19 pub tail_strategy: TailStrategy,
20}
21
22impl TokenBatchingConfig {
23 pub fn new(batch_size: usize, seq_len: usize) -> Self {
24 Self::try_new(batch_size, seq_len)
25 .unwrap_or_else(|error| panic!("invalid token batching configuration: {error}"))
26 }
27
28 pub fn try_new(batch_size: usize, seq_len: usize) -> Result<Self, DataValidationError> {
29 let config = Self {
30 batch_size,
31 seq_len,
32 stride: seq_len.saturating_sub(1).max(1),
33 pad_token: 0,
34 tail_strategy: TailStrategy::Pad,
35 };
36 config.validate()?;
37 Ok(config)
38 }
39
40 pub fn validate(&self) -> Result<(), DataValidationError> {
41 if self.batch_size == 0 {
42 return Err(DataValidationError::InvalidBatchSize(self.batch_size));
43 }
44 if self.seq_len == 0 {
45 return Err(DataValidationError::InvalidSeqLen(self.seq_len));
46 }
47 if self.stride == 0 {
48 return Err(DataValidationError::InvalidStride(self.stride));
49 }
50 Ok(())
51 }
52
53 pub fn with_stride(self, stride: usize) -> Self {
54 self.try_with_stride(stride)
55 .unwrap_or_else(|error| panic!("invalid token batching configuration: {error}"))
56 }
57
58 pub fn try_with_stride(mut self, stride: usize) -> Result<Self, DataValidationError> {
59 self.stride = stride;
60 self.validate()?;
61 Ok(self)
62 }
63
64 pub fn with_pad_token(mut self, pad_token: usize) -> Self {
65 self.pad_token = pad_token;
66 self
67 }
68
69 pub fn with_tail_strategy(mut self, tail_strategy: TailStrategy) -> Self {
70 self.tail_strategy = tail_strategy;
71 self
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct TokenBatch {
77 tokens: Vec<usize>,
78 shape: [usize; 2],
79 sequence_lengths: Vec<usize>,
80 pad_token: usize,
81}
82
83impl TokenBatch {
84 fn new(
85 rows: Vec<Vec<usize>>,
86 sequence_lengths: Vec<usize>,
87 seq_len: usize,
88 pad_token: usize,
89 ) -> Self {
90 assert!(
91 !rows.is_empty(),
92 "token batches must contain at least one row"
93 );
94 assert_eq!(
95 rows.len(),
96 sequence_lengths.len(),
97 "rows and sequence_lengths must align"
98 );
99 assert!(
100 rows.iter().all(|row| row.len() == seq_len),
101 "all rows must match the configured seq_len"
102 );
103
104 let batch_size = rows.len();
105 let tokens = rows.into_iter().flatten().collect();
106
107 Self {
108 tokens,
109 shape: [batch_size, seq_len],
110 sequence_lengths,
111 pad_token,
112 }
113 }
114
115 pub fn batch_size(&self) -> usize {
116 self.shape[0]
117 }
118
119 pub fn seq_len(&self) -> usize {
120 self.shape[1]
121 }
122
123 pub fn sequence_lengths(&self) -> &[usize] {
124 &self.sequence_lengths
125 }
126
127 pub fn num_tokens(&self) -> usize {
128 self.sequence_lengths.iter().sum()
129 }
130
131 pub fn num_predictions(&self) -> usize {
132 self.sequence_lengths
133 .iter()
134 .map(|length| length.saturating_sub(1))
135 .sum()
136 }
137
138 pub fn num_padded_tokens(&self) -> usize {
139 self.batch_size() * self.seq_len() - self.num_tokens()
140 }
141
142 pub fn pad_token(&self) -> usize {
143 self.pad_token
144 }
145
146 pub fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2, Int> {
147 let data = TensorData::new(
148 self.tokens
149 .iter()
150 .map(|token| *token as i64)
151 .collect::<Vec<_>>(),
152 self.shape,
153 );
154 Tensor::<B, 2, Int>::from_data(data, device)
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
159pub struct TokenDatasetSummary {
160 pub num_source_tokens: usize,
161 pub source_fingerprint: u64,
162 pub num_batches: usize,
163 pub num_sequences: usize,
164 pub seq_len: usize,
165 pub max_batch_size: usize,
166 pub num_predictions: usize,
167 pub num_padded_tokens: usize,
168}
169
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct TokenDataset {
172 config: TokenBatchingConfig,
173 batches: Vec<TokenBatch>,
174 num_source_tokens: usize,
175 source_fingerprint: u64,
176 num_sequences: usize,
177 num_predictions: usize,
178 num_padded_tokens: usize,
179}
180
181impl TokenDataset {
182 fn new(
183 config: TokenBatchingConfig,
184 batches: Vec<TokenBatch>,
185 num_source_tokens: usize,
186 source_fingerprint: u64,
187 num_sequences: usize,
188 num_predictions: usize,
189 num_padded_tokens: usize,
190 ) -> Self {
191 Self {
192 config,
193 batches,
194 num_source_tokens,
195 source_fingerprint,
196 num_sequences,
197 num_predictions,
198 num_padded_tokens,
199 }
200 }
201
202 pub fn batches(&self) -> &[TokenBatch] {
203 &self.batches
204 }
205
206 pub fn config(&self) -> &TokenBatchingConfig {
207 &self.config
208 }
209
210 pub fn num_batches(&self) -> usize {
211 self.batches.len()
212 }
213
214 pub fn source_fingerprint(&self) -> u64 {
215 self.source_fingerprint
216 }
217
218 pub fn num_sequences(&self) -> usize {
219 self.num_sequences
220 }
221
222 pub fn num_predictions(&self) -> usize {
223 self.num_predictions
224 }
225
226 pub fn num_padded_tokens(&self) -> usize {
227 self.num_padded_tokens
228 }
229
230 pub fn summary(&self) -> TokenDatasetSummary {
231 TokenDatasetSummary {
232 num_source_tokens: self.num_source_tokens,
233 source_fingerprint: self.source_fingerprint,
234 num_batches: self.num_batches(),
235 num_sequences: self.num_sequences,
236 seq_len: self.config.seq_len,
237 max_batch_size: self.config.batch_size,
238 num_predictions: self.num_predictions,
239 num_padded_tokens: self.num_padded_tokens,
240 }
241 }
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
245pub struct TokenBatcher {
246 config: TokenBatchingConfig,
247}
248
249impl TokenBatcher {
250 pub fn new(config: TokenBatchingConfig) -> Self {
251 Self::try_new(config)
252 .unwrap_or_else(|error| panic!("invalid token batching configuration: {error}"))
253 }
254
255 pub fn try_new(config: TokenBatchingConfig) -> Result<Self, DataValidationError> {
256 config.validate()?;
257 Ok(Self { config })
258 }
259
260 pub fn config(&self) -> &TokenBatchingConfig {
261 &self.config
262 }
263
264 pub fn batch_tokens(&self, tokens: &[usize]) -> TokenDataset {
265 let windows = self.windows(tokens);
266 let num_sequences = windows.len();
267 let source_fingerprint = fingerprint_tokens(tokens);
268 let mut batches = Vec::new();
269 let mut rows = Vec::with_capacity(self.config.batch_size);
270 let mut sequence_lengths = Vec::with_capacity(self.config.batch_size);
271 let mut num_predictions = 0;
272 let mut num_padded_tokens = 0;
273
274 for (row, valid_len) in windows {
275 num_predictions += valid_len.saturating_sub(1);
276 num_padded_tokens += self.config.seq_len - valid_len;
277 rows.push(row);
278 sequence_lengths.push(valid_len);
279
280 if rows.len() == self.config.batch_size {
281 batches.push(TokenBatch::new(
282 core::mem::take(&mut rows),
283 core::mem::take(&mut sequence_lengths),
284 self.config.seq_len,
285 self.config.pad_token,
286 ));
287 }
288 }
289
290 if !rows.is_empty() {
291 batches.push(TokenBatch::new(
292 rows,
293 sequence_lengths,
294 self.config.seq_len,
295 self.config.pad_token,
296 ));
297 }
298
299 TokenDataset::new(
300 self.config.clone(),
301 batches,
302 tokens.len(),
303 source_fingerprint,
304 num_sequences,
305 num_predictions,
306 num_padded_tokens,
307 )
308 }
309
310 fn windows(&self, tokens: &[usize]) -> Vec<(Vec<usize>, usize)> {
311 if tokens.len() < 2 {
312 return Vec::new();
313 }
314
315 let mut start = 0;
316 let mut windows = Vec::new();
317
318 while start + 1 < tokens.len() {
319 let end = (start + self.config.seq_len).min(tokens.len());
320 let mut row = tokens[start..end].to_vec();
321 let valid_len = row.len();
322
323 if valid_len < 2 {
324 break;
325 }
326
327 if valid_len < self.config.seq_len {
328 if matches!(self.config.tail_strategy, TailStrategy::Drop) {
329 break;
330 }
331 row.resize(self.config.seq_len, self.config.pad_token);
332 }
333
334 windows.push((row, valid_len));
335 start += self.config.stride;
336 }
337
338 windows
339 }
340}
341
342fn fingerprint_tokens(tokens: &[usize]) -> u64 {
343 const OFFSET_BASIS: u64 = 0xcbf29ce484222325;
344 const FNV_PRIME: u64 = 0x100000001b3;
345
346 let mut fingerprint = OFFSET_BASIS;
347 for token in tokens.iter().copied().map(|token| token as u64) {
348 for byte in token.to_le_bytes() {
349 fingerprint ^= u64::from(byte);
350 fingerprint = fingerprint.wrapping_mul(FNV_PRIME);
351 }
352 }
353 for byte in (tokens.len() as u64).to_le_bytes() {
354 fingerprint ^= u64::from(byte);
355 fingerprint = fingerprint.wrapping_mul(FNV_PRIME);
356 }
357
358 fingerprint
359}