1use crate::backend::BackendStorage;
2use crate::op::{self, CmpOp, ReduceOp};
3use crate::scalar::Scalar;
4use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
5use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
6
7#[derive(Debug)]
10pub enum Storage {
11 Cpu(CpuStorage),
12 Cuda(CudaStorage),
13 Metal(MetalStorage),
14}
15
16impl Storage {
17 pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
18 match self {
19 Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
20 Self::Cuda(storage) => {
21 let storage = storage.try_clone(layout)?;
22 Ok(Self::Cuda(storage))
23 }
24 Self::Metal(storage) => {
25 let storage = storage.try_clone(layout)?;
26 Ok(Self::Metal(storage))
27 }
28 }
29 }
30
31 pub fn device(&self) -> Device {
32 match self {
33 Self::Cpu(_) => Device::Cpu,
34 Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
35 Self::Metal(storage) => Device::Metal(storage.device().clone()),
36 }
37 }
38
39 pub fn dtype(&self) -> DType {
40 match self {
41 Self::Cpu(storage) => storage.dtype(),
42 Self::Cuda(storage) => storage.dtype(),
43 Self::Metal(storage) => storage.dtype(),
44 }
45 }
46
47 pub(crate) fn same_device(&self, rhs: &Self, op: &'static str) -> Result<()> {
48 let lhs_device = self.device();
49 let rhs_device = rhs.device();
50 let lhs = lhs_device.location();
51 let rhs = rhs_device.location();
52 let same_device = if self.device().is_metal() {
53 lhs_device.same_device(&rhs_device)
57 } else {
58 lhs == rhs
59 };
60 if !same_device {
61 Err(Error::DeviceMismatchBinaryOp { lhs, rhs, op }.bt())
62 } else {
63 Ok(())
64 }
65 }
66
67 pub(crate) fn same_dtype(&self, rhs: &Self, op: &'static str) -> Result<()> {
68 let lhs = self.dtype();
69 let rhs = rhs.dtype();
70 if lhs != rhs {
71 Err(Error::DTypeMismatchBinaryOp { lhs, rhs, op }.bt())
72 } else {
73 Ok(())
74 }
75 }
76
77 pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
78 match self {
79 Storage::Cpu(storage) => storage.const_set(v, l),
80 Storage::Cuda(storage) => storage.const_set(v, l),
81 Storage::Metal(storage) => storage.const_set(v, l),
82 }
83 }
84
85 pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
86 match self {
87 Storage::Cpu(storage) => {
88 let storage = storage.affine(layout, mul, add)?;
89 Ok(Self::Cpu(storage))
90 }
91 Self::Cuda(storage) => {
92 let storage = storage.affine(layout, mul, add)?;
93 Ok(Self::Cuda(storage))
94 }
95 Self::Metal(storage) => {
96 let storage = storage.affine(layout, mul, add)?;
97 Ok(Self::Metal(storage))
98 }
99 }
100 }
101
102 pub(crate) fn powf(&self, layout: &Layout, alpha: f64) -> Result<Self> {
103 match self {
104 Storage::Cpu(storage) => {
105 let storage = storage.powf(layout, alpha)?;
106 Ok(Self::Cpu(storage))
107 }
108 Self::Cuda(storage) => {
109 let storage = storage.powf(layout, alpha)?;
110 Ok(Self::Cuda(storage))
111 }
112 Self::Metal(storage) => {
113 let storage = storage.powf(layout, alpha)?;
114 Ok(Self::Metal(storage))
115 }
116 }
117 }
118
119 pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
120 match self {
121 Storage::Cpu(storage) => {
122 let storage = storage.elu(layout, alpha)?;
123 Ok(Self::Cpu(storage))
124 }
125 Self::Cuda(storage) => {
126 let storage = storage.elu(layout, alpha)?;
127 Ok(Self::Cuda(storage))
128 }
129 Self::Metal(storage) => {
130 let storage = storage.elu(layout, alpha)?;
131 Ok(Self::Metal(storage))
132 }
133 }
134 }
135
136 pub(crate) fn cmp(
137 &self,
138 op: CmpOp,
139 rhs: &Self,
140 lhs_layout: &Layout,
141 rhs_layout: &Layout,
142 ) -> Result<Self> {
143 self.same_device(rhs, "cmp")?;
144 self.same_dtype(rhs, "cmp")?;
145 match (self, rhs) {
146 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
147 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
148 Ok(Self::Cpu(storage))
149 }
150 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
151 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
152 Ok(Self::Cuda(storage))
153 }
154 (Self::Metal(lhs), Self::Metal(rhs)) => {
155 let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
156 Ok(Self::Metal(storage))
157 }
158 (lhs, rhs) => {
159 Err(Error::DeviceMismatchBinaryOp {
162 lhs: lhs.device().location(),
163 rhs: rhs.device().location(),
164 op: "cmp",
165 }
166 .bt())
167 }
168 }
169 }
170
171 pub(crate) fn reduce_op(&self, op: ReduceOp, layout: &Layout, s: &[usize]) -> Result<Self> {
172 match self {
173 Storage::Cpu(storage) => {
174 let storage = storage.reduce_op(op, layout, s)?;
175 Ok(Self::Cpu(storage))
176 }
177 Self::Cuda(storage) => {
178 let storage = storage.reduce_op(op, layout, s)?;
179 Ok(Self::Cuda(storage))
180 }
181 Self::Metal(storage) => {
182 let storage = storage.reduce_op(op, layout, s)?;
183 Ok(Self::Metal(storage))
184 }
185 }
186 }
187
188 pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
189 match self {
190 Storage::Cpu(storage) => {
191 let storage = storage.to_dtype(layout, dtype)?;
192 Ok(Self::Cpu(storage))
193 }
194 Self::Cuda(storage) => {
195 let storage = storage.to_dtype(layout, dtype)?;
196 Ok(Self::Cuda(storage))
197 }
198 Self::Metal(storage) => {
199 let storage = storage.to_dtype(layout, dtype)?;
200 Ok(Self::Metal(storage))
201 }
202 }
203 }
204
205 pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
206 match self {
207 Self::Cpu(storage) => {
208 let (storage, shape) = c.cpu_fwd(storage, l)?;
209 Ok((Self::Cpu(storage), shape))
210 }
211 Self::Cuda(storage) => {
212 let (storage, shape) = c.cuda_fwd(storage, l)?;
213 Ok((Self::Cuda(storage), shape))
214 }
215 Self::Metal(storage) => {
216 let (storage, shape) = c.metal_fwd(storage, l)?;
217 Ok((Self::Metal(storage), shape))
218 }
219 }
220 }
221
222 pub(crate) fn apply_op2(
223 &self,
224 l1: &Layout,
225 t2: &Self,
226 l2: &Layout,
227 c: &dyn CustomOp2,
228 ) -> Result<(Self, Shape)> {
229 self.same_device(t2, c.name())?;
230 match (self, t2) {
231 (Self::Cpu(s1), Self::Cpu(s2)) => {
232 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
233 Ok((Self::Cpu(s), shape))
234 }
235 (Self::Cuda(s1), Self::Cuda(s2)) => {
236 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
237 Ok((Self::Cuda(s), shape))
238 }
239 (Self::Metal(s1), Self::Metal(s2)) => {
240 let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
241 Ok((Self::Metal(s), shape))
242 }
243 _ => unreachable!(),
244 }
245 }
246
247 pub(crate) fn apply_op3(
248 &self,
249 l1: &Layout,
250 t2: &Self,
251 l2: &Layout,
252 t3: &Self,
253 l3: &Layout,
254 c: &dyn CustomOp3,
255 ) -> Result<(Self, Shape)> {
256 self.same_device(t2, c.name())?;
257 self.same_device(t3, c.name())?;
258 match (self, t2, t3) {
259 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
260 let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
261 Ok((Self::Cpu(s), shape))
262 }
263 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
264 let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
265 Ok((Self::Cuda(s), shape))
266 }
267 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
268 let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
269 Ok((Self::Metal(s), shape))
270 }
271 _ => unreachable!(),
272 }
273 }
274
275 pub(crate) fn inplace_op1(&mut self, l: &Layout, c: &dyn InplaceOp1) -> Result<()> {
276 match self {
277 Self::Cpu(storage) => c.cpu_fwd(storage, l),
278 Self::Cuda(storage) => c.cuda_fwd(storage, l),
279 Self::Metal(storage) => c.metal_fwd(storage, l),
280 }
281 }
282
283 pub(crate) fn inplace_op2(
284 &mut self,
285 l1: &Layout,
286 t2: &Self,
287 l2: &Layout,
288 c: &dyn InplaceOp2,
289 ) -> Result<()> {
290 self.same_device(t2, c.name())?;
291 match (self, t2) {
292 (Self::Cpu(s1), Self::Cpu(s2)) => c.cpu_fwd(s1, l1, s2, l2),
293 (Self::Cuda(s1), Self::Cuda(s2)) => c.cuda_fwd(s1, l1, s2, l2),
294 (Self::Metal(s1), Self::Metal(s2)) => c.metal_fwd(s1, l1, s2, l2),
295 _ => unreachable!(),
296 }
297 }
298
299 pub(crate) fn inplace_op3(
300 &mut self,
301 l1: &Layout,
302 t2: &Self,
303 l2: &Layout,
304 t3: &Self,
305 l3: &Layout,
306 c: &dyn InplaceOp3,
307 ) -> Result<()> {
308 self.same_device(t2, c.name())?;
309 self.same_device(t3, c.name())?;
310 match (self, t2, t3) {
311 (Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => c.cpu_fwd(s1, l1, s2, l2, s3, l3),
312 (Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => c.cuda_fwd(s1, l1, s2, l2, s3, l3),
313 (Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
314 c.metal_fwd(s1, l1, s2, l2, s3, l3)
315 }
316 _ => unreachable!(),
317 }
318 }
319
320 pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
321 match self {
322 Storage::Cpu(storage) => {
323 let storage = storage.unary_impl::<B>(layout)?;
324 Ok(Self::Cpu(storage))
325 }
326 Self::Cuda(storage) => {
327 let storage = storage.unary_impl::<B>(layout)?;
328 Ok(Self::Cuda(storage))
329 }
330 Self::Metal(storage) => {
331 let storage = storage.unary_impl::<B>(layout)?;
332 Ok(Self::Metal(storage))
333 }
334 }
335 }
336
337 pub(crate) fn binary_impl<B: op::BinaryOpT>(
338 &self,
339 rhs: &Self,
340 lhs_layout: &Layout,
341 rhs_layout: &Layout,
342 ) -> Result<Self> {
343 self.same_device(rhs, B::NAME)?;
344 self.same_dtype(rhs, B::NAME)?;
345 match (self, rhs) {
346 (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
347 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
348 Ok(Self::Cpu(storage))
349 }
350 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
351 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
352 Ok(Self::Cuda(storage))
353 }
354 (Self::Metal(lhs), Self::Metal(rhs)) => {
355 let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
356 Ok(Self::Metal(storage))
357 }
358 (lhs, rhs) => {
359 Err(Error::DeviceMismatchBinaryOp {
362 lhs: lhs.device().location(),
363 rhs: rhs.device().location(),
364 op: B::NAME,
365 }
366 .bt())
367 }
368 }
369 }
370
371 pub(crate) fn conv1d(
372 &self,
373 l: &Layout,
374 kernel: &Self,
375 kernel_l: &Layout,
376 params: &crate::conv::ParamsConv1D,
377 ) -> Result<Self> {
378 self.same_device(kernel, "conv1d")?;
379 self.same_dtype(kernel, "conv1d")?;
380 match (self, &kernel) {
381 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
382 let s = inp.conv1d(l, kernel, kernel_l, params)?;
383 Ok(Self::Cpu(s))
384 }
385 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
386 let s = inp.conv1d(l, kernel, kernel_l, params)?;
387 Ok(Self::Cuda(s))
388 }
389 (Storage::Metal(inp), Storage::Metal(kernel)) => {
390 let s = inp.conv1d(l, kernel, kernel_l, params)?;
391 Ok(Self::Metal(s))
392 }
393 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
394 lhs: lhs.device().location(),
395 rhs: rhs.device().location(),
396 op: "conv1d",
397 }
398 .bt()),
399 }
400 }
401
402 pub(crate) fn conv_transpose1d(
403 &self,
404 l: &Layout,
405 kernel: &Self,
406 kernel_l: &Layout,
407 params: &crate::conv::ParamsConvTranspose1D,
408 ) -> Result<Self> {
409 self.same_device(kernel, "conv-transpose1d")?;
410 self.same_dtype(kernel, "conv-transpose1d")?;
411 match (self, &kernel) {
412 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
413 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
414 Ok(Self::Cpu(s))
415 }
416 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
417 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
418 Ok(Self::Cuda(s))
419 }
420 (Storage::Metal(inp), Storage::Metal(kernel)) => {
421 let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
422 Ok(Self::Metal(s))
423 }
424 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
425 lhs: lhs.device().location(),
426 rhs: rhs.device().location(),
427 op: "conv-transpose1d",
428 }
429 .bt()),
430 }
431 }
432
433 pub(crate) fn conv2d(
434 &self,
435 l: &Layout,
436 kernel: &Self,
437 kernel_l: &Layout,
438 params: &crate::conv::ParamsConv2D,
439 ) -> Result<Self> {
440 self.same_device(kernel, "conv2d")?;
441 self.same_dtype(kernel, "conv2d")?;
442 match (self, &kernel) {
443 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
444 let s = inp.conv2d(l, kernel, kernel_l, params)?;
445 Ok(Self::Cpu(s))
446 }
447 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
448 let s = inp.conv2d(l, kernel, kernel_l, params)?;
449 Ok(Self::Cuda(s))
450 }
451 (Storage::Metal(inp), Storage::Metal(kernel)) => {
452 let s = inp.conv2d(l, kernel, kernel_l, params)?;
453 Ok(Self::Metal(s))
454 }
455 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
456 lhs: lhs.device().location(),
457 rhs: rhs.device().location(),
458 op: "conv2d",
459 }
460 .bt()),
461 }
462 }
463
464 pub(crate) fn conv_transpose2d(
465 &self,
466 l: &Layout,
467 kernel: &Self,
468 kernel_l: &Layout,
469 params: &crate::conv::ParamsConvTranspose2D,
470 ) -> Result<Self> {
471 self.same_device(kernel, "conv_transpose2d")?;
472 self.same_dtype(kernel, "conv_transpose2d")?;
473 match (self, &kernel) {
474 (Storage::Cpu(inp), Storage::Cpu(kernel)) => {
475 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
476 Ok(Self::Cpu(s))
477 }
478 (Storage::Cuda(inp), Storage::Cuda(kernel)) => {
479 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
480 Ok(Self::Cuda(s))
481 }
482 (Storage::Metal(inp), Storage::Metal(kernel)) => {
483 let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
484 Ok(Self::Metal(s))
485 }
486 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
487 lhs: lhs.device().location(),
488 rhs: rhs.device().location(),
489 op: "conv_transpose2d",
490 }
491 .bt()),
492 }
493 }
494
495 pub(crate) fn avg_pool2d(
496 &self,
497 layout: &Layout,
498 kernel_size: (usize, usize),
499 stride: (usize, usize),
500 ) -> Result<Self> {
501 match self {
502 Storage::Cpu(storage) => {
503 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
504 Ok(Self::Cpu(storage))
505 }
506 Self::Cuda(storage) => {
507 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
508 Ok(Self::Cuda(storage))
509 }
510 Self::Metal(storage) => {
511 let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
512 Ok(Self::Metal(storage))
513 }
514 }
515 }
516
517 pub(crate) fn max_pool2d(
518 &self,
519 layout: &Layout,
520 kernel_size: (usize, usize),
521 stride: (usize, usize),
522 ) -> Result<Self> {
523 match self {
524 Storage::Cpu(storage) => {
525 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
526 Ok(Self::Cpu(storage))
527 }
528 Self::Cuda(storage) => {
529 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
530 Ok(Self::Cuda(storage))
531 }
532 Self::Metal(storage) => {
533 let storage = storage.max_pool2d(layout, kernel_size, stride)?;
534 Ok(Self::Metal(storage))
535 }
536 }
537 }
538
539 pub(crate) fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
540 match self {
541 Storage::Cpu(storage) => {
542 let storage = storage.upsample_nearest1d(layout, sz)?;
543 Ok(Self::Cpu(storage))
544 }
545 Self::Cuda(storage) => {
546 let storage = storage.upsample_nearest1d(layout, sz)?;
547 Ok(Self::Cuda(storage))
548 }
549 Self::Metal(storage) => {
550 let storage = storage.upsample_nearest1d(layout, sz)?;
551 Ok(Self::Metal(storage))
552 }
553 }
554 }
555
556 pub(crate) fn upsample_nearest2d(&self, layout: &Layout, h: usize, w: usize) -> Result<Self> {
557 match self {
558 Storage::Cpu(storage) => {
559 let storage = storage.upsample_nearest2d(layout, h, w)?;
560 Ok(Self::Cpu(storage))
561 }
562 Self::Cuda(storage) => {
563 let storage = storage.upsample_nearest2d(layout, h, w)?;
564 Ok(Self::Cuda(storage))
565 }
566 Self::Metal(storage) => {
567 let storage = storage.upsample_nearest2d(layout, h, w)?;
568 Ok(Self::Metal(storage))
569 }
570 }
571 }
572
573 pub(crate) fn upsample_bilinear2d(
574 &self,
575 layout: &Layout,
576 h: usize,
577 w: usize,
578 align_corners: bool,
579 scale_h: Option<f64>,
580 scale_w: Option<f64>,
581 ) -> Result<Self> {
582 match self {
583 Storage::Cpu(storage) => {
584 let storage =
585 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
586 Ok(Self::Cpu(storage))
587 }
588 Self::Cuda(storage) => {
589 let storage =
590 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
591 Ok(Self::Cuda(storage))
592 }
593 Self::Metal(storage) => {
594 let storage =
595 storage.upsample_bilinear2d(layout, h, w, align_corners, scale_h, scale_w)?;
596 Ok(Self::Metal(storage))
597 }
598 }
599 }
600
601 pub(crate) fn where_cond(
602 &self,
603 layout: &Layout,
604 t: &Self,
605 layout_t: &Layout,
606 f: &Self,
607 layout_f: &Layout,
608 ) -> Result<Self> {
609 self.same_device(t, "where")?;
610 self.same_device(f, "where")?;
611 t.same_dtype(f, "where")?;
612 match (self, t, f) {
613 (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
614 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
615 Ok(Self::Cpu(storage))
616 }
617 (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
618 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
619 Ok(Self::Cuda(storage))
620 }
621 (Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
622 let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
623 Ok(Self::Metal(storage))
624 }
625 (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
626 lhs: lhs.device().location(),
627 rhs: rhs.device().location(),
628 op: "where",
629 }
630 .bt()),
631 }
632 }
633
634 pub(crate) fn gather(
635 &self,
636 l: &Layout,
637 indexes: &Self,
638 indexes_l: &Layout,
639 d: usize,
640 ) -> Result<Self> {
641 self.same_device(indexes, "index-add")?;
642 match (self, indexes) {
643 (Self::Cpu(s), Self::Cpu(indexes)) => {
644 let storage = s.gather(l, indexes, indexes_l, d)?;
645 Ok(Self::Cpu(storage))
646 }
647 (Self::Cuda(s), Self::Cuda(indexes)) => {
648 let storage = s.gather(l, indexes, indexes_l, d)?;
649 Ok(Self::Cuda(storage))
650 }
651 (Self::Metal(s), Self::Metal(indexes)) => {
652 let storage = s.gather(l, indexes, indexes_l, d)?;
653 Ok(Self::Metal(storage))
654 }
655 _ => unreachable!(),
656 }
657 }
658
659 pub(crate) fn scatter_set(
660 &mut self,
661 l: &Layout,
662 indexes: &Self,
663 indexes_l: &Layout,
664 source: &Self,
665 source_l: &Layout,
666 d: usize,
667 ) -> Result<()> {
668 self.same_device(indexes, "scatter-set")?;
669 self.same_device(source, "scatter-set")?;
670 match (self, indexes, source) {
671 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
672 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
673 }
674 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
675 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
676 }
677 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
678 s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
679 }
680 _ => unreachable!(),
681 }
682 Ok(())
683 }
684
685 pub(crate) fn scatter_add(
686 &mut self,
687 l: &Layout,
688 indexes: &Self,
689 indexes_l: &Layout,
690 source: &Self,
691 source_l: &Layout,
692 d: usize,
693 ) -> Result<()> {
694 self.same_device(indexes, "scatter-add")?;
695 self.same_device(source, "scatter-add")?;
696 match (self, indexes, source) {
697 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
698 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
699 }
700 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
701 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
702 }
703 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
704 s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
705 }
706 _ => unreachable!(),
707 }
708 Ok(())
709 }
710
711 pub(crate) fn index_add(
712 &self,
713 l: &Layout,
714 indexes: &Self,
715 indexes_l: &Layout,
716 source: &Self,
717 source_l: &Layout,
718 d: usize,
719 ) -> Result<Self> {
720 self.same_device(indexes, "index-add")?;
721 self.same_device(source, "index-add")?;
722 match (self, indexes, source) {
723 (Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
724 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
725 Ok(Self::Cpu(storage))
726 }
727 (Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
728 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
729 Ok(Self::Cuda(storage))
730 }
731 (Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
732 let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
733 Ok(Self::Metal(storage))
734 }
735 _ => unreachable!(),
736 }
737 }
738
739 pub(crate) fn index_select(
740 &self,
741 rhs: &Self,
742 lhs_l: &Layout,
743 rhs_l: &Layout,
744 d: usize,
745 ) -> Result<Self> {
746 self.same_device(rhs, "index-select")?;
747 match (self, rhs) {
748 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
749 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
750 Ok(Self::Cpu(storage))
751 }
752 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
753 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
754 Ok(Self::Cuda(storage))
755 }
756 (Self::Metal(lhs), Self::Metal(rhs)) => {
757 let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
758 Ok(Self::Metal(storage))
759 }
760 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
761 lhs: lhs.device().location(),
762 rhs: rhs.device().location(),
763 op: "index-select",
764 }
765 .bt()),
766 }
767 }
768
769 pub(crate) fn matmul(
770 &self,
771 rhs: &Self,
772 bmnk: (usize, usize, usize, usize),
773 lhs_layout: &Layout,
774 rhs_layout: &Layout,
775 ) -> Result<Self> {
776 self.same_device(rhs, "matmul")?;
777 self.same_dtype(rhs, "matmul")?;
778 match (self, rhs) {
779 (Self::Cpu(lhs), Self::Cpu(rhs)) => {
780 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
781 Ok(Self::Cpu(storage))
782 }
783 (Self::Cuda(lhs), Self::Cuda(rhs)) => {
784 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
785 Ok(Self::Cuda(storage))
786 }
787 (Self::Metal(lhs), Self::Metal(rhs)) => {
788 let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
789 Ok(Self::Metal(storage))
790 }
791 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
792 lhs: lhs.device().location(),
793 rhs: rhs.device().location(),
794 op: "matmul",
795 }
796 .bt()),
797 }
798 }
799
800 pub(crate) fn copy_strided_src(
802 &self,
803 dst: &mut Self,
804 dst_offset: usize,
805 src_l: &Layout,
806 ) -> Result<()> {
807 match (self, dst) {
808 (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
809 (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
810 (Self::Metal(src), Self::Metal(dst)) => {
811 Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
812 }
813 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
814 lhs: lhs.device().location(),
815 rhs: rhs.device().location(),
816 op: "copy",
817 }
818 .bt()),
819 }
820 }
821
822 #[allow(clippy::too_many_arguments)]
823 pub(crate) fn copy2d(
824 &self,
825 dst: &mut Self,
826 d1: usize,
827 d2: usize,
828 src_s: usize,
829 dst_s: usize,
830 src_o: usize,
831 dst_o: usize,
832 ) -> Result<()> {
833 match (self, dst) {
834 (Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
835 (Self::Cuda(src), Self::Cuda(dst)) => {
836 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
837 }
838 (Self::Metal(src), Self::Metal(dst)) => {
839 Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
840 }
841 (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
842 lhs: lhs.device().location(),
843 rhs: rhs.device().location(),
844 op: "copy2d",
845 }
846 .bt()),
847 }
848 }
849}