rstsr_native_impl/cpu_serial/
op_with_func.rs

1//! Basic math operations.
2//!
3//! This file assumes that layouts are pre-processed and valid.
4
5use crate::prelude_dev::*;
6
7// this value is used to determine whether to use contiguous inner iteration
8const CONTIG_SWITCH: usize = 16;
9
10pub fn op_mutc_refa_refb_func_cpu_serial<TA, TB, TC, D>(
11    c: &mut [MaybeUninit<TC>],
12    lc: &Layout<D>,
13    a: &[TA],
14    la: &Layout<D>,
15    b: &[TB],
16    lb: &Layout<D>,
17    mut f: impl FnMut(&mut MaybeUninit<TC>, &TA, &TB),
18) -> Result<()>
19where
20    D: DimAPI,
21{
22    // re-align layouts
23    let layouts_full = translate_to_col_major(&[lc, la, lb], TensorIterOrder::K)?;
24    let layouts_full_ref = layouts_full.iter().collect_vec();
25    let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
26
27    // contiguous iteration if possible, otherwise use iterator of layout
28    if size_contig >= CONTIG_SWITCH {
29        let lc = &layouts_contig[0];
30        let la = &layouts_contig[1];
31        let lb = &layouts_contig[2];
32        layout_col_major_dim_dispatch_3(lc, la, lb, |(idx_c, idx_a, idx_b)| {
33            for i in 0..size_contig {
34                f(&mut c[idx_c + i], &a[idx_a + i], &b[idx_b + i]);
35            }
36        })
37    } else {
38        let lc = &layouts_full[0];
39        let la = &layouts_full[1];
40        let lb = &layouts_full[2];
41        layout_col_major_dim_dispatch_3(lc, la, lb, |(idx_c, idx_a, idx_b)| {
42            f(&mut c[idx_c], &a[idx_a], &b[idx_b]);
43        })
44    }
45}
46
47pub fn op_mutc_refa_numb_func_cpu_serial<TA, TB, TC, D>(
48    c: &mut [MaybeUninit<TC>],
49    lc: &Layout<D>,
50    a: &[TA],
51    la: &Layout<D>,
52    b: TB,
53    mut f: impl FnMut(&mut MaybeUninit<TC>, &TA, &TB),
54) -> Result<()>
55where
56    D: DimAPI,
57{
58    // re-align layouts
59    let layouts_full = translate_to_col_major(&[lc, la], TensorIterOrder::K)?;
60    let layouts_full_ref = layouts_full.iter().collect_vec();
61    let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
62
63    // contiguous iteration if possible, otherwise use iterator of layout
64    if size_contig >= CONTIG_SWITCH {
65        let lc = &layouts_contig[0];
66        let la = &layouts_contig[1];
67        layout_col_major_dim_dispatch_2(lc, la, |(idx_c, idx_a)| {
68            for i in 0..size_contig {
69                f(&mut c[idx_c + i], &a[idx_a + i], &b);
70            }
71        })
72    } else {
73        let lc = &layouts_full[0];
74        let la = &layouts_full[1];
75        layout_col_major_dim_dispatch_2(lc, la, |(idx_c, idx_a)| {
76            f(&mut c[idx_c], &a[idx_a], &b);
77        })
78    }
79}
80
81pub fn op_mutc_numa_refb_func_cpu_serial<TA, TB, TC, D>(
82    c: &mut [MaybeUninit<TC>],
83    lc: &Layout<D>,
84    a: TA,
85    b: &[TB],
86    lb: &Layout<D>,
87    mut f: impl FnMut(&mut MaybeUninit<TC>, &TA, &TB),
88) -> Result<()>
89where
90    D: DimAPI,
91{
92    // re-align layouts
93    let layouts_full = translate_to_col_major(&[lc, lb], TensorIterOrder::K)?;
94    let layouts_full_ref = layouts_full.iter().collect_vec();
95    let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
96
97    // contiguous iteration if possible, otherwise use iterator of layout
98    if size_contig >= CONTIG_SWITCH {
99        let lc = &layouts_contig[0];
100        let lb = &layouts_contig[1];
101        layout_col_major_dim_dispatch_2(lc, lb, |(idx_c, idx_b)| {
102            for i in 0..size_contig {
103                f(&mut c[idx_c + i], &a, &b[idx_b + i]);
104            }
105        })
106    } else {
107        let lc = &layouts_full[0];
108        let lb = &layouts_full[1];
109        layout_col_major_dim_dispatch_2(lc, lb, |(idx_c, idx_b)| {
110            f(&mut c[idx_c], &a, &b[idx_b]);
111        })
112    }
113}
114
115pub fn op_muta_refb_func_cpu_serial<TA, TB, D>(
116    a: &mut [MaybeUninit<TA>],
117    la: &Layout<D>,
118    b: &[TB],
119    lb: &Layout<D>,
120    mut f: impl FnMut(&mut MaybeUninit<TA>, &TB),
121) -> Result<()>
122where
123    D: DimAPI,
124{
125    // re-align layouts
126    let layouts_full = translate_to_col_major(&[la, lb], TensorIterOrder::K)?;
127    let layouts_full_ref = layouts_full.iter().collect_vec();
128    let (layouts_contig, size_contig) = translate_to_col_major_with_contig(&layouts_full_ref);
129
130    // contiguous iteration if possible, otherwise use iterator of layout
131    if size_contig >= CONTIG_SWITCH {
132        let la = &layouts_contig[0];
133        let lb = &layouts_contig[1];
134        layout_col_major_dim_dispatch_2(la, lb, |(idx_a, idx_b)| {
135            for i in 0..size_contig {
136                f(&mut a[idx_a + i], &b[idx_b + i]);
137            }
138        })
139    } else {
140        let la = &layouts_full[0];
141        let lb = &layouts_full[1];
142        layout_col_major_dim_dispatch_2(la, lb, |(idx_a, idx_b)| {
143            f(&mut a[idx_a], &b[idx_b]);
144        })
145    }
146}
147
148pub fn op_muta_numb_func_cpu_serial<TA, TB, D>(
149    a: &mut [MaybeUninit<TA>],
150    la: &Layout<D>,
151    b: TB,
152    mut f: impl FnMut(&mut MaybeUninit<TA>, &TB),
153) -> Result<()>
154where
155    D: DimAPI,
156{
157    let layout = translate_to_col_major_unary(la, TensorIterOrder::G)?;
158    let (layout_contig, size_contig) = translate_to_col_major_with_contig(&[&layout]);
159
160    if size_contig >= CONTIG_SWITCH {
161        let la = &layout_contig[0];
162        layout_col_major_dim_dispatch_1(la, |idx_a| {
163            for i in 0..size_contig {
164                f(&mut a[idx_a + i], &b);
165            }
166        })
167    } else {
168        let la = &layout;
169        layout_col_major_dim_dispatch_1(la, |idx_a| {
170            f(&mut a[idx_a], &b);
171        })
172    }
173}
174
175pub fn op_muta_func_cpu_serial<T, D>(
176    a: &mut [MaybeUninit<T>],
177    la: &Layout<D>,
178    mut f: impl FnMut(&mut MaybeUninit<T>),
179) -> Result<()>
180where
181    D: DimAPI,
182{
183    let layout = translate_to_col_major_unary(la, TensorIterOrder::G)?;
184    let (layout_contig, size_contig) = translate_to_col_major_with_contig(&[&layout]);
185
186    if size_contig >= CONTIG_SWITCH {
187        let la = &layout_contig[0];
188        layout_col_major_dim_dispatch_1(la, |idx_a| {
189            for i in 0..size_contig {
190                f(&mut a[idx_a + i]);
191            }
192        })
193    } else {
194        let la = &layout;
195        layout_col_major_dim_dispatch_1(la, |idx_a| {
196            f(&mut a[idx_a]);
197        })
198    }
199}