Skip to main content

axonml_data/
collate.rs

1//! Collate - Batch Assembly Functions
2//!
3//! # File
4//! `crates/axonml-data/src/collate.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18
19// =============================================================================
20// Collate Trait
21// =============================================================================
22
23/// Trait for collating samples into batches.
24pub trait Collate<T>: Send + Sync {
25    /// The output batch type.
26    type Output;
27
28    /// Collates a vector of samples into a batch.
29    fn collate(&self, batch: Vec<T>) -> Self::Output;
30}
31
32// =============================================================================
33// DefaultCollate
34// =============================================================================
35
36/// Default collation strategy that stacks tensors.
37pub struct DefaultCollate;
38
39impl DefaultCollate {
40    /// Creates a new `DefaultCollate`.
41    #[must_use]
42    pub fn new() -> Self {
43        Self
44    }
45}
46
47impl Default for DefaultCollate {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
54    type Output = (Tensor<f32>, Tensor<f32>);
55
56    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
57        if batch.is_empty() {
58            return (
59                Tensor::from_vec(vec![], &[0]).unwrap(),
60                Tensor::from_vec(vec![], &[0]).unwrap(),
61            );
62        }
63
64        // Stack inputs
65        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
66        let stacked_x = stack_tensors(&inputs);
67
68        // Stack targets
69        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
70        let stacked_y = stack_tensors(&targets);
71
72        (stacked_x, stacked_y)
73    }
74}
75
76impl Collate<Tensor<f32>> for DefaultCollate {
77    type Output = Tensor<f32>;
78
79    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
80        stack_tensors(&batch)
81    }
82}
83
84// =============================================================================
85// StackCollate
86// =============================================================================
87
88/// Collation that stacks tensors along a new batch dimension.
89pub struct StackCollate {
90    /// Dimension to stack along (default: 0).
91    dim: usize,
92}
93
94impl StackCollate {
95    /// Creates a new `StackCollate` with default dimension 0.
96    #[must_use]
97    pub fn new() -> Self {
98        Self { dim: 0 }
99    }
100
101    /// Creates a `StackCollate` with specified dimension.
102    #[must_use]
103    pub fn with_dim(dim: usize) -> Self {
104        Self { dim }
105    }
106}
107
108impl Default for StackCollate {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl Collate<Tensor<f32>> for StackCollate {
115    type Output = Tensor<f32>;
116
117    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
118        if self.dim == 0 {
119            stack_tensors(&batch)
120        } else {
121            // For non-zero dimensions, we'd need more complex logic
122            // For now, always stack at dim 0
123            stack_tensors(&batch)
124        }
125    }
126}
127
128impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
129    type Output = (Tensor<f32>, Tensor<f32>);
130
131    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
132        if batch.is_empty() {
133            return (
134                Tensor::from_vec(vec![], &[0]).unwrap(),
135                Tensor::from_vec(vec![], &[0]).unwrap(),
136            );
137        }
138
139        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
140        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
141
142        (stack_tensors(&inputs), stack_tensors(&targets))
143    }
144}
145
146// =============================================================================
147// Helper Functions
148// =============================================================================
149
150/// Stacks a vector of tensors along dimension 0.
151#[must_use]
152pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
153    if tensors.is_empty() {
154        return Tensor::from_vec(vec![], &[0]).unwrap();
155    }
156
157    let first_shape = tensors[0].shape();
158    let batch_size = tensors.len();
159
160    // New shape: [batch_size, ...original_shape]
161    let mut new_shape = vec![batch_size];
162    new_shape.extend_from_slice(first_shape);
163
164    // Concatenate all data
165    let mut all_data = Vec::new();
166    for tensor in tensors {
167        all_data.extend(tensor.to_vec());
168    }
169
170    Tensor::from_vec(all_data, &new_shape).unwrap()
171}
172
173/// Concatenates tensors along an existing dimension.
174#[must_use]
175pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
176    if tensors.is_empty() {
177        return Tensor::from_vec(vec![], &[0]).unwrap();
178    }
179
180    if tensors.len() == 1 {
181        return tensors[0].clone();
182    }
183
184    let first_shape = tensors[0].shape();
185
186    // Calculate new shape
187    let mut new_shape = first_shape.to_vec();
188    let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
189    new_shape[dim] = concat_size;
190
191    // For dim=0 concatenation, just append all data
192    if dim == 0 {
193        let mut all_data = Vec::new();
194        for tensor in tensors {
195            all_data.extend(tensor.to_vec());
196        }
197        return Tensor::from_vec(all_data, &new_shape).unwrap();
198    }
199
200    // For other dimensions, more complex interleaving is needed
201    // This is a simplified version that handles common cases
202    let mut all_data = Vec::new();
203    for tensor in tensors {
204        all_data.extend(tensor.to_vec());
205    }
206    Tensor::from_vec(all_data, &new_shape).unwrap()
207}
208
209// =============================================================================
210// Tests
211// =============================================================================
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_stack_tensors() {
219        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
220        let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
221        let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
222
223        let stacked = stack_tensors(&[t1, t2, t3]);
224        assert_eq!(stacked.shape(), &[3, 2]);
225        assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
226    }
227
228    #[test]
229    fn test_stack_tensors_2d() {
230        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
231        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
232
233        let stacked = stack_tensors(&[t1, t2]);
234        assert_eq!(stacked.shape(), &[2, 2, 2]);
235    }
236
237    #[test]
238    fn test_default_collate() {
239        let collate = DefaultCollate::new();
240
241        let batch = vec![
242            (
243                Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
244                Tensor::from_vec(vec![0.0], &[1]).unwrap(),
245            ),
246            (
247                Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
248                Tensor::from_vec(vec![1.0], &[1]).unwrap(),
249            ),
250        ];
251
252        let (x, y) = collate.collate(batch);
253        assert_eq!(x.shape(), &[2, 2]);
254        assert_eq!(y.shape(), &[2, 1]);
255    }
256
257    #[test]
258    fn test_stack_collate() {
259        let collate = StackCollate::new();
260
261        let batch = vec![
262            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
263            Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
264        ];
265
266        let result = collate.collate(batch);
267        assert_eq!(result.shape(), &[2, 3]);
268    }
269
270    #[test]
271    fn test_empty_collate() {
272        let collate = DefaultCollate::new();
273        let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
274        let (x, y) = collate.collate(batch);
275        assert_eq!(x.shape(), &[0]);
276        assert_eq!(y.shape(), &[0]);
277    }
278
279    #[test]
280    fn test_concat_tensors() {
281        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
282        let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
283
284        let concat = concat_tensors(&[t1, t2], 0);
285        assert_eq!(concat.shape(), &[5]);
286        assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
287    }
288}