Skip to main content

hanzo_ml/
storage.rs

1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::scalar::Scalar;
4#[cfg(feature = "rocm")]
5use crate::RocmStorage;
6#[cfg(feature = "vulkan")]
7use crate::VulkanStorage;
8use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
9use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
10
11// We do not want to implement Clone on Storage as cloning may fail because of
12// out of memory. Instead try_clone should be used.
13#[derive(Debug)]
14pub enum Storage {
15    Cpu(CpuStorage),
16    Cuda(CudaStorage),
17    Metal(MetalStorage),
18    #[cfg(feature = "rocm")]
19    Rocm(RocmStorage),
20    #[cfg(feature = "vulkan")]
21    Vulkan(VulkanStorage),
22}
23
24impl Storage {
25    pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
26        match self {
27            Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
28            Self::Cuda(storage) => {
29                let storage = storage.try_clone(layout)?;
30                Ok(Self::Cuda(storage))
31            }
32            Self::Metal(storage) => {
33                let storage = storage.try_clone(layout)?;
34                Ok(Self::Metal(storage))
35            }
36            #[cfg(feature = "rocm")]
37            Self::Rocm(storage) => {
38                let storage = storage.try_clone(layout)?;
39                Ok(Self::Rocm(storage))
40            }
41            #[cfg(feature = "vulkan")]
42            Self::Vulkan(storage) => {
43                let storage = storage.try_clone(layout)?;
44                Ok(Self::Vulkan(storage))
45            }
46        }
47    }
48
49    pub fn device(&self) -> Device {
50        match self {
51            Self::Cpu(_) => Device::Cpu,
52            Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
53            Self::Metal(storage) => Device::Metal(storage.device().clone()),
54            #[cfg(feature = "rocm")]
55            Self::Rocm(storage) => Device::Rocm(storage.device().clone()),
56            #[cfg(feature = "vulkan")]
57            Self::Vulkan(storage) => Device::Vulkan(storage.device().clone()),
58        }
59    }
60
61    pub fn dtype(&self) -> DType {
62        match self {
63            Self::Cpu(storage) => storage.dtype(),
64            Self::Cuda(storage) => storage.dtype(),
65            Self::Metal(storage) => storage.dtype(),
66            #[cfg(feature = "rocm")]
67            Self::Rocm(storage) => storage.dtype(),
68            #[cfg(feature = "vulkan")]
69            Self::Vulkan(storage) => storage.dtype(),
70        }
71    }
72
73    pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
74        let lhs_device = self.device();
75        let rhs_device = rhs.device();
76        let lhs = lhs_device.location();
77        let rhs = rhs_device.location();
78        let same_device = if self.device().is_metal() {
79            // On metal, we require the device to be exactly the same rather than
80            // having the same location. In cuda this is not necessary as all CudaDevice on the
81            // same GPU will use the same cuda stream.
82            lhs_device.same_device(&rhs_device)
83        } else {
84            lhs == rhs
85        };
86        if !same_device {
87            Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
88        } else {
89            Ok(())
90        }
91    }
92
93    pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
94        let lhs = self.dtype();
95        let rhs = rhs.dtype();
96        if lhs != rhs {
97            Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
98        } else {
99            Ok(())
100        }
101    }
102
103    pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
104        match self {
105            Storage::Cpu(storage) => storage.const_set(v, l),
106            Storage::Cuda(storage) => storage.const_set(v, l),
107            Storage::Metal(storage) => storage.const_set(v, l),
108            #[cfg(feature = "rocm")]
109            Storage::Rocm(storage) => storage.const_set(v, l),
110            #[cfg(feature = "vulkan")]
111            Storage::Vulkan(storage) => storage.const_set(v, l),
112        }
113    }
114
115    pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
116        match self {
117            Storage::Cpu(storage) => {
118                let storage = storage.affine(layout, mul, add)?;
119                Ok(Self::Cpu(storage))
120            }
121            Self::Cuda(storage) => {
122                let storage = storage.affine(layout, mul, add)?;
123                Ok(Self::Cuda(storage))
124            }
125            Self::Metal(storage) => {
126                let storage = storage.affine(layout, mul, add)?;
127                Ok(Self::Metal(storage))
128            }
129            #[cfg(feature = "rocm")]
130            Self::Rocm(storage) => {
131                let storage = storage.affine(layout, mul, add)?;
132                Ok(Self::Rocm(storage))
133            }
134            #[cfg(feature = "vulkan")]
135            Self::Vulkan(storage) => {
136                let storage = storage.affine(layout, mul, add)?;
137                Ok(Self::Vulkan(storage))
138            }
139        }
140    }
141
142    pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
143        match self {
144            Storage::Cpu(storage) => {
145                let storage = storage.powf(layout, alpha)?;
146                Ok(Self::Cpu(storage))
147            }
148            Self::Cuda(storage) => {
149                let storage = storage.powf(layout, alpha)?;
150                Ok(Self::Cuda(storage))
151            }
152            Self::Metal(storage) => {
153                let storage = storage.powf(layout, alpha)?;
154                Ok(Self::Metal(storage))
155            }
156            #[cfg(feature = "rocm")]
157            Self::Rocm(storage) => {
158                let storage = storage.powf(layout, alpha)?;
159                Ok(Self::Rocm(storage))
160            }
161            #[cfg(feature = "vulkan")]
162            Self::Vulkan(storage) => {
163                let storage = storage.powf(layout, alpha)?;
164                Ok(Self::Vulkan(storage))
165            }
166        }
167    }
168
169    pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
170        match self {
171            Storage::Cpu(storage) => {
172                let storage = storage.elu(layout, alpha)?;
173                Ok(Self::Cpu(storage))
174            }
175            Self::Cuda(storage) => {
176                let storage = storage.elu(layout, alpha)?;
177                Ok(Self::Cuda(storage))
178            }
179            Self::Metal(storage) => {
180                let storage = storage.elu(layout, alpha)?;
181                Ok(Self::Metal(storage))
182            }
183            #[cfg(feature = "rocm")]
184            Self::Rocm(storage) => {
185                let storage = storage.elu(layout, alpha)?;
186                Ok(Self::Rocm(storage))
187            }
188            #[cfg(feature = "vulkan")]
189            Self::Vulkan(storage) => {
190                let storage = storage.elu(layout, alpha)?;
191                Ok(Self::Vulkan(storage))
192            }
193        }
194    }
195
196    pub(crate) fn cmp(
197        &self,
198        op: CmpOp,
199        rhs: &Self,
200        lhs_layout: &Layout,
201        rhs_layout: &Layout,
202    ) -> Result<Self> {
203        self.same_device(rhs, "cmp")?;
204        self.same_dtype(rhs, "cmp")?;
205        match (self, rhs) {
206            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
207                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
208                Ok(Self::Cpu(storage))
209            }
210            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
211                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
212                Ok(Self::Cuda(storage))
213            }
214            (Self::Metal(lhs), Self::Metal(rhs)) => {
215                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
216                Ok(Self::Metal(storage))
217            }
218            #[cfg(feature = "rocm")]
219            (Self::Rocm(lhs), Self::Rocm(rhs)) => {
220                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
221                Ok(Self::Rocm(storage))
222            }
223            #[cfg(feature = "vulkan")]
224            (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
225                let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
226                Ok(Self::Vulkan(storage))
227            }
228            (lhs, rhs) => {
229                // Should not happen because of the same device check above but we're defensive
230                // anyway.
231                Err(Error::DeviceMismatchBinaryOp {
232                    lhs: lhs.device().location(),
233                    rhs: rhs.device().location(),
234                    op: "cmp",
235                }
236                .bt())
237            }
238        }
239    }
240
241    pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
242        match self {
243            Storage::Cpu(storage) => {
244                let storage = storage.reduce_op(op, layout, s)?;
245                Ok(Self::Cpu(storage))
246            }
247            Self::Cuda(storage) => {
248                let storage = storage.reduce_op(op, layout, s)?;
249                Ok(Self::Cuda(storage))
250            }
251            Self::Metal(storage) => {
252                let storage = storage.reduce_op(op, layout, s)?;
253                Ok(Self::Metal(storage))
254            }
255            #[cfg(feature = "rocm")]
256            Self::Rocm(storage) => {
257                let storage = storage.reduce_op(op, layout, s)?;
258                Ok(Self::Rocm(storage))
259            }
260            #[cfg(feature = "vulkan")]
261            Self::Vulkan(storage) => {
262                let storage = storage.reduce_op(op, layout, s)?;
263                Ok(Self::Vulkan(storage))
264            }
265        }
266    }
267
268    pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
269        match self {
270            Storage::Cpu(storage) => {
271                let storage = storage.to_dtype(layout, dtype)?;
272                Ok(Self::Cpu(storage))
273            }
274            Self::Cuda(storage) => {
275                let storage = storage.to_dtype(layout, dtype)?;
276                Ok(Self::Cuda(storage))
277            }
278            Self::Metal(storage) => {
279                let storage = storage.to_dtype(layout, dtype)?;
280                Ok(Self::Metal(storage))
281            }
282            #[cfg(feature = "rocm")]
283            Self::Rocm(storage) => {
284                let storage = storage.to_dtype(layout, dtype)?;
285                Ok(Self::Rocm(storage))
286            }
287            #[cfg(feature = "vulkan")]
288            Self::Vulkan(storage) => {
289                let storage = storage.to_dtype(layout, dtype)?;
290                Ok(Self::Vulkan(storage))
291            }
292        }
293    }
294
295    pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
296        match self {
297            Self::Cpu(storage) => {
298                let (storage, shape) = c.cpu_fwd(storage, l)?;
299                Ok((Self::Cpu(storage), shape))
300            }
301            Self::Cuda(storage) => {
302                let (storage, shape) = c.cuda_fwd(storage, l)?;
303                Ok((Self::Cuda(storage), shape))
304            }
305            Self::Metal(storage) => {
306                let (storage, shape) = c.metal_fwd(storage, l)?;
307                Ok((Self::Metal(storage), shape))
308            }
309            #[cfg(feature = "rocm")]
310            Self::Rocm(storage) => {
311                let (storage, shape) = c.rocm_fwd(storage, l)?;
312                Ok((Self::Rocm(storage), shape))
313            }
314            #[cfg(feature = "vulkan")]
315            Self::Vulkan(storage) => {
316                let (storage, shape) = c.vulkan_fwd(storage, l)?;
317                Ok((Self::Vulkan(storage), shape))
318            }
319        }
320    }
321
322    pub(crate) fn apply_op2(
323        &self,
324        l1: &Layout,
325        t2: &Self,
326        l2: &Layout,
327        c: &dyn CustomOp2,
328    ) -> Result<(Self, Shape)> {
329        self.same_device(t2, c.name())?;
330        match (self, t2) {
331            (Self::Cpu(s1), Self::Cpu(s2)) => {
332                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
333                Ok((Self::Cpu(s), shape))
334            }
335            (Self::Cuda(s1), Self::Cuda(s2)) => {
336                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
337                Ok((Self::Cuda(s), shape))
338            }
339            (Self::Metal(s1), Self::Metal(s2)) => {
340                let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
341                Ok((Self::Metal(s), shape))
342            }
343            #[cfg(feature = "rocm")]
344            (Self::Rocm(s1), Self::Rocm(s2)) => {
345                let (s, shape) = c.rocm_fwd(s1, l1, s2, l2)?;
346                Ok((Self::Rocm(s), shape))
347            }
348            #[cfg(feature = "vulkan")]
349            (Self::Vulkan(s1), Self::Vulkan(s2)) => {
350                let (s, shape) = c.vulkan_fwd(s1, l1, s2, l2)?;
351                Ok((Self::Vulkan(s), shape))
352            }
353            _ => unreachable!(),
354        }
355    }
356
357    pub(crate) fn apply_op3(
358        &self,
359        l1: &Layout,
360        t2: &Self,
361        l2: &Layout,
362        t3: &Self,
363        l3: &Layout,
364        c: &dyn CustomOp3,
365    ) -> Result<(Self, Shape)> {
366        self.same_device(t2, c.name())?;
367        self.same_device(t3, c.name())?;
368        match (self, t2, t3) {
369            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
370                let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
371                Ok((Self::Cpu(s), shape))
372            }
373            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
374                let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
375                Ok((Self::Cuda(s), shape))
376            }
377            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
378                let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
379                Ok((Self::Metal(s), shape))
380            }
381            #[cfg(feature = "rocm")]
382            (Self::Rocm(s1), Self::Rocm(s2), Self::Rocm(s3)) => {
383                let (s, shape) = c.rocm_fwd(s1, l1, s2, l2, s3, l3)?;
384                Ok((Self::Rocm(s), shape))
385            }
386            #[cfg(feature = "vulkan")]
387            (Self::Vulkan(s1), Self::Vulkan(s2), Self::Vulkan(s3)) => {
388                let (s, shape) = c.vulkan_fwd(s1, l1, s2, l2, s3, l3)?;
389                Ok((Self::Vulkan(s), shape))
390            }
391            _ => unreachable!(),
392        }
393    }
394
395    pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
396        match self {
397            Self::Cpu(storage) => c.cpu_fwd(storage, l),
398            Self::Cuda(storage) => c.cuda_fwd(storage, l),
399            Self::Metal(storage) => c.metal_fwd(storage, l),
400            #[cfg(feature = "rocm")]
401            Self::Rocm(storage) => c.rocm_fwd(storage, l),
402            #[cfg(feature = "vulkan")]
403            Self::Vulkan(storage) => c.vulkan_fwd(storage, l),
404        }
405    }
406
407    pub(crate) fn inplace_op2(
408        &mut self,
409        l1: &Layout,
410        t2: &Self,
411        l2: &Layout,
412        c: &dyn InplaceOp2,
413    ) -> Result<()> {
414        self.same_device(t2, c.name())?;
415        match (self, t2) {
416            (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
417            (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
418            (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
419            #[cfg(feature = "rocm")]
420            (Self::Rocm(s1), Self::Rocm(s2)) => c.rocm_fwd(s1, l1, s2, l2),
421            #[cfg(feature = "vulkan")]
422            (Self::Vulkan(s1), Self::Vulkan(s2)) => c.vulkan_fwd(s1, l1, s2, l2),
423            _ => unreachable!(),
424        }
425    }
426
427    pub(crate) fn inplace_op3(
428        &mut self,
429        l1: &Layout,
430        t2: &Self,
431        l2: &Layout,
432        t3: &Self,
433        l3: &Layout,
434        c: &dyn InplaceOp3,
435    ) -> Result<()> {
436        self.same_device(t2, c.name())?;
437        self.same_device(t3, c.name())?;
438        match (self, t2, t3) {
439            (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
440            (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
441            (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
442                c.metal_fwd(s1, l1, s2, l2, s3, l3)
443            }
444            #[cfg(feature = "rocm")]
445            (Self::Rocm(s1), Self::Rocm(s2), Self::Rocm(s3)) => c.rocm_fwd(s1, l1, s2, l2, s3, l3),
446            #[cfg(feature = "vulkan")]
447            (Self::Vulkan(s1), Self::Vulkan(s2), Self::Vulkan(s3)) => {
448                c.vulkan_fwd(s1, l1, s2, l2, s3, l3)
449            }
450            _ => unreachable!(),
451        }
452    }
453
454    pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
455        match self {
456            Storage::Cpu(storage) => {
457                let storage = storage.unary_impl::<B>(layout)?;
458                Ok(Self::Cpu(storage))
459            }
460            Self::Cuda(storage) => {
461                let storage = storage.unary_impl::<B>(layout)?;
462                Ok(Self::Cuda(storage))
463            }
464            Self::Metal(storage) => {
465                let storage = storage.unary_impl::<B>(layout)?;
466                Ok(Self::Metal(storage))
467            }
468            #[cfg(feature = "rocm")]
469            Self::Rocm(storage) => {
470                let storage = storage.unary_impl::<B>(layout)?;
471                Ok(Self::Rocm(storage))
472            }
473            #[cfg(feature = "vulkan")]
474            Self::Vulkan(storage) => {
475                let storage = storage.unary_impl::<B>(layout)?;
476                Ok(Self::Vulkan(storage))
477            }
478        }
479    }
480
481    pub(crate) fn binary_impl<B: op::BinaryOpT>(
482        &self,
483        rhs: &Self,
484        lhs_layout: &Layout,
485        rhs_layout: &Layout,
486    ) -> Result<Self> {
487        self.same_device(rhs, B::NAME)?;
488        self.same_dtype(rhs, B::NAME)?;
489        match (self, rhs) {
490            (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
491                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
492                Ok(Self::Cpu(storage))
493            }
494            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
495                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
496                Ok(Self::Cuda(storage))
497            }
498            (Self::Metal(lhs), Self::Metal(rhs)) => {
499                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
500                Ok(Self::Metal(storage))
501            }
502            #[cfg(feature = "rocm")]
503            (Self::Rocm(lhs), Self::Rocm(rhs)) => {
504                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
505                Ok(Self::Rocm(storage))
506            }
507            #[cfg(feature = "vulkan")]
508            (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
509                let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
510                Ok(Self::Vulkan(storage))
511            }
512            (lhs, rhs) => {
513                // Should not happen because of the same device check above but we're defensive
514                // anyway.
515                Err(Error::DeviceMismatchBinaryOp {
516                    lhs: lhs.device().location(),
517                    rhs: rhs.device().location(),
518                    op: B::NAME,
519                }
520                .bt())
521            }
522        }
523    }
524
525    pub(crate) fn conv1d(
526        &self,
527        l: &Layout,
528        kernel: &Self,
529        kernel_l: &Layout,
530        params: &crate::conv::ParamsConv1D,
531    ) -> Result<Self> {
532        self.same_device(kernel, "conv1d")?;
533        self.same_dtype(kernel, "conv1d")?;
534        match (self, &kernel) {
535            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
536                let s = inp.conv1d(l, kernel, kernel_l, params)?;
537                Ok(Self::Cpu(s))
538            }
539            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
540                let s = inp.conv1d(l, kernel, kernel_l, params)?;
541                Ok(Self::Cuda(s))
542            }
543            (Storage::Metal(inp), Storage::Metal(kernel)) => {
544                let s = inp.conv1d(l, kernel, kernel_l, params)?;
545                Ok(Self::Metal(s))
546            }
547            #[cfg(feature = "rocm")]
548            (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
549                let s = inp.conv1d(l, kernel, kernel_l, params)?;
550                Ok(Self::Rocm(s))
551            }
552            #[cfg(feature = "vulkan")]
553            (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
554                let s = inp.conv1d(l, kernel, kernel_l, params)?;
555                Ok(Self::Vulkan(s))
556            }
557            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
558                lhs: lhs.device().location(),
559                rhs: rhs.device().location(),
560                op: "conv1d",
561            }
562            .bt()),
563        }
564    }
565
566    pub(crate) fn conv_transpose1d(
567        &self,
568        l: &Layout,
569        kernel: &Self,
570        kernel_l: &Layout,
571        params: &crate::conv::ParamsConvTranspose1D,
572    ) -> Result<Self> {
573        self.same_device(kernel, "conv-transpose1d")?;
574        self.same_dtype(kernel, "conv-transpose1d")?;
575        match (self, &kernel) {
576            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
577                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
578                Ok(Self::Cpu(s))
579            }
580            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
581                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
582                Ok(Self::Cuda(s))
583            }
584            (Storage::Metal(inp), Storage::Metal(kernel)) => {
585                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
586                Ok(Self::Metal(s))
587            }
588            #[cfg(feature = "rocm")]
589            (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
590                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
591                Ok(Self::Rocm(s))
592            }
593            #[cfg(feature = "vulkan")]
594            (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
595                let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
596                Ok(Self::Vulkan(s))
597            }
598            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
599                lhs: lhs.device().location(),
600                rhs: rhs.device().location(),
601                op: "conv-transpose1d",
602            }
603            .bt()),
604        }
605    }
606
607    pub(crate) fn conv2d(
608        &self,
609        l: &Layout,
610        kernel: &Self,
611        kernel_l: &Layout,
612        params: &crate::conv::ParamsConv2D,
613    ) -> Result<Self> {
614        self.same_device(kernel, "conv2d")?;
615        self.same_dtype(kernel, "conv2d")?;
616        match (self, &kernel) {
617            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
618                let s = inp.conv2d(l, kernel, kernel_l, params)?;
619                Ok(Self::Cpu(s))
620            }
621            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
622                let s = inp.conv2d(l, kernel, kernel_l, params)?;
623                Ok(Self::Cuda(s))
624            }
625            (Storage::Metal(inp), Storage::Metal(kernel)) => {
626                let s = inp.conv2d(l, kernel, kernel_l, params)?;
627                Ok(Self::Metal(s))
628            }
629            #[cfg(feature = "rocm")]
630            (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
631                let s = inp.conv2d(l, kernel, kernel_l, params)?;
632                Ok(Self::Rocm(s))
633            }
634            #[cfg(feature = "vulkan")]
635            (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
636                let s = inp.conv2d(l, kernel, kernel_l, params)?;
637                Ok(Self::Vulkan(s))
638            }
639            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
640                lhs: lhs.device().location(),
641                rhs: rhs.device().location(),
642                op: "conv2d",
643            }
644            .bt()),
645        }
646    }
647
648    pub(crate) fn conv_transpose2d(
649        &self,
650        l: &Layout,
651        kernel: &Self,
652        kernel_l: &Layout,
653        params: &crate::conv::ParamsConvTranspose2D,
654    ) -> Result<Self> {
655        self.same_device(kernel, "conv_transpose2d")?;
656        self.same_dtype(kernel, "conv_transpose2d")?;
657        match (self, &kernel) {
658            (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
659                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
660                Ok(Self::Cpu(s))
661            }
662            (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
663                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
664                Ok(Self::Cuda(s))
665            }
666            (Storage::Metal(inp), Storage::Metal(kernel)) => {
667                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
668                Ok(Self::Metal(s))
669            }
670            #[cfg(feature = "rocm")]
671            (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
672                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
673                Ok(Self::Rocm(s))
674            }
675            #[cfg(feature = "vulkan")]
676            (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
677                let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
678                Ok(Self::Vulkan(s))
679            }
680            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
681                lhs: lhs.device().location(),
682                rhs: rhs.device().location(),
683                op: "conv_transpose2d",
684            }
685            .bt()),
686        }
687    }
688
689    pub(crate) fn avg_pool2d(
690        &self,
691        layout: &Layout,
692        kernel_size: (usize, usize),
693        stride: (usize, usize),
694    ) -> Result<Self> {
695        match self {
696            Storage::Cpu(storage) => {
697                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
698                Ok(Self::Cpu(storage))
699            }
700            Self::Cuda(storage) => {
701                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
702                Ok(Self::Cuda(storage))
703            }
704            Self::Metal(storage) => {
705                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
706                Ok(Self::Metal(storage))
707            }
708            #[cfg(feature = "rocm")]
709            Self::Rocm(storage) => {
710                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
711                Ok(Self::Rocm(storage))
712            }
713            #[cfg(feature = "vulkan")]
714            Self::Vulkan(storage) => {
715                let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
716                Ok(Self::Vulkan(storage))
717            }
718        }
719    }
720
721    pub(crate) fn max_pool2d(
722        &self,
723        layout: &Layout,
724        kernel_size: (usize, usize),
725        stride: (usize, usize),
726    ) -> Result<Self> {
727        match self {
728            Storage::Cpu(storage) => {
729                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
730                Ok(Self::Cpu(storage))
731            }
732            Self::Cuda(storage) => {
733                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
734                Ok(Self::Cuda(storage))
735            }
736            Self::Metal(storage) => {
737                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
738                Ok(Self::Metal(storage))
739            }
740            #[cfg(feature = "rocm")]
741            Self::Rocm(storage) => {
742                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
743                Ok(Self::Rocm(storage))
744            }
745            #[cfg(feature = "vulkan")]
746            Self::Vulkan(storage) => {
747                let storage = storage.max_pool2d(layout, kernel_size, stride)?;
748                Ok(Self::Vulkan(storage))
749            }
750        }
751    }
752
753    pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
754        match self {
755            Storage::Cpu(storage) => {
756                let storage = storage.upsample_nearest1d(layout, sz)?;
757                Ok(Self::Cpu(storage))
758            }
759            Self::Cuda(storage) => {
760                let storage = storage.upsample_nearest1d(layout, sz)?;
761                Ok(Self::Cuda(storage))
762            }
763            Self::Metal(storage) => {
764                let storage = storage.upsample_nearest1d(layout, sz)?;
765                Ok(Self::Metal(storage))
766            }
767            #[cfg(feature = "rocm")]
768            Self::Rocm(storage) => {
769                let storage = storage.upsample_nearest1d(layout, sz)?;
770                Ok(Self::Rocm(storage))
771            }
772            #[cfg(feature = "vulkan")]
773            Self::Vulkan(storage) => {
774                let storage = storage.upsample_nearest1d(layout, sz)?;
775                Ok(Self::Vulkan(storage))
776            }
777        }
778    }
779
780    pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
781        match self {
782            Storage::Cpu(storage) => {
783                let storage = storage.upsample_nearest2d(layout, h, w)?;
784                Ok(Self::Cpu(storage))
785            }
786            Self::Cuda(storage) => {
787                let storage = storage.upsample_nearest2d(layout, h, w)?;
788                Ok(Self::Cuda(storage))
789            }
790            Self::Metal(storage) => {
791                let storage = storage.upsample_nearest2d(layout, h, w)?;
792                Ok(Self::Metal(storage))
793            }
794            #[cfg(feature = "rocm")]
795            Self::Rocm(storage) => {
796                let storage = storage.upsample_nearest2d(layout, h, w)?;
797                Ok(Self::Rocm(storage))
798            }
799            #[cfg(feature = "vulkan")]
800            Self::Vulkan(storage) => {
801                let storage = storage.upsample_nearest2d(layout, h, w)?;
802                Ok(Self::Vulkan(storage))
803            }
804        }
805    }
806
807    pub(crate) fn upsample_bilinear2d(
808        &self,
809        layout: &Layout,
810        h: usize,
811        w: usize,
812        align_corners: bool,
813        scale_h: Option<f64>,
814        scale_w: Option<f64>,
815    ) -> Result<Self> {
816        match self {
817            Storage::Cpu(storage) => {
818                let storage =
819                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
820                Ok(Self::Cpu(storage))
821            }
822            Self::Cuda(storage) => {
823                let storage =
824                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
825                Ok(Self::Cuda(storage))
826            }
827            Self::Metal(storage) => {
828                let storage =
829                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
830                Ok(Self::Metal(storage))
831            }
832            #[cfg(feature = "rocm")]
833            Self::Rocm(storage) => {
834                let storage =
835                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
836                Ok(Self::Rocm(storage))
837            }
838            #[cfg(feature = "vulkan")]
839            Self::Vulkan(storage) => {
840                let storage =
841                    storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
842                Ok(Self::Vulkan(storage))
843            }
844        }
845    }
846
847    pub(crate) fn where_cond(
848        &self,
849        layout: &Layout,
850        t: &Self,
851        layout_t: &Layout,
852        f: &Self,
853        layout_f: &Layout,
854    ) -> Result<Self> {
855        self.same_device(t, "where")?;
856        self.same_device(f, "where")?;
857        t.same_dtype(f, "where")?;
858        match (self, t, f) {
859            (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
860                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
861                Ok(Self::Cpu(storage))
862            }
863            (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
864                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
865                Ok(Self::Cuda(storage))
866            }
867            (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
868                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
869                Ok(Self::Metal(storage))
870            }
871            #[cfg(feature = "rocm")]
872            (Self::Rocm(cond), Self::Rocm(t), Self::Rocm(f)) => {
873                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
874                Ok(Self::Rocm(storage))
875            }
876            #[cfg(feature = "vulkan")]
877            (Self::Vulkan(cond), Self::Vulkan(t), Self::Vulkan(f)) => {
878                let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
879                Ok(Self::Vulkan(storage))
880            }
881            (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
882                lhs: lhs.device().location(),
883                rhs: rhs.device().location(),
884                op: "where",
885            }
886            .bt()),
887        }
888    }
889
890    pub(crate) fn gather(
891        &self,
892        l: &Layout,
893        indexes: &Self,
894        indexes_l: &Layout,
895        d: usize,
896    ) -> Result<Self> {
897        self.same_device(indexes, "index-add")?;
898        match (self, indexes) {
899            (Self::Cpu(s), Self::Cpu(indexes)) => {
900                let storage = s.gather(l, indexes, indexes_l, d)?;
901                Ok(Self::Cpu(storage))
902            }
903            (Self::Cuda(s), Self::Cuda(indexes)) => {
904                let storage = s.gather(l, indexes, indexes_l, d)?;
905                Ok(Self::Cuda(storage))
906            }
907            (Self::Metal(s), Self::Metal(indexes)) => {
908                let storage = s.gather(l, indexes, indexes_l, d)?;
909                Ok(Self::Metal(storage))
910            }
911            #[cfg(feature = "rocm")]
912            (Self::Rocm(s), Self::Rocm(indexes)) => {
913                let storage = s.gather(l, indexes, indexes_l, d)?;
914                Ok(Self::Rocm(storage))
915            }
916            #[cfg(feature = "vulkan")]
917            (Self::Vulkan(s), Self::Vulkan(indexes)) => {
918                let storage = s.gather(l, indexes, indexes_l, d)?;
919                Ok(Self::Vulkan(storage))
920            }
921            _ => unreachable!(),
922        }
923    }
924
925    pub(crate) fn scatter_set(
926        &mut self,
927        l: &Layout,
928        indexes: &Self,
929        indexes_l: &Layout,
930        source: &Self,
931        source_l: &Layout,
932        d: usize,
933    ) -> Result<()> {
934        self.same_device(indexes, "scatter-set")?;
935        self.same_device(source, "scatter-set")?;
936        match (self, indexes, source) {
937            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
938                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
939            }
940            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
941                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
942            }
943            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
944                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
945            }
946            #[cfg(feature = "rocm")]
947            (Self::Rocm(s), Self::Rocm(indexes), Self::Rocm(source)) => {
948                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
949            }
950            #[cfg(feature = "vulkan")]
951            (Self::Vulkan(s), Self::Vulkan(indexes), Self::Vulkan(source)) => {
952                s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
953            }
954            _ => unreachable!(),
955        }
956        Ok(())
957    }
958
959    pub(crate) fn scatter_add(
960        &mut self,
961        l: &Layout,
962        indexes: &Self,
963        indexes_l: &Layout,
964        source: &Self,
965        source_l: &Layout,
966        d: usize,
967    ) -> Result<()> {
968        self.same_device(indexes, "scatter-add")?;
969        self.same_device(source, "scatter-add")?;
970        match (self, indexes, source) {
971            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
972                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
973            }
974            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
975                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
976            }
977            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
978                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
979            }
980            #[cfg(feature = "rocm")]
981            (Self::Rocm(s), Self::Rocm(indexes), Self::Rocm(source)) => {
982                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
983            }
984            #[cfg(feature = "vulkan")]
985            (Self::Vulkan(s), Self::Vulkan(indexes), Self::Vulkan(source)) => {
986                s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
987            }
988            _ => unreachable!(),
989        }
990        Ok(())
991    }
992
993    pub(crate) fn index_add(
994        &self,
995        l: &Layout,
996        indexes: &Self,
997        indexes_l: &Layout,
998        source: &Self,
999        source_l: &Layout,
1000        d: usize,
1001    ) -> Result<Self> {
1002        self.same_device(indexes, "index-add")?;
1003        self.same_device(source, "index-add")?;
1004        match (self, indexes, source) {
1005            (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
1006                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1007                Ok(Self::Cpu(storage))
1008            }
1009            (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
1010                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1011                Ok(Self::Cuda(storage))
1012            }
1013            (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
1014                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1015                Ok(Self::Metal(storage))
1016            }
1017            #[cfg(feature = "rocm")]
1018            (Self::Rocm(s), Self::Rocm(indexes), Self::Rocm(source)) => {
1019                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1020                Ok(Self::Rocm(storage))
1021            }
1022            #[cfg(feature = "vulkan")]
1023            (Self::Vulkan(s), Self::Vulkan(indexes), Self::Vulkan(source)) => {
1024                let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1025                Ok(Self::Vulkan(storage))
1026            }
1027            _ => unreachable!(),
1028        }
1029    }
1030
1031    pub(crate) fn index_select(
1032        &self,
1033        rhs: &Self,
1034        lhs_l: &Layout,
1035        rhs_l: &Layout,
1036        d: usize,
1037    ) -> Result<Self> {
1038        self.same_device(rhs, "index-select")?;
1039        match (self, rhs) {
1040            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
1041                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1042                Ok(Self::Cpu(storage))
1043            }
1044            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
1045                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1046                Ok(Self::Cuda(storage))
1047            }
1048            (Self::Metal(lhs), Self::Metal(rhs)) => {
1049                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1050                Ok(Self::Metal(storage))
1051            }
1052            #[cfg(feature = "rocm")]
1053            (Self::Rocm(lhs), Self::Rocm(rhs)) => {
1054                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1055                Ok(Self::Rocm(storage))
1056            }
1057            #[cfg(feature = "vulkan")]
1058            (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
1059                let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1060                Ok(Self::Vulkan(storage))
1061            }
1062            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1063                lhs: lhs.device().location(),
1064                rhs: rhs.device().location(),
1065                op: "index-select",
1066            }
1067            .bt()),
1068        }
1069    }
1070
1071    pub(crate) fn matmul(
1072        &self,
1073        rhs: &Self,
1074        bmnk: (usize, usize, usize, usize),
1075        lhs_layout: &Layout,
1076        rhs_layout: &Layout,
1077    ) -> Result<Self> {
1078        self.same_device(rhs, "matmul")?;
1079        self.same_dtype(rhs, "matmul")?;
1080        match (self, rhs) {
1081            (Self::Cpu(lhs), Self::Cpu(rhs)) => {
1082                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1083                Ok(Self::Cpu(storage))
1084            }
1085            (Self::Cuda(lhs), Self::Cuda(rhs)) => {
1086                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1087                Ok(Self::Cuda(storage))
1088            }
1089            (Self::Metal(lhs), Self::Metal(rhs)) => {
1090                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1091                Ok(Self::Metal(storage))
1092            }
1093            #[cfg(feature = "rocm")]
1094            (Self::Rocm(lhs), Self::Rocm(rhs)) => {
1095                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1096                Ok(Self::Rocm(storage))
1097            }
1098            #[cfg(feature = "vulkan")]
1099            (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
1100                let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1101                Ok(Self::Vulkan(storage))
1102            }
1103            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1104                lhs: lhs.device().location(),
1105                rhs: rhs.device().location(),
1106                op: "matmul",
1107            }
1108            .bt()),
1109        }
1110    }
1111
1112    // self, the source can be strided whereas dst is contiguous.
1113    pub(crate) fn copy_strided_src(
1114        &self,
1115        dst: &mut Self,
1116        dst_offset: usize,
1117        src_l: &Layout,
1118    ) -> Result<()> {
1119        match (self, dst) {
1120            (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
1121            (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
1122            (Self::Metal(src), Self::Metal(dst)) => {
1123                Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
1124            }
1125            #[cfg(feature = "rocm")]
1126            (Self::Rocm(src), Self::Rocm(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
1127            #[cfg(feature = "vulkan")]
1128            (Self::Vulkan(src), Self::Vulkan(dst)) => {
1129                Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
1130            }
1131            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1132                lhs: lhs.device().location(),
1133                rhs: rhs.device().location(),
1134                op: "copy",
1135            }
1136            .bt()),
1137        }
1138    }
1139
1140    #[allow(clippy::too_many_arguments)]
1141    pub(crate) fn copy2d(
1142        &self,
1143        dst: &mut Self,
1144        d1: usize,
1145        d2: usize,
1146        src_s: usize,
1147        dst_s: usize,
1148        src_o: usize,
1149        dst_o: usize,
1150    ) -> Result<()> {
1151        match (self, dst) {
1152            (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
1153            (Self::Cuda(src), Self::Cuda(dst)) => {
1154                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1155            }
1156            (Self::Metal(src), Self::Metal(dst)) => {
1157                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1158            }
1159            #[cfg(feature = "rocm")]
1160            (Self::Rocm(src), Self::Rocm(dst)) => {
1161                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1162            }
1163            #[cfg(feature = "vulkan")]
1164            (Self::Vulkan(src), Self::Vulkan(dst)) => {
1165                Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1166            }
1167            (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1168                lhs: lhs.device().location(),
1169                rhs: rhs.device().location(),
1170                op: "copy2d",
1171            }
1172            .bt()),
1173        }
1174    }
1175}