Skip to main content

entrenar/train/
batch.rs

1//! Batch data structure
2
3use crate::Tensor;
4
5/// A training batch containing inputs and targets
6#[derive(Clone, Debug)]
7pub struct Batch {
8    /// Input features
9    pub inputs: Tensor,
10    /// Target labels/values
11    pub targets: Tensor,
12}
13
14impl Batch {
15    /// Create a new batch
16    pub fn new(inputs: Tensor, targets: Tensor) -> Self {
17        Self { inputs, targets }
18    }
19
20    /// Get batch size (length of inputs)
21    pub fn size(&self) -> usize {
22        self.inputs.len()
23    }
24}
25
26#[cfg(test)]
27mod tests {
28    use super::*;
29
30    #[test]
31    fn test_batch_creation() {
32        let inputs = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
33        let targets = Tensor::from_vec(vec![4.0, 5.0, 6.0], false);
34
35        let batch = Batch::new(inputs, targets);
36
37        assert_eq!(batch.size(), 3);
38    }
39}