1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
4use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
5
6#[derive(Debug)]
9pub enum Storage {
10 Cpu(CpuStorage),
11 Cuda(CudaStorage),
12 Metal(MetalStorage),
13}
14
15impl Storage {
16 pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
17 match self {
18 Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
19 Self::Cuda(storage) => {
20 let storage = storage.try_clone(layout)?;
21 Ok(Self::Cuda(storage))
22 }
23 Self::Metal(storage) => {
24 let storage = storage.try_clone(layout)?;
25 Ok(Self::Metal(storage))
26 }
27 }
28 }
29
30 pub fn device(&self) -> Device {
31 match self {
32 Self::Cpu(_) => Device::Cpu,
33 Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
34 Self::Metal(storage) => Device::Metal(storage.device().clone()),
35 }
36 }
37
38 pub fn dtype(&self) -> DType {
39 match self {
40 Self::Cpu(storage) => storage.dtype(),
41 Self::Cuda(storage) => storage.dtype(),
42 Self::Metal(storage) => storage.dtype(),
43 }
44 }
45
46 pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
47 let lhs_device = self.device();
48 let rhs_device = rhs.device();
49 let lhs = lhs_device.location();
50 let rhs = rhs_device.location();
51 let same_device = if self.device().is_metal() {
52 lhs_device.same_device(&rhs_device)
56 } else {
57 lhs == rhs
58 };
59 if !same_device {
60 Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
61 } else {
62 Ok(())
63 }
64 }
65
66 pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
67 let lhs = self.dtype();
68 let rhs = rhs.dtype();
69 if lhs != rhs {
70 Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
71 } else {
72 Ok(())
73 }
74 }
75
76 pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
77 match self {
78 Storage::Cpu(storage) => {
79 let storage = storage.affine(layout, mul, add)?;
80 Ok(Self::Cpu(storage))
81 }
82 Self::Cuda(storage) => {
83 let storage = storage.affine(layout, mul, add)?;
84 Ok(Self::Cuda(storage))
85 }
86 Self::Metal(storage) => {
87 let storage = storage.affine(layout, mul, add)?;
88 Ok(Self::Metal(storage))
89 }
90 }
91 }
92
93 pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
94 match self {
95 Storage::Cpu(storage) => {
96 let storage = storage.powf(layout, alpha)?;
97 Ok(Self::Cpu(storage))
98 }
99 Self::Cuda(storage) => {
100 let storage = storage.powf(layout, alpha)?;
101 Ok(Self::Cuda(storage))
102 }
103 Self::Metal(storage) => {
104 let storage = storage.powf(layout, alpha)?;
105 Ok(Self::Metal(storage))
106 }
107 }
108 }
109
110 pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
111 match self {
112 Storage::Cpu(storage) => {
113 let storage = storage.elu(layout, alpha)?;
114 Ok(Self::Cpu(storage))
115 }
116 Self::Cuda(storage) => {
117 let storage = storage.elu(layout, alpha)?;
118 Ok(Self::Cuda(storage))
119 }
120 Self::Metal(storage) => {
121 let storage = storage.elu(layout, alpha)?;
122 Ok(Self::Metal(storage))
123 }
124 }
125 }
126
127 pub(crate) fn cmp(
128 &self,
129 op: CmpOp,
130 rhs: &Self,
131 lhs_layout: &Layout,
132 rhs_layout: &Layout,
133 ) -> Result<Self> {
134 self.same_device(rhs, "cmp")?;
135 self.same_dtype(rhs, "cmp")?;
136 match (self, rhs) {
137 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
138 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
139 Ok(Self::Cpu(storage))
140 }
141 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
142 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
143 Ok(Self::Cuda(storage))
144 }
145 (Self::Metal(lhs), Self::Metal(rhs)) => {
146 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
147 Ok(Self::Metal(storage))
148 }
149 (lhs, rhs) => {
150 Err(Error::DeviceMismatchBinaryOp {
153 lhs: lhs.device().location(),
154 rhs: rhs.device().location(),
155 op: "cmp",
156 }
157 .bt())
158 }
159 }
160 }
161
162 pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
163 match self {
164 Storage::Cpu(storage) => {
165 let storage = storage.reduce_op(op, layout, s)?;
166 Ok(Self::Cpu(storage))
167 }
168 Self::Cuda(storage) => {
169 let storage = storage.reduce_op(op, layout, s)?;
170 Ok(Self::Cuda(storage))
171 }
172 Self::Metal(storage) => {
173 let storage = storage.reduce_op(op, layout, s)?;
174 Ok(Self::Metal(storage))
175 }
176 }
177 }
178
179 pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
180 match self {
181 Storage::Cpu(storage) => {
182 let storage = storage.to_dtype(layout, dtype)?;
183 Ok(Self::Cpu(storage))
184 }
185 Self::Cuda(storage) => {
186 let storage = storage.to_dtype(layout, dtype)?;
187 Ok(Self::Cuda(storage))
188 }
189 Self::Metal(storage) => {
190 let storage = storage.to_dtype(layout, dtype)?;
191 Ok(Self::Metal(storage))
192 }
193 }
194 }
195
196 pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
197 match self {
198 Self::Cpu(storage) => {
199 let (storage, shape) = c.cpu_fwd(storage, l)?;
200 Ok((Self::Cpu(storage), shape))
201 }
202 Self::Cuda(storage) => {
203 let (storage, shape) = c.cuda_fwd(storage, l)?;
204 Ok((Self::Cuda(storage), shape))
205 }
206 Self::Metal(storage) => {
207 let (storage, shape) = c.metal_fwd(storage, l)?;
208 Ok((Self::Metal(storage), shape))
209 }
210 }
211 }
212
213 pub(crate) fn apply_op2(
214 &self,
215 l1: &Layout,
216 t2: &Self,
217 l2: &Layout,
218 c: &dyn CustomOp2,
219 ) -> Result<(Self, Shape)> {
220 self.same_device(t2, c.name())?;
221 match (self, t2) {
222 (Self::Cpu(s1), Self::Cpu(s2)) => {
223 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
224 Ok((Self::Cpu(s), shape))
225 }
226 (Self::Cuda(s1), Self::Cuda(s2)) => {
227 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
228 Ok((Self::Cuda(s), shape))
229 }
230 (Self::Metal(s1), Self::Metal(s2)) => {
231 let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
232 Ok((Self::Metal(s), shape))
233 }
234 _ => unreachable!(),
235 }
236 }
237
238 pub(crate) fn apply_op3(
239 &self,
240 l1: &Layout,
241 t2: &Self,
242 l2: &Layout,
243 t3: &Self,
244 l3: &Layout,
245 c: &dyn CustomOp3,
246 ) -> Result<(Self, Shape)> {
247 self.same_device(t2, c.name())?;
248 self.same_device(t3, c.name())?;
249 match (self, t2, t3) {
250 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
251 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
252 Ok((Self::Cpu(s), shape))
253 }
254 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
255 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
256 Ok((Self::Cuda(s), shape))
257 }
258 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
259 let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
260 Ok((Self::Metal(s), shape))
261 }
262 _ => unreachable!(),
263 }
264 }
265
266 pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
267 match self {
268 Self::Cpu(storage) => c.cpu_fwd(storage, l),
269 Self::Cuda(storage) => c.cuda_fwd(storage, l),
270 Self::Metal(storage) => c.metal_fwd(storage, l),
271 }
272 }
273
274 pub(crate) fn inplace_op2(
275 &mut self,
276 l1: &Layout,
277 t2: &Self,
278 l2: &Layout,
279 c: &dyn InplaceOp2,
280 ) -> Result<()> {
281 self.same_device(t2, c.name())?;
282 match (self, t2) {
283 (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
284 (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
285 (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
286 _ => unreachable!(),
287 }
288 }
289
290 pub(crate) fn inplace_op3(
291 &mut self,
292 l1: &Layout,
293 t2: &Self,
294 l2: &Layout,
295 t3: &Self,
296 l3: &Layout,
297 c: &dyn InplaceOp3,
298 ) -> Result<()> {
299 self.same_device(t2, c.name())?;
300 self.same_device(t3, c.name())?;
301 match (self, t2, t3) {
302 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
303 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
304 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
305 c.metal_fwd(s1, l1, s2, l2, s3, l3)
306 }
307 _ => unreachable!(),
308 }
309 }
310
311 pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
312 match self {
313 Storage::Cpu(storage) => {
314 let storage = storage.unary_impl::<B>(layout)?;
315 Ok(Self::Cpu(storage))
316 }
317 Self::Cuda(storage) => {
318 let storage = storage.unary_impl::<B>(layout)?;
319 Ok(Self::Cuda(storage))
320 }
321 Self::Metal(storage) => {
322 let storage = storage.unary_impl::<B>(layout)?;
323 Ok(Self::Metal(storage))
324 }
325 }
326 }
327
328 pub(crate) fn binary_impl<B: op::BinaryOpT>(
329 &self,
330 rhs: &Self,
331 lhs_layout: &Layout,
332 rhs_layout: &Layout,
333 ) -> Result<Self> {
334 self.same_device(rhs, B::NAME)?;
335 self.same_dtype(rhs, B::NAME)?;
336 match (self, rhs) {
337 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
338 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
339 Ok(Self::Cpu(storage))
340 }
341 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
342 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
343 Ok(Self::Cuda(storage))
344 }
345 (Self::Metal(lhs), Self::Metal(rhs)) => {
346 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
347 Ok(Self::Metal(storage))
348 }
349 (lhs, rhs) => {
350 Err(Error::DeviceMismatchBinaryOp {
353 lhs: lhs.device().location(),
354 rhs: rhs.device().location(),
355 op: B::NAME,
356 }
357 .bt())
358 }
359 }
360 }
361
362 pub(crate) fn conv1d(
363 &self,
364 l: &Layout,
365 kernel: &Self,
366 kernel_l: &Layout,
367 params: &crate::conv::ParamsConv1D,
368 ) -> Result<Self> {
369 self.same_device(kernel, "conv1d")?;
370 self.same_dtype(kernel, "conv1d")?;
371 match (self, &kernel) {
372 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
373 let s = inp.conv1d(l, kernel, kernel_l, params)?;
374 Ok(Self::Cpu(s))
375 }
376 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
377 let s = inp.conv1d(l, kernel, kernel_l, params)?;
378 Ok(Self::Cuda(s))
379 }
380 (Storage::Metal(inp), Storage::Metal(kernel)) => {
381 let s = inp.conv1d(l, kernel, kernel_l, params)?;
382 Ok(Self::Metal(s))
383 }
384 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
385 lhs: lhs.device().location(),
386 rhs: rhs.device().location(),
387 op: "conv1d",
388 }
389 .bt()),
390 }
391 }
392
393 pub(crate) fn conv_transpose1d(
394 &self,
395 l: &Layout,
396 kernel: &Self,
397 kernel_l: &Layout,
398 params: &crate::conv::ParamsConvTranspose1D,
399 ) -> Result<Self> {
400 self.same_device(kernel, "conv-transpose1d")?;
401 self.same_dtype(kernel, "conv-transpose1d")?;
402 match (self, &kernel) {
403 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
404 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
405 Ok(Self::Cpu(s))
406 }
407 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
408 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
409 Ok(Self::Cuda(s))
410 }
411 (Storage::Metal(inp), Storage::Metal(kernel)) => {
412 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
413 Ok(Self::Metal(s))
414 }
415 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
416 lhs: lhs.device().location(),
417 rhs: rhs.device().location(),
418 op: "conv-transpose1d",
419 }
420 .bt()),
421 }
422 }
423
424 pub(crate) fn conv2d(
425 &self,
426 l: &Layout,
427 kernel: &Self,
428 kernel_l: &Layout,
429 params: &crate::conv::ParamsConv2D,
430 ) -> Result<Self> {
431 self.same_device(kernel, "conv2d")?;
432 self.same_dtype(kernel, "conv2d")?;
433 match (self, &kernel) {
434 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
435 let s = inp.conv2d(l, kernel, kernel_l, params)?;
436 Ok(Self::Cpu(s))
437 }
438 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
439 let s = inp.conv2d(l, kernel, kernel_l, params)?;
440 Ok(Self::Cuda(s))
441 }
442 (Storage::Metal(inp), Storage::Metal(kernel)) => {
443 let s = inp.conv2d(l, kernel, kernel_l, params)?;
444 Ok(Self::Metal(s))
445 }
446 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
447 lhs: lhs.device().location(),
448 rhs: rhs.device().location(),
449 op: "conv2d",
450 }
451 .bt()),
452 }
453 }
454
455 pub(crate) fn conv_transpose2d(
456 &self,
457 l: &Layout,
458 kernel: &Self,
459 kernel_l: &Layout,
460 params: &crate::conv::ParamsConvTranspose2D,
461 ) -> Result<Self> {
462 self.same_device(kernel, "conv_transpose2d")?;
463 self.same_dtype(kernel, "conv_transpose2d")?;
464 match (self, &kernel) {
465 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
466 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
467 Ok(Self::Cpu(s))
468 }
469 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
470 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
471 Ok(Self::Cuda(s))
472 }
473 (Storage::Metal(inp), Storage::Metal(kernel)) => {
474 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
475 Ok(Self::Metal(s))
476 }
477 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
478 lhs: lhs.device().location(),
479 rhs: rhs.device().location(),
480 op: "conv_transpose2d",
481 }
482 .bt()),
483 }
484 }
485
486 pub(crate) fn avg_pool2d(
487 &self,
488 layout: &Layout,
489 kernel_size: (usize, usize),
490 stride: (usize, usize),
491 ) -> Result<Self> {
492 match self {
493 Storage::Cpu(storage) => {
494 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
495 Ok(Self::Cpu(storage))
496 }
497 Self::Cuda(storage) => {
498 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
499 Ok(Self::Cuda(storage))
500 }
501 Self::Metal(storage) => {
502 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
503 Ok(Self::Metal(storage))
504 }
505 }
506 }
507
508 pub(crate) fn max_pool2d(
509 &self,
510 layout: &Layout,
511 kernel_size: (usize, usize),
512 stride: (usize, usize),
513 ) -> Result<Self> {
514 match self {
515 Storage::Cpu(storage) => {
516 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
517 Ok(Self::Cpu(storage))
518 }
519 Self::Cuda(storage) => {
520 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
521 Ok(Self::Cuda(storage))
522 }
523 Self::Metal(storage) => {
524 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
525 Ok(Self::Metal(storage))
526 }
527 }
528 }
529
530 pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
531 match self {
532 Storage::Cpu(storage) => {
533 let storage = storage.upsample_nearest1d(layout, sz)?;
534 Ok(Self::Cpu(storage))
535 }
536 Self::Cuda(storage) => {
537 let storage = storage.upsample_nearest1d(layout, sz)?;
538 Ok(Self::Cuda(storage))
539 }
540 Self::Metal(storage) => {
541 let storage = storage.upsample_nearest1d(layout, sz)?;
542 Ok(Self::Metal(storage))
543 }
544 }
545 }
546
547 pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
548 match self {
549 Storage::Cpu(storage) => {
550 let storage = storage.upsample_nearest2d(layout, h, w)?;
551 Ok(Self::Cpu(storage))
552 }
553 Self::Cuda(storage) => {
554 let storage = storage.upsample_nearest2d(layout, h, w)?;
555 Ok(Self::Cuda(storage))
556 }
557 Self::Metal(storage) => {
558 let storage = storage.upsample_nearest2d(layout, h, w)?;
559 Ok(Self::Metal(storage))
560 }
561 }
562 }
563
564 pub(crate) fn where_cond(
565 &self,
566 layout: &Layout,
567 t: &Self,
568 layout_t: &Layout,
569 f: &Self,
570 layout_f: &Layout,
571 ) -> Result<Self> {
572 self.same_device(t, "where")?;
573 self.same_device(f, "where")?;
574 t.same_dtype(f, "where")?;
575 match (self, t, f) {
576 (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
577 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
578 Ok(Self::Cpu(storage))
579 }
580 (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
581 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
582 Ok(Self::Cuda(storage))
583 }
584 (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
585 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
586 Ok(Self::Metal(storage))
587 }
588 (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
589 lhs: lhs.device().location(),
590 rhs: rhs.device().location(),
591 op: "where",
592 }
593 .bt()),
594 }
595 }
596
597 pub(crate) fn gather(
598 &self,
599 l: &Layout,
600 indexes: &Self,
601 indexes_l: &Layout,
602 d: usize,
603 ) -> Result<Self> {
604 self.same_device(indexes, "index-add")?;
605 match (self, indexes) {
606 (Self::Cpu(s), Self::Cpu(indexes)) => {
607 let storage = s.gather(l, indexes, indexes_l, d)?;
608 Ok(Self::Cpu(storage))
609 }
610 (Self::Cuda(s), Self::Cuda(indexes)) => {
611 let storage = s.gather(l, indexes, indexes_l, d)?;
612 Ok(Self::Cuda(storage))
613 }
614 (Self::Metal(s), Self::Metal(indexes)) => {
615 let storage = s.gather(l, indexes, indexes_l, d)?;
616 Ok(Self::Metal(storage))
617 }
618 _ => unreachable!(),
619 }
620 }
621
622 pub(crate) fn scatter_add(
623 &self,
624 l: &Layout,
625 indexes: &Self,
626 indexes_l: &Layout,
627 source: &Self,
628 source_l: &Layout,
629 d: usize,
630 ) -> Result<Self> {
631 self.same_device(indexes, "scatter-add")?;
632 self.same_device(source, "scatter-add")?;
633 match (self, indexes, source) {
634 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
635 let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
636 Ok(Self::Cpu(storage))
637 }
638 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
639 let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
640 Ok(Self::Cuda(storage))
641 }
642 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
643 let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
644 Ok(Self::Metal(storage))
645 }
646 _ => unreachable!(),
647 }
648 }
649
650 pub(crate) fn index_add(
651 &self,
652 l: &Layout,
653 indexes: &Self,
654 indexes_l: &Layout,
655 source: &Self,
656 source_l: &Layout,
657 d: usize,
658 ) -> Result<Self> {
659 self.same_device(indexes, "index-add")?;
660 self.same_device(source, "index-add")?;
661 match (self, indexes, source) {
662 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
663 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
664 Ok(Self::Cpu(storage))
665 }
666 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
667 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
668 Ok(Self::Cuda(storage))
669 }
670 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
671 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
672 Ok(Self::Metal(storage))
673 }
674 _ => unreachable!(),
675 }
676 }
677
678 pub(crate) fn index_select(
679 &self,
680 rhs: &Self,
681 lhs_l: &Layout,
682 rhs_l: &Layout,
683 d: usize,
684 ) -> Result<Self> {
685 self.same_device(rhs, "index-select")?;
686 match (self, rhs) {
687 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
688 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
689 Ok(Self::Cpu(storage))
690 }
691 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
692 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
693 Ok(Self::Cuda(storage))
694 }
695 (Self::Metal(lhs), Self::Metal(rhs)) => {
696 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
697 Ok(Self::Metal(storage))
698 }
699 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
700 lhs: lhs.device().location(),
701 rhs: rhs.device().location(),
702 op: "index-select",
703 }
704 .bt()),
705 }
706 }
707
708 pub(crate) fn matmul(
709 &self,
710 rhs: &Self,
711 bmnk: (usize, usize, usize, usize),
712 lhs_layout: &Layout,
713 rhs_layout: &Layout,
714 ) -> Result<Self> {
715 self.same_device(rhs, "matmul")?;
716 self.same_dtype(rhs, "matmul")?;
717 match (self, rhs) {
718 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
719 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
720 Ok(Self::Cpu(storage))
721 }
722 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
723 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
724 Ok(Self::Cuda(storage))
725 }
726 (Self::Metal(lhs), Self::Metal(rhs)) => {
727 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
728 Ok(Self::Metal(storage))
729 }
730 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
731 lhs: lhs.device().location(),
732 rhs: rhs.device().location(),
733 op: "matmul",
734 }
735 .bt()),
736 }
737 }
738
739 pub(crate) fn copy_strided_src(
741 &self,
742 dst: &mut Self,
743 dst_offset: usize,
744 src_l: &Layout,
745 ) -> Result<()> {
746 match (self, dst) {
747 (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
748 (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
749 (Self::Metal(src), Self::Metal(dst)) => {
750 Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
751 }
752 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
753 lhs: lhs.device().location(),
754 rhs: rhs.device().location(),
755 op: "copy",
756 }
757 .bt()),
758 }
759 }
760
761 #[allow(clippy::too_many_arguments)]
762 pub(crate) fn copy2d(
763 &self,
764 dst: &mut Self,
765 d1: usize,
766 d2: usize,
767 src_s: usize,
768 dst_s: usize,
769 src_o: usize,
770 dst_o: usize,
771 ) -> Result<()> {
772 match (self, dst) {
773 (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
774 (Self::Cuda(src), Self::Cuda(dst)) => {
775 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
776 }
777 (Self::Metal(src), Self::Metal(dst)) => {
778 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
779 }
780 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
781 lhs: lhs.device().location(),
782 rhs: rhs.device().location(),
783 op: "copy2d",
784 }
785 .bt()),
786 }
787 }
788}