1use crate::Tensor;
4
5#[derive(Clone, Debug)]
7pub struct Batch {
8 pub inputs: Tensor,
10 pub targets: Tensor,
12}
13
14impl Batch {
15 pub fn new(inputs: Tensor, targets: Tensor) -> Self {
17 Self { inputs, targets }
18 }
19
20 pub fn size(&self) -> usize {
22 self.inputs.len()
23 }
24}
25
26#[cfg(test)]
27mod tests {
28 use super::*;
29
30 #[test]
31 fn test_batch_creation() {
32 let inputs = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
33 let targets = Tensor::from_vec(vec![4.0, 5.0, 6.0], false);
34
35 let batch = Batch::new(inputs, targets);
36
37 assert_eq!(batch.size(), 3);
38 }
39}