rstsr_native_impl/cpu_serial/
op_with_func.rs1use crate::prelude_dev::*;
6
7const CONTIG_SWITCH: usize = 16;
9
10pub fn op_mutc_refa_refb_func_cpu_serial<TA, TB, TC, D>(
11 c: &mut [MaybeUninit<TC>],
12 lc: &Layout<D>,
13 a: &[TA],
14 la: &Layout<D>,
15 b: &[TB],
16 lb: &Layout<D>,
17 mut f: impl FnMut(&mut MaybeUninit<TC>, &TA, &TB),
18) -> Result<()>
19where
20 D: DimAPI,
21{
22 let layouts_full = translate_to_col_major(&[lc, la, lb], TensorIterOrder::K)?;
24 let layouts_full_ref = layouts_full.iter().collect_vec();
25 let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
26
27 if size_contig >= CONTIG_SWITCH {
29 let lc = &layouts_contig[0];
30 let la = &layouts_contig[1];
31 let lb = &layouts_contig[2];
32 layout_col_major_dim_dispatch_3(lc, la, lb, |(idx_c, idx_a, idx_b)| {
33 for i in 0..size_contig {
34 f(&mut c[idx_c + i], &a[idx_a + i], &b[idx_b + i]);
35 }
36 })
37 } else {
38 let lc = &layouts_full[0];
39 let la = &layouts_full[1];
40 let lb = &layouts_full[2];
41 layout_col_major_dim_dispatch_3(lc, la, lb, |(idx_c, idx_a, idx_b)| {
42 f(&mut c[idx_c], &a[idx_a], &b[idx_b]);
43 })
44 }
45}
46
47pub fn op_mutc_refa_numb_func_cpu_serial<TA, TB, TC, D>(
48 c: &mut [MaybeUninit<TC>],
49 lc: &Layout<D>,
50 a: &[TA],
51 la: &Layout<D>,
52 b: TB,
53 mut f: impl FnMut(&mut MaybeUninit<TC>, &TA, &TB),
54) -> Result<()>
55where
56 D: DimAPI,
57{
58 let layouts_full = translate_to_col_major(&[lc, la], TensorIterOrder::K)?;
60 let layouts_full_ref = layouts_full.iter().collect_vec();
61 let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
62
63 if size_contig >= CONTIG_SWITCH {
65 let lc = &layouts_contig[0];
66 let la = &layouts_contig[1];
67 layout_col_major_dim_dispatch_2(lc, la, |(idx_c, idx_a)| {
68 for i in 0..size_contig {
69 f(&mut c[idx_c + i], &a[idx_a + i], &b);
70 }
71 })
72 } else {
73 let lc = &layouts_full[0];
74 let la = &layouts_full[1];
75 layout_col_major_dim_dispatch_2(lc, la, |(idx_c, idx_a)| {
76 f(&mut c[idx_c], &a[idx_a], &b);
77 })
78 }
79}
80
81pub fn op_mutc_numa_refb_func_cpu_serial<TA, TB, TC, D>(
82 c: &mut [MaybeUninit<TC>],
83 lc: &Layout<D>,
84 a: TA,
85 b: &[TB],
86 lb: &Layout<D>,
87 mut f: impl FnMut(&mut MaybeUninit<TC>, &TA, &TB),
88) -> Result<()>
89where
90 D: DimAPI,
91{
92 let layouts_full = translate_to_col_major(&[lc, lb], TensorIterOrder::K)?;
94 let layouts_full_ref = layouts_full.iter().collect_vec();
95 let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
96
97 if size_contig >= CONTIG_SWITCH {
99 let lc = &layouts_contig[0];
100 let lb = &layouts_contig[1];
101 layout_col_major_dim_dispatch_2(lc, lb, |(idx_c, idx_b)| {
102 for i in 0..size_contig {
103 f(&mut c[idx_c + i], &a, &b[idx_b + i]);
104 }
105 })
106 } else {
107 let lc = &layouts_full[0];
108 let lb = &layouts_full[1];
109 layout_col_major_dim_dispatch_2(lc, lb, |(idx_c, idx_b)| {
110 f(&mut c[idx_c], &a, &b[idx_b]);
111 })
112 }
113}
114
115pub fn op_muta_refb_func_cpu_serial<TA, TB, D>(
116 a: &mut [MaybeUninit<TA>],
117 la: &Layout<D>,
118 b: &[TB],
119 lb: &Layout<D>,
120 mut f: impl FnMut(&mut MaybeUninit<TA>, &TB),
121) -> Result<()>
122where
123 D: DimAPI,
124{
125 let layouts_full = translate_to_col_major(&[la, lb], TensorIterOrder::K)?;
127 let layouts_full_ref = layouts_full.iter().collect_vec();
128 let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
129
130 if size_contig >= CONTIG_SWITCH {
132 let la = &layouts_contig[0];
133 let lb = &layouts_contig[1];
134 layout_col_major_dim_dispatch_2(la, lb, |(idx_a, idx_b)| {
135 for i in 0..size_contig {
136 f(&mut a[idx_a + i], &b[idx_b + i]);
137 }
138 })
139 } else {
140 let la = &layouts_full[0];
141 let lb = &layouts_full[1];
142 layout_col_major_dim_dispatch_2(la, lb, |(idx_a, idx_b)| {
143 f(&mut a[idx_a], &b[idx_b]);
144 })
145 }
146}
147
148pub fn op_muta_numb_func_cpu_serial<TA, TB, D>(
149 a: &mut [MaybeUninit<TA>],
150 la: &Layout<D>,
151 b: TB,
152 mut f: impl FnMut(&mut MaybeUninit<TA>, &TB),
153) -> Result<()>
154where
155 D: DimAPI,
156{
157 let layout = translate_to_col_major_unary(la, TensorIterOrder::G)?;
158 let (layout_contig, size_contig) = translate_to_col_major_with_contig(&[&layout]);
159
160 if size_contig >= CONTIG_SWITCH {
161 let la = &layout_contig[0];
162 layout_col_major_dim_dispatch_1(la, |idx_a| {
163 for i in 0..size_contig {
164 f(&mut a[idx_a + i], &b);
165 }
166 })
167 } else {
168 let la = &layout;
169 layout_col_major_dim_dispatch_1(la, |idx_a| {
170 f(&mut a[idx_a], &b);
171 })
172 }
173}
174
175pub fn op_muta_func_cpu_serial<T, D>(
176 a: &mut [MaybeUninit<T>],
177 la: &Layout<D>,
178 mut f: impl FnMut(&mut MaybeUninit<T>),
179) -> Result<()>
180where
181 D: DimAPI,
182{
183 let layout = translate_to_col_major_unary(la, TensorIterOrder::G)?;
184 let (layout_contig, size_contig) = translate_to_col_major_with_contig(&[&layout]);
185
186 if size_contig >= CONTIG_SWITCH {
187 let la = &layout_contig[0];
188 layout_col_major_dim_dispatch_1(la, |idx_a| {
189 for i in 0..size_contig {
190 f(&mut a[idx_a + i]);
191 }
192 })
193 } else {
194 let la = &layout;
195 layout_col_major_dim_dispatch_1(la, |idx_a| {
196 f(&mut a[idx_a]);
197 })
198 }
199}