entrenar/config/train/batches/rebatch.rs
1//! Batch re-batching utilities
2
3use crate::train::Batch;
4use crate::Tensor;
5
6/// Re-batch data into specified batch size
7#[allow(dead_code)]
8pub fn rebatch(batches: Vec<Batch>, batch_size: usize) -> Vec<Batch> {
9 // Flatten all data
10 let all_inputs: Vec<f32> =
11 batches.iter().flat_map(|b| b.inputs.data().iter().copied()).collect();
12 let all_targets: Vec<f32> =
13 batches.iter().flat_map(|b| b.targets.data().iter().copied()).collect();
14
15 if all_inputs.is_empty() {
16 return Vec::new();
17 }
18
19 // Determine feature dimensions from first batch
20 let input_dim = batches[0].inputs.len();
21 let target_dim = batches[0].targets.len();
22
23 // Re-batch
24 let num_examples = all_inputs.len() / input_dim;
25 let mut new_batches = Vec::new();
26
27 for chunk_start in (0..num_examples).step_by(batch_size) {
28 let chunk_end = (chunk_start + batch_size).min(num_examples);
29 let input_start = chunk_start * input_dim;
30 let input_end = chunk_end * input_dim;
31 let target_start = chunk_start * target_dim;
32 let target_end = chunk_end * target_dim;
33
34 new_batches.push(Batch::new(
35 Tensor::from_vec(all_inputs[input_start..input_end].to_vec(), false),
36 Tensor::from_vec(all_targets[target_start..target_end].to_vec(), false),
37 ));
38 }
39
40 new_batches
41}