use ndarray::Array2;
#[derive(Debug, Clone)]
pub struct Batch {
pub input_ids: Array2<u32>,
pub attention_mask: Array2<u8>,
pub labels: Option<Array2<u32>>,
pub lengths: Vec<usize>,
}
impl Batch {
#[must_use]
pub fn batch_size(&self) -> usize {
self.input_ids.nrows()
}
#[must_use]
pub fn max_seq_len(&self) -> usize {
self.input_ids.ncols()
}
}