candle_core/
storage.rs

1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::scalar::Scalar;
4use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
5use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
6
7// We do not want to implement Clone on Storage as cloning may fail because of
8// out of memory. Instead try_clone should be used.
9#[derive(Debug)]
10pub enum Storage {
11    Cpu(CpuStorage),
12    Cuda(CudaStorage),
13    Metal(MetalStorage),
14}
15
16impl Storage {
17    pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
18        match self {
19            Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
20            Self::Cuda(storage) => {
21                let storage = storage.try_clone(layout)?;
22                Ok(Self::Cuda(storage))
23            }
24            Self::Metal(storage) => {
25                let storage = storage.try_clone(layout)?;
26                Ok(Self::Metal(storage))
27            }
28        }
29    }
30
31    pub fn device(&self) -> Device {
32        match self {
33            Self::Cpu(_) => Device::Cpu,
34            Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
35            Self::Metal(storage) => Device::Metal(storage.device().clone()),
36        }
37    }
38
39    pub fn dtype(&self) -> DType {
40        match self {
41            Self::Cpu(storage) => storage.dtype(),
42            Self::Cuda(storage) => storage.dtype(),
43            Self::Metal(storage) => storage.dtype(),
44        }
45    }
46
47    pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
48        let lhs_device = self.device();
49        let rhs_device = rhs.device();
50        let lhs = lhs_device.location();
51        let rhs = rhs_device.location();
52        let same_device = if self.device().is_metal() {
53            // On metal, we require the device to be exactly the same rather than
54            // having the same location. In cuda this is not necessary as all CudaDevice on the
55            // same GPU will use the same cuda stream.
56            lhs_device.same_device(&rhs_device)
57        } else {
58            lhs == rhs
59        };
60        if !same_device {
61            Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
62        } else {
63            Ok(())
64        }
65    }
66
67    pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
68        let lhs = self.dtype();
69        let rhs = rhs.dtype();
70        if lhs != rhs {
71            Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
72        } else {
73            Ok(())
74        }
75    }
76
77    pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
78        match self {
79            Storage::Cpu(storage) => storage.const_set(v, l),
80            Storage::Cuda(storage) => storage.const_set(v, l),
81            Storage::Metal(storage) => storage.const_set(v, l),
82        }
83    }
84
85    pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
86        match self {
87            Storage::Cpu(storage) => {
88                let storage = storage.affine(layout, mul, add)?;
89                Ok(Self::Cpu(storage))
90            }
91            Self::Cuda(storage) => {
92                let storage = storage.affine(layout, mul, add)?;
93                Ok(Self::Cuda(storage))
94            }
95            Self::Metal(storage) => {
96                let storage = storage.affine(layout, mul, add)?;
97                Ok(Self::Metal(storage))
98            }
99        }
100    }
101
102    pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
103        match self {
104            Storage::Cpu(storage) => {
105                let storage = storage.powf(layout, alpha)?;
106                Ok(Self::Cpu(storage))
107            }
108            Self::Cuda(storage) => {
109                let storage = storage.powf(layout, alpha)?;
110                Ok(Self::Cuda(storage))
111            }
112            Self::Metal(storage) => {
113                let storage = storage.powf(layout, alpha)?;
114                Ok(Self::Metal(storage))
115            }
116        }
117    }
118
119    pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
120        match self {
121            Storage::Cpu(storage) => {
122                let storage = storage.elu(layout, alpha)?;
123                Ok(Self::Cpu(storage))
124            }
125            Self::Cuda(storage) => {
126                let storage = storage.elu(layout, alpha)?;
127                Ok(Self::Cuda(storage))
128            }
129            Self::Metal(storage) => {
130                let storage = storage.elu(layout, alpha)?;
131                Ok(Self::Metal(storage))
132            }
133        }
134    }
135
136    pub(crate) fn cmp(
137        &self,
138        op: CmpOp,
139        rhs: &Self,
140        lhs_layout: &Layout,
141        rhs_layout: &Layout,
142    ) -> Result<Self> {
143        self.same_device(rhs, "cmp")?;
144        self.same_dtype(rhs, "cmp")?;
145        match (self, rhs) {
146            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
147                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
148                Ok(Self::Cpu(storage))
149            }
150            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
151                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
152                Ok(Self::Cuda(storage))
153            }
154            (Self::Metal(lhs), Self::Metal(rhs)) => {
155                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
156                Ok(Self::Metal(storage))
157            }
158            (lhs, rhs) => {
159                // Should not happen because of the same device check above but we're defensive
160                // anyway.
161                Err(Error::DeviceMismatchBinaryOp {
162                    lhs: lhs.device().location(),
163                    rhs: rhs.device().location(),
164                    op: "cmp",
165                }
166                .bt())
167            }
168        }
169    }
170
171    pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
172        match self {
173            Storage::Cpu(storage) => {
174                let storage = storage.reduce_op(op, layout, s)?;
175                Ok(Self::Cpu(storage))
176            }
177            Self::Cuda(storage) => {
178                let storage = storage.reduce_op(op, layout, s)?;
179                Ok(Self::Cuda(storage))
180            }
181            Self::Metal(storage) => {
182                let storage = storage.reduce_op(op, layout, s)?;
183                Ok(Self::Metal(storage))
184            }
185        }
186    }
187
188    pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
189        match self {
190            Storage::Cpu(storage) => {
191                let storage = storage.to_dtype(layout, dtype)?;
192                Ok(Self::Cpu(storage))
193            }
194            Self::Cuda(storage) => {
195                let storage = storage.to_dtype(layout, dtype)?;
196                Ok(Self::Cuda(storage))
197            }
198            Self::Metal(storage) => {
199                let storage = storage.to_dtype(layout, dtype)?;
200                Ok(Self::Metal(storage))
201            }
202        }
203    }
204
205    pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
206        match self {
207            Self::Cpu(storage) => {
208                let (storage, shape) = c.cpu_fwd(storage, l)?;
209                Ok((Self::Cpu(storage), shape))
210            }
211            Self::Cuda(storage) => {
212                let (storage, shape) = c.cuda_fwd(storage, l)?;
213                Ok((Self::Cuda(storage), shape))
214            }
215            Self::Metal(storage) => {
216                let (storage, shape) = c.metal_fwd(storage, l)?;
217                Ok((Self::Metal(storage), shape))
218            }
219        }
220    }
221
222    pub(crate) fn apply_op2(
223        &self,
224        l1: &Layout,
225        t2: &Self,
226        l2: &Layout,
227        c: &dyn CustomOp2,
228    ) -> Result<(Self, Shape)> {
229        self.same_device(t2, c.name())?;
230        match (self, t2) {
231            (Self::Cpu(s1), Self::Cpu(s2)) => {
232                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
233                Ok((Self::Cpu(s), shape))
234            }
235            (Self::Cuda(s1), Self::Cuda(s2)) => {
236                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
237                Ok((Self::Cuda(s), shape))
238            }
239            (Self::Metal(s1), Self::Metal(s2)) => {
240                let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
241                Ok((Self::Metal(s), shape))
242            }
243            _ => unreachable!(),
244        }
245    }
246
247    pub(crate) fn apply_op3(
248        &self,
249        l1: &Layout,
250        t2: &Self,
251        l2: &Layout,
252        t3: &Self,
253        l3: &Layout,
254        c: &dyn CustomOp3,
255    ) -> Result<(Self, Shape)> {
256        self.same_device(t2, c.name())?;
257        self.same_device(t3, c.name())?;
258        match (self, t2, t3) {
259            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
260                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
261                Ok((Self::Cpu(s), shape))
262            }
263            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
264                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
265                Ok((Self::Cuda(s), shape))
266            }
267            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
268                let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
269                Ok((Self::Metal(s), shape))
270            }
271            _ => unreachable!(),
272        }
273    }
274
275    pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
276        match self {
277            Self::Cpu(storage) => c.cpu_fwd(storage, l),
278            Self::Cuda(storage) => c.cuda_fwd(storage, l),
279            Self::Metal(storage) => c.metal_fwd(storage, l),
280        }
281    }
282
283    pub(crate) fn inplace_op2(
284        &mut self,
285        l1: &Layout,
286        t2: &Self,
287        l2: &Layout,
288        c: &dyn InplaceOp2,
289    ) -> Result<()> {
290        self.same_device(t2, c.name())?;
291        match (self, t2) {
292            (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
293            (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
294            (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
295            _ => unreachable!(),
296        }
297    }
298
299    pub(crate) fn inplace_op3(
300        &mut self,
301        l1: &Layout,
302        t2: &Self,
303        l2: &Layout,
304        t3: &Self,
305        l3: &Layout,
306        c: &dyn InplaceOp3,
307    ) -> Result<()> {
308        self.same_device(t2, c.name())?;
309        self.same_device(t3, c.name())?;
310        match (self, t2, t3) {
311            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
312            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
313            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
314                c.metal_fwd(s1, l1, s2, l2, s3, l3)
315            }
316            _ => unreachable!(),
317        }
318    }
319
320    pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
321        match self {
322            Storage::Cpu(storage) => {
323                let storage = storage.unary_impl::<B>(layout)?;
324                Ok(Self::Cpu(storage))
325            }
326            Self::Cuda(storage) => {
327                let storage = storage.unary_impl::<B>(layout)?;
328                Ok(Self::Cuda(storage))
329            }
330            Self::Metal(storage) => {
331                let storage = storage.unary_impl::<B>(layout)?;
332                Ok(Self::Metal(storage))
333            }
334        }
335    }
336
337    pub(crate) fn binary_impl<B: op::BinaryOpT>(
338        &self,
339        rhs: &Self,
340        lhs_layout: &Layout,
341        rhs_layout: &Layout,
342    ) -> Result<Self> {
343        self.same_device(rhs, B::NAME)?;
344        self.same_dtype(rhs, B::NAME)?;
345        match (self, rhs) {
346            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
347                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
348                Ok(Self::Cpu(storage))
349            }
350            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
351                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
352                Ok(Self::Cuda(storage))
353            }
354            (Self::Metal(lhs), Self::Metal(rhs)) => {
355                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
356                Ok(Self::Metal(storage))
357            }
358            (lhs, rhs) => {
359                // Should not happen because of the same device check above but we're defensive
360                // anyway.
361                Err(Error::DeviceMismatchBinaryOp {
362                    lhs: lhs.device().location(),
363                    rhs: rhs.device().location(),
364                    op: B::NAME,
365                }
366                .bt())
367            }
368        }
369    }
370
371    pub(crate) fn conv1d(
372        &self,
373        l: &Layout,
374        kernel: &Self,
375        kernel_l: &Layout,
376        params: &crate::conv::ParamsConv1D,
377    ) -> Result<Self> {
378        self.same_device(kernel, "conv1d")?;
379        self.same_dtype(kernel, "conv1d")?;
380        match (self, &kernel) {
381            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
382                let s = inp.conv1d(l, kernel, kernel_l, params)?;
383                Ok(Self::Cpu(s))
384            }
385            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
386                let s = inp.conv1d(l, kernel, kernel_l, params)?;
387                Ok(Self::Cuda(s))
388            }
389            (Storage::Metal(inp), Storage::Metal(kernel)) => {
390                let s = inp.conv1d(l, kernel, kernel_l, params)?;
391                Ok(Self::Metal(s))
392            }
393            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
394                lhs: lhs.device().location(),
395                rhs: rhs.device().location(),
396                op: "conv1d",
397            }
398            .bt()),
399        }
400    }
401
402    pub(crate) fn conv_transpose1d(
403        &self,
404        l: &Layout,
405        kernel: &Self,
406        kernel_l: &Layout,
407        params: &crate::conv::ParamsConvTranspose1D,
408    ) -> Result<Self> {
409        self.same_device(kernel, "conv-transpose1d")?;
410        self.same_dtype(kernel, "conv-transpose1d")?;
411        match (self, &kernel) {
412            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
413                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
414                Ok(Self::Cpu(s))
415            }
416            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
417                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
418                Ok(Self::Cuda(s))
419            }
420            (Storage::Metal(inp), Storage::Metal(kernel)) => {
421                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
422                Ok(Self::Metal(s))
423            }
424            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
425                lhs: lhs.device().location(),
426                rhs: rhs.device().location(),
427                op: "conv-transpose1d",
428            }
429            .bt()),
430        }
431    }
432
433    pub(crate) fn conv2d(
434        &self,
435        l: &Layout,
436        kernel: &Self,
437        kernel_l: &Layout,
438        params: &crate::conv::ParamsConv2D,
439    ) -> Result<Self> {
440        self.same_device(kernel, "conv2d")?;
441        self.same_dtype(kernel, "conv2d")?;
442        match (self, &kernel) {
443            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
444                let s = inp.conv2d(l, kernel, kernel_l, params)?;
445                Ok(Self::Cpu(s))
446            }
447            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
448                let s = inp.conv2d(l, kernel, kernel_l, params)?;
449                Ok(Self::Cuda(s))
450            }
451            (Storage::Metal(inp), Storage::Metal(kernel)) => {
452                let s = inp.conv2d(l, kernel, kernel_l, params)?;
453                Ok(Self::Metal(s))
454            }
455            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
456                lhs: lhs.device().location(),
457                rhs: rhs.device().location(),
458                op: "conv2d",
459            }
460            .bt()),
461        }
462    }
463
464    pub(crate) fn conv_transpose2d(
465        &self,
466        l: &Layout,
467        kernel: &Self,
468        kernel_l: &Layout,
469        params: &crate::conv::ParamsConvTranspose2D,
470    ) -> Result<Self> {
471        self.same_device(kernel, "conv_transpose2d")?;
472        self.same_dtype(kernel, "conv_transpose2d")?;
473        match (self, &kernel) {
474            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
475                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
476                Ok(Self::Cpu(s))
477            }
478            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
479                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
480                Ok(Self::Cuda(s))
481            }
482            (Storage::Metal(inp), Storage::Metal(kernel)) => {
483                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
484                Ok(Self::Metal(s))
485            }
486            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
487                lhs: lhs.device().location(),
488                rhs: rhs.device().location(),
489                op: "conv_transpose2d",
490            }
491            .bt()),
492        }
493    }
494
495    pub(crate) fn avg_pool2d(
496        &self,
497        layout: &Layout,
498        kernel_size: (usize, usize),
499        stride: (usize, usize),
500    ) -> Result<Self> {
501        match self {
502            Storage::Cpu(storage) => {
503                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
504                Ok(Self::Cpu(storage))
505            }
506            Self::Cuda(storage) => {
507                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
508                Ok(Self::Cuda(storage))
509            }
510            Self::Metal(storage) => {
511                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
512                Ok(Self::Metal(storage))
513            }
514        }
515    }
516
517    pub(crate) fn max_pool2d(
518        &self,
519        layout: &Layout,
520        kernel_size: (usize, usize),
521        stride: (usize, usize),
522    ) -> Result<Self> {
523        match self {
524            Storage::Cpu(storage) => {
525                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
526                Ok(Self::Cpu(storage))
527            }
528            Self::Cuda(storage) => {
529                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
530                Ok(Self::Cuda(storage))
531            }
532            Self::Metal(storage) => {
533                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
534                Ok(Self::Metal(storage))
535            }
536        }
537    }
538
539    pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
540        match self {
541            Storage::Cpu(storage) => {
542                let storage = storage.upsample_nearest1d(layout, sz)?;
543                Ok(Self::Cpu(storage))
544            }
545            Self::Cuda(storage) => {
546                let storage = storage.upsample_nearest1d(layout, sz)?;
547                Ok(Self::Cuda(storage))
548            }
549            Self::Metal(storage) => {
550                let storage = storage.upsample_nearest1d(layout, sz)?;
551                Ok(Self::Metal(storage))
552            }
553        }
554    }
555
556    pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
557        match self {
558            Storage::Cpu(storage) => {
559                let storage = storage.upsample_nearest2d(layout, h, w)?;
560                Ok(Self::Cpu(storage))
561            }
562            Self::Cuda(storage) => {
563                let storage = storage.upsample_nearest2d(layout, h, w)?;
564                Ok(Self::Cuda(storage))
565            }
566            Self::Metal(storage) => {
567                let storage = storage.upsample_nearest2d(layout, h, w)?;
568                Ok(Self::Metal(storage))
569            }
570        }
571    }
572
573    pub(crate) fn upsample_bilinear2d(
574        &self,
575        layout: &Layout,
576        h: usize,
577        w: usize,
578        align_corners: bool,
579        scale_h: Option<f64>,
580        scale_w: Option<f64>,
581    ) -> Result<Self> {
582        match self {
583            Storage::Cpu(storage) => {
584                let storage =
585                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
586                Ok(Self::Cpu(storage))
587            }
588            Self::Cuda(storage) => {
589                let storage =
590                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
591                Ok(Self::Cuda(storage))
592            }
593            Self::Metal(storage) => {
594                let storage =
595                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
596                Ok(Self::Metal(storage))
597            }
598        }
599    }
600
601    pub(crate) fn where_cond(
602        &self,
603        layout: &Layout,
604        t: &Self,
605        layout_t: &Layout,
606        f: &Self,
607        layout_f: &Layout,
608    ) -> Result<Self> {
609        self.same_device(t, "where")?;
610        self.same_device(f, "where")?;
611        t.same_dtype(f, "where")?;
612        match (self, t, f) {
613            (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
614                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
615                Ok(Self::Cpu(storage))
616            }
617            (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
618                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
619                Ok(Self::Cuda(storage))
620            }
621            (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
622                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
623                Ok(Self::Metal(storage))
624            }
625            (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
626                lhs: lhs.device().location(),
627                rhs: rhs.device().location(),
628                op: "where",
629            }
630            .bt()),
631        }
632    }
633
634    pub(crate) fn gather(
635        &self,
636        l: &Layout,
637        indexes: &Self,
638        indexes_l: &Layout,
639        d: usize,
640    ) -> Result<Self> {
641        self.same_device(indexes, "index-add")?;
642        match (self, indexes) {
643            (Self::Cpu(s), Self::Cpu(indexes)) => {
644                let storage = s.gather(l, indexes, indexes_l, d)?;
645                Ok(Self::Cpu(storage))
646            }
647            (Self::Cuda(s), Self::Cuda(indexes)) => {
648                let storage = s.gather(l, indexes, indexes_l, d)?;
649                Ok(Self::Cuda(storage))
650            }
651            (Self::Metal(s), Self::Metal(indexes)) => {
652                let storage = s.gather(l, indexes, indexes_l, d)?;
653                Ok(Self::Metal(storage))
654            }
655            _ => unreachable!(),
656        }
657    }
658
659    pub(crate) fn scatter_set(
660        &mut self,
661        l: &Layout,
662        indexes: &Self,
663        indexes_l: &Layout,
664        source: &Self,
665        source_l: &Layout,
666        d: usize,
667    ) -> Result<()> {
668        self.same_device(indexes, "scatter-set")?;
669        self.same_device(source, "scatter-set")?;
670        match (self, indexes, source) {
671            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
672                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
673            }
674            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
675                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
676            }
677            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
678                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
679            }
680            _ => unreachable!(),
681        }
682        Ok(())
683    }
684
685    pub(crate) fn scatter_add(
686        &mut self,
687        l: &Layout,
688        indexes: &Self,
689        indexes_l: &Layout,
690        source: &Self,
691        source_l: &Layout,
692        d: usize,
693    ) -> Result<()> {
694        self.same_device(indexes, "scatter-add")?;
695        self.same_device(source, "scatter-add")?;
696        match (self, indexes, source) {
697            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
698                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
699            }
700            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
701                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
702            }
703            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
704                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
705            }
706            _ => unreachable!(),
707        }
708        Ok(())
709    }
710
711    pub(crate) fn index_add(
712        &self,
713        l: &Layout,
714        indexes: &Self,
715        indexes_l: &Layout,
716        source: &Self,
717        source_l: &Layout,
718        d: usize,
719    ) -> Result<Self> {
720        self.same_device(indexes, "index-add")?;
721        self.same_device(source, "index-add")?;
722        match (self, indexes, source) {
723            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
724                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
725                Ok(Self::Cpu(storage))
726            }
727            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
728                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
729                Ok(Self::Cuda(storage))
730            }
731            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
732                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
733                Ok(Self::Metal(storage))
734            }
735            _ => unreachable!(),
736        }
737    }
738
739    pub(crate) fn index_select(
740        &self,
741        rhs: &Self,
742        lhs_l: &Layout,
743        rhs_l: &Layout,
744        d: usize,
745    ) -> Result<Self> {
746        self.same_device(rhs, "index-select")?;
747        match (self, rhs) {
748            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
749                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
750                Ok(Self::Cpu(storage))
751            }
752            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
753                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
754                Ok(Self::Cuda(storage))
755            }
756            (Self::Metal(lhs), Self::Metal(rhs)) => {
757                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
758                Ok(Self::Metal(storage))
759            }
760            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
761                lhs: lhs.device().location(),
762                rhs: rhs.device().location(),
763                op: "index-select",
764            }
765            .bt()),
766        }
767    }
768
769    pub(crate) fn matmul(
770        &self,
771        rhs: &Self,
772        bmnk: (usize, usize, usize, usize),
773        lhs_layout: &Layout,
774        rhs_layout: &Layout,
775    ) -> Result<Self> {
776        self.same_device(rhs, "matmul")?;
777        self.same_dtype(rhs, "matmul")?;
778        match (self, rhs) {
779            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
780                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
781                Ok(Self::Cpu(storage))
782            }
783            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
784                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
785                Ok(Self::Cuda(storage))
786            }
787            (Self::Metal(lhs), Self::Metal(rhs)) => {
788                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
789                Ok(Self::Metal(storage))
790            }
791            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
792                lhs: lhs.device().location(),
793                rhs: rhs.device().location(),
794                op: "matmul",
795            }
796            .bt()),
797        }
798    }
799
800    // self, the source can be strided whereas dst is contiguous.
801    pub(crate) fn copy_strided_src(
802        &self,
803        dst: &mut Self,
804        dst_offset: usize,
805        src_l: &Layout,
806    ) -> Result<()> {
807        match (self, dst) {
808            (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
809            (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
810            (Self::Metal(src), Self::Metal(dst)) => {
811                Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
812            }
813            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
814                lhs: lhs.device().location(),
815                rhs: rhs.device().location(),
816                op: "copy",
817            }
818            .bt()),
819        }
820    }
821
822    #[allow(clippy::too_many_arguments)]
823    pub(crate) fn copy2d(
824        &self,
825        dst: &mut Self,
826        d1: usize,
827        d2: usize,
828        src_s: usize,
829        dst_s: usize,
830        src_o: usize,
831        dst_o: usize,
832    ) -> Result<()> {
833        match (self, dst) {
834            (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
835            (Self::Cuda(src), Self::Cuda(dst)) => {
836                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
837            }
838            (Self::Metal(src), Self::Metal(dst)) => {
839                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
840            }
841            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
842                lhs: lhs.device().location(),
843                rhs: rhs.device().location(),
844                op: "copy2d",
845            }
846            .bt()),
847        }
848    }
849}