1use axonml_tensor::Tensor;
18
19pub trait Collate<T>: Send + Sync {
25 type Output;
27
28 fn collate(&self, batch: Vec<T>) -> Self::Output;
30}
31
32pub struct DefaultCollate;
38
39impl DefaultCollate {
40 #[must_use]
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47impl Default for DefaultCollate {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
54 type Output = (Tensor<f32>, Tensor<f32>);
55
56 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
57 if batch.is_empty() {
58 return (
59 Tensor::from_vec(vec![], &[0]).unwrap(),
60 Tensor::from_vec(vec![], &[0]).unwrap(),
61 );
62 }
63
64 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
66 let stacked_x = stack_tensors(&inputs);
67
68 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
70 let stacked_y = stack_tensors(&targets);
71
72 (stacked_x, stacked_y)
73 }
74}
75
76impl Collate<Tensor<f32>> for DefaultCollate {
77 type Output = Tensor<f32>;
78
79 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
80 stack_tensors(&batch)
81 }
82}
83
84pub struct StackCollate {
90 dim: usize,
92}
93
94impl StackCollate {
95 #[must_use]
97 pub fn new() -> Self {
98 Self { dim: 0 }
99 }
100
101 #[must_use]
103 pub fn with_dim(dim: usize) -> Self {
104 Self { dim }
105 }
106}
107
108impl Default for StackCollate {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl Collate<Tensor<f32>> for StackCollate {
115 type Output = Tensor<f32>;
116
117 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
118 if batch.is_empty() {
119 return Tensor::from_vec(vec![], &[0]).unwrap();
120 }
121
122 if self.dim == 0 {
123 return stack_tensors(&batch);
124 }
125
126 let first_shape = batch[0].shape();
130 let ndim = first_shape.len();
131 let dim = self.dim.min(ndim); let mut item_shape_expanded = Vec::with_capacity(ndim + 1);
135 item_shape_expanded.extend_from_slice(&first_shape[..dim]);
136 item_shape_expanded.push(1);
137 item_shape_expanded.extend_from_slice(&first_shape[dim..]);
138
139 let reshaped: Vec<Tensor<f32>> = batch
141 .iter()
142 .map(|t| Tensor::from_vec(t.to_vec(), &item_shape_expanded).unwrap())
143 .collect();
144
145 concat_tensors(&reshaped, dim)
146 }
147}
148
149impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
150 type Output = (Tensor<f32>, Tensor<f32>);
151
152 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
153 if batch.is_empty() {
154 return (
155 Tensor::from_vec(vec![], &[0]).unwrap(),
156 Tensor::from_vec(vec![], &[0]).unwrap(),
157 );
158 }
159
160 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
161 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
162
163 (stack_tensors(&inputs), stack_tensors(&targets))
164 }
165}
166
167#[must_use]
173pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
174 if tensors.is_empty() {
175 return Tensor::from_vec(vec![], &[0]).unwrap();
176 }
177
178 let first_shape = tensors[0].shape();
179 let batch_size = tensors.len();
180
181 let mut new_shape = vec![batch_size];
183 new_shape.extend_from_slice(first_shape);
184
185 let mut all_data = Vec::new();
187 for tensor in tensors {
188 all_data.extend(tensor.to_vec());
189 }
190
191 Tensor::from_vec(all_data, &new_shape).unwrap()
192}
193
194#[must_use]
196pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
197 if tensors.is_empty() {
198 return Tensor::from_vec(vec![], &[0]).unwrap();
199 }
200
201 if tensors.len() == 1 {
202 return tensors[0].clone();
203 }
204
205 let first_shape = tensors[0].shape();
206 let ndim = first_shape.len();
207
208 let mut new_shape = first_shape.to_vec();
210 let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
211 new_shape[dim] = concat_size;
212
213 if dim == 0 {
215 let mut all_data = Vec::new();
216 for tensor in tensors {
217 all_data.extend(tensor.to_vec());
218 }
219 return Tensor::from_vec(all_data, &new_shape).unwrap();
220 }
221
222 let outer_size: usize = first_shape[..dim].iter().product();
226 let inner_size: usize = if dim + 1 < ndim {
227 first_shape[dim + 1..].iter().product()
228 } else {
229 1
230 };
231
232 let total_elements: usize = new_shape.iter().product();
233 let mut result = Vec::with_capacity(total_elements);
234
235 let all_vecs: Vec<Vec<f32>> = tensors.iter().map(|t| t.to_vec()).collect();
237
238 for o in 0..outer_size {
239 for (t_idx, tensor_data) in all_vecs.iter().enumerate() {
241 let t_dim_size = tensors[t_idx].shape()[dim];
242 let t_inner_stride = t_dim_size * inner_size;
243 let src_offset = o * t_inner_stride;
244 result.extend_from_slice(&tensor_data[src_offset..src_offset + t_inner_stride]);
245 }
246 }
247
248 Tensor::from_vec(result, &new_shape).unwrap()
249}
250
251#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_stack_tensors() {
261 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
262 let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
263 let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
264
265 let stacked = stack_tensors(&[t1, t2, t3]);
266 assert_eq!(stacked.shape(), &[3, 2]);
267 assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
268 }
269
270 #[test]
271 fn test_stack_tensors_2d() {
272 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
273 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
274
275 let stacked = stack_tensors(&[t1, t2]);
276 assert_eq!(stacked.shape(), &[2, 2, 2]);
277 }
278
279 #[test]
280 fn test_default_collate() {
281 let collate = DefaultCollate::new();
282
283 let batch = vec![
284 (
285 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
286 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
287 ),
288 (
289 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
290 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
291 ),
292 ];
293
294 let (x, y) = collate.collate(batch);
295 assert_eq!(x.shape(), &[2, 2]);
296 assert_eq!(y.shape(), &[2, 1]);
297 }
298
299 #[test]
300 fn test_stack_collate() {
301 let collate = StackCollate::new();
302
303 let batch = vec![
304 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
305 Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
306 ];
307
308 let result = collate.collate(batch);
309 assert_eq!(result.shape(), &[2, 3]);
310 }
311
312 #[test]
313 fn test_empty_collate() {
314 let collate = DefaultCollate::new();
315 let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
316 let (x, y) = collate.collate(batch);
317 assert_eq!(x.shape(), &[0]);
318 assert_eq!(y.shape(), &[0]);
319 }
320
321 #[test]
322 fn test_concat_tensors_dim0() {
323 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
324 let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
325
326 let concat = concat_tensors(&[t1, t2], 0);
327 assert_eq!(concat.shape(), &[5]);
328 assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
329 }
330
331 #[test]
332 fn test_concat_tensors_dim1() {
333 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
335 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
336
337 let concat = concat_tensors(&[t1, t2], 1);
338 assert_eq!(concat.shape(), &[2, 4]);
339 assert_eq!(
341 concat.to_vec(),
342 vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
343 );
344 }
345}