1use axonml_tensor::Tensor;
9
10pub trait Collate<T>: Send + Sync {
16 type Output;
18
19 fn collate(&self, batch: Vec<T>) -> Self::Output;
21}
22
23pub struct DefaultCollate;
29
30impl DefaultCollate {
31 #[must_use]
33 pub fn new() -> Self {
34 Self
35 }
36}
37
38impl Default for DefaultCollate {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
45 type Output = (Tensor<f32>, Tensor<f32>);
46
47 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
48 if batch.is_empty() {
49 return (
50 Tensor::from_vec(vec![], &[0]).unwrap(),
51 Tensor::from_vec(vec![], &[0]).unwrap(),
52 );
53 }
54
55 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
57 let stacked_x = stack_tensors(&inputs);
58
59 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
61 let stacked_y = stack_tensors(&targets);
62
63 (stacked_x, stacked_y)
64 }
65}
66
67impl Collate<Tensor<f32>> for DefaultCollate {
68 type Output = Tensor<f32>;
69
70 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
71 stack_tensors(&batch)
72 }
73}
74
75pub struct StackCollate {
81 dim: usize,
83}
84
85impl StackCollate {
86 #[must_use]
88 pub fn new() -> Self {
89 Self { dim: 0 }
90 }
91
92 #[must_use]
94 pub fn with_dim(dim: usize) -> Self {
95 Self { dim }
96 }
97}
98
99impl Default for StackCollate {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl Collate<Tensor<f32>> for StackCollate {
106 type Output = Tensor<f32>;
107
108 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
109 if self.dim == 0 {
110 stack_tensors(&batch)
111 } else {
112 stack_tensors(&batch)
115 }
116 }
117}
118
119impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
120 type Output = (Tensor<f32>, Tensor<f32>);
121
122 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
123 if batch.is_empty() {
124 return (
125 Tensor::from_vec(vec![], &[0]).unwrap(),
126 Tensor::from_vec(vec![], &[0]).unwrap(),
127 );
128 }
129
130 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
131 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
132
133 (stack_tensors(&inputs), stack_tensors(&targets))
134 }
135}
136
137#[must_use]
143pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
144 if tensors.is_empty() {
145 return Tensor::from_vec(vec![], &[0]).unwrap();
146 }
147
148 let first_shape = tensors[0].shape();
149 let batch_size = tensors.len();
150
151 let mut new_shape = vec![batch_size];
153 new_shape.extend_from_slice(first_shape);
154
155 let mut all_data = Vec::new();
157 for tensor in tensors {
158 all_data.extend(tensor.to_vec());
159 }
160
161 Tensor::from_vec(all_data, &new_shape).unwrap()
162}
163
164#[must_use]
166pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
167 if tensors.is_empty() {
168 return Tensor::from_vec(vec![], &[0]).unwrap();
169 }
170
171 if tensors.len() == 1 {
172 return tensors[0].clone();
173 }
174
175 let first_shape = tensors[0].shape();
176
177 let mut new_shape = first_shape.to_vec();
179 let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
180 new_shape[dim] = concat_size;
181
182 if dim == 0 {
184 let mut all_data = Vec::new();
185 for tensor in tensors {
186 all_data.extend(tensor.to_vec());
187 }
188 return Tensor::from_vec(all_data, &new_shape).unwrap();
189 }
190
191 let mut all_data = Vec::new();
194 for tensor in tensors {
195 all_data.extend(tensor.to_vec());
196 }
197 Tensor::from_vec(all_data, &new_shape).unwrap()
198}
199
200#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_stack_tensors() {
210 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
211 let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
212 let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
213
214 let stacked = stack_tensors(&[t1, t2, t3]);
215 assert_eq!(stacked.shape(), &[3, 2]);
216 assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
217 }
218
219 #[test]
220 fn test_stack_tensors_2d() {
221 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
222 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
223
224 let stacked = stack_tensors(&[t1, t2]);
225 assert_eq!(stacked.shape(), &[2, 2, 2]);
226 }
227
228 #[test]
229 fn test_default_collate() {
230 let collate = DefaultCollate::new();
231
232 let batch = vec![
233 (
234 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
235 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
236 ),
237 (
238 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
239 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
240 ),
241 ];
242
243 let (x, y) = collate.collate(batch);
244 assert_eq!(x.shape(), &[2, 2]);
245 assert_eq!(y.shape(), &[2, 1]);
246 }
247
248 #[test]
249 fn test_stack_collate() {
250 let collate = StackCollate::new();
251
252 let batch = vec![
253 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
254 Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
255 ];
256
257 let result = collate.collate(batch);
258 assert_eq!(result.shape(), &[2, 3]);
259 }
260
261 #[test]
262 fn test_empty_collate() {
263 let collate = DefaultCollate::new();
264 let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
265 let (x, y) = collate.collate(batch);
266 assert_eq!(x.shape(), &[0]);
267 assert_eq!(y.shape(), &[0]);
268 }
269
270 #[test]
271 fn test_concat_tensors() {
272 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
273 let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
274
275 let concat = concat_tensors(&[t1, t2], 0);
276 assert_eq!(concat.shape(), &[5]);
277 assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
278 }
279}