axonml_data/
collate.rs

1//! Collate - Batch Assembly Functions
2//!
3//! Provides functions for combining individual samples into batches.
4//!
5//! @version 0.1.0
6//! @author `AutomataNexus` Development Team
7
8use axonml_tensor::Tensor;
9
10// =============================================================================
11// Collate Trait
12// =============================================================================
13
14/// Trait for collating samples into batches.
15pub trait Collate<T>: Send + Sync {
16    /// The output batch type.
17    type Output;
18
19    /// Collates a vector of samples into a batch.
20    fn collate(&self, batch: Vec<T>) -> Self::Output;
21}
22
23// =============================================================================
24// DefaultCollate
25// =============================================================================
26
27/// Default collation strategy that stacks tensors.
28pub struct DefaultCollate;
29
30impl DefaultCollate {
31    /// Creates a new `DefaultCollate`.
32    #[must_use] pub fn new() -> Self {
33        Self
34    }
35}
36
37impl Default for DefaultCollate {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
44    type Output = (Tensor<f32>, Tensor<f32>);
45
46    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
47        if batch.is_empty() {
48            return (
49                Tensor::from_vec(vec![], &[0]).unwrap(),
50                Tensor::from_vec(vec![], &[0]).unwrap(),
51            );
52        }
53
54        // Stack inputs
55        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
56        let stacked_x = stack_tensors(&inputs);
57
58        // Stack targets
59        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
60        let stacked_y = stack_tensors(&targets);
61
62        (stacked_x, stacked_y)
63    }
64}
65
66impl Collate<Tensor<f32>> for DefaultCollate {
67    type Output = Tensor<f32>;
68
69    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
70        stack_tensors(&batch)
71    }
72}
73
74// =============================================================================
75// StackCollate
76// =============================================================================
77
78/// Collation that stacks tensors along a new batch dimension.
79pub struct StackCollate {
80    /// Dimension to stack along (default: 0).
81    dim: usize,
82}
83
84impl StackCollate {
85    /// Creates a new `StackCollate` with default dimension 0.
86    #[must_use] pub fn new() -> Self {
87        Self { dim: 0 }
88    }
89
90    /// Creates a `StackCollate` with specified dimension.
91    #[must_use] pub fn with_dim(dim: usize) -> Self {
92        Self { dim }
93    }
94}
95
96impl Default for StackCollate {
97    fn default() -> Self {
98        Self::new()
99    }
100}
101
102impl Collate<Tensor<f32>> for StackCollate {
103    type Output = Tensor<f32>;
104
105    fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
106        if self.dim == 0 {
107            stack_tensors(&batch)
108        } else {
109            // For non-zero dimensions, we'd need more complex logic
110            // For now, always stack at dim 0
111            stack_tensors(&batch)
112        }
113    }
114}
115
116impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
117    type Output = (Tensor<f32>, Tensor<f32>);
118
119    fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
120        if batch.is_empty() {
121            return (
122                Tensor::from_vec(vec![], &[0]).unwrap(),
123                Tensor::from_vec(vec![], &[0]).unwrap(),
124            );
125        }
126
127        let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
128        let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
129
130        (stack_tensors(&inputs), stack_tensors(&targets))
131    }
132}
133
134// =============================================================================
135// Helper Functions
136// =============================================================================
137
138/// Stacks a vector of tensors along dimension 0.
139#[must_use] pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
140    if tensors.is_empty() {
141        return Tensor::from_vec(vec![], &[0]).unwrap();
142    }
143
144    let first_shape = tensors[0].shape();
145    let batch_size = tensors.len();
146
147    // New shape: [batch_size, ...original_shape]
148    let mut new_shape = vec![batch_size];
149    new_shape.extend_from_slice(first_shape);
150
151    // Concatenate all data
152    let mut all_data = Vec::new();
153    for tensor in tensors {
154        all_data.extend(tensor.to_vec());
155    }
156
157    Tensor::from_vec(all_data, &new_shape).unwrap()
158}
159
160/// Concatenates tensors along an existing dimension.
161#[must_use] pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
162    if tensors.is_empty() {
163        return Tensor::from_vec(vec![], &[0]).unwrap();
164    }
165
166    if tensors.len() == 1 {
167        return tensors[0].clone();
168    }
169
170    let first_shape = tensors[0].shape();
171
172    // Calculate new shape
173    let mut new_shape = first_shape.to_vec();
174    let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
175    new_shape[dim] = concat_size;
176
177    // For dim=0 concatenation, just append all data
178    if dim == 0 {
179        let mut all_data = Vec::new();
180        for tensor in tensors {
181            all_data.extend(tensor.to_vec());
182        }
183        return Tensor::from_vec(all_data, &new_shape).unwrap();
184    }
185
186    // For other dimensions, more complex interleaving is needed
187    // This is a simplified version that handles common cases
188    let mut all_data = Vec::new();
189    for tensor in tensors {
190        all_data.extend(tensor.to_vec());
191    }
192    Tensor::from_vec(all_data, &new_shape).unwrap()
193}
194
195// =============================================================================
196// Tests
197// =============================================================================
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_stack_tensors() {
205        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
206        let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
207        let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
208
209        let stacked = stack_tensors(&[t1, t2, t3]);
210        assert_eq!(stacked.shape(), &[3, 2]);
211        assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
212    }
213
214    #[test]
215    fn test_stack_tensors_2d() {
216        let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
217        let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
218
219        let stacked = stack_tensors(&[t1, t2]);
220        assert_eq!(stacked.shape(), &[2, 2, 2]);
221    }
222
223    #[test]
224    fn test_default_collate() {
225        let collate = DefaultCollate::new();
226
227        let batch = vec![
228            (
229                Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
230                Tensor::from_vec(vec![0.0], &[1]).unwrap(),
231            ),
232            (
233                Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
234                Tensor::from_vec(vec![1.0], &[1]).unwrap(),
235            ),
236        ];
237
238        let (x, y) = collate.collate(batch);
239        assert_eq!(x.shape(), &[2, 2]);
240        assert_eq!(y.shape(), &[2, 1]);
241    }
242
243    #[test]
244    fn test_stack_collate() {
245        let collate = StackCollate::new();
246
247        let batch = vec![
248            Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
249            Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
250        ];
251
252        let result = collate.collate(batch);
253        assert_eq!(result.shape(), &[2, 3]);
254    }
255
256    #[test]
257    fn test_empty_collate() {
258        let collate = DefaultCollate::new();
259        let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
260        let (x, y) = collate.collate(batch);
261        assert_eq!(x.shape(), &[0]);
262        assert_eq!(y.shape(), &[0]);
263    }
264
265    #[test]
266    fn test_concat_tensors() {
267        let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
268        let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
269
270        let concat = concat_tensors(&[t1, t2], 0);
271        assert_eq!(concat.shape(), &[5]);
272        assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
273    }
274}