Skip to main content

candle_core/cpu_backend/
mod.rs

1//! Implementation of Backend Fns for CPU
2use crate::backend::{BackendDevice, BackendStorage};
3use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
4use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
5use float8::F8E4M3;
6use half::{bf16, f16};
7use rayon::prelude::*;
8
9mod utils;
10pub use utils::{
11    binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
12};
13mod conv2d;
14use conv2d::Conv2D;
15
16const USE_IM2COL_CONV1D: bool = true;
17const USE_COL2IM_CONV1D_TR: bool = true;
18
19// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
20// intercept the oom errors to avoid panicking and provide a proper error.
21#[derive(Debug, Clone)]
22pub enum CpuStorage {
23    U8(Vec<u8>),
24    U32(Vec<u32>),
25    I16(Vec<i16>),
26    I32(Vec<i32>),
27    I64(Vec<i64>),
28    BF16(Vec<bf16>),
29    F16(Vec<f16>),
30    F32(Vec<f32>),
31    F64(Vec<f64>),
32    F8E4M3(Vec<F8E4M3>),
33    // Dummy types that store raw bytes
34    F6E2M3(Vec<u8>),
35    F6E3M2(Vec<u8>),
36    F4(Vec<u8>),
37    F8E8M0(Vec<u8>),
38}
39
40#[derive(Debug, Clone)]
41pub enum CpuStorageRef<'a> {
42    U8(&'a [u8]),
43    U32(&'a [u32]),
44    I16(&'a [i16]),
45    I32(&'a [i32]),
46    I64(&'a [i64]),
47    BF16(&'a [bf16]),
48    F16(&'a [f16]),
49    F32(&'a [f32]),
50    F64(&'a [f64]),
51    F8E4M3(&'a [F8E4M3]),
52    // Dummy types that store raw bytes
53    F6E2M3(&'a [u8]),
54    F6E3M2(&'a [u8]),
55    F4(&'a [u8]),
56    F8E8M0(&'a [u8]),
57}
58
59#[derive(Debug, Clone)]
60pub struct CpuDevice;
61
62struct Cmp(CmpOp);
63impl Map2U8 for Cmp {
64    const OP: &'static str = "cmp";
65    #[inline(always)]
66    fn f<T: WithDType>(
67        &self,
68        lhs: &[T],
69        lhs_l: &Layout,
70        rhs: &[T],
71        rhs_l: &Layout,
72    ) -> Result<Vec<u8>> {
73        let dst = match self.0 {
74            CmpOp::Eq => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x == y)),
75            CmpOp::Ne => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x != y)),
76            CmpOp::Lt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x < y)),
77            CmpOp::Le => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x <= y)),
78            CmpOp::Gt => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x > y)),
79            CmpOp::Ge => binary_map(lhs_l, rhs_l, lhs, rhs, |x, y| u8::from(x >= y)),
80        };
81        Ok(dst)
82    }
83}
84
85struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
86
87impl<I: IntDType> Map2 for WCond<'_, I> {
88    const OP: &'static str = "where";
89    #[inline(always)]
90    fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
91        let vs = match (
92            self.1.contiguous_offsets(),
93            t_l.contiguous_offsets(),
94            f_l.contiguous_offsets(),
95        ) {
96            (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
97                let pred = &self.0[o1..o2];
98                let t = &t[o_t1..o_t2];
99                let f = &f[o_f1..o_f2];
100                pred.iter()
101                    .zip(t.iter().zip(f.iter()))
102                    .map(|(p, (&t, &f))| if p.is_true() { t } else { f })
103                    .collect::<Vec<_>>()
104            }
105            _ => self
106                .1
107                .strided_index()
108                .zip(t_l.strided_index().zip(f_l.strided_index()))
109                .map(|(i_p, (i_t, i_f))| {
110                    if self.0[i_p].is_true() {
111                        t[i_t]
112                    } else {
113                        f[i_f]
114                    }
115                })
116                .collect::<Vec<_>>(),
117        };
118        Ok(vs)
119    }
120}
121
122struct ReduceIndex {
123    reduce_dim_index: usize,
124    use_min: bool,
125    return_index: bool,
126}
127
128impl ReduceIndex {
129    // The value gets replaced if f(s[current_acc], s[i]) returns true.
130    #[inline(always)]
131    fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
132    where
133        T: Clone + Copy,
134        U: Clone + Copy,
135        F: Fn(T, T) -> bool,
136        G: Fn(T, usize) -> U,
137    {
138        let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
139        let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
140        let dst_len = src_l.shape().elem_count() / reduce_dim_size;
141        let mut dst: Vec<U> = Vec::with_capacity(dst_len);
142        let dst_to_set = dst.spare_capacity_mut();
143        let dst_to_set =
144            unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit<U>], &mut [U]>(dst_to_set) };
145        match src_l.contiguous_offsets() {
146            Some((o1, o2)) => {
147                let src = &src[o1..o2];
148                if reduce_dim_stride == 1 {
149                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
150                        let start_src_i = start_src_i * reduce_dim_size;
151                        let src = &src[start_src_i..start_src_i + reduce_dim_size];
152                        let mut acc = 0;
153                        let mut val = src[0];
154                        for (src_i, &s) in src.iter().enumerate() {
155                            if f(val, s) {
156                                acc = src_i;
157                                val = s
158                            }
159                        }
160                        *dst_v = g(val, acc)
161                    }
162                } else {
163                    for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
164                        let (p, q) = (
165                            start_src_i / reduce_dim_stride,
166                            start_src_i % reduce_dim_stride,
167                        );
168                        // start_src_i = p * reduce_dim_stride + q
169                        let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
170                        let src = &src[start_src_i..];
171                        let mut acc = 0;
172                        let mut val = src[0];
173                        for src_i in 0..reduce_dim_size {
174                            let s = src[src_i * reduce_dim_stride];
175                            if f(val, s) {
176                                acc = src_i;
177                                val = s
178                            }
179                        }
180                        *dst_v = g(val, acc)
181                    }
182                }
183            }
184            None => {
185                let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
186                for (unstr_index, src_index) in l.strided_index().enumerate() {
187                    let src = &src[src_index..];
188                    let mut acc = 0;
189                    let mut val = src[0];
190                    for src_i in 0..reduce_dim_size {
191                        let s = src[src_i * reduce_dim_stride];
192                        if f(val, s) {
193                            acc = src_i;
194                            val = s
195                        }
196                    }
197                    dst_to_set[unstr_index] = g(val, acc)
198                }
199            }
200        }
201        unsafe { dst.set_len(dst_len) };
202        Ok(dst)
203    }
204}
205
206impl Map1Any for ReduceIndex {
207    #[inline(always)]
208    fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
209        &self,
210        src: &[T],
211        src_l: &Layout,
212        wrap: W,
213    ) -> Result<CpuStorage> {
214        if src_l.shape().elem_count() == 0 {
215            Err(Error::EmptyTensor { op: "reduce" }.bt())?
216        }
217        let dst = match (self.return_index, self.use_min) {
218            (false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
219            (false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
220            (true, true) => {
221                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
222            }
223            (true, false) => {
224                CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
225            }
226        };
227        Ok(dst)
228    }
229}
230
231struct ReduceSum<'a> {
232    dst_shape: &'a Shape,
233    reduce_dims: &'a [usize],
234    reduce_dims_and_stride: Vec<(usize, usize)>,
235}
236
237impl ReduceSum<'_> {
238    #[inline(always)]
239    fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
240    where
241        T: WithDType,
242    {
243        let mut dst = vec![start_elt; self.dst_shape.elem_count()];
244        match src_l.contiguous_offsets() {
245            Some((o1, o2)) => {
246                let src = &src[o1..o2];
247                // Handle the case where we reduce over the last dimensions separately as it is
248                // fairly common and easy to optimize. This rely on the layout being contiguous!
249                // reduce_dims is sorted, check if it is ranging from a to n-1.
250                let reduce_over_last_dims = self
251                    .reduce_dims
252                    .iter()
253                    .rev()
254                    .enumerate()
255                    .all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
256                if reduce_over_last_dims {
257                    let reduce_sz = self
258                        .reduce_dims_and_stride
259                        .iter()
260                        .map(|(u, _)| u)
261                        .product::<usize>();
262                    for (dst_i, dst_v) in dst.iter_mut().enumerate() {
263                        let src_i = dst_i * reduce_sz;
264                        unsafe {
265                            T::vec_reduce_sum(
266                                src[src_i..src_i + reduce_sz].as_ptr(),
267                                dst_v,
268                                reduce_sz,
269                            )
270                        };
271                    }
272                    return Ok(dst);
273                };
274                for (unstr_index, &src) in src.iter().enumerate() {
275                    let mut dst_index = unstr_index;
276                    // Set the reduce_dims indexes to 0.
277                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
278                        // The compiler is able to optimize the following in a single divmod op.
279                        let (pre, post) = (dst_index / stride, dst_index % stride);
280                        dst_index = (pre / dim) * stride + post;
281                    }
282                    dst[dst_index] += src;
283                }
284            }
285            None => {
286                for (unstr_index, src_index) in src_l.strided_index().enumerate() {
287                    let mut dst_index = unstr_index;
288                    // Set the reduce_dims indexes to 0.
289                    for &(dim, stride) in self.reduce_dims_and_stride.iter() {
290                        // The compiler is able to optimize the following in a single divmod op.
291                        let (pre, post) = (dst_index / stride, dst_index % stride);
292                        dst_index = (pre / dim) * stride + post;
293                    }
294                    dst[dst_index] += src[src_index];
295                }
296            }
297        }
298        Ok(dst)
299    }
300}
301
302impl Map1 for ReduceSum<'_> {
303    #[inline(always)]
304    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
305        self.fold_impl(src, src_l, T::zero())
306    }
307}
308
309struct Affine(f64, f64);
310
311impl Map1 for Affine {
312    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
313        let mul = T::from_f64(self.0);
314        let add = T::from_f64(self.1);
315        Ok(unary_map(vs, layout, |v| v * mul + add))
316    }
317}
318
319struct AvgPool2D((usize, usize), (usize, usize));
320
321impl Map1 for AvgPool2D {
322    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
323        // https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html
324        let (k_h, k_w) = self.0;
325        let (s_h, s_w) = self.1;
326        let (b_sz, c, h, w) = layout.shape().dims4()?;
327        let stride = layout.stride();
328        let (stride_h, stride_w) = (stride[2], stride[3]);
329        let h_out = (h - k_h) / s_h + 1;
330        let w_out = (w - k_w) / s_w + 1;
331        let src_index = layout.start_offset();
332        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
333        let scale = 1f64 / (k_h * k_w) as f64;
334        let scale = T::from_f64(scale);
335        for b_idx in 0..b_sz {
336            let dst = &mut dst[b_idx * c * h_out * w_out..];
337            let src_index = src_index + b_idx * stride[0];
338            for c_idx in 0..c {
339                let dst = &mut dst[c_idx * h_out * w_out..];
340                let src_index = src_index + c_idx * stride[1];
341                for h_idx in 0..h_out {
342                    for w_idx in 0..w_out {
343                        let mut sum = T::zero();
344                        for m in 0..k_h {
345                            for n in 0..k_w {
346                                let m = s_h * h_idx + m;
347                                let n = s_w * w_idx + n;
348                                sum += src[src_index + m * stride_h + n * stride_w]
349                            }
350                        }
351                        dst[h_idx * w_out + w_idx] = sum * scale;
352                    }
353                }
354            }
355        }
356        Ok(dst)
357    }
358}
359
360struct MaxPool2D((usize, usize), (usize, usize));
361
362impl Map1 for MaxPool2D {
363    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
364        // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
365        let (k_h, k_w) = self.0;
366        let (s_h, s_w) = self.1;
367        let (b_sz, c, h, w) = layout.shape().dims4()?;
368        let stride = layout.stride();
369        let (stride_h, stride_w) = (stride[2], stride[3]);
370        let h_out = (h - k_h) / s_h + 1;
371        let w_out = (w - k_w) / s_w + 1;
372        let src_index = layout.start_offset();
373        let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];
374        for b_idx in 0..b_sz {
375            let dst = &mut dst[b_idx * c * h_out * w_out..];
376            let src_index = src_index + b_idx * stride[0];
377            for c_idx in 0..c {
378                let dst = &mut dst[c_idx * h_out * w_out..];
379                let src_index = src_index + c_idx * stride[1];
380                for h_idx in 0..h_out {
381                    for w_idx in 0..w_out {
382                        let mut largest =
383                            src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
384                        for m in 0..k_h {
385                            for n in 0..k_w {
386                                let m = s_h * h_idx + m;
387                                let n = s_w * w_idx + n;
388                                if largest < src[src_index + m * stride_h + n * stride_w] {
389                                    largest = src[src_index + m * stride_h + n * stride_w]
390                                }
391                            }
392                        }
393                        dst[h_idx * w_out + w_idx] = largest;
394                    }
395                }
396            }
397        }
398        Ok(dst)
399    }
400}
401
402struct UpsampleNearest1D(usize);
403
404impl Map1 for UpsampleNearest1D {
405    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
406        // TODO: Specialized implementation for the case 2*sz?
407        let dst_sz = self.0;
408        let (b_sz, c, src_sz) = layout.shape().dims3()?;
409        let stride = layout.stride();
410        let stride_sz = stride[2];
411        let src_index = layout.start_offset();
412        let scale_sz = src_sz as f64 / dst_sz as f64;
413        let mut dst = vec![T::zero(); b_sz * c * dst_sz];
414        let src_idxs = (0..dst_sz)
415            .map(|idx| usize::min(src_sz - 1, (idx as f64 * scale_sz) as usize))
416            .collect::<Vec<_>>();
417        for b_idx in 0..b_sz {
418            let dst = &mut dst[b_idx * c * dst_sz..];
419            let src_index = src_index + b_idx * stride[0];
420            for c_idx in 0..c {
421                let dst = &mut dst[c_idx * dst_sz..];
422                let src_index = src_index + c_idx * stride[1];
423                for (idx, src_idx) in src_idxs.iter().enumerate() {
424                    dst[idx] = src[src_index + src_idx * stride_sz]
425                }
426            }
427        }
428        Ok(dst)
429    }
430}
431
432struct UpsampleNearest2D(usize, usize);
433
434impl Map1 for UpsampleNearest2D {
435    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
436        // TODO: Specialized implementation for the case 2*h, 2*w?
437        let (dst_h, dst_w) = (self.0, self.1);
438        let (b_sz, c, src_h, src_w) = layout.shape().dims4()?;
439        let stride = layout.stride();
440        let (stride_h, stride_w) = (stride[2], stride[3]);
441        let src_index = layout.start_offset();
442        let scale_h = src_h as f64 / dst_h as f64;
443        let scale_w = src_w as f64 / dst_w as f64;
444        let mut dst = vec![T::zero(); b_sz * c * dst_h * dst_w];
445        let src_h_idxs = (0..dst_h)
446            .map(|h_idx| usize::min(src_h - 1, (h_idx as f64 * scale_h) as usize))
447            .collect::<Vec<_>>();
448        let src_w_idxs = (0..dst_w)
449            .map(|w_idx| usize::min(src_w - 1, (w_idx as f64 * scale_w) as usize))
450            .collect::<Vec<_>>();
451        for b_idx in 0..b_sz {
452            let dst = &mut dst[b_idx * c * dst_h * dst_w..];
453            let src_index = src_index + b_idx * stride[0];
454            for c_idx in 0..c {
455                let dst = &mut dst[c_idx * dst_h * dst_w..];
456                let src_index = src_index + c_idx * stride[1];
457                for (h_idx, src_h_idx) in src_h_idxs.iter().enumerate() {
458                    for (w_idx, src_w_idx) in src_w_idxs.iter().enumerate() {
459                        let src_index = src_index + src_h_idx * stride_h + src_w_idx * stride_w;
460                        dst[h_idx * dst_w + w_idx] = src[src_index]
461                    }
462                }
463            }
464        }
465        Ok(dst)
466    }
467}
468
469struct UpsampleBilinear2D {
470    target_h: usize,
471    target_w: usize,
472    align_corners: bool,
473    scale_h_factor: Option<f64>,
474    scale_w_factor: Option<f64>,
475}
476
477impl Map1 for UpsampleBilinear2D {
478    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
479        let (batch, channels, height_in, width_in) = layout.shape().dims4()?;
480        let height_out = self.target_h;
481        let width_out = self.target_w;
482
483        // Early return for identity case
484        if height_in == height_out && width_in == width_out {
485            return Ok(src.to_vec());
486        }
487
488        let stride = layout.stride();
489        let src_offset = layout.start_offset();
490
491        // Calculate scale factors following PyTorch's area_pixel_compute_scale logic
492        let scale_h = if self.align_corners {
493            if height_out > 1 {
494                (height_in - 1) as f64 / (height_out - 1) as f64
495            } else {
496                0.0
497            }
498        } else {
499            // PyTorch's compute_scales_value logic:
500            // If scale_factor was provided, use 1.0 / scale_factor
501            // Otherwise, use input_size / output_size
502            if let Some(scale_factor) = self.scale_h_factor {
503                1.0 / scale_factor
504            } else {
505                height_in as f64 / height_out as f64
506            }
507        };
508
509        let scale_w = if self.align_corners {
510            if width_out > 1 {
511                (width_in - 1) as f64 / (width_out - 1) as f64
512            } else {
513                0.0
514            }
515        } else if let Some(scale_factor) = self.scale_w_factor {
516            1.0 / scale_factor
517        } else {
518            width_in as f64 / width_out as f64
519        };
520
521        // Precompute indices and weights for height
522        let mut h_indices = Vec::with_capacity(height_out);
523        for h_out in 0..height_out {
524            let src_h = if self.align_corners {
525                scale_h * h_out as f64
526            } else {
527                scale_h * (h_out as f64 + 0.5) - 0.5
528            };
529            let src_h_clamped = src_h.max(0.0);
530            let h0 = src_h_clamped.floor() as usize;
531            let h1 = (h0 + 1).min(height_in - 1);
532            let weight_h = (src_h_clamped - h0 as f64).clamp(0.0, 1.0);
533            h_indices.push((h0, h1, weight_h));
534        }
535
536        // Precompute indices and weights for width
537        let mut w_indices = Vec::with_capacity(width_out);
538        for w_out in 0..width_out {
539            let src_w = if self.align_corners {
540                scale_w * w_out as f64
541            } else {
542                scale_w * (w_out as f64 + 0.5) - 0.5
543            };
544            let src_w_clamped = src_w.max(0.0);
545            let w0 = src_w_clamped.floor() as usize;
546            let w1 = (w0 + 1).min(width_in - 1);
547            let weight_w = (src_w_clamped - w0 as f64).clamp(0.0, 1.0);
548            w_indices.push((w0, w1, weight_w));
549        }
550
551        // Allocate output
552        let mut dst = vec![T::zero(); batch * channels * height_out * width_out];
553
554        // Perform bilinear interpolation
555        for b in 0..batch {
556            for c in 0..channels {
557                let base_idx = src_offset + b * stride[0] + c * stride[1];
558                let dst_base = (b * channels + c) * height_out * width_out;
559
560                for (h_out, &(h0, h1, weight_h)) in h_indices.iter().enumerate() {
561                    for (w_out, &(w0, w1, weight_w)) in w_indices.iter().enumerate() {
562                        // Get four neighboring pixels
563                        let idx_00 = base_idx + h0 * stride[2] + w0 * stride[3];
564                        let idx_10 = base_idx + h0 * stride[2] + w1 * stride[3];
565                        let idx_01 = base_idx + h1 * stride[2] + w0 * stride[3];
566                        let idx_11 = base_idx + h1 * stride[2] + w1 * stride[3];
567
568                        let v00 = src[idx_00].to_f64();
569                        let v10 = src[idx_10].to_f64();
570                        let v01 = src[idx_01].to_f64();
571                        let v11 = src[idx_11].to_f64();
572
573                        // Bilinear interpolation
574                        let v_top = v00 * (1.0 - weight_w) + v10 * weight_w;
575                        let v_bottom = v01 * (1.0 - weight_w) + v11 * weight_w;
576                        let value = v_top * (1.0 - weight_h) + v_bottom * weight_h;
577
578                        dst[dst_base + h_out * width_out + w_out] = T::from_f64(value);
579                    }
580                }
581            }
582        }
583
584        Ok(dst)
585    }
586}
587
588struct Gather<'a, I: IntDType> {
589    ids: &'a [I],
590    ids_l: &'a Layout,
591    dim: usize,
592}
593
594impl<I: IntDType> Map1 for Gather<'_, I> {
595    fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
596        let ids = match self.ids_l.contiguous_offsets() {
597            Some((a, b)) => &self.ids[a..b],
598            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
599        };
600        let src = match src_l.contiguous_offsets() {
601            Some((a, b)) => &src[a..b],
602            None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
603        };
604        let dim = self.dim;
605        let ids_dims = self.ids_l.dims();
606        let src_dims = src_l.dims();
607        let dst_len: usize = ids_dims.iter().product();
608        let dst_left_len: usize = ids_dims[..dim].iter().product();
609        let dst_dim_len = ids_dims[dim];
610        let dst_right_len: usize = ids_dims[dim + 1..].iter().product();
611
612        let src_dim_len = src_dims[dim];
613        let src_right_len: usize = src_dims[dim + 1..].iter().product();
614
615        let mut dst = vec![T::zero(); dst_len];
616        for left_i in 0..dst_left_len {
617            let start_src_idx = left_i * src_right_len * src_dim_len;
618            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
619            for i in 0..dst_dim_len {
620                let start_dst_idx = start_dst_idx + i * dst_right_len;
621                for right_i in 0..dst_right_len {
622                    let dst_idx = start_dst_idx + right_i;
623                    let index = ids[dst_idx];
624                    if index == I::max_value() {
625                        dst[dst_idx] = T::zero();
626                    } else {
627                        let index = index.as_usize();
628                        if index >= src_dim_len {
629                            Err(Error::InvalidIndex {
630                                index,
631                                size: src_dim_len,
632                                op: "gather",
633                            }
634                            .bt())?
635                        }
636                        let src_idx = start_src_idx + index * src_right_len + right_i;
637                        dst[dst_idx] = src[src_idx]
638                    }
639                }
640            }
641        }
642        Ok(dst)
643    }
644}
645
646struct IndexSelect<'a, T: IntDType> {
647    ids: &'a [T],
648    ids_l: &'a Layout,
649    dim: usize,
650}
651
652impl<I: IntDType> Map1 for IndexSelect<'_, I> {
653    fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
654        let src = match layout.contiguous_offsets() {
655            Some((a, b)) => &src[a..b],
656            None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
657        };
658        let dim = self.dim;
659        let n_ids = match self.ids_l.dims() {
660            [n_ids] => *n_ids,
661            d => Err(Error::UnexpectedNumberOfDims {
662                expected: 1,
663                got: d.len(),
664                shape: self.ids_l.shape().clone(),
665            }
666            .bt())?,
667        };
668        let stride_ids = self.ids_l.stride()[0];
669        let mut dst_dims = layout.dims().to_vec();
670        let src_dim = dst_dims[dim];
671        dst_dims[dim] = n_ids;
672        let dst_len: usize = dst_dims.iter().product();
673        let left_len: usize = dst_dims[..dim].iter().product();
674        let right_len: usize = dst_dims[dim + 1..].iter().product();
675        let mut dst = vec![T::zero(); dst_len];
676        for left_i in 0..left_len {
677            let start_src_idx = left_i * right_len * src_dim;
678            let start_dst_idx = left_i * right_len * n_ids;
679            for i in 0..n_ids {
680                let start_dst_idx = start_dst_idx + i * right_len;
681                let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
682                if index == I::max_value() {
683                    dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
684                } else {
685                    let index = index.as_usize();
686                    if index >= src_dim {
687                        Err(Error::InvalidIndex {
688                            index,
689                            size: src_dim,
690                            op: "index-select",
691                        }
692                        .bt())?
693                    }
694                    let start_src_idx = start_src_idx + index * right_len;
695                    dst[start_dst_idx..start_dst_idx + right_len]
696                        .copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
697                }
698            }
699        }
700        Ok(dst)
701    }
702}
703
704trait ElemUpdate {
705    fn f<T: WithDType>(dst: &mut T, src: T);
706}
707
708struct Set;
709struct Add;
710
711impl ElemUpdate for Set {
712    fn f<T: WithDType>(dst: &mut T, src: T) {
713        *dst = src
714    }
715}
716
717impl ElemUpdate for Add {
718    fn f<T: WithDType>(dst: &mut T, src: T) {
719        *dst += src
720    }
721}
722
723struct Scatter<'a, I: IntDType, M: ElemUpdate> {
724    ids: &'a [I],
725    ids_l: &'a Layout,
726    dim: usize,
727    _phantom: std::marker::PhantomData<M>,
728}
729
730impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
731    fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
732        Self {
733            ids,
734            ids_l,
735            dim,
736            _phantom: Default::default(),
737        }
738    }
739}
740
741impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
742    const OP: &'static str = "scatter";
743    fn f<T: WithDType>(
744        &self,
745        dst: &mut [T],
746        dst_l: &Layout,
747        src: &[T],
748        src_l: &Layout,
749    ) -> Result<()> {
750        let dst = match dst_l.contiguous_offsets() {
751            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
752            Some((o1, o2)) => &mut dst[o1..o2],
753        };
754
755        let src = match src_l.contiguous_offsets() {
756            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
757            Some((o1, o2)) => &src[o1..o2],
758        };
759
760        let dim = self.dim;
761        let ids_dims = self.ids_l.dims();
762        let dst_dims = dst_l.dims();
763        let dst_dim_len = dst_dims[dim];
764        let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
765
766        let ids_left_len: usize = ids_dims[..dim].iter().product();
767        let ids_dim_len = ids_dims[dim];
768        let ids_right_len: usize = ids_dims[dim + 1..].iter().product();
769
770        let ids = match self.ids_l.contiguous_offsets() {
771            Some((a, b)) => &self.ids[a..b],
772            None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
773        };
774        for left_i in 0..ids_left_len {
775            let start_ids_idx = left_i * ids_right_len * ids_dim_len;
776            let start_dst_idx = left_i * dst_right_len * dst_dim_len;
777            for i in 0..ids_dim_len {
778                let start_ids_idx = start_ids_idx + i * ids_right_len;
779                for right_i in 0..dst_right_len {
780                    let ids_idx = start_ids_idx + right_i;
781                    let index = ids[ids_idx];
782                    if index == I::max_value() {
783                        continue;
784                    }
785                    let index = index.as_usize();
786                    if index >= dst_dim_len {
787                        Err(Error::InvalidIndex {
788                            index,
789                            size: dst_dim_len,
790                            op: "gather",
791                        }
792                        .bt())?
793                    }
794                    let dst_idx = start_dst_idx + index * dst_right_len + right_i;
795                    M::f(&mut dst[dst_idx], src[ids_idx])
796                }
797            }
798        }
799
800        Ok(())
801    }
802}
803
804struct IndexAdd<'a, I: IntDType> {
805    ids: &'a [I],
806    dim: usize,
807}
808
809impl<I: IntDType> Map2 for IndexAdd<'_, I> {
810    const OP: &'static str = "index-add";
811    // https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
812    // v1, l1 -> self
813    fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
814        let dst_len = l1.shape().elem_count();
815        let mut dst = vec![T::zero(); dst_len];
816        copy_strided_src_(v1, &mut dst, 0, l1);
817        let src = match src_l.contiguous_offsets() {
818            None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
819            Some((o1, o2)) => &src[o1..o2],
820        };
821        let dim = self.dim;
822        let max_idx = l1.dims()[dim];
823        let pre_dim = src_l.dims()[..dim].iter().product::<usize>();
824        let src_dim_sz = src_l.dims()[dim];
825        let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
826        if dim == 0 {
827            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
828                if *dst_idx == I::max_value() {
829                    continue;
830                }
831                let dst_idx = dst_idx.as_usize();
832                if dst_idx >= max_idx {
833                    Err(Error::InvalidIndex {
834                        index: dst_idx,
835                        op: "index-add",
836                        size: max_idx,
837                    })?
838                }
839                let src_idx = src_idx * post_dim;
840                let dst_idx = dst_idx * post_dim;
841                let src = &src[src_idx..src_idx + post_dim];
842                let dst = &mut dst[dst_idx..dst_idx + post_dim];
843                for (d, &s) in dst.iter_mut().zip(src.iter()) {
844                    *d += s
845                }
846            }
847        } else {
848            for (src_idx, dst_idx) in self.ids.iter().enumerate() {
849                if *dst_idx == I::max_value() {
850                    continue;
851                }
852                let dst_idx = dst_idx.as_usize();
853                if dst_idx >= max_idx {
854                    Err(Error::InvalidIndex {
855                        index: dst_idx,
856                        op: "index-add",
857                        size: max_idx,
858                    })?
859                }
860                for pre_i in 0..pre_dim {
861                    let pre_src_i = (pre_i * src_dim_sz + src_idx) * post_dim;
862                    let pre_dst_i = (pre_i * max_idx + dst_idx) * post_dim;
863                    let src = &src[pre_src_i..pre_src_i + post_dim];
864                    let dst = &mut dst[pre_dst_i..pre_dst_i + post_dim];
865                    for (d, &s) in dst.iter_mut().zip(src.iter()) {
866                        *d += s
867                    }
868                }
869            }
870        }
871        Ok(dst)
872    }
873}
874
875#[allow(clippy::too_many_arguments)]
876fn copy2d_<T: Copy>(
877    src: &[T],
878    dst: &mut [T],
879    d1: usize,
880    d2: usize,
881    src_stride1: usize,
882    dst_stride1: usize,
883    src_offset: usize,
884    dst_offset: usize,
885) {
886    // Both are contiguous - one memcpy covers everything.
887    if src_stride1 == d2 && dst_stride1 == d2 {
888        let src_start = src_offset;
889        let dst_start = dst_offset;
890        dst[dst_start..dst_start + d1 * d2].copy_from_slice(&src[src_start..src_start + d1 * d2]);
891        return;
892    }
893    let mut src_idx = src_offset;
894    let mut dst_idx = dst_offset;
895    for _ in 0..d1 {
896        dst[dst_idx..dst_idx + d2].copy_from_slice(&src[src_idx..src_idx + d2]);
897        src_idx += src_stride1;
898        dst_idx += dst_stride1;
899    }
900}
901
902fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
903    match src_l.strided_blocks() {
904        crate::StridedBlocks::SingleBlock { start_offset, len } => dst
905            [dst_offset..dst_offset + len]
906            .copy_from_slice(&src[start_offset..start_offset + len]),
907        crate::StridedBlocks::UniformBlocks {
908            start_offset,
909            block_len,
910            count,
911            src_stride,
912        } => copy2d_(
913            src,
914            dst,
915            count,
916            block_len,
917            src_stride,
918            block_len,
919            start_offset,
920            dst_offset,
921        ),
922        crate::StridedBlocks::MultipleBlocks {
923            block_start_index,
924            block_len: 1,
925        } => {
926            let n = block_start_index.len();
927            let dst = &mut dst[dst_offset..dst_offset + n];
928            for (dst_elem, src_index) in dst.iter_mut().zip(block_start_index) {
929                *dst_elem = src[src_index]
930            }
931        }
932        crate::StridedBlocks::MultipleBlocks {
933            block_start_index,
934            block_len,
935        } => {
936            let n_blocks = block_start_index.len();
937            let dst = &mut dst[dst_offset..dst_offset + n_blocks * block_len];
938            let mut dst_index = 0;
939            for src_index in block_start_index {
940                dst[dst_index..dst_index + block_len]
941                    .copy_from_slice(&src[src_index..src_index + block_len]);
942                dst_index += block_len;
943            }
944        }
945    }
946}
947
948struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
949
950impl Map2 for Conv1D<'_> {
951    const OP: &'static str = "conv1d";
952    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
953        let p = self.0;
954        let inp = &inp[inp_l.start_offset()..];
955        let k = &k[k_l.start_offset()..];
956        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
957        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
958        let l_out = p.l_out();
959        let dst_elems = p.c_out * l_out * p.b_size;
960        // The output shape is [b_size, c_out, l_out]
961        let dst = vec![T::zero(); dst_elems];
962
963        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
964        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
965        for b_idx in 0..p.b_size {
966            for src_l in 0..p.l_in {
967                for src_c_idx in 0..p.c_in {
968                    let inp_idx = b_idx * inp_s0 + src_c_idx * inp_s1 + src_l * inp_s2;
969                    inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in + src_c_idx] = inp[inp_idx]
970                }
971            }
972        }
973
974        for offset in 0..p.k_size {
975            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
976                let dst_idx = dst_c_idx * l_out;
977                let k_cont = (0..p.c_in)
978                    .map(|c_in_idx| k[dst_c_idx * k_s0 + c_in_idx * k_s1 + offset * k_s2])
979                    .collect::<Vec<_>>();
980                for b_idx in 0..p.b_size {
981                    let dst_idx = dst_idx + b_idx * p.c_out * l_out;
982                    for dst_l in 0..l_out {
983                        let dst_idx = dst_idx + dst_l;
984                        let src_l = p.stride * dst_l + offset * p.dilation;
985                        if src_l < p.padding || src_l >= p.padding + p.l_in {
986                            continue;
987                        }
988                        let src_l = src_l - p.padding;
989                        let inp_cont = &inp_cont[b_idx * p.l_in * p.c_in + src_l * p.c_in..];
990                        assert!(inp_cont.len() >= p.c_in);
991                        assert!(k_cont.len() >= p.c_in);
992                        let mut d = T::zero();
993                        unsafe { T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in) }
994                        let dst_p = dst.as_ptr();
995                        // Safety: dst_idx are uniques per dst_c_idx which is used to parallelise
996                        // the different tasks so no two threads can try to write at the same
997                        // location.
998                        unsafe {
999                            let ptr = dst_p.add(dst_idx) as *mut T;
1000                            *ptr += d
1001                        }
1002                    }
1003                }
1004            })
1005        }
1006        Ok(dst)
1007    }
1008}
1009
1010struct Im2Col1D {
1011    l_k: usize,
1012    stride: usize,
1013    dilation: usize,
1014    padding: usize,
1015}
1016
1017impl Im2Col1D {
1018    fn l_out(&self, l: usize) -> usize {
1019        (l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
1020    }
1021}
1022
1023impl Map1 for Im2Col1D {
1024    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1025        let &Self {
1026            l_k,
1027            stride,
1028            dilation,
1029            padding,
1030        } = self;
1031        let (b, c, l) = layout.shape().dims3()?;
1032        let l_out = self.l_out(l);
1033        let src = &vs[layout.start_offset()..];
1034        let mut dst = vec![T::zero(); b * l_out * c * l_k];
1035        let (src_s0, src_s1, src_s2) = {
1036            let s = layout.stride();
1037            (s[0], s[1], s[2])
1038        };
1039        // TODO: provide specialized kernels for the common use cases.
1040        // - l_k = 1
1041        // - padding = 0
1042        // - stride = 1
1043        // - dilation = 1
1044        for b_idx in 0..b {
1045            let src_idx = b_idx * src_s0;
1046            let dst_idx = b_idx * l_out * c * l_k;
1047            for l_idx in 0..l_out {
1048                let dst_idx = dst_idx + l_idx * c * l_k;
1049                for c_idx in 0..c {
1050                    let dst_idx = dst_idx + c_idx * l_k;
1051                    let src_idx = c_idx * src_s1 + src_idx;
1052                    for l_k_idx in 0..l_k {
1053                        let src_l = l_idx * stride + l_k_idx * dilation;
1054                        if padding != 0 && (src_l < padding || src_l >= l + padding) {
1055                            continue;
1056                        }
1057                        let src_l = src_l - padding;
1058                        let src_idx = src_idx + src_l * src_s2;
1059                        let dst_idx = dst_idx + l_k_idx;
1060                        dst[dst_idx] = src[src_idx]
1061                    }
1062                }
1063            }
1064        }
1065        Ok(dst)
1066    }
1067}
1068
1069struct Im2Col {
1070    h_k: usize,
1071    w_k: usize,
1072    stride: usize,
1073    dilation: usize,
1074    padding: usize,
1075}
1076
1077impl Im2Col {
1078    fn hw_out(&self, h: usize, w: usize) -> (usize, usize) {
1079        let h_out = (h + 2 * self.padding - self.dilation * (self.h_k - 1) - 1) / self.stride + 1;
1080        let w_out = (w + 2 * self.padding - self.dilation * (self.w_k - 1) - 1) / self.stride + 1;
1081        (h_out, w_out)
1082    }
1083}
1084
1085impl Map1 for Im2Col {
1086    fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
1087        let &Self {
1088            h_k,
1089            w_k,
1090            stride,
1091            dilation,
1092            padding,
1093        } = self;
1094        let (b, c, h, w) = layout.shape().dims4()?;
1095        let (h_out, w_out) = self.hw_out(h, w);
1096        let src = &vs[layout.start_offset()..];
1097        let mut dst = vec![T::zero(); b * h_out * w_out * c * h_k * w_k];
1098        let (src_s0, src_s1, src_s2, src_s3) = {
1099            let s = layout.stride();
1100            (s[0], s[1], s[2], s[3])
1101        };
1102        // TODO: provide specialized kernels for the common use cases.
1103        // - h_k = w_k = 1
1104        // - padding = 0
1105        // - stride = 1
1106        // - dilation = 1
1107        for b_idx in 0..b {
1108            let src_idx = b_idx * src_s0;
1109            let dst_idx = b_idx * h_out * w_out * c * h_k * w_k;
1110            for h_idx in 0..h_out {
1111                let dst_idx = dst_idx + h_idx * w_out * c * h_k * w_k;
1112                for w_idx in 0..w_out {
1113                    let dst_idx = dst_idx + w_idx * c * h_k * w_k;
1114                    for c_idx in 0..c {
1115                        let dst_idx = dst_idx + c_idx * h_k * w_k;
1116                        let src_idx = c_idx * src_s1 + src_idx;
1117                        for h_k_idx in 0..h_k {
1118                            let src_h = h_idx * stride + h_k_idx * dilation;
1119                            if padding != 0 && (src_h < padding || src_h >= h + padding) {
1120                                continue;
1121                            }
1122                            let src_h = src_h - padding;
1123                            let src_idx = src_idx + src_h * src_s2;
1124                            let dst_idx = dst_idx + h_k_idx * w_k;
1125                            for w_k_idx in 0..w_k {
1126                                let src_w = w_idx * stride + w_k_idx * dilation;
1127                                if padding != 0 && (src_w < padding || src_w >= w + padding) {
1128                                    continue;
1129                                }
1130                                let src_w = src_w - padding;
1131                                let src_idx = src_idx + src_w * src_s3;
1132                                let dst_idx = dst_idx + w_k_idx;
1133                                dst[dst_idx] = src[src_idx]
1134                            }
1135                        }
1136                    }
1137                }
1138            }
1139        }
1140        Ok(dst)
1141    }
1142}
1143
1144struct Col2Im1D {
1145    stride: usize,
1146}
1147
1148impl Map1 for Col2Im1D {
1149    fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
1150        let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
1151        let stride = self.stride;
1152        let l_out = (l_in - 1) * stride + k_size;
1153        let mut im = vec![T::zero(); b_size * c_out * l_out];
1154        let (dst_s0, dst_s1) = (c_out * l_out, l_out);
1155        let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
1156        for l_in_i in 0..l_in {
1157            for k_i in 0..k_size {
1158                let l_out_i = l_in_i * stride + k_i;
1159                for b_i in 0..b_size {
1160                    for c_i in 0..c_out {
1161                        let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
1162                        let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
1163                        im[dst_idx] += col[src_idx]
1164                    }
1165                }
1166            }
1167        }
1168        Ok(im)
1169    }
1170}
1171
1172struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
1173
1174impl Map2 for ConvTranspose1D<'_> {
1175    const OP: &'static str = "conv_transpose1d";
1176    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1177        let p = self.0;
1178        let inp = &inp[inp_l.start_offset()..];
1179        let k = &k[k_l.start_offset()..];
1180        let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
1181        let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
1182        let l_out = p.l_out();
1183
1184        // Output shape: [b_size, c_out, l_out].
1185        let dst_elems = p.c_out * l_out * p.b_size;
1186        let dst = vec![T::zero(); dst_elems];
1187        let dst_s0 = p.c_out * l_out;
1188        let dst_s1 = l_out;
1189        let dst_s2 = 1;
1190
1191        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1192        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.l_in];
1193        let cont_s0 = p.l_in * p.c_in;
1194        let cont_s1 = p.c_in;
1195        for b_idx in 0..p.b_size {
1196            for l_idx in 0..p.l_in {
1197                for c_idx in 0..p.c_in {
1198                    let src_idx = b_idx * inp_s0 + c_idx * inp_s1 + l_idx * inp_s2;
1199                    let dst_idx = b_idx * cont_s0 + l_idx * cont_s1 + c_idx;
1200                    inp_cont[dst_idx] = inp[src_idx]
1201                }
1202            }
1203        }
1204
1205        for k_idx in 0..p.k_size {
1206            (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1207                let k_cont = (0..p.c_in)
1208                    .map(|c_in_idx| k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_idx * k_s2])
1209                    .collect::<Vec<_>>();
1210                for b_idx in 0..p.b_size {
1211                    for l_idx in 0..p.l_in {
1212                        let out_idx = l_idx * p.stride + k_idx * p.dilation;
1213                        if out_idx < p.padding {
1214                            continue;
1215                        }
1216                        let out_idx = out_idx - p.padding;
1217                        if out_idx < l_out {
1218                            let inp_cont = &inp_cont[b_idx * cont_s0 + l_idx * cont_s1..];
1219                            let dst_idx = b_idx * dst_s0 + out_idx * dst_s2 + dst_c_idx * dst_s1;
1220                            let mut d = T::zero();
1221                            unsafe {
1222                                T::vec_dot(inp_cont.as_ptr(), k_cont.as_ptr(), &mut d, p.c_in)
1223                            }
1224                            let dst_p = dst.as_ptr();
1225                            // Safety: dst_idx are uniques per dst_c_idx which is used to
1226                            // parallelise the different tasks so no two threads can try to
1227                            // write at the same location.
1228                            unsafe {
1229                                let ptr = dst_p.add(dst_idx) as *mut T;
1230                                *ptr += d
1231                            }
1232                        }
1233                    }
1234                }
1235            })
1236        }
1237        Ok(dst)
1238    }
1239}
1240
1241struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
1242
1243impl Map2 for ConvTranspose2D<'_> {
1244    const OP: &'static str = "conv_transpose2d";
1245    fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
1246        let p = self.0;
1247        let inp = &inp[inp_l.start_offset()..];
1248        let (inp_s0, inp_s1, inp_s2, inp_s3) = crate::shape::dims4(inp_l.stride())?;
1249        let k = &k[k_l.start_offset()..];
1250        let (k_s0, k_s1, k_s2, k_s3) = crate::shape::dims4(k_l.stride())?;
1251        let (out_h, out_w) = (p.out_h(), p.out_w());
1252
1253        // Output shape: [b_size, c_out, out_h, out_w].
1254        let dst = vec![T::zero(); p.b_size * p.c_out * out_h * out_w];
1255        let dst_s0 = p.c_out * out_h * out_w;
1256        let dst_s1 = out_h * out_w;
1257        let dst_s2 = out_w;
1258        let dst_s3 = 1;
1259
1260        // TODO: Avoid making this copy if `inp` already has the appropriate layout.
1261        let mut inp_cont = vec![T::zero(); p.b_size * p.c_in * p.i_h * p.i_w];
1262        let cont_s0 = p.i_h * p.i_w * p.c_in;
1263        let cont_s1 = p.i_w * p.c_in;
1264        let cont_s2 = p.c_in;
1265        for b_idx in 0..p.b_size {
1266            for h_idx in 0..p.i_h {
1267                for w_idx in 0..p.i_w {
1268                    for c_idx in 0..p.c_in {
1269                        let src_idx =
1270                            b_idx * inp_s0 + c_idx * inp_s1 + h_idx * inp_s2 + w_idx * inp_s3;
1271                        let dst_idx = b_idx * cont_s0 + h_idx * cont_s1 + w_idx * cont_s2 + c_idx;
1272                        inp_cont[dst_idx] = inp[src_idx]
1273                    }
1274                }
1275            }
1276        }
1277
1278        for k_y in 0..p.k_h {
1279            for k_x in 0..p.k_w {
1280                (0..p.c_out).into_par_iter().for_each(|dst_c_idx| {
1281                    let k_cont = (0..p.c_in)
1282                        .map(|c_in_idx| {
1283                            k[c_in_idx * k_s0 + dst_c_idx * k_s1 + k_y * k_s2 + k_x * k_s3]
1284                        })
1285                        .collect::<Vec<_>>();
1286                    for b_idx in 0..p.b_size {
1287                        for inp_y in 0..p.i_h {
1288                            for inp_x in 0..p.i_w {
1289                                let out_x = inp_x * p.stride + k_x * p.dilation;
1290                                let out_y = inp_y * p.stride + k_y * p.dilation;
1291                                if out_x < p.padding || out_y < p.padding {
1292                                    continue;
1293                                }
1294                                let out_x = out_x - p.padding;
1295                                let out_y = out_y - p.padding;
1296                                if out_x < out_w && out_y < out_h {
1297                                    let inp_cont = &inp_cont
1298                                        [b_idx * cont_s0 + inp_y * cont_s1 + inp_x * cont_s2..];
1299                                    let dst_idx = b_idx * dst_s0
1300                                        + out_y * dst_s2
1301                                        + out_x * dst_s3
1302                                        + dst_c_idx * dst_s1;
1303                                    let mut d = T::zero();
1304                                    unsafe {
1305                                        T::vec_dot(
1306                                            inp_cont.as_ptr(),
1307                                            k_cont.as_ptr(),
1308                                            &mut d,
1309                                            p.c_in,
1310                                        )
1311                                    }
1312                                    let dst_p = dst.as_ptr();
1313                                    // Safety: dst_idx are uniques per dst_c_idx which is used to
1314                                    // parallelise the different tasks so no two threads can try to
1315                                    // write at the same location.
1316                                    unsafe {
1317                                        let ptr = dst_p.add(dst_idx) as *mut T;
1318                                        *ptr += d
1319                                    }
1320                                }
1321                            }
1322                        }
1323                    }
1324                })
1325            }
1326        }
1327        Ok(dst)
1328    }
1329}
1330
1331struct MatMul((usize, usize, usize, usize));
1332
1333impl MatMul {
1334    fn striding_error(&self, lhs_l: &Layout, rhs_l: &Layout, msg: &'static str) -> Error {
1335        Error::MatMulUnexpectedStriding(Box::new(crate::error::MatMulUnexpectedStriding {
1336            lhs_l: lhs_l.clone(),
1337            rhs_l: rhs_l.clone(),
1338            bmnk: self.0,
1339            msg,
1340        }))
1341        .bt()
1342    }
1343
1344    fn ab_skip(&self, lhs_l: &Layout, rhs_l: &Layout) -> Result<(usize, usize)> {
1345        let lhs_stride = lhs_l.stride();
1346        let rhs_stride = rhs_l.stride();
1347        let rank = lhs_stride.len();
1348        let (_b, m, n, k) = self.0;
1349        let a_skip: usize = match lhs_stride[..rank - 2] {
1350            [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
1351            [_, stride] if lhs_l.dims()[0] == 1 => stride,
1352            [stride, _] if lhs_l.dims()[1] == 1 => stride,
1353            [stride] => stride,
1354            [] => m * k,
1355            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
1356        };
1357        let b_skip: usize = match rhs_stride[..rank - 2] {
1358            [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
1359            [_, stride] if rhs_l.dims()[0] == 1 => stride,
1360            [stride, _] if rhs_l.dims()[1] == 1 => stride,
1361            [stride] => stride,
1362            [] => n * k,
1363            _ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
1364        };
1365        Ok((a_skip, b_skip))
1366    }
1367}
1368
1369impl Map2 for MatMul {
1370    const OP: &'static str = "mat_mul";
1371
1372    #[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
1373    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1374        &self,
1375        lhs: &[T],
1376        lhs_l: &Layout,
1377        rhs: &[T],
1378        rhs_l: &Layout,
1379    ) -> Result<Vec<T>> {
1380        use gemm::{gemm, Parallelism};
1381
1382        match T::DTYPE {
1383            DType::F16 | DType::F32 | DType::F64 => {}
1384            _ => Err(Error::UnsupportedDTypeForOp(T::DTYPE, "matmul").bt())?,
1385        }
1386
1387        let (b, m, n, k) = self.0;
1388        let lhs = &lhs[lhs_l.start_offset()..];
1389        let rhs = &rhs[rhs_l.start_offset()..];
1390
1391        let lhs_stride = lhs_l.stride();
1392        let rhs_stride = rhs_l.stride();
1393        let rank = lhs_stride.len();
1394        let lhs_cs = lhs_stride[rank - 1];
1395        let lhs_rs = lhs_stride[rank - 2];
1396
1397        let rhs_cs = rhs_stride[rank - 1];
1398        let rhs_rs = rhs_stride[rank - 2];
1399
1400        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1401        let c_skip: usize = m * n;
1402
1403        let dst_shape: Shape = (m, n).into();
1404        let dst_strides = dst_shape.stride_contiguous();
1405        let dst_rs = dst_strides[0];
1406        let dst_cs = dst_strides[1];
1407
1408        let mut dst = vec![T::zero(); b * m * n];
1409        let num_threads = crate::utils::get_num_threads();
1410        let parallelism = if num_threads > 1 {
1411            Parallelism::Rayon(num_threads)
1412        } else {
1413            Parallelism::None
1414        };
1415        let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
1416            // a_skip and c_skip should be updated but step is always 0 so
1417            // it wouldn't matter.
1418            (1, b * m, n, k)
1419        } else if a_skip == 0 && b_skip == n * k {
1420            (1, m, b * n, k)
1421        } else {
1422            (b, m, n, k)
1423        };
1424        for step in 0..b {
1425            let lhs_p = &lhs[step * a_skip..];
1426            let rhs_p = &rhs[step * b_skip..];
1427            let dst_p = &mut dst[step * c_skip..];
1428            unsafe {
1429                gemm(
1430                    /* m: usize = */ m,
1431                    /* n: usize = */ n,
1432                    /* k: usize = */ k,
1433                    /* dst: *mut T = */ dst_p.as_mut_ptr(),
1434                    /* dst_cs: isize = */ dst_cs as isize,
1435                    /* dst_rs: isize = */ dst_rs as isize,
1436                    /* read_dst: bool = */ false,
1437                    /* lhs: *const T = */ lhs_p.as_ptr(),
1438                    /* lhs_cs: isize = */ lhs_cs as isize,
1439                    /* lhs_rs: isize = */ lhs_rs as isize,
1440                    /* rhs: *const T = */ rhs_p.as_ptr(),
1441                    /* rhs_cs: isize = */ rhs_cs as isize,
1442                    /* rhs_rs: isize = */ rhs_rs as isize,
1443                    /* alpha: T = */ T::zero(),
1444                    /* beta: T = */ T::one(),
1445                    /* conj_dst: bool = */ false,
1446                    /* conj_lhs: bool = */ false,
1447                    /* conj_rhs: bool = */ false,
1448                    parallelism,
1449                )
1450            }
1451        }
1452        Ok(dst)
1453    }
1454
1455    #[cfg(feature = "accelerate")]
1456    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1457        &self,
1458        lhs: &[T],
1459        lhs_l: &Layout,
1460        rhs: &[T],
1461        rhs_l: &Layout,
1462    ) -> Result<Vec<T>> {
1463        let (b, m, n, k) = self.0;
1464        let lhs = &lhs[lhs_l.start_offset()..];
1465        let rhs = &rhs[rhs_l.start_offset()..];
1466
1467        let lhs_stride = lhs_l.stride();
1468        let rhs_stride = rhs_l.stride();
1469
1470        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1471        let c_skip: usize = m * n;
1472
1473        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1474        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1475        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1476        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1477
1478        let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1479            (n as i32, b'N')
1480        } else if rhs_m1 == k && rhs_m2 == 1 {
1481            (k as i32, b'T')
1482        } else {
1483            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1484        };
1485        // The b tensor has dims batching, m, k (lhs)
1486        let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1487            (k as i32, b'N')
1488        } else if lhs_m1 == m && lhs_m2 == 1 {
1489            (m as i32, b'T')
1490        } else {
1491            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1492        };
1493
1494        let mut dst = vec![T::zero(); b * m * n];
1495        match T::DTYPE {
1496            DType::F16 => {
1497                crate::bail!("the accelerate backend does not support f16 matmul")
1498            }
1499            DType::F32 => {
1500                for step in 0..b {
1501                    let lhs_p = &lhs[step * a_skip..];
1502                    let rhs_p = &rhs[step * b_skip..];
1503                    let dst_p = &mut dst[step * c_skip..];
1504                    unsafe {
1505                        let a = rhs_p.as_ptr() as *const f32;
1506                        let b = lhs_p.as_ptr() as *const f32;
1507                        let c = dst_p.as_mut_ptr() as *mut f32;
1508                        let a = std::slice::from_raw_parts(a, a_skip);
1509                        let b = std::slice::from_raw_parts(b, b_skip);
1510                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1511                        crate::accelerate::sgemm(
1512                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1513                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1514                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1515                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1516                        )
1517                    }
1518                }
1519            }
1520            DType::F64 => {
1521                for step in 0..b {
1522                    let lhs_p = &lhs[step * a_skip..];
1523                    let rhs_p = &rhs[step * b_skip..];
1524                    let dst_p = &mut dst[step * c_skip..];
1525                    unsafe {
1526                        let a = rhs_p.as_ptr() as *const f64;
1527                        let b = lhs_p.as_ptr() as *const f64;
1528                        let c = dst_p.as_mut_ptr() as *mut f64;
1529                        let a = std::slice::from_raw_parts(a, a_skip);
1530                        let b = std::slice::from_raw_parts(b, b_skip);
1531                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1532                        crate::accelerate::dgemm(
1533                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1534                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1535                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1536                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1537                        )
1538                    }
1539                }
1540            }
1541            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1542        }
1543        Ok(dst)
1544    }
1545
1546    #[cfg(feature = "mkl")]
1547    fn f<T: 'static + WithDType + num_traits::Num + Copy>(
1548        &self,
1549        lhs: &[T],
1550        lhs_l: &Layout,
1551        rhs: &[T],
1552        rhs_l: &Layout,
1553    ) -> Result<Vec<T>> {
1554        let (b, m, n, k) = self.0;
1555        let lhs = &lhs[lhs_l.start_offset()..];
1556        let rhs = &rhs[rhs_l.start_offset()..];
1557
1558        let lhs_stride = lhs_l.stride();
1559        let rhs_stride = rhs_l.stride();
1560
1561        let (a_skip, b_skip) = self.ab_skip(lhs_l, rhs_l)?;
1562        let c_skip: usize = m * n;
1563
1564        let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
1565        let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
1566        let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
1567        let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
1568
1569        let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
1570            (n as i32, b'N')
1571        } else if rhs_m1 == k && rhs_m2 == 1 {
1572            (k as i32, b'T')
1573        } else {
1574            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
1575        };
1576        // The b tensor has dims batching, m, k (lhs)
1577        let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
1578            (k as i32, b'N')
1579        } else if lhs_m1 == m && lhs_m2 == 1 {
1580            (m as i32, b'T')
1581        } else {
1582            Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
1583        };
1584
1585        let mut dst = vec![T::zero(); b * m * n];
1586        match T::DTYPE {
1587            DType::F16 => {
1588                for step in 0..b {
1589                    let lhs_p = &lhs[step * a_skip..];
1590                    let rhs_p = &rhs[step * b_skip..];
1591                    let dst_p = &mut dst[step * c_skip..];
1592                    unsafe {
1593                        let a = rhs_p.as_ptr() as *const f16;
1594                        let b = lhs_p.as_ptr() as *const f16;
1595                        let c = dst_p.as_mut_ptr() as *mut f16;
1596                        let a = std::slice::from_raw_parts(a, a_skip);
1597                        let b = std::slice::from_raw_parts(b, b_skip);
1598                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1599                        crate::mkl::hgemm(
1600                            transa,
1601                            transb,
1602                            /* m= */ n as i32,
1603                            /* n= */ m as i32,
1604                            /* k= */ k as i32,
1605                            /* alpha= */ f16::ONE,
1606                            /* a= */ a,
1607                            /* lda= */ lda,
1608                            /* b= */ b,
1609                            /* ldb= */ ldb,
1610                            /* beta= */ f16::ZERO,
1611                            /* c= */ c,
1612                            /* ldc= */ n as i32,
1613                        )
1614                    }
1615                }
1616            }
1617            DType::F32 => {
1618                for step in 0..b {
1619                    let lhs_p = &lhs[step * a_skip..];
1620                    let rhs_p = &rhs[step * b_skip..];
1621                    let dst_p = &mut dst[step * c_skip..];
1622                    unsafe {
1623                        let a = rhs_p.as_ptr() as *const f32;
1624                        let b = lhs_p.as_ptr() as *const f32;
1625                        let c = dst_p.as_mut_ptr() as *mut f32;
1626                        let a = std::slice::from_raw_parts(a, a_skip);
1627                        let b = std::slice::from_raw_parts(b, b_skip);
1628                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1629                        crate::mkl::sgemm(
1630                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1631                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1632                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1633                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1634                        )
1635                    }
1636                }
1637            }
1638            DType::F64 => {
1639                for step in 0..b {
1640                    let lhs_p = &lhs[step * a_skip..];
1641                    let rhs_p = &rhs[step * b_skip..];
1642                    let dst_p = &mut dst[step * c_skip..];
1643                    unsafe {
1644                        let a = rhs_p.as_ptr() as *const f64;
1645                        let b = lhs_p.as_ptr() as *const f64;
1646                        let c = dst_p.as_mut_ptr() as *mut f64;
1647                        let a = std::slice::from_raw_parts(a, a_skip);
1648                        let b = std::slice::from_raw_parts(b, b_skip);
1649                        let c = std::slice::from_raw_parts_mut(c, c_skip);
1650                        crate::mkl::dgemm(
1651                            transa, transb, /* m= */ n as i32, /* n= */ m as i32,
1652                            /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
1653                            /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
1654                            /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
1655                        )
1656                    }
1657                }
1658            }
1659            dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
1660        }
1661        Ok(dst)
1662    }
1663}
1664
1665fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
1666    if v.is_sign_positive() {
1667        v
1668    } else {
1669        (v.exp() - T::one()) * alpha
1670    }
1671}
1672
1673impl CpuStorage {
1674    pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
1675        D::cpu_storage_as_slice(self)
1676    }
1677
1678    pub fn concat(storages: &[CpuStorage]) -> Result<CpuStorage> {
1679        let storage0 = &storages[0];
1680        let s = match storage0 {
1681            Self::U8(_) => {
1682                let storages = storages
1683                    .iter()
1684                    .map(|s| match s {
1685                        Self::U8(s) => Ok(s.as_slice()),
1686                        _ => crate::bail!("dtype mismatch"),
1687                    })
1688                    .collect::<Result<Vec<_>>>()?
1689                    .concat();
1690                Self::U8(storages)
1691            }
1692            Self::U32(_) => {
1693                let storages = storages
1694                    .iter()
1695                    .map(|s| match s {
1696                        Self::U32(s) => Ok(s.as_slice()),
1697                        _ => crate::bail!("dtype mismatch"),
1698                    })
1699                    .collect::<Result<Vec<_>>>()?
1700                    .concat();
1701                Self::U32(storages)
1702            }
1703            Self::I16(_) => {
1704                let storages = storages
1705                    .iter()
1706                    .map(|s| match s {
1707                        Self::I16(s) => Ok(s.as_slice()),
1708                        _ => crate::bail!("dtype mismatch"),
1709                    })
1710                    .collect::<Result<Vec<_>>>()?
1711                    .concat();
1712                Self::I16(storages)
1713            }
1714            Self::I32(_) => {
1715                let storages = storages
1716                    .iter()
1717                    .map(|s| match s {
1718                        Self::I32(s) => Ok(s.as_slice()),
1719                        _ => crate::bail!("dtype mismatch"),
1720                    })
1721                    .collect::<Result<Vec<_>>>()?
1722                    .concat();
1723                Self::I32(storages)
1724            }
1725            Self::I64(_) => {
1726                let storages = storages
1727                    .iter()
1728                    .map(|s| match s {
1729                        Self::I64(s) => Ok(s.as_slice()),
1730                        _ => crate::bail!("dtype mismatch"),
1731                    })
1732                    .collect::<Result<Vec<_>>>()?
1733                    .concat();
1734                Self::I64(storages)
1735            }
1736            Self::BF16(_) => {
1737                let storages = storages
1738                    .iter()
1739                    .map(|s| match s {
1740                        Self::BF16(s) => Ok(s.as_slice()),
1741                        _ => crate::bail!("dtype mismatch"),
1742                    })
1743                    .collect::<Result<Vec<_>>>()?
1744                    .concat();
1745                Self::BF16(storages)
1746            }
1747            Self::F16(_) => {
1748                let storages = storages
1749                    .iter()
1750                    .map(|s| match s {
1751                        Self::F16(s) => Ok(s.as_slice()),
1752                        _ => crate::bail!("dtype mismatch"),
1753                    })
1754                    .collect::<Result<Vec<_>>>()?
1755                    .concat();
1756                Self::F16(storages)
1757            }
1758            Self::F32(_) => {
1759                let storages = storages
1760                    .iter()
1761                    .map(|s| match s {
1762                        Self::F32(s) => Ok(s.as_slice()),
1763                        _ => crate::bail!("dtype mismatch"),
1764                    })
1765                    .collect::<Result<Vec<_>>>()?
1766                    .concat();
1767                Self::F32(storages)
1768            }
1769            Self::F64(_) => {
1770                let storages = storages
1771                    .iter()
1772                    .map(|s| match s {
1773                        Self::F64(s) => Ok(s.as_slice()),
1774                        _ => crate::bail!("dtype mismatch"),
1775                    })
1776                    .collect::<Result<Vec<_>>>()?
1777                    .concat();
1778                Self::F64(storages)
1779            }
1780            Self::F8E4M3(_) => {
1781                let storages = storages
1782                    .iter()
1783                    .map(|s| match s {
1784                        Self::F8E4M3(s) => Ok(s.as_slice()),
1785                        _ => crate::bail!("dtype mismatch"),
1786                    })
1787                    .collect::<Result<Vec<_>>>()?
1788                    .concat();
1789                Self::F8E4M3(storages)
1790            }
1791            Self::F6E2M3(_) => {
1792                let storages = storages
1793                    .iter()
1794                    .map(|s| match s {
1795                        Self::F6E2M3(s) => Ok(s.as_slice()),
1796                        _ => crate::bail!("dtype mismatch"),
1797                    })
1798                    .collect::<Result<Vec<_>>>()?
1799                    .concat();
1800                Self::F6E2M3(storages)
1801            }
1802            Self::F6E3M2(_) => {
1803                let storages = storages
1804                    .iter()
1805                    .map(|s| match s {
1806                        Self::F6E3M2(s) => Ok(s.as_slice()),
1807                        _ => crate::bail!("dtype mismatch"),
1808                    })
1809                    .collect::<Result<Vec<_>>>()?
1810                    .concat();
1811                Self::F6E3M2(storages)
1812            }
1813            Self::F4(_) => {
1814                let storages = storages
1815                    .iter()
1816                    .map(|s| match s {
1817                        Self::F4(s) => Ok(s.as_slice()),
1818                        _ => crate::bail!("dtype mismatch"),
1819                    })
1820                    .collect::<Result<Vec<_>>>()?
1821                    .concat();
1822                Self::F4(storages)
1823            }
1824            Self::F8E8M0(_) => {
1825                let storages = storages
1826                    .iter()
1827                    .map(|s| match s {
1828                        Self::F8E8M0(s) => Ok(s.as_slice()),
1829                        _ => crate::bail!("dtype mismatch"),
1830                    })
1831                    .collect::<Result<Vec<_>>>()?
1832                    .concat();
1833                Self::F8E8M0(storages)
1834            }
1835        };
1836        Ok(s)
1837    }
1838}
1839
1840impl BackendStorage for CpuStorage {
1841    type Device = CpuDevice;
1842
1843    fn dtype(&self) -> DType {
1844        match self {
1845            Self::U8(_) => DType::U8,
1846            Self::U32(_) => DType::U32,
1847            Self::I16(_) => DType::I16,
1848            Self::I32(_) => DType::I32,
1849            Self::I64(_) => DType::I64,
1850            Self::BF16(_) => DType::BF16,
1851            Self::F16(_) => DType::F16,
1852            Self::F32(_) => DType::F32,
1853            Self::F64(_) => DType::F64,
1854            Self::F8E4M3(_) => DType::F8E4M3,
1855            Self::F6E2M3(_) => DType::F6E2M3,
1856            Self::F6E3M2(_) => DType::F6E3M2,
1857            Self::F4(_) => DType::F4,
1858            Self::F8E8M0(_) => DType::F8E8M0,
1859        }
1860    }
1861
1862    fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
1863        // TODO: find a way around the quadratic number of cases below.
1864        match (self, dtype) {
1865            (Self::U8(storage), DType::BF16) => {
1866                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1867                Ok(Self::BF16(data))
1868            }
1869            (Self::U32(storage), DType::BF16) => {
1870                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1871                Ok(Self::BF16(data))
1872            }
1873            (Self::I64(storage), DType::BF16) => {
1874                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
1875                Ok(Self::BF16(data))
1876            }
1877            (Self::BF16(storage), DType::BF16) => {
1878                let data = unary_map(storage, layout, |v| v);
1879                Ok(Self::BF16(data))
1880            }
1881            (Self::F16(storage), DType::BF16) => {
1882                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
1883                Ok(Self::BF16(data))
1884            }
1885            (Self::F32(storage), DType::BF16) => {
1886                let data = unary_map(storage, layout, bf16::from_f32);
1887                Ok(Self::BF16(data))
1888            }
1889            (Self::F64(storage), DType::BF16) => {
1890                let data = unary_map(storage, layout, bf16::from_f64);
1891                Ok(Self::BF16(data))
1892            }
1893            (Self::U8(storage), DType::F16) => {
1894                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1895                Ok(Self::F16(data))
1896            }
1897            (Self::U32(storage), DType::F16) => {
1898                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1899                Ok(Self::F16(data))
1900            }
1901            (Self::I64(storage), DType::F16) => {
1902                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
1903                Ok(Self::F16(data))
1904            }
1905            (Self::BF16(storage), DType::F16) => {
1906                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
1907                Ok(Self::F16(data))
1908            }
1909            (Self::F16(storage), DType::F16) => {
1910                let data = unary_map(storage, layout, |v| v);
1911                Ok(Self::F16(data))
1912            }
1913            (Self::F32(storage), DType::F16) => {
1914                let data = unary_map(storage, layout, f16::from_f32);
1915                Ok(Self::F16(data))
1916            }
1917            (Self::F64(storage), DType::F16) => {
1918                let data = unary_map(storage, layout, f16::from_f64);
1919                Ok(Self::F16(data))
1920            }
1921            (Self::U8(storage), DType::F32) => {
1922                let data = unary_map(storage, layout, |v| v as f32);
1923                Ok(Self::F32(data))
1924            }
1925            (Self::U32(storage), DType::F32) => {
1926                let data = unary_map(storage, layout, |v| v as f32);
1927                Ok(Self::F32(data))
1928            }
1929            (Self::I64(storage), DType::F32) => {
1930                let data = unary_map(storage, layout, |v| v as f32);
1931                Ok(Self::F32(data))
1932            }
1933            (Self::BF16(storage), DType::F32) => {
1934                let data = unary_map(storage, layout, |v| v.to_f32());
1935                Ok(Self::F32(data))
1936            }
1937            (Self::F16(storage), DType::F32) => {
1938                let data = unary_map(storage, layout, |v| v.to_f32());
1939                Ok(Self::F32(data))
1940            }
1941            (Self::F32(storage), DType::F32) => {
1942                let data = unary_map(storage, layout, |v| v);
1943                Ok(Self::F32(data))
1944            }
1945            (Self::F64(storage), DType::F32) => {
1946                let data = unary_map(storage, layout, |v| v as f32);
1947                Ok(Self::F32(data))
1948            }
1949            (Self::U8(storage), DType::U8) => {
1950                let data = unary_map(storage, layout, |v| v);
1951                Ok(Self::U8(data))
1952            }
1953            (Self::BF16(storage), DType::U8) => {
1954                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1955                Ok(Self::U8(data))
1956            }
1957            (Self::F16(storage), DType::U8) => {
1958                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
1959                Ok(Self::U8(data))
1960            }
1961            (Self::F32(storage), DType::U8) => {
1962                let data = unary_map(storage, layout, |v| v as u8);
1963                Ok(Self::U8(data))
1964            }
1965            (Self::F64(storage), DType::U8) => {
1966                let data = unary_map(storage, layout, |v| v as u8);
1967                Ok(Self::U8(data))
1968            }
1969            (Self::U32(storage), DType::U8) => {
1970                let data = unary_map(storage, layout, |v| v as u8);
1971                Ok(Self::U8(data))
1972            }
1973            (Self::I64(storage), DType::U8) => {
1974                let data = unary_map(storage, layout, |v| v as u8);
1975                Ok(Self::U8(data))
1976            }
1977            (Self::U8(storage), DType::U32) => {
1978                let data = unary_map(storage, layout, |v| v as u32);
1979                Ok(Self::U32(data))
1980            }
1981            (Self::U32(storage), DType::U32) => {
1982                let data = unary_map(storage, layout, |v| v);
1983                Ok(Self::U32(data))
1984            }
1985            (Self::I64(storage), DType::U32) => {
1986                let data = unary_map(storage, layout, |v| v as u32);
1987                Ok(Self::U32(data))
1988            }
1989            (Self::BF16(storage), DType::U32) => {
1990                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1991                Ok(Self::U32(data))
1992            }
1993            (Self::F16(storage), DType::U32) => {
1994                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
1995                Ok(Self::U32(data))
1996            }
1997            (Self::F32(storage), DType::U32) => {
1998                let data = unary_map(storage, layout, |v| v as u32);
1999                Ok(Self::U32(data))
2000            }
2001            (Self::F64(storage), DType::U32) => {
2002                let data = unary_map(storage, layout, |v| v as u32);
2003                Ok(Self::U32(data))
2004            }
2005            (Self::U8(storage), DType::I64) => {
2006                let data = unary_map(storage, layout, |v| v as i64);
2007                Ok(Self::I64(data))
2008            }
2009            (Self::U32(storage), DType::I64) => {
2010                let data = unary_map(storage, layout, |v| v as i64);
2011                Ok(Self::I64(data))
2012            }
2013            (Self::I64(storage), DType::I64) => {
2014                let data = unary_map(storage, layout, |v| v);
2015                Ok(Self::I64(data))
2016            }
2017            (Self::BF16(storage), DType::I64) => {
2018                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2019                Ok(Self::I64(data))
2020            }
2021            (Self::F16(storage), DType::I64) => {
2022                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2023                Ok(Self::I64(data))
2024            }
2025            (Self::F32(storage), DType::I64) => {
2026                let data = unary_map(storage, layout, |v| v as i64);
2027                Ok(Self::I64(data))
2028            }
2029            (Self::F64(storage), DType::I64) => {
2030                let data = unary_map(storage, layout, |v| v as i64);
2031                Ok(Self::I64(data))
2032            }
2033            (Self::U8(storage), DType::F64) => {
2034                let data = unary_map(storage, layout, |v| v as f64);
2035                Ok(Self::F64(data))
2036            }
2037            (Self::U32(storage), DType::F64) => {
2038                let data = unary_map(storage, layout, |v| v as f64);
2039                Ok(Self::F64(data))
2040            }
2041            (Self::I64(storage), DType::F64) => {
2042                let data = unary_map(storage, layout, |v| v as f64);
2043                Ok(Self::F64(data))
2044            }
2045            (Self::BF16(storage), DType::F64) => {
2046                let data = unary_map(storage, layout, |v| v.to_f64());
2047                Ok(Self::F64(data))
2048            }
2049            (Self::F16(storage), DType::F64) => {
2050                let data = unary_map(storage, layout, |v| v.to_f64());
2051                Ok(Self::F64(data))
2052            }
2053            (Self::F32(storage), DType::F64) => {
2054                let data = unary_map(storage, layout, |v| v as f64);
2055                Ok(Self::F64(data))
2056            }
2057            (Self::F64(storage), DType::F64) => {
2058                let data = unary_map(storage, layout, |v| v);
2059                Ok(Self::F64(data))
2060            }
2061            // Conversions to F8E4M3
2062            (Self::U8(storage), DType::F8E4M3) => {
2063                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2064                Ok(Self::F8E4M3(data))
2065            }
2066            (Self::U32(storage), DType::F8E4M3) => {
2067                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2068                Ok(Self::F8E4M3(data))
2069            }
2070            (Self::I64(storage), DType::F8E4M3) => {
2071                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2072                Ok(Self::F8E4M3(data))
2073            }
2074            (Self::BF16(storage), DType::F8E4M3) => {
2075                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2076                Ok(Self::F8E4M3(data))
2077            }
2078            (Self::F16(storage), DType::F8E4M3) => {
2079                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
2080                Ok(Self::F8E4M3(data))
2081            }
2082            (Self::F32(storage), DType::F8E4M3) => {
2083                let data = unary_map(storage, layout, F8E4M3::from_f32);
2084                Ok(Self::F8E4M3(data))
2085            }
2086            (Self::F64(storage), DType::F8E4M3) => {
2087                let data = unary_map(storage, layout, F8E4M3::from_f64);
2088                Ok(Self::F8E4M3(data))
2089            }
2090            (Self::F8E4M3(storage), DType::F8E4M3) => {
2091                let data = unary_map(storage, layout, |v| v);
2092                Ok(Self::F8E4M3(data))
2093            }
2094            // Conversions from F8E4M3
2095            (Self::F8E4M3(storage), DType::U8) => {
2096                let data = unary_map(storage, layout, |v| v.to_f32() as u8);
2097                Ok(Self::U8(data))
2098            }
2099            (Self::F8E4M3(storage), DType::U32) => {
2100                let data = unary_map(storage, layout, |v| v.to_f32() as u32);
2101                Ok(Self::U32(data))
2102            }
2103            (Self::F8E4M3(storage), DType::I64) => {
2104                let data = unary_map(storage, layout, |v| v.to_f32() as i64);
2105                Ok(Self::I64(data))
2106            }
2107            (Self::F8E4M3(storage), DType::BF16) => {
2108                let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
2109                Ok(Self::BF16(data))
2110            }
2111            (Self::F8E4M3(storage), DType::F16) => {
2112                let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
2113                Ok(Self::F16(data))
2114            }
2115            (Self::F8E4M3(storage), DType::F32) => {
2116                let data = unary_map(storage, layout, |v| v.to_f32());
2117                Ok(Self::F32(data))
2118            }
2119            (Self::F8E4M3(storage), DType::F64) => {
2120                let data = unary_map(storage, layout, |v| v.to_f64());
2121                Ok(Self::F64(data))
2122            }
2123            // Conversions to I16
2124            (Self::U8(storage), DType::I16) => {
2125                let data = unary_map(storage, layout, |v| v as i16);
2126                Ok(Self::I16(data))
2127            }
2128            (Self::U32(storage), DType::I16) => {
2129                let data = unary_map(storage, layout, |v| v as i16);
2130                Ok(Self::I16(data))
2131            }
2132            (Self::I16(storage), DType::I16) => {
2133                let data = unary_map(storage, layout, |v| v);
2134                Ok(Self::I16(data))
2135            }
2136            (Self::I32(storage), DType::I16) => {
2137                let data = unary_map(storage, layout, |v| v as i16);
2138                Ok(Self::I16(data))
2139            }
2140            (Self::I64(storage), DType::I16) => {
2141                let data = unary_map(storage, layout, |v| v as i16);
2142                Ok(Self::I16(data))
2143            }
2144            (Self::BF16(storage), DType::I16) => {
2145                let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2146                Ok(Self::I16(data))
2147            }
2148            (Self::F16(storage), DType::I16) => {
2149                let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2150                Ok(Self::I16(data))
2151            }
2152            (Self::F32(storage), DType::I16) => {
2153                let data = unary_map(storage, layout, |v| v as i16);
2154                Ok(Self::I16(data))
2155            }
2156            (Self::F64(storage), DType::I16) => {
2157                let data = unary_map(storage, layout, |v| v as i16);
2158                Ok(Self::I16(data))
2159            }
2160            (Self::F8E4M3(storage), DType::I16) => {
2161                let data = unary_map(storage, layout, |v| v.to_f32() as i16);
2162                Ok(Self::I16(data))
2163            }
2164            // Conversions to I32
2165            (Self::U8(storage), DType::I32) => {
2166                let data = unary_map(storage, layout, |v| v as i32);
2167                Ok(Self::I32(data))
2168            }
2169            (Self::U32(storage), DType::I32) => {
2170                let data = unary_map(storage, layout, |v| v as i32);
2171                Ok(Self::I32(data))
2172            }
2173            (Self::I16(storage), DType::I32) => {
2174                let data = unary_map(storage, layout, |v| v as i32);
2175                Ok(Self::I32(data))
2176            }
2177            (Self::I32(storage), DType::I32) => {
2178                let data = unary_map(storage, layout, |v| v);
2179                Ok(Self::I32(data))
2180            }
2181            (Self::I64(storage), DType::I32) => {
2182                let data = unary_map(storage, layout, |v| v as i32);
2183                Ok(Self::I32(data))
2184            }
2185            (Self::BF16(storage), DType::I32) => {
2186                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2187                Ok(Self::I32(data))
2188            }
2189            (Self::F16(storage), DType::I32) => {
2190                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2191                Ok(Self::I32(data))
2192            }
2193            (Self::F32(storage), DType::I32) => {
2194                let data = unary_map(storage, layout, |v| v as i32);
2195                Ok(Self::I32(data))
2196            }
2197            (Self::F64(storage), DType::I32) => {
2198                let data = unary_map(storage, layout, |v| v as i32);
2199                Ok(Self::I32(data))
2200            }
2201            (Self::F8E4M3(storage), DType::I32) => {
2202                let data = unary_map(storage, layout, |v| v.to_f32() as i32);
2203                Ok(Self::I32(data))
2204            }
2205            // Conversions from I16
2206            (Self::I16(storage), DType::U8) => {
2207                let data = unary_map(storage, layout, |v| v as u8);
2208                Ok(Self::U8(data))
2209            }
2210            (Self::I16(storage), DType::U32) => {
2211                let data = unary_map(storage, layout, |v| v as u32);
2212                Ok(Self::U32(data))
2213            }
2214            (Self::I16(storage), DType::I64) => {
2215                let data = unary_map(storage, layout, |v| v as i64);
2216                Ok(Self::I64(data))
2217            }
2218            (Self::I16(storage), DType::BF16) => {
2219                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2220                Ok(Self::BF16(data))
2221            }
2222            (Self::I16(storage), DType::F16) => {
2223                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2224                Ok(Self::F16(data))
2225            }
2226            (Self::I16(storage), DType::F32) => {
2227                let data = unary_map(storage, layout, |v| v as f32);
2228                Ok(Self::F32(data))
2229            }
2230            (Self::I16(storage), DType::F64) => {
2231                let data = unary_map(storage, layout, |v| v as f64);
2232                Ok(Self::F64(data))
2233            }
2234            (Self::I16(storage), DType::F8E4M3) => {
2235                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2236                Ok(Self::F8E4M3(data))
2237            }
2238            // Conversions from I32
2239            (Self::I32(storage), DType::U8) => {
2240                let data = unary_map(storage, layout, |v| v as u8);
2241                Ok(Self::U8(data))
2242            }
2243            (Self::I32(storage), DType::U32) => {
2244                let data = unary_map(storage, layout, |v| v as u32);
2245                Ok(Self::U32(data))
2246            }
2247            (Self::I32(storage), DType::I64) => {
2248                let data = unary_map(storage, layout, |v| v as i64);
2249                Ok(Self::I64(data))
2250            }
2251            (Self::I32(storage), DType::BF16) => {
2252                let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
2253                Ok(Self::BF16(data))
2254            }
2255            (Self::I32(storage), DType::F16) => {
2256                let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
2257                Ok(Self::F16(data))
2258            }
2259            (Self::I32(storage), DType::F32) => {
2260                let data = unary_map(storage, layout, |v| v as f32);
2261                Ok(Self::F32(data))
2262            }
2263            (Self::I32(storage), DType::F64) => {
2264                let data = unary_map(storage, layout, |v| v as f64);
2265                Ok(Self::F64(data))
2266            }
2267            (Self::I32(storage), DType::F8E4M3) => {
2268                let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
2269                Ok(Self::F8E4M3(data))
2270            }
2271            // Dummy types - return error for all conversions to/from dummy types
2272            (_, DType::F6E2M3) | (_, DType::F6E3M2) | (_, DType::F4) | (_, DType::F8E8M0) => {
2273                Err(Error::UnsupportedDTypeForOp(dtype, "to_dtype").bt())
2274            }
2275            (Self::F6E2M3(_), _)
2276            | (Self::F6E3M2(_), _)
2277            | (Self::F4(_), _)
2278            | (Self::F8E8M0(_), _) => {
2279                Err(Error::UnsupportedDTypeForOp(self.dtype(), "to_dtype").bt())
2280            }
2281        }
2282    }
2283
2284    fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
2285        match op {
2286            ReduceOp::Sum => {
2287                let src_dims = layout.dims();
2288                let mut dst_dims = src_dims.to_vec();
2289                for &dim in reduce_dims.iter() {
2290                    dst_dims[dim] = 1;
2291                }
2292                let dst_shape = Shape::from(dst_dims);
2293                let mut reduce_dims = reduce_dims.to_vec();
2294                // Sort the reduce_dims as they have to be processed from left to right when converting the
2295                // indexes.
2296                reduce_dims.sort();
2297                let reduce_dims_and_stride: Vec<_> = reduce_dims
2298                    .iter()
2299                    .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
2300                    .collect();
2301                ReduceSum {
2302                    dst_shape: &dst_shape,
2303                    reduce_dims: &reduce_dims,
2304                    reduce_dims_and_stride,
2305                }
2306                .map(self, layout)
2307            }
2308            ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
2309                let reduce_dim_index = match reduce_dims {
2310                    [reduce_dim_index] => *reduce_dim_index,
2311                    _ => {
2312                        let op = match op {
2313                            ReduceOp::Min => "min",
2314                            ReduceOp::ArgMin => "argmin",
2315                            ReduceOp::Max => "max",
2316                            ReduceOp::ArgMax => "argmax",
2317                            _ => unreachable!(),
2318                        };
2319                        let dims = reduce_dims.to_vec();
2320                        Err(Error::OnlySingleDimension { op, dims })?
2321                    }
2322                };
2323                let (use_min, return_index) = match op {
2324                    ReduceOp::Min => (true, false),
2325                    ReduceOp::ArgMin => (true, true),
2326                    ReduceOp::Max => (false, false),
2327                    ReduceOp::ArgMax => (false, true),
2328                    _ => unreachable!(),
2329                };
2330                ReduceIndex {
2331                    reduce_dim_index,
2332                    use_min,
2333                    return_index,
2334                }
2335                .map(self, layout)
2336            }
2337        }
2338    }
2339
2340    fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
2341        Cmp(op).map(self, lhs_l, rhs, rhs_l)
2342    }
2343
2344    fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
2345        Affine(mul, add).map(self, layout)
2346    }
2347
2348    fn avg_pool2d(
2349        &self,
2350        layout: &Layout,
2351        kernel_size: (usize, usize),
2352        stride: (usize, usize),
2353    ) -> Result<Self> {
2354        AvgPool2D(kernel_size, stride).map(self, layout)
2355    }
2356
2357    fn max_pool2d(
2358        &self,
2359        layout: &Layout,
2360        kernel_size: (usize, usize),
2361        stride: (usize, usize),
2362    ) -> Result<Self> {
2363        MaxPool2D(kernel_size, stride).map(self, layout)
2364    }
2365
2366    fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
2367        UpsampleNearest1D(sz).map(self, layout)
2368    }
2369
2370    fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
2371        UpsampleNearest2D(h, w).map(self, layout)
2372    }
2373
2374    fn upsample_bilinear2d(
2375        &self,
2376        layout: &Layout,
2377        h: usize,
2378        w: usize,
2379        align_corners: bool,
2380        scale_h: Option<f64>,
2381        scale_w: Option<f64>,
2382    ) -> Result<Self> {
2383        UpsampleBilinear2D {
2384            target_h: h,
2385            target_w: w,
2386            align_corners,
2387            scale_h_factor: scale_h,
2388            scale_w_factor: scale_w,
2389        }
2390        .map(self, layout)
2391    }
2392
2393    fn powf(&self, layout: &Layout, e: f64) -> Result<Self> {
2394        use num_traits::Float;
2395        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2396        match self {
2397            Self::BF16(storage) => {
2398                let data = unary_map(storage, layout, |v| v.powf(bf16::from_f64(e)));
2399                Ok(Self::BF16(data))
2400            }
2401            Self::F16(storage) => {
2402                let data = unary_map(storage, layout, |v| v.powf(f16::from_f64(e)));
2403                Ok(Self::F16(data))
2404            }
2405            Self::F32(storage) => {
2406                let data = unary_map(storage, layout, |v| v.powf(e as f32));
2407                Ok(Self::F32(data))
2408            }
2409            Self::F64(storage) => {
2410                let data = unary_map(storage, layout, |v| v.powf(e));
2411                Ok(Self::F64(data))
2412            }
2413            Self::F8E4M3(storage) => {
2414                let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
2415                Ok(Self::F8E4M3(data))
2416            }
2417            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "powf").bt()),
2418            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "powf").bt()),
2419            Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "powf").bt()),
2420            Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "powf").bt()),
2421            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "powf").bt()),
2422            Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "powf").bt()),
2423            Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "powf").bt()),
2424            Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "powf").bt()),
2425            Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "powf").bt()),
2426        }
2427    }
2428
2429    fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
2430        // TODO: Have some generic map for functions that apply on num_traits::Float elements.
2431        match self {
2432            Self::BF16(storage) => {
2433                let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
2434                Ok(Self::BF16(data))
2435            }
2436            Self::F16(storage) => {
2437                let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
2438                Ok(Self::F16(data))
2439            }
2440            Self::F32(storage) => {
2441                let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
2442                Ok(Self::F32(data))
2443            }
2444            Self::F64(storage) => {
2445                let data = unary_map(storage, layout, |v| elu(v, alpha));
2446                Ok(Self::F64(data))
2447            }
2448            Self::F8E4M3(storage) => {
2449                let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
2450                Ok(Self::F8E4M3(data))
2451            }
2452            Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
2453            Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
2454            Self::I16(_) => Err(Error::UnsupportedDTypeForOp(DType::I16, "elu").bt()),
2455            Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
2456            Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
2457            Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "elu").bt()),
2458            Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "elu").bt()),
2459            Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "elu").bt()),
2460            Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "elu").bt()),
2461        }
2462    }
2463
2464    fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
2465        match self {
2466            Self::BF16(storage) => {
2467                let data = unary_map_vec(storage, layout, B::bf16, B::bf16_vec);
2468                Ok(Self::BF16(data))
2469            }
2470            Self::F16(storage) => {
2471                let data = unary_map_vec(storage, layout, B::f16, B::f16_vec);
2472                Ok(Self::F16(data))
2473            }
2474            Self::F32(storage) => {
2475                let data = unary_map_vec(storage, layout, B::f32, B::f32_vec);
2476                Ok(Self::F32(data))
2477            }
2478            Self::F64(storage) => {
2479                let data = unary_map_vec(storage, layout, B::f64, B::f64_vec);
2480                Ok(Self::F64(data))
2481            }
2482            Self::U8(storage) => {
2483                let data = unary_map(storage, layout, B::u8);
2484                Ok(Self::U8(data))
2485            }
2486            Self::U32(storage) => {
2487                let data = unary_map(storage, layout, B::u32);
2488                Ok(Self::U32(data))
2489            }
2490            Self::I16(storage) => {
2491                let data = unary_map(storage, layout, B::i16);
2492                Ok(Self::I16(data))
2493            }
2494            Self::I32(storage) => {
2495                let data = unary_map(storage, layout, B::i32);
2496                Ok(Self::I32(data))
2497            }
2498            Self::I64(storage) => {
2499                let data = unary_map(storage, layout, B::i64);
2500                Ok(Self::I64(data))
2501            }
2502            Self::F8E4M3(storage) => {
2503                let data = unary_map(storage, layout, B::f8e4m3);
2504                Ok(Self::F8E4M3(data))
2505            }
2506            Self::F6E2M3(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E2M3, "unary").bt()),
2507            Self::F6E3M2(_) => Err(Error::UnsupportedDTypeForOp(DType::F6E3M2, "unary").bt()),
2508            Self::F4(_) => Err(Error::UnsupportedDTypeForOp(DType::F4, "unary").bt()),
2509            Self::F8E8M0(_) => Err(Error::UnsupportedDTypeForOp(DType::F8E8M0, "unary").bt()),
2510        }
2511    }
2512
2513    fn binary_impl<B: BinaryOpT>(
2514        &self,
2515        rhs: &Self,
2516        lhs_l: &Layout,
2517        rhs_l: &Layout,
2518    ) -> Result<Self> {
2519        match (self, rhs) {
2520            (Self::BF16(lhs), Self::BF16(rhs)) => {
2521                let data = binary_map_vec(
2522                    lhs_l,
2523                    rhs_l,
2524                    lhs,
2525                    rhs,
2526                    B::bf16,
2527                    B::bf16_vec,
2528                    B::bf16_scalar_vec,
2529                );
2530                Ok(Self::BF16(data))
2531            }
2532            (Self::F16(lhs), Self::F16(rhs)) => {
2533                let data = binary_map_vec(
2534                    lhs_l,
2535                    rhs_l,
2536                    lhs,
2537                    rhs,
2538                    B::f16,
2539                    B::f16_vec,
2540                    B::f16_scalar_vec,
2541                );
2542                Ok(Self::F16(data))
2543            }
2544            (Self::F32(lhs), Self::F32(rhs)) => {
2545                let data = binary_map_vec(
2546                    lhs_l,
2547                    rhs_l,
2548                    lhs,
2549                    rhs,
2550                    B::f32,
2551                    B::f32_vec,
2552                    B::f32_scalar_vec,
2553                );
2554                Ok(Self::F32(data))
2555            }
2556            (Self::F64(lhs), Self::F64(rhs)) => {
2557                let data = binary_map_vec(
2558                    lhs_l,
2559                    rhs_l,
2560                    lhs,
2561                    rhs,
2562                    B::f64,
2563                    B::f64_vec,
2564                    B::f64_scalar_vec,
2565                );
2566                Ok(Self::F64(data))
2567            }
2568            (Self::U32(lhs), Self::U32(rhs)) => {
2569                let data = binary_map_vec(
2570                    lhs_l,
2571                    rhs_l,
2572                    lhs,
2573                    rhs,
2574                    B::u32,
2575                    B::u32_vec,
2576                    B::u32_scalar_vec,
2577                );
2578                Ok(Self::U32(data))
2579            }
2580            (Self::I16(lhs), Self::I16(rhs)) => {
2581                let data = binary_map_vec(
2582                    lhs_l,
2583                    rhs_l,
2584                    lhs,
2585                    rhs,
2586                    B::i16,
2587                    B::i16_vec,
2588                    B::i16_scalar_vec,
2589                );
2590                Ok(Self::I16(data))
2591            }
2592            (Self::I32(lhs), Self::I32(rhs)) => {
2593                let data = binary_map_vec(
2594                    lhs_l,
2595                    rhs_l,
2596                    lhs,
2597                    rhs,
2598                    B::i32,
2599                    B::i32_vec,
2600                    B::i32_scalar_vec,
2601                );
2602                Ok(Self::I32(data))
2603            }
2604            (Self::I64(lhs), Self::I64(rhs)) => {
2605                let data = binary_map_vec(
2606                    lhs_l,
2607                    rhs_l,
2608                    lhs,
2609                    rhs,
2610                    B::i64,
2611                    B::i64_vec,
2612                    B::i64_scalar_vec,
2613                );
2614                Ok(Self::I64(data))
2615            }
2616            (Self::U8(lhs), Self::U8(rhs)) => {
2617                let data =
2618                    binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::u8, B::u8_vec, B::u8_scalar_vec);
2619                Ok(Self::U8(data))
2620            }
2621            (Self::F8E4M3(lhs), Self::F8E4M3(rhs)) => {
2622                let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f8e4m3);
2623                Ok(Self::F8E4M3(data))
2624            }
2625            _ => {
2626                // This should be covered by the dtype check above.
2627                Err(Error::DTypeMismatchBinaryOp {
2628                    lhs: self.dtype(),
2629                    rhs: rhs.dtype(),
2630                    op: B::NAME,
2631                }
2632                .bt())
2633            }
2634        }
2635    }
2636
2637    fn copy2d(
2638        &self,
2639        dst: &mut Self,
2640        d1: usize,
2641        d2: usize,
2642        src_s: usize,
2643        dst_s: usize,
2644        src_o: usize,
2645        dst_o: usize,
2646    ) -> Result<()> {
2647        match (self, dst) {
2648            (Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2649            (Self::U32(src), Self::U32(dst)) => {
2650                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2651            }
2652            (Self::I16(src), Self::I16(dst)) => {
2653                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2654            }
2655            (Self::I32(src), Self::I32(dst)) => {
2656                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2657            }
2658            (Self::I64(src), Self::I64(dst)) => {
2659                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2660            }
2661            (Self::BF16(src), Self::BF16(dst)) => {
2662                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2663            }
2664            (Self::F16(src), Self::F16(dst)) => {
2665                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2666            }
2667            (Self::F32(src), Self::F32(dst)) => {
2668                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2669            }
2670            (Self::F64(src), Self::F64(dst)) => {
2671                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2672            }
2673            (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2674                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2675            }
2676            (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2677                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2678            }
2679            (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2680                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2681            }
2682            (Self::F4(src), Self::F4(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
2683            (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2684                copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
2685            }
2686            (_, dst) => {
2687                return Err(Error::DTypeMismatchBinaryOp {
2688                    lhs: self.dtype(),
2689                    rhs: dst.dtype(),
2690                    op: "copy2d",
2691                }
2692                .bt());
2693            }
2694        }
2695        Ok(())
2696    }
2697
2698    fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
2699        match (self, dst) {
2700            (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2701            (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2702            (Self::I16(src), Self::I16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2703            (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2704            (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2705            (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2706            (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2707            (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2708            (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2709            (Self::F8E4M3(src), Self::F8E4M3(dst)) => {
2710                copy_strided_src_(src, dst, dst_offset, src_l)
2711            }
2712            (Self::F6E2M3(src), Self::F6E2M3(dst)) => {
2713                copy_strided_src_(src, dst, dst_offset, src_l)
2714            }
2715            (Self::F6E3M2(src), Self::F6E3M2(dst)) => {
2716                copy_strided_src_(src, dst, dst_offset, src_l)
2717            }
2718            (Self::F4(src), Self::F4(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
2719            (Self::F8E8M0(src), Self::F8E8M0(dst)) => {
2720                copy_strided_src_(src, dst, dst_offset, src_l)
2721            }
2722            (_, dst) => {
2723                // This should be covered by the dtype check above.
2724                return Err(Error::DTypeMismatchBinaryOp {
2725                    lhs: self.dtype(),
2726                    rhs: dst.dtype(),
2727                    op: "copy_strided",
2728                }
2729                .bt());
2730            }
2731        }
2732        Ok(())
2733    }
2734
2735    fn where_cond(
2736        &self,
2737        layout: &Layout,
2738        t: &Self,
2739        t_l: &Layout,
2740        f: &Self,
2741        f_l: &Layout,
2742    ) -> Result<Self> {
2743        match self {
2744            Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2745            Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2746            Self::I16(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2747            Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2748            Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
2749            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
2750        }
2751    }
2752
2753    fn conv1d(
2754        &self,
2755        l: &Layout,
2756        kernel: &Self,
2757        kernel_l: &Layout,
2758        params: &crate::conv::ParamsConv1D,
2759    ) -> Result<Self> {
2760        if !USE_IM2COL_CONV1D {
2761            return Conv1D(params).map(self, l, kernel, kernel_l);
2762        }
2763        let op = Im2Col1D {
2764            l_k: params.k_size,
2765            padding: params.padding,
2766            stride: params.stride,
2767            dilation: params.dilation,
2768        };
2769        let col = op.map(self, l)?;
2770        let b = params.b_size;
2771        let n = params.c_out;
2772        let l_out = params.l_out();
2773        let k = op.l_k * params.c_in;
2774        let m = l_out;
2775        let col_l = Layout::contiguous((b, m, k));
2776        let res = if kernel_l.is_contiguous() {
2777            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2778                .transpose(1, 2)?
2779                .broadcast_as((b, k, n))?;
2780            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2781        } else {
2782            // Make the kernel contiguous if not already the case.
2783            let mut kernel_c = unsafe {
2784                self.device()
2785                    .alloc_uninit(kernel_l.shape(), kernel.dtype())?
2786            };
2787            kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
2788            let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
2789                .transpose(1, 2)?
2790                .broadcast_as((b, k, n))?;
2791            col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
2792        };
2793        let res_l = Layout::contiguous((b, l_out, params.c_out)).transpose(1, 2)?;
2794        let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
2795        res.copy_strided_src(&mut res_t, 0, &res_l)?;
2796        Ok(res_t)
2797    }
2798
2799    fn conv_transpose1d(
2800        &self,
2801        l: &Layout,
2802        kernel: &Self,
2803        kernel_l: &Layout,
2804        params: &crate::conv::ParamsConvTranspose1D,
2805    ) -> Result<Self> {
2806        let can_use_col2im = kernel_l.is_contiguous()
2807            && params.dilation == 1
2808            && params.padding == 0
2809            && params.output_padding == 0;
2810        if USE_COL2IM_CONV1D_TR && can_use_col2im {
2811            let (b_size, c_in, l_in) = l.shape().dims3()?;
2812            let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
2813            if !kernel_l.is_contiguous() {
2814                crate::bail!(
2815                    "convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
2816                )
2817            }
2818            if c_in != c_in2 {
2819                crate::bail!(
2820                    "convtr1d: shape mismatch on c_in {:?} {:?}",
2821                    l.shape(),
2822                    kernel_l.shape()
2823                )
2824            }
2825            let col = {
2826                // This merges the last two dimensions of the kernel together.
2827                let kernel_l_mm = Layout::new(
2828                    (b_size, c_in, k_size * c_out).into(),
2829                    vec![0, k_size * c_out, 1],
2830                    kernel_l.start_offset(),
2831                );
2832                self.matmul(
2833                    kernel,
2834                    (
2835                        b_size,
2836                        /* m */ l_in,
2837                        /* n */ c_out * k_size,
2838                        /* k */ c_in,
2839                    ),
2840                    &l.transpose(1, 2)?,
2841                    &kernel_l_mm,
2842                )?
2843            };
2844            let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
2845            Col2Im1D {
2846                stride: params.stride,
2847            }
2848            .map(&col, &col_l)
2849        } else {
2850            ConvTranspose1D(params).map(self, l, kernel, kernel_l)
2851        }
2852    }
2853
2854    fn conv2d(
2855        &self,
2856        l: &Layout,
2857        kernel: &Self,
2858        kernel_l: &Layout,
2859        params: &crate::conv::ParamsConv2D,
2860    ) -> Result<Self> {
2861        Conv2D(params).map(self, l, kernel, kernel_l)
2862    }
2863
2864    fn conv_transpose2d(
2865        &self,
2866        l: &Layout,
2867        kernel: &Self,
2868        kernel_l: &Layout,
2869        params: &crate::conv::ParamsConvTranspose2D,
2870    ) -> Result<Self> {
2871        ConvTranspose2D(params).map(self, l, kernel, kernel_l)
2872    }
2873
2874    fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
2875        match ids {
2876            Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2877            Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2878            Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
2879            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
2880        }
2881    }
2882
2883    fn gather(&self, l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
2884        match ids {
2885            Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
2886            Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
2887            Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
2888            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
2889        }
2890    }
2891
2892    fn scatter_set(
2893        &mut self,
2894        l: &Layout,
2895        ids: &Self,
2896        ids_l: &Layout,
2897        src: &Self,
2898        src_l: &Layout,
2899        dim: usize,
2900    ) -> Result<()> {
2901        match ids {
2902            Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2903            Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2904            Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
2905            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
2906        }
2907    }
2908
2909    fn scatter_add_set(
2910        &mut self,
2911        l: &Layout,
2912        ids: &Self,
2913        ids_l: &Layout,
2914        src: &Self,
2915        src_l: &Layout,
2916        dim: usize,
2917    ) -> Result<()> {
2918        match ids {
2919            Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2920            Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2921            Self::I16(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2922            Self::I32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2923            Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
2924            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
2925        }
2926    }
2927
2928    fn index_add(
2929        &self,
2930        l: &Layout,
2931        ids: &Self,
2932        ids_l: &Layout,
2933        src: &Self,
2934        src_l: &Layout,
2935        dim: usize,
2936    ) -> Result<Self> {
2937        match ids {
2938            Self::U8(ids) => {
2939                let ids = match ids_l.contiguous_offsets() {
2940                    Some((a, b)) => &ids[a..b],
2941                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2942                };
2943                IndexAdd { ids, dim }.map(self, l, src, src_l)
2944            }
2945            Self::U32(ids) => {
2946                let ids = match ids_l.contiguous_offsets() {
2947                    Some((a, b)) => &ids[a..b],
2948                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2949                };
2950                IndexAdd { ids, dim }.map(self, l, src, src_l)
2951            }
2952            Self::I16(ids) => {
2953                let ids = match ids_l.contiguous_offsets() {
2954                    Some((a, b)) => &ids[a..b],
2955                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2956                };
2957                IndexAdd { ids, dim }.map(self, l, src, src_l)
2958            }
2959            Self::I32(ids) => {
2960                let ids = match ids_l.contiguous_offsets() {
2961                    Some((a, b)) => &ids[a..b],
2962                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2963                };
2964                IndexAdd { ids, dim }.map(self, l, src, src_l)
2965            }
2966            Self::I64(ids) => {
2967                let ids = match ids_l.contiguous_offsets() {
2968                    Some((a, b)) => &ids[a..b],
2969                    None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
2970                };
2971                IndexAdd { ids, dim }.map(self, l, src, src_l)
2972            }
2973            _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
2974        }
2975    }
2976
2977    fn matmul(
2978        &self,
2979        rhs: &Self,
2980        bmnk: (usize, usize, usize, usize),
2981        lhs_l: &Layout,
2982        rhs_l: &Layout,
2983    ) -> Result<Self> {
2984        MatMul(bmnk).map(self, lhs_l, rhs, rhs_l)
2985    }
2986
2987    fn device(&self) -> &Self::Device {
2988        &CpuDevice
2989    }
2990
2991    fn try_clone(&self, _: &Layout) -> Result<Self> {
2992        Ok(self.clone())
2993    }
2994
2995    fn to_cpu_storage(&self) -> Result<CpuStorage> {
2996        Ok(self.clone())
2997    }
2998
2999    fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
3000        use crate::scalar::Scalar;
3001        fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
3002            match l.strided_blocks() {
3003                crate::StridedBlocks::SingleBlock { start_offset, len } => {
3004                    src[start_offset..start_offset + len].fill(s)
3005                }
3006                crate::StridedBlocks::UniformBlocks {
3007                    start_offset,
3008                    block_len,
3009                    count,
3010                    src_stride,
3011                } => {
3012                    for i in 0..count {
3013                        let start = start_offset + i * src_stride;
3014                        src[start..start + block_len].fill(s)
3015                    }
3016                }
3017                crate::StridedBlocks::MultipleBlocks {
3018                    block_start_index,
3019                    block_len: 1,
3020                } => {
3021                    for src_index in block_start_index {
3022                        src[src_index] = s
3023                    }
3024                }
3025                crate::StridedBlocks::MultipleBlocks {
3026                    block_start_index,
3027                    block_len,
3028                } => {
3029                    for src_index in block_start_index {
3030                        src[src_index..src_index + block_len].fill(s)
3031                    }
3032                }
3033            }
3034        }
3035        match (self, s) {
3036            (Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
3037            (Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
3038            (Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
3039            (Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
3040            (Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
3041            (Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
3042            (Self::I16(storage), Scalar::I16(v)) => set(storage, l, v),
3043            (Self::I32(storage), Scalar::I32(v)) => set(storage, l, v),
3044            (Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
3045            (Self::F8E4M3(storage), Scalar::F8E4M3(v)) => set(storage, l, v),
3046            // Dummy types don't support scalar operations
3047            (Self::F6E2M3(_), _) => {
3048                crate::bail!("const_set not supported for dummy type F6E2M3")
3049            }
3050            (Self::F6E3M2(_), _) => {
3051                crate::bail!("const_set not supported for dummy type F6E3M2")
3052            }
3053            (Self::F4(_), _) => {
3054                crate::bail!("const_set not supported for dummy type F4")
3055            }
3056            (Self::F8E8M0(_), _) => {
3057                crate::bail!("const_set not supported for dummy type F8E8M0")
3058            }
3059            (st, s) => crate::bail!(
3060                "const_set dtype mismatch, expected {:?} but got {:?}",
3061                st.dtype(),
3062                s
3063            ),
3064        }
3065        Ok(())
3066    }
3067}
3068
3069impl BackendDevice for CpuDevice {
3070    type Storage = CpuStorage;
3071
3072    fn location(&self) -> crate::DeviceLocation {
3073        crate::DeviceLocation::Cpu
3074    }
3075
3076    fn same_device(&self, _: &Self) -> bool {
3077        true
3078    }
3079
3080    fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
3081        Ok(T::to_cpu_storage(s))
3082    }
3083
3084    fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
3085        Ok(s.clone())
3086    }
3087
3088    fn storage_from_cpu_storage_owned(&self, s: CpuStorage) -> Result<Self::Storage> {
3089        Ok(s)
3090    }
3091
3092    fn new(_: usize) -> Result<Self> {
3093        Ok(Self)
3094    }
3095
3096    fn set_seed(&self, _seed: u64) -> Result<()> {
3097        crate::bail!("cannot seed the CPU rng with set_seed")
3098    }
3099
3100    fn get_current_seed(&self) -> Result<u64> {
3101        crate::bail!("cannot get the CPU rng seed with get_current_seed")
3102    }
3103
3104    fn rand_uniform(&self, shape: &Shape, dtype: DType, min: f64, max: f64) -> Result<CpuStorage> {
3105        use rand::prelude::*;
3106
3107        let elem_count = shape.elem_count();
3108        let mut rng = rand::rng();
3109        match dtype {
3110            DType::U8
3111            | DType::U32
3112            | DType::I16
3113            | DType::I32
3114            | DType::I64
3115            | DType::F6E2M3
3116            | DType::F6E3M2
3117            | DType::F4
3118            | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()),
3119            DType::BF16 => {
3120                let mut data = Vec::with_capacity(elem_count);
3121                let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
3122                    .map_err(Error::wrap)?;
3123                for _i in 0..elem_count {
3124                    data.push(rng.sample::<bf16, _>(uniform))
3125                }
3126                Ok(CpuStorage::BF16(data))
3127            }
3128            DType::F16 => {
3129                let mut data = Vec::with_capacity(elem_count);
3130                let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
3131                    .map_err(Error::wrap)?;
3132                for _i in 0..elem_count {
3133                    data.push(rng.sample::<f16, _>(uniform))
3134                }
3135                Ok(CpuStorage::F16(data))
3136            }
3137            DType::F8E4M3 => {
3138                let mut data = Vec::with_capacity(elem_count);
3139                let uniform =
3140                    rand::distr::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max))
3141                        .map_err(Error::wrap)?;
3142                for _i in 0..elem_count {
3143                    data.push(rng.sample::<F8E4M3, _>(uniform))
3144                }
3145                Ok(CpuStorage::F8E4M3(data))
3146            }
3147            DType::F32 => {
3148                let mut data = Vec::with_capacity(elem_count);
3149                let uniform =
3150                    rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
3151                for _i in 0..elem_count {
3152                    data.push(rng.sample::<f32, _>(uniform))
3153                }
3154                Ok(CpuStorage::F32(data))
3155            }
3156            DType::F64 => {
3157                let mut data = Vec::with_capacity(elem_count);
3158                let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
3159                for _i in 0..elem_count {
3160                    data.push(rng.sample::<f64, _>(uniform))
3161                }
3162                Ok(CpuStorage::F64(data))
3163            }
3164        }
3165    }
3166
3167    fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CpuStorage> {
3168        use rand::prelude::*;
3169
3170        let elem_count = shape.elem_count();
3171        let mut rng = rand::rng();
3172        match dtype {
3173            DType::U8
3174            | DType::U32
3175            | DType::I16
3176            | DType::I32
3177            | DType::I64
3178            | DType::F6E2M3
3179            | DType::F6E3M2
3180            | DType::F4
3181            | DType::F8E8M0 => Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()),
3182            DType::BF16 => {
3183                let mut data = Vec::with_capacity(elem_count);
3184                let normal = rand_distr::Normal::new(bf16::from_f64(mean), bf16::from_f64(std))
3185                    .map_err(Error::wrap)?;
3186                for _i in 0..elem_count {
3187                    data.push(normal.sample(&mut rng))
3188                }
3189                Ok(CpuStorage::BF16(data))
3190            }
3191            DType::F16 => {
3192                let mut data = Vec::with_capacity(elem_count);
3193                let normal = rand_distr::Normal::new(f16::from_f64(mean), f16::from_f64(std))
3194                    .map_err(Error::wrap)?;
3195                for _i in 0..elem_count {
3196                    data.push(normal.sample(&mut rng))
3197                }
3198                Ok(CpuStorage::F16(data))
3199            }
3200            DType::F8E4M3 => {
3201                let mut data = Vec::with_capacity(elem_count);
3202                let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
3203                    .map_err(Error::wrap)?;
3204                for _i in 0..elem_count {
3205                    data.push(normal.sample(&mut rng))
3206                }
3207                Ok(CpuStorage::F8E4M3(data))
3208            }
3209            DType::F32 => {
3210                let mut data = Vec::with_capacity(elem_count);
3211                let normal =
3212                    rand_distr::Normal::new(mean as f32, std as f32).map_err(Error::wrap)?;
3213                for _i in 0..elem_count {
3214                    data.push(normal.sample(&mut rng))
3215                }
3216                Ok(CpuStorage::F32(data))
3217            }
3218            DType::F64 => {
3219                let mut data = Vec::with_capacity(elem_count);
3220                let normal = rand_distr::Normal::new(mean, std).map_err(Error::wrap)?;
3221                for _i in 0..elem_count {
3222                    data.push(normal.sample(&mut rng))
3223                }
3224                Ok(CpuStorage::F64(data))
3225            }
3226        }
3227    }
3228
3229    #[allow(clippy::uninit_vec)]
3230    unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3231        let elem_count = shape.elem_count();
3232        // The code below is highly unsafe but hopefully not directly unsound as we only consider
3233        // types that are Copy, not Drop, and for which all bit patterns are proper values.
3234        // It's still pretty risky, see the following for more details:
3235        // https://github.com/rust-lang/rust-clippy/issues/4483
3236        let storage = match dtype {
3237            DType::U8 => {
3238                let mut v = Vec::with_capacity(elem_count);
3239                v.set_len(elem_count);
3240                CpuStorage::U8(v)
3241            }
3242            DType::U32 => {
3243                let mut v = Vec::with_capacity(elem_count);
3244                v.set_len(elem_count);
3245                CpuStorage::U32(v)
3246            }
3247            DType::I16 => {
3248                let mut v = Vec::with_capacity(elem_count);
3249                v.set_len(elem_count);
3250                CpuStorage::I16(v)
3251            }
3252            DType::I32 => {
3253                let mut v = Vec::with_capacity(elem_count);
3254                v.set_len(elem_count);
3255                CpuStorage::I32(v)
3256            }
3257            DType::I64 => {
3258                let mut v = Vec::with_capacity(elem_count);
3259                v.set_len(elem_count);
3260                CpuStorage::I64(v)
3261            }
3262            DType::BF16 => {
3263                let mut v = Vec::with_capacity(elem_count);
3264                v.set_len(elem_count);
3265                CpuStorage::BF16(v)
3266            }
3267            DType::F16 => {
3268                let mut v = Vec::with_capacity(elem_count);
3269                v.set_len(elem_count);
3270                CpuStorage::F16(v)
3271            }
3272            DType::F32 => {
3273                let mut v = Vec::with_capacity(elem_count);
3274                v.set_len(elem_count);
3275                CpuStorage::F32(v)
3276            }
3277            DType::F64 => {
3278                let mut v = Vec::with_capacity(elem_count);
3279                v.set_len(elem_count);
3280                CpuStorage::F64(v)
3281            }
3282            DType::F8E4M3 => {
3283                let mut v = Vec::with_capacity(elem_count);
3284                v.set_len(elem_count);
3285                CpuStorage::F8E4M3(v)
3286            }
3287            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3288                return Err(Error::UnsupportedDTypeForOp(dtype, "alloc_uninit").bt())
3289            }
3290        };
3291        Ok(storage)
3292    }
3293
3294    fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
3295        let elem_count = shape.elem_count();
3296        let storage = match dtype {
3297            DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
3298            DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
3299            DType::I16 => CpuStorage::I16(vec![0i16; elem_count]),
3300            DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
3301            DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
3302            DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
3303            DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
3304            DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
3305            DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
3306            DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
3307            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
3308                return Err(Error::UnsupportedDTypeForOp(dtype, "zeros").bt())
3309            }
3310        };
3311        Ok(storage)
3312    }
3313
3314    fn synchronize(&self) -> Result<()> {
3315        Ok(())
3316    }
3317}
3318
3319#[macro_export]
3320macro_rules! map_dtype {
3321    ($name:expr, $storage:ident, $fn:expr, ($($dtypes:ident),+)) => {
3322        match $storage {
3323            $(CpuStorage::$dtypes(__e) => CpuStorage::$dtypes($fn(__e)),)*
3324            s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())?,
3325        }
3326    };
3327}