Skip to main content

axonml_data/
collate.rs

1//! Collate - Batch Assembly Functions
2//!
3//! # File
4//! `crates/axonml-data/src/collate.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19
20// =============================================================================
21// Collate Trait
22// =============================================================================
23
24/// Trait for collating samples into batches.
25pub trait Collate<T>: Send + Sync {
26    /// The output batch type.
27    type Output;
28
29    /// Collates a vector of samples into a batch.
30    fn collate(&self, batch: Vec<T>) -> Self::Output;
31}
32
33// =============================================================================
34// DefaultCollate
35// =============================================================================
36
37/// Default collation strategy that stacks tensors.
38pub struct DefaultCollate;
39
40impl DefaultCollate {
41    /// Creates a new `DefaultCollate`.
42    #[must_use]
43    pub fn new() -> Self {
44        Self
45    }
46}
47
48impl Default for DefaultCollate {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
55    type Output = (Tensor<f32>, Tensor<f32>);
56
57    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
58        if batch.is_empty() {
59            return (
60                Tensor::from_vec(vec![], &[0]).unwrap(),
61                Tensor::from_vec(vec![], &[0]).unwrap(),
62            );
63        }
64
65        // Stack inputs
66        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
67        let stacked_x = stack_tensors(&inputs);
68
69        // Stack targets
70        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
71        let stacked_y = stack_tensors(&targets);
72
73        (stacked_x, stacked_y)
74    }
75}
76
77impl Collate<Tensor<f32>> for DefaultCollate {
78    type Output = Tensor<f32>;
79
80    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
81        stack_tensors(&batch)
82    }
83}
84
85// =============================================================================
86// StackCollate
87// =============================================================================
88
89/// Collation that stacks tensors along a new batch dimension.
90pub struct StackCollate {
91    /// Dimension to stack along (default: 0).
92    dim: usize,
93}
94
95impl StackCollate {
96    /// Creates a new `StackCollate` with default dimension 0.
97    #[must_use]
98    pub fn new() -> Self {
99        Self { dim: 0 }
100    }
101
102    /// Creates a `StackCollate` with specified dimension.
103    #[must_use]
104    pub fn with_dim(dim: usize) -> Self {
105        Self { dim }
106    }
107}
108
109impl Default for StackCollate {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl Collate<Tensor<f32>> for StackCollate {
116    type Output = Tensor<f32>;
117
118    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
119        if batch.is_empty() {
120            return Tensor::from_vec(vec![], &[0]).unwrap();
121        }
122
123        if self.dim == 0 {
124            return stack_tensors(&batch);
125        }
126
127        // Stack along non-zero dimension:
128        // Insert a new dimension at self.dim, then concatenate.
129        // First, unsqueeze each tensor at self.dim.
130        let first_shape = batch[0].shape();
131        let ndim = first_shape.len();
132        let dim = self.dim.min(ndim); // clamp to valid range
133
134        // Build the shape with new dim inserted
135        let mut item_shape_expanded = Vec::with_capacity(ndim + 1);
136        item_shape_expanded.extend_from_slice(&first_shape[..dim]);
137        item_shape_expanded.push(1);
138        item_shape_expanded.extend_from_slice(&first_shape[dim..]);
139
140        // Reshape each tensor to have size-1 at self.dim, then concat along self.dim
141        let reshaped: Vec<Tensor<f32>> = batch
142            .iter()
143            .map(|t| Tensor::from_vec(t.to_vec(), &item_shape_expanded).unwrap())
144            .collect();
145
146        concat_tensors(&reshaped, dim)
147    }
148}
149
150impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
151    type Output = (Tensor<f32>, Tensor<f32>);
152
153    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
154        if batch.is_empty() {
155            return (
156                Tensor::from_vec(vec![], &[0]).unwrap(),
157                Tensor::from_vec(vec![], &[0]).unwrap(),
158            );
159        }
160
161        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
162        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
163
164        (stack_tensors(&inputs), stack_tensors(&targets))
165    }
166}
167
168// =============================================================================
169// Helper Functions
170// =============================================================================
171
172/// Stacks a vector of tensors along dimension 0.
173#[must_use]
174pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
175    if tensors.is_empty() {
176        return Tensor::from_vec(vec![], &[0]).unwrap();
177    }
178
179    let first_shape = tensors[0].shape();
180    let batch_size = tensors.len();
181
182    // New shape: [batch_size, ...original_shape]
183    let mut new_shape = vec![batch_size];
184    new_shape.extend_from_slice(first_shape);
185
186    // Concatenate all data
187    let mut all_data = Vec::new();
188    for tensor in tensors {
189        all_data.extend(tensor.to_vec());
190    }
191
192    Tensor::from_vec(all_data, &new_shape).unwrap()
193}
194
195/// Concatenates tensors along an existing dimension.
196#[must_use]
197pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
198    if tensors.is_empty() {
199        return Tensor::from_vec(vec![], &[0]).unwrap();
200    }
201
202    if tensors.len() == 1 {
203        return tensors[0].clone();
204    }
205
206    let first_shape = tensors[0].shape();
207    let ndim = first_shape.len();
208
209    // Calculate new shape
210    let mut new_shape = first_shape.to_vec();
211    let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
212    new_shape[dim] = concat_size;
213
214    // For dim=0 concatenation, just append all data
215    if dim == 0 {
216        let mut all_data = Vec::new();
217        for tensor in tensors {
218            all_data.extend(tensor.to_vec());
219        }
220        return Tensor::from_vec(all_data, &new_shape).unwrap();
221    }
222
223    // For non-zero dims, properly interleave data.
224    // outer_size = product of dims before `dim`
225    // inner_size = product of dims after `dim`
226    let outer_size: usize = first_shape[..dim].iter().product();
227    let inner_size: usize = if dim + 1 < ndim {
228        first_shape[dim + 1..].iter().product()
229    } else {
230        1
231    };
232
233    let total_elements: usize = new_shape.iter().product();
234    let mut result = Vec::with_capacity(total_elements);
235
236    // Precompute all tensor data
237    let all_vecs: Vec<Vec<f32>> = tensors.iter().map(|t| t.to_vec()).collect();
238
239    for o in 0..outer_size {
240        // For each tensor, copy its slice along the concat dimension
241        for (t_idx, tensor_data) in all_vecs.iter().enumerate() {
242            let t_dim_size = tensors[t_idx].shape()[dim];
243            let t_inner_stride = t_dim_size * inner_size;
244            let src_offset = o * t_inner_stride;
245            result.extend_from_slice(&tensor_data[src_offset..src_offset + t_inner_stride]);
246        }
247    }
248
249    Tensor::from_vec(result, &new_shape).unwrap()
250}
251
252// =============================================================================
253// Tests
254// =============================================================================
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_stack_tensors() {
262        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
263        let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
264        let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
265
266        let stacked = stack_tensors(&[t1, t2, t3]);
267        assert_eq!(stacked.shape(), &[3, 2]);
268        assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
269    }
270
271    #[test]
272    fn test_stack_tensors_2d() {
273        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
274        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
275
276        let stacked = stack_tensors(&[t1, t2]);
277        assert_eq!(stacked.shape(), &[2, 2, 2]);
278    }
279
280    #[test]
281    fn test_default_collate() {
282        let collate = DefaultCollate::new();
283
284        let batch = vec![
285            (
286                Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
287                Tensor::from_vec(vec![0.0], &[1]).unwrap(),
288            ),
289            (
290                Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
291                Tensor::from_vec(vec![1.0], &[1]).unwrap(),
292            ),
293        ];
294
295        let (x, y) = collate.collate(batch);
296        assert_eq!(x.shape(), &[2, 2]);
297        assert_eq!(y.shape(), &[2, 1]);
298    }
299
300    #[test]
301    fn test_stack_collate() {
302        let collate = StackCollate::new();
303
304        let batch = vec![
305            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
306            Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
307        ];
308
309        let result = collate.collate(batch);
310        assert_eq!(result.shape(), &[2, 3]);
311    }
312
313    #[test]
314    fn test_empty_collate() {
315        let collate = DefaultCollate::new();
316        let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
317        let (x, y) = collate.collate(batch);
318        assert_eq!(x.shape(), &[0]);
319        assert_eq!(y.shape(), &[0]);
320    }
321
322    #[test]
323    fn test_concat_tensors_dim0() {
324        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
325        let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
326
327        let concat = concat_tensors(&[t1, t2], 0);
328        assert_eq!(concat.shape(), &[5]);
329        assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
330    }
331
332    #[test]
333    fn test_concat_tensors_dim1() {
334        // Two [2, 2] tensors → [2, 4]
335        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
336        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
337
338        let concat = concat_tensors(&[t1, t2], 1);
339        assert_eq!(concat.shape(), &[2, 4]);
340        // Row 0: [1,2,5,6], Row 1: [3,4,7,8]
341        assert_eq!(
342            concat.to_vec(),
343            vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
344        );
345    }
346}