Skip to main content

deep_delta_learning/
data.rs

1use burn::prelude::*;
2use burn::tensor::{Int, TensorData};
3use serde::{Deserialize, Serialize};
4
5use crate::error::DataValidationError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum TailStrategy {
9    Pad,
10    Drop,
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct TokenBatchingConfig {
15    pub batch_size: usize,
16    pub seq_len: usize,
17    pub stride: usize,
18    pub pad_token: usize,
19    pub tail_strategy: TailStrategy,
20}
21
22impl TokenBatchingConfig {
23    pub fn new(batch_size: usize, seq_len: usize) -> Self {
24        Self::try_new(batch_size, seq_len)
25            .unwrap_or_else(|error| panic!("invalid token batching configuration: {error}"))
26    }
27
28    pub fn try_new(batch_size: usize, seq_len: usize) -> Result<Self, DataValidationError> {
29        let config = Self {
30            batch_size,
31            seq_len,
32            stride: seq_len.saturating_sub(1).max(1),
33            pad_token: 0,
34            tail_strategy: TailStrategy::Pad,
35        };
36        config.validate()?;
37        Ok(config)
38    }
39
40    pub fn validate(&self) -> Result<(), DataValidationError> {
41        if self.batch_size == 0 {
42            return Err(DataValidationError::InvalidBatchSize(self.batch_size));
43        }
44        if self.seq_len == 0 {
45            return Err(DataValidationError::InvalidSeqLen(self.seq_len));
46        }
47        if self.stride == 0 {
48            return Err(DataValidationError::InvalidStride(self.stride));
49        }
50        Ok(())
51    }
52
53    pub fn with_stride(self, stride: usize) -> Self {
54        self.try_with_stride(stride)
55            .unwrap_or_else(|error| panic!("invalid token batching configuration: {error}"))
56    }
57
58    pub fn try_with_stride(mut self, stride: usize) -> Result<Self, DataValidationError> {
59        self.stride = stride;
60        self.validate()?;
61        Ok(self)
62    }
63
64    pub fn with_pad_token(mut self, pad_token: usize) -> Self {
65        self.pad_token = pad_token;
66        self
67    }
68
69    pub fn with_tail_strategy(mut self, tail_strategy: TailStrategy) -> Self {
70        self.tail_strategy = tail_strategy;
71        self
72    }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
76pub struct TokenBatch {
77    tokens: Vec<usize>,
78    shape: [usize; 2],
79    sequence_lengths: Vec<usize>,
80    pad_token: usize,
81}
82
83impl TokenBatch {
84    fn new(
85        rows: Vec<Vec<usize>>,
86        sequence_lengths: Vec<usize>,
87        seq_len: usize,
88        pad_token: usize,
89    ) -> Self {
90        assert!(
91            !rows.is_empty(),
92            "token batches must contain at least one row"
93        );
94        assert_eq!(
95            rows.len(),
96            sequence_lengths.len(),
97            "rows and sequence_lengths must align"
98        );
99        assert!(
100            rows.iter().all(|row| row.len() == seq_len),
101            "all rows must match the configured seq_len"
102        );
103
104        let batch_size = rows.len();
105        let tokens = rows.into_iter().flatten().collect();
106
107        Self {
108            tokens,
109            shape: [batch_size, seq_len],
110            sequence_lengths,
111            pad_token,
112        }
113    }
114
115    pub fn batch_size(&self) -> usize {
116        self.shape[0]
117    }
118
119    pub fn seq_len(&self) -> usize {
120        self.shape[1]
121    }
122
123    pub fn sequence_lengths(&self) -> &[usize] {
124        &self.sequence_lengths
125    }
126
127    pub fn num_tokens(&self) -> usize {
128        self.sequence_lengths.iter().sum()
129    }
130
131    pub fn num_predictions(&self) -> usize {
132        self.sequence_lengths
133            .iter()
134            .map(|length| length.saturating_sub(1))
135            .sum()
136    }
137
138    pub fn num_padded_tokens(&self) -> usize {
139        self.batch_size() * self.seq_len() - self.num_tokens()
140    }
141
142    pub fn pad_token(&self) -> usize {
143        self.pad_token
144    }
145
146    pub fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2, Int> {
147        let data = TensorData::new(
148            self.tokens
149                .iter()
150                .map(|token| *token as i64)
151                .collect::<Vec<_>>(),
152            self.shape,
153        );
154        Tensor::<B, 2, Int>::from_data(data, device)
155    }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
159pub struct TokenDatasetSummary {
160    pub num_source_tokens: usize,
161    pub source_fingerprint: u64,
162    pub num_batches: usize,
163    pub num_sequences: usize,
164    pub seq_len: usize,
165    pub max_batch_size: usize,
166    pub num_predictions: usize,
167    pub num_padded_tokens: usize,
168}
169
170#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
171pub struct TokenDataset {
172    config: TokenBatchingConfig,
173    batches: Vec<TokenBatch>,
174    num_source_tokens: usize,
175    source_fingerprint: u64,
176    num_sequences: usize,
177    num_predictions: usize,
178    num_padded_tokens: usize,
179}
180
181impl TokenDataset {
182    fn new(
183        config: TokenBatchingConfig,
184        batches: Vec<TokenBatch>,
185        num_source_tokens: usize,
186        source_fingerprint: u64,
187        num_sequences: usize,
188        num_predictions: usize,
189        num_padded_tokens: usize,
190    ) -> Self {
191        Self {
192            config,
193            batches,
194            num_source_tokens,
195            source_fingerprint,
196            num_sequences,
197            num_predictions,
198            num_padded_tokens,
199        }
200    }
201
202    pub fn batches(&self) -> &[TokenBatch] {
203        &self.batches
204    }
205
206    pub fn config(&self) -> &TokenBatchingConfig {
207        &self.config
208    }
209
210    pub fn num_batches(&self) -> usize {
211        self.batches.len()
212    }
213
214    pub fn source_fingerprint(&self) -> u64 {
215        self.source_fingerprint
216    }
217
218    pub fn num_sequences(&self) -> usize {
219        self.num_sequences
220    }
221
222    pub fn num_predictions(&self) -> usize {
223        self.num_predictions
224    }
225
226    pub fn num_padded_tokens(&self) -> usize {
227        self.num_padded_tokens
228    }
229
230    pub fn summary(&self) -> TokenDatasetSummary {
231        TokenDatasetSummary {
232            num_source_tokens: self.num_source_tokens,
233            source_fingerprint: self.source_fingerprint,
234            num_batches: self.num_batches(),
235            num_sequences: self.num_sequences,
236            seq_len: self.config.seq_len,
237            max_batch_size: self.config.batch_size,
238            num_predictions: self.num_predictions,
239            num_padded_tokens: self.num_padded_tokens,
240        }
241    }
242}
243
244#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
245pub struct TokenBatcher {
246    config: TokenBatchingConfig,
247}
248
249impl TokenBatcher {
250    pub fn new(config: TokenBatchingConfig) -> Self {
251        Self::try_new(config)
252            .unwrap_or_else(|error| panic!("invalid token batching configuration: {error}"))
253    }
254
255    pub fn try_new(config: TokenBatchingConfig) -> Result<Self, DataValidationError> {
256        config.validate()?;
257        Ok(Self { config })
258    }
259
260    pub fn config(&self) -> &TokenBatchingConfig {
261        &self.config
262    }
263
264    pub fn batch_tokens(&self, tokens: &[usize]) -> TokenDataset {
265        let windows = self.windows(tokens);
266        let num_sequences = windows.len();
267        let source_fingerprint = fingerprint_tokens(tokens);
268        let mut batches = Vec::new();
269        let mut rows = Vec::with_capacity(self.config.batch_size);
270        let mut sequence_lengths = Vec::with_capacity(self.config.batch_size);
271        let mut num_predictions = 0;
272        let mut num_padded_tokens = 0;
273
274        for (row, valid_len) in windows {
275            num_predictions += valid_len.saturating_sub(1);
276            num_padded_tokens += self.config.seq_len - valid_len;
277            rows.push(row);
278            sequence_lengths.push(valid_len);
279
280            if rows.len() == self.config.batch_size {
281                batches.push(TokenBatch::new(
282                    core::mem::take(&mut rows),
283                    core::mem::take(&mut sequence_lengths),
284                    self.config.seq_len,
285                    self.config.pad_token,
286                ));
287            }
288        }
289
290        if !rows.is_empty() {
291            batches.push(TokenBatch::new(
292                rows,
293                sequence_lengths,
294                self.config.seq_len,
295                self.config.pad_token,
296            ));
297        }
298
299        TokenDataset::new(
300            self.config.clone(),
301            batches,
302            tokens.len(),
303            source_fingerprint,
304            num_sequences,
305            num_predictions,
306            num_padded_tokens,
307        )
308    }
309
310    fn windows(&self, tokens: &[usize]) -> Vec<(Vec<usize>, usize)> {
311        if tokens.len() < 2 {
312            return Vec::new();
313        }
314
315        let mut start = 0;
316        let mut windows = Vec::new();
317
318        while start + 1 < tokens.len() {
319            let end = (start + self.config.seq_len).min(tokens.len());
320            let mut row = tokens[start..end].to_vec();
321            let valid_len = row.len();
322
323            if valid_len < 2 {
324                break;
325            }
326
327            if valid_len < self.config.seq_len {
328                if matches!(self.config.tail_strategy, TailStrategy::Drop) {
329                    break;
330                }
331                row.resize(self.config.seq_len, self.config.pad_token);
332            }
333
334            windows.push((row, valid_len));
335            start += self.config.stride;
336        }
337
338        windows
339    }
340}
341
342fn fingerprint_tokens(tokens: &[usize]) -> u64 {
343    const OFFSET_BASIS: u64 = 0xcbf29ce484222325;
344    const FNV_PRIME: u64 = 0x100000001b3;
345
346    let mut fingerprint = OFFSET_BASIS;
347    for token in tokens.iter().copied().map(|token| token as u64) {
348        for byte in token.to_le_bytes() {
349            fingerprint ^= u64::from(byte);
350            fingerprint = fingerprint.wrapping_mul(FNV_PRIME);
351        }
352    }
353    for byte in (tokens.len() as u64).to_le_bytes() {
354        fingerprint ^= u64::from(byte);
355        fingerprint = fingerprint.wrapping_mul(FNV_PRIME);
356    }
357
358    fingerprint
359}