candle_core/cpu_backend/
utils.rs

1/// Helper functions to write CPU kernels.
2use crate::backend::BackendStorage;
3use crate::{Error, Layout, Result, WithDType};
4
5type C = super::CpuStorage;
6pub trait Map1 {
7    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
8
9    fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
10        match vs {
11            C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
12            C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
13            C::I16(vs) => Ok(C::I16(self.f(vs, layout)?)),
14            C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
15            C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
16            C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
17            C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
18            C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
19            C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
20            C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
21            // Dummy types don't support Map1 operations
22            C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
23            C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
24            C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
25            C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1").bt()),
26        }
27    }
28}
29
30pub trait Map1Any {
31    fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
32
33    fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
34        match vs {
35            C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
36            C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
37            C::I16(vs) => Ok(self.f(vs, layout, C::I16)?),
38            C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
39            C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
40            C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
41            C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
42            C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
43            C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
44            C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
45            // Dummy types don't support Map1Any operations
46            C::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
47            C::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
48            C::F4(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
49            C::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(vs.dtype(), "map1any").bt()),
50        }
51    }
52}
53
54pub trait Map2 {
55    const OP: &'static str;
56    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
57
58    fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
59        match (v1, v2) {
60            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
61            (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
62            (C::I16(v1), C::I16(v2)) => Ok(C::I16(self.f(v1, l1, v2, l2)?)),
63            (C::I32(v1), C::I32(v2)) => Ok(C::I32(self.f(v1, l1, v2, l2)?)),
64            (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
65            (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
66            (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
67            (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
68            (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
69            (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
70            _ => Err(Error::DTypeMismatchBinaryOp {
71                lhs: v1.dtype(),
72                rhs: v2.dtype(),
73                op: Self::OP,
74            }
75            .bt()),
76        }
77    }
78}
79
80pub trait Map2InPlace {
81    const OP: &'static str;
82    fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
83
84    fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
85        match (v1, v2) {
86            (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
87            (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
88            (C::I16(v1), C::I16(v2)) => self.f(v1, l1, v2, l2)?,
89            (C::I32(v1), C::I32(v2)) => self.f(v1, l1, v2, l2)?,
90            (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
91            (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
92            (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
93            (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
94            (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
95            (C::F8E4M3(v1), C::F8E4M3(v2)) => self.f(v1, l1, v2, l2)?,
96            (v1, v2) => Err(Error::DTypeMismatchBinaryOp {
97                lhs: v1.dtype(),
98                rhs: v2.dtype(),
99                op: Self::OP,
100            }
101            .bt())?,
102        };
103        Ok(())
104    }
105}
106
107pub trait Map2U8 {
108    const OP: &'static str;
109    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
110
111    fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
112        match (v1, v2) {
113            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
114            (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
115            (C::I16(v1), C::I16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
116            (C::I32(v1), C::I32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
117            (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
118            (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
119            (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
120            (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
121            (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
122            (C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
123            _ => Err(Error::DTypeMismatchBinaryOp {
124                lhs: v1.dtype(),
125                rhs: v2.dtype(),
126                op: Self::OP,
127            }
128            .bt()),
129        }
130    }
131}
132
133pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
134    lhs_l: &Layout,
135    rhs_l: &Layout,
136    lhs: &[T],
137    rhs: &[T],
138    mut f: F,
139) -> Vec<U> {
140    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
141        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
142            .iter()
143            .zip(rhs[o_r1..o_r2].iter())
144            .map(|(&l, &r)| f(l, r))
145            .collect(),
146        (Some((o_l1, o_l2)), None) => {
147            // TODO: Maybe we want to avoid going through the layout twice.
148            match rhs_l.offsets_b() {
149                Some(ob) => {
150                    let mut i_in_block = 0;
151                    let mut i_right_broadcast = 0;
152                    lhs[o_l1..o_l2]
153                        .iter()
154                        .map(|&l| {
155                            let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
156                            i_right_broadcast += 1;
157                            if i_right_broadcast >= ob.right_broadcast {
158                                i_in_block += 1;
159                                i_right_broadcast = 0;
160                            }
161                            if i_in_block >= ob.len {
162                                i_in_block = 0
163                            }
164                            f(l, *r)
165                        })
166                        .collect()
167                }
168                None => lhs_l
169                    .strided_index()
170                    .zip(rhs_l.strided_index())
171                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
172                    .collect(),
173            }
174        }
175        (None, Some((o_r1, o_r2))) => {
176            // TODO: Maybe we want to avoid going through the layout twice.
177            match lhs_l.offsets_b() {
178                Some(ob) => {
179                    let mut i_in_block = 0;
180                    let mut i_right_broadcast = 0;
181                    rhs[o_r1..o_r2]
182                        .iter()
183                        .map(|&r| {
184                            let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
185                            i_right_broadcast += 1;
186                            if i_right_broadcast >= ob.right_broadcast {
187                                i_in_block += 1;
188                                i_right_broadcast = 0;
189                            }
190                            if i_in_block >= ob.len {
191                                i_in_block = 0
192                            }
193                            f(*l, r)
194                        })
195                        .collect()
196                }
197                None => lhs_l
198                    .strided_index()
199                    .zip(rhs_l.strided_index())
200                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
201                    .collect(),
202            }
203        }
204        _ => lhs_l
205            .strided_index()
206            .zip(rhs_l.strided_index())
207            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
208            .collect(),
209    }
210}
211
212// Similar to binary_map but with vectorized variants.
213pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
214    lhs_l: &Layout,
215    rhs_l: &Layout,
216    lhs: &[T],
217    rhs: &[T],
218    mut f: F,
219    mut f_vec: FV,
220) -> Vec<T> {
221    let el_count = lhs_l.shape().elem_count();
222    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
223        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
224            let mut ys: Vec<T> = Vec::with_capacity(el_count);
225            let ys_to_set = ys.spare_capacity_mut();
226            let ys_to_set = unsafe {
227                std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
228            };
229            f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
230            // SAFETY: values are all set by f_vec.
231            unsafe { ys.set_len(el_count) };
232            ys
233        }
234        (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
235            Some(ob) if ob.right_broadcast == 1 => {
236                let rhs = &rhs[ob.start..ob.start + ob.len];
237                let mut ys: Vec<T> = Vec::with_capacity(el_count);
238                let ys_to_set = ys.spare_capacity_mut();
239                let ys_to_set = unsafe {
240                    std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
241                };
242                let mut dst_i = 0;
243                for src_i in (o_l1..o_l2).step_by(ob.len) {
244                    f_vec(
245                        &lhs[src_i..src_i + ob.len],
246                        rhs,
247                        &mut ys_to_set[dst_i..dst_i + ob.len],
248                    );
249                    dst_i += ob.len;
250                }
251                // SAFETY: values are all set by f_vec.
252                unsafe { ys.set_len(el_count) };
253                ys
254            }
255            Some(ob) => {
256                let rhs = &rhs[ob.start..ob.start + ob.len];
257                let mut ys = lhs[o_l1..o_l2].to_vec();
258                for idx_l in 0..ob.left_broadcast {
259                    let start = idx_l * ob.len * ob.right_broadcast;
260                    for (i, &r) in rhs.iter().enumerate() {
261                        let start = start + i * ob.right_broadcast;
262                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
263                            *v = f(*v, r)
264                        }
265                    }
266                }
267                ys
268            }
269            None => lhs_l
270                .strided_index()
271                .zip(rhs_l.strided_index())
272                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
273                .collect(),
274        },
275        (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
276            Some(ob) if ob.right_broadcast == 1 => {
277                let lhs = &lhs[ob.start..ob.start + ob.len];
278                let mut ys: Vec<T> = Vec::with_capacity(el_count);
279                let ys_to_set = ys.spare_capacity_mut();
280                let ys_to_set = unsafe {
281                    std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
282                };
283                let mut dst_i = 0;
284                for src_i in (o_r1..o_r2).step_by(ob.len) {
285                    f_vec(
286                        lhs,
287                        &rhs[src_i..src_i + ob.len],
288                        &mut ys_to_set[dst_i..dst_i + ob.len],
289                    );
290                    dst_i += ob.len;
291                }
292                // SAFETY: values are all set by f_vec.
293                unsafe { ys.set_len(el_count) };
294                ys
295            }
296            Some(ob) => {
297                let lhs = &lhs[ob.start..ob.start + ob.len];
298                let mut ys = rhs[o_r1..o_r2].to_vec();
299                for idx_l in 0..ob.left_broadcast {
300                    let start = idx_l * ob.len * ob.right_broadcast;
301                    for (i, &l) in lhs.iter().enumerate() {
302                        let start = start + i * ob.right_broadcast;
303                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
304                            *v = f(l, *v)
305                        }
306                    }
307                }
308                ys
309            }
310            None => lhs_l
311                .strided_index()
312                .zip(rhs_l.strided_index())
313                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
314                .collect(),
315        },
316        _ => lhs_l
317            .strided_index()
318            .zip(rhs_l.strided_index())
319            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
320            .collect(),
321    }
322}
323
324pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
325    vs: &[T],
326    layout: &Layout,
327    mut f: F,
328) -> Vec<U> {
329    match layout.strided_blocks() {
330        crate::StridedBlocks::SingleBlock { start_offset, len } => vs
331            [start_offset..start_offset + len]
332            .iter()
333            .map(|&v| f(v))
334            .collect(),
335        crate::StridedBlocks::MultipleBlocks {
336            block_start_index,
337            block_len,
338        } => {
339            let mut result = Vec::with_capacity(layout.shape().elem_count());
340            // Specialize the case where block_len is one to avoid the second loop.
341            if block_len == 1 {
342                for index in block_start_index {
343                    let v = unsafe { vs.get_unchecked(index) };
344                    result.push(f(*v))
345                }
346            } else {
347                for index in block_start_index {
348                    for offset in 0..block_len {
349                        let v = unsafe { vs.get_unchecked(index + offset) };
350                        result.push(f(*v))
351                    }
352                }
353            }
354            result
355        }
356    }
357}
358
359pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
360    vs: &[T],
361    layout: &Layout,
362    mut f: F,
363    mut f_vec: FV,
364) -> Vec<U> {
365    match layout.strided_blocks() {
366        crate::StridedBlocks::SingleBlock { start_offset, len } => {
367            let mut ys: Vec<U> = Vec::with_capacity(len);
368            let ys_to_set = ys.spare_capacity_mut();
369            let ys_to_set = unsafe {
370                std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
371            };
372            f_vec(&vs[start_offset..start_offset + len], ys_to_set);
373            // SAFETY: values are all set by f_vec.
374            unsafe { ys.set_len(len) };
375            ys
376        }
377        crate::StridedBlocks::MultipleBlocks {
378            block_start_index,
379            block_len,
380        } => {
381            let el_count = layout.shape().elem_count();
382            // Specialize the case where block_len is one to avoid the second loop.
383            if block_len == 1 {
384                let mut result = Vec::with_capacity(el_count);
385                for index in block_start_index {
386                    let v = unsafe { vs.get_unchecked(index) };
387                    result.push(f(*v))
388                }
389                result
390            } else {
391                let mut ys: Vec<U> = Vec::with_capacity(el_count);
392                let ys_to_set = ys.spare_capacity_mut();
393                let ys_to_set = unsafe {
394                    std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
395                };
396                let mut dst_index = 0;
397                for src_index in block_start_index {
398                    let vs = &vs[src_index..src_index + block_len];
399                    let ys = &mut ys_to_set[dst_index..dst_index + block_len];
400                    f_vec(vs, ys);
401                    dst_index += block_len;
402                }
403                // SAFETY: values are all set by f_vec.
404                unsafe { ys.set_len(el_count) };
405                ys
406            }
407        }
408    }
409}