entrenar/hf_pipeline/dataset/batch.rs
1//! Batch struct for training
2
3use ndarray::Array2;
4
5/// Batch of examples for training
6#[derive(Debug, Clone)]
7pub struct Batch {
8 /// Input IDs [batch_size, max_seq_len]
9 pub input_ids: Array2<u32>,
10 /// Attention mask [batch_size, max_seq_len]
11 pub attention_mask: Array2<u8>,
12 /// Labels [batch_size, max_seq_len] (optional)
13 pub labels: Option<Array2<u32>>,
14 /// Original sequence lengths
15 pub lengths: Vec<usize>,
16}
17
18impl Batch {
19 /// Get batch size
20 #[must_use]
21 pub fn batch_size(&self) -> usize {
22 self.input_ids.nrows()
23 }
24
25 /// Get maximum sequence length
26 #[must_use]
27 pub fn max_seq_len(&self) -> usize {
28 contract_pre_seq_len_from_data!();
29 self.input_ids.ncols()
30 }
31}