Skip to main content

axonml_data/
collate.rs

1//! Collate - Batch Assembly Functions
2//!
3//! # File
4//! `crates/axonml-data/src/collate.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18
19// =============================================================================
20// Collate Trait
21// =============================================================================
22
23/// Trait for collating samples into batches.
24pub trait Collate<T>: Send + Sync {
25    /// The output batch type.
26    type Output;
27
28    /// Collates a vector of samples into a batch.
29    fn collate(&self, batch: Vec<T>) -> Self::Output;
30}
31
32// =============================================================================
33// DefaultCollate
34// =============================================================================
35
36/// Default collation strategy that stacks tensors.
37pub struct DefaultCollate;
38
39impl DefaultCollate {
40    /// Creates a new `DefaultCollate`.
41    #[must_use]
42    pub fn new() -> Self {
43        Self
44    }
45}
46
47impl Default for DefaultCollate {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
54    type Output = (Tensor<f32>, Tensor<f32>);
55
56    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
57        if batch.is_empty() {
58            return (
59                Tensor::from_vec(vec![], &[0]).unwrap(),
60                Tensor::from_vec(vec![], &[0]).unwrap(),
61            );
62        }
63
64        // Stack inputs
65        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
66        let stacked_x = stack_tensors(&inputs);
67
68        // Stack targets
69        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
70        let stacked_y = stack_tensors(&targets);
71
72        (stacked_x, stacked_y)
73    }
74}
75
76impl Collate<Tensor<f32>> for DefaultCollate {
77    type Output = Tensor<f32>;
78
79    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
80        stack_tensors(&batch)
81    }
82}
83
84// =============================================================================
85// StackCollate
86// =============================================================================
87
88/// Collation that stacks tensors along a new batch dimension.
89pub struct StackCollate {
90    /// Dimension to stack along (default: 0).
91    dim: usize,
92}
93
94impl StackCollate {
95    /// Creates a new `StackCollate` with default dimension 0.
96    #[must_use]
97    pub fn new() -> Self {
98        Self { dim: 0 }
99    }
100
101    /// Creates a `StackCollate` with specified dimension.
102    #[must_use]
103    pub fn with_dim(dim: usize) -> Self {
104        Self { dim }
105    }
106}
107
108impl Default for StackCollate {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl Collate<Tensor<f32>> for StackCollate {
115    type Output = Tensor<f32>;
116
117    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
118        if batch.is_empty() {
119            return Tensor::from_vec(vec![], &[0]).unwrap();
120        }
121
122        if self.dim == 0 {
123            return stack_tensors(&batch);
124        }
125
126        // Stack along non-zero dimension:
127        // Insert a new dimension at self.dim, then concatenate.
128        // First, unsqueeze each tensor at self.dim.
129        let first_shape = batch[0].shape();
130        let ndim = first_shape.len();
131        let dim = self.dim.min(ndim); // clamp to valid range
132
133        // Build the shape with new dim inserted
134        let mut item_shape_expanded = Vec::with_capacity(ndim + 1);
135        item_shape_expanded.extend_from_slice(&first_shape[..dim]);
136        item_shape_expanded.push(1);
137        item_shape_expanded.extend_from_slice(&first_shape[dim..]);
138
139        // Reshape each tensor to have size-1 at self.dim, then concat along self.dim
140        let reshaped: Vec<Tensor<f32>> = batch
141            .iter()
142            .map(|t| Tensor::from_vec(t.to_vec(), &item_shape_expanded).unwrap())
143            .collect();
144
145        concat_tensors(&reshaped, dim)
146    }
147}
148
149impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
150    type Output = (Tensor<f32>, Tensor<f32>);
151
152    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
153        if batch.is_empty() {
154            return (
155                Tensor::from_vec(vec![], &[0]).unwrap(),
156                Tensor::from_vec(vec![], &[0]).unwrap(),
157            );
158        }
159
160        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
161        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
162
163        (stack_tensors(&inputs), stack_tensors(&targets))
164    }
165}
166
167// =============================================================================
168// Helper Functions
169// =============================================================================
170
171/// Stacks a vector of tensors along dimension 0.
172#[must_use]
173pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
174    if tensors.is_empty() {
175        return Tensor::from_vec(vec![], &[0]).unwrap();
176    }
177
178    let first_shape = tensors[0].shape();
179    let batch_size = tensors.len();
180
181    // New shape: [batch_size, ...original_shape]
182    let mut new_shape = vec![batch_size];
183    new_shape.extend_from_slice(first_shape);
184
185    // Concatenate all data
186    let mut all_data = Vec::new();
187    for tensor in tensors {
188        all_data.extend(tensor.to_vec());
189    }
190
191    Tensor::from_vec(all_data, &new_shape).unwrap()
192}
193
194/// Concatenates tensors along an existing dimension.
195#[must_use]
196pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
197    if tensors.is_empty() {
198        return Tensor::from_vec(vec![], &[0]).unwrap();
199    }
200
201    if tensors.len() == 1 {
202        return tensors[0].clone();
203    }
204
205    let first_shape = tensors[0].shape();
206    let ndim = first_shape.len();
207
208    // Calculate new shape
209    let mut new_shape = first_shape.to_vec();
210    let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
211    new_shape[dim] = concat_size;
212
213    // For dim=0 concatenation, just append all data
214    if dim == 0 {
215        let mut all_data = Vec::new();
216        for tensor in tensors {
217            all_data.extend(tensor.to_vec());
218        }
219        return Tensor::from_vec(all_data, &new_shape).unwrap();
220    }
221
222    // For non-zero dims, properly interleave data.
223    // outer_size = product of dims before `dim`
224    // inner_size = product of dims after `dim`
225    let outer_size: usize = first_shape[..dim].iter().product();
226    let inner_size: usize = if dim + 1 < ndim {
227        first_shape[dim + 1..].iter().product()
228    } else {
229        1
230    };
231
232    let total_elements: usize = new_shape.iter().product();
233    let mut result = Vec::with_capacity(total_elements);
234
235    // Precompute all tensor data
236    let all_vecs: Vec<Vec<f32>> = tensors.iter().map(|t| t.to_vec()).collect();
237
238    for o in 0..outer_size {
239        // For each tensor, copy its slice along the concat dimension
240        for (t_idx, tensor_data) in all_vecs.iter().enumerate() {
241            let t_dim_size = tensors[t_idx].shape()[dim];
242            let t_inner_stride = t_dim_size * inner_size;
243            let src_offset = o * t_inner_stride;
244            result.extend_from_slice(&tensor_data[src_offset..src_offset + t_inner_stride]);
245        }
246    }
247
248    Tensor::from_vec(result, &new_shape).unwrap()
249}
250
251// =============================================================================
252// Tests
253// =============================================================================
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_stack_tensors() {
261        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
262        let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
263        let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
264
265        let stacked = stack_tensors(&[t1, t2, t3]);
266        assert_eq!(stacked.shape(), &[3, 2]);
267        assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
268    }
269
270    #[test]
271    fn test_stack_tensors_2d() {
272        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
273        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
274
275        let stacked = stack_tensors(&[t1, t2]);
276        assert_eq!(stacked.shape(), &[2, 2, 2]);
277    }
278
279    #[test]
280    fn test_default_collate() {
281        let collate = DefaultCollate::new();
282
283        let batch = vec![
284            (
285                Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
286                Tensor::from_vec(vec![0.0], &[1]).unwrap(),
287            ),
288            (
289                Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
290                Tensor::from_vec(vec![1.0], &[1]).unwrap(),
291            ),
292        ];
293
294        let (x, y) = collate.collate(batch);
295        assert_eq!(x.shape(), &[2, 2]);
296        assert_eq!(y.shape(), &[2, 1]);
297    }
298
299    #[test]
300    fn test_stack_collate() {
301        let collate = StackCollate::new();
302
303        let batch = vec![
304            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
305            Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
306        ];
307
308        let result = collate.collate(batch);
309        assert_eq!(result.shape(), &[2, 3]);
310    }
311
312    #[test]
313    fn test_empty_collate() {
314        let collate = DefaultCollate::new();
315        let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
316        let (x, y) = collate.collate(batch);
317        assert_eq!(x.shape(), &[0]);
318        assert_eq!(y.shape(), &[0]);
319    }
320
321    #[test]
322    fn test_concat_tensors_dim0() {
323        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
324        let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
325
326        let concat = concat_tensors(&[t1, t2], 0);
327        assert_eq!(concat.shape(), &[5]);
328        assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
329    }
330
331    #[test]
332    fn test_concat_tensors_dim1() {
333        // Two [2, 2] tensors → [2, 4]
334        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
335        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
336
337        let concat = concat_tensors(&[t1, t2], 1);
338        assert_eq!(concat.shape(), &[2, 4]);
339        // Row 0: [1,2,5,6], Row 1: [3,4,7,8]
340        assert_eq!(
341            concat.to_vec(),
342            vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
343        );
344    }
345}