entrenar/hf_pipeline/dataset/
collator.rs1use ndarray::Array2;
4
5use super::batch::Batch;
6use super::dataset_impl::Dataset;
7use super::example::Example;
8
9pub struct DistillationCollator {
11 pub pad_token_id: u32,
13 pub max_length: usize,
15 pub pad_left: bool,
17}
18
19impl Default for DistillationCollator {
20 fn default() -> Self {
21 Self { pad_token_id: 0, max_length: 512, pad_left: false }
22 }
23}
24
25impl DistillationCollator {
26 #[must_use]
28 pub fn new(pad_token_id: u32) -> Self {
29 Self { pad_token_id, ..Default::default() }
30 }
31
32 #[must_use]
34 pub fn max_length(mut self, len: usize) -> Self {
35 self.max_length = len;
36 self
37 }
38
39 #[must_use]
41 pub fn pad_left(mut self, left: bool) -> Self {
42 self.pad_left = left;
43 self
44 }
45
46 #[must_use]
48 pub fn collate(&self, examples: &[Example]) -> Batch {
49 if examples.is_empty() {
50 return Batch {
51 input_ids: Array2::zeros((0, 0)),
52 attention_mask: Array2::zeros((0, 0)),
53 labels: None,
54 lengths: vec![],
55 };
56 }
57
58 let max_len = examples.iter().map(|e| e.len().min(self.max_length)).max().unwrap_or(0);
60
61 let batch_size = examples.len();
62 let mut input_ids = Array2::from_elem((batch_size, max_len), self.pad_token_id);
63 let mut attention_mask = Array2::zeros((batch_size, max_len));
64 let mut lengths = Vec::with_capacity(batch_size);
65
66 let has_labels = examples.iter().any(|e| e.labels.is_some());
67 let mut labels = if has_labels {
68 Some(Array2::from_elem((batch_size, max_len), self.pad_token_id))
69 } else {
70 None
71 };
72
73 for (i, example) in examples.iter().enumerate() {
74 let seq_len = example.len().min(self.max_length);
75 lengths.push(seq_len);
76
77 let (start, end) =
78 if self.pad_left { (max_len - seq_len, max_len) } else { (0, seq_len) };
79
80 for (j, &token) in example.input_ids.iter().take(seq_len).enumerate() {
82 input_ids[[i, start + j]] = token;
83 }
84
85 for j in start..end {
87 attention_mask[[i, j]] = 1;
88 }
89
90 if let (Some(ref mut label_arr), Some(ref ex_labels)) = (&mut labels, &example.labels) {
92 for (j, &token) in ex_labels.iter().take(seq_len).enumerate() {
93 label_arr[[i, start + j]] = token;
94 }
95 }
96 }
97
98 Batch { input_ids, attention_mask, labels, lengths }
99 }
100
101 pub fn batch_dataset(&self, dataset: &Dataset, batch_size: usize) -> Vec<Batch> {
103 dataset.examples().chunks(batch_size).map(|chunk| self.collate(chunk)).collect()
104 }
105}