tract_data/
scatter.rs

1use crate::prelude::*;
2use ndarray::Dimension;
3
4pub(crate) unsafe fn scatter_contig_data<T: Datum>(
5    mut src: *const T,
6    dst: *mut T,
7    dst_len_and_strides: &[(usize, usize)],
8) {
9    unsafe {
10        match *dst_len_and_strides {
11            [(len_a, stride_a)] => {
12                for a in 0..len_a {
13                    *dst.add(a * stride_a) = (*src).clone();
14                    src = src.offset(1);
15                }
16            }
17            [(len_a, stride_a), (len_b, stride_b)] => {
18                for a in 0..len_a {
19                    for b in 0..len_b {
20                        *dst.add(a * stride_a + b * stride_b) = (*src).clone();
21                        src = src.offset(1);
22                    }
23                }
24            }
25            [(len_a, stride_a), (len_b, stride_b), (len_c, stride_c)] => {
26                for a in 0..len_a {
27                    for b in 0..len_b {
28                        for c in 0..len_c {
29                            *dst.add(a * stride_a + b * stride_b + c * stride_c) = (*src).clone();
30                            src = src.offset(1);
31                        }
32                    }
33                }
34            }
35            [(len_a, stride_a), (len_b, stride_b), (len_c, stride_c), (len_d, stride_d)] => {
36                for a in 0..len_a {
37                    for b in 0..len_b {
38                        for c in 0..len_c {
39                            for d in 0..len_d {
40                                *dst.add(
41                                    a * stride_a + b * stride_b + c * stride_c + d * stride_d,
42                                ) = (*src).clone();
43                                src = src.offset(1);
44                            }
45                        }
46                    }
47                }
48            }
49            [(len_a, stride_a), (len_b, stride_b), (len_c, stride_c), (len_d, stride_d), (len_e, stride_e)] => {
50                for a in 0..len_a {
51                    for b in 0..len_b {
52                        for c in 0..len_c {
53                            for d in 0..len_d {
54                                for e in 0..len_e {
55                                    *dst.add(
56                                        a * stride_a
57                                            + b * stride_b
58                                            + c * stride_c
59                                            + d * stride_d
60                                            + e * stride_e,
61                                    ) = (*src).clone();
62                                    src = src.offset(1);
63                                }
64                            }
65                        }
66                    }
67                }
68            }
69            _ => {
70                let shape: TVec<usize> = dst_len_and_strides.iter().map(|pair| pair.0).collect();
71                for coords in ndarray::indices(&*shape) {
72                    let offset = coords
73                        .slice()
74                        .iter()
75                        .zip(dst_len_and_strides.iter())
76                        .map(|(x, (_len, stride))| x * stride)
77                        .sum::<usize>();
78                    *dst.add(offset) = (*src).clone();
79                    src = src.offset(1);
80                }
81            }
82        }
83    }
84}