1use axonml_tensor::Tensor;
9
10pub trait Collate<T>: Send + Sync {
16 type Output;
18
19 fn collate(&self, batch: Vec<T>) -> Self::Output;
21}
22
23pub struct DefaultCollate;
29
30impl DefaultCollate {
31 #[must_use] pub fn new() -> Self {
33 Self
34 }
35}
36
37impl Default for DefaultCollate {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl Collate<(Tensor<f32>, Tensor<f32>)> for DefaultCollate {
44 type Output = (Tensor<f32>, Tensor<f32>);
45
46 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
47 if batch.is_empty() {
48 return (
49 Tensor::from_vec(vec![], &[0]).unwrap(),
50 Tensor::from_vec(vec![], &[0]).unwrap(),
51 );
52 }
53
54 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
56 let stacked_x = stack_tensors(&inputs);
57
58 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
60 let stacked_y = stack_tensors(&targets);
61
62 (stacked_x, stacked_y)
63 }
64}
65
66impl Collate<Tensor<f32>> for DefaultCollate {
67 type Output = Tensor<f32>;
68
69 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
70 stack_tensors(&batch)
71 }
72}
73
74pub struct StackCollate {
80 dim: usize,
82}
83
84impl StackCollate {
85 #[must_use] pub fn new() -> Self {
87 Self { dim: 0 }
88 }
89
90 #[must_use] pub fn with_dim(dim: usize) -> Self {
92 Self { dim }
93 }
94}
95
96impl Default for StackCollate {
97 fn default() -> Self {
98 Self::new()
99 }
100}
101
102impl Collate<Tensor<f32>> for StackCollate {
103 type Output = Tensor<f32>;
104
105 fn collate(&self, batch: Vec<Tensor<f32>>) -> Self::Output {
106 if self.dim == 0 {
107 stack_tensors(&batch)
108 } else {
109 stack_tensors(&batch)
112 }
113 }
114}
115
116impl Collate<(Tensor<f32>, Tensor<f32>)> for StackCollate {
117 type Output = (Tensor<f32>, Tensor<f32>);
118
119 fn collate(&self, batch: Vec<(Tensor<f32>, Tensor<f32>)>) -> Self::Output {
120 if batch.is_empty() {
121 return (
122 Tensor::from_vec(vec![], &[0]).unwrap(),
123 Tensor::from_vec(vec![], &[0]).unwrap(),
124 );
125 }
126
127 let inputs: Vec<Tensor<f32>> = batch.iter().map(|(x, _)| x.clone()).collect();
128 let targets: Vec<Tensor<f32>> = batch.iter().map(|(_, y)| y.clone()).collect();
129
130 (stack_tensors(&inputs), stack_tensors(&targets))
131 }
132}
133
134#[must_use] pub fn stack_tensors(tensors: &[Tensor<f32>]) -> Tensor<f32> {
140 if tensors.is_empty() {
141 return Tensor::from_vec(vec![], &[0]).unwrap();
142 }
143
144 let first_shape = tensors[0].shape();
145 let batch_size = tensors.len();
146
147 let mut new_shape = vec![batch_size];
149 new_shape.extend_from_slice(first_shape);
150
151 let mut all_data = Vec::new();
153 for tensor in tensors {
154 all_data.extend(tensor.to_vec());
155 }
156
157 Tensor::from_vec(all_data, &new_shape).unwrap()
158}
159
160#[must_use] pub fn concat_tensors(tensors: &[Tensor<f32>], dim: usize) -> Tensor<f32> {
162 if tensors.is_empty() {
163 return Tensor::from_vec(vec![], &[0]).unwrap();
164 }
165
166 if tensors.len() == 1 {
167 return tensors[0].clone();
168 }
169
170 let first_shape = tensors[0].shape();
171
172 let mut new_shape = first_shape.to_vec();
174 let concat_size: usize = tensors.iter().map(|t| t.shape()[dim]).sum();
175 new_shape[dim] = concat_size;
176
177 if dim == 0 {
179 let mut all_data = Vec::new();
180 for tensor in tensors {
181 all_data.extend(tensor.to_vec());
182 }
183 return Tensor::from_vec(all_data, &new_shape).unwrap();
184 }
185
186 let mut all_data = Vec::new();
189 for tensor in tensors {
190 all_data.extend(tensor.to_vec());
191 }
192 Tensor::from_vec(all_data, &new_shape).unwrap()
193}
194
195#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_stack_tensors() {
205 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
206 let t2 = Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap();
207 let t3 = Tensor::from_vec(vec![5.0, 6.0], &[2]).unwrap();
208
209 let stacked = stack_tensors(&[t1, t2, t3]);
210 assert_eq!(stacked.shape(), &[3, 2]);
211 assert_eq!(stacked.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
212 }
213
214 #[test]
215 fn test_stack_tensors_2d() {
216 let t1 = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
217 let t2 = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
218
219 let stacked = stack_tensors(&[t1, t2]);
220 assert_eq!(stacked.shape(), &[2, 2, 2]);
221 }
222
223 #[test]
224 fn test_default_collate() {
225 let collate = DefaultCollate::new();
226
227 let batch = vec![
228 (
229 Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(),
230 Tensor::from_vec(vec![0.0], &[1]).unwrap(),
231 ),
232 (
233 Tensor::from_vec(vec![3.0, 4.0], &[2]).unwrap(),
234 Tensor::from_vec(vec![1.0], &[1]).unwrap(),
235 ),
236 ];
237
238 let (x, y) = collate.collate(batch);
239 assert_eq!(x.shape(), &[2, 2]);
240 assert_eq!(y.shape(), &[2, 1]);
241 }
242
243 #[test]
244 fn test_stack_collate() {
245 let collate = StackCollate::new();
246
247 let batch = vec![
248 Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap(),
249 Tensor::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap(),
250 ];
251
252 let result = collate.collate(batch);
253 assert_eq!(result.shape(), &[2, 3]);
254 }
255
256 #[test]
257 fn test_empty_collate() {
258 let collate = DefaultCollate::new();
259 let batch: Vec<(Tensor<f32>, Tensor<f32>)> = vec![];
260 let (x, y) = collate.collate(batch);
261 assert_eq!(x.shape(), &[0]);
262 assert_eq!(y.shape(), &[0]);
263 }
264
265 #[test]
266 fn test_concat_tensors() {
267 let t1 = Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap();
268 let t2 = Tensor::from_vec(vec![3.0, 4.0, 5.0], &[3]).unwrap();
269
270 let concat = concat_tensors(&[t1, t2], 0);
271 assert_eq!(concat.shape(), &[5]);
272 assert_eq!(concat.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
273 }
274}