candle_core/cpu_backend/
utils.rs

1/// Helper functions to write CPU kernels.
2use crate::backend::BackendStorage;
3use crate::{Error, Layout, Result, WithDType};
4
5type C = super::CpuStorage;
6pub trait Map1 {
7    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
8
9    fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
10        match vs {
11            C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
12            C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
13            C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
14            C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
15            C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
16            C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
17            C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
18        }
19    }
20}
21
22pub trait Map1Any {
23    fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
24
25    fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
26        match vs {
27            C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
28            C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
29            C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
30            C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
31            C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
32            C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
33            C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
34        }
35    }
36}
37
38pub trait Map2 {
39    const OP: &'static str;
40    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
41
42    fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
43        match (v1, v2) {
44            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
45            (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
46            (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
47            (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
48            (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
49            (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
50            (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
51            _ => Err(Error::DTypeMismatchBinaryOp {
52                lhs: v1.dtype(),
53                rhs: v2.dtype(),
54                op: Self::OP,
55            }
56            .bt()),
57        }
58    }
59}
60
61pub trait Map2InPlace {
62    const OP: &'static str;
63    fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
64
65    fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
66        match (v1, v2) {
67            (C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
68            (C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
69            (C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
70            (C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
71            (C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
72            (C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
73            (C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
74            (v1, v2) => Err(Error::DTypeMismatchBinaryOp {
75                lhs: v1.dtype(),
76                rhs: v2.dtype(),
77                op: Self::OP,
78            }
79            .bt())?,
80        };
81        Ok(())
82    }
83}
84
85pub trait Map2U8 {
86    const OP: &'static str;
87    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
88
89    fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
90        match (v1, v2) {
91            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
92            (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
93            (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
94            (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
95            (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
96            (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
97            (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
98            _ => Err(Error::DTypeMismatchBinaryOp {
99                lhs: v1.dtype(),
100                rhs: v2.dtype(),
101                op: Self::OP,
102            }
103            .bt()),
104        }
105    }
106}
107
108pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
109    lhs_l: &Layout,
110    rhs_l: &Layout,
111    lhs: &[T],
112    rhs: &[T],
113    mut f: F,
114) -> Vec<U> {
115    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
116        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
117            .iter()
118            .zip(rhs[o_r1..o_r2].iter())
119            .map(|(&l, &r)| f(l, r))
120            .collect(),
121        (Some((o_l1, o_l2)), None) => {
122            // TODO: Maybe we want to avoid going through the layout twice.
123            match rhs_l.offsets_b() {
124                Some(ob) => {
125                    let mut i_in_block = 0;
126                    let mut i_right_broadcast = 0;
127                    lhs[o_l1..o_l2]
128                        .iter()
129                        .map(|&l| {
130                            let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
131                            i_right_broadcast += 1;
132                            if i_right_broadcast >= ob.right_broadcast {
133                                i_in_block += 1;
134                                i_right_broadcast = 0;
135                            }
136                            if i_in_block >= ob.len {
137                                i_in_block = 0
138                            }
139                            f(l, *r)
140                        })
141                        .collect()
142                }
143                None => lhs_l
144                    .strided_index()
145                    .zip(rhs_l.strided_index())
146                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
147                    .collect(),
148            }
149        }
150        (None, Some((o_r1, o_r2))) => {
151            // TODO: Maybe we want to avoid going through the layout twice.
152            match lhs_l.offsets_b() {
153                Some(ob) => {
154                    let mut i_in_block = 0;
155                    let mut i_right_broadcast = 0;
156                    rhs[o_r1..o_r2]
157                        .iter()
158                        .map(|&r| {
159                            let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
160                            i_right_broadcast += 1;
161                            if i_right_broadcast >= ob.right_broadcast {
162                                i_in_block += 1;
163                                i_right_broadcast = 0;
164                            }
165                            if i_in_block >= ob.len {
166                                i_in_block = 0
167                            }
168                            f(*l, r)
169                        })
170                        .collect()
171                }
172                None => lhs_l
173                    .strided_index()
174                    .zip(rhs_l.strided_index())
175                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
176                    .collect(),
177            }
178        }
179        _ => lhs_l
180            .strided_index()
181            .zip(rhs_l.strided_index())
182            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
183            .collect(),
184    }
185}
186
187// Similar to binary_map but with vectorized variants.
188pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
189    lhs_l: &Layout,
190    rhs_l: &Layout,
191    lhs: &[T],
192    rhs: &[T],
193    mut f: F,
194    mut f_vec: FV,
195) -> Vec<T> {
196    let el_count = lhs_l.shape().elem_count();
197    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
198        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
199            let mut ys: Vec<T> = Vec::with_capacity(el_count);
200            let ys_to_set = ys.spare_capacity_mut();
201            let ys_to_set = unsafe {
202                std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
203            };
204            f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
205            // SAFETY: values are all set by f_vec.
206            unsafe { ys.set_len(el_count) };
207            ys
208        }
209        (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
210            Some(ob) if ob.right_broadcast == 1 => {
211                let rhs = &rhs[ob.start..ob.start + ob.len];
212                let mut ys: Vec<T> = Vec::with_capacity(el_count);
213                let ys_to_set = ys.spare_capacity_mut();
214                let ys_to_set = unsafe {
215                    std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
216                };
217                let mut dst_i = 0;
218                for src_i in (o_l1..o_l2).step_by(ob.len) {
219                    f_vec(
220                        &lhs[src_i..src_i + ob.len],
221                        rhs,
222                        &mut ys_to_set[dst_i..dst_i + ob.len],
223                    );
224                    dst_i += ob.len;
225                }
226                // SAFETY: values are all set by f_vec.
227                unsafe { ys.set_len(el_count) };
228                ys
229            }
230            Some(ob) => {
231                let rhs = &rhs[ob.start..ob.start + ob.len];
232                let mut ys = lhs[o_l1..o_l2].to_vec();
233                for idx_l in 0..ob.left_broadcast {
234                    let start = idx_l * ob.len * ob.right_broadcast;
235                    for (i, &r) in rhs.iter().enumerate() {
236                        let start = start + i * ob.right_broadcast;
237                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
238                            *v = f(*v, r)
239                        }
240                    }
241                }
242                ys
243            }
244            None => lhs_l
245                .strided_index()
246                .zip(rhs_l.strided_index())
247                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
248                .collect(),
249        },
250        (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
251            Some(ob) if ob.right_broadcast == 1 => {
252                let lhs = &lhs[ob.start..ob.start + ob.len];
253                let mut ys: Vec<T> = Vec::with_capacity(el_count);
254                let ys_to_set = ys.spare_capacity_mut();
255                let ys_to_set = unsafe {
256                    std::mem::transmute::<&mut [std::mem::MaybeUninit<T>], &mut [T]>(ys_to_set)
257                };
258                let mut dst_i = 0;
259                for src_i in (o_r1..o_r2).step_by(ob.len) {
260                    f_vec(
261                        lhs,
262                        &rhs[src_i..src_i + ob.len],
263                        &mut ys_to_set[dst_i..dst_i + ob.len],
264                    );
265                    dst_i += ob.len;
266                }
267                // SAFETY: values are all set by f_vec.
268                unsafe { ys.set_len(el_count) };
269                ys
270            }
271            Some(ob) => {
272                let lhs = &lhs[ob.start..ob.start + ob.len];
273                let mut ys = rhs[o_r1..o_r2].to_vec();
274                for idx_l in 0..ob.left_broadcast {
275                    let start = idx_l * ob.len * ob.right_broadcast;
276                    for (i, &l) in lhs.iter().enumerate() {
277                        let start = start + i * ob.right_broadcast;
278                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
279                            *v = f(l, *v)
280                        }
281                    }
282                }
283                ys
284            }
285            None => lhs_l
286                .strided_index()
287                .zip(rhs_l.strided_index())
288                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
289                .collect(),
290        },
291        _ => lhs_l
292            .strided_index()
293            .zip(rhs_l.strided_index())
294            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
295            .collect(),
296    }
297}
298
299pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
300    vs: &[T],
301    layout: &Layout,
302    mut f: F,
303) -> Vec<U> {
304    match layout.strided_blocks() {
305        crate::StridedBlocks::SingleBlock { start_offset, len } => vs
306            [start_offset..start_offset + len]
307            .iter()
308            .map(|&v| f(v))
309            .collect(),
310        crate::StridedBlocks::MultipleBlocks {
311            block_start_index,
312            block_len,
313        } => {
314            let mut result = Vec::with_capacity(layout.shape().elem_count());
315            // Specialize the case where block_len is one to avoid the second loop.
316            if block_len == 1 {
317                for index in block_start_index {
318                    let v = unsafe { vs.get_unchecked(index) };
319                    result.push(f(*v))
320                }
321            } else {
322                for index in block_start_index {
323                    for offset in 0..block_len {
324                        let v = unsafe { vs.get_unchecked(index + offset) };
325                        result.push(f(*v))
326                    }
327                }
328            }
329            result
330        }
331    }
332}
333
334pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
335    vs: &[T],
336    layout: &Layout,
337    mut f: F,
338    mut f_vec: FV,
339) -> Vec<U> {
340    match layout.strided_blocks() {
341        crate::StridedBlocks::SingleBlock { start_offset, len } => {
342            let mut ys: Vec<U> = Vec::with_capacity(len);
343            let ys_to_set = ys.spare_capacity_mut();
344            let ys_to_set = unsafe {
345                std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
346            };
347            f_vec(&vs[start_offset..start_offset + len], ys_to_set);
348            // SAFETY: values are all set by f_vec.
349            unsafe { ys.set_len(len) };
350            ys
351        }
352        crate::StridedBlocks::MultipleBlocks {
353            block_start_index,
354            block_len,
355        } => {
356            let el_count = layout.shape().elem_count();
357            // Specialize the case where block_len is one to avoid the second loop.
358            if block_len == 1 {
359                let mut result = Vec::with_capacity(el_count);
360                for index in block_start_index {
361                    let v = unsafe { vs.get_unchecked(index) };
362                    result.push(f(*v))
363                }
364                result
365            } else {
366                let mut ys: Vec<U> = Vec::with_capacity(el_count);
367                let ys_to_set = ys.spare_capacity_mut();
368                let ys_to_set = unsafe {
369                    std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(ys_to_set)
370                };
371                let mut dst_index = 0;
372                for src_index in block_start_index {
373                    let vs = &vs[src_index..src_index + block_len];
374                    let ys = &mut ys_to_set[dst_index..dst_index + block_len];
375                    f_vec(vs, ys);
376                    dst_index += block_len;
377                }
378                // SAFETY: values are all set by f_vec.
379                unsafe { ys.set_len(el_count) };
380                ys
381            }
382        }
383    }
384}