1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::scalar::Scalar;
4#[cfg(feature = "rocm")]
5use crate::RocmStorage;
6#[cfg(feature = "vulkan")]
7use crate::VulkanStorage;
8use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
9use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
10
11#[derive(Debug)]
14pub enum Storage {
15 Cpu(CpuStorage),
16 Cuda(CudaStorage),
17 Metal(MetalStorage),
18 #[cfg(feature = "rocm")]
19 Rocm(RocmStorage),
20 #[cfg(feature = "vulkan")]
21 Vulkan(VulkanStorage),
22}
23
24impl Storage {
25 pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
26 match self {
27 Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
28 Self::Cuda(storage) => {
29 let storage = storage.try_clone(layout)?;
30 Ok(Self::Cuda(storage))
31 }
32 Self::Metal(storage) => {
33 let storage = storage.try_clone(layout)?;
34 Ok(Self::Metal(storage))
35 }
36 #[cfg(feature = "rocm")]
37 Self::Rocm(storage) => {
38 let storage = storage.try_clone(layout)?;
39 Ok(Self::Rocm(storage))
40 }
41 #[cfg(feature = "vulkan")]
42 Self::Vulkan(storage) => {
43 let storage = storage.try_clone(layout)?;
44 Ok(Self::Vulkan(storage))
45 }
46 }
47 }
48
49 pub fn device(&self) -> Device {
50 match self {
51 Self::Cpu(_) => Device::Cpu,
52 Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
53 Self::Metal(storage) => Device::Metal(storage.device().clone()),
54 #[cfg(feature = "rocm")]
55 Self::Rocm(storage) => Device::Rocm(storage.device().clone()),
56 #[cfg(feature = "vulkan")]
57 Self::Vulkan(storage) => Device::Vulkan(storage.device().clone()),
58 }
59 }
60
61 pub fn dtype(&self) -> DType {
62 match self {
63 Self::Cpu(storage) => storage.dtype(),
64 Self::Cuda(storage) => storage.dtype(),
65 Self::Metal(storage) => storage.dtype(),
66 #[cfg(feature = "rocm")]
67 Self::Rocm(storage) => storage.dtype(),
68 #[cfg(feature = "vulkan")]
69 Self::Vulkan(storage) => storage.dtype(),
70 }
71 }
72
73 pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
74 let lhs_device = self.device();
75 let rhs_device = rhs.device();
76 let lhs = lhs_device.location();
77 let rhs = rhs_device.location();
78 let same_device = if self.device().is_metal() {
79 lhs_device.same_device(&rhs_device)
83 } else {
84 lhs == rhs
85 };
86 if !same_device {
87 Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
88 } else {
89 Ok(())
90 }
91 }
92
93 pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
94 let lhs = self.dtype();
95 let rhs = rhs.dtype();
96 if lhs != rhs {
97 Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
98 } else {
99 Ok(())
100 }
101 }
102
103 pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
104 match self {
105 Storage::Cpu(storage) => storage.const_set(v, l),
106 Storage::Cuda(storage) => storage.const_set(v, l),
107 Storage::Metal(storage) => storage.const_set(v, l),
108 #[cfg(feature = "rocm")]
109 Storage::Rocm(storage) => storage.const_set(v, l),
110 #[cfg(feature = "vulkan")]
111 Storage::Vulkan(storage) => storage.const_set(v, l),
112 }
113 }
114
115 pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
116 match self {
117 Storage::Cpu(storage) => {
118 let storage = storage.affine(layout, mul, add)?;
119 Ok(Self::Cpu(storage))
120 }
121 Self::Cuda(storage) => {
122 let storage = storage.affine(layout, mul, add)?;
123 Ok(Self::Cuda(storage))
124 }
125 Self::Metal(storage) => {
126 let storage = storage.affine(layout, mul, add)?;
127 Ok(Self::Metal(storage))
128 }
129 #[cfg(feature = "rocm")]
130 Self::Rocm(storage) => {
131 let storage = storage.affine(layout, mul, add)?;
132 Ok(Self::Rocm(storage))
133 }
134 #[cfg(feature = "vulkan")]
135 Self::Vulkan(storage) => {
136 let storage = storage.affine(layout, mul, add)?;
137 Ok(Self::Vulkan(storage))
138 }
139 }
140 }
141
142 pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
143 match self {
144 Storage::Cpu(storage) => {
145 let storage = storage.powf(layout, alpha)?;
146 Ok(Self::Cpu(storage))
147 }
148 Self::Cuda(storage) => {
149 let storage = storage.powf(layout, alpha)?;
150 Ok(Self::Cuda(storage))
151 }
152 Self::Metal(storage) => {
153 let storage = storage.powf(layout, alpha)?;
154 Ok(Self::Metal(storage))
155 }
156 #[cfg(feature = "rocm")]
157 Self::Rocm(storage) => {
158 let storage = storage.powf(layout, alpha)?;
159 Ok(Self::Rocm(storage))
160 }
161 #[cfg(feature = "vulkan")]
162 Self::Vulkan(storage) => {
163 let storage = storage.powf(layout, alpha)?;
164 Ok(Self::Vulkan(storage))
165 }
166 }
167 }
168
169 pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
170 match self {
171 Storage::Cpu(storage) => {
172 let storage = storage.elu(layout, alpha)?;
173 Ok(Self::Cpu(storage))
174 }
175 Self::Cuda(storage) => {
176 let storage = storage.elu(layout, alpha)?;
177 Ok(Self::Cuda(storage))
178 }
179 Self::Metal(storage) => {
180 let storage = storage.elu(layout, alpha)?;
181 Ok(Self::Metal(storage))
182 }
183 #[cfg(feature = "rocm")]
184 Self::Rocm(storage) => {
185 let storage = storage.elu(layout, alpha)?;
186 Ok(Self::Rocm(storage))
187 }
188 #[cfg(feature = "vulkan")]
189 Self::Vulkan(storage) => {
190 let storage = storage.elu(layout, alpha)?;
191 Ok(Self::Vulkan(storage))
192 }
193 }
194 }
195
196 pub(crate) fn cmp(
197 &self,
198 op: CmpOp,
199 rhs: &Self,
200 lhs_layout: &Layout,
201 rhs_layout: &Layout,
202 ) -> Result<Self> {
203 self.same_device(rhs, "cmp")?;
204 self.same_dtype(rhs, "cmp")?;
205 match (self, rhs) {
206 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
207 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
208 Ok(Self::Cpu(storage))
209 }
210 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
211 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
212 Ok(Self::Cuda(storage))
213 }
214 (Self::Metal(lhs), Self::Metal(rhs)) => {
215 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
216 Ok(Self::Metal(storage))
217 }
218 #[cfg(feature = "rocm")]
219 (Self::Rocm(lhs), Self::Rocm(rhs)) => {
220 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
221 Ok(Self::Rocm(storage))
222 }
223 #[cfg(feature = "vulkan")]
224 (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
225 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
226 Ok(Self::Vulkan(storage))
227 }
228 (lhs, rhs) => {
229 Err(Error::DeviceMismatchBinaryOp {
232 lhs: lhs.device().location(),
233 rhs: rhs.device().location(),
234 op: "cmp",
235 }
236 .bt())
237 }
238 }
239 }
240
241 pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
242 match self {
243 Storage::Cpu(storage) => {
244 let storage = storage.reduce_op(op, layout, s)?;
245 Ok(Self::Cpu(storage))
246 }
247 Self::Cuda(storage) => {
248 let storage = storage.reduce_op(op, layout, s)?;
249 Ok(Self::Cuda(storage))
250 }
251 Self::Metal(storage) => {
252 let storage = storage.reduce_op(op, layout, s)?;
253 Ok(Self::Metal(storage))
254 }
255 #[cfg(feature = "rocm")]
256 Self::Rocm(storage) => {
257 let storage = storage.reduce_op(op, layout, s)?;
258 Ok(Self::Rocm(storage))
259 }
260 #[cfg(feature = "vulkan")]
261 Self::Vulkan(storage) => {
262 let storage = storage.reduce_op(op, layout, s)?;
263 Ok(Self::Vulkan(storage))
264 }
265 }
266 }
267
268 pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
269 match self {
270 Storage::Cpu(storage) => {
271 let storage = storage.to_dtype(layout, dtype)?;
272 Ok(Self::Cpu(storage))
273 }
274 Self::Cuda(storage) => {
275 let storage = storage.to_dtype(layout, dtype)?;
276 Ok(Self::Cuda(storage))
277 }
278 Self::Metal(storage) => {
279 let storage = storage.to_dtype(layout, dtype)?;
280 Ok(Self::Metal(storage))
281 }
282 #[cfg(feature = "rocm")]
283 Self::Rocm(storage) => {
284 let storage = storage.to_dtype(layout, dtype)?;
285 Ok(Self::Rocm(storage))
286 }
287 #[cfg(feature = "vulkan")]
288 Self::Vulkan(storage) => {
289 let storage = storage.to_dtype(layout, dtype)?;
290 Ok(Self::Vulkan(storage))
291 }
292 }
293 }
294
295 pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
296 match self {
297 Self::Cpu(storage) => {
298 let (storage, shape) = c.cpu_fwd(storage, l)?;
299 Ok((Self::Cpu(storage), shape))
300 }
301 Self::Cuda(storage) => {
302 let (storage, shape) = c.cuda_fwd(storage, l)?;
303 Ok((Self::Cuda(storage), shape))
304 }
305 Self::Metal(storage) => {
306 let (storage, shape) = c.metal_fwd(storage, l)?;
307 Ok((Self::Metal(storage), shape))
308 }
309 #[cfg(feature = "rocm")]
310 Self::Rocm(storage) => {
311 let (storage, shape) = c.rocm_fwd(storage, l)?;
312 Ok((Self::Rocm(storage), shape))
313 }
314 #[cfg(feature = "vulkan")]
315 Self::Vulkan(storage) => {
316 let (storage, shape) = c.vulkan_fwd(storage, l)?;
317 Ok((Self::Vulkan(storage), shape))
318 }
319 }
320 }
321
322 pub(crate) fn apply_op2(
323 &self,
324 l1: &Layout,
325 t2: &Self,
326 l2: &Layout,
327 c: &dyn CustomOp2,
328 ) -> Result<(Self, Shape)> {
329 self.same_device(t2, c.name())?;
330 match (self, t2) {
331 (Self::Cpu(s1), Self::Cpu(s2)) => {
332 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
333 Ok((Self::Cpu(s), shape))
334 }
335 (Self::Cuda(s1), Self::Cuda(s2)) => {
336 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
337 Ok((Self::Cuda(s), shape))
338 }
339 (Self::Metal(s1), Self::Metal(s2)) => {
340 let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
341 Ok((Self::Metal(s), shape))
342 }
343 #[cfg(feature = "rocm")]
344 (Self::Rocm(s1), Self::Rocm(s2)) => {
345 let (s, shape) = c.rocm_fwd(s1, l1, s2, l2)?;
346 Ok((Self::Rocm(s), shape))
347 }
348 #[cfg(feature = "vulkan")]
349 (Self::Vulkan(s1), Self::Vulkan(s2)) => {
350 let (s, shape) = c.vulkan_fwd(s1, l1, s2, l2)?;
351 Ok((Self::Vulkan(s), shape))
352 }
353 _ => unreachable!(),
354 }
355 }
356
357 pub(crate) fn apply_op3(
358 &self,
359 l1: &Layout,
360 t2: &Self,
361 l2: &Layout,
362 t3: &Self,
363 l3: &Layout,
364 c: &dyn CustomOp3,
365 ) -> Result<(Self, Shape)> {
366 self.same_device(t2, c.name())?;
367 self.same_device(t3, c.name())?;
368 match (self, t2, t3) {
369 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
370 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
371 Ok((Self::Cpu(s), shape))
372 }
373 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
374 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
375 Ok((Self::Cuda(s), shape))
376 }
377 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
378 let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
379 Ok((Self::Metal(s), shape))
380 }
381 #[cfg(feature = "rocm")]
382 (Self::Rocm(s1), Self::Rocm(s2), Self::Rocm(s3)) => {
383 let (s, shape) = c.rocm_fwd(s1, l1, s2, l2, s3, l3)?;
384 Ok((Self::Rocm(s), shape))
385 }
386 #[cfg(feature = "vulkan")]
387 (Self::Vulkan(s1), Self::Vulkan(s2), Self::Vulkan(s3)) => {
388 let (s, shape) = c.vulkan_fwd(s1, l1, s2, l2, s3, l3)?;
389 Ok((Self::Vulkan(s), shape))
390 }
391 _ => unreachable!(),
392 }
393 }
394
395 pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
396 match self {
397 Self::Cpu(storage) => c.cpu_fwd(storage, l),
398 Self::Cuda(storage) => c.cuda_fwd(storage, l),
399 Self::Metal(storage) => c.metal_fwd(storage, l),
400 #[cfg(feature = "rocm")]
401 Self::Rocm(storage) => c.rocm_fwd(storage, l),
402 #[cfg(feature = "vulkan")]
403 Self::Vulkan(storage) => c.vulkan_fwd(storage, l),
404 }
405 }
406
407 pub(crate) fn inplace_op2(
408 &mut self,
409 l1: &Layout,
410 t2: &Self,
411 l2: &Layout,
412 c: &dyn InplaceOp2,
413 ) -> Result<()> {
414 self.same_device(t2, c.name())?;
415 match (self, t2) {
416 (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
417 (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
418 (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
419 #[cfg(feature = "rocm")]
420 (Self::Rocm(s1), Self::Rocm(s2)) => c.rocm_fwd(s1, l1, s2, l2),
421 #[cfg(feature = "vulkan")]
422 (Self::Vulkan(s1), Self::Vulkan(s2)) => c.vulkan_fwd(s1, l1, s2, l2),
423 _ => unreachable!(),
424 }
425 }
426
427 pub(crate) fn inplace_op3(
428 &mut self,
429 l1: &Layout,
430 t2: &Self,
431 l2: &Layout,
432 t3: &Self,
433 l3: &Layout,
434 c: &dyn InplaceOp3,
435 ) -> Result<()> {
436 self.same_device(t2, c.name())?;
437 self.same_device(t3, c.name())?;
438 match (self, t2, t3) {
439 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
440 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
441 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
442 c.metal_fwd(s1, l1, s2, l2, s3, l3)
443 }
444 #[cfg(feature = "rocm")]
445 (Self::Rocm(s1), Self::Rocm(s2), Self::Rocm(s3)) => c.rocm_fwd(s1, l1, s2, l2, s3, l3),
446 #[cfg(feature = "vulkan")]
447 (Self::Vulkan(s1), Self::Vulkan(s2), Self::Vulkan(s3)) => {
448 c.vulkan_fwd(s1, l1, s2, l2, s3, l3)
449 }
450 _ => unreachable!(),
451 }
452 }
453
454 pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
455 match self {
456 Storage::Cpu(storage) => {
457 let storage = storage.unary_impl::<B>(layout)?;
458 Ok(Self::Cpu(storage))
459 }
460 Self::Cuda(storage) => {
461 let storage = storage.unary_impl::<B>(layout)?;
462 Ok(Self::Cuda(storage))
463 }
464 Self::Metal(storage) => {
465 let storage = storage.unary_impl::<B>(layout)?;
466 Ok(Self::Metal(storage))
467 }
468 #[cfg(feature = "rocm")]
469 Self::Rocm(storage) => {
470 let storage = storage.unary_impl::<B>(layout)?;
471 Ok(Self::Rocm(storage))
472 }
473 #[cfg(feature = "vulkan")]
474 Self::Vulkan(storage) => {
475 let storage = storage.unary_impl::<B>(layout)?;
476 Ok(Self::Vulkan(storage))
477 }
478 }
479 }
480
481 pub(crate) fn binary_impl<B: op::BinaryOpT>(
482 &self,
483 rhs: &Self,
484 lhs_layout: &Layout,
485 rhs_layout: &Layout,
486 ) -> Result<Self> {
487 self.same_device(rhs, B::NAME)?;
488 self.same_dtype(rhs, B::NAME)?;
489 match (self, rhs) {
490 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
491 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
492 Ok(Self::Cpu(storage))
493 }
494 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
495 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
496 Ok(Self::Cuda(storage))
497 }
498 (Self::Metal(lhs), Self::Metal(rhs)) => {
499 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
500 Ok(Self::Metal(storage))
501 }
502 #[cfg(feature = "rocm")]
503 (Self::Rocm(lhs), Self::Rocm(rhs)) => {
504 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
505 Ok(Self::Rocm(storage))
506 }
507 #[cfg(feature = "vulkan")]
508 (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
509 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
510 Ok(Self::Vulkan(storage))
511 }
512 (lhs, rhs) => {
513 Err(Error::DeviceMismatchBinaryOp {
516 lhs: lhs.device().location(),
517 rhs: rhs.device().location(),
518 op: B::NAME,
519 }
520 .bt())
521 }
522 }
523 }
524
525 pub(crate) fn conv1d(
526 &self,
527 l: &Layout,
528 kernel: &Self,
529 kernel_l: &Layout,
530 params: &crate::conv::ParamsConv1D,
531 ) -> Result<Self> {
532 self.same_device(kernel, "conv1d")?;
533 self.same_dtype(kernel, "conv1d")?;
534 match (self, &kernel) {
535 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
536 let s = inp.conv1d(l, kernel, kernel_l, params)?;
537 Ok(Self::Cpu(s))
538 }
539 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
540 let s = inp.conv1d(l, kernel, kernel_l, params)?;
541 Ok(Self::Cuda(s))
542 }
543 (Storage::Metal(inp), Storage::Metal(kernel)) => {
544 let s = inp.conv1d(l, kernel, kernel_l, params)?;
545 Ok(Self::Metal(s))
546 }
547 #[cfg(feature = "rocm")]
548 (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
549 let s = inp.conv1d(l, kernel, kernel_l, params)?;
550 Ok(Self::Rocm(s))
551 }
552 #[cfg(feature = "vulkan")]
553 (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
554 let s = inp.conv1d(l, kernel, kernel_l, params)?;
555 Ok(Self::Vulkan(s))
556 }
557 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
558 lhs: lhs.device().location(),
559 rhs: rhs.device().location(),
560 op: "conv1d",
561 }
562 .bt()),
563 }
564 }
565
566 pub(crate) fn conv_transpose1d(
567 &self,
568 l: &Layout,
569 kernel: &Self,
570 kernel_l: &Layout,
571 params: &crate::conv::ParamsConvTranspose1D,
572 ) -> Result<Self> {
573 self.same_device(kernel, "conv-transpose1d")?;
574 self.same_dtype(kernel, "conv-transpose1d")?;
575 match (self, &kernel) {
576 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
577 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
578 Ok(Self::Cpu(s))
579 }
580 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
581 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
582 Ok(Self::Cuda(s))
583 }
584 (Storage::Metal(inp), Storage::Metal(kernel)) => {
585 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
586 Ok(Self::Metal(s))
587 }
588 #[cfg(feature = "rocm")]
589 (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
590 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
591 Ok(Self::Rocm(s))
592 }
593 #[cfg(feature = "vulkan")]
594 (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
595 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
596 Ok(Self::Vulkan(s))
597 }
598 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
599 lhs: lhs.device().location(),
600 rhs: rhs.device().location(),
601 op: "conv-transpose1d",
602 }
603 .bt()),
604 }
605 }
606
607 pub(crate) fn conv2d(
608 &self,
609 l: &Layout,
610 kernel: &Self,
611 kernel_l: &Layout,
612 params: &crate::conv::ParamsConv2D,
613 ) -> Result<Self> {
614 self.same_device(kernel, "conv2d")?;
615 self.same_dtype(kernel, "conv2d")?;
616 match (self, &kernel) {
617 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
618 let s = inp.conv2d(l, kernel, kernel_l, params)?;
619 Ok(Self::Cpu(s))
620 }
621 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
622 let s = inp.conv2d(l, kernel, kernel_l, params)?;
623 Ok(Self::Cuda(s))
624 }
625 (Storage::Metal(inp), Storage::Metal(kernel)) => {
626 let s = inp.conv2d(l, kernel, kernel_l, params)?;
627 Ok(Self::Metal(s))
628 }
629 #[cfg(feature = "rocm")]
630 (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
631 let s = inp.conv2d(l, kernel, kernel_l, params)?;
632 Ok(Self::Rocm(s))
633 }
634 #[cfg(feature = "vulkan")]
635 (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
636 let s = inp.conv2d(l, kernel, kernel_l, params)?;
637 Ok(Self::Vulkan(s))
638 }
639 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
640 lhs: lhs.device().location(),
641 rhs: rhs.device().location(),
642 op: "conv2d",
643 }
644 .bt()),
645 }
646 }
647
648 pub(crate) fn conv_transpose2d(
649 &self,
650 l: &Layout,
651 kernel: &Self,
652 kernel_l: &Layout,
653 params: &crate::conv::ParamsConvTranspose2D,
654 ) -> Result<Self> {
655 self.same_device(kernel, "conv_transpose2d")?;
656 self.same_dtype(kernel, "conv_transpose2d")?;
657 match (self, &kernel) {
658 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
659 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
660 Ok(Self::Cpu(s))
661 }
662 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
663 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
664 Ok(Self::Cuda(s))
665 }
666 (Storage::Metal(inp), Storage::Metal(kernel)) => {
667 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
668 Ok(Self::Metal(s))
669 }
670 #[cfg(feature = "rocm")]
671 (Storage::Rocm(inp), Storage::Rocm(kernel)) => {
672 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
673 Ok(Self::Rocm(s))
674 }
675 #[cfg(feature = "vulkan")]
676 (Storage::Vulkan(inp), Storage::Vulkan(kernel)) => {
677 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
678 Ok(Self::Vulkan(s))
679 }
680 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
681 lhs: lhs.device().location(),
682 rhs: rhs.device().location(),
683 op: "conv_transpose2d",
684 }
685 .bt()),
686 }
687 }
688
689 pub(crate) fn avg_pool2d(
690 &self,
691 layout: &Layout,
692 kernel_size: (usize, usize),
693 stride: (usize, usize),
694 ) -> Result<Self> {
695 match self {
696 Storage::Cpu(storage) => {
697 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
698 Ok(Self::Cpu(storage))
699 }
700 Self::Cuda(storage) => {
701 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
702 Ok(Self::Cuda(storage))
703 }
704 Self::Metal(storage) => {
705 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
706 Ok(Self::Metal(storage))
707 }
708 #[cfg(feature = "rocm")]
709 Self::Rocm(storage) => {
710 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
711 Ok(Self::Rocm(storage))
712 }
713 #[cfg(feature = "vulkan")]
714 Self::Vulkan(storage) => {
715 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
716 Ok(Self::Vulkan(storage))
717 }
718 }
719 }
720
721 pub(crate) fn max_pool2d(
722 &self,
723 layout: &Layout,
724 kernel_size: (usize, usize),
725 stride: (usize, usize),
726 ) -> Result<Self> {
727 match self {
728 Storage::Cpu(storage) => {
729 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
730 Ok(Self::Cpu(storage))
731 }
732 Self::Cuda(storage) => {
733 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
734 Ok(Self::Cuda(storage))
735 }
736 Self::Metal(storage) => {
737 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
738 Ok(Self::Metal(storage))
739 }
740 #[cfg(feature = "rocm")]
741 Self::Rocm(storage) => {
742 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
743 Ok(Self::Rocm(storage))
744 }
745 #[cfg(feature = "vulkan")]
746 Self::Vulkan(storage) => {
747 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
748 Ok(Self::Vulkan(storage))
749 }
750 }
751 }
752
753 pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
754 match self {
755 Storage::Cpu(storage) => {
756 let storage = storage.upsample_nearest1d(layout, sz)?;
757 Ok(Self::Cpu(storage))
758 }
759 Self::Cuda(storage) => {
760 let storage = storage.upsample_nearest1d(layout, sz)?;
761 Ok(Self::Cuda(storage))
762 }
763 Self::Metal(storage) => {
764 let storage = storage.upsample_nearest1d(layout, sz)?;
765 Ok(Self::Metal(storage))
766 }
767 #[cfg(feature = "rocm")]
768 Self::Rocm(storage) => {
769 let storage = storage.upsample_nearest1d(layout, sz)?;
770 Ok(Self::Rocm(storage))
771 }
772 #[cfg(feature = "vulkan")]
773 Self::Vulkan(storage) => {
774 let storage = storage.upsample_nearest1d(layout, sz)?;
775 Ok(Self::Vulkan(storage))
776 }
777 }
778 }
779
780 pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
781 match self {
782 Storage::Cpu(storage) => {
783 let storage = storage.upsample_nearest2d(layout, h, w)?;
784 Ok(Self::Cpu(storage))
785 }
786 Self::Cuda(storage) => {
787 let storage = storage.upsample_nearest2d(layout, h, w)?;
788 Ok(Self::Cuda(storage))
789 }
790 Self::Metal(storage) => {
791 let storage = storage.upsample_nearest2d(layout, h, w)?;
792 Ok(Self::Metal(storage))
793 }
794 #[cfg(feature = "rocm")]
795 Self::Rocm(storage) => {
796 let storage = storage.upsample_nearest2d(layout, h, w)?;
797 Ok(Self::Rocm(storage))
798 }
799 #[cfg(feature = "vulkan")]
800 Self::Vulkan(storage) => {
801 let storage = storage.upsample_nearest2d(layout, h, w)?;
802 Ok(Self::Vulkan(storage))
803 }
804 }
805 }
806
807 pub(crate) fn upsample_bilinear2d(
808 &self,
809 layout: &Layout,
810 h: usize,
811 w: usize,
812 align_corners: bool,
813 scale_h: Option<f64>,
814 scale_w: Option<f64>,
815 ) -> Result<Self> {
816 match self {
817 Storage::Cpu(storage) => {
818 let storage =
819 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
820 Ok(Self::Cpu(storage))
821 }
822 Self::Cuda(storage) => {
823 let storage =
824 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
825 Ok(Self::Cuda(storage))
826 }
827 Self::Metal(storage) => {
828 let storage =
829 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
830 Ok(Self::Metal(storage))
831 }
832 #[cfg(feature = "rocm")]
833 Self::Rocm(storage) => {
834 let storage =
835 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
836 Ok(Self::Rocm(storage))
837 }
838 #[cfg(feature = "vulkan")]
839 Self::Vulkan(storage) => {
840 let storage =
841 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
842 Ok(Self::Vulkan(storage))
843 }
844 }
845 }
846
847 pub(crate) fn where_cond(
848 &self,
849 layout: &Layout,
850 t: &Self,
851 layout_t: &Layout,
852 f: &Self,
853 layout_f: &Layout,
854 ) -> Result<Self> {
855 self.same_device(t, "where")?;
856 self.same_device(f, "where")?;
857 t.same_dtype(f, "where")?;
858 match (self, t, f) {
859 (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
860 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
861 Ok(Self::Cpu(storage))
862 }
863 (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
864 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
865 Ok(Self::Cuda(storage))
866 }
867 (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
868 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
869 Ok(Self::Metal(storage))
870 }
871 #[cfg(feature = "rocm")]
872 (Self::Rocm(cond), Self::Rocm(t), Self::Rocm(f)) => {
873 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
874 Ok(Self::Rocm(storage))
875 }
876 #[cfg(feature = "vulkan")]
877 (Self::Vulkan(cond), Self::Vulkan(t), Self::Vulkan(f)) => {
878 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
879 Ok(Self::Vulkan(storage))
880 }
881 (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
882 lhs: lhs.device().location(),
883 rhs: rhs.device().location(),
884 op: "where",
885 }
886 .bt()),
887 }
888 }
889
890 pub(crate) fn gather(
891 &self,
892 l: &Layout,
893 indexes: &Self,
894 indexes_l: &Layout,
895 d: usize,
896 ) -> Result<Self> {
897 self.same_device(indexes, "index-add")?;
898 match (self, indexes) {
899 (Self::Cpu(s), Self::Cpu(indexes)) => {
900 let storage = s.gather(l, indexes, indexes_l, d)?;
901 Ok(Self::Cpu(storage))
902 }
903 (Self::Cuda(s), Self::Cuda(indexes)) => {
904 let storage = s.gather(l, indexes, indexes_l, d)?;
905 Ok(Self::Cuda(storage))
906 }
907 (Self::Metal(s), Self::Metal(indexes)) => {
908 let storage = s.gather(l, indexes, indexes_l, d)?;
909 Ok(Self::Metal(storage))
910 }
911 #[cfg(feature = "rocm")]
912 (Self::Rocm(s), Self::Rocm(indexes)) => {
913 let storage = s.gather(l, indexes, indexes_l, d)?;
914 Ok(Self::Rocm(storage))
915 }
916 #[cfg(feature = "vulkan")]
917 (Self::Vulkan(s), Self::Vulkan(indexes)) => {
918 let storage = s.gather(l, indexes, indexes_l, d)?;
919 Ok(Self::Vulkan(storage))
920 }
921 _ => unreachable!(),
922 }
923 }
924
925 pub(crate) fn scatter_set(
926 &mut self,
927 l: &Layout,
928 indexes: &Self,
929 indexes_l: &Layout,
930 source: &Self,
931 source_l: &Layout,
932 d: usize,
933 ) -> Result<()> {
934 self.same_device(indexes, "scatter-set")?;
935 self.same_device(source, "scatter-set")?;
936 match (self, indexes, source) {
937 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
938 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
939 }
940 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
941 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
942 }
943 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
944 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
945 }
946 #[cfg(feature = "rocm")]
947 (Self::Rocm(s), Self::Rocm(indexes), Self::Rocm(source)) => {
948 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
949 }
950 #[cfg(feature = "vulkan")]
951 (Self::Vulkan(s), Self::Vulkan(indexes), Self::Vulkan(source)) => {
952 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
953 }
954 _ => unreachable!(),
955 }
956 Ok(())
957 }
958
959 pub(crate) fn scatter_add(
960 &mut self,
961 l: &Layout,
962 indexes: &Self,
963 indexes_l: &Layout,
964 source: &Self,
965 source_l: &Layout,
966 d: usize,
967 ) -> Result<()> {
968 self.same_device(indexes, "scatter-add")?;
969 self.same_device(source, "scatter-add")?;
970 match (self, indexes, source) {
971 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
972 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
973 }
974 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
975 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
976 }
977 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
978 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
979 }
980 #[cfg(feature = "rocm")]
981 (Self::Rocm(s), Self::Rocm(indexes), Self::Rocm(source)) => {
982 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
983 }
984 #[cfg(feature = "vulkan")]
985 (Self::Vulkan(s), Self::Vulkan(indexes), Self::Vulkan(source)) => {
986 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
987 }
988 _ => unreachable!(),
989 }
990 Ok(())
991 }
992
993 pub(crate) fn index_add(
994 &self,
995 l: &Layout,
996 indexes: &Self,
997 indexes_l: &Layout,
998 source: &Self,
999 source_l: &Layout,
1000 d: usize,
1001 ) -> Result<Self> {
1002 self.same_device(indexes, "index-add")?;
1003 self.same_device(source, "index-add")?;
1004 match (self, indexes, source) {
1005 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
1006 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1007 Ok(Self::Cpu(storage))
1008 }
1009 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
1010 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1011 Ok(Self::Cuda(storage))
1012 }
1013 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
1014 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1015 Ok(Self::Metal(storage))
1016 }
1017 #[cfg(feature = "rocm")]
1018 (Self::Rocm(s), Self::Rocm(indexes), Self::Rocm(source)) => {
1019 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1020 Ok(Self::Rocm(storage))
1021 }
1022 #[cfg(feature = "vulkan")]
1023 (Self::Vulkan(s), Self::Vulkan(indexes), Self::Vulkan(source)) => {
1024 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
1025 Ok(Self::Vulkan(storage))
1026 }
1027 _ => unreachable!(),
1028 }
1029 }
1030
1031 pub(crate) fn index_select(
1032 &self,
1033 rhs: &Self,
1034 lhs_l: &Layout,
1035 rhs_l: &Layout,
1036 d: usize,
1037 ) -> Result<Self> {
1038 self.same_device(rhs, "index-select")?;
1039 match (self, rhs) {
1040 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
1041 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1042 Ok(Self::Cpu(storage))
1043 }
1044 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
1045 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1046 Ok(Self::Cuda(storage))
1047 }
1048 (Self::Metal(lhs), Self::Metal(rhs)) => {
1049 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1050 Ok(Self::Metal(storage))
1051 }
1052 #[cfg(feature = "rocm")]
1053 (Self::Rocm(lhs), Self::Rocm(rhs)) => {
1054 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1055 Ok(Self::Rocm(storage))
1056 }
1057 #[cfg(feature = "vulkan")]
1058 (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
1059 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
1060 Ok(Self::Vulkan(storage))
1061 }
1062 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1063 lhs: lhs.device().location(),
1064 rhs: rhs.device().location(),
1065 op: "index-select",
1066 }
1067 .bt()),
1068 }
1069 }
1070
1071 pub(crate) fn matmul(
1072 &self,
1073 rhs: &Self,
1074 bmnk: (usize, usize, usize, usize),
1075 lhs_layout: &Layout,
1076 rhs_layout: &Layout,
1077 ) -> Result<Self> {
1078 self.same_device(rhs, "matmul")?;
1079 self.same_dtype(rhs, "matmul")?;
1080 match (self, rhs) {
1081 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
1082 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1083 Ok(Self::Cpu(storage))
1084 }
1085 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
1086 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1087 Ok(Self::Cuda(storage))
1088 }
1089 (Self::Metal(lhs), Self::Metal(rhs)) => {
1090 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1091 Ok(Self::Metal(storage))
1092 }
1093 #[cfg(feature = "rocm")]
1094 (Self::Rocm(lhs), Self::Rocm(rhs)) => {
1095 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1096 Ok(Self::Rocm(storage))
1097 }
1098 #[cfg(feature = "vulkan")]
1099 (Self::Vulkan(lhs), Self::Vulkan(rhs)) => {
1100 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
1101 Ok(Self::Vulkan(storage))
1102 }
1103 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1104 lhs: lhs.device().location(),
1105 rhs: rhs.device().location(),
1106 op: "matmul",
1107 }
1108 .bt()),
1109 }
1110 }
1111
1112 pub(crate) fn copy_strided_src(
1114 &self,
1115 dst: &mut Self,
1116 dst_offset: usize,
1117 src_l: &Layout,
1118 ) -> Result<()> {
1119 match (self, dst) {
1120 (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
1121 (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
1122 (Self::Metal(src), Self::Metal(dst)) => {
1123 Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
1124 }
1125 #[cfg(feature = "rocm")]
1126 (Self::Rocm(src), Self::Rocm(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
1127 #[cfg(feature = "vulkan")]
1128 (Self::Vulkan(src), Self::Vulkan(dst)) => {
1129 Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
1130 }
1131 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1132 lhs: lhs.device().location(),
1133 rhs: rhs.device().location(),
1134 op: "copy",
1135 }
1136 .bt()),
1137 }
1138 }
1139
1140 #[allow(clippy::too_many_arguments)]
1141 pub(crate) fn copy2d(
1142 &self,
1143 dst: &mut Self,
1144 d1: usize,
1145 d2: usize,
1146 src_s: usize,
1147 dst_s: usize,
1148 src_o: usize,
1149 dst_o: usize,
1150 ) -> Result<()> {
1151 match (self, dst) {
1152 (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
1153 (Self::Cuda(src), Self::Cuda(dst)) => {
1154 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1155 }
1156 (Self::Metal(src), Self::Metal(dst)) => {
1157 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1158 }
1159 #[cfg(feature = "rocm")]
1160 (Self::Rocm(src), Self::Rocm(dst)) => {
1161 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1162 }
1163 #[cfg(feature = "vulkan")]
1164 (Self::Vulkan(src), Self::Vulkan(dst)) => {
1165 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
1166 }
1167 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
1168 lhs: lhs.device().location(),
1169 rhs: rhs.device().location(),
1170 op: "copy2d",
1171 }
1172 .bt()),
1173 }
1174 }
1175}