candle_core/cpu_backend/
mod.rs

1//! Implementation of Backend Fns for CPU
2use crate::backend::{BackendDevice, BackendStorage};
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5use float8::F8E4M3;
6use half::{bf16, f16};
7use rayon::prelude::*;
8
9mod utils;
10pub use utils::{
11    binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
12};
13mod conv2d;
14use conv2d::Conv2D;
15
16const USE_IM2COL_CONV1D: bool = true;
17const USE_COL2IM_CONV1D_TR: bool = true;
18
19// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
20// intercept the oom errors to avoid panicking and provide a proper error.
21#[derive(Debug, Clone)]
22pub enum CpuStorage {
23    U8(Vec<u8>),
24    U32(Vec<u32>),
25    I16(Vec<i16>),
26    I32(Vec<i32>),
27    I64(Vec<i64>),
28    BF16(Vec<bf16>),
29    F16(Vec<f16>),
30    F32(Vec<f32>),
31    F64(Vec<f64>),
32    F8E4M3(Vec<F8E4M3>),
33    // Dummy types that store raw bytes
34    F6E2M3(Vec<u8>),
35    F6E3M2(Vec<u8>),
36    F4(Vec<u8>),
37    F8E8M0(Vec<u8>),
38}
39
40#[derive(Debug, Clone)]
41pub enum CpuStorageRef<'a> {
42    U8(&'a [u8]),
43    U32(&'a [u32]),
44    I16(&'a [i16]),
45    I32(&'a [i32]),
46    I64(&'a [i64]),
47    BF16(&'a [bf16]),
48    F16(&'a [f16]),
49    F32(&'a [f32]),
50    F64(&'a [f64]),
51    F8E4M3(&'a [F8E4M3]),
52    // Dummy types that store raw bytes
53    F6E2M3(&'a [u8]),
54    F6E3M2(&'a [u8]),
55    F4(&'a [u8]),
56    F8E8M0(&'a [u8]),
57}
58
59#[derive(Debug, Clone)]
60pub struct CpuDevice;
61
62struct Cmp(CmpOp);
63impl Map2U8 for Cmp {
64    const OP: &'static str = "cmp";
65    #[inline(always)]
66    fn f<T: WithDType>(
67        &self,
68        lhs: &[T],
69        lhs_l: &Layout,
70        rhs: &[T],
71        rhs_l: &Layout,
72    ) -> Result<Vec<u8>> {
73        let dst = match self.0 {
74            CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
75            CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
76            CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
77            CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
78            CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
79            CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
80        };
81        Ok(dst)
82    }
83}
84
85struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
86
87impl<I: IntDType> Map2 for WCond<'_, I> {
88    const OP: &'static str = "where";
89    #[inline(always)]
90    fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
91        let vs = match (
92            self.1.contiguous_offsets(),
93            t_l.contiguous_offsets(),
94            f_l.contiguous_offsets(),
95        ) {
96            (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
97                let pred = &self.0[o1..o2];
98                let t = &t[o_t1..o_t2];
99                let f = &f[o_f1..o_f2];
100                pred.iter()
101                    .zip(t.iter().zip(f.iter()))
102                    .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
103                    .collect::<Vec<_>>()
104            }
105            _ => self
106                .1
107                .strided_index()
108                .zip(t_l.strided_index().zip(f_l.strided_index()))
109                .map(|(i_p, (i_t, i_f))| {
110                    if self.0[i_p].is_true() {
111                        t[i_t]
112                    } else {
113                        f[i_f]
114                    }
115                })
116                .collect::<Vec<_>>(),
117        };
118        Ok(vs)
119    }
120}
121
122struct ReduceIndex {
123    reduce_dim_index: usize,
124    use_min: bool,
125    return_index: bool,
126}
127
128impl ReduceIndex {
129    // The value gets replaced if f(s[current_acc], s[i]) returns true.
130    #[inline(always)]
131    fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
132    where
133        T: Clone + Copy,
134        U: Clone + Copy,
135        F: Fn(T, T) -> bool,
136        G: Fn(T, usize) -> U,
137    {
138        let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
139        let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
140        let dst_len = src_l.shape().elem_count() / reduce_dim_size;
141        let mut dst: Vec<U> = Vec::with_capacity(dst_len);
142        let dst_to_set = dst.spare_capacity_mut();
143        let dst_to_set =
144            unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
145        match src_l.contiguous_offsets() {
146            Some((o1, o2)) => {
147                let src = &src[o1..o2];
148                if reduce_dim_stride == 1 {
149                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
150                        let start_src_i = start_src_i * reduce_dim_size;
151                        let src = &src[start_src_i..start_src_i + reduce_dim_size];
152                        let mut acc = 0;
153                        let mut val = src[0];
154                        for (src_i, &s) in src.iter().enumerate() {
155                            if f(val, s) {
156                                acc = src_i;
157                                val = s
158                            }
159                        }
160                        *dst_v = g(val, acc)
161                    }
162                } else {
163                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
164                        let (p, q) = (
165                            start_src_i / reduce_dim_stride,
166                            start_src_i % reduce_dim_stride,
167                        );
168                        // start_src_i = p * reduce_dim_stride + q
169                        let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
170                        let src = &src[start_src_i..];
171                        let mut acc = 0;
172                        let mut val = src[0];
173                        for src_i in 0..reduce_dim_size {
174                            let s = src[src_i * reduce_dim_stride];
175                            if f(val, s) {
176                                acc = src_i;
177                                val = s
178                            }
179                        }
180                        *dst_v = g(val, acc)
181                    }
182                }
183            }
184            None => {
185                let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
186                for (unstr_index, src_index) in l.strided_index().enumerate() {
187                    let src = &src[src_index..];
188                    let mut acc = 0;
189                    let mut val = src[0];
190                    for src_i in 0..reduce_dim_size {
191                        let s = src[src_i * reduce_dim_stride];
192                        if f(val, s) {
193                            acc = src_i;
194                            val = s
195                        }
196                    }
197                    dst_to_set[unstr_index] = g(val, acc)
198                }
199            }
200        }
201        unsafe { dst.set_len(dst_len) };
202        Ok(dst)
203    }
204}
205
206impl Map1Any for ReduceIndex {
207    #[inline(always)]
208    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
209        &self,
210        src: &[T],
211        src_l: &Layout,
212        wrap: W,
213    ) -> Result<CpuStorage> {
214        if src_l.shape().elem_count() == 0 {
215            Err(Error::EmptyTensor { op: "reduce" }.bt())?
216        }
217        let dst = match (self.return_index, self.use_min) {
218            (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
219            (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
220            (true, true) => {
221                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
222            }
223            (true, false) => {
224                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
225            }
226        };
227        Ok(dst)
228    }
229}
230
231struct ReduceSum<'a> {
232    dst_shape: &'a Shape,
233    reduce_dims: &'a [usize],
234    reduce_dims_and_stride: Vec<(usize, usize)>,
235}
236
237impl ReduceSum<'_> {
238    #[inline(always)]
239    fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
240    where
241        T: WithDType,
242    {
243        let mut dst = vec![start_elt; self.dst_shape.elem_count()];
244        match src_l.contiguous_offsets() {
245            Some((o1, o2)) => {
246                let src = &src[o1..o2];
247                // Handle the case where we reduce over the last dimensions separately as it is
248                // fairly common and easy to optimize. This rely on the layout being contiguous!
249                // reduce_dims is sorted, check if it is ranging from a to n-1.
250                let reduce_over_last_dims = self
251                    .reduce_dims
252                    .iter()
253                    .rev()
254                    .enumerate()
255                    .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
256                if reduce_over_last_dims {
257                    let reduce_sz = self
258                        .reduce_dims_and_stride
259                        .iter()
260                        .map(|(u, _)| u)
261                        .product::<usize>();
262                    for (dst_i, dst_v) in dst.iter_mut().enumerate() {
263                        let src_i = dst_i * reduce_sz;
264                        unsafe {
265                            T::vec_reduce_sum(
266                                src[src_i..src_i + reduce_sz].as_ptr(),
267                                dst_v,
268                                reduce_sz,
269                            )
270                        };
271                    }
272                    return Ok(dst);
273                };
274                for (unstr_index, &src) in src.iter().enumerate() {
275                    let mut dst_index = unstr_index;
276                    // Set the reduce_dims indexes to 0.
277                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
278                        // The compiler is able to optimize the following in a single divmod op.
279                        let (pre, post) = (dst_index / stride, dst_index % stride);
280                        dst_index = (pre / dim) * stride + post;
281                    }
282                    dst[dst_index] += src;
283                }
284            }
285            None => {
286                for (unstr_index, src_index) in src_l.strided_index().enumerate() {
287                    let mut dst_index = unstr_index;
288                    // Set the reduce_dims indexes to 0.
289                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
290                        // The compiler is able to optimize the following in a single divmod op.
291                        let (pre, post) = (dst_index / stride, dst_index % stride);
292                        dst_index = (pre / dim) * stride + post;
293                    }
294                    dst[dst_index] += src[src_index];
295                }
296            }
297        }
298        Ok(dst)
299    }
300}
301
302impl Map1 for ReduceSum<'_> {
303    #[inline(always)]
304    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
305        self.fold_impl(src, src_l, T::zero())
306    }
307}
308
309struct Affine(f64, f64);
310
311impl Map1 for Affine {
312    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
313        let mul = T::from_f64(self.0);
314        let add = T::from_f64(self.1);
315        Ok(unary_map(vs, layout, |v| v * mul + add))
316    }
317}
318
319struct AvgPool2D((usize, usize), (usize, usize));
320
321impl Map1 for AvgPool2D {
322    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
323        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
324        let (k_h, k_w) = self.0;
325        let (s_h, s_w) = self.1;
326        let (b_sz, c, h, w) = layout.shape().dims4()?;
327        let stride = layout.stride();
328        let (stride_h, stride_w) = (stride[2], stride[3]);
329        let h_out = (h - k_h) / s_h + 1;
330        let w_out = (w - k_w) / s_w + 1;
331        let src_index = layout.start_offset();
332        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
333        let scale = 1f64 / (k_h * k_w) as f64;
334        let scale = T::from_f64(scale);
335        for b_idx in 0..b_sz {
336            let dst = &mut dst[b_idx * c * h_out * w_out..];
337            let src_index = src_index + b_idx * stride[0];
338            for c_idx in 0..c {
339                let dst = &mut dst[c_idx * h_out * w_out..];
340                let src_index = src_index + c_idx * stride[1];
341                for h_idx in 0..h_out {
342                    for w_idx in 0..w_out {
343                        let mut sum = T::zero();
344                        for m in 0..k_h {
345                            for n in 0..k_w {
346                                let m = s_h * h_idx + m;
347                                let n = s_w * w_idx + n;
348                                sum += src[src_index + m * stride_h + n * stride_w]
349                            }
350                        }
351                        dst[h_idx * w_out + w_idx] = sum * scale;
352                    }
353                }
354            }
355        }
356        Ok(dst)
357    }
358}
359
360struct MaxPool2D((usize, usize), (usize, usize));
361
362impl Map1 for MaxPool2D {
363    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
364        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
365        let (k_h, k_w) = self.0;
366        let (s_h, s_w) = self.1;
367        let (b_sz, c, h, w) = layout.shape().dims4()?;
368        let stride = layout.stride();
369        let (stride_h, stride_w) = (stride[2], stride[3]);
370        let h_out = (h - k_h) / s_h + 1;
371        let w_out = (w - k_w) / s_w + 1;
372        let src_index = layout.start_offset();
373        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
374        for b_idx in 0..b_sz {
375            let dst = &mut dst[b_idx * c * h_out * w_out..];
376            let src_index = src_index + b_idx * stride[0];
377            for c_idx in 0..c {
378                let dst = &mut dst[c_idx * h_out * w_out..];
379                let src_index = src_index + c_idx * stride[1];
380                for h_idx in 0..h_out {
381                    for w_idx in 0..w_out {
382                        let mut largest =
383                            src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
384                        for m in 0..k_h {
385                            for n in 0..k_w {
386                                let m = s_h * h_idx + m;
387                                let n = s_w * w_idx + n;
388                                if largest < src[src_index + m * stride_h + n * stride_w] {
389                                    largest = src[src_index + m * stride_h + n * stride_w]
390                                }
391                            }
392                        }
393                        dst[h_idx * w_out + w_idx] = largest;
394                    }
395                }
396            }
397        }
398        Ok(dst)
399    }
400}
401
402struct UpsampleNearest1D(usize);
403
404impl Map1 for UpsampleNearest1D {
405    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
406        // TODO: Specialized implementation for the case 2*sz?
407        let dst_sz = self.0;
408        let (b_sz, c, src_sz) = layout.shape().dims3()?;
409        let stride = layout.stride();
410        let stride_sz = stride[2];
411        let src_index = layout.start_offset();
412        let scale_sz = src_sz as f64 / dst_sz as f64;
413        let mut dst = vec![T::zero(); b_sz * c * dst_sz];
414        let src_idxs = (0..dst_sz)
415            .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
416            .collect::<Vec<_>>();
417        for b_idx in 0..b_sz {
418            let dst = &mut dst[b_idx * c * dst_sz..];
419            let src_index = src_index + b_idx * stride[0];
420            for c_idx in 0..c {
421                let dst = &mut dst[c_idx * dst_sz..];
422                let src_index = src_index + c_idx * stride[1];
423                for (idx, src_idx) in src_idxs.iter().enumerate() {
424                    dst[idx] = src[src_index + src_idx * stride_sz]
425                }
426            }
427        }
428        Ok(dst)
429    }
430}
431
432struct UpsampleNearest2D(usize, usize);
433
434impl Map1 for UpsampleNearest2D {
435    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
436        // TODO: Specialized implementation for the case 2*h, 2*w?
437        let (dst_h, dst_w) = (self.0, self.1);
438        let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
439        let stride = layout.stride();
440        let (stride_h, stride_w) = (stride[2], stride[3]);
441        let src_index = layout.start_offset();
442        let scale_h = src_h as f64 / dst_h as f64;
443        let scale_w = src_w as f64 / dst_w as f64;
444        let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
445        let src_h_idxs = (0..dst_h)
446            .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
447            .collect::<Vec<_>>();
448        let src_w_idxs = (0..dst_w)
449            .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
450            .collect::<Vec<_>>();
451        for b_idx in 0..b_sz {
452            let dst = &mut dst[b_idx * c * dst_h * dst_w..];
453            let src_index = src_index + b_idx * stride[0];
454            for c_idx in 0..c {
455                let dst = &mut dst[c_idx * dst_h * dst_w..];
456                let src_index = src_index + c_idx * stride[1];
457                for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
458                    for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
459                        let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
460                        dst[h_idx * dst_w + w_idx] = src[src_index]
461                    }
462                }
463            }
464        }
465        Ok(dst)
466    }
467}
468
469struct UpsampleBilinear2D {
470    target_h: usize,
471    target_w: usize,
472    align_corners: bool,
473    scale_h_factor: Option<f64>,
474    scale_w_factor: Option<f64>,
475}
476
477impl Map1 for UpsampleBilinear2D {
478    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
479        let (batch, channels, height_in, width_in) = layout.shape().dims4()?;
480        let height_out = self.target_h;
481        let width_out = self.target_w;
482
483        // Early return for identity case
484        if height_in == height_out && width_in == width_out {
485            return Ok(src.to_vec());
486        }
487
488        let stride = layout.stride();
489        let src_offset = layout.start_offset();
490
491        // Calculate scale factors following PyTorch's area_pixel_compute_scale logic
492        let scale_h = if self.align_corners {
493            if height_out > 1 {
494                (height_in - 1) as f64 / (height_out - 1) as f64
495            } else {
496                0.0
497            }
498        } else {
499            // PyTorch's compute_scales_value logic:
500            // If scale_factor was provided, use 1.0 / scale_factor
501            // Otherwise, use input_size / output_size
502            if let Some(scale_factor) = self.scale_h_factor {
503                1.0 / scale_factor
504            } else {
505                height_in as f64 / height_out as f64
506            }
507        };
508
509        let scale_w = if self.align_corners {
510            if width_out > 1 {
511                (width_in - 1) as f64 / (width_out - 1) as f64
512            } else {
513                0.0
514            }
515        } else if let Some(scale_factor) = self.scale_w_factor {
516            1.0 / scale_factor
517        } else {
518            width_in as f64 / width_out as f64
519        };
520
521        // Precompute indices and weights for height
522        let mut h_indices = Vec::with_capacity(height_out);
523        for h_out in 0..height_out {
524            let src_h = if self.align_corners {
525                scale_h * h_out as f64
526            } else {
527                scale_h * (h_out as f64 + 0.5) - 0.5
528            };
529            let src_h_clamped = src_h.max(0.0);
530            let h0 = src_h_clamped.floor() as usize;
531            let h1 = (h0 + 1).min(height_in - 1);
532            let weight_h = (src_h_clamped - h0 as f64).clamp(0.0, 1.0);
533            h_indices.push((h0, h1, weight_h));
534        }
535
536        // Precompute indices and weights for width
537        let mut w_indices = Vec::with_capacity(width_out);
538        for w_out in 0..width_out {
539            let src_w = if self.align_corners {
540                scale_w * w_out as f64
541            } else {
542                scale_w * (w_out as f64 + 0.5) - 0.5
543            };
544            let src_w_clamped = src_w.max(0.0);
545            let w0 = src_w_clamped.floor() as usize;
546            let w1 = (w0 + 1).min(width_in - 1);
547            let weight_w = (src_w_clamped - w0 as f64).clamp(0.0, 1.0);
548            w_indices.push((w0, w1, weight_w));
549        }
550
551        // Allocate output
552        let mut dst = vec![T::zero(); batch * channels * height_out * width_out];
553
554        // Perform bilinear interpolation
555        for b in 0..batch {
556            for c in 0..channels {
557                let base_idx = src_offset + b * stride[0] + c * stride[1];
558                let dst_base = (b * channels + c) * height_out * width_out;
559
560                for (h_out, &(h0, h1, weight_h)) in h_indices.iter().enumerate() {
561                    for (w_out, &(w0, w1, weight_w)) in w_indices.iter().enumerate() {
562                        // Get four neighboring pixels
563                        let idx_00 = base_idx + h0 * stride[2] + w0 * stride[3];
564                        let idx_10 = base_idx + h0 * stride[2] + w1 * stride[3];
565                        let idx_01 = base_idx + h1 * stride[2] + w0 * stride[3];
566                        let idx_11 = base_idx + h1 * stride[2] + w1 * stride[3];
567
568                        let v00 = src[idx_00].to_f64();
569                        let v10 = src[idx_10].to_f64();
570                        let v01 = src[idx_01].to_f64();
571                        let v11 = src[idx_11].to_f64();
572
573                        // Bilinear interpolation
574                        let v_top = v00 * (1.0 - weight_w) + v10 * weight_w;
575                        let v_bottom = v01 * (1.0 - weight_w) + v11 * weight_w;
576                        let value = v_top * (1.0 - weight_h) + v_bottom * weight_h;
577
578                        dst[dst_base + h_out * width_out + w_out] = T::from_f64(value);
579                    }
580                }
581            }
582        }
583
584        Ok(dst)
585    }
586}
587
588struct Gather<'a, I: IntDType> {
589    ids: &'a [I],
590    ids_l: &'a Layout,
591    dim: usize,
592}
593
594impl<I: IntDType> Map1 for Gather<'_, I> {
595    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
596        let ids = match self.ids_l.contiguous_offsets() {
597            Some((a, b)) => &self.ids[a..b],
598            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
599        };
600        let src = match src_l.contiguous_offsets() {
601            Some((a, b)) => &src[a..b],
602            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
603        };
604        let dim = self.dim;
605        let ids_dims = self.ids_l.dims();
606        let src_dims = src_l.dims();
607        let dst_len: usize = ids_dims.iter().product();
608        let dst_left_len: usize = ids_dims[..dim].iter().product();
609        let dst_dim_len = ids_dims[dim];
610        let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
611
612        let src_dim_len = src_dims[dim];
613        let src_right_len: usize = src_dims[dim + 1..].iter().product();
614
615        let mut dst = vec![T::zero(); dst_len];
616        for left_i in 0..dst_left_len {
617            let start_src_idx = left_i * src_right_len * src_dim_len;
618            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
619            for i in 0..dst_dim_len {
620                let start_dst_idx = start_dst_idx + i * dst_right_len;
621                for right_i in 0..dst_right_len {
622                    let dst_idx = start_dst_idx + right_i;
623                    let index = ids[dst_idx];
624                    if index == I::max_value() {
625                        dst[dst_idx] = T::zero();
626                    } else {
627                        let index = index.as_usize();
628                        if index >= src_dim_len {
629                            Err(Error::InvalidIndex {
630                                index,
631                                size: src_dim_len,
632                                op: "gather",
633                            }
634                            .bt())?
635                        }
636                        let src_idx = start_src_idx + index * src_right_len + right_i;
637                        dst[dst_idx] = src[src_idx]
638                    }
639                }
640            }
641        }
642        Ok(dst)
643    }
644}
645
646struct IndexSelect<'a, T: IntDType> {
647    ids: &'a [T],
648    ids_l: &'a Layout,
649    dim: usize,
650}
651
652impl<I: IntDType> Map1 for IndexSelect<'_, I> {
653    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
654        let src = match layout.contiguous_offsets() {
655            Some((a, b)) => &src[a..b],
656            None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
657        };
658        let dim = self.dim;
659        let n_ids = match self.ids_l.dims() {
660            [n_ids] => *n_ids,
661            d => Err(Error::UnexpectedNumberOfDims {
662                expected: 1,
663                got: d.len(),
664                shape: self.ids_l.shape().clone(),
665            }
666            .bt())?,
667        };
668        let stride_ids = self.ids_l.stride()[0];
669        let mut dst_dims = layout.dims().to_vec();
670        let src_dim = dst_dims[dim];
671        dst_dims[dim] = n_ids;
672        let dst_len: usize = dst_dims.iter().product();
673        let left_len: usize = dst_dims[..dim].iter().product();
674        let right_len: usize = dst_dims[dim + 1..].iter().product();
675        let mut dst = vec![T::zero(); dst_len];
676        for left_i in 0..left_len {
677            let start_src_idx = left_i * right_len * src_dim;
678            let start_dst_idx = left_i * right_len * n_ids;
679            for i in 0..n_ids {
680                let start_dst_idx = start_dst_idx + i * right_len;
681                let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
682                if index == I::max_value() {
683                    dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
684                } else {
685                    let index = index.as_usize();
686                    if index >= src_dim {
687                        Err(Error::InvalidIndex {
688                            index,
689                            size: src_dim,
690                            op: "index-select",
691                        }
692                        .bt())?
693                    }
694                    let start_src_idx = start_src_idx + index * right_len;
695                    dst[start_dst_idx..start_dst_idx + right_len]
696                        .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
697                }
698            }
699        }
700        Ok(dst)
701    }
702}
703
704trait ElemUpdate {
705    fn f<T: WithDType>(dst: &mut T, src: T);
706}
707
708struct Set;
709struct Add;
710
711impl ElemUpdate for Set {
712    fn f<T: WithDType>(dst: &mut T, src: T) {
713        *dst = src
714    }
715}
716
717impl ElemUpdate for Add {
718    fn f<T: WithDType>(dst: &mut T, src: T) {
719        *dst += src
720    }
721}
722
723struct Scatter<'a, I: IntDType, M: ElemUpdate> {
724    ids: &'a [I],
725    ids_l: &'a Layout,
726    dim: usize,
727    _phantom: std::marker::PhantomData<M>,
728}
729
730impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
731    fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
732        Self {
733            ids,
734            ids_l,
735            dim,
736            _phantom: Default::default(),
737        }
738    }
739}
740
741impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
742    const OP: &'static str = "scatter";
743    fn f<T: WithDType>(
744        &self,
745        dst: &mut [T],
746        dst_l: &Layout,
747        src: &[T],
748        src_l: &Layout,
749    ) -> Result<()> {
750        let dst = match dst_l.contiguous_offsets() {
751            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
752            Some((o1, o2)) => &mut dst[o1..o2],
753        };
754
755        let src = match src_l.contiguous_offsets() {
756            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
757            Some((o1, o2)) => &src[o1..o2],
758        };
759
760        let dim = self.dim;
761        let ids_dims = self.ids_l.dims();
762        let dst_dims = dst_l.dims();
763        let dst_dim_len = dst_dims[dim];
764        let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
765
766        let ids_left_len: usize = ids_dims[..dim].iter().product();
767        let ids_dim_len = ids_dims[dim];
768        let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
769
770        let ids = match self.ids_l.contiguous_offsets() {
771            Some((a, b)) => &self.ids[a..b],
772            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
773        };
774        for left_i in 0..ids_left_len {
775            let start_ids_idx = left_i * ids_right_len * ids_dim_len;
776            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
777            for i in 0..ids_dim_len {
778                let start_ids_idx = start_ids_idx + i * ids_right_len;
779                for right_i in 0..dst_right_len {
780                    let ids_idx = start_ids_idx + right_i;
781                    let index = ids[ids_idx];
782                    if index == I::max_value() {
783                        continue;
784                    }
785                    let index = index.as_usize();
786                    if index >= dst_dim_len {
787                        Err(Error::InvalidIndex {
788                            index,
789                            size: dst_dim_len,
790                            op: "gather",
791                        }
792                        .bt())?
793                    }
794                    let dst_idx = start_dst_idx + index * dst_right_len + right_i;
795                    M::f(&mut dst[dst_idx], src[ids_idx])
796                }
797            }
798        }
799
800        Ok(())
801    }
802}
803
804struct IndexAdd<'a, I: IntDType> {
805    ids: &'a [I],
806    dim: usize,
807}
808
809impl<I: IntDType> Map2 for IndexAdd<'_, I> {
810    const OP: &'static str = "index-add";
811    // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
812    // v1, l1 -> self
813    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
814        let dst_len = l1.shape().elem_count();
815        let mut dst = vec![T::zero(); dst_len];
816        copy_strided_src_(v1, &mut dst, 0, l1);
817        let src = match src_l.contiguous_offsets() {
818            None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
819            Some((o1, o2)) => &src[o1..o2],
820        };
821        let dim = self.dim;
822        let max_idx = l1.dims()[dim];
823        let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
824        let src_dim_sz = src_l.dims()[dim];
825        let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
826        if dim == 0 {
827            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
828                if *dst_idx == I::max_value() {
829                    continue;
830                }
831                let dst_idx = dst_idx.as_usize();
832                if dst_idx >= max_idx {
833                    Err(Error::InvalidIndex {
834                        index: dst_idx,
835                        op: "index-add",
836                        size: max_idx,
837                    })?
838                }
839                let src_idx = src_idx * post_dim;
840                let dst_idx = dst_idx * post_dim;
841                let src = &src[src_idx..src_idx + post_dim];
842                let dst = &mut dst[dst_idx..dst_idx + post_dim];
843                for (d, &s) in dst.iter_mut().zip(src.iter()) {
844                    *d += s
845                }
846            }
847        } else {
848            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
849                if *dst_idx == I::max_value() {
850                    continue;
851                }
852                let dst_idx = dst_idx.as_usize();
853                if dst_idx >= max_idx {
854                    Err(Error::InvalidIndex {
855                        index: dst_idx,
856                        op: "index-add",
857                        size: max_idx,
858                    })?
859                }
860                for pre_i in 0..pre_dim {
861                    let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
862                    let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
863                    let src = &src[pre_src_i..pre_src_i + post_dim];
864                    let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
865                    for (d, &s) in dst.iter_mut().zip(src.iter()) {
866                        *d += s
867                    }
868                }
869            }
870        }
871        Ok(dst)
872    }
873}
874
875#[allow(clippy::too_many_arguments)]
876fn copy2d_<T: Copy>(
877    src: &[T],
878    dst: &mut [T],
879    d1: usize,
880    d2: usize,
881    src_stride1: usize,
882    dst_stride1: usize,
883    src_offset: usize,
884    dst_offset: usize,
885) {
886    for i1 in 0..d1 {
887        let dst_idx = i1 * dst_stride1 + dst_offset;
888        let src_idx = i1 * src_stride1 + src_offset;
889        let dst = &mut dst[dst_idx..dst_idx + d2];
890        let src = &src[src_idx..src_idx + d2];
891        dst.copy_from_slice(src)
892    }
893}
894
895fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
896    match src_l.strided_blocks() {
897        crate::StridedBlocks::SingleBlock { start_offset, len } => {
898            let to_copy = (dst.len() - dst_offset).min(len);
899            dst[dst_offset..dst_offset + to_copy]
900                .copy_from_slice(&src[start_offset..start_offset + to_copy])
901        }
902        crate::StridedBlocks::MultipleBlocks {
903            block_start_index,
904            block_len: 1,
905        } => {
906            for (dst_index, src_index) in block_start_index.enumerate() {
907                let dst_index = dst_index + dst_offset;
908                if dst_index >= dst.len() {
909                    break;
910                }
911                dst[dst_index] = src[src_index]
912            }
913        }
914        crate::StridedBlocks::MultipleBlocks {
915            block_start_index,
916            block_len,
917        } => {
918            let mut dst_index = dst_offset;
919            for src_index in block_start_index {
920                let next_dst_index = dst_index + block_len;
921                if dst_index >= dst.len() {
922                    break;
923                }
924                let to_copy = usize::min(block_len, dst.len() - dst_index);
925                dst[dst_index..dst_index + to_copy]
926                    .copy_from_slice(&src[src_index..src_index + to_copy]);
927                dst_index = next_dst_index
928            }
929        }
930    }
931}
932
933struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
934
935impl Map2 for Conv1D<'_> {
936    const OP: &'static str = "conv1d";
937    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
938        let p = self.0;
939        let inp = &inp[inp_l.start_offset()..];
940        let k = &k[k_l.start_offset()..];
941        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
942        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
943        let l_out = p.l_out();
944        let dst_elems = p.c_out * l_out * p.b_size;
945        // The output shape is [b_size, c_out, l_out]
946        let dst = vec![T::zero(); dst_elems];
947
948        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
949        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
950        for b_idx in 0..p.b_size {
951            for src_l in 0..p.l_in {
952                for src_c_idx in 0..p.c_in {
953                    let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
954                    inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
955                }
956            }
957        }
958
959        for offset in 0..p.k_size {
960            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
961                let dst_idx = dst_c_idx * l_out;
962                let k_cont = (0..p.c_in)
963                    .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
964                    .collect::<Vec<_>>();
965                for b_idx in 0..p.b_size {
966                    let dst_idx = dst_idx + b_idx * p.c_out * l_out;
967                    for dst_l in 0..l_out {
968                        let dst_idx = dst_idx + dst_l;
969                        let src_l = p.stride * dst_l + offset * p.dilation;
970                        if src_l < p.padding || src_l >= p.padding + p.l_in {
971                            continue;
972                        }
973                        let src_l = src_l - p.padding;
974                        let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
975                        assert!(inp_cont.len() >= p.c_in);
976                        assert!(k_cont.len() >= p.c_in);
977                        let mut d = T::zero();
978                        unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
979                        let dst_p = dst.as_ptr();
980                        // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
981                        // the different tasks so no two threads can try to write at the same
982                        // location.
983                        unsafe {
984                            let ptr = dst_p.add(dst_idx) as *mut T;
985                            *ptr += d
986                        }
987                    }
988                }
989            })
990        }
991        Ok(dst)
992    }
993}
994
995struct Im2Col1D {
996    l_k: usize,
997    stride: usize,
998    dilation: usize,
999    padding: usize,
1000}
1001
1002impl Im2Col1D {
1003    fn l_out(&self, l: usize) -> usize {
1004        (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
1005    }
1006}
1007
1008impl Map1 for Im2Col1D {
1009    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1010        let &Self {
1011            l_k,
1012            stride,
1013            dilation,
1014            padding,
1015        } = self;
1016        let (b, c, l) = layout.shape().dims3()?;
1017        let l_out = self.l_out(l);
1018        let src = &vs[layout.start_offset()..];
1019        let mut dst = vec![T::zero(); b * l_out * c * l_k];
1020        let (src_s0, src_s1, src_s2) = {
1021            let s = layout.stride();
1022            (s[0], s[1], s[2])
1023        };
1024        // TODO: provide specialized kernels for the common use cases.
1025        // - l_k = 1
1026        // - padding = 0
1027        // - stride = 1
1028        // - dilation = 1
1029        for b_idx in 0..b {
1030            let src_idx = b_idx * src_s0;
1031            let dst_idx = b_idx * l_out * c * l_k;
1032            for l_idx in 0..l_out {
1033                let dst_idx = dst_idx + l_idx * c * l_k;
1034                for c_idx in 0..c {
1035                    let dst_idx = dst_idx + c_idx * l_k;
1036                    let src_idx = c_idx * src_s1 + src_idx;
1037                    for l_k_idx in 0..l_k {
1038                        let src_l = l_idx * stride + l_k_idx * dilation;
1039                        if padding != 0 && (src_l < padding || src_l >= l + padding) {
1040                            continue;
1041                        }
1042                        let src_l = src_l - padding;
1043                        let src_idx = src_idx + src_l * src_s2;
1044                        let dst_idx = dst_idx + l_k_idx;
1045                        dst[dst_idx] = src[src_idx]
1046                    }
1047                }
1048            }
1049        }
1050        Ok(dst)
1051    }
1052}
1053
1054struct Im2Col {
1055    h_k: usize,
1056    w_k: usize,
1057    stride: usize,
1058    dilation: usize,
1059    padding: usize,
1060}
1061
1062impl Im2Col {
1063    fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
1064        let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
1065        let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
1066        (h_out, w_out)
1067    }
1068}
1069
1070impl Map1 for Im2Col {
1071    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1072        let &Self {
1073            h_k,
1074            w_k,
1075            stride,
1076            dilation,
1077            padding,
1078        } = self;
1079        let (b, c, h, w) = layout.shape().dims4()?;
1080        let (h_out, w_out) = self.hw_out(h, w);
1081        let src = &vs[layout.start_offset()..];
1082        let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
1083        let (src_s0, src_s1, src_s2, src_s3) = {
1084            let s = layout.stride();
1085            (s[0], s[1], s[2], s[3])
1086        };
1087        // TODO: provide specialized kernels for the common use cases.
1088        // - h_k = w_k = 1
1089        // - padding = 0
1090        // - stride = 1
1091        // - dilation = 1
1092        for b_idx in 0..b {
1093            let src_idx = b_idx * src_s0;
1094            let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
1095            for h_idx in 0..h_out {
1096                let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
1097                for w_idx in 0..w_out {
1098                    let dst_idx = dst_idx + w_idx * c * h_k * w_k;
1099                    for c_idx in 0..c {
1100                        let dst_idx = dst_idx + c_idx * h_k * w_k;
1101                        let src_idx = c_idx * src_s1 + src_idx;
1102                        for h_k_idx in 0..h_k {
1103                            let src_h = h_idx * stride + h_k_idx * dilation;
1104                            if padding != 0 && (src_h < padding || src_h >= h + padding) {
1105                                continue;
1106                            }
1107                            let src_h = src_h - padding;
1108                            let src_idx = src_idx + src_h * src_s2;
1109                            let dst_idx = dst_idx + h_k_idx * w_k;
1110                            for w_k_idx in 0..w_k {
1111                                let src_w = w_idx * stride + w_k_idx * dilation;
1112                                if padding != 0 && (src_w < padding || src_w >= w + padding) {
1113                                    continue;
1114                                }
1115                                let src_w = src_w - padding;
1116                                let src_idx = src_idx + src_w * src_s3;
1117                                let dst_idx = dst_idx + w_k_idx;
1118                                dst[dst_idx] = src[src_idx]
1119                            }
1120                        }
1121                    }
1122                }
1123            }
1124        }
1125        Ok(dst)
1126    }
1127}
1128
1129struct Col2Im1D {
1130    stride: usize,
1131}
1132
1133impl Map1 for Col2Im1D {
1134    fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
1135        let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
1136        let stride = self.stride;
1137        let l_out = (l_in - 1) * stride + k_size;
1138        let mut im = vec![T::zero(); b_size * c_out * l_out];
1139        let (dst_s0, dst_s1) = (c_out * l_out, l_out);
1140        let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
1141        for l_in_i in 0..l_in {
1142            for k_i in 0..k_size {
1143                let l_out_i = l_in_i * stride + k_i;
1144                for b_i in 0..b_size {
1145                    for c_i in 0..c_out {
1146                        let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
1147                        let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
1148                        im[dst_idx] += col[src_idx]
1149                    }
1150                }
1151            }
1152        }
1153        Ok(im)
1154    }
1155}
1156
1157struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
1158
1159impl Map2 for ConvTranspose1D<'_> {
1160    const OP: &'static str = "conv_transpose1d";
1161    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1162        let p = self.0;
1163        let inp = &inp[inp_l.start_offset()..];
1164        let k = &k[k_l.start_offset()..];
1165        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1166        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1167        let l_out = p.l_out();
1168
1169        // Output shape: [b_size, c_out, l_out].
1170        let dst_elems = p.c_out * l_out * p.b_size;
1171        let dst = vec![T::zero(); dst_elems];
1172        let dst_s0 = p.c_out * l_out;
1173        let dst_s1 = l_out;
1174        let dst_s2 = 1;
1175
1176        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1177        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1178        let cont_s0 = p.l_in * p.c_in;
1179        let cont_s1 = p.c_in;
1180        for b_idx in 0..p.b_size {
1181            for l_idx in 0..p.l_in {
1182                for c_idx in 0..p.c_in {
1183                    let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
1184                    let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
1185                    inp_cont[dst_idx] = inp[src_idx]
1186                }
1187            }
1188        }
1189
1190        for k_idx in 0..p.k_size {
1191            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1192                let k_cont = (0..p.c_in)
1193                    .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
1194                    .collect::<Vec<_>>();
1195                for b_idx in 0..p.b_size {
1196                    for l_idx in 0..p.l_in {
1197                        let out_idx = l_idx * p.stride + k_idx * p.dilation;
1198                        if out_idx < p.padding {
1199                            continue;
1200                        }
1201                        let out_idx = out_idx - p.padding;
1202                        if out_idx < l_out {
1203                            let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
1204                            let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
1205                            let mut d = T::zero();
1206                            unsafe {
1207                                T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1208                            }
1209                            let dst_p = dst.as_ptr();
1210                            // Safety: dst_idx are uniques per dst_c_idx which is used to
1211                            // parallelise the different tasks so no two threads can try to
1212                            // write at the same location.
1213                            unsafe {
1214                                let ptr = dst_p.add(dst_idx) as *mut T;
1215                                *ptr += d
1216                            }
1217                        }
1218                    }
1219                }
1220            })
1221        }
1222        Ok(dst)
1223    }
1224}
1225
1226struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1227
1228impl Map2 for ConvTranspose2D<'_> {
1229    const OP: &'static str = "conv_transpose2d";
1230    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1231        let p = self.0;
1232        let inp = &inp[inp_l.start_offset()..];
1233        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1234        let k = &k[k_l.start_offset()..];
1235        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1236        let (out_h, out_w) = (p.out_h(), p.out_w());
1237
1238        // Output shape: [b_size, c_out, out_h, out_w].
1239        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1240        let dst_s0 = p.c_out * out_h * out_w;
1241        let dst_s1 = out_h * out_w;
1242        let dst_s2 = out_w;
1243        let dst_s3 = 1;
1244
1245        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1246        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1247        let cont_s0 = p.i_h * p.i_w * p.c_in;
1248        let cont_s1 = p.i_w * p.c_in;
1249        let cont_s2 = p.c_in;
1250        for b_idx in 0..p.b_size {
1251            for h_idx in 0..p.i_h {
1252                for w_idx in 0..p.i_w {
1253                    for c_idx in 0..p.c_in {
1254                        let src_idx =
1255                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1256                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1257                        inp_cont[dst_idx] = inp[src_idx]
1258                    }
1259                }
1260            }
1261        }
1262
1263        for k_y in 0..p.k_h {
1264            for k_x in 0..p.k_w {
1265                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1266                    let k_cont = (0..p.c_in)
1267                        .map(|c_in_idx| {
1268                            k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1269                        })
1270                        .collect::<Vec<_>>();
1271                    for b_idx in 0..p.b_size {
1272                        for inp_y in 0..p.i_h {
1273                            for inp_x in 0..p.i_w {
1274                                let out_x = inp_x * p.stride + k_x * p.dilation;
1275                                let out_y = inp_y * p.stride + k_y * p.dilation;
1276                                if out_x < p.padding || out_y < p.padding {
1277                                    continue;
1278                                }
1279                                let out_x = out_x - p.padding;
1280                                let out_y = out_y - p.padding;
1281                                if out_x < out_w && out_y < out_h {
1282                                    let inp_cont = &inp_cont
1283                                        [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1284                                    let dst_idx = b_idx * dst_s0
1285                                        + out_y * dst_s2
1286                                        + out_x * dst_s3
1287                                        + dst_c_idx * dst_s1;
1288                                    let mut d = T::zero();
1289                                    unsafe {
1290                                        T::vec_dot(
1291                                            inp_cont.as_ptr(),
1292                                            k_cont.as_ptr(),
1293                                            &mut d,
1294                                            p.c_in,
1295                                        )
1296                                    }
1297                                    let dst_p = dst.as_ptr();
1298                                    // Safety: dst_idx are uniques per dst_c_idx which is used to
1299                                    // parallelise the different tasks so no two threads can try to
1300                                    // write at the same location.
1301                                    unsafe {
1302                                        let ptr = dst_p.add(dst_idx) as *mut T;
1303                                        *ptr += d
1304                                    }
1305                                }
1306                            }
1307                        }
1308                    }
1309                })
1310            }
1311        }
1312        Ok(dst)
1313    }
1314}
1315
1316struct MatMul((usize, usize, usize, usize));
1317
1318impl MatMul {
1319    fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1320        Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1321            lhs_l: lhs_l.clone(),
1322            rhs_l: rhs_l.clone(),
1323            bmnk: self.0,
1324            msg,
1325        }))
1326        .bt()
1327    }
1328
1329    fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
1330        let lhs_stride = lhs_l.stride();
1331        let rhs_stride = rhs_l.stride();
1332        let rank = lhs_stride.len();
1333        let (_b, m, n, k) = self.0;
1334        let a_skip: usize = match lhs_stride[..rank - 2] {
1335            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1336            [_, stride] if lhs_l.dims()[0] == 1 => stride,
1337            [stride, _] if lhs_l.dims()[1] == 1 => stride,
1338            [stride] => stride,
1339            [] => m * k,
1340            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1341        };
1342        let b_skip: usize = match rhs_stride[..rank - 2] {
1343            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1344            [_, stride] if rhs_l.dims()[0] == 1 => stride,
1345            [stride, _] if rhs_l.dims()[1] == 1 => stride,
1346            [stride] => stride,
1347            [] => n * k,
1348            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1349        };
1350        Ok((a_skip, b_skip))
1351    }
1352}
1353
1354impl Map2 for MatMul {
1355    const OP: &'static str = "mat_mul";
1356
1357    #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1358    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1359        &self,
1360        lhs: &[T],
1361        lhs_l: &Layout,
1362        rhs: &[T],
1363        rhs_l: &Layout,
1364    ) -> Result<Vec<T>> {
1365        use gemm::{gemm, Parallelism};
1366
1367        match T::DTYPE {
1368            DType::F16 | DType::F32 | DType::F64 => {}
1369            _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1370        }
1371
1372        let (b, m, n, k) = self.0;
1373        let lhs = &lhs[lhs_l.start_offset()..];
1374        let rhs = &rhs[rhs_l.start_offset()..];
1375
1376        let lhs_stride = lhs_l.stride();
1377        let rhs_stride = rhs_l.stride();
1378        let rank = lhs_stride.len();
1379        let lhs_cs = lhs_stride[rank - 1];
1380        let lhs_rs = lhs_stride[rank - 2];
1381
1382        let rhs_cs = rhs_stride[rank - 1];
1383        let rhs_rs = rhs_stride[rank - 2];
1384
1385        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1386        let c_skip: usize = m * n;
1387
1388        let dst_shape: Shape = (m, n).into();
1389        let dst_strides = dst_shape.stride_contiguous();
1390        let dst_rs = dst_strides[0];
1391        let dst_cs = dst_strides[1];
1392
1393        let mut dst = vec![T::zero(); b * m * n];
1394        let num_threads = crate::utils::get_num_threads();
1395        let parallelism = if num_threads > 1 {
1396            Parallelism::Rayon(num_threads)
1397        } else {
1398            Parallelism::None
1399        };
1400        let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
1401            // a_skip and c_skip should be updated but step is always 0 so
1402            // it wouldn't matter.
1403            (1, b * m, n, k)
1404        } else if a_skip == 0 && b_skip == n * k {
1405            (1, m, b * n, k)
1406        } else {
1407            (b, m, n, k)
1408        };
1409        for step in 0..b {
1410            let lhs_p = &lhs[step * a_skip..];
1411            let rhs_p = &rhs[step * b_skip..];
1412            let dst_p = &mut dst[step * c_skip..];
1413            unsafe {
1414                gemm(
1415                    /* m: usize = */ m,
1416                    /* n: usize = */ n,
1417                    /* k: usize = */ k,
1418                    /* dst: *mut T = */ dst_p.as_mut_ptr(),
1419                    /* dst_cs: isize = */ dst_cs as isize,
1420                    /* dst_rs: isize = */ dst_rs as isize,
1421                    /* read_dst: bool = */ false,
1422                    /* lhs: *const T = */ lhs_p.as_ptr(),
1423                    /* lhs_cs: isize = */ lhs_cs as isize,
1424                    /* lhs_rs: isize = */ lhs_rs as isize,
1425                    /* rhs: *const T = */ rhs_p.as_ptr(),
1426                    /* rhs_cs: isize = */ rhs_cs as isize,
1427                    /* rhs_rs: isize = */ rhs_rs as isize,
1428                    /* alpha: T = */ T::zero(),
1429                    /* beta: T = */ T::one(),
1430                    /* conj_dst: bool = */ false,
1431                    /* conj_lhs: bool = */ false,
1432                    /* conj_rhs: bool = */ false,
1433                    parallelism,
1434                )
1435            }
1436        }
1437        Ok(dst)
1438    }
1439
1440    #[cfg(feature = "accelerate")]
1441    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1442        &self,
1443        lhs: &[T],
1444        lhs_l: &Layout,
1445        rhs: &[T],
1446        rhs_l: &Layout,
1447    ) -> Result<Vec<T>> {
1448        let (b, m, n, k) = self.0;
1449        let lhs = &lhs[lhs_l.start_offset()..];
1450        let rhs = &rhs[rhs_l.start_offset()..];
1451
1452        let lhs_stride = lhs_l.stride();
1453        let rhs_stride = rhs_l.stride();
1454
1455        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1456        let c_skip: usize = m * n;
1457
1458        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1459        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1460        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1461        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1462
1463        let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1464            (n as i32, b'N')
1465        } else if rhs_m1 == k && rhs_m2 == 1 {
1466            (k as i32, b'T')
1467        } else {
1468            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1469        };
1470        // The b tensor has dims batching, m, k (lhs)
1471        let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1472            (k as i32, b'N')
1473        } else if lhs_m1 == m && lhs_m2 == 1 {
1474            (m as i32, b'T')
1475        } else {
1476            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1477        };
1478
1479        let mut dst = vec![T::zero(); b * m * n];
1480        match T::DTYPE {
1481            DType::F16 => {
1482                crate::bail!("the accelerate backend does not support f16 matmul")
1483            }
1484            DType::F32 => {
1485                for step in 0..b {
1486                    let lhs_p = &lhs[step * a_skip..];
1487                    let rhs_p = &rhs[step * b_skip..];
1488                    let dst_p = &mut dst[step * c_skip..];
1489                    unsafe {
1490                        let a = rhs_p.as_ptr() as *const f32;
1491                        let b = lhs_p.as_ptr() as *const f32;
1492                        let c = dst_p.as_mut_ptr() as *mut f32;
1493                        let a = std::slice::from_raw_parts(a, a_skip);
1494                        let b = std::slice::from_raw_parts(b, b_skip);
1495                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1496                        crate::accelerate::sgemm(
1497                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1498                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1499                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1500                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1501                        )
1502                    }
1503                }
1504            }
1505            DType::F64 => {
1506                for step in 0..b {
1507                    let lhs_p = &lhs[step * a_skip..];
1508                    let rhs_p = &rhs[step * b_skip..];
1509                    let dst_p = &mut dst[step * c_skip..];
1510                    unsafe {
1511                        let a = rhs_p.as_ptr() as *const f64;
1512                        let b = lhs_p.as_ptr() as *const f64;
1513                        let c = dst_p.as_mut_ptr() as *mut f64;
1514                        let a = std::slice::from_raw_parts(a, a_skip);
1515                        let b = std::slice::from_raw_parts(b, b_skip);
1516                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1517                        crate::accelerate::dgemm(
1518                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1519                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1520                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1521                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1522                        )
1523                    }
1524                }
1525            }
1526            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1527        }
1528        Ok(dst)
1529    }
1530
1531    #[cfg(feature = "mkl")]
1532    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1533        &self,
1534        lhs: &[T],
1535        lhs_l: &Layout,
1536        rhs: &[T],
1537        rhs_l: &Layout,
1538    ) -> Result<Vec<T>> {
1539        let (b, m, n, k) = self.0;
1540        let lhs = &lhs[lhs_l.start_offset()..];
1541        let rhs = &rhs[rhs_l.start_offset()..];
1542
1543        let lhs_stride = lhs_l.stride();
1544        let rhs_stride = rhs_l.stride();
1545
1546        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1547        let c_skip: usize = m * n;
1548
1549        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1550        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1551        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1552        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1553
1554        let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1555            (n as i32, b'N')
1556        } else if rhs_m1 == k && rhs_m2 == 1 {
1557            (k as i32, b'T')
1558        } else {
1559            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1560        };
1561        // The b tensor has dims batching, m, k (lhs)
1562        let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1563            (k as i32, b'N')
1564        } else if lhs_m1 == m && lhs_m2 == 1 {
1565            (m as i32, b'T')
1566        } else {
1567            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1568        };
1569
1570        let mut dst = vec![T::zero(); b * m * n];
1571        match T::DTYPE {
1572            DType::F16 => {
1573                for step in 0..b {
1574                    let lhs_p = &lhs[step * a_skip..];
1575                    let rhs_p = &rhs[step * b_skip..];
1576                    let dst_p = &mut dst[step * c_skip..];
1577                    unsafe {
1578                        let a = rhs_p.as_ptr() as *const f16;
1579                        let b = lhs_p.as_ptr() as *const f16;
1580                        let c = dst_p.as_mut_ptr() as *mut f16;
1581                        let a = std::slice::from_raw_parts(a, a_skip);
1582                        let b = std::slice::from_raw_parts(b, b_skip);
1583                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1584                        crate::mkl::hgemm(
1585                            transa,
1586                            transb,
1587                            /* m= */ n as i32,
1588                            /* n= */ m as i32,
1589                            /* k= */ k as i32,
1590                            /* alpha= */ f16::ONE,
1591                            /* a= */ a,
1592                            /* lda= */ lda,
1593                            /* b= */ b,
1594                            /* ldb= */ ldb,
1595                            /* beta= */ f16::ZERO,
1596                            /* c= */ c,
1597                            /* ldc= */ n as i32,
1598                        )
1599                    }
1600                }
1601            }
1602            DType::F32 => {
1603                for step in 0..b {
1604                    let lhs_p = &lhs[step * a_skip..];
1605                    let rhs_p = &rhs[step * b_skip..];
1606                    let dst_p = &mut dst[step * c_skip..];
1607                    unsafe {
1608                        let a = rhs_p.as_ptr() as *const f32;
1609                        let b = lhs_p.as_ptr() as *const f32;
1610                        let c = dst_p.as_mut_ptr() as *mut f32;
1611                        let a = std::slice::from_raw_parts(a, a_skip);
1612                        let b = std::slice::from_raw_parts(b, b_skip);
1613                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1614                        crate::mkl::sgemm(
1615                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1616                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1617                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1618                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1619                        )
1620                    }
1621                }
1622            }
1623            DType::F64 => {
1624                for step in 0..b {
1625                    let lhs_p = &lhs[step * a_skip..];
1626                    let rhs_p = &rhs[step * b_skip..];
1627                    let dst_p = &mut dst[step * c_skip..];
1628                    unsafe {
1629                        let a = rhs_p.as_ptr() as *const f64;
1630                        let b = lhs_p.as_ptr() as *const f64;
1631                        let c = dst_p.as_mut_ptr() as *mut f64;
1632                        let a = std::slice::from_raw_parts(a, a_skip);
1633                        let b = std::slice::from_raw_parts(b, b_skip);
1634                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1635                        crate::mkl::dgemm(
1636                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1637                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1638                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1639                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1640                        )
1641                    }
1642                }
1643            }
1644            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1645        }
1646        Ok(dst)
1647    }
1648}
1649
1650fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1651    if v.is_sign_positive() {
1652        v
1653    } else {
1654        (v.exp() - T::one()) * alpha
1655    }
1656}
1657
1658impl CpuStorage {
1659    pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1660        D::cpu_storage_as_slice(self)
1661    }
1662
1663    pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1664        let storage0 = &storages[0];
1665        let s = match storage0 {
1666            Self::U8(_) => {
1667                let storages = storages
1668                    .iter()
1669                    .map(|s| match s {
1670                        Self::U8(s) => Ok(s.as_slice()),
1671                        _ => crate::bail!("dtype mismatch"),
1672                    })
1673                    .collect::<Result<Vec<_>>>()?
1674                    .concat();
1675                Self::U8(storages)
1676            }
1677            Self::U32(_) => {
1678                let storages = storages
1679                    .iter()
1680                    .map(|s| match s {
1681                        Self::U32(s) => Ok(s.as_slice()),
1682                        _ => crate::bail!("dtype mismatch"),
1683                    })
1684                    .collect::<Result<Vec<_>>>()?
1685                    .concat();
1686                Self::U32(storages)
1687            }
1688            Self::I16(_) => {
1689                let storages = storages
1690                    .iter()
1691                    .map(|s| match s {
1692                        Self::I16(s) => Ok(s.as_slice()),
1693                        _ => crate::bail!("dtype mismatch"),
1694                    })
1695                    .collect::<Result<Vec<_>>>()?
1696                    .concat();
1697                Self::I16(storages)
1698            }
1699            Self::I32(_) => {
1700                let storages = storages
1701                    .iter()
1702                    .map(|s| match s {
1703                        Self::I32(s) => Ok(s.as_slice()),
1704                        _ => crate::bail!("dtype mismatch"),
1705                    })
1706                    .collect::<Result<Vec<_>>>()?
1707                    .concat();
1708                Self::I32(storages)
1709            }
1710            Self::I64(_) => {
1711                let storages = storages
1712                    .iter()
1713                    .map(|s| match s {
1714                        Self::I64(s) => Ok(s.as_slice()),
1715                        _ => crate::bail!("dtype mismatch"),
1716                    })
1717                    .collect::<Result<Vec<_>>>()?
1718                    .concat();
1719                Self::I64(storages)
1720            }
1721            Self::BF16(_) => {
1722                let storages = storages
1723                    .iter()
1724                    .map(|s| match s {
1725                        Self::BF16(s) => Ok(s.as_slice()),
1726                        _ => crate::bail!("dtype mismatch"),
1727                    })
1728                    .collect::<Result<Vec<_>>>()?
1729                    .concat();
1730                Self::BF16(storages)
1731            }
1732            Self::F16(_) => {
1733                let storages = storages
1734                    .iter()
1735                    .map(|s| match s {
1736                        Self::F16(s) => Ok(s.as_slice()),
1737                        _ => crate::bail!("dtype mismatch"),
1738                    })
1739                    .collect::<Result<Vec<_>>>()?
1740                    .concat();
1741                Self::F16(storages)
1742            }
1743            Self::F32(_) => {
1744                let storages = storages
1745                    .iter()
1746                    .map(|s| match s {
1747                        Self::F32(s) => Ok(s.as_slice()),
1748                        _ => crate::bail!("dtype mismatch"),
1749                    })
1750                    .collect::<Result<Vec<_>>>()?
1751                    .concat();
1752                Self::F32(storages)
1753            }
1754            Self::F64(_) => {
1755                let storages = storages
1756                    .iter()
1757                    .map(|s| match s {
1758                        Self::F64(s) => Ok(s.as_slice()),
1759                        _ => crate::bail!("dtype mismatch"),
1760                    })
1761                    .collect::<Result<Vec<_>>>()?
1762                    .concat();
1763                Self::F64(storages)
1764            }
1765            Self::F8E4M3(_) => {
1766                let storages = storages
1767                    .iter()
1768                    .map(|s| match s {
1769                        Self::F8E4M3(s) => Ok(s.as_slice()),
1770                        _ => crate::bail!("dtype mismatch"),
1771                    })
1772                    .collect::<Result<Vec<_>>>()?
1773                    .concat();
1774                Self::F8E4M3(storages)
1775            }
1776            Self::F6E2M3(_) => {
1777                let storages = storages
1778                    .iter()
1779                    .map(|s| match s {
1780                        Self::F6E2M3(s) => Ok(s.as_slice()),
1781                        _ => crate::bail!("dtype mismatch"),
1782                    })
1783                    .collect::<Result<Vec<_>>>()?
1784                    .concat();
1785                Self::F6E2M3(storages)
1786            }
1787            Self::F6E3M2(_) => {
1788                let storages = storages
1789                    .iter()
1790                    .map(|s| match s {
1791                        Self::F6E3M2(s) => Ok(s.as_slice()),
1792                        _ => crate::bail!("dtype mismatch"),
1793                    })
1794                    .collect::<Result<Vec<_>>>()?
1795                    .concat();
1796                Self::F6E3M2(storages)
1797            }
1798            Self::F4(_) => {
1799                let storages = storages
1800                    .iter()
1801                    .map(|s| match s {
1802                        Self::F4(s) => Ok(s.as_slice()),
1803                        _ => crate::bail!("dtype mismatch"),
1804                    })
1805                    .collect::<Result<Vec<_>>>()?
1806                    .concat();
1807                Self::F4(storages)
1808            }
1809            Self::F8E8M0(_) => {
1810                let storages = storages
1811                    .iter()
1812                    .map(|s| match s {
1813                        Self::F8E8M0(s) => Ok(s.as_slice()),
1814                        _ => crate::bail!("dtype mismatch"),
1815                    })
1816                    .collect::<Result<Vec<_>>>()?
1817                    .concat();
1818                Self::F8E8M0(storages)
1819            }
1820        };
1821        Ok(s)
1822    }
1823}
1824
1825impl BackendStorage for CpuStorage {
1826    type Device = CpuDevice;
1827
1828    fn dtype(&self) -> DType {
1829        match self {
1830            Self::U8(_) => DType::U8,
1831            Self::U32(_) => DType::U32,
1832            Self::I16(_) => DType::I16,
1833            Self::I32(_) => DType::I32,
1834            Self::I64(_) => DType::I64,
1835            Self::BF16(_) => DType::BF16,
1836            Self::F16(_) => DType::F16,
1837            Self::F32(_) => DType::F32,
1838            Self::F64(_) => DType::F64,
1839            Self::F8E4M3(_) => DType::F8E4M3,
1840            Self::F6E2M3(_) => DType::F6E2M3,
1841            Self::F6E3M2(_) => DType::F6E3M2,
1842            Self::F4(_) => DType::F4,
1843            Self::F8E8M0(_) => DType::F8E8M0,
1844        }
1845    }
1846
1847    fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1848        // TODO: find a way around the quadratic number of cases below.
1849        match (self, dtype) {
1850            (Self::U8(storage), DType::BF16) => {
1851                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1852                Ok(Self::BF16(data))
1853            }
1854            (Self::U32(storage), DType::BF16) => {
1855                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1856                Ok(Self::BF16(data))
1857            }
1858            (Self::I64(storage), DType::BF16) => {
1859                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1860                Ok(Self::BF16(data))
1861            }
1862            (Self::BF16(storage), DType::BF16) => {
1863                let data = unary_map(storage, layout, |v| v);
1864                Ok(Self::BF16(data))
1865            }
1866            (Self::F16(storage), DType::BF16) => {
1867                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1868                Ok(Self::BF16(data))
1869            }
1870            (Self::F32(storage), DType::BF16) => {
1871                let data = unary_map(storage, layout, bf16::from_f32);
1872                Ok(Self::BF16(data))
1873            }
1874            (Self::F64(storage), DType::BF16) => {
1875                let data = unary_map(storage, layout, bf16::from_f64);
1876                Ok(Self::BF16(data))
1877            }
1878            (Self::U8(storage), DType::F16) => {
1879                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1880                Ok(Self::F16(data))
1881            }
1882            (Self::U32(storage), DType::F16) => {
1883                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1884                Ok(Self::F16(data))
1885            }
1886            (Self::I64(storage), DType::F16) => {
1887                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1888                Ok(Self::F16(data))
1889            }
1890            (Self::BF16(storage), DType::F16) => {
1891                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1892                Ok(Self::F16(data))
1893            }
1894            (Self::F16(storage), DType::F16) => {
1895                let data = unary_map(storage, layout, |v| v);
1896                Ok(Self::F16(data))
1897            }
1898            (Self::F32(storage), DType::F16) => {
1899                let data = unary_map(storage, layout, f16::from_f32);
1900                Ok(Self::F16(data))
1901            }
1902            (Self::F64(storage), DType::F16) => {
1903                let data = unary_map(storage, layout, f16::from_f64);
1904                Ok(Self::F16(data))
1905            }
1906            (Self::U8(storage), DType::F32) => {
1907                let data = unary_map(storage, layout, |v| v as f32);
1908                Ok(Self::F32(data))
1909            }
1910            (Self::U32(storage), DType::F32) => {
1911                let data = unary_map(storage, layout, |v| v as f32);
1912                Ok(Self::F32(data))
1913            }
1914            (Self::I64(storage), DType::F32) => {
1915                let data = unary_map(storage, layout, |v| v as f32);
1916                Ok(Self::F32(data))
1917            }
1918            (Self::BF16(storage), DType::F32) => {
1919                let data = unary_map(storage, layout, |v| v.to_f32());
1920                Ok(Self::F32(data))
1921            }
1922            (Self::F16(storage), DType::F32) => {
1923                let data = unary_map(storage, layout, |v| v.to_f32());
1924                Ok(Self::F32(data))
1925            }
1926            (Self::F32(storage), DType::F32) => {
1927                let data = unary_map(storage, layout, |v| v);
1928                Ok(Self::F32(data))
1929            }
1930            (Self::F64(storage), DType::F32) => {
1931                let data = unary_map(storage, layout, |v| v as f32);
1932                Ok(Self::F32(data))
1933            }
1934            (Self::U8(storage), DType::U8) => {
1935                let data = unary_map(storage, layout, |v| v);
1936                Ok(Self::U8(data))
1937            }
1938            (Self::BF16(storage), DType::U8) => {
1939                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1940                Ok(Self::U8(data))
1941            }
1942            (Self::F16(storage), DType::U8) => {
1943                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1944                Ok(Self::U8(data))
1945            }
1946            (Self::F32(storage), DType::U8) => {
1947                let data = unary_map(storage, layout, |v| v as u8);
1948                Ok(Self::U8(data))
1949            }
1950            (Self::F64(storage), DType::U8) => {
1951                let data = unary_map(storage, layout, |v| v as u8);
1952                Ok(Self::U8(data))
1953            }
1954            (Self::U32(storage), DType::U8) => {
1955                let data = unary_map(storage, layout, |v| v as u8);
1956                Ok(Self::U8(data))
1957            }
1958            (Self::I64(storage), DType::U8) => {
1959                let data = unary_map(storage, layout, |v| v as u8);
1960                Ok(Self::U8(data))
1961            }
1962            (Self::U8(storage), DType::U32) => {
1963                let data = unary_map(storage, layout, |v| v as u32);
1964                Ok(Self::U32(data))
1965            }
1966            (Self::U32(storage), DType::U32) => {
1967                let data = unary_map(storage, layout, |v| v);
1968                Ok(Self::U32(data))
1969            }
1970            (Self::I64(storage), DType::U32) => {
1971                let data = unary_map(storage, layout, |v| v as u32);
1972                Ok(Self::U32(data))
1973            }
1974            (Self::BF16(storage), DType::U32) => {
1975                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1976                Ok(Self::U32(data))
1977            }
1978            (Self::F16(storage), DType::U32) => {
1979                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1980                Ok(Self::U32(data))
1981            }
1982            (Self::F32(storage), DType::U32) => {
1983                let data = unary_map(storage, layout, |v| v as u32);
1984                Ok(Self::U32(data))
1985            }
1986            (Self::F64(storage), DType::U32) => {
1987                let data = unary_map(storage, layout, |v| v as u32);
1988                Ok(Self::U32(data))
1989            }
1990            (Self::U8(storage), DType::I64) => {
1991                let data = unary_map(storage, layout, |v| v as i64);
1992                Ok(Self::I64(data))
1993            }
1994            (Self::U32(storage), DType::I64) => {
1995                let data = unary_map(storage, layout, |v| v as i64);
1996                Ok(Self::I64(data))
1997            }
1998            (Self::I64(storage), DType::I64) => {
1999                let data = unary_map(storage, layout, |v| v);
2000                Ok(Self::I64(data))
2001            }
2002            (Self::BF16(storage), DType::I64) => {
2003                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2004                Ok(Self::I64(data))
2005            }
2006            (Self::F16(storage), DType::I64) => {
2007                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2008                Ok(Self::I64(data))
2009            }
2010            (Self::F32(storage), DType::I64) => {
2011                let data = unary_map(storage, layout, |v| v as i64);
2012                Ok(Self::I64(data))
2013            }
2014            (Self::F64(storage), DType::I64) => {
2015                let data = unary_map(storage, layout, |v| v as i64);
2016                Ok(Self::I64(data))
2017            }
2018            (Self::U8(storage), DType::F64) => {
2019                let data = unary_map(storage, layout, |v| v as f64);
2020                Ok(Self::F64(data))
2021            }
2022            (Self::U32(storage), DType::F64) => {
2023                let data = unary_map(storage, layout, |v| v as f64);
2024                Ok(Self::F64(data))
2025            }
2026            (Self::I64(storage), DType::F64) => {
2027                let data = unary_map(storage, layout, |v| v as f64);
2028                Ok(Self::F64(data))
2029            }
2030            (Self::BF16(storage), DType::F64) => {
2031                let data = unary_map(storage, layout, |v| v.to_f64());
2032                Ok(Self::F64(data))
2033            }
2034            (Self::F16(storage), DType::F64) => {
2035                let data = unary_map(storage, layout, |v| v.to_f64());
2036                Ok(Self::F64(data))
2037            }
2038            (Self::F32(storage), DType::F64) => {
2039                let data = unary_map(storage, layout, |v| v as f64);
2040                Ok(Self::F64(data))
2041            }
2042            (Self::F64(storage), DType::F64) => {
2043                let data = unary_map(storage, layout, |v| v);
2044                Ok(Self::F64(data))
2045            }
2046            // Conversions to F8E4M3
2047            (Self::U8(storage), DType::F8E4M3) => {
2048                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2049                Ok(Self::F8E4M3(data))
2050            }
2051            (Self::U32(storage), DType::F8E4M3) => {
2052                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2053                Ok(Self::F8E4M3(data))
2054            }
2055            (Self::I64(storage), DType::F8E4M3) => {
2056                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2057                Ok(Self::F8E4M3(data))
2058            }
2059            (Self::BF16(storage), DType::F8E4M3) => {
2060                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2061                Ok(Self::F8E4M3(data))
2062            }
2063            (Self::F16(storage), DType::F8E4M3) => {
2064                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2065                Ok(Self::F8E4M3(data))
2066            }
2067            (Self::F32(storage), DType::F8E4M3) => {
2068                let data = unary_map(storage, layout, F8E4M3::from_f32);
2069                Ok(Self::F8E4M3(data))
2070            }
2071            (Self::F64(storage), DType::F8E4M3) => {
2072                let data = unary_map(storage, layout, F8E4M3::from_f64);
2073                Ok(Self::F8E4M3(data))
2074            }
2075            (Self::F8E4M3(storage), DType::F8E4M3) => {
2076                let data = unary_map(storage, layout, |v| v);
2077                Ok(Self::F8E4M3(data))
2078            }
2079            // Conversions from F8E4M3
2080            (Self::F8E4M3(storage), DType::U8) => {
2081                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
2082                Ok(Self::U8(data))
2083            }
2084            (Self::F8E4M3(storage), DType::U32) => {
2085                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2086                Ok(Self::U32(data))
2087            }
2088            (Self::F8E4M3(storage), DType::I64) => {
2089                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2090                Ok(Self::I64(data))
2091            }
2092            (Self::F8E4M3(storage), DType::BF16) => {
2093                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
2094                Ok(Self::BF16(data))
2095            }
2096            (Self::F8E4M3(storage), DType::F16) => {
2097                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
2098                Ok(Self::F16(data))
2099            }
2100            (Self::F8E4M3(storage), DType::F32) => {
2101                let data = unary_map(storage, layout, |v| v.to_f32());
2102                Ok(Self::F32(data))
2103            }
2104            (Self::F8E4M3(storage), DType::F64) => {
2105                let data = unary_map(storage, layout, |v| v.to_f64());
2106                Ok(Self::F64(data))
2107            }
2108            // Conversions to I16
2109            (Self::U8(storage), DType::I16) => {
2110                let data = unary_map(storage, layout, |v| v as i16);
2111                Ok(Self::I16(data))
2112            }
2113            (Self::U32(storage), DType::I16) => {
2114                let data = unary_map(storage, layout, |v| v as i16);
2115                Ok(Self::I16(data))
2116            }
2117            (Self::I16(storage), DType::I16) => {
2118                let data = unary_map(storage, layout, |v| v);
2119                Ok(Self::I16(data))
2120            }
2121            (Self::I32(storage), DType::I16) => {
2122                let data = unary_map(storage, layout, |v| v as i16);
2123                Ok(Self::I16(data))
2124            }
2125            (Self::I64(storage), DType::I16) => {
2126                let data = unary_map(storage, layout, |v| v as i16);
2127                Ok(Self::I16(data))
2128            }
2129            (Self::BF16(storage), DType::I16) => {
2130                let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2131                Ok(Self::I16(data))
2132            }
2133            (Self::F16(storage), DType::I16) => {
2134                let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2135                Ok(Self::I16(data))
2136            }
2137            (Self::F32(storage), DType::I16) => {
2138                let data = unary_map(storage, layout, |v| v as i16);
2139                Ok(Self::I16(data))
2140            }
2141            (Self::F64(storage), DType::I16) => {
2142                let data = unary_map(storage, layout, |v| v as i16);
2143                Ok(Self::I16(data))
2144            }
2145            (Self::F8E4M3(storage), DType::I16) => {
2146                let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2147                Ok(Self::I16(data))
2148            }
2149            // Conversions to I32
2150            (Self::U8(storage), DType::I32) => {
2151                let data = unary_map(storage, layout, |v| v as i32);
2152                Ok(Self::I32(data))
2153            }
2154            (Self::U32(storage), DType::I32) => {
2155                let data = unary_map(storage, layout, |v| v as i32);
2156                Ok(Self::I32(data))
2157            }
2158            (Self::I16(storage), DType::I32) => {
2159                let data = unary_map(storage, layout, |v| v as i32);
2160                Ok(Self::I32(data))
2161            }
2162            (Self::I32(storage), DType::I32) => {
2163                let data = unary_map(storage, layout, |v| v);
2164                Ok(Self::I32(data))
2165            }
2166            (Self::I64(storage), DType::I32) => {
2167                let data = unary_map(storage, layout, |v| v as i32);
2168                Ok(Self::I32(data))
2169            }
2170            (Self::BF16(storage), DType::I32) => {
2171                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2172                Ok(Self::I32(data))
2173            }
2174            (Self::F16(storage), DType::I32) => {
2175                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2176                Ok(Self::I32(data))
2177            }
2178            (Self::F32(storage), DType::I32) => {
2179                let data = unary_map(storage, layout, |v| v as i32);
2180                Ok(Self::I32(data))
2181            }
2182            (Self::F64(storage), DType::I32) => {
2183                let data = unary_map(storage, layout, |v| v as i32);
2184                Ok(Self::I32(data))
2185            }
2186            (Self::F8E4M3(storage), DType::I32) => {
2187                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2188                Ok(Self::I32(data))
2189            }
2190            // Conversions from I16
2191            (Self::I16(storage), DType::U8) => {
2192                let data = unary_map(storage, layout, |v| v as u8);
2193                Ok(Self::U8(data))
2194            }
2195            (Self::I16(storage), DType::U32) => {
2196                let data = unary_map(storage, layout, |v| v as u32);
2197                Ok(Self::U32(data))
2198            }
2199            (Self::I16(storage), DType::I64) => {
2200                let data = unary_map(storage, layout, |v| v as i64);
2201                Ok(Self::I64(data))
2202            }
2203            (Self::I16(storage), DType::BF16) => {
2204                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2205                Ok(Self::BF16(data))
2206            }
2207            (Self::I16(storage), DType::F16) => {
2208                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2209                Ok(Self::F16(data))
2210            }
2211            (Self::I16(storage), DType::F32) => {
2212                let data = unary_map(storage, layout, |v| v as f32);
2213                Ok(Self::F32(data))
2214            }
2215            (Self::I16(storage), DType::F64) => {
2216                let data = unary_map(storage, layout, |v| v as f64);
2217                Ok(Self::F64(data))
2218            }
2219            (Self::I16(storage), DType::F8E4M3) => {
2220                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2221                Ok(Self::F8E4M3(data))
2222            }
2223            // Conversions from I32
2224            (Self::I32(storage), DType::U8) => {
2225                let data = unary_map(storage, layout, |v| v as u8);
2226                Ok(Self::U8(data))
2227            }
2228            (Self::I32(storage), DType::U32) => {
2229                let data = unary_map(storage, layout, |v| v as u32);
2230                Ok(Self::U32(data))
2231            }
2232            (Self::I32(storage), DType::I64) => {
2233                let data = unary_map(storage, layout, |v| v as i64);
2234                Ok(Self::I64(data))
2235            }
2236            (Self::I32(storage), DType::BF16) => {
2237                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2238                Ok(Self::BF16(data))
2239            }
2240            (Self::I32(storage), DType::F16) => {
2241                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2242                Ok(Self::F16(data))
2243            }
2244            (Self::I32(storage), DType::F32) => {
2245                let data = unary_map(storage, layout, |v| v as f32);
2246                Ok(Self::F32(data))
2247            }
2248            (Self::I32(storage), DType::F64) => {
2249                let data = unary_map(storage, layout, |v| v as f64);
2250                Ok(Self::F64(data))
2251            }
2252            (Self::I32(storage), DType::F8E4M3) => {
2253                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2254                Ok(Self::F8E4M3(data))
2255            }
2256            // Dummy types - return error for all conversions to/from dummy types
2257            (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => {
2258                Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt())
2259            }
2260            (Self::F6E2M3(_), _)
2261            | (Self::F6E3M2(_), _)
2262            | (Self::F4(_), _)
2263            | (Self::F8E8M0(_), _) => {
2264                Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt())
2265            }
2266        }
2267    }
2268
2269    fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
2270        match op {
2271            ReduceOp::Sum => {
2272                let src_dims = layout.dims();
2273                let mut dst_dims = src_dims.to_vec();
2274                for &dim in reduce_dims.iter() {
2275                    dst_dims[dim] = 1;
2276                }
2277                let dst_shape = Shape::from(dst_dims);
2278                let mut reduce_dims = reduce_dims.to_vec();
2279                // Sort the reduce_dims as they have to be processed from left to right when converting the
2280                // indexes.
2281                reduce_dims.sort();
2282                let reduce_dims_and_stride: Vec<_> = reduce_dims
2283                    .iter()
2284                    .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
2285                    .collect();
2286                ReduceSum {
2287                    dst_shape: &dst_shape,
2288                    reduce_dims: &reduce_dims,
2289                    reduce_dims_and_stride,
2290                }
2291                .map(self, layout)
2292            }
2293            ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
2294                let reduce_dim_index = match reduce_dims {
2295                    [reduce_dim_index] => *reduce_dim_index,
2296                    _ => {
2297                        let op = match op {
2298                            ReduceOp::Min => "min",
2299                            ReduceOp::ArgMin => "argmin",
2300                            ReduceOp::Max => "max",
2301                            ReduceOp::ArgMax => "argmax",
2302                            _ => unreachable!(),
2303                        };
2304                        let dims = reduce_dims.to_vec();
2305                        Err(Error::OnlySingleDimension { op, dims })?
2306                    }
2307                };
2308                let (use_min, return_index) = match op {
2309                    ReduceOp::Min => (true, false),
2310                    ReduceOp::ArgMin => (true, true),
2311                    ReduceOp::Max => (false, false),
2312                    ReduceOp::ArgMax => (false, true),
2313                    _ => unreachable!(),
2314                };
2315                ReduceIndex {
2316                    reduce_dim_index,
2317                    use_min,
2318                    return_index,
2319                }
2320                .map(self, layout)
2321            }
2322        }
2323    }
2324
2325    fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
2326        Cmp(op).map(self, lhs_l, rhs, rhs_l)
2327    }
2328
2329    fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
2330        Affine(mul, add).map(self, layout)
2331    }
2332
2333    fn avg_pool2d(
2334        &self,
2335        layout: &Layout,
2336        kernel_size: (usize, usize),
2337        stride: (usize, usize),
2338    ) -> Result<Self> {
2339        AvgPool2D(kernel_size, stride).map(self, layout)
2340    }
2341
2342    fn max_pool2d(
2343        &self,
2344        layout: &Layout,
2345        kernel_size: (usize, usize),
2346        stride: (usize, usize),
2347    ) -> Result<Self> {
2348        MaxPool2D(kernel_size, stride).map(self, layout)
2349    }
2350
2351    fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
2352        UpsampleNearest1D(sz).map(self, layout)
2353    }
2354
2355    fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
2356        UpsampleNearest2D(h, w).map(self, layout)
2357    }
2358
2359    fn upsample_bilinear2d(
2360        &self,
2361        layout: &Layout,
2362        h: usize,
2363        w: usize,
2364        align_corners: bool,
2365        scale_h: Option<f64>,
2366        scale_w: Option<f64>,
2367    ) -> Result<Self> {
2368        UpsampleBilinear2D {
2369            target_h: h,
2370            target_w: w,
2371            align_corners,
2372            scale_h_factor: scale_h,
2373            scale_w_factor: scale_w,
2374        }
2375        .map(self, layout)
2376    }
2377
2378    fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
2379        use num_traits::Float;
2380        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2381        match self {
2382            Self::BF16(storage) => {
2383                let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
2384                Ok(Self::BF16(data))
2385            }
2386            Self::F16(storage) => {
2387                let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
2388                Ok(Self::F16(data))
2389            }
2390            Self::F32(storage) => {
2391                let data = unary_map(storage, layout, |v| v.powf(e as f32));
2392                Ok(Self::F32(data))
2393            }
2394            Self::F64(storage) => {
2395                let data = unary_map(storage, layout, |v| v.powf(e));
2396                Ok(Self::F64(data))
2397            }
2398            Self::F8E4M3(storage) => {
2399                let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
2400                Ok(Self::F8E4M3(data))
2401            }
2402            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()),
2403            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()),
2404            Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()),
2405            Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()),
2406            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()),
2407            Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()),
2408            Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()),
2409            Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()),
2410            Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()),
2411        }
2412    }
2413
2414    fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2415        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2416        match self {
2417            Self::BF16(storage) => {
2418                let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2419                Ok(Self::BF16(data))
2420            }
2421            Self::F16(storage) => {
2422                let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2423                Ok(Self::F16(data))
2424            }
2425            Self::F32(storage) => {
2426                let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2427                Ok(Self::F32(data))
2428            }
2429            Self::F64(storage) => {
2430                let data = unary_map(storage, layout, |v| elu(v, alpha));
2431                Ok(Self::F64(data))
2432            }
2433            Self::F8E4M3(storage) => {
2434                let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
2435                Ok(Self::F8E4M3(data))
2436            }
2437            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2438            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2439            Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
2440            Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
2441            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2442            Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()),
2443            Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()),
2444            Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()),
2445            Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()),
2446        }
2447    }
2448
2449    fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2450        match self {
2451            Self::BF16(storage) => {
2452                if B::BF16_VEC {
2453                    let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2454                    Ok(Self::BF16(data))
2455                } else {
2456                    let data = unary_map(storage, layout, B::bf16);
2457                    Ok(Self::BF16(data))
2458                }
2459            }
2460            Self::F16(storage) => {
2461                if B::F16_VEC {
2462                    let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2463                    Ok(Self::F16(data))
2464                } else {
2465                    let data = unary_map(storage, layout, B::f16);
2466                    Ok(Self::F16(data))
2467                }
2468            }
2469            Self::F32(storage) => {
2470                if B::F32_VEC {
2471                    let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2472                    Ok(Self::F32(data))
2473                } else {
2474                    let data = unary_map(storage, layout, B::f32);
2475                    Ok(Self::F32(data))
2476                }
2477            }
2478            Self::F64(storage) => {
2479                if B::F64_VEC {
2480                    let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2481                    Ok(Self::F64(data))
2482                } else {
2483                    let data = unary_map(storage, layout, B::f64);
2484                    Ok(Self::F64(data))
2485                }
2486            }
2487            Self::U8(storage) => {
2488                let data = unary_map(storage, layout, B::u8);
2489                Ok(Self::U8(data))
2490            }
2491            Self::U32(storage) => {
2492                let data = unary_map(storage, layout, B::u32);
2493                Ok(Self::U32(data))
2494            }
2495            Self::I16(storage) => {
2496                let data = unary_map(storage, layout, B::i16);
2497                Ok(Self::I16(data))
2498            }
2499            Self::I32(storage) => {
2500                let data = unary_map(storage, layout, B::i32);
2501                Ok(Self::I32(data))
2502            }
2503            Self::I64(storage) => {
2504                let data = unary_map(storage, layout, B::i64);
2505                Ok(Self::I64(data))
2506            }
2507            Self::F8E4M3(storage) => {
2508                let data = unary_map(storage, layout, B::f8e4m3);
2509                Ok(Self::F8E4M3(data))
2510            }
2511            Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()),
2512            Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()),
2513            Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()),
2514            Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()),
2515        }
2516    }
2517
2518    fn binary_impl<B: BinaryOpT>(
2519        &self,
2520        rhs: &Self,
2521        lhs_l: &Layout,
2522        rhs_l: &Layout,
2523    ) -> Result<Self> {
2524        match (self, rhs) {
2525            (Self::BF16(lhs), Self::BF16(rhs)) => {
2526                let data = if B::BF16_VEC {
2527                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::bf16, B::bf16_vec)
2528                } else {
2529                    binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16)
2530                };
2531                Ok(Self::BF16(data))
2532            }
2533            (Self::F16(lhs), Self::F16(rhs)) => {
2534                let data = if B::F16_VEC {
2535                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f16, B::f16_vec)
2536                } else {
2537                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f16)
2538                };
2539                Ok(Self::F16(data))
2540            }
2541            (Self::F32(lhs), Self::F32(rhs)) => {
2542                let data = if B::F32_VEC {
2543                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f32, B::f32_vec)
2544                } else {
2545                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f32)
2546                };
2547                Ok(Self::F32(data))
2548            }
2549            (Self::F64(lhs), Self::F64(rhs)) => {
2550                let data = if B::F64_VEC {
2551                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::f64, B::f64_vec)
2552                } else {
2553                    binary_map(lhs_l, rhs_l, lhs, rhs, B::f64)
2554                };
2555                Ok(Self::F64(data))
2556            }
2557            (Self::U32(lhs), Self::U32(rhs)) => {
2558                let data = if B::U32_VEC {
2559                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u32, B::u32_vec)
2560                } else {
2561                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u32)
2562                };
2563                Ok(Self::U32(data))
2564            }
2565            (Self::I16(lhs), Self::I16(rhs)) => {
2566                let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i16);
2567                Ok(Self::I16(data))
2568            }
2569            (Self::I32(lhs), Self::I32(rhs)) => {
2570                let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::i32);
2571                Ok(Self::I32(data))
2572            }
2573            (Self::I64(lhs), Self::I64(rhs)) => {
2574                let data = if B::I64_VEC {
2575                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
2576                } else {
2577                    binary_map(lhs_l, rhs_l, lhs, rhs, B::i64)
2578                };
2579                Ok(Self::I64(data))
2580            }
2581            (Self::U8(lhs), Self::U8(rhs)) => {
2582                let data = if B::U8_VEC {
2583                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec)
2584                } else {
2585                    binary_map(lhs_l, rhs_l, lhs, rhs, B::u8)
2586                };
2587                Ok(Self::U8(data))
2588            }
2589            (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => {
2590                let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3);
2591                Ok(Self::F8E4M3(data))
2592            }
2593            _ => {
2594                // This should be covered by the dtype check above.
2595                Err(Error::DTypeMismatchBinaryOp {
2596                    lhs: self.dtype(),
2597                    rhs: rhs.dtype(),
2598                    op: B::NAME,
2599                }
2600                .bt())
2601            }
2602        }
2603    }
2604
2605    fn copy2d(
2606        &self,
2607        dst: &mut Self,
2608        d1: usize,
2609        d2: usize,
2610        src_s: usize,
2611        dst_s: usize,
2612        src_o: usize,
2613        dst_o: usize,
2614    ) -> Result<()> {
2615        match (self, dst) {
2616            (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2617            (Self::U32(src), Self::U32(dst)) => {
2618                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2619            }
2620            (Self::I16(src), Self::I16(dst)) => {
2621                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2622            }
2623            (Self::I32(src), Self::I32(dst)) => {
2624                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2625            }
2626            (Self::I64(src), Self::I64(dst)) => {
2627                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2628            }
2629            (Self::BF16(src), Self::BF16(dst)) => {
2630                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2631            }
2632            (Self::F16(src), Self::F16(dst)) => {
2633                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2634            }
2635            (Self::F32(src), Self::F32(dst)) => {
2636                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2637            }
2638            (Self::F64(src), Self::F64(dst)) => {
2639                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2640            }
2641            (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2642                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2643            }
2644            (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2645                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2646            }
2647            (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2648                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2649            }
2650            (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2651            (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2652                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2653            }
2654            (_, dst) => {
2655                return Err(Error::DTypeMismatchBinaryOp {
2656                    lhs: self.dtype(),
2657                    rhs: dst.dtype(),
2658                    op: "copy2d",
2659                }
2660                .bt());
2661            }
2662        }
2663        Ok(())
2664    }
2665
2666    fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2667        match (self, dst) {
2668            (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2669            (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2670            (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2671            (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2672            (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2673            (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2674            (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2675            (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2676            (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2677            (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2678                copy_strided_src_(src, dst, dst_offset, src_l)
2679            }
2680            (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2681                copy_strided_src_(src, dst, dst_offset, src_l)
2682            }
2683            (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2684                copy_strided_src_(src, dst, dst_offset, src_l)
2685            }
2686            (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2687            (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2688                copy_strided_src_(src, dst, dst_offset, src_l)
2689            }
2690            (_, dst) => {
2691                // This should be covered by the dtype check above.
2692                return Err(Error::DTypeMismatchBinaryOp {
2693                    lhs: self.dtype(),
2694                    rhs: dst.dtype(),
2695                    op: "copy_strided",
2696                }
2697                .bt());
2698            }
2699        }
2700        Ok(())
2701    }
2702
2703    fn where_cond(
2704        &self,
2705        layout: &Layout,
2706        t: &Self,
2707        t_l: &Layout,
2708        f: &Self,
2709        f_l: &Layout,
2710    ) -> Result<Self> {
2711        match self {
2712            Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2713            Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2714            Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2715            Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2716            Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2717            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2718        }
2719    }
2720
2721    fn conv1d(
2722        &self,
2723        l: &Layout,
2724        kernel: &Self,
2725        kernel_l: &Layout,
2726        params: &crate::conv::ParamsConv1D,
2727    ) -> Result<Self> {
2728        if !USE_IM2COL_CONV1D {
2729            return Conv1D(params).map(self, l, kernel, kernel_l);
2730        }
2731        let op = Im2Col1D {
2732            l_k: params.k_size,
2733            padding: params.padding,
2734            stride: params.stride,
2735            dilation: params.dilation,
2736        };
2737        let col = op.map(self, l)?;
2738        let b = params.b_size;
2739        let n = params.c_out;
2740        let l_out = params.l_out();
2741        let k = op.l_k * params.c_in;
2742        let m = l_out;
2743        let col_l = Layout::contiguous((b, m, k));
2744        let res = if kernel_l.is_contiguous() {
2745            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2746                .transpose(1, 2)?
2747                .broadcast_as((b, k, n))?;
2748            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2749        } else {
2750            // Make the kernel contiguous if not already the case.
2751            let mut kernel_c = unsafe {
2752                self.device()
2753                    .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2754            };
2755            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2756            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2757                .transpose(1, 2)?
2758                .broadcast_as((b, k, n))?;
2759            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2760        };
2761        let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2762        let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2763        res.copy_strided_src(&mut res_t, 0, &res_l)?;
2764        Ok(res_t)
2765    }
2766
2767    fn conv_transpose1d(
2768        &self,
2769        l: &Layout,
2770        kernel: &Self,
2771        kernel_l: &Layout,
2772        params: &crate::conv::ParamsConvTranspose1D,
2773    ) -> Result<Self> {
2774        let can_use_col2im = kernel_l.is_contiguous()
2775            && params.dilation == 1
2776            && params.padding == 0
2777            && params.output_padding == 0;
2778        if USE_COL2IM_CONV1D_TR && can_use_col2im {
2779            let (b_size, c_in, l_in) = l.shape().dims3()?;
2780            let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
2781            if !kernel_l.is_contiguous() {
2782                crate::bail!(
2783                    "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
2784                )
2785            }
2786            if c_in != c_in2 {
2787                crate::bail!(
2788                    "convtr1d: shape mismatch on c_in {:?} {:?}",
2789                    l.shape(),
2790                    kernel_l.shape()
2791                )
2792            }
2793            let col = {
2794                // This merges the last two dimensions of the kernel together.
2795                let kernel_l_mm = Layout::new(
2796                    (b_size, c_in, k_size * c_out).into(),
2797                    vec![0, k_size * c_out, 1],
2798                    kernel_l.start_offset(),
2799                );
2800                self.matmul(
2801                    kernel,
2802                    (
2803                        b_size,
2804                        /* m */ l_in,
2805                        /* n */ c_out * k_size,
2806                        /* k */ c_in,
2807                    ),
2808                    &l.transpose(1, 2)?,
2809                    &kernel_l_mm,
2810                )?
2811            };
2812            let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
2813            Col2Im1D {
2814                stride: params.stride,
2815            }
2816            .map(&col, &col_l)
2817        } else {
2818            ConvTranspose1D(params).map(self, l, kernel, kernel_l)
2819        }
2820    }
2821
2822    fn conv2d(
2823        &self,
2824        l: &Layout,
2825        kernel: &Self,
2826        kernel_l: &Layout,
2827        params: &crate::conv::ParamsConv2D,
2828    ) -> Result<Self> {
2829        Conv2D(params).map(self, l, kernel, kernel_l)
2830    }
2831
2832    fn conv_transpose2d(
2833        &self,
2834        l: &Layout,
2835        kernel: &Self,
2836        kernel_l: &Layout,
2837        params: &crate::conv::ParamsConvTranspose2D,
2838    ) -> Result<Self> {
2839        ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2840    }
2841
2842    fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2843        match ids {
2844            Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2845            Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2846            Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2847            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
2848        }
2849    }
2850
2851    fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2852        match ids {
2853            Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2854            Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2855            Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2856            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
2857        }
2858    }
2859
2860    fn scatter_set(
2861        &mut self,
2862        l: &Layout,
2863        ids: &Self,
2864        ids_l: &Layout,
2865        src: &Self,
2866        src_l: &Layout,
2867        dim: usize,
2868    ) -> Result<()> {
2869        match ids {
2870            Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2871            Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2872            Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2873            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
2874        }
2875    }
2876
2877    fn scatter_add_set(
2878        &mut self,
2879        l: &Layout,
2880        ids: &Self,
2881        ids_l: &Layout,
2882        src: &Self,
2883        src_l: &Layout,
2884        dim: usize,
2885    ) -> Result<()> {
2886        match ids {
2887            Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2888            Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2889            Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2890            Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2891            Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2892            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
2893        }
2894    }
2895
2896    fn index_add(
2897        &self,
2898        l: &Layout,
2899        ids: &Self,
2900        ids_l: &Layout,
2901        src: &Self,
2902        src_l: &Layout,
2903        dim: usize,
2904    ) -> Result<Self> {
2905        match ids {
2906            Self::U8(ids) => {
2907                let ids = match ids_l.contiguous_offsets() {
2908                    Some((a, b)) => &ids[a..b],
2909                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2910                };
2911                IndexAdd { ids, dim }.map(self, l, src, src_l)
2912            }
2913            Self::U32(ids) => {
2914                let ids = match ids_l.contiguous_offsets() {
2915                    Some((a, b)) => &ids[a..b],
2916                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2917                };
2918                IndexAdd { ids, dim }.map(self, l, src, src_l)
2919            }
2920            Self::I16(ids) => {
2921                let ids = match ids_l.contiguous_offsets() {
2922                    Some((a, b)) => &ids[a..b],
2923                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2924                };
2925                IndexAdd { ids, dim }.map(self, l, src, src_l)
2926            }
2927            Self::I32(ids) => {
2928                let ids = match ids_l.contiguous_offsets() {
2929                    Some((a, b)) => &ids[a..b],
2930                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2931                };
2932                IndexAdd { ids, dim }.map(self, l, src, src_l)
2933            }
2934            Self::I64(ids) => {
2935                let ids = match ids_l.contiguous_offsets() {
2936                    Some((a, b)) => &ids[a..b],
2937                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2938                };
2939                IndexAdd { ids, dim }.map(self, l, src, src_l)
2940            }
2941            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2942        }
2943    }
2944
2945    fn matmul(
2946        &self,
2947        rhs: &Self,
2948        bmnk: (usize, usize, usize, usize),
2949        lhs_l: &Layout,
2950        rhs_l: &Layout,
2951    ) -> Result<Self> {
2952        MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2953    }
2954
2955    fn device(&self) -> &Self::Device {
2956        &CpuDevice
2957    }
2958
2959    fn try_clone(&self, _: &Layout) -> Result<Self> {
2960        Ok(self.clone())
2961    }
2962
2963    fn to_cpu_storage(&self) -> Result<CpuStorage> {
2964        Ok(self.clone())
2965    }
2966
2967    fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
2968        use crate::scalar::Scalar;
2969        fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
2970            match l.strided_blocks() {
2971                crate::StridedBlocks::SingleBlock { start_offset, len } => {
2972                    src[start_offset..start_offset + len].fill(s)
2973                }
2974                crate::StridedBlocks::MultipleBlocks {
2975                    block_start_index,
2976                    block_len: 1,
2977                } => {
2978                    for src_index in block_start_index {
2979                        src[src_index] = s
2980                    }
2981                }
2982                crate::StridedBlocks::MultipleBlocks {
2983                    block_start_index,
2984                    block_len,
2985                } => {
2986                    for src_index in block_start_index {
2987                        src[src_index..src_index + block_len].fill(s)
2988                    }
2989                }
2990            }
2991        }
2992        match (self, s) {
2993            (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
2994            (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
2995            (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
2996            (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
2997            (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
2998            (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
2999            (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v),
3000            (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v),
3001            (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
3002            (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
3003            // Dummy types don't support scalar operations
3004            (Self::F6E2M3(_), _) => {
3005                crate::bail!("const_set not supported for dummy type F6E2M3")
3006            }
3007            (Self::F6E3M2(_), _) => {
3008                crate::bail!("const_set not supported for dummy type F6E3M2")
3009            }
3010            (Self::F4(_), _) => {
3011                crate::bail!("const_set not supported for dummy type F4")
3012            }
3013            (Self::F8E8M0(_), _) => {
3014                crate::bail!("const_set not supported for dummy type F8E8M0")
3015            }
3016            (st, s) => crate::bail!(
3017                "const_set dtype mismatch, expected {:?} but got {:?}",
3018                st.dtype(),
3019                s
3020            ),
3021        }
3022        Ok(())
3023    }
3024}
3025
3026impl BackendDevice for CpuDevice {
3027    type Storage = CpuStorage;
3028
3029    fn location(&self) -> crate::DeviceLocation {
3030        crate::DeviceLocation::Cpu
3031    }
3032
3033    fn same_device(&self, _: &Self) -> bool {
3034        true
3035    }
3036
3037    fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
3038        Ok(T::to_cpu_storage(s))
3039    }
3040
3041    fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
3042        Ok(s.clone())
3043    }
3044
3045    fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
3046        Ok(s)
3047    }
3048
3049    fn new(_: usize) -> Result<Self> {
3050        Ok(Self)
3051    }
3052
3053    fn set_seed(&self, _seed: u64) -> Result<()> {
3054        crate::bail!("cannot seed the CPU rng with set_seed")
3055    }
3056
3057    fn get_current_seed(&self) -> Result<u64> {
3058        crate::bail!("cannot get the CPU rng seed with get_current_seed")
3059    }
3060
3061    fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
3062        use rand::prelude::*;
3063
3064        let elem_count = shape.elem_count();
3065        let mut rng = rand::rng();
3066        match dtype {
3067            DType::U8
3068            | DType::U32
3069            | DType::I16
3070            | DType::I32
3071            | DType::I64
3072            | DType::F6E2M3
3073            | DType::F6E3M2
3074            | DType::F4
3075            | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
3076            DType::BF16 => {
3077                let mut data = Vec::with_capacity(elem_count);
3078                let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3079                    .map_err(Error::wrap)?;
3080                for _i in 0..elem_count {
3081                    data.push(rng.sample::<bf16, _>(uniform))
3082                }
3083                Ok(CpuStorage::BF16(data))
3084            }
3085            DType::F16 => {
3086                let mut data = Vec::with_capacity(elem_count);
3087                let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3088                    .map_err(Error::wrap)?;
3089                for _i in 0..elem_count {
3090                    data.push(rng.sample::<f16, _>(uniform))
3091                }
3092                Ok(CpuStorage::F16(data))
3093            }
3094            DType::F8E4M3 => {
3095                let mut data = Vec::with_capacity(elem_count);
3096                let uniform =
3097                    rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
3098                        .map_err(Error::wrap)?;
3099                for _i in 0..elem_count {
3100                    data.push(rng.sample::<F8E4M3, _>(uniform))
3101                }
3102                Ok(CpuStorage::F8E4M3(data))
3103            }
3104            DType::F32 => {
3105                let mut data = Vec::with_capacity(elem_count);
3106                let uniform =
3107                    rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
3108                for _i in 0..elem_count {
3109                    data.push(rng.sample::<f32, _>(uniform))
3110                }
3111                Ok(CpuStorage::F32(data))
3112            }
3113            DType::F64 => {
3114                let mut data = Vec::with_capacity(elem_count);
3115                let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
3116                for _i in 0..elem_count {
3117                    data.push(rng.sample::<f64, _>(uniform))
3118                }
3119                Ok(CpuStorage::F64(data))
3120            }
3121        }
3122    }
3123
3124    fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
3125        use rand::prelude::*;
3126
3127        let elem_count = shape.elem_count();
3128        let mut rng = rand::rng();
3129        match dtype {
3130            DType::U8
3131            | DType::U32
3132            | DType::I16
3133            | DType::I32
3134            | DType::I64
3135            | DType::F6E2M3
3136            | DType::F6E3M2
3137            | DType::F4
3138            | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
3139            DType::BF16 => {
3140                let mut data = Vec::with_capacity(elem_count);
3141                let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
3142                    .map_err(Error::wrap)?;
3143                for _i in 0..elem_count {
3144                    data.push(normal.sample(&mut rng))
3145                }
3146                Ok(CpuStorage::BF16(data))
3147            }
3148            DType::F16 => {
3149                let mut data = Vec::with_capacity(elem_count);
3150                let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
3151                    .map_err(Error::wrap)?;
3152                for _i in 0..elem_count {
3153                    data.push(normal.sample(&mut rng))
3154                }
3155                Ok(CpuStorage::F16(data))
3156            }
3157            DType::F8E4M3 => {
3158                let mut data = Vec::with_capacity(elem_count);
3159                let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
3160                    .map_err(Error::wrap)?;
3161                for _i in 0..elem_count {
3162                    data.push(normal.sample(&mut rng))
3163                }
3164                Ok(CpuStorage::F8E4M3(data))
3165            }
3166            DType::F32 => {
3167                let mut data = Vec::with_capacity(elem_count);
3168                let normal =
3169                    rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
3170                for _i in 0..elem_count {
3171                    data.push(normal.sample(&mut rng))
3172                }
3173                Ok(CpuStorage::F32(data))
3174            }
3175            DType::F64 => {
3176                let mut data = Vec::with_capacity(elem_count);
3177                let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
3178                for _i in 0..elem_count {
3179                    data.push(normal.sample(&mut rng))
3180                }
3181                Ok(CpuStorage::F64(data))
3182            }
3183        }
3184    }
3185
3186    #[allow(clippy::uninit_vec)]
3187    unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3188        let elem_count = shape.elem_count();
3189        // The code below is highly unsafe but hopefully not directly unsound as we only consider
3190        // types that are Copy, not Drop, and for which all bit patterns are proper values.
3191        // It's still pretty risky, see the following for more details:
3192        // https://github.com/rust-lang/rust-clippy/issues/4483
3193        let storage = match dtype {
3194            DType::U8 => {
3195                let mut v = Vec::with_capacity(elem_count);
3196                v.set_len(elem_count);
3197                CpuStorage::U8(v)
3198            }
3199            DType::U32 => {
3200                let mut v = Vec::with_capacity(elem_count);
3201                v.set_len(elem_count);
3202                CpuStorage::U32(v)
3203            }
3204            DType::I16 => {
3205                let mut v = Vec::with_capacity(elem_count);
3206                v.set_len(elem_count);
3207                CpuStorage::I16(v)
3208            }
3209            DType::I32 => {
3210                let mut v = Vec::with_capacity(elem_count);
3211                v.set_len(elem_count);
3212                CpuStorage::I32(v)
3213            }
3214            DType::I64 => {
3215                let mut v = Vec::with_capacity(elem_count);
3216                v.set_len(elem_count);
3217                CpuStorage::I64(v)
3218            }
3219            DType::BF16 => {
3220                let mut v = Vec::with_capacity(elem_count);
3221                v.set_len(elem_count);
3222                CpuStorage::BF16(v)
3223            }
3224            DType::F16 => {
3225                let mut v = Vec::with_capacity(elem_count);
3226                v.set_len(elem_count);
3227                CpuStorage::F16(v)
3228            }
3229            DType::F32 => {
3230                let mut v = Vec::with_capacity(elem_count);
3231                v.set_len(elem_count);
3232                CpuStorage::F32(v)
3233            }
3234            DType::F64 => {
3235                let mut v = Vec::with_capacity(elem_count);
3236                v.set_len(elem_count);
3237                CpuStorage::F64(v)
3238            }
3239            DType::F8E4M3 => {
3240                let mut v = Vec::with_capacity(elem_count);
3241                v.set_len(elem_count);
3242                CpuStorage::F8E4M3(v)
3243            }
3244            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3245                return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt())
3246            }
3247        };
3248        Ok(storage)
3249    }
3250
3251    fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3252        let elem_count = shape.elem_count();
3253        let storage = match dtype {
3254            DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
3255            DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
3256            DType::I16 => CpuStorage::I16(vec![0i16; elem_count]),
3257            DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
3258            DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
3259            DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
3260            DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
3261            DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
3262            DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
3263            DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
3264            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3265                return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt())
3266            }
3267        };
3268        Ok(storage)
3269    }
3270
3271    fn synchronize(&self) -> Result<()> {
3272        Ok(())
3273    }
3274}
3275
3276#[macro_export]
3277macro_rules! map_dtype {
3278    ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
3279        match $storage {
3280            $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
3281            s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
3282        }
3283    };
3284}