Skip to main content

axonml_data/
collate.rs

1//! Collate - Batch Assembly Functions
2//!
3//! Provides functions for combining individual samples into batches.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_tensor::Tensor;
9
10// =============================================================================
11// Collate Trait
12// =============================================================================
13
14/// Trait for collating samples into batches.
15pub trait Collate<T>: Send + Sync {
16    /// The output batch type.
17    type Output;
18
19    /// Collates a vector of samples into a batch.
20    fn collate(&self, batch: Vec<T>) -> Self::Output;
21}
22
23// =============================================================================
24// DefaultCollate
25// =============================================================================
26
27/// Default collation strategy that stacks tensors.
28pub struct DefaultCollate;
29
30impl DefaultCollate {
31    /// Creates a new `DefaultCollate`.
32    #[must_use]
33    pub fn new() -> Self {
34        Self
35    }
36}
37
38impl Default for DefaultCollate {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
45    type Output = (Tensor<f32>, Tensor<f32>);
46
47    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
48        if batch.is_empty() {
49            return (
50                Tensor::from_vec(vec![], &[0]).unwrap(),
51                Tensor::from_vec(vec![], &[0]).unwrap(),
52            );
53        }
54
55        // Stack inputs
56        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
57        let stacked_x = stack_tensors(&inputs);
58
59        // Stack targets
60        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
61        let stacked_y = stack_tensors(&targets);
62
63        (stacked_x, stacked_y)
64    }
65}
66
67impl Collate<Tensor<f32>> for DefaultCollate {
68    type Output = Tensor<f32>;
69
70    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
71        stack_tensors(&batch)
72    }
73}
74
75// =============================================================================
76// StackCollate
77// =============================================================================
78
79/// Collation that stacks tensors along a new batch dimension.
80pub struct StackCollate {
81    /// Dimension to stack along (default: 0).
82    dim: usize,
83}
84
85impl StackCollate {
86    /// Creates a new `StackCollate` with default dimension 0.
87    #[must_use]
88    pub fn new() -> Self {
89        Self { dim: 0 }
90    }
91
92    /// Creates a `StackCollate` with specified dimension.
93    #[must_use]
94    pub fn with_dim(dim: usize) -> Self {
95        Self { dim }
96    }
97}
98
99impl Default for StackCollate {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl Collate<Tensor<f32>> for StackCollate {
106    type Output = Tensor<f32>;
107
108    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
109        if self.dim == 0 {
110            stack_tensors(&batch)
111        } else {
112            // For non-zero dimensions, we'd need more complex logic
113            // For now, always stack at dim 0
114            stack_tensors(&batch)
115        }
116    }
117}
118
119impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
120    type Output = (Tensor<f32>, Tensor<f32>);
121
122    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
123        if batch.is_empty() {
124            return (
125                Tensor::from_vec(vec![], &[0]).unwrap(),
126                Tensor::from_vec(vec![], &[0]).unwrap(),
127            );
128        }
129
130        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
131        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
132
133        (stack_tensors(&inputs), stack_tensors(&targets))
134    }
135}
136
137// =============================================================================
138// Helper Functions
139// =============================================================================
140
141/// Stacks a vector of tensors along dimension 0.
142#[must_use]
143pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
144    if tensors.is_empty() {
145        return Tensor::from_vec(vec![], &[0]).unwrap();
146    }
147
148    let first_shape = tensors[0].shape();
149    let batch_size = tensors.len();
150
151    // New shape: [batch_size, ...original_shape]
152    let mut new_shape = vec![batch_size];
153    new_shape.extend_from_slice(first_shape);
154
155    // Concatenate all data
156    let mut all_data = Vec::new();
157    for tensor in tensors {
158        all_data.extend(tensor.to_vec());
159    }
160
161    Tensor::from_vec(all_data, &new_shape).unwrap()
162}
163
164/// Concatenates tensors along an existing dimension.
165#[must_use]
166pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
167    if tensors.is_empty() {
168        return Tensor::from_vec(vec![], &[0]).unwrap();
169    }
170
171    if tensors.len() == 1 {
172        return tensors[0].clone();
173    }
174
175    let first_shape = tensors[0].shape();
176
177    // Calculate new shape
178    let mut new_shape = first_shape.to_vec();
179    let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
180    new_shape[dim] = concat_size;
181
182    // For dim=0 concatenation, just append all data
183    if dim == 0 {
184        let mut all_data = Vec::new();
185        for tensor in tensors {
186            all_data.extend(tensor.to_vec());
187        }
188        return Tensor::from_vec(all_data, &new_shape).unwrap();
189    }
190
191    // For other dimensions, more complex interleaving is needed
192    // This is a simplified version that handles common cases
193    let mut all_data = Vec::new();
194    for tensor in tensors {
195        all_data.extend(tensor.to_vec());
196    }
197    Tensor::from_vec(all_data, &new_shape).unwrap()
198}
199
200// =============================================================================
201// Tests
202// =============================================================================
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_stack_tensors() {
210        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
211        let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
212        let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
213
214        let stacked = stack_tensors(&[t1, t2, t3]);
215        assert_eq!(stacked.shape(), &[3, 2]);
216        assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
217    }
218
219    #[test]
220    fn test_stack_tensors_2d() {
221        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
222        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
223
224        let stacked = stack_tensors(&[t1, t2]);
225        assert_eq!(stacked.shape(), &[2, 2, 2]);
226    }
227
228    #[test]
229    fn test_default_collate() {
230        let collate = DefaultCollate::new();
231
232        let batch = vec![
233            (
234                Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
235                Tensor::from_vec(vec![0.0], &[1]).unwrap(),
236            ),
237            (
238                Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
239                Tensor::from_vec(vec![1.0], &[1]).unwrap(),
240            ),
241        ];
242
243        let (x, y) = collate.collate(batch);
244        assert_eq!(x.shape(), &[2, 2]);
245        assert_eq!(y.shape(), &[2, 1]);
246    }
247
248    #[test]
249    fn test_stack_collate() {
250        let collate = StackCollate::new();
251
252        let batch = vec![
253            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
254            Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
255        ];
256
257        let result = collate.collate(batch);
258        assert_eq!(result.shape(), &[2, 3]);
259    }
260
261    #[test]
262    fn test_empty_collate() {
263        let collate = DefaultCollate::new();
264        let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
265        let (x, y) = collate.collate(batch);
266        assert_eq!(x.shape(), &[0]);
267        assert_eq!(y.shape(), &[0]);
268    }
269
270    #[test]
271    fn test_concat_tensors() {
272        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
273        let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
274
275        let concat = concat_tensors(&[t1, t2], 0);
276        assert_eq!(concat.shape(), &[5]);
277        assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
278    }
279}