Skip to main content

entrenar/config/train/batches/
rebatch.rs

1//! Batch re-batching utilities
2
3use crate::train::Batch;
4use crate::Tensor;
5
6/// Re-batch data into specified batch size
7#[allow(dead_code)]
8pub fn rebatch(batches: Vec<Batch>, batch_size: usize) -> Vec<Batch> {
9    // Flatten all data
10    let all_inputs: Vec<f32> =
11        batches.iter().flat_map(|b| b.inputs.data().iter().copied()).collect();
12    let all_targets: Vec<f32> =
13        batches.iter().flat_map(|b| b.targets.data().iter().copied()).collect();
14
15    if all_inputs.is_empty() {
16        return Vec::new();
17    }
18
19    // Determine feature dimensions from first batch
20    let input_dim = batches[0].inputs.len();
21    let target_dim = batches[0].targets.len();
22
23    // Re-batch
24    let num_examples = all_inputs.len() / input_dim;
25    let mut new_batches = Vec::new();
26
27    for chunk_start in (0..num_examples).step_by(batch_size) {
28        let chunk_end = (chunk_start + batch_size).min(num_examples);
29        let input_start = chunk_start * input_dim;
30        let input_end = chunk_end * input_dim;
31        let target_start = chunk_start * target_dim;
32        let target_end = chunk_end * target_dim;
33
34        new_batches.push(Batch::new(
35            Tensor::from_vec(all_inputs[input_start..input_end].to_vec(), false),
36            Tensor::from_vec(all_targets[target_start..target_end].to_vec(), false),
37        ));
38    }
39
40    new_batches
41}