Skip to main content

entrenar/hf_pipeline/dataset/
collator.rs

1//! Distillation collator for batching examples
2
3use ndarray::Array2;
4
5use super::batch::Batch;
6use super::dataset_impl::Dataset;
7use super::example::Example;
8
9/// Collator for batching examples with dynamic padding
10pub struct DistillationCollator {
11    /// Padding token ID
12    pub pad_token_id: u32,
13    /// Maximum sequence length (truncate if longer)
14    pub max_length: usize,
15    /// Padding side (true = left, false = right)
16    pub pad_left: bool,
17}
18
19impl Default for DistillationCollator {
20    fn default() -> Self {
21        Self { pad_token_id: 0, max_length: 512, pad_left: false }
22    }
23}
24
25impl DistillationCollator {
26    /// Create new collator
27    #[must_use]
28    pub fn new(pad_token_id: u32) -> Self {
29        Self { pad_token_id, ..Default::default() }
30    }
31
32    /// Set maximum length
33    #[must_use]
34    pub fn max_length(mut self, len: usize) -> Self {
35        self.max_length = len;
36        self
37    }
38
39    /// Set padding side
40    #[must_use]
41    pub fn pad_left(mut self, left: bool) -> Self {
42        self.pad_left = left;
43        self
44    }
45
46    /// Collate examples into a batch
47    #[must_use]
48    pub fn collate(&self, examples: &[Example]) -> Batch {
49        if examples.is_empty() {
50            return Batch {
51                input_ids: Array2::zeros((0, 0)),
52                attention_mask: Array2::zeros((0, 0)),
53                labels: None,
54                lengths: vec![],
55            };
56        }
57
58        // Find max length in batch (capped at max_length)
59        let max_len = examples.iter().map(|e| e.len().min(self.max_length)).max().unwrap_or(0);
60
61        let batch_size = examples.len();
62        let mut input_ids = Array2::from_elem((batch_size, max_len), self.pad_token_id);
63        let mut attention_mask = Array2::zeros((batch_size, max_len));
64        let mut lengths = Vec::with_capacity(batch_size);
65
66        let has_labels = examples.iter().any(|e| e.labels.is_some());
67        let mut labels = if has_labels {
68            Some(Array2::from_elem((batch_size, max_len), self.pad_token_id))
69        } else {
70            None
71        };
72
73        for (i, example) in examples.iter().enumerate() {
74            let seq_len = example.len().min(self.max_length);
75            lengths.push(seq_len);
76
77            let (start, end) =
78                if self.pad_left { (max_len - seq_len, max_len) } else { (0, seq_len) };
79
80            // Copy input IDs
81            for (j, &token) in example.input_ids.iter().take(seq_len).enumerate() {
82                input_ids[[i, start + j]] = token;
83            }
84
85            // Set attention mask
86            for j in start..end {
87                attention_mask[[i, j]] = 1;
88            }
89
90            // Copy labels if present
91            if let (Some(ref mut label_arr), Some(ref ex_labels)) = (&mut labels, &example.labels) {
92                for (j, &token) in ex_labels.iter().take(seq_len).enumerate() {
93                    label_arr[[i, start + j]] = token;
94                }
95            }
96        }
97
98        Batch { input_ids, attention_mask, labels, lengths }
99    }
100
101    /// Create batches from dataset
102    pub fn batch_dataset(&self, dataset: &Dataset, batch_size: usize) -> Vec<Batch> {
103        dataset.examples().chunks(batch_size).map(|chunk| self.collate(chunk)).collect()
104    }
105}