1use axonml_tensor::Tensor;
19
20pub trait Collate<T>: Send + Sync {
26 type Output;
28
29 fn collate(&self, batch: Vec<T>) -> Self::Output;
31}
32
33pub struct DefaultCollate;
39
40impl DefaultCollate {
41 #[must_use]
43 pub fn new() -> Self {
44 Self
45 }
46}
47
48impl Default for DefaultCollate {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
55 type Output = (Tensor<f32>, Tensor<f32>);
56
57 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
58 if batch.is_empty() {
59 return (
60 Tensor::from_vec(vec![], &[0]).unwrap(),
61 Tensor::from_vec(vec![], &[0]).unwrap(),
62 );
63 }
64
65 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
67 let stacked_x = stack_tensors(&inputs);
68
69 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
71 let stacked_y = stack_tensors(&targets);
72
73 (stacked_x, stacked_y)
74 }
75}
76
77impl Collate<Tensor<f32>> for DefaultCollate {
78 type Output = Tensor<f32>;
79
80 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
81 stack_tensors(&batch)
82 }
83}
84
85pub struct StackCollate {
91 dim: usize,
93}
94
95impl StackCollate {
96 #[must_use]
98 pub fn new() -> Self {
99 Self { dim: 0 }
100 }
101
102 #[must_use]
104 pub fn with_dim(dim: usize) -> Self {
105 Self { dim }
106 }
107}
108
109impl Default for StackCollate {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl Collate<Tensor<f32>> for StackCollate {
116 type Output = Tensor<f32>;
117
118 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
119 if batch.is_empty() {
120 return Tensor::from_vec(vec![], &[0]).unwrap();
121 }
122
123 if self.dim == 0 {
124 return stack_tensors(&batch);
125 }
126
127 let first_shape = batch[0].shape();
131 let ndim = first_shape.len();
132 let dim = self.dim.min(ndim); let mut item_shape_expanded = Vec::with_capacity(ndim + 1);
136 item_shape_expanded.extend_from_slice(&first_shape[..dim]);
137 item_shape_expanded.push(1);
138 item_shape_expanded.extend_from_slice(&first_shape[dim..]);
139
140 let reshaped: Vec<Tensor<f32>> = batch
142 .iter()
143 .map(|t| Tensor::from_vec(t.to_vec(), &item_shape_expanded).unwrap())
144 .collect();
145
146 concat_tensors(&reshaped, dim)
147 }
148}
149
150impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
151 type Output = (Tensor<f32>, Tensor<f32>);
152
153 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
154 if batch.is_empty() {
155 return (
156 Tensor::from_vec(vec![], &[0]).unwrap(),
157 Tensor::from_vec(vec![], &[0]).unwrap(),
158 );
159 }
160
161 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
162 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
163
164 (stack_tensors(&inputs), stack_tensors(&targets))
165 }
166}
167
168#[must_use]
174pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
175 if tensors.is_empty() {
176 return Tensor::from_vec(vec![], &[0]).unwrap();
177 }
178
179 let first_shape = tensors[0].shape();
180 let batch_size = tensors.len();
181
182 let mut new_shape = vec![batch_size];
184 new_shape.extend_from_slice(first_shape);
185
186 let mut all_data = Vec::new();
188 for tensor in tensors {
189 all_data.extend(tensor.to_vec());
190 }
191
192 Tensor::from_vec(all_data, &new_shape).unwrap()
193}
194
195#[must_use]
197pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
198 if tensors.is_empty() {
199 return Tensor::from_vec(vec![], &[0]).unwrap();
200 }
201
202 if tensors.len() == 1 {
203 return tensors[0].clone();
204 }
205
206 let first_shape = tensors[0].shape();
207 let ndim = first_shape.len();
208
209 let mut new_shape = first_shape.to_vec();
211 let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
212 new_shape[dim] = concat_size;
213
214 if dim == 0 {
216 let mut all_data = Vec::new();
217 for tensor in tensors {
218 all_data.extend(tensor.to_vec());
219 }
220 return Tensor::from_vec(all_data, &new_shape).unwrap();
221 }
222
223 let outer_size: usize = first_shape[..dim].iter().product();
227 let inner_size: usize = if dim + 1 < ndim {
228 first_shape[dim + 1..].iter().product()
229 } else {
230 1
231 };
232
233 let total_elements: usize = new_shape.iter().product();
234 let mut result = Vec::with_capacity(total_elements);
235
236 let all_vecs: Vec<Vec<f32>> = tensors.iter().map(|t| t.to_vec()).collect();
238
239 for o in 0..outer_size {
240 for (t_idx, tensor_data) in all_vecs.iter().enumerate() {
242 let t_dim_size = tensors[t_idx].shape()[dim];
243 let t_inner_stride = t_dim_size * inner_size;
244 let src_offset = o * t_inner_stride;
245 result.extend_from_slice(&tensor_data[src_offset..src_offset + t_inner_stride]);
246 }
247 }
248
249 Tensor::from_vec(result, &new_shape).unwrap()
250}
251
252#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_stack_tensors() {
262 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
263 let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
264 let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
265
266 let stacked = stack_tensors(&[t1, t2, t3]);
267 assert_eq!(stacked.shape(), &[3, 2]);
268 assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
269 }
270
271 #[test]
272 fn test_stack_tensors_2d() {
273 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
274 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
275
276 let stacked = stack_tensors(&[t1, t2]);
277 assert_eq!(stacked.shape(), &[2, 2, 2]);
278 }
279
280 #[test]
281 fn test_default_collate() {
282 let collate = DefaultCollate::new();
283
284 let batch = vec![
285 (
286 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
287 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
288 ),
289 (
290 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
291 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
292 ),
293 ];
294
295 let (x, y) = collate.collate(batch);
296 assert_eq!(x.shape(), &[2, 2]);
297 assert_eq!(y.shape(), &[2, 1]);
298 }
299
300 #[test]
301 fn test_stack_collate() {
302 let collate = StackCollate::new();
303
304 let batch = vec![
305 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
306 Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
307 ];
308
309 let result = collate.collate(batch);
310 assert_eq!(result.shape(), &[2, 3]);
311 }
312
313 #[test]
314 fn test_empty_collate() {
315 let collate = DefaultCollate::new();
316 let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
317 let (x, y) = collate.collate(batch);
318 assert_eq!(x.shape(), &[0]);
319 assert_eq!(y.shape(), &[0]);
320 }
321
322 #[test]
323 fn test_concat_tensors_dim0() {
324 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
325 let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
326
327 let concat = concat_tensors(&[t1, t2], 0);
328 assert_eq!(concat.shape(), &[5]);
329 assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
330 }
331
332 #[test]
333 fn test_concat_tensors_dim1() {
334 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
336 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
337
338 let concat = concat_tensors(&[t1, t2], 1);
339 assert_eq!(concat.shape(), &[2, 4]);
340 assert_eq!(
342 concat.to_vec(),
343 vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
344 );
345 }
346}