candle_core_temp/
cpu_backend.rs

1use crate::backend::{BackendDevice, BackendStorage};
2use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
3use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
4use half::{bf16, f16};
5use rayon::prelude::*;
6
7const USE_IM2COL_CONV1D: bool = true;
8const USE_IM2COL_CONV2D: bool = true;
9
10// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
11// intercept the oom errors to avoid panicking and provide a proper error.
12#[derive(Debug, Clone)]
13pub enum CpuStorage {
14    U8(Vec<u8>),
15    U32(Vec<u32>),
16    I64(Vec<i64>),
17    BF16(Vec<bf16>),
18    F16(Vec<f16>),
19    F32(Vec<f32>),
20    F64(Vec<f64>),
21}
22
23#[derive(Debug, Clone)]
24pub struct CpuDevice;
25
26pub trait Map1 {
27    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
28
29    fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
30        match vs {
31            CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
32            CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
33            CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
34            CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
35            CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
36            CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
37            CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
38        }
39    }
40}
41
42pub trait Map1Any {
43    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
44        &self,
45        vs: &[T],
46        layout: &Layout,
47        wrap: W,
48    ) -> Result<CpuStorage>;
49
50    fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
51        match vs {
52            CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
53            CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
54            CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
55            CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
56            CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
57            CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
58            CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
59        }
60    }
61}
62
63type C = CpuStorage;
64pub trait Map2 {
65    const OP: &'static str;
66    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
67
68    fn map(
69        &self,
70        v1: &CpuStorage,
71        l1: &Layout,
72        v2: &CpuStorage,
73        l2: &Layout,
74    ) -> Result<CpuStorage> {
75        match (v1, v2) {
76            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
77            (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
78            (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
79            (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
80            (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
81            (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
82            (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
83            _ => Err(Error::DTypeMismatchBinaryOp {
84                lhs: v1.dtype(),
85                rhs: v2.dtype(),
86                op: Self::OP,
87            }
88            .bt()),
89        }
90    }
91}
92
93pub trait Map2U8 {
94    const OP: &'static str;
95    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
96
97    fn map(
98        &self,
99        v1: &CpuStorage,
100        l1: &Layout,
101        v2: &CpuStorage,
102        l2: &Layout,
103    ) -> Result<CpuStorage> {
104        match (v1, v2) {
105            (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
106            (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
107            (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
108            (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
109            (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
110            (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
111            (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
112            _ => Err(Error::DTypeMismatchBinaryOp {
113                lhs: v1.dtype(),
114                rhs: v2.dtype(),
115                op: Self::OP,
116            }
117            .bt()),
118        }
119    }
120}
121
122struct Cmp(CmpOp);
123impl Map2U8 for Cmp {
124    const OP: &'static str = "cmp";
125    #[inline(always)]
126    fn f<T: WithDType>(
127        &self,
128        lhs: &[T],
129        lhs_l: &Layout,
130        rhs: &[T],
131        rhs_l: &Layout,
132    ) -> Result<Vec<u8>> {
133        let dst = match self.0 {
134            CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
135            CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
136            CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
137            CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
138            CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
139            CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
140        };
141        Ok(dst)
142    }
143}
144
145struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
146
147impl<'a, I: IntDType> Map2 for WCond<'a, I> {
148    const OP: &'static str = "where";
149    #[inline(always)]
150    fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
151        let vs = match (
152            self.1.contiguous_offsets(),
153            t_l.contiguous_offsets(),
154            f_l.contiguous_offsets(),
155        ) {
156            (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
157                let pred = &self.0[o1..o2];
158                let t = &t[o_t1..o_t2];
159                let f = &f[o_f1..o_f2];
160                pred.iter()
161                    .zip(t.iter().zip(f.iter()))
162                    .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
163                    .collect::<Vec<_>>()
164            }
165            _ => self
166                .1
167                .strided_index()
168                .zip(t_l.strided_index().zip(f_l.strided_index()))
169                .map(|(i_p, (i_t, i_f))| {
170                    if self.0[i_p].is_true() {
171                        t[i_t]
172                    } else {
173                        f[i_f]
174                    }
175                })
176                .collect::<Vec<_>>(),
177        };
178        Ok(vs)
179    }
180}
181
182struct ReduceIndex {
183    reduce_dim_index: usize,
184    use_min: bool,
185    return_index: bool,
186}
187
188impl ReduceIndex {
189    // The value gets replaced if f(s[current_acc], s[i]) returns true.
190    #[inline(always)]
191    fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
192    where
193        T: Clone + Copy,
194        U: Clone + Copy,
195        F: Fn(T, T) -> bool,
196        G: Fn(T, usize) -> U,
197    {
198        let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
199        let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
200        let dst_len = src_l.shape().elem_count() / reduce_dim_size;
201        let mut dst: Vec<U> = Vec::with_capacity(dst_len);
202        let dst_to_set = dst.spare_capacity_mut();
203        let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
204        match src_l.contiguous_offsets() {
205            Some((o1, o2)) => {
206                let src = &src[o1..o2];
207                if reduce_dim_stride == 1 {
208                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
209                        let start_src_i = start_src_i * reduce_dim_size;
210                        let src = &src[start_src_i..start_src_i + reduce_dim_size];
211                        let mut acc = 0;
212                        let mut val = src[0];
213                        for (src_i, &s) in src.iter().enumerate() {
214                            if f(val, s) {
215                                acc = src_i;
216                                val = s
217                            }
218                        }
219                        *dst_v = g(val, acc)
220                    }
221                } else {
222                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
223                        let (p, q) = (
224                            start_src_i / reduce_dim_stride,
225                            start_src_i % reduce_dim_stride,
226                        );
227                        // start_src_i = p * reduce_dim_stride + q
228                        let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
229                        let src = &src[start_src_i..];
230                        let mut acc = 0;
231                        let mut val = src[0];
232                        for src_i in 0..reduce_dim_size {
233                            let s = src[src_i * reduce_dim_stride];
234                            if f(val, s) {
235                                acc = src_i;
236                                val = s
237                            }
238                        }
239                        *dst_v = g(val, acc)
240                    }
241                }
242            }
243            None => {
244                let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
245                for (unstr_index, src_index) in l.strided_index().enumerate() {
246                    let src = &src[src_index..];
247                    let mut acc = 0;
248                    let mut val = src[0];
249                    for src_i in 0..reduce_dim_size {
250                        let s = src[src_i * reduce_dim_stride];
251                        if f(val, s) {
252                            acc = src_i;
253                            val = s
254                        }
255                    }
256                    dst_to_set[unstr_index] = g(val, acc)
257                }
258            }
259        }
260        unsafe { dst.set_len(dst_len) };
261        Ok(dst)
262    }
263}
264
265impl Map1Any for ReduceIndex {
266    #[inline(always)]
267    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
268        &self,
269        src: &[T],
270        src_l: &Layout,
271        wrap: W,
272    ) -> Result<CpuStorage> {
273        if src_l.shape().elem_count() == 0 {
274            Err(Error::EmptyTensor { op: "reduce" }.bt())?
275        }
276        let dst = match (self.return_index, self.use_min) {
277            (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
278            (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
279            (true, true) => {
280                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
281            }
282            (true, false) => {
283                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
284            }
285        };
286        Ok(dst)
287    }
288}
289
290struct ReduceSum<'a> {
291    dst_shape: &'a Shape,
292    reduce_dims: &'a [usize],
293    reduce_dims_and_stride: Vec<(usize, usize)>,
294}
295
296impl<'a> ReduceSum<'a> {
297    #[inline(always)]
298    fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
299    where
300        T: WithDType,
301    {
302        let mut dst = vec![start_elt; self.dst_shape.elem_count()];
303        match src_l.contiguous_offsets() {
304            Some((o1, o2)) => {
305                let src = &src[o1..o2];
306                // Handle the case where we reduce over the last dimensions separately as it is
307                // fairly common and easy to optimize. This rely on the layout being contiguous!
308                // reduce_dims is sorted, check if it is ranging from a to n-1.
309                let reduce_over_last_dims = self
310                    .reduce_dims
311                    .iter()
312                    .rev()
313                    .enumerate()
314                    .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
315                if reduce_over_last_dims {
316                    let reduce_sz = self
317                        .reduce_dims_and_stride
318                        .iter()
319                        .map(|(u, _)| u)
320                        .product::<usize>();
321                    for (dst_i, dst_v) in dst.iter_mut().enumerate() {
322                        let src_i = dst_i * reduce_sz;
323                        unsafe {
324                            T::vec_reduce_sum(
325                                src[src_i..src_i + reduce_sz].as_ptr(),
326                                dst_v,
327                                reduce_sz,
328                            )
329                        };
330                    }
331                    return Ok(dst);
332                };
333                for (unstr_index, &src) in src.iter().enumerate() {
334                    let mut dst_index = unstr_index;
335                    // Set the reduce_dims indexes to 0.
336                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
337                        // The compiler is able to optimize the following in a single divmod op.
338                        let (pre, post) = (dst_index / stride, dst_index % stride);
339                        dst_index = (pre / dim) * stride + post;
340                    }
341                    dst[dst_index] += src;
342                }
343            }
344            None => {
345                for (unstr_index, src_index) in src_l.strided_index().enumerate() {
346                    let mut dst_index = unstr_index;
347                    // Set the reduce_dims indexes to 0.
348                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
349                        // The compiler is able to optimize the following in a single divmod op.
350                        let (pre, post) = (dst_index / stride, dst_index % stride);
351                        dst_index = (pre / dim) * stride + post;
352                    }
353                    dst[dst_index] += src[src_index];
354                }
355            }
356        }
357        Ok(dst)
358    }
359}
360
361impl<'a> Map1 for ReduceSum<'a> {
362    #[inline(always)]
363    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
364        self.fold_impl(src, src_l, T::zero())
365    }
366}
367
368pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
369    vs: &[T],
370    layout: &Layout,
371    mut f: F,
372) -> Vec<U> {
373    match layout.strided_blocks() {
374        crate::StridedBlocks::SingleBlock { start_offset, len } => vs
375            [start_offset..start_offset + len]
376            .iter()
377            .map(|&v| f(v))
378            .collect(),
379        crate::StridedBlocks::MultipleBlocks {
380            block_start_index,
381            block_len,
382        } => {
383            let mut result = Vec::with_capacity(layout.shape().elem_count());
384            // Specialize the case where block_len is one to avoid the second loop.
385            if block_len == 1 {
386                for index in block_start_index {
387                    let v = unsafe { vs.get_unchecked(index) };
388                    result.push(f(*v))
389                }
390            } else {
391                for index in block_start_index {
392                    for offset in 0..block_len {
393                        let v = unsafe { vs.get_unchecked(index + offset) };
394                        result.push(f(*v))
395                    }
396                }
397            }
398            result
399        }
400    }
401}
402
403pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
404    vs: &[T],
405    layout: &Layout,
406    mut f: F,
407    mut f_vec: FV,
408) -> Vec<U> {
409    match layout.strided_blocks() {
410        crate::StridedBlocks::SingleBlock { start_offset, len } => {
411            let mut ys: Vec<U> = Vec::with_capacity(len);
412            let ys_to_set = ys.spare_capacity_mut();
413            let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
414            f_vec(&vs[start_offset..start_offset + len], ys_to_set);
415            // SAFETY: values are all set by f_vec.
416            unsafe { ys.set_len(len) };
417            ys
418        }
419        crate::StridedBlocks::MultipleBlocks {
420            block_start_index,
421            block_len,
422        } => {
423            let el_count = layout.shape().elem_count();
424            // Specialize the case where block_len is one to avoid the second loop.
425            if block_len == 1 {
426                let mut result = Vec::with_capacity(el_count);
427                for index in block_start_index {
428                    let v = unsafe { vs.get_unchecked(index) };
429                    result.push(f(*v))
430                }
431                result
432            } else {
433                let mut ys: Vec<U> = Vec::with_capacity(el_count);
434                let ys_to_set = ys.spare_capacity_mut();
435                let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
436                let mut dst_index = 0;
437                for src_index in block_start_index {
438                    let vs = &vs[src_index..src_index + block_len];
439                    let ys = &mut ys_to_set[dst_index..dst_index + block_len];
440                    f_vec(vs, ys);
441                    dst_index += block_len;
442                }
443                // SAFETY: values are all set by f_vec.
444                unsafe { ys.set_len(el_count) };
445                ys
446            }
447        }
448    }
449}
450
451// This function maps over two strided index sequences.
452pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
453    lhs_l: &Layout,
454    rhs_l: &Layout,
455    lhs: &[T],
456    rhs: &[T],
457    mut f: F,
458) -> Vec<U> {
459    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
460        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
461            .iter()
462            .zip(rhs[o_r1..o_r2].iter())
463            .map(|(&l, &r)| f(l, r))
464            .collect(),
465        (Some((o_l1, o_l2)), None) => {
466            // TODO: Maybe we want to avoid going through the layout twice.
467            match rhs_l.offsets_b() {
468                Some(ob) => {
469                    let mut i_in_block = 0;
470                    let mut i_right_broadcast = 0;
471                    lhs[o_l1..o_l2]
472                        .iter()
473                        .map(|&l| {
474                            let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
475                            i_right_broadcast += 1;
476                            if i_right_broadcast >= ob.right_broadcast {
477                                i_in_block += 1;
478                                i_right_broadcast = 0;
479                            }
480                            if i_in_block >= ob.len {
481                                i_in_block = 0
482                            }
483                            f(l, *r)
484                        })
485                        .collect()
486                }
487                None => lhs_l
488                    .strided_index()
489                    .zip(rhs_l.strided_index())
490                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
491                    .collect(),
492            }
493        }
494        (None, Some((o_r1, o_r2))) => {
495            // TODO: Maybe we want to avoid going through the layout twice.
496            match lhs_l.offsets_b() {
497                Some(ob) => {
498                    let mut i_in_block = 0;
499                    let mut i_right_broadcast = 0;
500                    rhs[o_r1..o_r2]
501                        .iter()
502                        .map(|&r| {
503                            let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
504                            i_right_broadcast += 1;
505                            if i_right_broadcast >= ob.right_broadcast {
506                                i_in_block += 1;
507                                i_right_broadcast = 0;
508                            }
509                            if i_in_block >= ob.len {
510                                i_in_block = 0
511                            }
512                            f(*l, r)
513                        })
514                        .collect()
515                }
516                None => lhs_l
517                    .strided_index()
518                    .zip(rhs_l.strided_index())
519                    .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
520                    .collect(),
521            }
522        }
523        _ => lhs_l
524            .strided_index()
525            .zip(rhs_l.strided_index())
526            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
527            .collect(),
528    }
529}
530
531// Similar to binary_map but with vectorized variants.
532pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
533    lhs_l: &Layout,
534    rhs_l: &Layout,
535    lhs: &[T],
536    rhs: &[T],
537    mut f: F,
538    mut f_vec: FV,
539) -> Vec<T> {
540    let el_count = lhs_l.shape().elem_count();
541    match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
542        (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
543            let mut ys: Vec<T> = Vec::with_capacity(el_count);
544            let ys_to_set = ys.spare_capacity_mut();
545            let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
546            f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
547            // SAFETY: values are all set by f_vec.
548            unsafe { ys.set_len(el_count) };
549            ys
550        }
551        (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
552            Some(ob) if ob.right_broadcast == 1 => {
553                let rhs = &rhs[ob.start..ob.start + ob.len];
554                let mut ys: Vec<T> = Vec::with_capacity(el_count);
555                let ys_to_set = ys.spare_capacity_mut();
556                let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
557                let mut dst_i = 0;
558                for src_i in (o_l1..o_l2).step_by(ob.len) {
559                    f_vec(
560                        &lhs[src_i..src_i + ob.len],
561                        rhs,
562                        &mut ys_to_set[dst_i..dst_i + ob.len],
563                    );
564                    dst_i += ob.len;
565                }
566                // SAFETY: values are all set by f_vec.
567                unsafe { ys.set_len(el_count) };
568                ys
569            }
570            Some(ob) => {
571                let rhs = &rhs[ob.start..ob.start + ob.len];
572                let mut ys = lhs[o_l1..o_l2].to_vec();
573                for idx_l in 0..ob.left_broadcast {
574                    let start = idx_l * ob.len * ob.right_broadcast;
575                    for (i, &r) in rhs.iter().enumerate() {
576                        let start = start + i * ob.right_broadcast;
577                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
578                            *v = f(*v, r)
579                        }
580                    }
581                }
582                ys
583            }
584            None => lhs_l
585                .strided_index()
586                .zip(rhs_l.strided_index())
587                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
588                .collect(),
589        },
590        (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
591            Some(ob) if ob.right_broadcast == 1 => {
592                let lhs = &lhs[ob.start..ob.start + ob.len];
593                let mut ys: Vec<T> = Vec::with_capacity(el_count);
594                let ys_to_set = ys.spare_capacity_mut();
595                let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
596                let mut dst_i = 0;
597                for src_i in (o_r1..o_r2).step_by(ob.len) {
598                    f_vec(
599                        lhs,
600                        &rhs[src_i..src_i + ob.len],
601                        &mut ys_to_set[dst_i..dst_i + ob.len],
602                    );
603                    dst_i += ob.len;
604                }
605                // SAFETY: values are all set by f_vec.
606                unsafe { ys.set_len(el_count) };
607                ys
608            }
609            Some(ob) => {
610                let lhs = &lhs[ob.start..ob.start + ob.len];
611                let mut ys = rhs[o_r1..o_r2].to_vec();
612                for idx_l in 0..ob.left_broadcast {
613                    let start = idx_l * ob.len * ob.right_broadcast;
614                    for (i, &l) in lhs.iter().enumerate() {
615                        let start = start + i * ob.right_broadcast;
616                        for v in ys[start..start + ob.right_broadcast].iter_mut() {
617                            *v = f(l, *v)
618                        }
619                    }
620                }
621                ys
622            }
623            None => lhs_l
624                .strided_index()
625                .zip(rhs_l.strided_index())
626                .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
627                .collect(),
628        },
629        _ => lhs_l
630            .strided_index()
631            .zip(rhs_l.strided_index())
632            .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
633            .collect(),
634    }
635}
636
637struct Affine(f64, f64);
638
639impl Map1 for Affine {
640    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
641        let mul = T::from_f64(self.0);
642        let add = T::from_f64(self.1);
643        Ok(unary_map(vs, layout, |v| v * mul + add))
644    }
645}
646
647struct AvgPool2D((usize, usize), (usize, usize));
648
649impl Map1 for AvgPool2D {
650    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
651        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
652        let (k_h, k_w) = self.0;
653        let (s_h, s_w) = self.1;
654        let (b_sz, c, h, w) = layout.shape().dims4()?;
655        let stride = layout.stride();
656        let (stride_h, stride_w) = (stride[2], stride[3]);
657        let h_out = (h - k_h) / s_h + 1;
658        let w_out = (w - k_w) / s_w + 1;
659        let src_index = layout.start_offset();
660        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
661        let scale = 1f64 / (k_h * k_w) as f64;
662        let scale = T::from_f64(scale);
663        for b_idx in 0..b_sz {
664            let dst = &mut dst[b_idx * c * h_out * w_out..];
665            let src_index = src_index + b_idx * stride[0];
666            for c_idx in 0..c {
667                let dst = &mut dst[c_idx * h_out * w_out..];
668                let src_index = src_index + c_idx * stride[1];
669                for h_idx in 0..h_out {
670                    for w_idx in 0..w_out {
671                        let mut sum = T::zero();
672                        for m in 0..k_h {
673                            for n in 0..k_w {
674                                let m = s_h * h_idx + m;
675                                let n = s_w * w_idx + n;
676                                sum += src[src_index + m * stride_h + n * stride_w]
677                            }
678                        }
679                        dst[h_idx * w_out + w_idx] = sum * scale;
680                    }
681                }
682            }
683        }
684        Ok(dst)
685    }
686}
687
688struct MaxPool2D((usize, usize), (usize, usize));
689
690impl Map1 for MaxPool2D {
691    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
692        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
693        let (k_h, k_w) = self.0;
694        let (s_h, s_w) = self.1;
695        let (b_sz, c, h, w) = layout.shape().dims4()?;
696        let stride = layout.stride();
697        let (stride_h, stride_w) = (stride[2], stride[3]);
698        let h_out = (h - k_h) / s_h + 1;
699        let w_out = (w - k_w) / s_w + 1;
700        let src_index = layout.start_offset();
701        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
702        for b_idx in 0..b_sz {
703            let dst = &mut dst[b_idx * c * h_out * w_out..];
704            let src_index = src_index + b_idx * stride[0];
705            for c_idx in 0..c {
706                let dst = &mut dst[c_idx * h_out * w_out..];
707                let src_index = src_index + c_idx * stride[1];
708                for h_idx in 0..h_out {
709                    for w_idx in 0..w_out {
710                        let mut largest =
711                            src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
712                        for m in 0..k_h {
713                            for n in 0..k_w {
714                                let m = s_h * h_idx + m;
715                                let n = s_w * w_idx + n;
716                                if largest < src[src_index + m * stride_h + n * stride_w] {
717                                    largest = src[src_index + m * stride_h + n * stride_w]
718                                }
719                            }
720                        }
721                        dst[h_idx * w_out + w_idx] = largest;
722                    }
723                }
724            }
725        }
726        Ok(dst)
727    }
728}
729
730struct UpsampleNearest1D(usize);
731
732impl Map1 for UpsampleNearest1D {
733    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
734        // TODO: Specialized implementation for the case 2*sz?
735        let dst_sz = self.0;
736        let (b_sz, c, src_sz) = layout.shape().dims3()?;
737        let stride = layout.stride();
738        let stride_sz = stride[2];
739        let src_index = layout.start_offset();
740        let scale_sz = src_sz as f64 / dst_sz as f64;
741        let mut dst = vec![T::zero(); b_sz * c * dst_sz];
742        let src_idxs = (0..dst_sz)
743            .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
744            .collect::<Vec<_>>();
745        for b_idx in 0..b_sz {
746            let dst = &mut dst[b_idx * c * dst_sz..];
747            let src_index = src_index + b_idx * stride[0];
748            for c_idx in 0..c {
749                let dst = &mut dst[c_idx * dst_sz..];
750                let src_index = src_index + c_idx * stride[1];
751                for (idx, src_idx) in src_idxs.iter().enumerate() {
752                    dst[idx] = src[src_index + src_idx * stride_sz]
753                }
754            }
755        }
756        Ok(dst)
757    }
758}
759
760struct UpsampleNearest2D(usize, usize);
761
762impl Map1 for UpsampleNearest2D {
763    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
764        // TODO: Specialized implementation for the case 2*h, 2*w?
765        let (dst_h, dst_w) = (self.0, self.1);
766        let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
767        let stride = layout.stride();
768        let (stride_h, stride_w) = (stride[2], stride[3]);
769        let src_index = layout.start_offset();
770        let scale_h = src_h as f64 / dst_h as f64;
771        let scale_w = src_w as f64 / dst_w as f64;
772        let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
773        let src_h_idxs = (0..dst_h)
774            .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
775            .collect::<Vec<_>>();
776        let src_w_idxs = (0..dst_w)
777            .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
778            .collect::<Vec<_>>();
779        for b_idx in 0..b_sz {
780            let dst = &mut dst[b_idx * c * dst_h * dst_w..];
781            let src_index = src_index + b_idx * stride[0];
782            for c_idx in 0..c {
783                let dst = &mut dst[c_idx * dst_h * dst_w..];
784                let src_index = src_index + c_idx * stride[1];
785                for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
786                    for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
787                        let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
788                        dst[h_idx * dst_w + w_idx] = src[src_index]
789                    }
790                }
791            }
792        }
793        Ok(dst)
794    }
795}
796
797struct Gather<'a, I: IntDType> {
798    ids: &'a [I],
799    ids_l: &'a Layout,
800    dim: usize,
801}
802
803impl<'a, I: IntDType> Map1 for Gather<'a, I> {
804    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
805        let ids = match self.ids_l.contiguous_offsets() {
806            Some((a, b)) => &self.ids[a..b],
807            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
808        };
809        let src = match src_l.contiguous_offsets() {
810            Some((a, b)) => &src[a..b],
811            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
812        };
813        let dim = self.dim;
814        let ids_dims = self.ids_l.dims();
815        let src_dims = src_l.dims();
816        let dst_len: usize = ids_dims.iter().product();
817        let dst_left_len: usize = ids_dims[..dim].iter().product();
818        let dst_dim_len = ids_dims[dim];
819        let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
820
821        let src_dim_len = src_dims[dim];
822        let src_right_len: usize = src_dims[dim + 1..].iter().product();
823
824        let mut dst = vec![T::zero(); dst_len];
825        for left_i in 0..dst_left_len {
826            let start_src_idx = left_i * src_right_len * src_dim_len;
827            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
828            for i in 0..dst_dim_len {
829                let start_dst_idx = start_dst_idx + i * dst_right_len;
830                for right_i in 0..dst_right_len {
831                    let dst_idx = start_dst_idx + right_i;
832                    let index = ids[dst_idx].as_usize();
833                    if index >= src_dim_len {
834                        Err(Error::InvalidIndex {
835                            index,
836                            size: src_dim_len,
837                            op: "gather",
838                        }
839                        .bt())?
840                    }
841                    let src_idx = start_src_idx + index * src_right_len + right_i;
842                    dst[dst_idx] = src[src_idx]
843                }
844            }
845        }
846        Ok(dst)
847    }
848}
849
850struct IndexSelect<'a, T: IntDType> {
851    ids: &'a [T],
852    ids_l: &'a Layout,
853    dim: usize,
854}
855
856impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
857    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
858        let src = match layout.contiguous_offsets() {
859            Some((a, b)) => &src[a..b],
860            None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
861        };
862        let dim = self.dim;
863        let n_ids = match self.ids_l.dims() {
864            [n_ids] => *n_ids,
865            d => Err(Error::UnexpectedNumberOfDims {
866                expected: 1,
867                got: d.len(),
868                shape: self.ids_l.shape().clone(),
869            }
870            .bt())?,
871        };
872        let stride_ids = self.ids_l.stride()[0];
873        let mut dst_dims = layout.dims().to_vec();
874        let src_dim = dst_dims[dim];
875        dst_dims[dim] = n_ids;
876        let dst_len: usize = dst_dims.iter().product();
877        let left_len: usize = dst_dims[..dim].iter().product();
878        let right_len: usize = dst_dims[dim + 1..].iter().product();
879        let mut dst = vec![T::zero(); dst_len];
880        for left_i in 0..left_len {
881            let start_src_idx = left_i * right_len * src_dim;
882            let start_dst_idx = left_i * right_len * n_ids;
883            for i in 0..n_ids {
884                let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
885                if index >= src_dim {
886                    Err(Error::InvalidIndex {
887                        index,
888                        size: src_dim,
889                        op: "index-select",
890                    }
891                    .bt())?
892                }
893                let start_src_idx = start_src_idx + index * right_len;
894                let start_dst_idx = start_dst_idx + i * right_len;
895                dst[start_dst_idx..start_dst_idx + right_len]
896                    .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
897            }
898        }
899        Ok(dst)
900    }
901}
902
903struct ScatterAdd<'a, I: IntDType> {
904    ids: &'a [I],
905    ids_l: &'a Layout,
906    dim: usize,
907}
908
909impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
910    const OP: &'static str = "scatter-add";
911    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
912        let dst_len = l1.shape().elem_count();
913        let mut dst = vec![T::zero(); dst_len];
914        copy_strided_src_(v1, &mut dst, 0, l1);
915        let src = match src_l.contiguous_offsets() {
916            None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
917            Some((o1, o2)) => &src[o1..o2],
918        };
919
920        let dim = self.dim;
921        let ids_dims = self.ids_l.dims();
922        let dst_dims = l1.dims();
923        let dst_dim_len = dst_dims[dim];
924        let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
925
926        let ids_left_len: usize = ids_dims[..dim].iter().product();
927        let ids_dim_len = ids_dims[dim];
928        let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
929
930        let ids = match self.ids_l.contiguous_offsets() {
931            Some((a, b)) => &self.ids[a..b],
932            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
933        };
934        for left_i in 0..ids_left_len {
935            let start_ids_idx = left_i * ids_right_len * ids_dim_len;
936            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
937            for i in 0..ids_dim_len {
938                let start_ids_idx = start_ids_idx + i * ids_right_len;
939                for right_i in 0..dst_right_len {
940                    let ids_idx = start_ids_idx + right_i;
941                    let index = ids[ids_idx].as_usize();
942                    if index >= dst_dim_len {
943                        Err(Error::InvalidIndex {
944                            index,
945                            size: dst_dim_len,
946                            op: "gather",
947                        }
948                        .bt())?
949                    }
950                    let dst_idx = start_dst_idx + index * dst_right_len + right_i;
951                    dst[dst_idx] += src[ids_idx]
952                }
953            }
954        }
955
956        Ok(dst)
957    }
958}
959
960struct IndexAdd<'a, I: IntDType> {
961    ids: &'a [I],
962    dim: usize,
963}
964
965impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
966    const OP: &'static str = "index-add";
967    // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
968    // v1, l1 -> self
969    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
970        let dst_len = l1.shape().elem_count();
971        let mut dst = vec![T::zero(); dst_len];
972        copy_strided_src_(v1, &mut dst, 0, l1);
973        let src = match src_l.contiguous_offsets() {
974            None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
975            Some((o1, o2)) => &src[o1..o2],
976        };
977        let dim = self.dim;
978        let max_idx = l1.dims()[dim];
979        let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
980        let src_dim_sz = src_l.dims()[dim];
981        let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
982        if dim == 0 {
983            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
984                let dst_idx = dst_idx.as_usize();
985                if dst_idx >= max_idx {
986                    Err(Error::InvalidIndex {
987                        index: dst_idx,
988                        op: "index-add",
989                        size: max_idx,
990                    })?
991                }
992                let src_idx = src_idx * post_dim;
993                let dst_idx = dst_idx * post_dim;
994                let src = &src[src_idx..src_idx + post_dim];
995                let dst = &mut dst[dst_idx..dst_idx + post_dim];
996                for (d, &s) in dst.iter_mut().zip(src.iter()) {
997                    *d += s
998                }
999            }
1000        } else {
1001            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
1002                let dst_idx = dst_idx.as_usize();
1003                if dst_idx >= max_idx {
1004                    Err(Error::InvalidIndex {
1005                        index: dst_idx,
1006                        op: "index-add",
1007                        size: max_idx,
1008                    })?
1009                }
1010                for pre_i in 0..pre_dim {
1011                    let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
1012                    let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
1013                    let src = &src[pre_src_i..pre_src_i + post_dim];
1014                    let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
1015                    for (d, &s) in dst.iter_mut().zip(src.iter()) {
1016                        *d += s
1017                    }
1018                }
1019            }
1020        }
1021        Ok(dst)
1022    }
1023}
1024
1025fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
1026    match src_l.strided_blocks() {
1027        crate::StridedBlocks::SingleBlock { start_offset, len } => {
1028            let to_copy = (dst.len() - dst_offset).min(len);
1029            dst[dst_offset..dst_offset + to_copy]
1030                .copy_from_slice(&src[start_offset..start_offset + to_copy])
1031        }
1032        crate::StridedBlocks::MultipleBlocks {
1033            block_start_index,
1034            block_len: 1,
1035        } => {
1036            for (dst_index, src_index) in block_start_index.enumerate() {
1037                let dst_index = dst_index + dst_offset;
1038                if dst_index >= dst.len() {
1039                    break;
1040                }
1041                dst[dst_index] = src[src_index]
1042            }
1043        }
1044        crate::StridedBlocks::MultipleBlocks {
1045            block_start_index,
1046            block_len,
1047        } => {
1048            let mut dst_index = dst_offset;
1049            for src_index in block_start_index {
1050                let next_dst_index = dst_index + block_len;
1051                if dst_index >= dst.len() {
1052                    break;
1053                }
1054                let to_copy = usize::min(block_len, dst.len() - dst_index);
1055                dst[dst_index..dst_index + to_copy]
1056                    .copy_from_slice(&src[src_index..src_index + to_copy]);
1057                dst_index = next_dst_index
1058            }
1059        }
1060    }
1061}
1062
1063struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
1064
1065impl<'a> Map2 for Conv1D<'a> {
1066    const OP: &'static str = "conv1d";
1067    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1068        let p = self.0;
1069        let inp = &inp[inp_l.start_offset()..];
1070        let k = &k[k_l.start_offset()..];
1071        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1072        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1073        let l_out = p.l_out();
1074        let dst_elems = p.c_out * l_out * p.b_size;
1075        // The output shape is [b_size, c_out, l_out]
1076        let dst = vec![T::zero(); dst_elems];
1077
1078        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1079        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1080        for b_idx in 0..p.b_size {
1081            for src_l in 0..p.l_in {
1082                for src_c_idx in 0..p.c_in {
1083                    let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
1084                    inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
1085                }
1086            }
1087        }
1088
1089        for offset in 0..p.k_size {
1090            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1091                let dst_idx = dst_c_idx * l_out;
1092                let k_cont = (0..p.c_in)
1093                    .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
1094                    .collect::<Vec<_>>();
1095                for b_idx in 0..p.b_size {
1096                    let dst_idx = dst_idx + b_idx * p.c_out * l_out;
1097                    for dst_l in 0..l_out {
1098                        let dst_idx = dst_idx + dst_l;
1099                        let src_l = p.stride * dst_l + offset * p.dilation;
1100                        if src_l < p.padding || src_l >= p.padding + p.l_in {
1101                            continue;
1102                        }
1103                        let src_l = src_l - p.padding;
1104                        let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
1105                        assert!(inp_cont.len() >= p.c_in);
1106                        assert!(k_cont.len() >= p.c_in);
1107                        let mut d = T::zero();
1108                        unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
1109                        let dst_p = dst.as_ptr();
1110                        // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
1111                        // the different tasks so no two threads can try to write at the same
1112                        // location.
1113                        unsafe {
1114                            let ptr = dst_p.add(dst_idx) as *mut T;
1115                            *ptr += d
1116                        }
1117                    }
1118                }
1119            })
1120        }
1121        Ok(dst)
1122    }
1123}
1124
1125struct Im2Col1D {
1126    l_k: usize,
1127    stride: usize,
1128    dilation: usize,
1129    padding: usize,
1130}
1131
1132impl Im2Col1D {
1133    fn l_out(&self, l: usize) -> usize {
1134        (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
1135    }
1136}
1137
1138impl Map1 for Im2Col1D {
1139    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1140        let &Self {
1141            l_k,
1142            stride,
1143            dilation,
1144            padding,
1145        } = self;
1146        let (b, c, l) = layout.shape().dims3()?;
1147        let l_out = self.l_out(l);
1148        let src = &vs[layout.start_offset()..];
1149        let mut dst = vec![T::zero(); b * l_out * c * l_k];
1150        let (src_s0, src_s1, src_s2) = {
1151            let s = layout.stride();
1152            (s[0], s[1], s[2])
1153        };
1154        // TODO: provide specialized kernels for the common use cases.
1155        // - l_k = 1
1156        // - padding = 0
1157        // - stride = 1
1158        // - dilation = 1
1159        for b_idx in 0..b {
1160            let src_idx = b_idx * src_s0;
1161            let dst_idx = b_idx * l_out * c * l_k;
1162            for l_idx in 0..l_out {
1163                let dst_idx = dst_idx + l_idx * c * l_k;
1164                for c_idx in 0..c {
1165                    let dst_idx = dst_idx + c_idx * l_k;
1166                    let src_idx = c_idx * src_s1 + src_idx;
1167                    for l_k_idx in 0..l_k {
1168                        let src_l = l_idx * stride + l_k_idx * dilation;
1169                        if padding != 0 && (src_l < padding || src_l >= l + padding) {
1170                            continue;
1171                        }
1172                        let src_l = src_l - padding;
1173                        let src_idx = src_idx + src_l * src_s2;
1174                        let dst_idx = dst_idx + l_k_idx;
1175                        dst[dst_idx] = src[src_idx]
1176                    }
1177                }
1178            }
1179        }
1180        Ok(dst)
1181    }
1182}
1183
1184struct Im2Col {
1185    h_k: usize,
1186    w_k: usize,
1187    stride: usize,
1188    dilation: usize,
1189    padding: usize,
1190}
1191
1192impl Im2Col {
1193    fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
1194        let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
1195        let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
1196        (h_out, w_out)
1197    }
1198}
1199
1200impl Map1 for Im2Col {
1201    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1202        let &Self {
1203            h_k,
1204            w_k,
1205            stride,
1206            dilation,
1207            padding,
1208        } = self;
1209        let (b, c, h, w) = layout.shape().dims4()?;
1210        let (h_out, w_out) = self.hw_out(h, w);
1211        let src = &vs[layout.start_offset()..];
1212        let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
1213        let (src_s0, src_s1, src_s2, src_s3) = {
1214            let s = layout.stride();
1215            (s[0], s[1], s[2], s[3])
1216        };
1217        // TODO: provide specialized kernels for the common use cases.
1218        // - h_k = w_k = 1
1219        // - padding = 0
1220        // - stride = 1
1221        // - dilation = 1
1222        for b_idx in 0..b {
1223            let src_idx = b_idx * src_s0;
1224            let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
1225            for h_idx in 0..h_out {
1226                let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
1227                for w_idx in 0..w_out {
1228                    let dst_idx = dst_idx + w_idx * c * h_k * w_k;
1229                    for c_idx in 0..c {
1230                        let dst_idx = dst_idx + c_idx * h_k * w_k;
1231                        let src_idx = c_idx * src_s1 + src_idx;
1232                        for h_k_idx in 0..h_k {
1233                            let src_h = h_idx * stride + h_k_idx * dilation;
1234                            if padding != 0 && (src_h < padding || src_h >= h + padding) {
1235                                continue;
1236                            }
1237                            let src_h = src_h - padding;
1238                            let src_idx = src_idx + src_h * src_s2;
1239                            let dst_idx = dst_idx + h_k_idx * w_k;
1240                            for w_k_idx in 0..w_k {
1241                                let src_w = w_idx * stride + w_k_idx * dilation;
1242                                if padding != 0 && (src_w < padding || src_w >= w + padding) {
1243                                    continue;
1244                                }
1245                                let src_w = src_w - padding;
1246                                let src_idx = src_idx + src_w * src_s3;
1247                                let dst_idx = dst_idx + w_k_idx;
1248                                dst[dst_idx] = src[src_idx]
1249                            }
1250                        }
1251                    }
1252                }
1253            }
1254        }
1255        Ok(dst)
1256    }
1257}
1258
1259struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
1260
1261impl<'a> Map2 for Conv2D<'a> {
1262    const OP: &'static str = "conv2d";
1263    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1264        let p = self.0;
1265        let inp = &inp[inp_l.start_offset()..];
1266        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1267        let k = &k[k_l.start_offset()..];
1268        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1269        let (out_h, out_w) = (p.out_h(), p.out_w());
1270
1271        // Output shape: [b_size, c_out, out_h, out_w].
1272        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1273
1274        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1275        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1276        let cont_s0 = p.i_h * p.i_w * p.c_in;
1277        let cont_s1 = p.i_w * p.c_in;
1278        let cont_s2 = p.c_in;
1279        for b_idx in 0..p.b_size {
1280            for h_idx in 0..p.i_h {
1281                for w_idx in 0..p.i_w {
1282                    for c_idx in 0..p.c_in {
1283                        let src_idx =
1284                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1285                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1286                        inp_cont[dst_idx] = inp[src_idx]
1287                    }
1288                }
1289            }
1290        }
1291
1292        for offset_h in 0..p.k_h {
1293            for offset_w in 0..p.k_w {
1294                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1295                    let dst_idx = dst_c_idx * out_w * out_h;
1296                    let k_cont = (0..p.c_in)
1297                        .map(|c_in_idx| {
1298                            k[dst_c_idx * k_s0
1299                                + c_in_idx * k_s1
1300                                + offset_h * k_s2
1301                                + offset_w * k_s3]
1302                        })
1303                        .collect::<Vec<_>>();
1304                    for b_idx in 0..p.b_size {
1305                        let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
1306                        for dst_h in 0..out_h {
1307                            let dst_idx = dst_idx + dst_h * out_w;
1308                            let src_h = p.stride * dst_h + offset_h * p.dilation;
1309                            if src_h < p.padding || src_h >= p.i_h + p.padding {
1310                                continue;
1311                            }
1312                            let src_h = src_h - p.padding;
1313                            for dst_w in 0..out_w {
1314                                let dst_idx = dst_idx + dst_w;
1315                                let src_w = p.stride * dst_w + offset_w * p.dilation;
1316                                if src_w < p.padding || src_w >= p.i_w + p.padding {
1317                                    continue;
1318                                }
1319                                let src_w = src_w - p.padding;
1320                                let inp_cont = &inp_cont
1321                                    [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
1322                                assert!(inp_cont.len() >= p.c_in);
1323                                assert!(k_cont.len() >= p.c_in);
1324                                let mut d = T::zero();
1325                                unsafe {
1326                                    T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1327                                }
1328                                let dst_p = dst.as_ptr();
1329                                // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
1330                                // the different tasks so no two threads can try to write at the same
1331                                // location.
1332                                unsafe {
1333                                    let ptr = dst_p.add(dst_idx) as *mut T;
1334                                    *ptr += d
1335                                }
1336                            }
1337                        }
1338                    }
1339                });
1340            }
1341        }
1342
1343        Ok(dst)
1344    }
1345}
1346
1347struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1348
1349impl<'a> Map2 for ConvTranspose2D<'a> {
1350    const OP: &'static str = "conv_transpose2d";
1351    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1352        let p = self.0;
1353        let inp = &inp[inp_l.start_offset()..];
1354        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1355        let k = &k[k_l.start_offset()..];
1356        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1357        let (out_h, out_w) = (p.out_h(), p.out_w());
1358
1359        // Output shape: [b_size, c_out, out_h, out_w].
1360        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1361        let dst_s0 = p.c_out * out_h * out_w;
1362        let dst_s1 = out_h * out_w;
1363        let dst_s2 = out_w;
1364        let dst_s3 = 1;
1365
1366        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1367        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1368        let cont_s0 = p.i_h * p.i_w * p.c_in;
1369        let cont_s1 = p.i_w * p.c_in;
1370        let cont_s2 = p.c_in;
1371        for b_idx in 0..p.b_size {
1372            for h_idx in 0..p.i_h {
1373                for w_idx in 0..p.i_w {
1374                    for c_idx in 0..p.c_in {
1375                        let src_idx =
1376                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1377                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1378                        inp_cont[dst_idx] = inp[src_idx]
1379                    }
1380                }
1381            }
1382        }
1383
1384        for k_y in 0..p.k_h {
1385            for k_x in 0..p.k_w {
1386                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1387                    let k_cont = (0..p.c_in)
1388                        .map(|c_in_idx| {
1389                            k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1390                        })
1391                        .collect::<Vec<_>>();
1392                    for b_idx in 0..p.b_size {
1393                        for inp_y in 0..p.i_h {
1394                            for inp_x in 0..p.i_w {
1395                                let out_x = inp_x * p.stride + k_x * p.dilation;
1396                                let out_y = inp_y * p.stride + k_y * p.dilation;
1397                                if out_x < p.padding || out_y < p.padding {
1398                                    continue;
1399                                }
1400                                let out_x = out_x - p.padding;
1401                                let out_y = out_y - p.padding;
1402                                if out_x < out_w && out_y < out_h {
1403                                    let inp_cont = &inp_cont
1404                                        [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1405                                    let dst_idx = b_idx * dst_s0
1406                                        + out_y * dst_s2
1407                                        + out_x * dst_s3
1408                                        + dst_c_idx * dst_s1;
1409                                    let mut d = T::zero();
1410                                    unsafe {
1411                                        T::vec_dot(
1412                                            inp_cont.as_ptr(),
1413                                            k_cont.as_ptr(),
1414                                            &mut d,
1415                                            p.c_in,
1416                                        )
1417                                    }
1418                                    let dst_p = dst.as_ptr();
1419                                    // Safety: dst_idx are uniques per dst_c_idx which is used to
1420                                    // parallelise the different tasks so no two threads can try to
1421                                    // write at the same location.
1422                                    unsafe {
1423                                        let ptr = dst_p.add(dst_idx) as *mut T;
1424                                        *ptr += d
1425                                    }
1426                                }
1427                            }
1428                        }
1429                    }
1430                })
1431            }
1432        }
1433        Ok(dst)
1434    }
1435}
1436
1437struct MatMul((usize, usize, usize, usize));
1438
1439impl MatMul {
1440    fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1441        Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1442            lhs_l: lhs_l.clone(),
1443            rhs_l: rhs_l.clone(),
1444            bmnk: self.0,
1445            msg,
1446        }))
1447        .bt()
1448    }
1449}
1450
1451impl Map2 for MatMul {
1452    const OP: &'static str = "mat_mul";
1453
1454    #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1455    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1456        &self,
1457        lhs: &[T],
1458        lhs_l: &Layout,
1459        rhs: &[T],
1460        rhs_l: &Layout,
1461    ) -> Result<Vec<T>> {
1462        use gemm::{gemm, Parallelism};
1463
1464        match T::DTYPE {
1465            DType::F16 | DType::F32 | DType::F64 => {}
1466            _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1467        }
1468
1469        let (b, m, n, k) = self.0;
1470        let lhs = &lhs[lhs_l.start_offset()..];
1471        let rhs = &rhs[rhs_l.start_offset()..];
1472
1473        let lhs_stride = lhs_l.stride();
1474        let rhs_stride = rhs_l.stride();
1475        let rank = lhs_stride.len();
1476        let lhs_cs = lhs_stride[rank - 1];
1477        let lhs_rs = lhs_stride[rank - 2];
1478
1479        let rhs_cs = rhs_stride[rank - 1];
1480        let rhs_rs = rhs_stride[rank - 2];
1481
1482        let a_skip: usize = match lhs_stride[..rank - 2] {
1483            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1484            [stride] => stride,
1485            [] => m * k,
1486            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1487        };
1488        let b_skip: usize = match rhs_stride[..rank - 2] {
1489            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1490            [stride] => stride,
1491            [] => n * k,
1492            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1493        };
1494        let c_skip: usize = m * n;
1495
1496        let dst_shape: Shape = (m, n).into();
1497        let dst_strides = dst_shape.stride_contiguous();
1498        let dst_rs = dst_strides[0];
1499        let dst_cs = dst_strides[1];
1500
1501        let mut dst = vec![T::zero(); b * m * n];
1502        let num_threads = crate::utils::get_num_threads();
1503        let parallelism = if num_threads > 1 {
1504            Parallelism::Rayon(num_threads)
1505        } else {
1506            Parallelism::None
1507        };
1508        for step in 0..b {
1509            let lhs_p = &lhs[step * a_skip..];
1510            let rhs_p = &rhs[step * b_skip..];
1511            let dst_p = &mut dst[step * c_skip..];
1512            unsafe {
1513                gemm(
1514                    /* m: usize = */ m,
1515                    /* n: usize = */ n,
1516                    /* k: usize = */ k,
1517                    /* dst: *mut T = */ dst_p.as_mut_ptr(),
1518                    /* dst_cs: isize = */ dst_cs as isize,
1519                    /* dst_rs: isize = */ dst_rs as isize,
1520                    /* read_dst: bool = */ false,
1521                    /* lhs: *const T = */ lhs_p.as_ptr(),
1522                    /* lhs_cs: isize = */ lhs_cs as isize,
1523                    /* lhs_rs: isize = */ lhs_rs as isize,
1524                    /* rhs: *const T = */ rhs_p.as_ptr(),
1525                    /* rhs_cs: isize = */ rhs_cs as isize,
1526                    /* rhs_rs: isize = */ rhs_rs as isize,
1527                    /* alpha: T = */ T::zero(),
1528                    /* beta: T = */ T::one(),
1529                    /* conj_dst: bool = */ false,
1530                    /* conj_lhs: bool = */ false,
1531                    /* conj_rhs: bool = */ false,
1532                    parallelism,
1533                )
1534            }
1535        }
1536        Ok(dst)
1537    }
1538
1539    #[cfg(feature = "accelerate")]
1540    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1541        &self,
1542        lhs: &[T],
1543        lhs_l: &Layout,
1544        rhs: &[T],
1545        rhs_l: &Layout,
1546    ) -> Result<Vec<T>> {
1547        let (b, m, n, k) = self.0;
1548        let lhs = &lhs[lhs_l.start_offset()..];
1549        let rhs = &rhs[rhs_l.start_offset()..];
1550
1551        let lhs_stride = lhs_l.stride();
1552        let rhs_stride = rhs_l.stride();
1553        let rank = lhs_stride.len();
1554
1555        let a_skip: usize = match lhs_stride[..rank - 2] {
1556            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1557            [stride] => stride,
1558            [] => m * k,
1559            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1560        };
1561        let b_skip: usize = match rhs_stride[..rank - 2] {
1562            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1563            [stride] => stride,
1564            [] => n * k,
1565            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1566        };
1567        let c_skip: usize = m * n;
1568
1569        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1570        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1571        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1572        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1573
1574        let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
1575            (n as i32, b'N')
1576        } else if rhs_m1 == k && rhs_m2 == 1 {
1577            (k as i32, b'T')
1578        } else {
1579            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1580        };
1581        // The b tensor has dims batching, m, k (lhs)
1582        let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
1583            (k as i32, b'N')
1584        } else if lhs_m1 == m && lhs_m2 == 1 {
1585            (m as i32, b'T')
1586        } else {
1587            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1588        };
1589
1590        let mut dst = vec![T::zero(); b * m * n];
1591        match T::DTYPE {
1592            DType::F16 => {
1593                crate::bail!("the accelerate backend does not support f16 matmul")
1594            }
1595            DType::F32 => {
1596                for step in 0..b {
1597                    let lhs_p = &lhs[step * a_skip..];
1598                    let rhs_p = &rhs[step * b_skip..];
1599                    let dst_p = &mut dst[step * c_skip..];
1600                    unsafe {
1601                        let a = rhs_p.as_ptr() as *const f32;
1602                        let b = lhs_p.as_ptr() as *const f32;
1603                        let c = dst_p.as_mut_ptr() as *mut f32;
1604                        let a = std::slice::from_raw_parts(a, a_skip);
1605                        let b = std::slice::from_raw_parts(b, b_skip);
1606                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1607                        crate::accelerate::sgemm(
1608                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1609                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1610                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1611                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1612                        )
1613                    }
1614                }
1615            }
1616            DType::F64 => {
1617                for step in 0..b {
1618                    let lhs_p = &lhs[step * a_skip..];
1619                    let rhs_p = &rhs[step * b_skip..];
1620                    let dst_p = &mut dst[step * c_skip..];
1621                    unsafe {
1622                        let a = rhs_p.as_ptr() as *const f64;
1623                        let b = lhs_p.as_ptr() as *const f64;
1624                        let c = dst_p.as_mut_ptr() as *mut f64;
1625                        let a = std::slice::from_raw_parts(a, a_skip);
1626                        let b = std::slice::from_raw_parts(b, b_skip);
1627                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1628                        crate::accelerate::dgemm(
1629                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1630                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1631                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1632                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1633                        )
1634                    }
1635                }
1636            }
1637            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1638        }
1639        Ok(dst)
1640    }
1641
1642    #[cfg(feature = "mkl")]
1643    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1644        &self,
1645        lhs: &[T],
1646        lhs_l: &Layout,
1647        rhs: &[T],
1648        rhs_l: &Layout,
1649    ) -> Result<Vec<T>> {
1650        let (b, m, n, k) = self.0;
1651        let lhs = &lhs[lhs_l.start_offset()..];
1652        let rhs = &rhs[rhs_l.start_offset()..];
1653
1654        let lhs_stride = lhs_l.stride();
1655        let rhs_stride = rhs_l.stride();
1656        let rank = lhs_stride.len();
1657
1658        let a_skip: usize = match lhs_stride[..rank - 2] {
1659            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1660            [stride] => stride,
1661            [] => m * k,
1662            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1663        };
1664        let b_skip: usize = match rhs_stride[..rank - 2] {
1665            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1666            [stride] => stride,
1667            [] => n * k,
1668            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1669        };
1670        let c_skip: usize = m * n;
1671
1672        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1673        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1674        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1675        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1676
1677        let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
1678            (n as i32, b'N')
1679        } else if rhs_m1 == k && rhs_m2 == 1 {
1680            (k as i32, b'T')
1681        } else {
1682            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1683        };
1684        // The b tensor has dims batching, m, k (lhs)
1685        let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
1686            (k as i32, b'N')
1687        } else if lhs_m1 == m && lhs_m2 == 1 {
1688            (m as i32, b'T')
1689        } else {
1690            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1691        };
1692
1693        let mut dst = vec![T::zero(); b * m * n];
1694        match T::DTYPE {
1695            DType::F16 => {
1696                for step in 0..b {
1697                    let lhs_p = &lhs[step * a_skip..];
1698                    let rhs_p = &rhs[step * b_skip..];
1699                    let dst_p = &mut dst[step * c_skip..];
1700                    unsafe {
1701                        let a = rhs_p.as_ptr() as *const f16;
1702                        let b = lhs_p.as_ptr() as *const f16;
1703                        let c = dst_p.as_mut_ptr() as *mut f16;
1704                        let a = std::slice::from_raw_parts(a, a_skip);
1705                        let b = std::slice::from_raw_parts(b, b_skip);
1706                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1707                        crate::mkl::hgemm(
1708                            transa,
1709                            transb,
1710                            /* m= */ n as i32,
1711                            /* n= */ m as i32,
1712                            /* k= */ k as i32,
1713                            /* alpha= */ f16::ONE,
1714                            /* a= */ a,
1715                            /* lda= */ lda,
1716                            /* b= */ b,
1717                            /* ldb= */ ldb,
1718                            /* beta= */ f16::ZERO,
1719                            /* c= */ c,
1720                            /* ldc= */ n as i32,
1721                        )
1722                    }
1723                }
1724            }
1725            DType::F32 => {
1726                for step in 0..b {
1727                    let lhs_p = &lhs[step * a_skip..];
1728                    let rhs_p = &rhs[step * b_skip..];
1729                    let dst_p = &mut dst[step * c_skip..];
1730                    unsafe {
1731                        let a = rhs_p.as_ptr() as *const f32;
1732                        let b = lhs_p.as_ptr() as *const f32;
1733                        let c = dst_p.as_mut_ptr() as *mut f32;
1734                        let a = std::slice::from_raw_parts(a, a_skip);
1735                        let b = std::slice::from_raw_parts(b, b_skip);
1736                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1737                        crate::mkl::sgemm(
1738                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1739                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1740                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1741                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1742                        )
1743                    }
1744                }
1745            }
1746            DType::F64 => {
1747                for step in 0..b {
1748                    let lhs_p = &lhs[step * a_skip..];
1749                    let rhs_p = &rhs[step * b_skip..];
1750                    let dst_p = &mut dst[step * c_skip..];
1751                    unsafe {
1752                        let a = rhs_p.as_ptr() as *const f64;
1753                        let b = lhs_p.as_ptr() as *const f64;
1754                        let c = dst_p.as_mut_ptr() as *mut f64;
1755                        let a = std::slice::from_raw_parts(a, a_skip);
1756                        let b = std::slice::from_raw_parts(b, b_skip);
1757                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1758                        crate::mkl::dgemm(
1759                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1760                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1761                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1762                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1763                        )
1764                    }
1765                }
1766            }
1767            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1768        }
1769        Ok(dst)
1770    }
1771}
1772
1773fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1774    if v.is_sign_positive() {
1775        v
1776    } else {
1777        (v.exp() - T::one()) * alpha
1778    }
1779}
1780
1781impl CpuStorage {
1782    pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1783        D::cpu_storage_as_slice(self)
1784    }
1785
1786    pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1787        let storage0 = &storages[0];
1788        let s = match storage0 {
1789            Self::U8(_) => {
1790                let storages = storages
1791                    .iter()
1792                    .map(|s| match s {
1793                        Self::U8(s) => Ok(s.as_slice()),
1794                        _ => crate::bail!("dtype mismatch"),
1795                    })
1796                    .collect::<Result<Vec<_>>>()?
1797                    .concat();
1798                Self::U8(storages)
1799            }
1800            Self::U32(_) => {
1801                let storages = storages
1802                    .iter()
1803                    .map(|s| match s {
1804                        Self::U32(s) => Ok(s.as_slice()),
1805                        _ => crate::bail!("dtype mismatch"),
1806                    })
1807                    .collect::<Result<Vec<_>>>()?
1808                    .concat();
1809                Self::U32(storages)
1810            }
1811            Self::I64(_) => {
1812                let storages = storages
1813                    .iter()
1814                    .map(|s| match s {
1815                        Self::I64(s) => Ok(s.as_slice()),
1816                        _ => crate::bail!("dtype mismatch"),
1817                    })
1818                    .collect::<Result<Vec<_>>>()?
1819                    .concat();
1820                Self::I64(storages)
1821            }
1822            Self::BF16(_) => {
1823                let storages = storages
1824                    .iter()
1825                    .map(|s| match s {
1826                        Self::BF16(s) => Ok(s.as_slice()),
1827                        _ => crate::bail!("dtype mismatch"),
1828                    })
1829                    .collect::<Result<Vec<_>>>()?
1830                    .concat();
1831                Self::BF16(storages)
1832            }
1833            Self::F16(_) => {
1834                let storages = storages
1835                    .iter()
1836                    .map(|s| match s {
1837                        Self::F16(s) => Ok(s.as_slice()),
1838                        _ => crate::bail!("dtype mismatch"),
1839                    })
1840                    .collect::<Result<Vec<_>>>()?
1841                    .concat();
1842                Self::F16(storages)
1843            }
1844            Self::F32(_) => {
1845                let storages = storages
1846                    .iter()
1847                    .map(|s| match s {
1848                        Self::F32(s) => Ok(s.as_slice()),
1849                        _ => crate::bail!("dtype mismatch"),
1850                    })
1851                    .collect::<Result<Vec<_>>>()?
1852                    .concat();
1853                Self::F32(storages)
1854            }
1855            Self::F64(_) => {
1856                let storages = storages
1857                    .iter()
1858                    .map(|s| match s {
1859                        Self::F64(s) => Ok(s.as_slice()),
1860                        _ => crate::bail!("dtype mismatch"),
1861                    })
1862                    .collect::<Result<Vec<_>>>()?
1863                    .concat();
1864                Self::F64(storages)
1865            }
1866        };
1867        Ok(s)
1868    }
1869}
1870
1871impl BackendStorage for CpuStorage {
1872    type Device = CpuDevice;
1873
1874    fn dtype(&self) -> DType {
1875        match self {
1876            Self::U8(_) => DType::U8,
1877            Self::U32(_) => DType::U32,
1878            Self::I64(_) => DType::I64,
1879            Self::BF16(_) => DType::BF16,
1880            Self::F16(_) => DType::F16,
1881            Self::F32(_) => DType::F32,
1882            Self::F64(_) => DType::F64,
1883        }
1884    }
1885
1886    fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1887        // TODO: find a way around the quadratic number of cases below.
1888        match (self, dtype) {
1889            (Self::U8(storage), DType::BF16) => {
1890                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1891                Ok(Self::BF16(data))
1892            }
1893            (Self::U32(storage), DType::BF16) => {
1894                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1895                Ok(Self::BF16(data))
1896            }
1897            (Self::I64(storage), DType::BF16) => {
1898                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1899                Ok(Self::BF16(data))
1900            }
1901            (Self::BF16(storage), DType::BF16) => {
1902                let data = unary_map(storage, layout, |v| v);
1903                Ok(Self::BF16(data))
1904            }
1905            (Self::F16(storage), DType::BF16) => {
1906                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1907                Ok(Self::BF16(data))
1908            }
1909            (Self::F32(storage), DType::BF16) => {
1910                let data = unary_map(storage, layout, bf16::from_f32);
1911                Ok(Self::BF16(data))
1912            }
1913            (Self::F64(storage), DType::BF16) => {
1914                let data = unary_map(storage, layout, bf16::from_f64);
1915                Ok(Self::BF16(data))
1916            }
1917            (Self::U8(storage), DType::F16) => {
1918                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1919                Ok(Self::F16(data))
1920            }
1921            (Self::U32(storage), DType::F16) => {
1922                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1923                Ok(Self::F16(data))
1924            }
1925            (Self::I64(storage), DType::F16) => {
1926                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1927                Ok(Self::F16(data))
1928            }
1929            (Self::BF16(storage), DType::F16) => {
1930                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1931                Ok(Self::F16(data))
1932            }
1933            (Self::F16(storage), DType::F16) => {
1934                let data = unary_map(storage, layout, |v| v);
1935                Ok(Self::F16(data))
1936            }
1937            (Self::F32(storage), DType::F16) => {
1938                let data = unary_map(storage, layout, f16::from_f32);
1939                Ok(Self::F16(data))
1940            }
1941            (Self::F64(storage), DType::F16) => {
1942                let data = unary_map(storage, layout, f16::from_f64);
1943                Ok(Self::F16(data))
1944            }
1945            (Self::U8(storage), DType::F32) => {
1946                let data = unary_map(storage, layout, |v| v as f32);
1947                Ok(Self::F32(data))
1948            }
1949            (Self::U32(storage), DType::F32) => {
1950                let data = unary_map(storage, layout, |v| v as f32);
1951                Ok(Self::F32(data))
1952            }
1953            (Self::I64(storage), DType::F32) => {
1954                let data = unary_map(storage, layout, |v| v as f32);
1955                Ok(Self::F32(data))
1956            }
1957            (Self::BF16(storage), DType::F32) => {
1958                let data = unary_map(storage, layout, |v| v.to_f32());
1959                Ok(Self::F32(data))
1960            }
1961            (Self::F16(storage), DType::F32) => {
1962                let data = unary_map(storage, layout, |v| v.to_f32());
1963                Ok(Self::F32(data))
1964            }
1965            (Self::F32(storage), DType::F32) => {
1966                let data = unary_map(storage, layout, |v| v);
1967                Ok(Self::F32(data))
1968            }
1969            (Self::F64(storage), DType::F32) => {
1970                let data = unary_map(storage, layout, |v| v as f32);
1971                Ok(Self::F32(data))
1972            }
1973            (Self::U8(storage), DType::U8) => {
1974                let data = unary_map(storage, layout, |v| v);
1975                Ok(Self::U8(data))
1976            }
1977            (Self::BF16(storage), DType::U8) => {
1978                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1979                Ok(Self::U8(data))
1980            }
1981            (Self::F16(storage), DType::U8) => {
1982                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1983                Ok(Self::U8(data))
1984            }
1985            (Self::F32(storage), DType::U8) => {
1986                let data = unary_map(storage, layout, |v| v as u8);
1987                Ok(Self::U8(data))
1988            }
1989            (Self::F64(storage), DType::U8) => {
1990                let data = unary_map(storage, layout, |v| v as u8);
1991                Ok(Self::U8(data))
1992            }
1993            (Self::U32(storage), DType::U8) => {
1994                let data = unary_map(storage, layout, |v| v as u8);
1995                Ok(Self::U8(data))
1996            }
1997            (Self::I64(storage), DType::U8) => {
1998                let data = unary_map(storage, layout, |v| v as u8);
1999                Ok(Self::U8(data))
2000            }
2001            (Self::U8(storage), DType::U32) => {
2002                let data = unary_map(storage, layout, |v| v as u32);
2003                Ok(Self::U32(data))
2004            }
2005            (Self::U32(storage), DType::U32) => {
2006                let data = unary_map(storage, layout, |v| v);
2007                Ok(Self::U32(data))
2008            }
2009            (Self::I64(storage), DType::U32) => {
2010                let data = unary_map(storage, layout, |v| v as u32);
2011                Ok(Self::U32(data))
2012            }
2013            (Self::BF16(storage), DType::U32) => {
2014                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2015                Ok(Self::U32(data))
2016            }
2017            (Self::F16(storage), DType::U32) => {
2018                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2019                Ok(Self::U32(data))
2020            }
2021            (Self::F32(storage), DType::U32) => {
2022                let data = unary_map(storage, layout, |v| v as u32);
2023                Ok(Self::U32(data))
2024            }
2025            (Self::F64(storage), DType::U32) => {
2026                let data = unary_map(storage, layout, |v| v as u32);
2027                Ok(Self::U32(data))
2028            }
2029            (Self::U8(storage), DType::I64) => {
2030                let data = unary_map(storage, layout, |v| v as i64);
2031                Ok(Self::I64(data))
2032            }
2033            (Self::U32(storage), DType::I64) => {
2034                let data = unary_map(storage, layout, |v| v as i64);
2035                Ok(Self::I64(data))
2036            }
2037            (Self::I64(storage), DType::I64) => {
2038                let data = unary_map(storage, layout, |v| v);
2039                Ok(Self::I64(data))
2040            }
2041            (Self::BF16(storage), DType::I64) => {
2042                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2043                Ok(Self::I64(data))
2044            }
2045            (Self::F16(storage), DType::I64) => {
2046                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2047                Ok(Self::I64(data))
2048            }
2049            (Self::F32(storage), DType::I64) => {
2050                let data = unary_map(storage, layout, |v| v as i64);
2051                Ok(Self::I64(data))
2052            }
2053            (Self::F64(storage), DType::I64) => {
2054                let data = unary_map(storage, layout, |v| v as i64);
2055                Ok(Self::I64(data))
2056            }
2057            (Self::U8(storage), DType::F64) => {
2058                let data = unary_map(storage, layout, |v| v as f64);
2059                Ok(Self::F64(data))
2060            }
2061            (Self::U32(storage), DType::F64) => {
2062                let data = unary_map(storage, layout, |v| v as f64);
2063                Ok(Self::F64(data))
2064            }
2065            (Self::I64(storage), DType::F64) => {
2066                let data = unary_map(storage, layout, |v| v as f64);
2067                Ok(Self::F64(data))
2068            }
2069            (Self::BF16(storage), DType::F64) => {
2070                let data = unary_map(storage, layout, |v| v.to_f64());
2071                Ok(Self::F64(data))
2072            }
2073            (Self::F16(storage), DType::F64) => {
2074                let data = unary_map(storage, layout, |v| v.to_f64());
2075                Ok(Self::F64(data))
2076            }
2077            (Self::F32(storage), DType::F64) => {
2078                let data = unary_map(storage, layout, |v| v as f64);
2079                Ok(Self::F64(data))
2080            }
2081            (Self::F64(storage), DType::F64) => {
2082                let data = unary_map(storage, layout, |v| v);
2083                Ok(Self::F64(data))
2084            }
2085        }
2086    }
2087
2088    fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
2089        match op {
2090            ReduceOp::Sum => {
2091                let src_dims = layout.dims();
2092                let mut dst_dims = src_dims.to_vec();
2093                for &dim in reduce_dims.iter() {
2094                    dst_dims[dim] = 1;
2095                }
2096                let dst_shape = Shape::from(dst_dims);
2097                let mut reduce_dims = reduce_dims.to_vec();
2098                // Sort the reduce_dims as they have to be processed from left to right when converting the
2099                // indexes.
2100                reduce_dims.sort();
2101                let reduce_dims_and_stride: Vec<_> = reduce_dims
2102                    .iter()
2103                    .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
2104                    .collect();
2105                ReduceSum {
2106                    dst_shape: &dst_shape,
2107                    reduce_dims: &reduce_dims,
2108                    reduce_dims_and_stride,
2109                }
2110                .map(self, layout)
2111            }
2112            ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
2113                let reduce_dim_index = match reduce_dims {
2114                    [reduce_dim_index] => *reduce_dim_index,
2115                    _ => {
2116                        let op = match op {
2117                            ReduceOp::Min => "min",
2118                            ReduceOp::ArgMin => "argmin",
2119                            ReduceOp::Max => "max",
2120                            ReduceOp::ArgMax => "argmax",
2121                            _ => unreachable!(),
2122                        };
2123                        let dims = reduce_dims.to_vec();
2124                        Err(Error::OnlySingleDimension { op, dims })?
2125                    }
2126                };
2127                let (use_min, return_index) = match op {
2128                    ReduceOp::Min => (true, false),
2129                    ReduceOp::ArgMin => (true, true),
2130                    ReduceOp::Max => (false, false),
2131                    ReduceOp::ArgMax => (false, true),
2132                    _ => unreachable!(),
2133                };
2134                ReduceIndex {
2135                    reduce_dim_index,
2136                    use_min,
2137                    return_index,
2138                }
2139                .map(self, layout)
2140            }
2141        }
2142    }
2143
2144    fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
2145        Cmp(op).map(self, lhs_l, rhs, rhs_l)
2146    }
2147
2148    fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
2149        Affine(mul, add).map(self, layout)
2150    }
2151
2152    fn avg_pool2d(
2153        &self,
2154        layout: &Layout,
2155        kernel_size: (usize, usize),
2156        stride: (usize, usize),
2157    ) -> Result<Self> {
2158        AvgPool2D(kernel_size, stride).map(self, layout)
2159    }
2160
2161    fn max_pool2d(
2162        &self,
2163        layout: &Layout,
2164        kernel_size: (usize, usize),
2165        stride: (usize, usize),
2166    ) -> Result<Self> {
2167        MaxPool2D(kernel_size, stride).map(self, layout)
2168    }
2169
2170    fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
2171        UpsampleNearest1D(sz).map(self, layout)
2172    }
2173
2174    fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
2175        UpsampleNearest2D(h, w).map(self, layout)
2176    }
2177
2178    fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
2179        use num_traits::Float;
2180        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2181        match self {
2182            Self::BF16(storage) => {
2183                let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
2184                Ok(Self::BF16(data))
2185            }
2186            Self::F16(storage) => {
2187                let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
2188                Ok(Self::F16(data))
2189            }
2190            Self::F32(storage) => {
2191                let data = unary_map(storage, layout, |v| v.powf(e as f32));
2192                Ok(Self::F32(data))
2193            }
2194            Self::F64(storage) => {
2195                let data = unary_map(storage, layout, |v| v.powf(e));
2196                Ok(Self::F64(data))
2197            }
2198            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2199            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2200            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2201        }
2202    }
2203
2204    fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2205        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2206        match self {
2207            Self::BF16(storage) => {
2208                let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2209                Ok(Self::BF16(data))
2210            }
2211            Self::F16(storage) => {
2212                let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2213                Ok(Self::F16(data))
2214            }
2215            Self::F32(storage) => {
2216                let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2217                Ok(Self::F32(data))
2218            }
2219            Self::F64(storage) => {
2220                let data = unary_map(storage, layout, |v| elu(v, alpha));
2221                Ok(Self::F64(data))
2222            }
2223            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2224            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2225            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2226        }
2227    }
2228
2229    fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2230        match self {
2231            Self::BF16(storage) => {
2232                if B::BF16_VEC {
2233                    let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2234                    Ok(Self::BF16(data))
2235                } else {
2236                    let data = unary_map(storage, layout, B::bf16);
2237                    Ok(Self::BF16(data))
2238                }
2239            }
2240            Self::F16(storage) => {
2241                if B::F16_VEC {
2242                    let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2243                    Ok(Self::F16(data))
2244                } else {
2245                    let data = unary_map(storage, layout, B::f16);
2246                    Ok(Self::F16(data))
2247                }
2248            }
2249            Self::F32(storage) => {
2250                if B::F32_VEC {
2251                    let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2252                    Ok(Self::F32(data))
2253                } else {
2254                    let data = unary_map(storage, layout, B::f32);
2255                    Ok(Self::F32(data))
2256                }
2257            }
2258            Self::F64(storage) => {
2259                if B::F64_VEC {
2260                    let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2261                    Ok(Self::F64(data))
2262                } else {
2263                    let data = unary_map(storage, layout, B::f64);
2264                    Ok(Self::F64(data))
2265                }
2266            }
2267            Self::U8(storage) => {
2268                let data = unary_map(storage, layout, B::u8);
2269                Ok(Self::U8(data))
2270            }
2271            Self::U32(storage) => {
2272                let data = unary_map(storage, layout, B::u32);
2273                Ok(Self::U32(data))
2274            }
2275            Self::I64(storage) => {
2276                let data = unary_map(storage, layout, B::i64);
2277                Ok(Self::I64(data))
2278            }
2279        }
2280    }
2281
2282    fn binary_impl<B: BinaryOpT>(
2283        &self,
2284        rhs: &Self,
2285        lhs_l: &Layout,
2286        rhs_l: &Layout,
2287    ) -> Result<Self> {
2288        match (self, rhs) {
2289            (Self::BF16(lhs), Self::BF16(rhs)) => {
2290                let data = if B::BF16_VEC {
2291                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
2292                } else {
2293                    binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
2294                };
2295                Ok(Self::BF16(data))
2296            }
2297            (Self::F16(lhs), Self::F16(rhs)) => {
2298                let data = if B::F16_VEC {
2299                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
2300                } else {
2301                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
2302                };
2303                Ok(Self::F16(data))
2304            }
2305            (Self::F32(lhs), Self::F32(rhs)) => {
2306                let data = if B::F32_VEC {
2307                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
2308                } else {
2309                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
2310                };
2311                Ok(Self::F32(data))
2312            }
2313            (Self::F64(lhs), Self::F64(rhs)) => {
2314                let data = if B::F64_VEC {
2315                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
2316                } else {
2317                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
2318                };
2319                Ok(Self::F64(data))
2320            }
2321            (Self::U32(lhs), Self::U32(rhs)) => {
2322                let data = if B::U32_VEC {
2323                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
2324                } else {
2325                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
2326                };
2327                Ok(Self::U32(data))
2328            }
2329            (Self::I64(lhs), Self::I64(rhs)) => {
2330                let data = if B::I64_VEC {
2331                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
2332                } else {
2333                    binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
2334                };
2335                Ok(Self::I64(data))
2336            }
2337            (Self::U8(lhs), Self::U8(rhs)) => {
2338                let data = if B::U8_VEC {
2339                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
2340                } else {
2341                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
2342                };
2343                Ok(Self::U8(data))
2344            }
2345            _ => {
2346                // This should be covered by the dtype check above.
2347                Err(Error::DTypeMismatchBinaryOp {
2348                    lhs: self.dtype(),
2349                    rhs: rhs.dtype(),
2350                    op: B::NAME,
2351                }
2352                .bt())
2353            }
2354        }
2355    }
2356
2357    fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2358        match (self, dst) {
2359            (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2360            (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2361            (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2362            (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2363            (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2364            (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2365            (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2366            (_, dst) => {
2367                // This should be covered by the dtype check above.
2368                return Err(Error::DTypeMismatchBinaryOp {
2369                    lhs: self.dtype(),
2370                    rhs: dst.dtype(),
2371                    op: "copy_strided",
2372                }
2373                .bt());
2374            }
2375        }
2376        Ok(())
2377    }
2378
2379    fn where_cond(
2380        &self,
2381        layout: &Layout,
2382        t: &Self,
2383        t_l: &Layout,
2384        f: &Self,
2385        f_l: &Layout,
2386    ) -> Result<Self> {
2387        match self {
2388            Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2389            Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2390            Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2391            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2392        }
2393    }
2394
2395    fn conv1d(
2396        &self,
2397        l: &Layout,
2398        kernel: &Self,
2399        kernel_l: &Layout,
2400        params: &crate::conv::ParamsConv1D,
2401    ) -> Result<Self> {
2402        if !USE_IM2COL_CONV1D {
2403            return Conv1D(params).map(self, l, kernel, kernel_l);
2404        }
2405        let op = Im2Col1D {
2406            l_k: params.k_size,
2407            padding: params.padding,
2408            stride: params.stride,
2409            dilation: params.dilation,
2410        };
2411        let col = op.map(self, l)?;
2412        let b = params.b_size;
2413        let n = params.c_out;
2414        let l_out = params.l_out();
2415        let k = op.l_k * params.c_in;
2416        let m = l_out;
2417        let col_l = Layout::contiguous((b, m, k));
2418        let res = if kernel_l.is_contiguous() {
2419            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2420                .transpose(1, 2)?
2421                .broadcast_as((b, k, n))?;
2422            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2423        } else {
2424            // Make the kernel contiguous if not already the case.
2425            let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
2426            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2427            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2428                .transpose(1, 2)?
2429                .broadcast_as((b, k, n))?;
2430            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2431        };
2432        let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2433        let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
2434        res.copy_strided_src(&mut res_t, 0, &res_l)?;
2435        Ok(res_t)
2436    }
2437
2438    fn conv2d(
2439        &self,
2440        l: &Layout,
2441        kernel: &Self,
2442        kernel_l: &Layout,
2443        params: &crate::conv::ParamsConv2D,
2444    ) -> Result<Self> {
2445        if !USE_IM2COL_CONV2D {
2446            return Conv2D(params).map(self, l, kernel, kernel_l);
2447        }
2448        let op = Im2Col {
2449            h_k: params.k_h,
2450            w_k: params.k_w,
2451            padding: params.padding,
2452            stride: params.stride,
2453            dilation: params.dilation,
2454        };
2455        let col = op.map(self, l)?;
2456        let b = params.b_size;
2457        let n = params.c_out;
2458        let (h_out, w_out) = (params.out_h(), params.out_w());
2459        let k = op.h_k * op.w_k * params.c_in;
2460        let m = h_out * w_out;
2461        let col_l = Layout::contiguous((b, m, k));
2462        let res = if kernel_l.is_contiguous() {
2463            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2464                .transpose(1, 2)?
2465                .broadcast_as((b, k, n))?;
2466            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2467        } else {
2468            // Make the kernel contiguous if not already the case.
2469            let mut kernel_c = self.device().zeros_impl(kernel_l.shape(), kernel.dtype())?;
2470            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2471            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2472                .transpose(1, 2)?
2473                .broadcast_as((b, k, n))?;
2474            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2475        };
2476        let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
2477            .transpose(1, 2)?
2478            .transpose(1, 3)?;
2479        let mut res_t = self.device().zeros_impl(res_l.shape(), res.dtype())?;
2480        res.copy_strided_src(&mut res_t, 0, &res_l)?;
2481        Ok(res_t)
2482    }
2483
2484    fn conv_transpose2d(
2485        &self,
2486        l: &Layout,
2487        kernel: &Self,
2488        kernel_l: &Layout,
2489        params: &crate::conv::ParamsConvTranspose2D,
2490    ) -> Result<Self> {
2491        ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2492    }
2493
2494    fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2495        match ids {
2496            Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2497            Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2498            Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2499            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
2500        }
2501    }
2502
2503    fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2504        match ids {
2505            Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2506            Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2507            Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2508            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
2509        }
2510    }
2511
2512    fn scatter_add(
2513        &self,
2514        l: &Layout,
2515        ids: &Self,
2516        ids_l: &Layout,
2517        src: &Self,
2518        src_l: &Layout,
2519        dim: usize,
2520    ) -> Result<Self> {
2521        match ids {
2522            Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
2523            Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
2524            Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
2525            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
2526        }
2527    }
2528
2529    fn index_add(
2530        &self,
2531        l: &Layout,
2532        ids: &Self,
2533        ids_l: &Layout,
2534        src: &Self,
2535        src_l: &Layout,
2536        dim: usize,
2537    ) -> Result<Self> {
2538        match ids {
2539            Self::U8(ids) => {
2540                let ids = match ids_l.contiguous_offsets() {
2541                    Some((a, b)) => &ids[a..b],
2542                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2543                };
2544                IndexAdd { ids, dim }.map(self, l, src, src_l)
2545            }
2546            Self::U32(ids) => {
2547                let ids = match ids_l.contiguous_offsets() {
2548                    Some((a, b)) => &ids[a..b],
2549                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2550                };
2551                IndexAdd { ids, dim }.map(self, l, src, src_l)
2552            }
2553            Self::I64(ids) => {
2554                let ids = match ids_l.contiguous_offsets() {
2555                    Some((a, b)) => &ids[a..b],
2556                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2557                };
2558                IndexAdd { ids, dim }.map(self, l, src, src_l)
2559            }
2560            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2561        }
2562    }
2563
2564    fn matmul(
2565        &self,
2566        rhs: &Self,
2567        bmnk: (usize, usize, usize, usize),
2568        lhs_l: &Layout,
2569        rhs_l: &Layout,
2570    ) -> Result<Self> {
2571        MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2572    }
2573
2574    fn device(&self) -> &Self::Device {
2575        &CpuDevice
2576    }
2577
2578    fn try_clone(&self, _: &Layout) -> Result<Self> {
2579        Ok(self.clone())
2580    }
2581
2582    fn to_cpu_storage(&self) -> Result<CpuStorage> {
2583        Ok(self.clone())
2584    }
2585}
2586
2587impl BackendDevice for CpuDevice {
2588    type Storage = CpuStorage;
2589
2590    fn location(&self) -> crate::DeviceLocation {
2591        crate::DeviceLocation::Cpu
2592    }
2593
2594    fn same_device(&self, _: &Self) -> bool {
2595        true
2596    }
2597
2598    fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
2599        Ok(s.clone())
2600    }
2601
2602    fn new(_: usize) -> Result<Self> {
2603        Ok(Self)
2604    }
2605
2606    fn set_seed(&self, _seed: u64) -> Result<()> {
2607        crate::bail!("cannot seed the CPU rng with set_seed")
2608    }
2609
2610    fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
2611        use rand::prelude::*;
2612
2613        let elem_count = shape.elem_count();
2614        let mut rng = rand::thread_rng();
2615        match dtype {
2616            DType::U8 | DType::U32 | DType::I64 => {
2617                Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
2618            }
2619            DType::BF16 => {
2620                let mut data = Vec::with_capacity(elem_count);
2621                let uniform =
2622                    rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
2623                for _i in 0..elem_count {
2624                    data.push(rng.sample::<bf16, _>(uniform))
2625                }
2626                Ok(CpuStorage::BF16(data))
2627            }
2628            DType::F16 => {
2629                let mut data = Vec::with_capacity(elem_count);
2630                let uniform =
2631                    rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
2632                for _i in 0..elem_count {
2633                    data.push(rng.sample::<f16, _>(uniform))
2634                }
2635                Ok(CpuStorage::F16(data))
2636            }
2637            DType::F32 => {
2638                let mut data = Vec::with_capacity(elem_count);
2639                let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
2640                for _i in 0..elem_count {
2641                    data.push(rng.sample::<f32, _>(uniform))
2642                }
2643                Ok(CpuStorage::F32(data))
2644            }
2645            DType::F64 => {
2646                let mut data = Vec::with_capacity(elem_count);
2647                let uniform = rand::distributions::Uniform::new(min, max);
2648                for _i in 0..elem_count {
2649                    data.push(rng.sample::<f64, _>(uniform))
2650                }
2651                Ok(CpuStorage::F64(data))
2652            }
2653        }
2654    }
2655
2656    fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
2657        use rand::prelude::*;
2658
2659        let elem_count = shape.elem_count();
2660        let mut rng = rand::thread_rng();
2661        match dtype {
2662            DType::U8 | DType::U32 | DType::I64 => {
2663                Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
2664            }
2665            DType::BF16 => {
2666                let mut data = Vec::with_capacity(elem_count);
2667                let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
2668                    .map_err(Error::wrap)?;
2669                for _i in 0..elem_count {
2670                    data.push(normal.sample(&mut rng))
2671                }
2672                Ok(CpuStorage::BF16(data))
2673            }
2674            DType::F16 => {
2675                let mut data = Vec::with_capacity(elem_count);
2676                let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
2677                    .map_err(Error::wrap)?;
2678                for _i in 0..elem_count {
2679                    data.push(normal.sample(&mut rng))
2680                }
2681                Ok(CpuStorage::F16(data))
2682            }
2683            DType::F32 => {
2684                let mut data = Vec::with_capacity(elem_count);
2685                let normal =
2686                    rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
2687                for _i in 0..elem_count {
2688                    data.push(normal.sample(&mut rng))
2689                }
2690                Ok(CpuStorage::F32(data))
2691            }
2692            DType::F64 => {
2693                let mut data = Vec::with_capacity(elem_count);
2694                let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
2695                for _i in 0..elem_count {
2696                    data.push(normal.sample(&mut rng))
2697                }
2698                Ok(CpuStorage::F64(data))
2699            }
2700        }
2701    }
2702
2703    fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2704        let elem_count = shape.elem_count();
2705        let storage = match dtype {
2706            DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
2707            DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
2708            DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
2709            DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
2710            DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
2711            DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
2712            DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
2713        };
2714        Ok(storage)
2715    }
2716
2717    fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2718        let elem_count = shape.elem_count();
2719        let storage = match dtype {
2720            DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
2721            DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
2722            DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
2723            DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
2724            DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
2725            DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
2726            DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
2727        };
2728        Ok(storage)
2729    }
2730}
2731
2732#[macro_export]
2733macro_rules! map_dtype {
2734    ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
2735        match $storage {
2736            $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
2737            s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
2738        }
2739    };
2740}