candle_core/
storage.rs

1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
4use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
5
6// We do not want to implement Clone on Storage as cloning may fail because of
7// out of memory. Instead try_clone should be used.
8#[derive(Debug)]
9pub enum Storage {
10    Cpu(CpuStorage),
11    Cuda(CudaStorage),
12    Metal(MetalStorage),
13}
14
15impl Storage {
16    pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
17        match self {
18            Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
19            Self::Cuda(storage) => {
20                let storage = storage.try_clone(layout)?;
21                Ok(Self::Cuda(storage))
22            }
23            Self::Metal(storage) => {
24                let storage = storage.try_clone(layout)?;
25                Ok(Self::Metal(storage))
26            }
27        }
28    }
29
30    pub fn device(&self) -> Device {
31        match self {
32            Self::Cpu(_) => Device::Cpu,
33            Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
34            Self::Metal(storage) => Device::Metal(storage.device().clone()),
35        }
36    }
37
38    pub fn dtype(&self) -> DType {
39        match self {
40            Self::Cpu(storage) => storage.dtype(),
41            Self::Cuda(storage) => storage.dtype(),
42            Self::Metal(storage) => storage.dtype(),
43        }
44    }
45
46    pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
47        let lhs_device = self.device();
48        let rhs_device = rhs.device();
49        let lhs = lhs_device.location();
50        let rhs = rhs_device.location();
51        let same_device = if self.device().is_metal() {
52            // On metal, we require the device to be exactly the same rather than
53            // having the same location. In cuda this is not necessary as all CudaDevice on the
54            // same GPU will use the same cuda stream.
55            lhs_device.same_device(&rhs_device)
56        } else {
57            lhs == rhs
58        };
59        if !same_device {
60            Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
61        } else {
62            Ok(())
63        }
64    }
65
66    pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
67        let lhs = self.dtype();
68        let rhs = rhs.dtype();
69        if lhs != rhs {
70            Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
71        } else {
72            Ok(())
73        }
74    }
75
76    pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
77        match self {
78            Storage::Cpu(storage) => {
79                let storage = storage.affine(layout, mul, add)?;
80                Ok(Self::Cpu(storage))
81            }
82            Self::Cuda(storage) => {
83                let storage = storage.affine(layout, mul, add)?;
84                Ok(Self::Cuda(storage))
85            }
86            Self::Metal(storage) => {
87                let storage = storage.affine(layout, mul, add)?;
88                Ok(Self::Metal(storage))
89            }
90        }
91    }
92
93    pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
94        match self {
95            Storage::Cpu(storage) => {
96                let storage = storage.powf(layout, alpha)?;
97                Ok(Self::Cpu(storage))
98            }
99            Self::Cuda(storage) => {
100                let storage = storage.powf(layout, alpha)?;
101                Ok(Self::Cuda(storage))
102            }
103            Self::Metal(storage) => {
104                let storage = storage.powf(layout, alpha)?;
105                Ok(Self::Metal(storage))
106            }
107        }
108    }
109
110    pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
111        match self {
112            Storage::Cpu(storage) => {
113                let storage = storage.elu(layout, alpha)?;
114                Ok(Self::Cpu(storage))
115            }
116            Self::Cuda(storage) => {
117                let storage = storage.elu(layout, alpha)?;
118                Ok(Self::Cuda(storage))
119            }
120            Self::Metal(storage) => {
121                let storage = storage.elu(layout, alpha)?;
122                Ok(Self::Metal(storage))
123            }
124        }
125    }
126
127    pub(crate) fn cmp(
128        &self,
129        op: CmpOp,
130        rhs: &Self,
131        lhs_layout: &Layout,
132        rhs_layout: &Layout,
133    ) -> Result<Self> {
134        self.same_device(rhs, "cmp")?;
135        self.same_dtype(rhs, "cmp")?;
136        match (self, rhs) {
137            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
138                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
139                Ok(Self::Cpu(storage))
140            }
141            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
142                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
143                Ok(Self::Cuda(storage))
144            }
145            (Self::Metal(lhs), Self::Metal(rhs)) => {
146                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
147                Ok(Self::Metal(storage))
148            }
149            (lhs, rhs) => {
150                // Should not happen because of the same device check above but we're defensive
151                // anyway.
152                Err(Error::DeviceMismatchBinaryOp {
153                    lhs: lhs.device().location(),
154                    rhs: rhs.device().location(),
155                    op: "cmp",
156                }
157                .bt())
158            }
159        }
160    }
161
162    pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
163        match self {
164            Storage::Cpu(storage) => {
165                let storage = storage.reduce_op(op, layout, s)?;
166                Ok(Self::Cpu(storage))
167            }
168            Self::Cuda(storage) => {
169                let storage = storage.reduce_op(op, layout, s)?;
170                Ok(Self::Cuda(storage))
171            }
172            Self::Metal(storage) => {
173                let storage = storage.reduce_op(op, layout, s)?;
174                Ok(Self::Metal(storage))
175            }
176        }
177    }
178
179    pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
180        match self {
181            Storage::Cpu(storage) => {
182                let storage = storage.to_dtype(layout, dtype)?;
183                Ok(Self::Cpu(storage))
184            }
185            Self::Cuda(storage) => {
186                let storage = storage.to_dtype(layout, dtype)?;
187                Ok(Self::Cuda(storage))
188            }
189            Self::Metal(storage) => {
190                let storage = storage.to_dtype(layout, dtype)?;
191                Ok(Self::Metal(storage))
192            }
193        }
194    }
195
196    pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
197        match self {
198            Self::Cpu(storage) => {
199                let (storage, shape) = c.cpu_fwd(storage, l)?;
200                Ok((Self::Cpu(storage), shape))
201            }
202            Self::Cuda(storage) => {
203                let (storage, shape) = c.cuda_fwd(storage, l)?;
204                Ok((Self::Cuda(storage), shape))
205            }
206            Self::Metal(storage) => {
207                let (storage, shape) = c.metal_fwd(storage, l)?;
208                Ok((Self::Metal(storage), shape))
209            }
210        }
211    }
212
213    pub(crate) fn apply_op2(
214        &self,
215        l1: &Layout,
216        t2: &Self,
217        l2: &Layout,
218        c: &dyn CustomOp2,
219    ) -> Result<(Self, Shape)> {
220        self.same_device(t2, c.name())?;
221        match (self, t2) {
222            (Self::Cpu(s1), Self::Cpu(s2)) => {
223                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
224                Ok((Self::Cpu(s), shape))
225            }
226            (Self::Cuda(s1), Self::Cuda(s2)) => {
227                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
228                Ok((Self::Cuda(s), shape))
229            }
230            (Self::Metal(s1), Self::Metal(s2)) => {
231                let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
232                Ok((Self::Metal(s), shape))
233            }
234            _ => unreachable!(),
235        }
236    }
237
238    pub(crate) fn apply_op3(
239        &self,
240        l1: &Layout,
241        t2: &Self,
242        l2: &Layout,
243        t3: &Self,
244        l3: &Layout,
245        c: &dyn CustomOp3,
246    ) -> Result<(Self, Shape)> {
247        self.same_device(t2, c.name())?;
248        self.same_device(t3, c.name())?;
249        match (self, t2, t3) {
250            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
251                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
252                Ok((Self::Cpu(s), shape))
253            }
254            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
255                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
256                Ok((Self::Cuda(s), shape))
257            }
258            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
259                let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
260                Ok((Self::Metal(s), shape))
261            }
262            _ => unreachable!(),
263        }
264    }
265
266    pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
267        match self {
268            Self::Cpu(storage) => c.cpu_fwd(storage, l),
269            Self::Cuda(storage) => c.cuda_fwd(storage, l),
270            Self::Metal(storage) => c.metal_fwd(storage, l),
271        }
272    }
273
274    pub(crate) fn inplace_op2(
275        &mut self,
276        l1: &Layout,
277        t2: &Self,
278        l2: &Layout,
279        c: &dyn InplaceOp2,
280    ) -> Result<()> {
281        self.same_device(t2, c.name())?;
282        match (self, t2) {
283            (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
284            (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
285            (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
286            _ => unreachable!(),
287        }
288    }
289
290    pub(crate) fn inplace_op3(
291        &mut self,
292        l1: &Layout,
293        t2: &Self,
294        l2: &Layout,
295        t3: &Self,
296        l3: &Layout,
297        c: &dyn InplaceOp3,
298    ) -> Result<()> {
299        self.same_device(t2, c.name())?;
300        self.same_device(t3, c.name())?;
301        match (self, t2, t3) {
302            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
303            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
304            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
305                c.metal_fwd(s1, l1, s2, l2, s3, l3)
306            }
307            _ => unreachable!(),
308        }
309    }
310
311    pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
312        match self {
313            Storage::Cpu(storage) => {
314                let storage = storage.unary_impl::<B>(layout)?;
315                Ok(Self::Cpu(storage))
316            }
317            Self::Cuda(storage) => {
318                let storage = storage.unary_impl::<B>(layout)?;
319                Ok(Self::Cuda(storage))
320            }
321            Self::Metal(storage) => {
322                let storage = storage.unary_impl::<B>(layout)?;
323                Ok(Self::Metal(storage))
324            }
325        }
326    }
327
328    pub(crate) fn binary_impl<B: op::BinaryOpT>(
329        &self,
330        rhs: &Self,
331        lhs_layout: &Layout,
332        rhs_layout: &Layout,
333    ) -> Result<Self> {
334        self.same_device(rhs, B::NAME)?;
335        self.same_dtype(rhs, B::NAME)?;
336        match (self, rhs) {
337            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
338                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
339                Ok(Self::Cpu(storage))
340            }
341            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
342                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
343                Ok(Self::Cuda(storage))
344            }
345            (Self::Metal(lhs), Self::Metal(rhs)) => {
346                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
347                Ok(Self::Metal(storage))
348            }
349            (lhs, rhs) => {
350                // Should not happen because of the same device check above but we're defensive
351                // anyway.
352                Err(Error::DeviceMismatchBinaryOp {
353                    lhs: lhs.device().location(),
354                    rhs: rhs.device().location(),
355                    op: B::NAME,
356                }
357                .bt())
358            }
359        }
360    }
361
362    pub(crate) fn conv1d(
363        &self,
364        l: &Layout,
365        kernel: &Self,
366        kernel_l: &Layout,
367        params: &crate::conv::ParamsConv1D,
368    ) -> Result<Self> {
369        self.same_device(kernel, "conv1d")?;
370        self.same_dtype(kernel, "conv1d")?;
371        match (self, &kernel) {
372            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
373                let s = inp.conv1d(l, kernel, kernel_l, params)?;
374                Ok(Self::Cpu(s))
375            }
376            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
377                let s = inp.conv1d(l, kernel, kernel_l, params)?;
378                Ok(Self::Cuda(s))
379            }
380            (Storage::Metal(inp), Storage::Metal(kernel)) => {
381                let s = inp.conv1d(l, kernel, kernel_l, params)?;
382                Ok(Self::Metal(s))
383            }
384            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
385                lhs: lhs.device().location(),
386                rhs: rhs.device().location(),
387                op: "conv1d",
388            }
389            .bt()),
390        }
391    }
392
393    pub(crate) fn conv_transpose1d(
394        &self,
395        l: &Layout,
396        kernel: &Self,
397        kernel_l: &Layout,
398        params: &crate::conv::ParamsConvTranspose1D,
399    ) -> Result<Self> {
400        self.same_device(kernel, "conv-transpose1d")?;
401        self.same_dtype(kernel, "conv-transpose1d")?;
402        match (self, &kernel) {
403            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
404                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
405                Ok(Self::Cpu(s))
406            }
407            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
408                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
409                Ok(Self::Cuda(s))
410            }
411            (Storage::Metal(inp), Storage::Metal(kernel)) => {
412                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
413                Ok(Self::Metal(s))
414            }
415            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
416                lhs: lhs.device().location(),
417                rhs: rhs.device().location(),
418                op: "conv-transpose1d",
419            }
420            .bt()),
421        }
422    }
423
424    pub(crate) fn conv2d(
425        &self,
426        l: &Layout,
427        kernel: &Self,
428        kernel_l: &Layout,
429        params: &crate::conv::ParamsConv2D,
430    ) -> Result<Self> {
431        self.same_device(kernel, "conv2d")?;
432        self.same_dtype(kernel, "conv2d")?;
433        match (self, &kernel) {
434            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
435                let s = inp.conv2d(l, kernel, kernel_l, params)?;
436                Ok(Self::Cpu(s))
437            }
438            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
439                let s = inp.conv2d(l, kernel, kernel_l, params)?;
440                Ok(Self::Cuda(s))
441            }
442            (Storage::Metal(inp), Storage::Metal(kernel)) => {
443                let s = inp.conv2d(l, kernel, kernel_l, params)?;
444                Ok(Self::Metal(s))
445            }
446            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
447                lhs: lhs.device().location(),
448                rhs: rhs.device().location(),
449                op: "conv2d",
450            }
451            .bt()),
452        }
453    }
454
455    pub(crate) fn conv_transpose2d(
456        &self,
457        l: &Layout,
458        kernel: &Self,
459        kernel_l: &Layout,
460        params: &crate::conv::ParamsConvTranspose2D,
461    ) -> Result<Self> {
462        self.same_device(kernel, "conv_transpose2d")?;
463        self.same_dtype(kernel, "conv_transpose2d")?;
464        match (self, &kernel) {
465            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
466                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
467                Ok(Self::Cpu(s))
468            }
469            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
470                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
471                Ok(Self::Cuda(s))
472            }
473            (Storage::Metal(inp), Storage::Metal(kernel)) => {
474                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
475                Ok(Self::Metal(s))
476            }
477            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
478                lhs: lhs.device().location(),
479                rhs: rhs.device().location(),
480                op: "conv_transpose2d",
481            }
482            .bt()),
483        }
484    }
485
486    pub(crate) fn avg_pool2d(
487        &self,
488        layout: &Layout,
489        kernel_size: (usize, usize),
490        stride: (usize, usize),
491    ) -> Result<Self> {
492        match self {
493            Storage::Cpu(storage) => {
494                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
495                Ok(Self::Cpu(storage))
496            }
497            Self::Cuda(storage) => {
498                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
499                Ok(Self::Cuda(storage))
500            }
501            Self::Metal(storage) => {
502                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
503                Ok(Self::Metal(storage))
504            }
505        }
506    }
507
508    pub(crate) fn max_pool2d(
509        &self,
510        layout: &Layout,
511        kernel_size: (usize, usize),
512        stride: (usize, usize),
513    ) -> Result<Self> {
514        match self {
515            Storage::Cpu(storage) => {
516                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
517                Ok(Self::Cpu(storage))
518            }
519            Self::Cuda(storage) => {
520                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
521                Ok(Self::Cuda(storage))
522            }
523            Self::Metal(storage) => {
524                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
525                Ok(Self::Metal(storage))
526            }
527        }
528    }
529
530    pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
531        match self {
532            Storage::Cpu(storage) => {
533                let storage = storage.upsample_nearest1d(layout, sz)?;
534                Ok(Self::Cpu(storage))
535            }
536            Self::Cuda(storage) => {
537                let storage = storage.upsample_nearest1d(layout, sz)?;
538                Ok(Self::Cuda(storage))
539            }
540            Self::Metal(storage) => {
541                let storage = storage.upsample_nearest1d(layout, sz)?;
542                Ok(Self::Metal(storage))
543            }
544        }
545    }
546
547    pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
548        match self {
549            Storage::Cpu(storage) => {
550                let storage = storage.upsample_nearest2d(layout, h, w)?;
551                Ok(Self::Cpu(storage))
552            }
553            Self::Cuda(storage) => {
554                let storage = storage.upsample_nearest2d(layout, h, w)?;
555                Ok(Self::Cuda(storage))
556            }
557            Self::Metal(storage) => {
558                let storage = storage.upsample_nearest2d(layout, h, w)?;
559                Ok(Self::Metal(storage))
560            }
561        }
562    }
563
564    pub(crate) fn where_cond(
565        &self,
566        layout: &Layout,
567        t: &Self,
568        layout_t: &Layout,
569        f: &Self,
570        layout_f: &Layout,
571    ) -> Result<Self> {
572        self.same_device(t, "where")?;
573        self.same_device(f, "where")?;
574        t.same_dtype(f, "where")?;
575        match (self, t, f) {
576            (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
577                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
578                Ok(Self::Cpu(storage))
579            }
580            (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
581                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
582                Ok(Self::Cuda(storage))
583            }
584            (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
585                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
586                Ok(Self::Metal(storage))
587            }
588            (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
589                lhs: lhs.device().location(),
590                rhs: rhs.device().location(),
591                op: "where",
592            }
593            .bt()),
594        }
595    }
596
597    pub(crate) fn gather(
598        &self,
599        l: &Layout,
600        indexes: &Self,
601        indexes_l: &Layout,
602        d: usize,
603    ) -> Result<Self> {
604        self.same_device(indexes, "index-add")?;
605        match (self, indexes) {
606            (Self::Cpu(s), Self::Cpu(indexes)) => {
607                let storage = s.gather(l, indexes, indexes_l, d)?;
608                Ok(Self::Cpu(storage))
609            }
610            (Self::Cuda(s), Self::Cuda(indexes)) => {
611                let storage = s.gather(l, indexes, indexes_l, d)?;
612                Ok(Self::Cuda(storage))
613            }
614            (Self::Metal(s), Self::Metal(indexes)) => {
615                let storage = s.gather(l, indexes, indexes_l, d)?;
616                Ok(Self::Metal(storage))
617            }
618            _ => unreachable!(),
619        }
620    }
621
622    pub(crate) fn scatter_add(
623        &self,
624        l: &Layout,
625        indexes: &Self,
626        indexes_l: &Layout,
627        source: &Self,
628        source_l: &Layout,
629        d: usize,
630    ) -> Result<Self> {
631        self.same_device(indexes, "scatter-add")?;
632        self.same_device(source, "scatter-add")?;
633        match (self, indexes, source) {
634            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
635                let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
636                Ok(Self::Cpu(storage))
637            }
638            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
639                let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
640                Ok(Self::Cuda(storage))
641            }
642            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
643                let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
644                Ok(Self::Metal(storage))
645            }
646            _ => unreachable!(),
647        }
648    }
649
650    pub(crate) fn index_add(
651        &self,
652        l: &Layout,
653        indexes: &Self,
654        indexes_l: &Layout,
655        source: &Self,
656        source_l: &Layout,
657        d: usize,
658    ) -> Result<Self> {
659        self.same_device(indexes, "index-add")?;
660        self.same_device(source, "index-add")?;
661        match (self, indexes, source) {
662            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
663                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
664                Ok(Self::Cpu(storage))
665            }
666            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
667                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
668                Ok(Self::Cuda(storage))
669            }
670            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
671                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
672                Ok(Self::Metal(storage))
673            }
674            _ => unreachable!(),
675        }
676    }
677
678    pub(crate) fn index_select(
679        &self,
680        rhs: &Self,
681        lhs_l: &Layout,
682        rhs_l: &Layout,
683        d: usize,
684    ) -> Result<Self> {
685        self.same_device(rhs, "index-select")?;
686        match (self, rhs) {
687            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
688                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
689                Ok(Self::Cpu(storage))
690            }
691            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
692                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
693                Ok(Self::Cuda(storage))
694            }
695            (Self::Metal(lhs), Self::Metal(rhs)) => {
696                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
697                Ok(Self::Metal(storage))
698            }
699            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
700                lhs: lhs.device().location(),
701                rhs: rhs.device().location(),
702                op: "index-select",
703            }
704            .bt()),
705        }
706    }
707
708    pub(crate) fn matmul(
709        &self,
710        rhs: &Self,
711        bmnk: (usize, usize, usize, usize),
712        lhs_layout: &Layout,
713        rhs_layout: &Layout,
714    ) -> Result<Self> {
715        self.same_device(rhs, "matmul")?;
716        self.same_dtype(rhs, "matmul")?;
717        match (self, rhs) {
718            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
719                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
720                Ok(Self::Cpu(storage))
721            }
722            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
723                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
724                Ok(Self::Cuda(storage))
725            }
726            (Self::Metal(lhs), Self::Metal(rhs)) => {
727                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
728                Ok(Self::Metal(storage))
729            }
730            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
731                lhs: lhs.device().location(),
732                rhs: rhs.device().location(),
733                op: "matmul",
734            }
735            .bt()),
736        }
737    }
738
739    // self, the source can be strided whereas dst is contiguous.
740    pub(crate) fn copy_strided_src(
741        &self,
742        dst: &mut Self,
743        dst_offset: usize,
744        src_l: &Layout,
745    ) -> Result<()> {
746        match (self, dst) {
747            (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
748            (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
749            (Self::Metal(src), Self::Metal(dst)) => {
750                Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
751            }
752            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
753                lhs: lhs.device().location(),
754                rhs: rhs.device().location(),
755                op: "copy",
756            }
757            .bt()),
758        }
759    }
760
761    #[allow(clippy::too_many_arguments)]
762    pub(crate) fn copy2d(
763        &self,
764        dst: &mut Self,
765        d1: usize,
766        d2: usize,
767        src_s: usize,
768        dst_s: usize,
769        src_o: usize,
770        dst_o: usize,
771    ) -> Result<()> {
772        match (self, dst) {
773            (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
774            (Self::Cuda(src), Self::Cuda(dst)) => {
775                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
776            }
777            (Self::Metal(src), Self::Metal(dst)) => {
778                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
779            }
780            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
781                lhs: lhs.device().location(),
782                rhs: rhs.device().location(),
783                op: "copy2d",
784            }
785            .bt()),
786        }
787    }
788}