1use axonml_tensor::Tensor;
18
19pub trait Collate<T>: Send + Sync {
25 type Output;
27
28 fn collate(&self, batch: Vec<T>) -> Self::Output;
30}
31
32pub struct DefaultCollate;
38
39impl DefaultCollate {
40 #[must_use]
42 pub fn new() -> Self {
43 Self
44 }
45}
46
47impl Default for DefaultCollate {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
54 type Output = (Tensor<f32>, Tensor<f32>);
55
56 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
57 if batch.is_empty() {
58 return (
59 Tensor::from_vec(vec![], &[0]).unwrap(),
60 Tensor::from_vec(vec![], &[0]).unwrap(),
61 );
62 }
63
64 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
66 let stacked_x = stack_tensors(&inputs);
67
68 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
70 let stacked_y = stack_tensors(&targets);
71
72 (stacked_x, stacked_y)
73 }
74}
75
76impl Collate<Tensor<f32>> for DefaultCollate {
77 type Output = Tensor<f32>;
78
79 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
80 stack_tensors(&batch)
81 }
82}
83
84pub struct StackCollate {
90 dim: usize,
92}
93
94impl StackCollate {
95 #[must_use]
97 pub fn new() -> Self {
98 Self { dim: 0 }
99 }
100
101 #[must_use]
103 pub fn with_dim(dim: usize) -> Self {
104 Self { dim }
105 }
106}
107
108impl Default for StackCollate {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl Collate<Tensor<f32>> for StackCollate {
115 type Output = Tensor<f32>;
116
117 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
118 if self.dim == 0 {
119 stack_tensors(&batch)
120 } else {
121 stack_tensors(&batch)
124 }
125 }
126}
127
128impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
129 type Output = (Tensor<f32>, Tensor<f32>);
130
131 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
132 if batch.is_empty() {
133 return (
134 Tensor::from_vec(vec![], &[0]).unwrap(),
135 Tensor::from_vec(vec![], &[0]).unwrap(),
136 );
137 }
138
139 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
140 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
141
142 (stack_tensors(&inputs), stack_tensors(&targets))
143 }
144}
145
146#[must_use]
152pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
153 if tensors.is_empty() {
154 return Tensor::from_vec(vec![], &[0]).unwrap();
155 }
156
157 let first_shape = tensors[0].shape();
158 let batch_size = tensors.len();
159
160 let mut new_shape = vec![batch_size];
162 new_shape.extend_from_slice(first_shape);
163
164 let mut all_data = Vec::new();
166 for tensor in tensors {
167 all_data.extend(tensor.to_vec());
168 }
169
170 Tensor::from_vec(all_data, &new_shape).unwrap()
171}
172
173#[must_use]
175pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
176 if tensors.is_empty() {
177 return Tensor::from_vec(vec![], &[0]).unwrap();
178 }
179
180 if tensors.len() == 1 {
181 return tensors[0].clone();
182 }
183
184 let first_shape = tensors[0].shape();
185
186 let mut new_shape = first_shape.to_vec();
188 let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
189 new_shape[dim] = concat_size;
190
191 if dim == 0 {
193 let mut all_data = Vec::new();
194 for tensor in tensors {
195 all_data.extend(tensor.to_vec());
196 }
197 return Tensor::from_vec(all_data, &new_shape).unwrap();
198 }
199
200 let mut all_data = Vec::new();
203 for tensor in tensors {
204 all_data.extend(tensor.to_vec());
205 }
206 Tensor::from_vec(all_data, &new_shape).unwrap()
207}
208
209#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_stack_tensors() {
219 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
220 let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
221 let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
222
223 let stacked = stack_tensors(&[t1, t2, t3]);
224 assert_eq!(stacked.shape(), &[3, 2]);
225 assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
226 }
227
228 #[test]
229 fn test_stack_tensors_2d() {
230 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
231 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
232
233 let stacked = stack_tensors(&[t1, t2]);
234 assert_eq!(stacked.shape(), &[2, 2, 2]);
235 }
236
237 #[test]
238 fn test_default_collate() {
239 let collate = DefaultCollate::new();
240
241 let batch = vec![
242 (
243 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
244 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
245 ),
246 (
247 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
248 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
249 ),
250 ];
251
252 let (x, y) = collate.collate(batch);
253 assert_eq!(x.shape(), &[2, 2]);
254 assert_eq!(y.shape(), &[2, 1]);
255 }
256
257 #[test]
258 fn test_stack_collate() {
259 let collate = StackCollate::new();
260
261 let batch = vec![
262 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
263 Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
264 ];
265
266 let result = collate.collate(batch);
267 assert_eq!(result.shape(), &[2, 3]);
268 }
269
270 #[test]
271 fn test_empty_collate() {
272 let collate = DefaultCollate::new();
273 let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
274 let (x, y) = collate.collate(batch);
275 assert_eq!(x.shape(), &[0]);
276 assert_eq!(y.shape(), &[0]);
277 }
278
279 #[test]
280 fn test_concat_tensors() {
281 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
282 let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
283
284 let concat = concat_tensors(&[t1, t2], 0);
285 assert_eq!(concat.shape(), &[5]);
286 assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
287 }
288}