candle_core/
storage.rs

1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::scalar::Scalar;
4use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
5use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
6
7// We do not want to implement Clone on Storage as cloning may fail because of
8// out of memory. Instead try_clone should be used.
9#[derive(Debug)]
10pub enum Storage {
11    Cpu(CpuStorage),
12    Cuda(CudaStorage),
13    Metal(MetalStorage),
14}
15
16impl Storage {
17    pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
18        match self {
19            Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
20            Self::Cuda(storage) => {
21                let storage = storage.try_clone(layout)?;
22                Ok(Self::Cuda(storage))
23            }
24            Self::Metal(storage) => {
25                let storage = storage.try_clone(layout)?;
26                Ok(Self::Metal(storage))
27            }
28        }
29    }
30
31    pub fn device(&self) -> Device {
32        match self {
33            Self::Cpu(_) => Device::Cpu,
34            Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
35            Self::Metal(storage) => Device::Metal(storage.device().clone()),
36        }
37    }
38
39    pub fn dtype(&self) -> DType {
40        match self {
41            Self::Cpu(storage) => storage.dtype(),
42            Self::Cuda(storage) => storage.dtype(),
43            Self::Metal(storage) => storage.dtype(),
44        }
45    }
46
47    pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
48        let lhs_device = self.device();
49        let rhs_device = rhs.device();
50        let lhs = lhs_device.location();
51        let rhs = rhs_device.location();
52        let same_device = if self.device().is_metal() {
53            // On metal, we require the device to be exactly the same rather than
54            // having the same location. In cuda this is not necessary as all CudaDevice on the
55            // same GPU will use the same cuda stream.
56            lhs_device.same_device(&rhs_device)
57        } else {
58            lhs == rhs
59        };
60        if !same_device {
61            Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
62        } else {
63            Ok(())
64        }
65    }
66
67    pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
68        let lhs = self.dtype();
69        let rhs = rhs.dtype();
70        if lhs != rhs {
71            Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
72        } else {
73            Ok(())
74        }
75    }
76
77    pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
78        match self {
79            Storage::Cpu(storage) => storage.const_set(v, l),
80            Storage::Cuda(storage) => storage.const_set(v, l),
81            Storage::Metal(storage) => storage.const_set(v, l),
82        }
83    }
84
85    pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
86        match self {
87            Storage::Cpu(storage) => {
88                let storage = storage.affine(layout, mul, add)?;
89                Ok(Self::Cpu(storage))
90            }
91            Self::Cuda(storage) => {
92                let storage = storage.affine(layout, mul, add)?;
93                Ok(Self::Cuda(storage))
94            }
95            Self::Metal(storage) => {
96                let storage = storage.affine(layout, mul, add)?;
97                Ok(Self::Metal(storage))
98            }
99        }
100    }
101
102    pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
103        match self {
104            Storage::Cpu(storage) => {
105                let storage = storage.powf(layout, alpha)?;
106                Ok(Self::Cpu(storage))
107            }
108            Self::Cuda(storage) => {
109                let storage = storage.powf(layout, alpha)?;
110                Ok(Self::Cuda(storage))
111            }
112            Self::Metal(storage) => {
113                let storage = storage.powf(layout, alpha)?;
114                Ok(Self::Metal(storage))
115            }
116        }
117    }
118
119    pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
120        match self {
121            Storage::Cpu(storage) => {
122                let storage = storage.elu(layout, alpha)?;
123                Ok(Self::Cpu(storage))
124            }
125            Self::Cuda(storage) => {
126                let storage = storage.elu(layout, alpha)?;
127                Ok(Self::Cuda(storage))
128            }
129            Self::Metal(storage) => {
130                let storage = storage.elu(layout, alpha)?;
131                Ok(Self::Metal(storage))
132            }
133        }
134    }
135
136    pub(crate) fn cmp(
137        &self,
138        op: CmpOp,
139        rhs: &Self,
140        lhs_layout: &Layout,
141        rhs_layout: &Layout,
142    ) -> Result<Self> {
143        self.same_device(rhs, "cmp")?;
144        self.same_dtype(rhs, "cmp")?;
145        match (self, rhs) {
146            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
147                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
148                Ok(Self::Cpu(storage))
149            }
150            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
151                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
152                Ok(Self::Cuda(storage))
153            }
154            (Self::Metal(lhs), Self::Metal(rhs)) => {
155                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
156                Ok(Self::Metal(storage))
157            }
158            (lhs, rhs) => {
159                // Should not happen because of the same device check above but we're defensive
160                // anyway.
161                Err(Error::DeviceMismatchBinaryOp {
162                    lhs: lhs.device().location(),
163                    rhs: rhs.device().location(),
164                    op: "cmp",
165                }
166                .bt())
167            }
168        }
169    }
170
171    pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
172        match self {
173            Storage::Cpu(storage) => {
174                let storage = storage.reduce_op(op, layout, s)?;
175                Ok(Self::Cpu(storage))
176            }
177            Self::Cuda(storage) => {
178                let storage = storage.reduce_op(op, layout, s)?;
179                Ok(Self::Cuda(storage))
180            }
181            Self::Metal(storage) => {
182                let storage = storage.reduce_op(op, layout, s)?;
183                Ok(Self::Metal(storage))
184            }
185        }
186    }
187
188    pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
189        match self {
190            Storage::Cpu(storage) => {
191                let storage = storage.to_dtype(layout, dtype)?;
192                Ok(Self::Cpu(storage))
193            }
194            Self::Cuda(storage) => {
195                let storage = storage.to_dtype(layout, dtype)?;
196                Ok(Self::Cuda(storage))
197            }
198            Self::Metal(storage) => {
199                let storage = storage.to_dtype(layout, dtype)?;
200                Ok(Self::Metal(storage))
201            }
202        }
203    }
204
205    pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
206        match self {
207            Self::Cpu(storage) => {
208                let (storage, shape) = c.cpu_fwd(storage, l)?;
209                Ok((Self::Cpu(storage), shape))
210            }
211            Self::Cuda(storage) => {
212                let (storage, shape) = c.cuda_fwd(storage, l)?;
213                Ok((Self::Cuda(storage), shape))
214            }
215            Self::Metal(storage) => {
216                let (storage, shape) = c.metal_fwd(storage, l)?;
217                Ok((Self::Metal(storage), shape))
218            }
219        }
220    }
221
222    pub(crate) fn apply_op2(
223        &self,
224        l1: &Layout,
225        t2: &Self,
226        l2: &Layout,
227        c: &dyn CustomOp2,
228    ) -> Result<(Self, Shape)> {
229        self.same_device(t2, c.name())?;
230        match (self, t2) {
231            (Self::Cpu(s1), Self::Cpu(s2)) => {
232                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
233                Ok((Self::Cpu(s), shape))
234            }
235            (Self::Cuda(s1), Self::Cuda(s2)) => {
236                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
237                Ok((Self::Cuda(s), shape))
238            }
239            (Self::Metal(s1), Self::Metal(s2)) => {
240                let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
241                Ok((Self::Metal(s), shape))
242            }
243            _ => unreachable!(),
244        }
245    }
246
247    pub(crate) fn apply_op3(
248        &self,
249        l1: &Layout,
250        t2: &Self,
251        l2: &Layout,
252        t3: &Self,
253        l3: &Layout,
254        c: &dyn CustomOp3,
255    ) -> Result<(Self, Shape)> {
256        self.same_device(t2, c.name())?;
257        self.same_device(t3, c.name())?;
258        match (self, t2, t3) {
259            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
260                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
261                Ok((Self::Cpu(s), shape))
262            }
263            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
264                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
265                Ok((Self::Cuda(s), shape))
266            }
267            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
268                let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
269                Ok((Self::Metal(s), shape))
270            }
271            _ => unreachable!(),
272        }
273    }
274
275    pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
276        match self {
277            Self::Cpu(storage) => c.cpu_fwd(storage, l),
278            Self::Cuda(storage) => c.cuda_fwd(storage, l),
279            Self::Metal(storage) => c.metal_fwd(storage, l),
280        }
281    }
282
283    pub(crate) fn inplace_op2(
284        &mut self,
285        l1: &Layout,
286        t2: &Self,
287        l2: &Layout,
288        c: &dyn InplaceOp2,
289    ) -> Result<()> {
290        self.same_device(t2, c.name())?;
291        match (self, t2) {
292            (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
293            (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
294            (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
295            _ => unreachable!(),
296        }
297    }
298
299    pub(crate) fn inplace_op3(
300        &mut self,
301        l1: &Layout,
302        t2: &Self,
303        l2: &Layout,
304        t3: &Self,
305        l3: &Layout,
306        c: &dyn InplaceOp3,
307    ) -> Result<()> {
308        self.same_device(t2, c.name())?;
309        self.same_device(t3, c.name())?;
310        match (self, t2, t3) {
311            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
312            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
313            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
314                c.metal_fwd(s1, l1, s2, l2, s3, l3)
315            }
316            _ => unreachable!(),
317        }
318    }
319
320    pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
321        match self {
322            Storage::Cpu(storage) => {
323                let storage = storage.unary_impl::<B>(layout)?;
324                Ok(Self::Cpu(storage))
325            }
326            Self::Cuda(storage) => {
327                let storage = storage.unary_impl::<B>(layout)?;
328                Ok(Self::Cuda(storage))
329            }
330            Self::Metal(storage) => {
331                let storage = storage.unary_impl::<B>(layout)?;
332                Ok(Self::Metal(storage))
333            }
334        }
335    }
336
337    pub(crate) fn binary_impl<B: op::BinaryOpT>(
338        &self,
339        rhs: &Self,
340        lhs_layout: &Layout,
341        rhs_layout: &Layout,
342    ) -> Result<Self> {
343        self.same_device(rhs, B::NAME)?;
344        self.same_dtype(rhs, B::NAME)?;
345        match (self, rhs) {
346            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
347                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
348                Ok(Self::Cpu(storage))
349            }
350            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
351                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
352                Ok(Self::Cuda(storage))
353            }
354            (Self::Metal(lhs), Self::Metal(rhs)) => {
355                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
356                Ok(Self::Metal(storage))
357            }
358            (lhs, rhs) => {
359                // Should not happen because of the same device check above but we're defensive
360                // anyway.
361                Err(Error::DeviceMismatchBinaryOp {
362                    lhs: lhs.device().location(),
363                    rhs: rhs.device().location(),
364                    op: B::NAME,
365                }
366                .bt())
367            }
368        }
369    }
370
371    pub(crate) fn conv1d(
372        &self,
373        l: &Layout,
374        kernel: &Self,
375        kernel_l: &Layout,
376        params: &crate::conv::ParamsConv1D,
377    ) -> Result<Self> {
378        self.same_device(kernel, "conv1d")?;
379        self.same_dtype(kernel, "conv1d")?;
380        match (self, &kernel) {
381            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
382                let s = inp.conv1d(l, kernel, kernel_l, params)?;
383                Ok(Self::Cpu(s))
384            }
385            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
386                let s = inp.conv1d(l, kernel, kernel_l, params)?;
387                Ok(Self::Cuda(s))
388            }
389            (Storage::Metal(inp), Storage::Metal(kernel)) => {
390                let s = inp.conv1d(l, kernel, kernel_l, params)?;
391                Ok(Self::Metal(s))
392            }
393            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
394                lhs: lhs.device().location(),
395                rhs: rhs.device().location(),
396                op: "conv1d",
397            }
398            .bt()),
399        }
400    }
401
402    pub(crate) fn conv_transpose1d(
403        &self,
404        l: &Layout,
405        kernel: &Self,
406        kernel_l: &Layout,
407        params: &crate::conv::ParamsConvTranspose1D,
408    ) -> Result<Self> {
409        self.same_device(kernel, "conv-transpose1d")?;
410        self.same_dtype(kernel, "conv-transpose1d")?;
411        match (self, &kernel) {
412            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
413                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
414                Ok(Self::Cpu(s))
415            }
416            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
417                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
418                Ok(Self::Cuda(s))
419            }
420            (Storage::Metal(inp), Storage::Metal(kernel)) => {
421                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
422                Ok(Self::Metal(s))
423            }
424            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
425                lhs: lhs.device().location(),
426                rhs: rhs.device().location(),
427                op: "conv-transpose1d",
428            }
429            .bt()),
430        }
431    }
432
433    pub(crate) fn conv2d(
434        &self,
435        l: &Layout,
436        kernel: &Self,
437        kernel_l: &Layout,
438        params: &crate::conv::ParamsConv2D,
439    ) -> Result<Self> {
440        self.same_device(kernel, "conv2d")?;
441        self.same_dtype(kernel, "conv2d")?;
442        match (self, &kernel) {
443            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
444                let s = inp.conv2d(l, kernel, kernel_l, params)?;
445                Ok(Self::Cpu(s))
446            }
447            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
448                let s = inp.conv2d(l, kernel, kernel_l, params)?;
449                Ok(Self::Cuda(s))
450            }
451            (Storage::Metal(inp), Storage::Metal(kernel)) => {
452                let s = inp.conv2d(l, kernel, kernel_l, params)?;
453                Ok(Self::Metal(s))
454            }
455            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
456                lhs: lhs.device().location(),
457                rhs: rhs.device().location(),
458                op: "conv2d",
459            }
460            .bt()),
461        }
462    }
463
464    pub(crate) fn conv_transpose2d(
465        &self,
466        l: &Layout,
467        kernel: &Self,
468        kernel_l: &Layout,
469        params: &crate::conv::ParamsConvTranspose2D,
470    ) -> Result<Self> {
471        self.same_device(kernel, "conv_transpose2d")?;
472        self.same_dtype(kernel, "conv_transpose2d")?;
473        match (self, &kernel) {
474            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
475                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
476                Ok(Self::Cpu(s))
477            }
478            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
479                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
480                Ok(Self::Cuda(s))
481            }
482            (Storage::Metal(inp), Storage::Metal(kernel)) => {
483                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
484                Ok(Self::Metal(s))
485            }
486            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
487                lhs: lhs.device().location(),
488                rhs: rhs.device().location(),
489                op: "conv_transpose2d",
490            }
491            .bt()),
492        }
493    }
494
495    pub(crate) fn avg_pool2d(
496        &self,
497        layout: &Layout,
498        kernel_size: (usize, usize),
499        stride: (usize, usize),
500    ) -> Result<Self> {
501        match self {
502            Storage::Cpu(storage) => {
503                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
504                Ok(Self::Cpu(storage))
505            }
506            Self::Cuda(storage) => {
507                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
508                Ok(Self::Cuda(storage))
509            }
510            Self::Metal(storage) => {
511                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
512                Ok(Self::Metal(storage))
513            }
514        }
515    }
516
517    pub(crate) fn max_pool2d(
518        &self,
519        layout: &Layout,
520        kernel_size: (usize, usize),
521        stride: (usize, usize),
522    ) -> Result<Self> {
523        match self {
524            Storage::Cpu(storage) => {
525                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
526                Ok(Self::Cpu(storage))
527            }
528            Self::Cuda(storage) => {
529                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
530                Ok(Self::Cuda(storage))
531            }
532            Self::Metal(storage) => {
533                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
534                Ok(Self::Metal(storage))
535            }
536        }
537    }
538
539    pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
540        match self {
541            Storage::Cpu(storage) => {
542                let storage = storage.upsample_nearest1d(layout, sz)?;
543                Ok(Self::Cpu(storage))
544            }
545            Self::Cuda(storage) => {
546                let storage = storage.upsample_nearest1d(layout, sz)?;
547                Ok(Self::Cuda(storage))
548            }
549            Self::Metal(storage) => {
550                let storage = storage.upsample_nearest1d(layout, sz)?;
551                Ok(Self::Metal(storage))
552            }
553        }
554    }
555
556    pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
557        match self {
558            Storage::Cpu(storage) => {
559                let storage = storage.upsample_nearest2d(layout, h, w)?;
560                Ok(Self::Cpu(storage))
561            }
562            Self::Cuda(storage) => {
563                let storage = storage.upsample_nearest2d(layout, h, w)?;
564                Ok(Self::Cuda(storage))
565            }
566            Self::Metal(storage) => {
567                let storage = storage.upsample_nearest2d(layout, h, w)?;
568                Ok(Self::Metal(storage))
569            }
570        }
571    }
572
573    pub(crate) fn where_cond(
574        &self,
575        layout: &Layout,
576        t: &Self,
577        layout_t: &Layout,
578        f: &Self,
579        layout_f: &Layout,
580    ) -> Result<Self> {
581        self.same_device(t, "where")?;
582        self.same_device(f, "where")?;
583        t.same_dtype(f, "where")?;
584        match (self, t, f) {
585            (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
586                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
587                Ok(Self::Cpu(storage))
588            }
589            (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
590                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
591                Ok(Self::Cuda(storage))
592            }
593            (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
594                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
595                Ok(Self::Metal(storage))
596            }
597            (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
598                lhs: lhs.device().location(),
599                rhs: rhs.device().location(),
600                op: "where",
601            }
602            .bt()),
603        }
604    }
605
606    pub(crate) fn gather(
607        &self,
608        l: &Layout,
609        indexes: &Self,
610        indexes_l: &Layout,
611        d: usize,
612    ) -> Result<Self> {
613        self.same_device(indexes, "index-add")?;
614        match (self, indexes) {
615            (Self::Cpu(s), Self::Cpu(indexes)) => {
616                let storage = s.gather(l, indexes, indexes_l, d)?;
617                Ok(Self::Cpu(storage))
618            }
619            (Self::Cuda(s), Self::Cuda(indexes)) => {
620                let storage = s.gather(l, indexes, indexes_l, d)?;
621                Ok(Self::Cuda(storage))
622            }
623            (Self::Metal(s), Self::Metal(indexes)) => {
624                let storage = s.gather(l, indexes, indexes_l, d)?;
625                Ok(Self::Metal(storage))
626            }
627            _ => unreachable!(),
628        }
629    }
630
631    pub(crate) fn scatter_set(
632        &mut self,
633        l: &Layout,
634        indexes: &Self,
635        indexes_l: &Layout,
636        source: &Self,
637        source_l: &Layout,
638        d: usize,
639    ) -> Result<()> {
640        self.same_device(indexes, "scatter-set")?;
641        self.same_device(source, "scatter-set")?;
642        match (self, indexes, source) {
643            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
644                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
645            }
646            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
647                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
648            }
649            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
650                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
651            }
652            _ => unreachable!(),
653        }
654        Ok(())
655    }
656
657    pub(crate) fn scatter_add(
658        &mut self,
659        l: &Layout,
660        indexes: &Self,
661        indexes_l: &Layout,
662        source: &Self,
663        source_l: &Layout,
664        d: usize,
665    ) -> Result<()> {
666        self.same_device(indexes, "scatter-add")?;
667        self.same_device(source, "scatter-add")?;
668        match (self, indexes, source) {
669            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
670                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
671            }
672            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
673                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
674            }
675            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
676                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
677            }
678            _ => unreachable!(),
679        }
680        Ok(())
681    }
682
683    pub(crate) fn index_add(
684        &self,
685        l: &Layout,
686        indexes: &Self,
687        indexes_l: &Layout,
688        source: &Self,
689        source_l: &Layout,
690        d: usize,
691    ) -> Result<Self> {
692        self.same_device(indexes, "index-add")?;
693        self.same_device(source, "index-add")?;
694        match (self, indexes, source) {
695            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
696                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
697                Ok(Self::Cpu(storage))
698            }
699            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
700                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
701                Ok(Self::Cuda(storage))
702            }
703            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
704                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
705                Ok(Self::Metal(storage))
706            }
707            _ => unreachable!(),
708        }
709    }
710
711    pub(crate) fn index_select(
712        &self,
713        rhs: &Self,
714        lhs_l: &Layout,
715        rhs_l: &Layout,
716        d: usize,
717    ) -> Result<Self> {
718        self.same_device(rhs, "index-select")?;
719        match (self, rhs) {
720            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
721                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
722                Ok(Self::Cpu(storage))
723            }
724            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
725                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
726                Ok(Self::Cuda(storage))
727            }
728            (Self::Metal(lhs), Self::Metal(rhs)) => {
729                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
730                Ok(Self::Metal(storage))
731            }
732            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
733                lhs: lhs.device().location(),
734                rhs: rhs.device().location(),
735                op: "index-select",
736            }
737            .bt()),
738        }
739    }
740
741    pub(crate) fn matmul(
742        &self,
743        rhs: &Self,
744        bmnk: (usize, usize, usize, usize),
745        lhs_layout: &Layout,
746        rhs_layout: &Layout,
747    ) -> Result<Self> {
748        self.same_device(rhs, "matmul")?;
749        self.same_dtype(rhs, "matmul")?;
750        match (self, rhs) {
751            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
752                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
753                Ok(Self::Cpu(storage))
754            }
755            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
756                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
757                Ok(Self::Cuda(storage))
758            }
759            (Self::Metal(lhs), Self::Metal(rhs)) => {
760                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
761                Ok(Self::Metal(storage))
762            }
763            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
764                lhs: lhs.device().location(),
765                rhs: rhs.device().location(),
766                op: "matmul",
767            }
768            .bt()),
769        }
770    }
771
772    // self, the source can be strided whereas dst is contiguous.
773    pub(crate) fn copy_strided_src(
774        &self,
775        dst: &mut Self,
776        dst_offset: usize,
777        src_l: &Layout,
778    ) -> Result<()> {
779        match (self, dst) {
780            (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
781            (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
782            (Self::Metal(src), Self::Metal(dst)) => {
783                Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
784            }
785            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
786                lhs: lhs.device().location(),
787                rhs: rhs.device().location(),
788                op: "copy",
789            }
790            .bt()),
791        }
792    }
793
794    #[allow(clippy::too_many_arguments)]
795    pub(crate) fn copy2d(
796        &self,
797        dst: &mut Self,
798        d1: usize,
799        d2: usize,
800        src_s: usize,
801        dst_s: usize,
802        src_o: usize,
803        dst_o: usize,
804    ) -> Result<()> {
805        match (self, dst) {
806            (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
807            (Self::Cuda(src), Self::Cuda(dst)) => {
808                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
809            }
810            (Self::Metal(src), Self::Metal(dst)) => {
811                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
812            }
813            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
814                lhs: lhs.device().location(),
815                rhs: rhs.device().location(),
816                op: "copy2d",
817            }
818            .bt()),
819        }
820    }
821}