1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::scalar::Scalar;
4use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
5use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
6
7#[derive(Debug)]
10pub enum Storage {
11 Cpu(CpuStorage),
12 Cuda(CudaStorage),
13 Metal(MetalStorage),
14}
15
16impl Storage {
17 pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
18 match self {
19 Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
20 Self::Cuda(storage) => {
21 let storage = storage.try_clone(layout)?;
22 Ok(Self::Cuda(storage))
23 }
24 Self::Metal(storage) => {
25 let storage = storage.try_clone(layout)?;
26 Ok(Self::Metal(storage))
27 }
28 }
29 }
30
31 pub fn device(&self) -> Device {
32 match self {
33 Self::Cpu(_) => Device::Cpu,
34 Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
35 Self::Metal(storage) => Device::Metal(storage.device().clone()),
36 }
37 }
38
39 pub fn dtype(&self) -> DType {
40 match self {
41 Self::Cpu(storage) => storage.dtype(),
42 Self::Cuda(storage) => storage.dtype(),
43 Self::Metal(storage) => storage.dtype(),
44 }
45 }
46
47 pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
48 let lhs_device = self.device();
49 let rhs_device = rhs.device();
50 let lhs = lhs_device.location();
51 let rhs = rhs_device.location();
52 let same_device = if self.device().is_metal() {
53 lhs_device.same_device(&rhs_device)
57 } else {
58 lhs == rhs
59 };
60 if !same_device {
61 Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
62 } else {
63 Ok(())
64 }
65 }
66
67 pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
68 let lhs = self.dtype();
69 let rhs = rhs.dtype();
70 if lhs != rhs {
71 Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
72 } else {
73 Ok(())
74 }
75 }
76
77 pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
78 match self {
79 Storage::Cpu(storage) => storage.const_set(v, l),
80 Storage::Cuda(storage) => storage.const_set(v, l),
81 Storage::Metal(storage) => storage.const_set(v, l),
82 }
83 }
84
85 pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
86 match self {
87 Storage::Cpu(storage) => {
88 let storage = storage.affine(layout, mul, add)?;
89 Ok(Self::Cpu(storage))
90 }
91 Self::Cuda(storage) => {
92 let storage = storage.affine(layout, mul, add)?;
93 Ok(Self::Cuda(storage))
94 }
95 Self::Metal(storage) => {
96 let storage = storage.affine(layout, mul, add)?;
97 Ok(Self::Metal(storage))
98 }
99 }
100 }
101
102 pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
103 match self {
104 Storage::Cpu(storage) => {
105 let storage = storage.powf(layout, alpha)?;
106 Ok(Self::Cpu(storage))
107 }
108 Self::Cuda(storage) => {
109 let storage = storage.powf(layout, alpha)?;
110 Ok(Self::Cuda(storage))
111 }
112 Self::Metal(storage) => {
113 let storage = storage.powf(layout, alpha)?;
114 Ok(Self::Metal(storage))
115 }
116 }
117 }
118
119 pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
120 match self {
121 Storage::Cpu(storage) => {
122 let storage = storage.elu(layout, alpha)?;
123 Ok(Self::Cpu(storage))
124 }
125 Self::Cuda(storage) => {
126 let storage = storage.elu(layout, alpha)?;
127 Ok(Self::Cuda(storage))
128 }
129 Self::Metal(storage) => {
130 let storage = storage.elu(layout, alpha)?;
131 Ok(Self::Metal(storage))
132 }
133 }
134 }
135
136 pub(crate) fn cmp(
137 &self,
138 op: CmpOp,
139 rhs: &Self,
140 lhs_layout: &Layout,
141 rhs_layout: &Layout,
142 ) -> Result<Self> {
143 self.same_device(rhs, "cmp")?;
144 self.same_dtype(rhs, "cmp")?;
145 match (self, rhs) {
146 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
147 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
148 Ok(Self::Cpu(storage))
149 }
150 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
151 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
152 Ok(Self::Cuda(storage))
153 }
154 (Self::Metal(lhs), Self::Metal(rhs)) => {
155 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
156 Ok(Self::Metal(storage))
157 }
158 (lhs, rhs) => {
159 Err(Error::DeviceMismatchBinaryOp {
162 lhs: lhs.device().location(),
163 rhs: rhs.device().location(),
164 op: "cmp",
165 }
166 .bt())
167 }
168 }
169 }
170
171 pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
172 match self {
173 Storage::Cpu(storage) => {
174 let storage = storage.reduce_op(op, layout, s)?;
175 Ok(Self::Cpu(storage))
176 }
177 Self::Cuda(storage) => {
178 let storage = storage.reduce_op(op, layout, s)?;
179 Ok(Self::Cuda(storage))
180 }
181 Self::Metal(storage) => {
182 let storage = storage.reduce_op(op, layout, s)?;
183 Ok(Self::Metal(storage))
184 }
185 }
186 }
187
188 pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
189 match self {
190 Storage::Cpu(storage) => {
191 let storage = storage.to_dtype(layout, dtype)?;
192 Ok(Self::Cpu(storage))
193 }
194 Self::Cuda(storage) => {
195 let storage = storage.to_dtype(layout, dtype)?;
196 Ok(Self::Cuda(storage))
197 }
198 Self::Metal(storage) => {
199 let storage = storage.to_dtype(layout, dtype)?;
200 Ok(Self::Metal(storage))
201 }
202 }
203 }
204
205 pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
206 match self {
207 Self::Cpu(storage) => {
208 let (storage, shape) = c.cpu_fwd(storage, l)?;
209 Ok((Self::Cpu(storage), shape))
210 }
211 Self::Cuda(storage) => {
212 let (storage, shape) = c.cuda_fwd(storage, l)?;
213 Ok((Self::Cuda(storage), shape))
214 }
215 Self::Metal(storage) => {
216 let (storage, shape) = c.metal_fwd(storage, l)?;
217 Ok((Self::Metal(storage), shape))
218 }
219 }
220 }
221
222 pub(crate) fn apply_op2(
223 &self,
224 l1: &Layout,
225 t2: &Self,
226 l2: &Layout,
227 c: &dyn CustomOp2,
228 ) -> Result<(Self, Shape)> {
229 self.same_device(t2, c.name())?;
230 match (self, t2) {
231 (Self::Cpu(s1), Self::Cpu(s2)) => {
232 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
233 Ok((Self::Cpu(s), shape))
234 }
235 (Self::Cuda(s1), Self::Cuda(s2)) => {
236 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
237 Ok((Self::Cuda(s), shape))
238 }
239 (Self::Metal(s1), Self::Metal(s2)) => {
240 let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
241 Ok((Self::Metal(s), shape))
242 }
243 _ => unreachable!(),
244 }
245 }
246
247 pub(crate) fn apply_op3(
248 &self,
249 l1: &Layout,
250 t2: &Self,
251 l2: &Layout,
252 t3: &Self,
253 l3: &Layout,
254 c: &dyn CustomOp3,
255 ) -> Result<(Self, Shape)> {
256 self.same_device(t2, c.name())?;
257 self.same_device(t3, c.name())?;
258 match (self, t2, t3) {
259 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
260 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
261 Ok((Self::Cpu(s), shape))
262 }
263 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
264 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
265 Ok((Self::Cuda(s), shape))
266 }
267 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
268 let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
269 Ok((Self::Metal(s), shape))
270 }
271 _ => unreachable!(),
272 }
273 }
274
275 pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
276 match self {
277 Self::Cpu(storage) => c.cpu_fwd(storage, l),
278 Self::Cuda(storage) => c.cuda_fwd(storage, l),
279 Self::Metal(storage) => c.metal_fwd(storage, l),
280 }
281 }
282
283 pub(crate) fn inplace_op2(
284 &mut self,
285 l1: &Layout,
286 t2: &Self,
287 l2: &Layout,
288 c: &dyn InplaceOp2,
289 ) -> Result<()> {
290 self.same_device(t2, c.name())?;
291 match (self, t2) {
292 (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
293 (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
294 (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
295 _ => unreachable!(),
296 }
297 }
298
299 pub(crate) fn inplace_op3(
300 &mut self,
301 l1: &Layout,
302 t2: &Self,
303 l2: &Layout,
304 t3: &Self,
305 l3: &Layout,
306 c: &dyn InplaceOp3,
307 ) -> Result<()> {
308 self.same_device(t2, c.name())?;
309 self.same_device(t3, c.name())?;
310 match (self, t2, t3) {
311 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
312 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
313 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
314 c.metal_fwd(s1, l1, s2, l2, s3, l3)
315 }
316 _ => unreachable!(),
317 }
318 }
319
320 pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
321 match self {
322 Storage::Cpu(storage) => {
323 let storage = storage.unary_impl::<B>(layout)?;
324 Ok(Self::Cpu(storage))
325 }
326 Self::Cuda(storage) => {
327 let storage = storage.unary_impl::<B>(layout)?;
328 Ok(Self::Cuda(storage))
329 }
330 Self::Metal(storage) => {
331 let storage = storage.unary_impl::<B>(layout)?;
332 Ok(Self::Metal(storage))
333 }
334 }
335 }
336
337 pub(crate) fn binary_impl<B: op::BinaryOpT>(
338 &self,
339 rhs: &Self,
340 lhs_layout: &Layout,
341 rhs_layout: &Layout,
342 ) -> Result<Self> {
343 self.same_device(rhs, B::NAME)?;
344 self.same_dtype(rhs, B::NAME)?;
345 match (self, rhs) {
346 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
347 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
348 Ok(Self::Cpu(storage))
349 }
350 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
351 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
352 Ok(Self::Cuda(storage))
353 }
354 (Self::Metal(lhs), Self::Metal(rhs)) => {
355 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
356 Ok(Self::Metal(storage))
357 }
358 (lhs, rhs) => {
359 Err(Error::DeviceMismatchBinaryOp {
362 lhs: lhs.device().location(),
363 rhs: rhs.device().location(),
364 op: B::NAME,
365 }
366 .bt())
367 }
368 }
369 }
370
371 pub(crate) fn conv1d(
372 &self,
373 l: &Layout,
374 kernel: &Self,
375 kernel_l: &Layout,
376 params: &crate::conv::ParamsConv1D,
377 ) -> Result<Self> {
378 self.same_device(kernel, "conv1d")?;
379 self.same_dtype(kernel, "conv1d")?;
380 match (self, &kernel) {
381 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
382 let s = inp.conv1d(l, kernel, kernel_l, params)?;
383 Ok(Self::Cpu(s))
384 }
385 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
386 let s = inp.conv1d(l, kernel, kernel_l, params)?;
387 Ok(Self::Cuda(s))
388 }
389 (Storage::Metal(inp), Storage::Metal(kernel)) => {
390 let s = inp.conv1d(l, kernel, kernel_l, params)?;
391 Ok(Self::Metal(s))
392 }
393 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
394 lhs: lhs.device().location(),
395 rhs: rhs.device().location(),
396 op: "conv1d",
397 }
398 .bt()),
399 }
400 }
401
402 pub(crate) fn conv_transpose1d(
403 &self,
404 l: &Layout,
405 kernel: &Self,
406 kernel_l: &Layout,
407 params: &crate::conv::ParamsConvTranspose1D,
408 ) -> Result<Self> {
409 self.same_device(kernel, "conv-transpose1d")?;
410 self.same_dtype(kernel, "conv-transpose1d")?;
411 match (self, &kernel) {
412 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
413 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
414 Ok(Self::Cpu(s))
415 }
416 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
417 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
418 Ok(Self::Cuda(s))
419 }
420 (Storage::Metal(inp), Storage::Metal(kernel)) => {
421 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
422 Ok(Self::Metal(s))
423 }
424 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
425 lhs: lhs.device().location(),
426 rhs: rhs.device().location(),
427 op: "conv-transpose1d",
428 }
429 .bt()),
430 }
431 }
432
433 pub(crate) fn conv2d(
434 &self,
435 l: &Layout,
436 kernel: &Self,
437 kernel_l: &Layout,
438 params: &crate::conv::ParamsConv2D,
439 ) -> Result<Self> {
440 self.same_device(kernel, "conv2d")?;
441 self.same_dtype(kernel, "conv2d")?;
442 match (self, &kernel) {
443 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
444 let s = inp.conv2d(l, kernel, kernel_l, params)?;
445 Ok(Self::Cpu(s))
446 }
447 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
448 let s = inp.conv2d(l, kernel, kernel_l, params)?;
449 Ok(Self::Cuda(s))
450 }
451 (Storage::Metal(inp), Storage::Metal(kernel)) => {
452 let s = inp.conv2d(l, kernel, kernel_l, params)?;
453 Ok(Self::Metal(s))
454 }
455 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
456 lhs: lhs.device().location(),
457 rhs: rhs.device().location(),
458 op: "conv2d",
459 }
460 .bt()),
461 }
462 }
463
464 pub(crate) fn conv_transpose2d(
465 &self,
466 l: &Layout,
467 kernel: &Self,
468 kernel_l: &Layout,
469 params: &crate::conv::ParamsConvTranspose2D,
470 ) -> Result<Self> {
471 self.same_device(kernel, "conv_transpose2d")?;
472 self.same_dtype(kernel, "conv_transpose2d")?;
473 match (self, &kernel) {
474 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
475 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
476 Ok(Self::Cpu(s))
477 }
478 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
479 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
480 Ok(Self::Cuda(s))
481 }
482 (Storage::Metal(inp), Storage::Metal(kernel)) => {
483 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
484 Ok(Self::Metal(s))
485 }
486 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
487 lhs: lhs.device().location(),
488 rhs: rhs.device().location(),
489 op: "conv_transpose2d",
490 }
491 .bt()),
492 }
493 }
494
495 pub(crate) fn avg_pool2d(
496 &self,
497 layout: &Layout,
498 kernel_size: (usize, usize),
499 stride: (usize, usize),
500 ) -> Result<Self> {
501 match self {
502 Storage::Cpu(storage) => {
503 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
504 Ok(Self::Cpu(storage))
505 }
506 Self::Cuda(storage) => {
507 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
508 Ok(Self::Cuda(storage))
509 }
510 Self::Metal(storage) => {
511 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
512 Ok(Self::Metal(storage))
513 }
514 }
515 }
516
517 pub(crate) fn max_pool2d(
518 &self,
519 layout: &Layout,
520 kernel_size: (usize, usize),
521 stride: (usize, usize),
522 ) -> Result<Self> {
523 match self {
524 Storage::Cpu(storage) => {
525 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
526 Ok(Self::Cpu(storage))
527 }
528 Self::Cuda(storage) => {
529 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
530 Ok(Self::Cuda(storage))
531 }
532 Self::Metal(storage) => {
533 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
534 Ok(Self::Metal(storage))
535 }
536 }
537 }
538
539 pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
540 match self {
541 Storage::Cpu(storage) => {
542 let storage = storage.upsample_nearest1d(layout, sz)?;
543 Ok(Self::Cpu(storage))
544 }
545 Self::Cuda(storage) => {
546 let storage = storage.upsample_nearest1d(layout, sz)?;
547 Ok(Self::Cuda(storage))
548 }
549 Self::Metal(storage) => {
550 let storage = storage.upsample_nearest1d(layout, sz)?;
551 Ok(Self::Metal(storage))
552 }
553 }
554 }
555
556 pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
557 match self {
558 Storage::Cpu(storage) => {
559 let storage = storage.upsample_nearest2d(layout, h, w)?;
560 Ok(Self::Cpu(storage))
561 }
562 Self::Cuda(storage) => {
563 let storage = storage.upsample_nearest2d(layout, h, w)?;
564 Ok(Self::Cuda(storage))
565 }
566 Self::Metal(storage) => {
567 let storage = storage.upsample_nearest2d(layout, h, w)?;
568 Ok(Self::Metal(storage))
569 }
570 }
571 }
572
573 pub(crate) fn where_cond(
574 &self,
575 layout: &Layout,
576 t: &Self,
577 layout_t: &Layout,
578 f: &Self,
579 layout_f: &Layout,
580 ) -> Result<Self> {
581 self.same_device(t, "where")?;
582 self.same_device(f, "where")?;
583 t.same_dtype(f, "where")?;
584 match (self, t, f) {
585 (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
586 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
587 Ok(Self::Cpu(storage))
588 }
589 (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
590 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
591 Ok(Self::Cuda(storage))
592 }
593 (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
594 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
595 Ok(Self::Metal(storage))
596 }
597 (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
598 lhs: lhs.device().location(),
599 rhs: rhs.device().location(),
600 op: "where",
601 }
602 .bt()),
603 }
604 }
605
606 pub(crate) fn gather(
607 &self,
608 l: &Layout,
609 indexes: &Self,
610 indexes_l: &Layout,
611 d: usize,
612 ) -> Result<Self> {
613 self.same_device(indexes, "index-add")?;
614 match (self, indexes) {
615 (Self::Cpu(s), Self::Cpu(indexes)) => {
616 let storage = s.gather(l, indexes, indexes_l, d)?;
617 Ok(Self::Cpu(storage))
618 }
619 (Self::Cuda(s), Self::Cuda(indexes)) => {
620 let storage = s.gather(l, indexes, indexes_l, d)?;
621 Ok(Self::Cuda(storage))
622 }
623 (Self::Metal(s), Self::Metal(indexes)) => {
624 let storage = s.gather(l, indexes, indexes_l, d)?;
625 Ok(Self::Metal(storage))
626 }
627 _ => unreachable!(),
628 }
629 }
630
631 pub(crate) fn scatter_set(
632 &mut self,
633 l: &Layout,
634 indexes: &Self,
635 indexes_l: &Layout,
636 source: &Self,
637 source_l: &Layout,
638 d: usize,
639 ) -> Result<()> {
640 self.same_device(indexes, "scatter-set")?;
641 self.same_device(source, "scatter-set")?;
642 match (self, indexes, source) {
643 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
644 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
645 }
646 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
647 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
648 }
649 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
650 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
651 }
652 _ => unreachable!(),
653 }
654 Ok(())
655 }
656
657 pub(crate) fn scatter_add(
658 &mut self,
659 l: &Layout,
660 indexes: &Self,
661 indexes_l: &Layout,
662 source: &Self,
663 source_l: &Layout,
664 d: usize,
665 ) -> Result<()> {
666 self.same_device(indexes, "scatter-add")?;
667 self.same_device(source, "scatter-add")?;
668 match (self, indexes, source) {
669 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
670 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
671 }
672 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
673 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
674 }
675 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
676 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
677 }
678 _ => unreachable!(),
679 }
680 Ok(())
681 }
682
683 pub(crate) fn index_add(
684 &self,
685 l: &Layout,
686 indexes: &Self,
687 indexes_l: &Layout,
688 source: &Self,
689 source_l: &Layout,
690 d: usize,
691 ) -> Result<Self> {
692 self.same_device(indexes, "index-add")?;
693 self.same_device(source, "index-add")?;
694 match (self, indexes, source) {
695 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
696 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
697 Ok(Self::Cpu(storage))
698 }
699 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
700 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
701 Ok(Self::Cuda(storage))
702 }
703 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
704 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
705 Ok(Self::Metal(storage))
706 }
707 _ => unreachable!(),
708 }
709 }
710
711 pub(crate) fn index_select(
712 &self,
713 rhs: &Self,
714 lhs_l: &Layout,
715 rhs_l: &Layout,
716 d: usize,
717 ) -> Result<Self> {
718 self.same_device(rhs, "index-select")?;
719 match (self, rhs) {
720 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
721 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
722 Ok(Self::Cpu(storage))
723 }
724 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
725 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
726 Ok(Self::Cuda(storage))
727 }
728 (Self::Metal(lhs), Self::Metal(rhs)) => {
729 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
730 Ok(Self::Metal(storage))
731 }
732 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
733 lhs: lhs.device().location(),
734 rhs: rhs.device().location(),
735 op: "index-select",
736 }
737 .bt()),
738 }
739 }
740
741 pub(crate) fn matmul(
742 &self,
743 rhs: &Self,
744 bmnk: (usize, usize, usize, usize),
745 lhs_layout: &Layout,
746 rhs_layout: &Layout,
747 ) -> Result<Self> {
748 self.same_device(rhs, "matmul")?;
749 self.same_dtype(rhs, "matmul")?;
750 match (self, rhs) {
751 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
752 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
753 Ok(Self::Cpu(storage))
754 }
755 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
756 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
757 Ok(Self::Cuda(storage))
758 }
759 (Self::Metal(lhs), Self::Metal(rhs)) => {
760 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
761 Ok(Self::Metal(storage))
762 }
763 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
764 lhs: lhs.device().location(),
765 rhs: rhs.device().location(),
766 op: "matmul",
767 }
768 .bt()),
769 }
770 }
771
772 pub(crate) fn copy_strided_src(
774 &self,
775 dst: &mut Self,
776 dst_offset: usize,
777 src_l: &Layout,
778 ) -> Result<()> {
779 match (self, dst) {
780 (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
781 (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
782 (Self::Metal(src), Self::Metal(dst)) => {
783 Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
784 }
785 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
786 lhs: lhs.device().location(),
787 rhs: rhs.device().location(),
788 op: "copy",
789 }
790 .bt()),
791 }
792 }
793
794 #[allow(clippy::too_many_arguments)]
795 pub(crate) fn copy2d(
796 &self,
797 dst: &mut Self,
798 d1: usize,
799 d2: usize,
800 src_s: usize,
801 dst_s: usize,
802 src_o: usize,
803 dst_o: usize,
804 ) -> Result<()> {
805 match (self, dst) {
806 (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
807 (Self::Cuda(src), Self::Cuda(dst)) => {
808 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
809 }
810 (Self::Metal(src), Self::Metal(dst)) => {
811 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
812 }
813 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
814 lhs: lhs.device().location(),
815 rhs: rhs.device().location(),
816 op: "copy2d",
817 }
818 .bt()),
819 }
820 }
821}