candle_core/cpu_backend/
mod.rs

1//! Implementation of Backend Fns for CPU
2use crate::backend::{BackendDevice, BackendStorage};
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5use half::{bf16, f16};
6use rayon::prelude::*;
7
8mod utils;
9pub use utils::{
10    binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
11};
12
13const USE_IM2COL_CONV1D: bool = true;
14const USE_COL2IM_CONV1D_TR: bool = true;
15const USE_IM2COL_CONV2D: bool = true;
16
17// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
18// intercept the oom errors to avoid panicking and provide a proper error.
19#[derive(Debug, Clone)]
20pub enum CpuStorage {
21    U8(Vec<u8>),
22    U32(Vec<u32>),
23    I64(Vec<i64>),
24    BF16(Vec<bf16>),
25    F16(Vec<f16>),
26    F32(Vec<f32>),
27    F64(Vec<f64>),
28}
29
30#[derive(Debug, Clone)]
31pub enum CpuStorageRef<'a> {
32    U8(&'a [u8]),
33    U32(&'a [u32]),
34    I64(&'a [i64]),
35    BF16(&'a [bf16]),
36    F16(&'a [f16]),
37    F32(&'a [f32]),
38    F64(&'a [f64]),
39}
40
41#[derive(Debug, Clone)]
42pub struct CpuDevice;
43
44struct Cmp(CmpOp);
45impl Map2U8 for Cmp {
46    const OP: &'static str = "cmp";
47    #[inline(always)]
48    fn f<T: WithDType>(
49        &self,
50        lhs: &[T],
51        lhs_l: &Layout,
52        rhs: &[T],
53        rhs_l: &Layout,
54    ) -> Result<Vec<u8>> {
55        let dst = match self.0 {
56            CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
57            CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
58            CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
59            CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
60            CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
61            CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
62        };
63        Ok(dst)
64    }
65}
66
67struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
68
69impl<I: IntDType> Map2 for WCond<'_, I> {
70    const OP: &'static str = "where";
71    #[inline(always)]
72    fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
73        let vs = match (
74            self.1.contiguous_offsets(),
75            t_l.contiguous_offsets(),
76            f_l.contiguous_offsets(),
77        ) {
78            (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
79                let pred = &self.0[o1..o2];
80                let t = &t[o_t1..o_t2];
81                let f = &f[o_f1..o_f2];
82                pred.iter()
83                    .zip(t.iter().zip(f.iter()))
84                    .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
85                    .collect::<Vec<_>>()
86            }
87            _ => self
88                .1
89                .strided_index()
90                .zip(t_l.strided_index().zip(f_l.strided_index()))
91                .map(|(i_p, (i_t, i_f))| {
92                    if self.0[i_p].is_true() {
93                        t[i_t]
94                    } else {
95                        f[i_f]
96                    }
97                })
98                .collect::<Vec<_>>(),
99        };
100        Ok(vs)
101    }
102}
103
104struct ReduceIndex {
105    reduce_dim_index: usize,
106    use_min: bool,
107    return_index: bool,
108}
109
110impl ReduceIndex {
111    // The value gets replaced if f(s[current_acc], s[i]) returns true.
112    #[inline(always)]
113    fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
114    where
115        T: Clone + Copy,
116        U: Clone + Copy,
117        F: Fn(T, T) -> bool,
118        G: Fn(T, usize) -> U,
119    {
120        let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
121        let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
122        let dst_len = src_l.shape().elem_count() / reduce_dim_size;
123        let mut dst: Vec<U> = Vec::with_capacity(dst_len);
124        let dst_to_set = dst.spare_capacity_mut();
125        let dst_to_set =
126            unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
127        match src_l.contiguous_offsets() {
128            Some((o1, o2)) => {
129                let src = &src[o1..o2];
130                if reduce_dim_stride == 1 {
131                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
132                        let start_src_i = start_src_i * reduce_dim_size;
133                        let src = &src[start_src_i..start_src_i + reduce_dim_size];
134                        let mut acc = 0;
135                        let mut val = src[0];
136                        for (src_i, &s) in src.iter().enumerate() {
137                            if f(val, s) {
138                                acc = src_i;
139                                val = s
140                            }
141                        }
142                        *dst_v = g(val, acc)
143                    }
144                } else {
145                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
146                        let (p, q) = (
147                            start_src_i / reduce_dim_stride,
148                            start_src_i % reduce_dim_stride,
149                        );
150                        // start_src_i = p * reduce_dim_stride + q
151                        let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
152                        let src = &src[start_src_i..];
153                        let mut acc = 0;
154                        let mut val = src[0];
155                        for src_i in 0..reduce_dim_size {
156                            let s = src[src_i * reduce_dim_stride];
157                            if f(val, s) {
158                                acc = src_i;
159                                val = s
160                            }
161                        }
162                        *dst_v = g(val, acc)
163                    }
164                }
165            }
166            None => {
167                let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
168                for (unstr_index, src_index) in l.strided_index().enumerate() {
169                    let src = &src[src_index..];
170                    let mut acc = 0;
171                    let mut val = src[0];
172                    for src_i in 0..reduce_dim_size {
173                        let s = src[src_i * reduce_dim_stride];
174                        if f(val, s) {
175                            acc = src_i;
176                            val = s
177                        }
178                    }
179                    dst_to_set[unstr_index] = g(val, acc)
180                }
181            }
182        }
183        unsafe { dst.set_len(dst_len) };
184        Ok(dst)
185    }
186}
187
188impl Map1Any for ReduceIndex {
189    #[inline(always)]
190    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
191        &self,
192        src: &[T],
193        src_l: &Layout,
194        wrap: W,
195    ) -> Result<CpuStorage> {
196        if src_l.shape().elem_count() == 0 {
197            Err(Error::EmptyTensor { op: "reduce" }.bt())?
198        }
199        let dst = match (self.return_index, self.use_min) {
200            (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
201            (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
202            (true, true) => {
203                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
204            }
205            (true, false) => {
206                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
207            }
208        };
209        Ok(dst)
210    }
211}
212
213struct ReduceSum<'a> {
214    dst_shape: &'a Shape,
215    reduce_dims: &'a [usize],
216    reduce_dims_and_stride: Vec<(usize, usize)>,
217}
218
219impl ReduceSum<'_> {
220    #[inline(always)]
221    fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
222    where
223        T: WithDType,
224    {
225        let mut dst = vec![start_elt; self.dst_shape.elem_count()];
226        match src_l.contiguous_offsets() {
227            Some((o1, o2)) => {
228                let src = &src[o1..o2];
229                // Handle the case where we reduce over the last dimensions separately as it is
230                // fairly common and easy to optimize. This rely on the layout being contiguous!
231                // reduce_dims is sorted, check if it is ranging from a to n-1.
232                let reduce_over_last_dims = self
233                    .reduce_dims
234                    .iter()
235                    .rev()
236                    .enumerate()
237                    .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
238                if reduce_over_last_dims {
239                    let reduce_sz = self
240                        .reduce_dims_and_stride
241                        .iter()
242                        .map(|(u, _)| u)
243                        .product::<usize>();
244                    for (dst_i, dst_v) in dst.iter_mut().enumerate() {
245                        let src_i = dst_i * reduce_sz;
246                        unsafe {
247                            T::vec_reduce_sum(
248                                src[src_i..src_i + reduce_sz].as_ptr(),
249                                dst_v,
250                                reduce_sz,
251                            )
252                        };
253                    }
254                    return Ok(dst);
255                };
256                for (unstr_index, &src) in src.iter().enumerate() {
257                    let mut dst_index = unstr_index;
258                    // Set the reduce_dims indexes to 0.
259                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
260                        // The compiler is able to optimize the following in a single divmod op.
261                        let (pre, post) = (dst_index / stride, dst_index % stride);
262                        dst_index = (pre / dim) * stride + post;
263                    }
264                    dst[dst_index] += src;
265                }
266            }
267            None => {
268                for (unstr_index, src_index) in src_l.strided_index().enumerate() {
269                    let mut dst_index = unstr_index;
270                    // Set the reduce_dims indexes to 0.
271                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
272                        // The compiler is able to optimize the following in a single divmod op.
273                        let (pre, post) = (dst_index / stride, dst_index % stride);
274                        dst_index = (pre / dim) * stride + post;
275                    }
276                    dst[dst_index] += src[src_index];
277                }
278            }
279        }
280        Ok(dst)
281    }
282}
283
284impl Map1 for ReduceSum<'_> {
285    #[inline(always)]
286    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
287        self.fold_impl(src, src_l, T::zero())
288    }
289}
290
291struct Affine(f64, f64);
292
293impl Map1 for Affine {
294    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
295        let mul = T::from_f64(self.0);
296        let add = T::from_f64(self.1);
297        Ok(unary_map(vs, layout, |v| v * mul + add))
298    }
299}
300
301struct AvgPool2D((usize, usize), (usize, usize));
302
303impl Map1 for AvgPool2D {
304    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
305        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
306        let (k_h, k_w) = self.0;
307        let (s_h, s_w) = self.1;
308        let (b_sz, c, h, w) = layout.shape().dims4()?;
309        let stride = layout.stride();
310        let (stride_h, stride_w) = (stride[2], stride[3]);
311        let h_out = (h - k_h) / s_h + 1;
312        let w_out = (w - k_w) / s_w + 1;
313        let src_index = layout.start_offset();
314        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
315        let scale = 1f64 / (k_h * k_w) as f64;
316        let scale = T::from_f64(scale);
317        for b_idx in 0..b_sz {
318            let dst = &mut dst[b_idx * c * h_out * w_out..];
319            let src_index = src_index + b_idx * stride[0];
320            for c_idx in 0..c {
321                let dst = &mut dst[c_idx * h_out * w_out..];
322                let src_index = src_index + c_idx * stride[1];
323                for h_idx in 0..h_out {
324                    for w_idx in 0..w_out {
325                        let mut sum = T::zero();
326                        for m in 0..k_h {
327                            for n in 0..k_w {
328                                let m = s_h * h_idx + m;
329                                let n = s_w * w_idx + n;
330                                sum += src[src_index + m * stride_h + n * stride_w]
331                            }
332                        }
333                        dst[h_idx * w_out + w_idx] = sum * scale;
334                    }
335                }
336            }
337        }
338        Ok(dst)
339    }
340}
341
342struct MaxPool2D((usize, usize), (usize, usize));
343
344impl Map1 for MaxPool2D {
345    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
346        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
347        let (k_h, k_w) = self.0;
348        let (s_h, s_w) = self.1;
349        let (b_sz, c, h, w) = layout.shape().dims4()?;
350        let stride = layout.stride();
351        let (stride_h, stride_w) = (stride[2], stride[3]);
352        let h_out = (h - k_h) / s_h + 1;
353        let w_out = (w - k_w) / s_w + 1;
354        let src_index = layout.start_offset();
355        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
356        for b_idx in 0..b_sz {
357            let dst = &mut dst[b_idx * c * h_out * w_out..];
358            let src_index = src_index + b_idx * stride[0];
359            for c_idx in 0..c {
360                let dst = &mut dst[c_idx * h_out * w_out..];
361                let src_index = src_index + c_idx * stride[1];
362                for h_idx in 0..h_out {
363                    for w_idx in 0..w_out {
364                        let mut largest =
365                            src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
366                        for m in 0..k_h {
367                            for n in 0..k_w {
368                                let m = s_h * h_idx + m;
369                                let n = s_w * w_idx + n;
370                                if largest < src[src_index + m * stride_h + n * stride_w] {
371                                    largest = src[src_index + m * stride_h + n * stride_w]
372                                }
373                            }
374                        }
375                        dst[h_idx * w_out + w_idx] = largest;
376                    }
377                }
378            }
379        }
380        Ok(dst)
381    }
382}
383
384struct UpsampleNearest1D(usize);
385
386impl Map1 for UpsampleNearest1D {
387    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
388        // TODO: Specialized implementation for the case 2*sz?
389        let dst_sz = self.0;
390        let (b_sz, c, src_sz) = layout.shape().dims3()?;
391        let stride = layout.stride();
392        let stride_sz = stride[2];
393        let src_index = layout.start_offset();
394        let scale_sz = src_sz as f64 / dst_sz as f64;
395        let mut dst = vec![T::zero(); b_sz * c * dst_sz];
396        let src_idxs = (0..dst_sz)
397            .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
398            .collect::<Vec<_>>();
399        for b_idx in 0..b_sz {
400            let dst = &mut dst[b_idx * c * dst_sz..];
401            let src_index = src_index + b_idx * stride[0];
402            for c_idx in 0..c {
403                let dst = &mut dst[c_idx * dst_sz..];
404                let src_index = src_index + c_idx * stride[1];
405                for (idx, src_idx) in src_idxs.iter().enumerate() {
406                    dst[idx] = src[src_index + src_idx * stride_sz]
407                }
408            }
409        }
410        Ok(dst)
411    }
412}
413
414struct UpsampleNearest2D(usize, usize);
415
416impl Map1 for UpsampleNearest2D {
417    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
418        // TODO: Specialized implementation for the case 2*h, 2*w?
419        let (dst_h, dst_w) = (self.0, self.1);
420        let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
421        let stride = layout.stride();
422        let (stride_h, stride_w) = (stride[2], stride[3]);
423        let src_index = layout.start_offset();
424        let scale_h = src_h as f64 / dst_h as f64;
425        let scale_w = src_w as f64 / dst_w as f64;
426        let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
427        let src_h_idxs = (0..dst_h)
428            .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
429            .collect::<Vec<_>>();
430        let src_w_idxs = (0..dst_w)
431            .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
432            .collect::<Vec<_>>();
433        for b_idx in 0..b_sz {
434            let dst = &mut dst[b_idx * c * dst_h * dst_w..];
435            let src_index = src_index + b_idx * stride[0];
436            for c_idx in 0..c {
437                let dst = &mut dst[c_idx * dst_h * dst_w..];
438                let src_index = src_index + c_idx * stride[1];
439                for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
440                    for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
441                        let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
442                        dst[h_idx * dst_w + w_idx] = src[src_index]
443                    }
444                }
445            }
446        }
447        Ok(dst)
448    }
449}
450
451struct Gather<'a, I: IntDType> {
452    ids: &'a [I],
453    ids_l: &'a Layout,
454    dim: usize,
455}
456
457impl<I: IntDType> Map1 for Gather<'_, I> {
458    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
459        let ids = match self.ids_l.contiguous_offsets() {
460            Some((a, b)) => &self.ids[a..b],
461            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
462        };
463        let src = match src_l.contiguous_offsets() {
464            Some((a, b)) => &src[a..b],
465            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
466        };
467        let dim = self.dim;
468        let ids_dims = self.ids_l.dims();
469        let src_dims = src_l.dims();
470        let dst_len: usize = ids_dims.iter().product();
471        let dst_left_len: usize = ids_dims[..dim].iter().product();
472        let dst_dim_len = ids_dims[dim];
473        let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
474
475        let src_dim_len = src_dims[dim];
476        let src_right_len: usize = src_dims[dim + 1..].iter().product();
477
478        let mut dst = vec![T::zero(); dst_len];
479        for left_i in 0..dst_left_len {
480            let start_src_idx = left_i * src_right_len * src_dim_len;
481            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
482            for i in 0..dst_dim_len {
483                let start_dst_idx = start_dst_idx + i * dst_right_len;
484                for right_i in 0..dst_right_len {
485                    let dst_idx = start_dst_idx + right_i;
486                    let index = ids[dst_idx].as_usize();
487                    if index >= src_dim_len {
488                        Err(Error::InvalidIndex {
489                            index,
490                            size: src_dim_len,
491                            op: "gather",
492                        }
493                        .bt())?
494                    }
495                    let src_idx = start_src_idx + index * src_right_len + right_i;
496                    dst[dst_idx] = src[src_idx]
497                }
498            }
499        }
500        Ok(dst)
501    }
502}
503
504struct IndexSelect<'a, T: IntDType> {
505    ids: &'a [T],
506    ids_l: &'a Layout,
507    dim: usize,
508}
509
510impl<I: IntDType> Map1 for IndexSelect<'_, I> {
511    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
512        let src = match layout.contiguous_offsets() {
513            Some((a, b)) => &src[a..b],
514            None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
515        };
516        let dim = self.dim;
517        let n_ids = match self.ids_l.dims() {
518            [n_ids] => *n_ids,
519            d => Err(Error::UnexpectedNumberOfDims {
520                expected: 1,
521                got: d.len(),
522                shape: self.ids_l.shape().clone(),
523            }
524            .bt())?,
525        };
526        let stride_ids = self.ids_l.stride()[0];
527        let mut dst_dims = layout.dims().to_vec();
528        let src_dim = dst_dims[dim];
529        dst_dims[dim] = n_ids;
530        let dst_len: usize = dst_dims.iter().product();
531        let left_len: usize = dst_dims[..dim].iter().product();
532        let right_len: usize = dst_dims[dim + 1..].iter().product();
533        let mut dst = vec![T::zero(); dst_len];
534        for left_i in 0..left_len {
535            let start_src_idx = left_i * right_len * src_dim;
536            let start_dst_idx = left_i * right_len * n_ids;
537            for i in 0..n_ids {
538                let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
539                if index >= src_dim {
540                    Err(Error::InvalidIndex {
541                        index,
542                        size: src_dim,
543                        op: "index-select",
544                    }
545                    .bt())?
546                }
547                let start_src_idx = start_src_idx + index * right_len;
548                let start_dst_idx = start_dst_idx + i * right_len;
549                dst[start_dst_idx..start_dst_idx + right_len]
550                    .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
551            }
552        }
553        Ok(dst)
554    }
555}
556
557trait ElemUpdate {
558    fn f<T: WithDType>(dst: &mut T, src: T);
559}
560
561struct Set;
562struct Add;
563
564impl ElemUpdate for Set {
565    fn f<T: WithDType>(dst: &mut T, src: T) {
566        *dst = src
567    }
568}
569
570impl ElemUpdate for Add {
571    fn f<T: WithDType>(dst: &mut T, src: T) {
572        *dst += src
573    }
574}
575
576struct Scatter<'a, I: IntDType, M: ElemUpdate> {
577    ids: &'a [I],
578    ids_l: &'a Layout,
579    dim: usize,
580    _phantom: std::marker::PhantomData<M>,
581}
582
583impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
584    fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
585        Self {
586            ids,
587            ids_l,
588            dim,
589            _phantom: Default::default(),
590        }
591    }
592}
593
594impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
595    const OP: &'static str = "scatter";
596    fn f<T: WithDType>(
597        &self,
598        dst: &mut [T],
599        dst_l: &Layout,
600        src: &[T],
601        src_l: &Layout,
602    ) -> Result<()> {
603        let dst = match dst_l.contiguous_offsets() {
604            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
605            Some((o1, o2)) => &mut dst[o1..o2],
606        };
607
608        let src = match src_l.contiguous_offsets() {
609            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
610            Some((o1, o2)) => &src[o1..o2],
611        };
612
613        let dim = self.dim;
614        let ids_dims = self.ids_l.dims();
615        let dst_dims = dst_l.dims();
616        let dst_dim_len = dst_dims[dim];
617        let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
618
619        let ids_left_len: usize = ids_dims[..dim].iter().product();
620        let ids_dim_len = ids_dims[dim];
621        let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
622
623        let ids = match self.ids_l.contiguous_offsets() {
624            Some((a, b)) => &self.ids[a..b],
625            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
626        };
627        for left_i in 0..ids_left_len {
628            let start_ids_idx = left_i * ids_right_len * ids_dim_len;
629            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
630            for i in 0..ids_dim_len {
631                let start_ids_idx = start_ids_idx + i * ids_right_len;
632                for right_i in 0..dst_right_len {
633                    let ids_idx = start_ids_idx + right_i;
634                    let index = ids[ids_idx].as_usize();
635                    if index >= dst_dim_len {
636                        Err(Error::InvalidIndex {
637                            index,
638                            size: dst_dim_len,
639                            op: "gather",
640                        }
641                        .bt())?
642                    }
643                    let dst_idx = start_dst_idx + index * dst_right_len + right_i;
644                    M::f(&mut dst[dst_idx], src[ids_idx])
645                }
646            }
647        }
648
649        Ok(())
650    }
651}
652
653struct IndexAdd<'a, I: IntDType> {
654    ids: &'a [I],
655    dim: usize,
656}
657
658impl<I: IntDType> Map2 for IndexAdd<'_, I> {
659    const OP: &'static str = "index-add";
660    // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
661    // v1, l1 -> self
662    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
663        let dst_len = l1.shape().elem_count();
664        let mut dst = vec![T::zero(); dst_len];
665        copy_strided_src_(v1, &mut dst, 0, l1);
666        let src = match src_l.contiguous_offsets() {
667            None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
668            Some((o1, o2)) => &src[o1..o2],
669        };
670        let dim = self.dim;
671        let max_idx = l1.dims()[dim];
672        let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
673        let src_dim_sz = src_l.dims()[dim];
674        let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
675        if dim == 0 {
676            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
677                let dst_idx = dst_idx.as_usize();
678                if dst_idx >= max_idx {
679                    Err(Error::InvalidIndex {
680                        index: dst_idx,
681                        op: "index-add",
682                        size: max_idx,
683                    })?
684                }
685                let src_idx = src_idx * post_dim;
686                let dst_idx = dst_idx * post_dim;
687                let src = &src[src_idx..src_idx + post_dim];
688                let dst = &mut dst[dst_idx..dst_idx + post_dim];
689                for (d, &s) in dst.iter_mut().zip(src.iter()) {
690                    *d += s
691                }
692            }
693        } else {
694            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
695                let dst_idx = dst_idx.as_usize();
696                if dst_idx >= max_idx {
697                    Err(Error::InvalidIndex {
698                        index: dst_idx,
699                        op: "index-add",
700                        size: max_idx,
701                    })?
702                }
703                for pre_i in 0..pre_dim {
704                    let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
705                    let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
706                    let src = &src[pre_src_i..pre_src_i + post_dim];
707                    let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
708                    for (d, &s) in dst.iter_mut().zip(src.iter()) {
709                        *d += s
710                    }
711                }
712            }
713        }
714        Ok(dst)
715    }
716}
717
718#[allow(clippy::too_many_arguments)]
719fn copy2d_<T: Copy>(
720    src: &[T],
721    dst: &mut [T],
722    d1: usize,
723    d2: usize,
724    src_stride1: usize,
725    dst_stride1: usize,
726    src_offset: usize,
727    dst_offset: usize,
728) {
729    for i1 in 0..d1 {
730        let dst_idx = i1 * dst_stride1 + dst_offset;
731        let src_idx = i1 * src_stride1 + src_offset;
732        let dst = &mut dst[dst_idx..dst_idx + d2];
733        let src = &src[src_idx..src_idx + d2];
734        dst.copy_from_slice(src)
735    }
736}
737
738fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
739    match src_l.strided_blocks() {
740        crate::StridedBlocks::SingleBlock { start_offset, len } => {
741            let to_copy = (dst.len() - dst_offset).min(len);
742            dst[dst_offset..dst_offset + to_copy]
743                .copy_from_slice(&src[start_offset..start_offset + to_copy])
744        }
745        crate::StridedBlocks::MultipleBlocks {
746            block_start_index,
747            block_len: 1,
748        } => {
749            for (dst_index, src_index) in block_start_index.enumerate() {
750                let dst_index = dst_index + dst_offset;
751                if dst_index >= dst.len() {
752                    break;
753                }
754                dst[dst_index] = src[src_index]
755            }
756        }
757        crate::StridedBlocks::MultipleBlocks {
758            block_start_index,
759            block_len,
760        } => {
761            let mut dst_index = dst_offset;
762            for src_index in block_start_index {
763                let next_dst_index = dst_index + block_len;
764                if dst_index >= dst.len() {
765                    break;
766                }
767                let to_copy = usize::min(block_len, dst.len() - dst_index);
768                dst[dst_index..dst_index + to_copy]
769                    .copy_from_slice(&src[src_index..src_index + to_copy]);
770                dst_index = next_dst_index
771            }
772        }
773    }
774}
775
776struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
777
778impl Map2 for Conv1D<'_> {
779    const OP: &'static str = "conv1d";
780    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
781        let p = self.0;
782        let inp = &inp[inp_l.start_offset()..];
783        let k = &k[k_l.start_offset()..];
784        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
785        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
786        let l_out = p.l_out();
787        let dst_elems = p.c_out * l_out * p.b_size;
788        // The output shape is [b_size, c_out, l_out]
789        let dst = vec![T::zero(); dst_elems];
790
791        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
792        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
793        for b_idx in 0..p.b_size {
794            for src_l in 0..p.l_in {
795                for src_c_idx in 0..p.c_in {
796                    let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
797                    inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
798                }
799            }
800        }
801
802        for offset in 0..p.k_size {
803            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
804                let dst_idx = dst_c_idx * l_out;
805                let k_cont = (0..p.c_in)
806                    .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
807                    .collect::<Vec<_>>();
808                for b_idx in 0..p.b_size {
809                    let dst_idx = dst_idx + b_idx * p.c_out * l_out;
810                    for dst_l in 0..l_out {
811                        let dst_idx = dst_idx + dst_l;
812                        let src_l = p.stride * dst_l + offset * p.dilation;
813                        if src_l < p.padding || src_l >= p.padding + p.l_in {
814                            continue;
815                        }
816                        let src_l = src_l - p.padding;
817                        let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
818                        assert!(inp_cont.len() >= p.c_in);
819                        assert!(k_cont.len() >= p.c_in);
820                        let mut d = T::zero();
821                        unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
822                        let dst_p = dst.as_ptr();
823                        // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
824                        // the different tasks so no two threads can try to write at the same
825                        // location.
826                        unsafe {
827                            let ptr = dst_p.add(dst_idx) as *mut T;
828                            *ptr += d
829                        }
830                    }
831                }
832            })
833        }
834        Ok(dst)
835    }
836}
837
838struct Im2Col1D {
839    l_k: usize,
840    stride: usize,
841    dilation: usize,
842    padding: usize,
843}
844
845impl Im2Col1D {
846    fn l_out(&self, l: usize) -> usize {
847        (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
848    }
849}
850
851impl Map1 for Im2Col1D {
852    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
853        let &Self {
854            l_k,
855            stride,
856            dilation,
857            padding,
858        } = self;
859        let (b, c, l) = layout.shape().dims3()?;
860        let l_out = self.l_out(l);
861        let src = &vs[layout.start_offset()..];
862        let mut dst = vec![T::zero(); b * l_out * c * l_k];
863        let (src_s0, src_s1, src_s2) = {
864            let s = layout.stride();
865            (s[0], s[1], s[2])
866        };
867        // TODO: provide specialized kernels for the common use cases.
868        // - l_k = 1
869        // - padding = 0
870        // - stride = 1
871        // - dilation = 1
872        for b_idx in 0..b {
873            let src_idx = b_idx * src_s0;
874            let dst_idx = b_idx * l_out * c * l_k;
875            for l_idx in 0..l_out {
876                let dst_idx = dst_idx + l_idx * c * l_k;
877                for c_idx in 0..c {
878                    let dst_idx = dst_idx + c_idx * l_k;
879                    let src_idx = c_idx * src_s1 + src_idx;
880                    for l_k_idx in 0..l_k {
881                        let src_l = l_idx * stride + l_k_idx * dilation;
882                        if padding != 0 && (src_l < padding || src_l >= l + padding) {
883                            continue;
884                        }
885                        let src_l = src_l - padding;
886                        let src_idx = src_idx + src_l * src_s2;
887                        let dst_idx = dst_idx + l_k_idx;
888                        dst[dst_idx] = src[src_idx]
889                    }
890                }
891            }
892        }
893        Ok(dst)
894    }
895}
896
897struct Im2Col {
898    h_k: usize,
899    w_k: usize,
900    stride: usize,
901    dilation: usize,
902    padding: usize,
903}
904
905impl Im2Col {
906    fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
907        let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
908        let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
909        (h_out, w_out)
910    }
911}
912
913impl Map1 for Im2Col {
914    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
915        let &Self {
916            h_k,
917            w_k,
918            stride,
919            dilation,
920            padding,
921        } = self;
922        let (b, c, h, w) = layout.shape().dims4()?;
923        let (h_out, w_out) = self.hw_out(h, w);
924        let src = &vs[layout.start_offset()..];
925        let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
926        let (src_s0, src_s1, src_s2, src_s3) = {
927            let s = layout.stride();
928            (s[0], s[1], s[2], s[3])
929        };
930        // TODO: provide specialized kernels for the common use cases.
931        // - h_k = w_k = 1
932        // - padding = 0
933        // - stride = 1
934        // - dilation = 1
935        for b_idx in 0..b {
936            let src_idx = b_idx * src_s0;
937            let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
938            for h_idx in 0..h_out {
939                let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
940                for w_idx in 0..w_out {
941                    let dst_idx = dst_idx + w_idx * c * h_k * w_k;
942                    for c_idx in 0..c {
943                        let dst_idx = dst_idx + c_idx * h_k * w_k;
944                        let src_idx = c_idx * src_s1 + src_idx;
945                        for h_k_idx in 0..h_k {
946                            let src_h = h_idx * stride + h_k_idx * dilation;
947                            if padding != 0 && (src_h < padding || src_h >= h + padding) {
948                                continue;
949                            }
950                            let src_h = src_h - padding;
951                            let src_idx = src_idx + src_h * src_s2;
952                            let dst_idx = dst_idx + h_k_idx * w_k;
953                            for w_k_idx in 0..w_k {
954                                let src_w = w_idx * stride + w_k_idx * dilation;
955                                if padding != 0 && (src_w < padding || src_w >= w + padding) {
956                                    continue;
957                                }
958                                let src_w = src_w - padding;
959                                let src_idx = src_idx + src_w * src_s3;
960                                let dst_idx = dst_idx + w_k_idx;
961                                dst[dst_idx] = src[src_idx]
962                            }
963                        }
964                    }
965                }
966            }
967        }
968        Ok(dst)
969    }
970}
971
972struct Col2Im1D {
973    stride: usize,
974}
975
976impl Map1 for Col2Im1D {
977    fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
978        let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
979        let stride = self.stride;
980        let l_out = (l_in - 1) * stride + k_size;
981        let mut im = vec![T::zero(); b_size * c_out * l_out];
982        let (dst_s0, dst_s1) = (c_out * l_out, l_out);
983        let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
984        for l_in_i in 0..l_in {
985            for k_i in 0..k_size {
986                let l_out_i = l_in_i * stride + k_i;
987                for b_i in 0..b_size {
988                    for c_i in 0..c_out {
989                        let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
990                        let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
991                        im[dst_idx] += col[src_idx]
992                    }
993                }
994            }
995        }
996        Ok(im)
997    }
998}
999
1000struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
1001
1002impl Map2 for ConvTranspose1D<'_> {
1003    const OP: &'static str = "conv_transpose1d";
1004    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1005        let p = self.0;
1006        let inp = &inp[inp_l.start_offset()..];
1007        let k = &k[k_l.start_offset()..];
1008        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1009        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1010        let l_out = p.l_out();
1011
1012        // Output shape: [b_size, c_out, l_out].
1013        let dst_elems = p.c_out * l_out * p.b_size;
1014        let dst = vec![T::zero(); dst_elems];
1015        let dst_s0 = p.c_out * l_out;
1016        let dst_s1 = l_out;
1017        let dst_s2 = 1;
1018
1019        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1020        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1021        let cont_s0 = p.l_in * p.c_in;
1022        let cont_s1 = p.c_in;
1023        for b_idx in 0..p.b_size {
1024            for l_idx in 0..p.l_in {
1025                for c_idx in 0..p.c_in {
1026                    let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
1027                    let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
1028                    inp_cont[dst_idx] = inp[src_idx]
1029                }
1030            }
1031        }
1032
1033        for k_idx in 0..p.k_size {
1034            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1035                let k_cont = (0..p.c_in)
1036                    .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
1037                    .collect::<Vec<_>>();
1038                for b_idx in 0..p.b_size {
1039                    for l_idx in 0..p.l_in {
1040                        let out_idx = l_idx * p.stride + k_idx * p.dilation;
1041                        if out_idx < p.padding {
1042                            continue;
1043                        }
1044                        let out_idx = out_idx - p.padding;
1045                        if out_idx < l_out {
1046                            let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
1047                            let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
1048                            let mut d = T::zero();
1049                            unsafe {
1050                                T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1051                            }
1052                            let dst_p = dst.as_ptr();
1053                            // Safety: dst_idx are uniques per dst_c_idx which is used to
1054                            // parallelise the different tasks so no two threads can try to
1055                            // write at the same location.
1056                            unsafe {
1057                                let ptr = dst_p.add(dst_idx) as *mut T;
1058                                *ptr += d
1059                            }
1060                        }
1061                    }
1062                }
1063            })
1064        }
1065        Ok(dst)
1066    }
1067}
1068
1069struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
1070
1071impl Map2 for Conv2D<'_> {
1072    const OP: &'static str = "conv2d";
1073    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1074        let p = self.0;
1075        let inp = &inp[inp_l.start_offset()..];
1076        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1077        let k = &k[k_l.start_offset()..];
1078        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1079        let (out_h, out_w) = (p.out_h(), p.out_w());
1080
1081        // Output shape: [b_size, c_out, out_h, out_w].
1082        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1083
1084        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1085        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1086        let cont_s0 = p.i_h * p.i_w * p.c_in;
1087        let cont_s1 = p.i_w * p.c_in;
1088        let cont_s2 = p.c_in;
1089        for b_idx in 0..p.b_size {
1090            for h_idx in 0..p.i_h {
1091                for w_idx in 0..p.i_w {
1092                    for c_idx in 0..p.c_in {
1093                        let src_idx =
1094                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1095                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1096                        inp_cont[dst_idx] = inp[src_idx]
1097                    }
1098                }
1099            }
1100        }
1101
1102        for offset_h in 0..p.k_h {
1103            for offset_w in 0..p.k_w {
1104                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1105                    let dst_idx = dst_c_idx * out_w * out_h;
1106                    let k_cont = (0..p.c_in)
1107                        .map(|c_in_idx| {
1108                            k[dst_c_idx * k_s0
1109                                + c_in_idx * k_s1
1110                                + offset_h * k_s2
1111                                + offset_w * k_s3]
1112                        })
1113                        .collect::<Vec<_>>();
1114                    for b_idx in 0..p.b_size {
1115                        let dst_idx = dst_idx + b_idx * p.c_out * out_h * out_w;
1116                        for dst_h in 0..out_h {
1117                            let dst_idx = dst_idx + dst_h * out_w;
1118                            let src_h = p.stride * dst_h + offset_h * p.dilation;
1119                            if src_h < p.padding || src_h >= p.i_h + p.padding {
1120                                continue;
1121                            }
1122                            let src_h = src_h - p.padding;
1123                            for dst_w in 0..out_w {
1124                                let dst_idx = dst_idx + dst_w;
1125                                let src_w = p.stride * dst_w + offset_w * p.dilation;
1126                                if src_w < p.padding || src_w >= p.i_w + p.padding {
1127                                    continue;
1128                                }
1129                                let src_w = src_w - p.padding;
1130                                let inp_cont = &inp_cont
1131                                    [b_idx * cont_s0 + src_h * cont_s1 + src_w * cont_s2..];
1132                                assert!(inp_cont.len() >= p.c_in);
1133                                assert!(k_cont.len() >= p.c_in);
1134                                let mut d = T::zero();
1135                                unsafe {
1136                                    T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1137                                }
1138                                let dst_p = dst.as_ptr();
1139                                // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
1140                                // the different tasks so no two threads can try to write at the same
1141                                // location.
1142                                unsafe {
1143                                    let ptr = dst_p.add(dst_idx) as *mut T;
1144                                    *ptr += d
1145                                }
1146                            }
1147                        }
1148                    }
1149                });
1150            }
1151        }
1152
1153        Ok(dst)
1154    }
1155}
1156
1157struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1158
1159impl Map2 for ConvTranspose2D<'_> {
1160    const OP: &'static str = "conv_transpose2d";
1161    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1162        let p = self.0;
1163        let inp = &inp[inp_l.start_offset()..];
1164        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1165        let k = &k[k_l.start_offset()..];
1166        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1167        let (out_h, out_w) = (p.out_h(), p.out_w());
1168
1169        // Output shape: [b_size, c_out, out_h, out_w].
1170        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1171        let dst_s0 = p.c_out * out_h * out_w;
1172        let dst_s1 = out_h * out_w;
1173        let dst_s2 = out_w;
1174        let dst_s3 = 1;
1175
1176        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1177        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1178        let cont_s0 = p.i_h * p.i_w * p.c_in;
1179        let cont_s1 = p.i_w * p.c_in;
1180        let cont_s2 = p.c_in;
1181        for b_idx in 0..p.b_size {
1182            for h_idx in 0..p.i_h {
1183                for w_idx in 0..p.i_w {
1184                    for c_idx in 0..p.c_in {
1185                        let src_idx =
1186                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1187                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1188                        inp_cont[dst_idx] = inp[src_idx]
1189                    }
1190                }
1191            }
1192        }
1193
1194        for k_y in 0..p.k_h {
1195            for k_x in 0..p.k_w {
1196                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1197                    let k_cont = (0..p.c_in)
1198                        .map(|c_in_idx| {
1199                            k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1200                        })
1201                        .collect::<Vec<_>>();
1202                    for b_idx in 0..p.b_size {
1203                        for inp_y in 0..p.i_h {
1204                            for inp_x in 0..p.i_w {
1205                                let out_x = inp_x * p.stride + k_x * p.dilation;
1206                                let out_y = inp_y * p.stride + k_y * p.dilation;
1207                                if out_x < p.padding || out_y < p.padding {
1208                                    continue;
1209                                }
1210                                let out_x = out_x - p.padding;
1211                                let out_y = out_y - p.padding;
1212                                if out_x < out_w && out_y < out_h {
1213                                    let inp_cont = &inp_cont
1214                                        [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1215                                    let dst_idx = b_idx * dst_s0
1216                                        + out_y * dst_s2
1217                                        + out_x * dst_s3
1218                                        + dst_c_idx * dst_s1;
1219                                    let mut d = T::zero();
1220                                    unsafe {
1221                                        T::vec_dot(
1222                                            inp_cont.as_ptr(),
1223                                            k_cont.as_ptr(),
1224                                            &mut d,
1225                                            p.c_in,
1226                                        )
1227                                    }
1228                                    let dst_p = dst.as_ptr();
1229                                    // Safety: dst_idx are uniques per dst_c_idx which is used to
1230                                    // parallelise the different tasks so no two threads can try to
1231                                    // write at the same location.
1232                                    unsafe {
1233                                        let ptr = dst_p.add(dst_idx) as *mut T;
1234                                        *ptr += d
1235                                    }
1236                                }
1237                            }
1238                        }
1239                    }
1240                })
1241            }
1242        }
1243        Ok(dst)
1244    }
1245}
1246
1247struct MatMul((usize, usize, usize, usize));
1248
1249impl MatMul {
1250    fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1251        Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1252            lhs_l: lhs_l.clone(),
1253            rhs_l: rhs_l.clone(),
1254            bmnk: self.0,
1255            msg,
1256        }))
1257        .bt()
1258    }
1259
1260    fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
1261        let lhs_stride = lhs_l.stride();
1262        let rhs_stride = rhs_l.stride();
1263        let rank = lhs_stride.len();
1264        let (_b, m, n, k) = self.0;
1265        let a_skip: usize = match lhs_stride[..rank - 2] {
1266            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1267            [_, stride] if lhs_l.dims()[0] == 1 => stride,
1268            [stride, _] if lhs_l.dims()[1] == 1 => stride,
1269            [stride] => stride,
1270            [] => m * k,
1271            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1272        };
1273        let b_skip: usize = match rhs_stride[..rank - 2] {
1274            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1275            [_, stride] if rhs_l.dims()[0] == 1 => stride,
1276            [stride, _] if rhs_l.dims()[1] == 1 => stride,
1277            [stride] => stride,
1278            [] => n * k,
1279            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1280        };
1281        Ok((a_skip, b_skip))
1282    }
1283}
1284
1285impl Map2 for MatMul {
1286    const OP: &'static str = "mat_mul";
1287
1288    #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1289    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1290        &self,
1291        lhs: &[T],
1292        lhs_l: &Layout,
1293        rhs: &[T],
1294        rhs_l: &Layout,
1295    ) -> Result<Vec<T>> {
1296        use gemm::{gemm, Parallelism};
1297
1298        match T::DTYPE {
1299            DType::F16 | DType::F32 | DType::F64 => {}
1300            _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1301        }
1302
1303        let (b, m, n, k) = self.0;
1304        let lhs = &lhs[lhs_l.start_offset()..];
1305        let rhs = &rhs[rhs_l.start_offset()..];
1306
1307        let lhs_stride = lhs_l.stride();
1308        let rhs_stride = rhs_l.stride();
1309        let rank = lhs_stride.len();
1310        let lhs_cs = lhs_stride[rank - 1];
1311        let lhs_rs = lhs_stride[rank - 2];
1312
1313        let rhs_cs = rhs_stride[rank - 1];
1314        let rhs_rs = rhs_stride[rank - 2];
1315
1316        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1317        let c_skip: usize = m * n;
1318
1319        let dst_shape: Shape = (m, n).into();
1320        let dst_strides = dst_shape.stride_contiguous();
1321        let dst_rs = dst_strides[0];
1322        let dst_cs = dst_strides[1];
1323
1324        let mut dst = vec![T::zero(); b * m * n];
1325        let num_threads = crate::utils::get_num_threads();
1326        let parallelism = if num_threads > 1 {
1327            Parallelism::Rayon(num_threads)
1328        } else {
1329            Parallelism::None
1330        };
1331        let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
1332            // a_skip and c_skip should be updated but step is always 0 so
1333            // it wouldn't matter.
1334            (1, b * m, n, k)
1335        } else if a_skip == 0 && b_skip == n * k {
1336            (1, m, b * n, k)
1337        } else {
1338            (b, m, n, k)
1339        };
1340        for step in 0..b {
1341            let lhs_p = &lhs[step * a_skip..];
1342            let rhs_p = &rhs[step * b_skip..];
1343            let dst_p = &mut dst[step * c_skip..];
1344            unsafe {
1345                gemm(
1346                    /* m: usize = */ m,
1347                    /* n: usize = */ n,
1348                    /* k: usize = */ k,
1349                    /* dst: *mut T = */ dst_p.as_mut_ptr(),
1350                    /* dst_cs: isize = */ dst_cs as isize,
1351                    /* dst_rs: isize = */ dst_rs as isize,
1352                    /* read_dst: bool = */ false,
1353                    /* lhs: *const T = */ lhs_p.as_ptr(),
1354                    /* lhs_cs: isize = */ lhs_cs as isize,
1355                    /* lhs_rs: isize = */ lhs_rs as isize,
1356                    /* rhs: *const T = */ rhs_p.as_ptr(),
1357                    /* rhs_cs: isize = */ rhs_cs as isize,
1358                    /* rhs_rs: isize = */ rhs_rs as isize,
1359                    /* alpha: T = */ T::zero(),
1360                    /* beta: T = */ T::one(),
1361                    /* conj_dst: bool = */ false,
1362                    /* conj_lhs: bool = */ false,
1363                    /* conj_rhs: bool = */ false,
1364                    parallelism,
1365                )
1366            }
1367        }
1368        Ok(dst)
1369    }
1370
1371    #[cfg(feature = "accelerate")]
1372    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1373        &self,
1374        lhs: &[T],
1375        lhs_l: &Layout,
1376        rhs: &[T],
1377        rhs_l: &Layout,
1378    ) -> Result<Vec<T>> {
1379        let (b, m, n, k) = self.0;
1380        let lhs = &lhs[lhs_l.start_offset()..];
1381        let rhs = &rhs[rhs_l.start_offset()..];
1382
1383        let lhs_stride = lhs_l.stride();
1384        let rhs_stride = rhs_l.stride();
1385
1386        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1387        let c_skip: usize = m * n;
1388
1389        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1390        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1391        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1392        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1393
1394        let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1395            (n as i32, b'N')
1396        } else if rhs_m1 == k && rhs_m2 == 1 {
1397            (k as i32, b'T')
1398        } else {
1399            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1400        };
1401        // The b tensor has dims batching, m, k (lhs)
1402        let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1403            (k as i32, b'N')
1404        } else if lhs_m1 == m && lhs_m2 == 1 {
1405            (m as i32, b'T')
1406        } else {
1407            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1408        };
1409
1410        let mut dst = vec![T::zero(); b * m * n];
1411        match T::DTYPE {
1412            DType::F16 => {
1413                crate::bail!("the accelerate backend does not support f16 matmul")
1414            }
1415            DType::F32 => {
1416                for step in 0..b {
1417                    let lhs_p = &lhs[step * a_skip..];
1418                    let rhs_p = &rhs[step * b_skip..];
1419                    let dst_p = &mut dst[step * c_skip..];
1420                    unsafe {
1421                        let a = rhs_p.as_ptr() as *const f32;
1422                        let b = lhs_p.as_ptr() as *const f32;
1423                        let c = dst_p.as_mut_ptr() as *mut f32;
1424                        let a = std::slice::from_raw_parts(a, a_skip);
1425                        let b = std::slice::from_raw_parts(b, b_skip);
1426                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1427                        crate::accelerate::sgemm(
1428                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1429                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1430                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1431                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1432                        )
1433                    }
1434                }
1435            }
1436            DType::F64 => {
1437                for step in 0..b {
1438                    let lhs_p = &lhs[step * a_skip..];
1439                    let rhs_p = &rhs[step * b_skip..];
1440                    let dst_p = &mut dst[step * c_skip..];
1441                    unsafe {
1442                        let a = rhs_p.as_ptr() as *const f64;
1443                        let b = lhs_p.as_ptr() as *const f64;
1444                        let c = dst_p.as_mut_ptr() as *mut f64;
1445                        let a = std::slice::from_raw_parts(a, a_skip);
1446                        let b = std::slice::from_raw_parts(b, b_skip);
1447                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1448                        crate::accelerate::dgemm(
1449                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1450                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1451                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1452                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1453                        )
1454                    }
1455                }
1456            }
1457            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1458        }
1459        Ok(dst)
1460    }
1461
1462    #[cfg(feature = "mkl")]
1463    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1464        &self,
1465        lhs: &[T],
1466        lhs_l: &Layout,
1467        rhs: &[T],
1468        rhs_l: &Layout,
1469    ) -> Result<Vec<T>> {
1470        let (b, m, n, k) = self.0;
1471        let lhs = &lhs[lhs_l.start_offset()..];
1472        let rhs = &rhs[rhs_l.start_offset()..];
1473
1474        let lhs_stride = lhs_l.stride();
1475        let rhs_stride = rhs_l.stride();
1476
1477        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1478        let c_skip: usize = m * n;
1479
1480        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1481        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1482        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1483        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1484
1485        let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1486            (n as i32, b'N')
1487        } else if rhs_m1 == k && rhs_m2 == 1 {
1488            (k as i32, b'T')
1489        } else {
1490            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1491        };
1492        // The b tensor has dims batching, m, k (lhs)
1493        let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1494            (k as i32, b'N')
1495        } else if lhs_m1 == m && lhs_m2 == 1 {
1496            (m as i32, b'T')
1497        } else {
1498            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1499        };
1500
1501        let mut dst = vec![T::zero(); b * m * n];
1502        match T::DTYPE {
1503            DType::F16 => {
1504                for step in 0..b {
1505                    let lhs_p = &lhs[step * a_skip..];
1506                    let rhs_p = &rhs[step * b_skip..];
1507                    let dst_p = &mut dst[step * c_skip..];
1508                    unsafe {
1509                        let a = rhs_p.as_ptr() as *const f16;
1510                        let b = lhs_p.as_ptr() as *const f16;
1511                        let c = dst_p.as_mut_ptr() as *mut f16;
1512                        let a = std::slice::from_raw_parts(a, a_skip);
1513                        let b = std::slice::from_raw_parts(b, b_skip);
1514                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1515                        crate::mkl::hgemm(
1516                            transa,
1517                            transb,
1518                            /* m= */ n as i32,
1519                            /* n= */ m as i32,
1520                            /* k= */ k as i32,
1521                            /* alpha= */ f16::ONE,
1522                            /* a= */ a,
1523                            /* lda= */ lda,
1524                            /* b= */ b,
1525                            /* ldb= */ ldb,
1526                            /* beta= */ f16::ZERO,
1527                            /* c= */ c,
1528                            /* ldc= */ n as i32,
1529                        )
1530                    }
1531                }
1532            }
1533            DType::F32 => {
1534                for step in 0..b {
1535                    let lhs_p = &lhs[step * a_skip..];
1536                    let rhs_p = &rhs[step * b_skip..];
1537                    let dst_p = &mut dst[step * c_skip..];
1538                    unsafe {
1539                        let a = rhs_p.as_ptr() as *const f32;
1540                        let b = lhs_p.as_ptr() as *const f32;
1541                        let c = dst_p.as_mut_ptr() as *mut f32;
1542                        let a = std::slice::from_raw_parts(a, a_skip);
1543                        let b = std::slice::from_raw_parts(b, b_skip);
1544                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1545                        crate::mkl::sgemm(
1546                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1547                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1548                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1549                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1550                        )
1551                    }
1552                }
1553            }
1554            DType::F64 => {
1555                for step in 0..b {
1556                    let lhs_p = &lhs[step * a_skip..];
1557                    let rhs_p = &rhs[step * b_skip..];
1558                    let dst_p = &mut dst[step * c_skip..];
1559                    unsafe {
1560                        let a = rhs_p.as_ptr() as *const f64;
1561                        let b = lhs_p.as_ptr() as *const f64;
1562                        let c = dst_p.as_mut_ptr() as *mut f64;
1563                        let a = std::slice::from_raw_parts(a, a_skip);
1564                        let b = std::slice::from_raw_parts(b, b_skip);
1565                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1566                        crate::mkl::dgemm(
1567                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1568                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1569                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1570                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1571                        )
1572                    }
1573                }
1574            }
1575            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1576        }
1577        Ok(dst)
1578    }
1579}
1580
1581fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1582    if v.is_sign_positive() {
1583        v
1584    } else {
1585        (v.exp() - T::one()) * alpha
1586    }
1587}
1588
1589impl CpuStorage {
1590    pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1591        D::cpu_storage_as_slice(self)
1592    }
1593
1594    pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1595        let storage0 = &storages[0];
1596        let s = match storage0 {
1597            Self::U8(_) => {
1598                let storages = storages
1599                    .iter()
1600                    .map(|s| match s {
1601                        Self::U8(s) => Ok(s.as_slice()),
1602                        _ => crate::bail!("dtype mismatch"),
1603                    })
1604                    .collect::<Result<Vec<_>>>()?
1605                    .concat();
1606                Self::U8(storages)
1607            }
1608            Self::U32(_) => {
1609                let storages = storages
1610                    .iter()
1611                    .map(|s| match s {
1612                        Self::U32(s) => Ok(s.as_slice()),
1613                        _ => crate::bail!("dtype mismatch"),
1614                    })
1615                    .collect::<Result<Vec<_>>>()?
1616                    .concat();
1617                Self::U32(storages)
1618            }
1619            Self::I64(_) => {
1620                let storages = storages
1621                    .iter()
1622                    .map(|s| match s {
1623                        Self::I64(s) => Ok(s.as_slice()),
1624                        _ => crate::bail!("dtype mismatch"),
1625                    })
1626                    .collect::<Result<Vec<_>>>()?
1627                    .concat();
1628                Self::I64(storages)
1629            }
1630            Self::BF16(_) => {
1631                let storages = storages
1632                    .iter()
1633                    .map(|s| match s {
1634                        Self::BF16(s) => Ok(s.as_slice()),
1635                        _ => crate::bail!("dtype mismatch"),
1636                    })
1637                    .collect::<Result<Vec<_>>>()?
1638                    .concat();
1639                Self::BF16(storages)
1640            }
1641            Self::F16(_) => {
1642                let storages = storages
1643                    .iter()
1644                    .map(|s| match s {
1645                        Self::F16(s) => Ok(s.as_slice()),
1646                        _ => crate::bail!("dtype mismatch"),
1647                    })
1648                    .collect::<Result<Vec<_>>>()?
1649                    .concat();
1650                Self::F16(storages)
1651            }
1652            Self::F32(_) => {
1653                let storages = storages
1654                    .iter()
1655                    .map(|s| match s {
1656                        Self::F32(s) => Ok(s.as_slice()),
1657                        _ => crate::bail!("dtype mismatch"),
1658                    })
1659                    .collect::<Result<Vec<_>>>()?
1660                    .concat();
1661                Self::F32(storages)
1662            }
1663            Self::F64(_) => {
1664                let storages = storages
1665                    .iter()
1666                    .map(|s| match s {
1667                        Self::F64(s) => Ok(s.as_slice()),
1668                        _ => crate::bail!("dtype mismatch"),
1669                    })
1670                    .collect::<Result<Vec<_>>>()?
1671                    .concat();
1672                Self::F64(storages)
1673            }
1674        };
1675        Ok(s)
1676    }
1677}
1678
1679impl BackendStorage for CpuStorage {
1680    type Device = CpuDevice;
1681
1682    fn dtype(&self) -> DType {
1683        match self {
1684            Self::U8(_) => DType::U8,
1685            Self::U32(_) => DType::U32,
1686            Self::I64(_) => DType::I64,
1687            Self::BF16(_) => DType::BF16,
1688            Self::F16(_) => DType::F16,
1689            Self::F32(_) => DType::F32,
1690            Self::F64(_) => DType::F64,
1691        }
1692    }
1693
1694    fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1695        // TODO: find a way around the quadratic number of cases below.
1696        match (self, dtype) {
1697            (Self::U8(storage), DType::BF16) => {
1698                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1699                Ok(Self::BF16(data))
1700            }
1701            (Self::U32(storage), DType::BF16) => {
1702                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1703                Ok(Self::BF16(data))
1704            }
1705            (Self::I64(storage), DType::BF16) => {
1706                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1707                Ok(Self::BF16(data))
1708            }
1709            (Self::BF16(storage), DType::BF16) => {
1710                let data = unary_map(storage, layout, |v| v);
1711                Ok(Self::BF16(data))
1712            }
1713            (Self::F16(storage), DType::BF16) => {
1714                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1715                Ok(Self::BF16(data))
1716            }
1717            (Self::F32(storage), DType::BF16) => {
1718                let data = unary_map(storage, layout, bf16::from_f32);
1719                Ok(Self::BF16(data))
1720            }
1721            (Self::F64(storage), DType::BF16) => {
1722                let data = unary_map(storage, layout, bf16::from_f64);
1723                Ok(Self::BF16(data))
1724            }
1725            (Self::U8(storage), DType::F16) => {
1726                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1727                Ok(Self::F16(data))
1728            }
1729            (Self::U32(storage), DType::F16) => {
1730                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1731                Ok(Self::F16(data))
1732            }
1733            (Self::I64(storage), DType::F16) => {
1734                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1735                Ok(Self::F16(data))
1736            }
1737            (Self::BF16(storage), DType::F16) => {
1738                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1739                Ok(Self::F16(data))
1740            }
1741            (Self::F16(storage), DType::F16) => {
1742                let data = unary_map(storage, layout, |v| v);
1743                Ok(Self::F16(data))
1744            }
1745            (Self::F32(storage), DType::F16) => {
1746                let data = unary_map(storage, layout, f16::from_f32);
1747                Ok(Self::F16(data))
1748            }
1749            (Self::F64(storage), DType::F16) => {
1750                let data = unary_map(storage, layout, f16::from_f64);
1751                Ok(Self::F16(data))
1752            }
1753            (Self::U8(storage), DType::F32) => {
1754                let data = unary_map(storage, layout, |v| v as f32);
1755                Ok(Self::F32(data))
1756            }
1757            (Self::U32(storage), DType::F32) => {
1758                let data = unary_map(storage, layout, |v| v as f32);
1759                Ok(Self::F32(data))
1760            }
1761            (Self::I64(storage), DType::F32) => {
1762                let data = unary_map(storage, layout, |v| v as f32);
1763                Ok(Self::F32(data))
1764            }
1765            (Self::BF16(storage), DType::F32) => {
1766                let data = unary_map(storage, layout, |v| v.to_f32());
1767                Ok(Self::F32(data))
1768            }
1769            (Self::F16(storage), DType::F32) => {
1770                let data = unary_map(storage, layout, |v| v.to_f32());
1771                Ok(Self::F32(data))
1772            }
1773            (Self::F32(storage), DType::F32) => {
1774                let data = unary_map(storage, layout, |v| v);
1775                Ok(Self::F32(data))
1776            }
1777            (Self::F64(storage), DType::F32) => {
1778                let data = unary_map(storage, layout, |v| v as f32);
1779                Ok(Self::F32(data))
1780            }
1781            (Self::U8(storage), DType::U8) => {
1782                let data = unary_map(storage, layout, |v| v);
1783                Ok(Self::U8(data))
1784            }
1785            (Self::BF16(storage), DType::U8) => {
1786                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1787                Ok(Self::U8(data))
1788            }
1789            (Self::F16(storage), DType::U8) => {
1790                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1791                Ok(Self::U8(data))
1792            }
1793            (Self::F32(storage), DType::U8) => {
1794                let data = unary_map(storage, layout, |v| v as u8);
1795                Ok(Self::U8(data))
1796            }
1797            (Self::F64(storage), DType::U8) => {
1798                let data = unary_map(storage, layout, |v| v as u8);
1799                Ok(Self::U8(data))
1800            }
1801            (Self::U32(storage), DType::U8) => {
1802                let data = unary_map(storage, layout, |v| v as u8);
1803                Ok(Self::U8(data))
1804            }
1805            (Self::I64(storage), DType::U8) => {
1806                let data = unary_map(storage, layout, |v| v as u8);
1807                Ok(Self::U8(data))
1808            }
1809            (Self::U8(storage), DType::U32) => {
1810                let data = unary_map(storage, layout, |v| v as u32);
1811                Ok(Self::U32(data))
1812            }
1813            (Self::U32(storage), DType::U32) => {
1814                let data = unary_map(storage, layout, |v| v);
1815                Ok(Self::U32(data))
1816            }
1817            (Self::I64(storage), DType::U32) => {
1818                let data = unary_map(storage, layout, |v| v as u32);
1819                Ok(Self::U32(data))
1820            }
1821            (Self::BF16(storage), DType::U32) => {
1822                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1823                Ok(Self::U32(data))
1824            }
1825            (Self::F16(storage), DType::U32) => {
1826                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1827                Ok(Self::U32(data))
1828            }
1829            (Self::F32(storage), DType::U32) => {
1830                let data = unary_map(storage, layout, |v| v as u32);
1831                Ok(Self::U32(data))
1832            }
1833            (Self::F64(storage), DType::U32) => {
1834                let data = unary_map(storage, layout, |v| v as u32);
1835                Ok(Self::U32(data))
1836            }
1837            (Self::U8(storage), DType::I64) => {
1838                let data = unary_map(storage, layout, |v| v as i64);
1839                Ok(Self::I64(data))
1840            }
1841            (Self::U32(storage), DType::I64) => {
1842                let data = unary_map(storage, layout, |v| v as i64);
1843                Ok(Self::I64(data))
1844            }
1845            (Self::I64(storage), DType::I64) => {
1846                let data = unary_map(storage, layout, |v| v);
1847                Ok(Self::I64(data))
1848            }
1849            (Self::BF16(storage), DType::I64) => {
1850                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
1851                Ok(Self::I64(data))
1852            }
1853            (Self::F16(storage), DType::I64) => {
1854                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
1855                Ok(Self::I64(data))
1856            }
1857            (Self::F32(storage), DType::I64) => {
1858                let data = unary_map(storage, layout, |v| v as i64);
1859                Ok(Self::I64(data))
1860            }
1861            (Self::F64(storage), DType::I64) => {
1862                let data = unary_map(storage, layout, |v| v as i64);
1863                Ok(Self::I64(data))
1864            }
1865            (Self::U8(storage), DType::F64) => {
1866                let data = unary_map(storage, layout, |v| v as f64);
1867                Ok(Self::F64(data))
1868            }
1869            (Self::U32(storage), DType::F64) => {
1870                let data = unary_map(storage, layout, |v| v as f64);
1871                Ok(Self::F64(data))
1872            }
1873            (Self::I64(storage), DType::F64) => {
1874                let data = unary_map(storage, layout, |v| v as f64);
1875                Ok(Self::F64(data))
1876            }
1877            (Self::BF16(storage), DType::F64) => {
1878                let data = unary_map(storage, layout, |v| v.to_f64());
1879                Ok(Self::F64(data))
1880            }
1881            (Self::F16(storage), DType::F64) => {
1882                let data = unary_map(storage, layout, |v| v.to_f64());
1883                Ok(Self::F64(data))
1884            }
1885            (Self::F32(storage), DType::F64) => {
1886                let data = unary_map(storage, layout, |v| v as f64);
1887                Ok(Self::F64(data))
1888            }
1889            (Self::F64(storage), DType::F64) => {
1890                let data = unary_map(storage, layout, |v| v);
1891                Ok(Self::F64(data))
1892            }
1893        }
1894    }
1895
1896    fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
1897        match op {
1898            ReduceOp::Sum => {
1899                let src_dims = layout.dims();
1900                let mut dst_dims = src_dims.to_vec();
1901                for &dim in reduce_dims.iter() {
1902                    dst_dims[dim] = 1;
1903                }
1904                let dst_shape = Shape::from(dst_dims);
1905                let mut reduce_dims = reduce_dims.to_vec();
1906                // Sort the reduce_dims as they have to be processed from left to right when converting the
1907                // indexes.
1908                reduce_dims.sort();
1909                let reduce_dims_and_stride: Vec<_> = reduce_dims
1910                    .iter()
1911                    .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
1912                    .collect();
1913                ReduceSum {
1914                    dst_shape: &dst_shape,
1915                    reduce_dims: &reduce_dims,
1916                    reduce_dims_and_stride,
1917                }
1918                .map(self, layout)
1919            }
1920            ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
1921                let reduce_dim_index = match reduce_dims {
1922                    [reduce_dim_index] => *reduce_dim_index,
1923                    _ => {
1924                        let op = match op {
1925                            ReduceOp::Min => "min",
1926                            ReduceOp::ArgMin => "argmin",
1927                            ReduceOp::Max => "max",
1928                            ReduceOp::ArgMax => "argmax",
1929                            _ => unreachable!(),
1930                        };
1931                        let dims = reduce_dims.to_vec();
1932                        Err(Error::OnlySingleDimension { op, dims })?
1933                    }
1934                };
1935                let (use_min, return_index) = match op {
1936                    ReduceOp::Min => (true, false),
1937                    ReduceOp::ArgMin => (true, true),
1938                    ReduceOp::Max => (false, false),
1939                    ReduceOp::ArgMax => (false, true),
1940                    _ => unreachable!(),
1941                };
1942                ReduceIndex {
1943                    reduce_dim_index,
1944                    use_min,
1945                    return_index,
1946                }
1947                .map(self, layout)
1948            }
1949        }
1950    }
1951
1952    fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
1953        Cmp(op).map(self, lhs_l, rhs, rhs_l)
1954    }
1955
1956    fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
1957        Affine(mul, add).map(self, layout)
1958    }
1959
1960    fn avg_pool2d(
1961        &self,
1962        layout: &Layout,
1963        kernel_size: (usize, usize),
1964        stride: (usize, usize),
1965    ) -> Result<Self> {
1966        AvgPool2D(kernel_size, stride).map(self, layout)
1967    }
1968
1969    fn max_pool2d(
1970        &self,
1971        layout: &Layout,
1972        kernel_size: (usize, usize),
1973        stride: (usize, usize),
1974    ) -> Result<Self> {
1975        MaxPool2D(kernel_size, stride).map(self, layout)
1976    }
1977
1978    fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
1979        UpsampleNearest1D(sz).map(self, layout)
1980    }
1981
1982    fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
1983        UpsampleNearest2D(h, w).map(self, layout)
1984    }
1985
1986    fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
1987        use num_traits::Float;
1988        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
1989        match self {
1990            Self::BF16(storage) => {
1991                let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
1992                Ok(Self::BF16(data))
1993            }
1994            Self::F16(storage) => {
1995                let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
1996                Ok(Self::F16(data))
1997            }
1998            Self::F32(storage) => {
1999                let data = unary_map(storage, layout, |v| v.powf(e as f32));
2000                Ok(Self::F32(data))
2001            }
2002            Self::F64(storage) => {
2003                let data = unary_map(storage, layout, |v| v.powf(e));
2004                Ok(Self::F64(data))
2005            }
2006            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2007            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2008            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2009        }
2010    }
2011
2012    fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2013        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2014        match self {
2015            Self::BF16(storage) => {
2016                let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2017                Ok(Self::BF16(data))
2018            }
2019            Self::F16(storage) => {
2020                let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2021                Ok(Self::F16(data))
2022            }
2023            Self::F32(storage) => {
2024                let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2025                Ok(Self::F32(data))
2026            }
2027            Self::F64(storage) => {
2028                let data = unary_map(storage, layout, |v| elu(v, alpha));
2029                Ok(Self::F64(data))
2030            }
2031            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2032            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2033            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2034        }
2035    }
2036
2037    fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2038        match self {
2039            Self::BF16(storage) => {
2040                if B::BF16_VEC {
2041                    let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2042                    Ok(Self::BF16(data))
2043                } else {
2044                    let data = unary_map(storage, layout, B::bf16);
2045                    Ok(Self::BF16(data))
2046                }
2047            }
2048            Self::F16(storage) => {
2049                if B::F16_VEC {
2050                    let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2051                    Ok(Self::F16(data))
2052                } else {
2053                    let data = unary_map(storage, layout, B::f16);
2054                    Ok(Self::F16(data))
2055                }
2056            }
2057            Self::F32(storage) => {
2058                if B::F32_VEC {
2059                    let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2060                    Ok(Self::F32(data))
2061                } else {
2062                    let data = unary_map(storage, layout, B::f32);
2063                    Ok(Self::F32(data))
2064                }
2065            }
2066            Self::F64(storage) => {
2067                if B::F64_VEC {
2068                    let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2069                    Ok(Self::F64(data))
2070                } else {
2071                    let data = unary_map(storage, layout, B::f64);
2072                    Ok(Self::F64(data))
2073                }
2074            }
2075            Self::U8(storage) => {
2076                let data = unary_map(storage, layout, B::u8);
2077                Ok(Self::U8(data))
2078            }
2079            Self::U32(storage) => {
2080                let data = unary_map(storage, layout, B::u32);
2081                Ok(Self::U32(data))
2082            }
2083            Self::I64(storage) => {
2084                let data = unary_map(storage, layout, B::i64);
2085                Ok(Self::I64(data))
2086            }
2087        }
2088    }
2089
2090    fn binary_impl<B: BinaryOpT>(
2091        &self,
2092        rhs: &Self,
2093        lhs_l: &Layout,
2094        rhs_l: &Layout,
2095    ) -> Result<Self> {
2096        match (self, rhs) {
2097            (Self::BF16(lhs), Self::BF16(rhs)) => {
2098                let data = if B::BF16_VEC {
2099                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
2100                } else {
2101                    binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
2102                };
2103                Ok(Self::BF16(data))
2104            }
2105            (Self::F16(lhs), Self::F16(rhs)) => {
2106                let data = if B::F16_VEC {
2107                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
2108                } else {
2109                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
2110                };
2111                Ok(Self::F16(data))
2112            }
2113            (Self::F32(lhs), Self::F32(rhs)) => {
2114                let data = if B::F32_VEC {
2115                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
2116                } else {
2117                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
2118                };
2119                Ok(Self::F32(data))
2120            }
2121            (Self::F64(lhs), Self::F64(rhs)) => {
2122                let data = if B::F64_VEC {
2123                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
2124                } else {
2125                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
2126                };
2127                Ok(Self::F64(data))
2128            }
2129            (Self::U32(lhs), Self::U32(rhs)) => {
2130                let data = if B::U32_VEC {
2131                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
2132                } else {
2133                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
2134                };
2135                Ok(Self::U32(data))
2136            }
2137            (Self::I64(lhs), Self::I64(rhs)) => {
2138                let data = if B::I64_VEC {
2139                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
2140                } else {
2141                    binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
2142                };
2143                Ok(Self::I64(data))
2144            }
2145            (Self::U8(lhs), Self::U8(rhs)) => {
2146                let data = if B::U8_VEC {
2147                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
2148                } else {
2149                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
2150                };
2151                Ok(Self::U8(data))
2152            }
2153            _ => {
2154                // This should be covered by the dtype check above.
2155                Err(Error::DTypeMismatchBinaryOp {
2156                    lhs: self.dtype(),
2157                    rhs: rhs.dtype(),
2158                    op: B::NAME,
2159                }
2160                .bt())
2161            }
2162        }
2163    }
2164
2165    fn copy2d(
2166        &self,
2167        dst: &mut Self,
2168        d1: usize,
2169        d2: usize,
2170        src_s: usize,
2171        dst_s: usize,
2172        src_o: usize,
2173        dst_o: usize,
2174    ) -> Result<()> {
2175        match (self, dst) {
2176            (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2177            (Self::U32(src), Self::U32(dst)) => {
2178                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2179            }
2180            (Self::I64(src), Self::I64(dst)) => {
2181                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2182            }
2183            (Self::BF16(src), Self::BF16(dst)) => {
2184                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2185            }
2186            (Self::F16(src), Self::F16(dst)) => {
2187                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2188            }
2189            (Self::F32(src), Self::F32(dst)) => {
2190                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2191            }
2192            (Self::F64(src), Self::F64(dst)) => {
2193                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2194            }
2195            (_, dst) => {
2196                return Err(Error::DTypeMismatchBinaryOp {
2197                    lhs: self.dtype(),
2198                    rhs: dst.dtype(),
2199                    op: "copy2d",
2200                }
2201                .bt());
2202            }
2203        }
2204        Ok(())
2205    }
2206
2207    fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2208        match (self, dst) {
2209            (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2210            (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2211            (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2212            (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2213            (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2214            (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2215            (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2216            (_, dst) => {
2217                // This should be covered by the dtype check above.
2218                return Err(Error::DTypeMismatchBinaryOp {
2219                    lhs: self.dtype(),
2220                    rhs: dst.dtype(),
2221                    op: "copy_strided",
2222                }
2223                .bt());
2224            }
2225        }
2226        Ok(())
2227    }
2228
2229    fn where_cond(
2230        &self,
2231        layout: &Layout,
2232        t: &Self,
2233        t_l: &Layout,
2234        f: &Self,
2235        f_l: &Layout,
2236    ) -> Result<Self> {
2237        match self {
2238            Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2239            Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2240            Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2241            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2242        }
2243    }
2244
2245    fn conv1d(
2246        &self,
2247        l: &Layout,
2248        kernel: &Self,
2249        kernel_l: &Layout,
2250        params: &crate::conv::ParamsConv1D,
2251    ) -> Result<Self> {
2252        if !USE_IM2COL_CONV1D {
2253            return Conv1D(params).map(self, l, kernel, kernel_l);
2254        }
2255        let op = Im2Col1D {
2256            l_k: params.k_size,
2257            padding: params.padding,
2258            stride: params.stride,
2259            dilation: params.dilation,
2260        };
2261        let col = op.map(self, l)?;
2262        let b = params.b_size;
2263        let n = params.c_out;
2264        let l_out = params.l_out();
2265        let k = op.l_k * params.c_in;
2266        let m = l_out;
2267        let col_l = Layout::contiguous((b, m, k));
2268        let res = if kernel_l.is_contiguous() {
2269            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2270                .transpose(1, 2)?
2271                .broadcast_as((b, k, n))?;
2272            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2273        } else {
2274            // Make the kernel contiguous if not already the case.
2275            let mut kernel_c = unsafe {
2276                self.device()
2277                    .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2278            };
2279            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2280            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2281                .transpose(1, 2)?
2282                .broadcast_as((b, k, n))?;
2283            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2284        };
2285        let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2286        let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2287        res.copy_strided_src(&mut res_t, 0, &res_l)?;
2288        Ok(res_t)
2289    }
2290
2291    fn conv_transpose1d(
2292        &self,
2293        l: &Layout,
2294        kernel: &Self,
2295        kernel_l: &Layout,
2296        params: &crate::conv::ParamsConvTranspose1D,
2297    ) -> Result<Self> {
2298        let can_use_col2im = kernel_l.is_contiguous()
2299            && params.dilation == 1
2300            && params.padding == 0
2301            && params.output_padding == 0;
2302        if USE_COL2IM_CONV1D_TR && can_use_col2im {
2303            let (b_size, c_in, l_in) = l.shape().dims3()?;
2304            let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
2305            if !kernel_l.is_contiguous() {
2306                crate::bail!(
2307                    "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
2308                )
2309            }
2310            if c_in != c_in2 {
2311                crate::bail!(
2312                    "convtr1d: shape mismatch on c_in {:?} {:?}",
2313                    l.shape(),
2314                    kernel_l.shape()
2315                )
2316            }
2317            let col = {
2318                // This merges the last two dimensions of the kernel together.
2319                let kernel_l_mm = Layout::new(
2320                    (b_size, c_in, k_size * c_out).into(),
2321                    vec![0, k_size * c_out, 1],
2322                    kernel_l.start_offset(),
2323                );
2324                self.matmul(
2325                    kernel,
2326                    (
2327                        b_size,
2328                        /* m */ l_in,
2329                        /* n */ c_out * k_size,
2330                        /* k */ c_in,
2331                    ),
2332                    &l.transpose(1, 2)?,
2333                    &kernel_l_mm,
2334                )?
2335            };
2336            let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
2337            Col2Im1D {
2338                stride: params.stride,
2339            }
2340            .map(&col, &col_l)
2341        } else {
2342            ConvTranspose1D(params).map(self, l, kernel, kernel_l)
2343        }
2344    }
2345
2346    fn conv2d(
2347        &self,
2348        l: &Layout,
2349        kernel: &Self,
2350        kernel_l: &Layout,
2351        params: &crate::conv::ParamsConv2D,
2352    ) -> Result<Self> {
2353        if !USE_IM2COL_CONV2D {
2354            return Conv2D(params).map(self, l, kernel, kernel_l);
2355        }
2356        let op = Im2Col {
2357            h_k: params.k_h,
2358            w_k: params.k_w,
2359            padding: params.padding,
2360            stride: params.stride,
2361            dilation: params.dilation,
2362        };
2363        let col = op.map(self, l)?;
2364        let b = params.b_size;
2365        let n = params.c_out;
2366        let (h_out, w_out) = (params.out_h(), params.out_w());
2367        let k = op.h_k * op.w_k * params.c_in;
2368        let m = h_out * w_out;
2369        let col_l = Layout::contiguous((b, m, k));
2370        let res = if kernel_l.is_contiguous() {
2371            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2372                .transpose(1, 2)?
2373                .broadcast_as((b, k, n))?;
2374            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2375        } else {
2376            // Make the kernel contiguous if not already the case.
2377            let mut kernel_c = unsafe {
2378                self.device()
2379                    .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2380            };
2381            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2382            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2383                .transpose(1, 2)?
2384                .broadcast_as((b, k, n))?;
2385            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2386        };
2387        let res_l = Layout::contiguous((b, h_out, w_out, params.c_out))
2388            .transpose(1, 2)?
2389            .transpose(1, 3)?;
2390        let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2391        res.copy_strided_src(&mut res_t, 0, &res_l)?;
2392        Ok(res_t)
2393    }
2394
2395    fn conv_transpose2d(
2396        &self,
2397        l: &Layout,
2398        kernel: &Self,
2399        kernel_l: &Layout,
2400        params: &crate::conv::ParamsConvTranspose2D,
2401    ) -> Result<Self> {
2402        ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2403    }
2404
2405    fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2406        match ids {
2407            Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2408            Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2409            Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2410            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
2411        }
2412    }
2413
2414    fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2415        match ids {
2416            Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2417            Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2418            Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2419            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
2420        }
2421    }
2422
2423    fn scatter_set(
2424        &mut self,
2425        l: &Layout,
2426        ids: &Self,
2427        ids_l: &Layout,
2428        src: &Self,
2429        src_l: &Layout,
2430        dim: usize,
2431    ) -> Result<()> {
2432        match ids {
2433            Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2434            Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2435            Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2436            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
2437        }
2438    }
2439
2440    fn scatter_add_set(
2441        &mut self,
2442        l: &Layout,
2443        ids: &Self,
2444        ids_l: &Layout,
2445        src: &Self,
2446        src_l: &Layout,
2447        dim: usize,
2448    ) -> Result<()> {
2449        match ids {
2450            Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2451            Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2452            Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2453            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
2454        }
2455    }
2456
2457    fn index_add(
2458        &self,
2459        l: &Layout,
2460        ids: &Self,
2461        ids_l: &Layout,
2462        src: &Self,
2463        src_l: &Layout,
2464        dim: usize,
2465    ) -> Result<Self> {
2466        match ids {
2467            Self::U8(ids) => {
2468                let ids = match ids_l.contiguous_offsets() {
2469                    Some((a, b)) => &ids[a..b],
2470                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2471                };
2472                IndexAdd { ids, dim }.map(self, l, src, src_l)
2473            }
2474            Self::U32(ids) => {
2475                let ids = match ids_l.contiguous_offsets() {
2476                    Some((a, b)) => &ids[a..b],
2477                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2478                };
2479                IndexAdd { ids, dim }.map(self, l, src, src_l)
2480            }
2481            Self::I64(ids) => {
2482                let ids = match ids_l.contiguous_offsets() {
2483                    Some((a, b)) => &ids[a..b],
2484                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2485                };
2486                IndexAdd { ids, dim }.map(self, l, src, src_l)
2487            }
2488            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2489        }
2490    }
2491
2492    fn matmul(
2493        &self,
2494        rhs: &Self,
2495        bmnk: (usize, usize, usize, usize),
2496        lhs_l: &Layout,
2497        rhs_l: &Layout,
2498    ) -> Result<Self> {
2499        MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2500    }
2501
2502    fn device(&self) -> &Self::Device {
2503        &CpuDevice
2504    }
2505
2506    fn try_clone(&self, _: &Layout) -> Result<Self> {
2507        Ok(self.clone())
2508    }
2509
2510    fn to_cpu_storage(&self) -> Result<CpuStorage> {
2511        Ok(self.clone())
2512    }
2513
2514    fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
2515        use crate::scalar::Scalar;
2516        fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
2517            match l.strided_blocks() {
2518                crate::StridedBlocks::SingleBlock { start_offset, len } => {
2519                    src[start_offset..start_offset + len].fill(s)
2520                }
2521                crate::StridedBlocks::MultipleBlocks {
2522                    block_start_index,
2523                    block_len: 1,
2524                } => {
2525                    for src_index in block_start_index {
2526                        src[src_index] = s
2527                    }
2528                }
2529                crate::StridedBlocks::MultipleBlocks {
2530                    block_start_index,
2531                    block_len,
2532                } => {
2533                    for src_index in block_start_index {
2534                        src[src_index..src_index + block_len].fill(s)
2535                    }
2536                }
2537            }
2538        }
2539        match (self, s) {
2540            (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
2541            (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
2542            (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
2543            (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
2544            (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
2545            (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
2546            (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
2547            (st, s) => crate::bail!(
2548                "const_set dtype mismatch, expected {:?} but got {:?}",
2549                st.dtype(),
2550                s
2551            ),
2552        }
2553        Ok(())
2554    }
2555}
2556
2557impl BackendDevice for CpuDevice {
2558    type Storage = CpuStorage;
2559
2560    fn location(&self) -> crate::DeviceLocation {
2561        crate::DeviceLocation::Cpu
2562    }
2563
2564    fn same_device(&self, _: &Self) -> bool {
2565        true
2566    }
2567
2568    fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
2569        Ok(T::to_cpu_storage(s))
2570    }
2571
2572    fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
2573        Ok(s.clone())
2574    }
2575
2576    fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
2577        Ok(s)
2578    }
2579
2580    fn new(_: usize) -> Result<Self> {
2581        Ok(Self)
2582    }
2583
2584    fn set_seed(&self, _seed: u64) -> Result<()> {
2585        crate::bail!("cannot seed the CPU rng with set_seed")
2586    }
2587
2588    fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
2589        use rand::prelude::*;
2590
2591        let elem_count = shape.elem_count();
2592        let mut rng = rand::rng();
2593        match dtype {
2594            DType::U8 | DType::U32 | DType::I64 => {
2595                Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
2596            }
2597            DType::BF16 => {
2598                let mut data = Vec::with_capacity(elem_count);
2599                let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
2600                    .map_err(Error::wrap)?;
2601                for _i in 0..elem_count {
2602                    data.push(rng.sample::<bf16, _>(uniform))
2603                }
2604                Ok(CpuStorage::BF16(data))
2605            }
2606            DType::F16 => {
2607                let mut data = Vec::with_capacity(elem_count);
2608                let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
2609                    .map_err(Error::wrap)?;
2610                for _i in 0..elem_count {
2611                    data.push(rng.sample::<f16, _>(uniform))
2612                }
2613                Ok(CpuStorage::F16(data))
2614            }
2615            DType::F32 => {
2616                let mut data = Vec::with_capacity(elem_count);
2617                let uniform =
2618                    rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
2619                for _i in 0..elem_count {
2620                    data.push(rng.sample::<f32, _>(uniform))
2621                }
2622                Ok(CpuStorage::F32(data))
2623            }
2624            DType::F64 => {
2625                let mut data = Vec::with_capacity(elem_count);
2626                let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
2627                for _i in 0..elem_count {
2628                    data.push(rng.sample::<f64, _>(uniform))
2629                }
2630                Ok(CpuStorage::F64(data))
2631            }
2632        }
2633    }
2634
2635    fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
2636        use rand::prelude::*;
2637
2638        let elem_count = shape.elem_count();
2639        let mut rng = rand::rng();
2640        match dtype {
2641            DType::U8 | DType::U32 | DType::I64 => {
2642                Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
2643            }
2644            DType::BF16 => {
2645                let mut data = Vec::with_capacity(elem_count);
2646                let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
2647                    .map_err(Error::wrap)?;
2648                for _i in 0..elem_count {
2649                    data.push(normal.sample(&mut rng))
2650                }
2651                Ok(CpuStorage::BF16(data))
2652            }
2653            DType::F16 => {
2654                let mut data = Vec::with_capacity(elem_count);
2655                let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
2656                    .map_err(Error::wrap)?;
2657                for _i in 0..elem_count {
2658                    data.push(normal.sample(&mut rng))
2659                }
2660                Ok(CpuStorage::F16(data))
2661            }
2662            DType::F32 => {
2663                let mut data = Vec::with_capacity(elem_count);
2664                let normal =
2665                    rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
2666                for _i in 0..elem_count {
2667                    data.push(normal.sample(&mut rng))
2668                }
2669                Ok(CpuStorage::F32(data))
2670            }
2671            DType::F64 => {
2672                let mut data = Vec::with_capacity(elem_count);
2673                let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
2674                for _i in 0..elem_count {
2675                    data.push(normal.sample(&mut rng))
2676                }
2677                Ok(CpuStorage::F64(data))
2678            }
2679        }
2680    }
2681
2682    #[allow(clippy::uninit_vec)]
2683    unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2684        let elem_count = shape.elem_count();
2685        // The code below is highly unsafe but hopefully not directly unsound as we only consider
2686        // types that are Copy, not Drop, and for which all bit patterns are proper values.
2687        // It's still pretty risky, see the following for more details:
2688        // https://github.com/rust-lang/rust-clippy/issues/4483
2689        let storage = match dtype {
2690            DType::U8 => {
2691                let mut v = Vec::with_capacity(elem_count);
2692                v.set_len(elem_count);
2693                CpuStorage::U8(v)
2694            }
2695            DType::U32 => {
2696                let mut v = Vec::with_capacity(elem_count);
2697                v.set_len(elem_count);
2698                CpuStorage::U32(v)
2699            }
2700            DType::I64 => {
2701                let mut v = Vec::with_capacity(elem_count);
2702                v.set_len(elem_count);
2703                CpuStorage::I64(v)
2704            }
2705            DType::BF16 => {
2706                let mut v = Vec::with_capacity(elem_count);
2707                v.set_len(elem_count);
2708                CpuStorage::BF16(v)
2709            }
2710            DType::F16 => {
2711                let mut v = Vec::with_capacity(elem_count);
2712                v.set_len(elem_count);
2713                CpuStorage::F16(v)
2714            }
2715            DType::F32 => {
2716                let mut v = Vec::with_capacity(elem_count);
2717                v.set_len(elem_count);
2718                CpuStorage::F32(v)
2719            }
2720            DType::F64 => {
2721                let mut v = Vec::with_capacity(elem_count);
2722                v.set_len(elem_count);
2723                CpuStorage::F64(v)
2724            }
2725        };
2726        Ok(storage)
2727    }
2728
2729    fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
2730        let elem_count = shape.elem_count();
2731        let storage = match dtype {
2732            DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
2733            DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
2734            DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
2735            DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
2736            DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
2737            DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
2738            DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
2739        };
2740        Ok(storage)
2741    }
2742
2743    fn synchronize(&self) -> Result<()> {
2744        Ok(())
2745    }
2746}
2747
2748#[macro_export]
2749macro_rules! map_dtype {
2750    ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
2751        match $storage {
2752            $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
2753            s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
2754        }
2755    };
2756}