1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4const CONTIG_SWITCH: usize = 32;
6
7pub fn unrolled_reduce<TI, TS, I, F, FSum>(mut xs: &[TI], init: I, f: F, f_sum: FSum) -> TS
13where
14 TI: Clone,
15 TS: Clone,
16 I: Fn() -> TS,
17 F: Fn(TS, TI) -> TS,
18 FSum: Fn(TS, TS) -> TS,
19{
20 let mut acc = init();
23 let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) =
24 (init(), init(), init(), init(), init(), init(), init(), init());
25 while xs.len() >= 8 {
26 p0 = f(p0, xs[0].clone());
27 p1 = f(p1, xs[1].clone());
28 p2 = f(p2, xs[2].clone());
29 p3 = f(p3, xs[3].clone());
30 p4 = f(p4, xs[4].clone());
31 p5 = f(p5, xs[5].clone());
32 p6 = f(p6, xs[6].clone());
33 p7 = f(p7, xs[7].clone());
34
35 xs = &xs[8..];
36 }
37 acc = f_sum(acc.clone(), f_sum(p0, p4));
38 acc = f_sum(acc.clone(), f_sum(p1, p5));
39 acc = f_sum(acc.clone(), f_sum(p2, p6));
40 acc = f_sum(acc.clone(), f_sum(p3, p7));
41
42 for (i, x) in xs.iter().enumerate() {
45 if i >= 7 {
46 break;
47 }
48 acc = f(acc.clone(), x.clone())
49 }
50 acc
51}
52
53pub fn unrolled_binary_reduce<TI1, TI2, TS, I, F, FSum>(
62 mut xs1: &[TI1],
63 mut xs2: &[TI2],
64 init: I,
65 f: F,
66 f_sum: FSum,
67) -> TS
68where
69 TI1: Clone,
70 TI2: Clone,
71 TS: Clone,
72 I: Fn() -> TS,
73 F: Fn(TS, (TI1, TI2)) -> TS,
74 FSum: Fn(TS, TS) -> TS,
75{
76 let mut acc = init();
79 let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) =
80 (init(), init(), init(), init(), init(), init(), init(), init());
81 while xs1.len() >= 8 && xs2.len() >= 8 {
82 p0 = f(p0, (xs1[0].clone(), xs2[0].clone()));
83 p1 = f(p1, (xs1[1].clone(), xs2[1].clone()));
84 p2 = f(p2, (xs1[2].clone(), xs2[2].clone()));
85 p3 = f(p3, (xs1[3].clone(), xs2[3].clone()));
86 p4 = f(p4, (xs1[4].clone(), xs2[4].clone()));
87 p5 = f(p5, (xs1[5].clone(), xs2[5].clone()));
88 p6 = f(p6, (xs1[6].clone(), xs2[6].clone()));
89 p7 = f(p7, (xs1[7].clone(), xs2[7].clone()));
90
91 xs1 = &xs1[8..];
92 xs2 = &xs2[8..];
93 }
94 acc = f_sum(acc.clone(), f_sum(p0, p4));
95 acc = f_sum(acc.clone(), f_sum(p1, p5));
96 acc = f_sum(acc.clone(), f_sum(p2, p6));
97 acc = f_sum(acc.clone(), f_sum(p3, p7));
98
99 for (i, (x1, x2)) in (xs1.iter().zip(xs2.iter())).enumerate() {
102 if i >= 7 {
103 break;
104 }
105 acc = f(acc.clone(), (x1.clone(), x2.clone()))
106 }
107 acc
108}
109
110pub fn reduce_all_cpu_serial<TI, TS, TO, D, I, F, FSum, FOut>(
113 a: &[TI],
114 la: &Layout<D>,
115 init: I,
116 f: F,
117 f_sum: FSum,
118 f_out: FOut,
119) -> Result<TO>
120where
121 TI: Clone,
122 TS: Clone,
123 D: DimAPI,
124 I: Fn() -> TS,
125 F: Fn(TS, TI) -> TS,
126 FSum: Fn(TS, TS) -> TS,
127 FOut: Fn(TS) -> TO,
128{
129 let layout = translate_to_col_major_unary(la, TensorIterOrder::K)?;
131 let (layout_contig, size_contig) = translate_to_col_major_with_contig(&[&layout]);
132
133 if size_contig >= CONTIG_SWITCH {
134 let mut acc = init();
135 layout_col_major_dim_dispatch_1(&layout_contig[0], |idx_a| {
136 let slc = &a[idx_a..idx_a + size_contig];
137 let acc_inner = unrolled_reduce(slc, &init, &f, &f_sum);
138 acc = f_sum(acc.clone(), acc_inner);
139 })?;
140 Ok(f_out(acc))
141 } else {
142 let iter_a = IterLayoutColMajor::new(&layout)?;
143 let acc = iter_a.fold(init(), |acc, idx| f(acc, a[idx].clone()));
144 Ok(f_out(acc))
145 }
146}
147
148pub fn reduce_axes_cpu_serial<TI, TS, I, F, FSum, FOut>(
149 a: &[TI],
150 la: &Layout<IxD>,
151 axes: &[isize],
152 init: I,
153 f: F,
154 f_sum: FSum,
155 f_out: FOut,
156) -> Result<(Vec<TS>, Layout<IxD>)>
157where
158 TI: Clone,
159 TS: Clone,
160 I: Fn() -> TS,
161 F: Fn(TS, TI) -> TS,
162 FSum: Fn(TS, TS) -> TS,
163 FOut: Fn(TS) -> TS,
164{
165 let (layout_axes, layout_rest) = la.dim_split_axes(axes)?;
167 let layout_axes = translate_to_col_major_unary(&layout_axes, TensorIterOrder::default())?;
168
169 let layout_out = layout_for_array_copy(&layout_rest, TensorIterOrder::default())?;
171
172 let (_, size_contig) = translate_to_col_major_with_contig(&[&layout_axes]);
174 if size_contig >= CONTIG_SWITCH {
175 let layouts_swapped = translate_to_col_major(&[&layout_out, &layout_rest], TensorIterOrder::default())?;
177 let layout_out_swapped = &layouts_swapped[0];
178 let layout_rest_swapped = &layouts_swapped[1];
179
180 let iter_out_swapped = IterLayoutRowMajor::new(layout_out_swapped)?;
182 let iter_rest_swapped = IterLayoutRowMajor::new(layout_rest_swapped)?;
183
184 let mut layout_inner = layout_axes.clone();
186
187 let len_out = layout_out.size();
189 let mut out: Vec<MaybeUninit<TS>> = unsafe { uninitialized_vec(len_out)? };
190
191 izip!(iter_out_swapped, iter_rest_swapped).try_for_each(|(idx_out, idx_rest)| -> Result<()> {
193 unsafe { layout_inner.set_offset(idx_rest) };
194 let acc = reduce_all_cpu_serial(a, &layout_inner, &init, &f, &f_sum, &f_out)?;
195 out[idx_out] = MaybeUninit::new(acc);
196 Ok(())
197 })?;
198 let out = unsafe { transmute::<Vec<MaybeUninit<TS>>, Vec<TS>>(out) };
199 Ok((out, layout_out))
200 } else {
201 let iter_layout_axes = IterLayoutRowMajor::new(&layout_axes)?;
203
204 let mut layout_inner = layout_rest.clone();
206
207 let len_out = layout_out.size();
209 let init_val = init();
210 let out = vec![init_val; len_out];
211 let mut out = unsafe { transmute::<Vec<TS>, Vec<MaybeUninit<TS>>>(out) };
212
213 let f_add = |a: &mut MaybeUninit<TS>, b: &TI| unsafe {
215 a.write(f(a.assume_init_read(), b.clone()));
216 };
217
218 for idx_axes in iter_layout_axes {
219 unsafe { layout_inner.set_offset(idx_axes) };
220 op_muta_refb_func_cpu_serial(&mut out, &layout_out, a, &layout_inner, f_add)?;
221 }
222 let fin_inplace = |a: &mut MaybeUninit<TS>| unsafe {
223 a.write(f_out(a.assume_init_read()));
224 };
225 op_muta_func_cpu_serial(&mut out, &layout_out, fin_inplace)?;
226 let out = unsafe { transmute::<Vec<MaybeUninit<TS>>, Vec<TS>>(out) };
227 Ok((out, layout_out))
228 }
229}
230
231pub fn reduce_axes_difftype_cpu_serial<TI, TS, TO, I, F, FSum, FOut>(
232 a: &[TI],
233 la: &Layout<IxD>,
234 axes: &[isize],
235 init: I,
236 f: F,
237 f_sum: FSum,
238 f_out: FOut,
239) -> Result<(Vec<TO>, Layout<IxD>)>
240where
241 TI: Clone,
242 TS: Clone,
243 I: Fn() -> TS,
244 F: Fn(TS, TI) -> TS,
245 FSum: Fn(TS, TS) -> TS,
246 FOut: Fn(TS) -> TO,
247{
248 let (layout_axes, layout_rest) = la.dim_split_axes(axes)?;
250 let layout_axes = translate_to_col_major_unary(&layout_axes, TensorIterOrder::default())?;
251
252 let layout_out = layout_for_array_copy(&layout_rest, TensorIterOrder::default())?;
254
255 let (_, size_contig) = translate_to_col_major_with_contig(&[&layout_axes]);
257 if size_contig >= CONTIG_SWITCH {
258 let layouts_swapped = translate_to_col_major(&[&layout_out, &layout_rest], TensorIterOrder::default())?;
260 let layout_out_swapped = &layouts_swapped[0];
261 let layout_rest_swapped = &layouts_swapped[1];
262
263 let iter_out_swapped = IterLayoutRowMajor::new(layout_out_swapped)?;
265 let iter_rest_swapped = IterLayoutRowMajor::new(layout_rest_swapped)?;
266
267 let mut layout_inner = layout_axes.clone();
269
270 let len_out = layout_out.size();
272 let mut out: Vec<MaybeUninit<TO>> = unsafe { uninitialized_vec(len_out)? };
273
274 izip!(iter_out_swapped, iter_rest_swapped).try_for_each(|(idx_out, idx_rest)| -> Result<()> {
276 unsafe { layout_inner.set_offset(idx_rest) };
277 let acc = reduce_all_cpu_serial(a, &layout_inner, &init, &f, &f_sum, &f_out)?;
278 out[idx_out] = MaybeUninit::new(acc);
279 Ok(())
280 })?;
281 let out = unsafe { transmute::<Vec<MaybeUninit<TO>>, Vec<TO>>(out) };
282 Ok((out, layout_out))
283 } else {
284 let iter_layout_axes = IterLayoutRowMajor::new(&layout_axes)?;
286
287 let mut layout_inner = layout_rest.clone();
289
290 let len_out = layout_out.size();
292 let init_val = init();
293 let out = vec![init_val; len_out];
294 let mut out = unsafe { transmute::<Vec<TS>, Vec<MaybeUninit<TS>>>(out) };
295
296 let f_add = |a: &mut MaybeUninit<TS>, b: &TI| unsafe {
298 a.write(f(a.assume_init_read(), b.clone()));
299 };
300
301 for idx_axes in iter_layout_axes {
302 unsafe { layout_inner.set_offset(idx_axes) };
303 op_muta_refb_func_cpu_serial(&mut out, &layout_out, a, &layout_inner, f_add)?;
304 }
305
306 let mut out_converted = unsafe { uninitialized_vec(len_out)? };
307 let f_out = |a: &mut MaybeUninit<TO>, b: &MaybeUninit<TS>| unsafe {
308 a.write(f_out(b.assume_init_read()));
309 };
310 op_muta_refb_func_cpu_serial(&mut out_converted, &layout_out, &out, &layout_out, f_out)?;
311 let out_converted = unsafe { transmute::<Vec<MaybeUninit<TO>>, Vec<TO>>(out_converted) };
312 Ok((out_converted, layout_out))
313 }
314}
315
316pub fn reduce_all_binary_cpu_serial<TI1, TI2, TS, TO, D, I, F, FSum, FOut>(
321 a: &[TI1],
322 la: &Layout<D>,
323 b: &[TI2],
324 lb: &Layout<D>,
325 init: I,
326 f: F,
327 f_sum: FSum,
328 f_out: FOut,
329) -> Result<TO>
330where
331 TI1: Clone,
332 TI2: Clone,
333 TS: Clone,
334 D: DimAPI,
335 I: Fn() -> TS,
336 F: Fn(TS, (TI1, TI2)) -> TS,
337 FSum: Fn(TS, TS) -> TS,
338 FOut: Fn(TS) -> TO,
339{
340 let layouts_full = translate_to_col_major(&[la, lb], TensorIterOrder::K)?;
342 let layouts_full_ref = layouts_full.iter().collect_vec();
343 let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
344
345 if size_contig >= CONTIG_SWITCH {
346 let mut acc = init();
347 let la = &layouts_contig[0];
348 let lb = &layouts_contig[1];
349 layout_col_major_dim_dispatch_2(la, lb, |(idx_a, idx_b)| {
350 let slc_a = &a[idx_a..idx_a + size_contig];
351 let slc_b = &b[idx_b..idx_b + size_contig];
352 let acc_inner = unrolled_binary_reduce(slc_a, slc_b, &init, &f, &f_sum);
353 acc = f_sum(acc.clone(), acc_inner);
354 })?;
355 Ok(f_out(acc))
356 } else {
357 let la = &layouts_full[0];
358 let lb = &layouts_full[1];
359 let iter_a = IterLayoutColMajor::new(la)?;
360 let iter_b = IterLayoutColMajor::new(lb)?;
361 let acc =
362 izip!(iter_a, iter_b).fold(init(), |acc, (idx_a, idx_b)| f(acc, (a[idx_a].clone(), b[idx_b].clone())));
363 Ok(f_out(acc))
364 }
365}
366
367pub fn reduce_all_unraveled_arg_cpu_serial<T, D, Fcomp, Feq>(
372 a: &[T],
373 la: &Layout<D>,
374 f_comp: Fcomp,
375 f_eq: Feq,
376) -> Result<D>
377where
378 T: Clone,
379 D: DimAPI,
380 Fcomp: Fn(Option<T>, T) -> Option<bool>,
381 Feq: Fn(Option<T>, T) -> Option<bool>,
382{
383 rstsr_assert!(la.size() > 0, InvalidLayout, "empty sequence is not allowed for reduce_arg.")?;
384
385 let fold_func = |acc: Option<(D, T)>, (cur_idx, cur_offset): (D, usize)| {
386 let cur_val = a[cur_offset].clone();
387
388 let comp = f_comp(acc.as_ref().map(|(_, val)| val.clone()), cur_val.clone());
389 if let Some(comp) = comp {
390 if comp {
391 Some((cur_idx, cur_val))
393 } else {
394 let comp_eq = f_eq(acc.as_ref().map(|(_, val)| val.clone()), cur_val.clone());
395 if comp_eq.is_some_and(|x| x) {
396 if let Some(acc_idx) = acc.as_ref().map(|(idx, _)| idx.clone()) {
398 if cur_idx < acc_idx {
399 Some((cur_idx, cur_val))
400 } else {
401 acc
402 }
403 } else {
404 Some((cur_idx, cur_val))
405 }
406 } else {
407 acc
409 }
410 }
411 } else {
412 acc
414 }
415 };
416
417 let iter_a = IndexedIterLayout::new(la, RowMajor)?;
418 let acc = iter_a.into_iter().fold(None, fold_func);
419 if acc.is_none() {
420 rstsr_raise!(InvalidValue, "reduce_arg seems not returning a valid value.")?;
421 }
422 Ok(acc.unwrap().0)
423}
424
425pub fn reduce_axes_unraveled_arg_cpu_serial<T, D, Fcomp, Feq>(
426 a: &[T],
427 la: &Layout<D>,
428 axes: &[isize],
429 f_comp: Fcomp,
430 f_eq: Feq,
431) -> Result<(Vec<IxD>, Layout<IxD>)>
432where
433 T: Clone,
434 D: DimAPI,
435 Fcomp: Fn(Option<T>, T) -> Option<bool>,
436 Feq: Fn(Option<T>, T) -> Option<bool>,
437{
438 rstsr_assert!(la.size() > 0, InvalidLayout, "empty sequence is not allowed for reduce_arg.")?;
439
440 let (layout_axes, layout_rest) = la.dim_split_axes(axes)?;
442
443 let layout_out = layout_for_array_copy(&layout_rest, TensorIterOrder::default())?;
445
446 let layouts_swapped = translate_to_col_major(&[&layout_out, &layout_rest], TensorIterOrder::default())?;
448 let layout_out_swapped = &layouts_swapped[0];
449 let layout_rest_swapped = &layouts_swapped[1];
450
451 let iter_out_swapped = IterLayoutRowMajor::new(layout_out_swapped)?;
453 let iter_rest_swapped = IterLayoutRowMajor::new(layout_rest_swapped)?;
454
455 let mut layout_inner = layout_axes.clone();
457
458 let len_out = layout_out.size();
460 let mut out: Vec<MaybeUninit<IxD>> = unsafe { uninitialized_vec(len_out)? };
461
462 izip!(iter_out_swapped, iter_rest_swapped).try_for_each(|(idx_out, idx_rest)| -> Result<()> {
464 unsafe { layout_inner.set_offset(idx_rest) };
465 let acc = reduce_all_unraveled_arg_cpu_serial(a, &layout_inner, &f_comp, &f_eq)?;
466 out[idx_out] = MaybeUninit::new(acc);
467 Ok(())
468 })?;
469 let out = unsafe { transmute::<Vec<MaybeUninit<IxD>>, Vec<IxD>>(out) };
470 Ok((out, layout_out))
471}
472
473pub fn reduce_all_arg_cpu_serial<T, D, Fcomp, Feq>(
474 a: &[T],
475 la: &Layout<D>,
476 f_comp: Fcomp,
477 f_eq: Feq,
478 order: FlagOrder,
479) -> Result<usize>
480where
481 T: Clone,
482 D: DimAPI,
483 Fcomp: Fn(Option<T>, T) -> Option<bool>,
484 Feq: Fn(Option<T>, T) -> Option<bool>,
485{
486 let idx = reduce_all_unraveled_arg_cpu_serial(a, la, f_comp, f_eq)?;
487 let pseudo_shape = la.shape();
488 let pseudo_layout = match order {
489 RowMajor => pseudo_shape.c(),
490 ColMajor => pseudo_shape.f(),
491 };
492 unsafe { Ok(pseudo_layout.index_uncheck(idx.as_ref()) as usize) }
493}
494
495pub fn reduce_axes_arg_cpu_serial<T, D, Fcomp, Feq>(
496 a: &[T],
497 la: &Layout<D>,
498 axes: &[isize],
499 f_comp: Fcomp,
500 f_eq: Feq,
501 order: FlagOrder,
502) -> Result<(Vec<usize>, Layout<IxD>)>
503where
504 T: Clone,
505 D: DimAPI,
506 Fcomp: Fn(Option<T>, T) -> Option<bool>,
507 Feq: Fn(Option<T>, T) -> Option<bool>,
508{
509 let (idx, layout) = reduce_axes_unraveled_arg_cpu_serial(a, la, axes, f_comp, f_eq)?;
510 let pseudo_shape = layout.shape();
511 let pseudo_layout = match order {
512 RowMajor => pseudo_shape.c(),
513 ColMajor => pseudo_shape.f(),
514 };
515 let out = idx.into_iter().map(|x| unsafe { pseudo_layout.index_uncheck(x.as_ref()) as usize }).collect();
516 Ok((out, layout))
517}
518
519