Skip to main content

entrenar/hf_pipeline/dataset/
batch.rs

1//! Batch struct for training
2
3use ndarray::Array2;
4
5/// Batch of examples for training
6#[derive(Debug, Clone)]
7pub struct Batch {
8    /// Input IDs [batch_size, max_seq_len]
9    pub input_ids: Array2<u32>,
10    /// Attention mask [batch_size, max_seq_len]
11    pub attention_mask: Array2<u8>,
12    /// Labels [batch_size, max_seq_len] (optional)
13    pub labels: Option<Array2<u32>>,
14    /// Original sequence lengths
15    pub lengths: Vec<usize>,
16}
17
18impl Batch {
19    /// Get batch size
20    #[must_use]
21    pub fn batch_size(&self) -> usize {
22        self.input_ids.nrows()
23    }
24
25    /// Get maximum sequence length
26    #[must_use]
27    pub fn max_seq_len(&self) -> usize {
28        contract_pre_seq_len_from_data!();
29        self.input_ids.ncols()
30    }
31}