1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9 Cpu,
10 Cuda {
11 gpu_id: usize,
12 },
13 Metal {
14 gpu_id: usize,
15 },
16 #[cfg(feature = "rocm")]
17 Rocm {
18 gpu_id: usize,
19 },
20 #[cfg(feature = "vulkan")]
21 Vulkan {
22 gpu_id: usize,
23 },
24}
25
26#[derive(Debug, Clone)]
28pub enum Device {
29 Cpu,
30 Cuda(crate::CudaDevice),
31 Metal(crate::MetalDevice),
32 #[cfg(feature = "rocm")]
33 Rocm(crate::RocmDevice),
34 #[cfg(feature = "vulkan")]
35 Vulkan(crate::VulkanDevice),
36}
37
38pub trait NdArray {
39 fn shape(&self) -> Result<Shape>;
40
41 fn to_cpu_storage(&self) -> CpuStorage;
42}
43
44impl<S: WithDType> NdArray for S {
45 fn shape(&self) -> Result<Shape> {
46 Ok(Shape::from(()))
47 }
48
49 fn to_cpu_storage(&self) -> CpuStorage {
50 S::to_cpu_storage(&[*self])
51 }
52}
53
54impl<S: WithDType, const N: usize> NdArray for &[S; N] {
55 fn shape(&self) -> Result<Shape> {
56 Ok(Shape::from(self.len()))
57 }
58
59 fn to_cpu_storage(&self) -> CpuStorage {
60 S::to_cpu_storage(self.as_slice())
61 }
62}
63
64impl<S: WithDType> NdArray for &[S] {
65 fn shape(&self) -> Result<Shape> {
66 Ok(Shape::from(self.len()))
67 }
68
69 fn to_cpu_storage(&self) -> CpuStorage {
70 S::to_cpu_storage(self)
71 }
72}
73
74impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
75 fn shape(&self) -> Result<Shape> {
76 Ok(Shape::from((M, N)))
77 }
78
79 fn to_cpu_storage(&self) -> CpuStorage {
80 S::to_cpu_storage_owned(self.concat())
81 }
82}
83
84impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
85 for &[[[S; N3]; N2]; N1]
86{
87 fn shape(&self) -> Result<Shape> {
88 Ok(Shape::from((N1, N2, N3)))
89 }
90
91 fn to_cpu_storage(&self) -> CpuStorage {
92 let mut vec = Vec::with_capacity(N1 * N2 * N3);
93 for i1 in 0..N1 {
94 for i2 in 0..N2 {
95 vec.extend(self[i1][i2])
96 }
97 }
98 S::to_cpu_storage_owned(vec)
99 }
100}
101
102impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
103 for &[[[[S; N4]; N3]; N2]; N1]
104{
105 fn shape(&self) -> Result<Shape> {
106 Ok(Shape::from((N1, N2, N3, N4)))
107 }
108
109 fn to_cpu_storage(&self) -> CpuStorage {
110 let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
111 for i1 in 0..N1 {
112 for i2 in 0..N2 {
113 for i3 in 0..N3 {
114 vec.extend(self[i1][i2][i3])
115 }
116 }
117 }
118 S::to_cpu_storage_owned(vec)
119 }
120}
121
122impl<S: WithDType> NdArray for Vec<S> {
123 fn shape(&self) -> Result<Shape> {
124 Ok(Shape::from(self.len()))
125 }
126
127 fn to_cpu_storage(&self) -> CpuStorage {
128 S::to_cpu_storage(self.as_slice())
129 }
130}
131
132impl<S: WithDType> NdArray for Vec<&[S]> {
133 fn shape(&self) -> Result<Shape> {
134 if self.is_empty() {
135 crate::bail!("empty array")
136 }
137 let n = self.len();
138 let m = self[0].len();
139 for v in self.iter() {
140 if v.len() != m {
141 crate::bail!("two elements have different len {m} {}", v.len())
142 }
143 }
144 Ok(Shape::from((n, m)))
145 }
146
147 fn to_cpu_storage(&self) -> CpuStorage {
148 let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
149 S::to_cpu_storage_owned(data)
150 }
151}
152
153impl<S: WithDType> NdArray for Vec<Vec<S>> {
154 fn shape(&self) -> Result<Shape> {
155 if self.is_empty() {
156 crate::bail!("empty array")
157 }
158 let n = self.len();
159 let m = self[0].len();
160 for v in self.iter() {
161 if v.len() != m {
162 crate::bail!("two elements have different len {m} {}", v.len())
163 }
164 }
165 Ok(Shape::from((n, m)))
166 }
167
168 fn to_cpu_storage(&self) -> CpuStorage {
169 let len: usize = self.iter().map(|v| v.len()).sum();
170 let mut dst = Vec::with_capacity(len);
171 for v in self.iter() {
172 dst.extend(v.iter().copied());
173 }
174 S::to_cpu_storage_owned(dst)
175 }
176}
177
178impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
179 fn shape(&self) -> Result<Shape> {
180 if self.is_empty() {
181 crate::bail!("empty array")
182 }
183 let shape0 = self[0].shape()?;
184 let n = self.len();
185 for v in self.iter() {
186 let shape = v.shape()?;
187 if shape != shape0 {
188 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
189 }
190 }
191 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
192 }
193
194 fn to_cpu_storage(&self) -> CpuStorage {
195 if self.is_empty() {
196 return S::to_cpu_storage_owned(vec![]);
197 }
198 let len: usize = self
199 .iter()
200 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
201 .sum();
202 let mut dst = Vec::with_capacity(len);
203 for v1 in self.iter() {
204 for v2 in v1.iter() {
205 dst.extend(v2.iter().copied());
206 }
207 }
208 S::to_cpu_storage_owned(dst)
209 }
210}
211
212impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
213 fn shape(&self) -> Result<Shape> {
214 if self.is_empty() {
215 crate::bail!("empty array")
216 }
217 let shape0 = self[0].shape()?;
218 let n = self.len();
219 for v in self.iter() {
220 let shape = v.shape()?;
221 if shape != shape0 {
222 crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
223 }
224 }
225 Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
226 }
227
228 fn to_cpu_storage(&self) -> CpuStorage {
229 let len: usize = self
230 .iter()
231 .map(|v| {
232 v.iter()
233 .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
234 .sum::<usize>()
235 })
236 .sum();
237 let mut dst = Vec::with_capacity(len);
238 for v1 in self.iter() {
239 for v2 in v1.iter() {
240 for v3 in v2.iter() {
241 dst.extend(v3.iter().copied());
242 }
243 }
244 }
245 S::to_cpu_storage_owned(dst)
246 }
247}
248
249impl Device {
250 pub fn new_cuda(ordinal: usize) -> Result<Self> {
251 Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
252 }
253
254 #[cfg(feature = "rocm")]
255 pub fn new_rocm(ordinal: usize) -> Result<Self> {
256 Ok(Self::Rocm(crate::RocmDevice::new(ordinal)?))
257 }
258 #[cfg(feature = "vulkan")]
259 pub fn new_vulkan(ordinal: usize) -> Result<Self> {
260 Ok(Self::Vulkan(crate::VulkanDevice::new(ordinal)?))
261 }
262
263 pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
264 match self {
265 Self::Cuda(d) => Ok(d),
266 Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
267 Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
268 #[cfg(feature = "rocm")]
269 Self::Rocm(_) => crate::bail!("expected a cuda device, got rocm"),
270 #[cfg(feature = "vulkan")]
271 Self::Vulkan(_) => crate::bail!("expected a cuda device, got vulkan"),
272 }
273 }
274
275 pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
276 match self {
277 Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
278 Self::Cpu => crate::bail!("expected a metal device, got cpu"),
279 Self::Metal(d) => Ok(d),
280 #[cfg(feature = "rocm")]
281 Self::Rocm(_) => crate::bail!("expected a metal device, got rocm"),
282 #[cfg(feature = "vulkan")]
283 Self::Vulkan(_) => crate::bail!("expected a metal device, got vulkan"),
284 }
285 }
286
287 #[cfg(feature = "rocm")]
288 pub fn as_rocm_device(&self) -> Result<&crate::RocmDevice> {
289 match self {
290 Self::Cuda(_) => crate::bail!("expected a rocm device, got cuda"),
291 Self::Cpu => crate::bail!("expected a rocm device, got cpu"),
292 Self::Metal(_) => crate::bail!("expected a rocm device, got Metal"),
293 Self::Rocm(d) => Ok(d),
294 }
295 }
296 #[cfg(feature = "vulkan")]
297 pub fn as_vulkan_device(&self) -> Result<&crate::VulkanDevice> {
298 match self {
299 Self::Cuda(_) => crate::bail!("expected a vulkan device, got cuda"),
300 Self::Cpu => crate::bail!("expected a vulkan device, got cpu"),
301 Self::Metal(_) => crate::bail!("expected a vulkan device, got Metal"),
302 Self::Vulkan(d) => Ok(d),
303 }
304 }
305
306 pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
307 Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
308 }
309
310 pub fn new_metal(ordinal: usize) -> Result<Self> {
311 Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
312 }
313
314 pub fn set_seed(&self, seed: u64) -> Result<()> {
315 match self {
316 Self::Cpu => CpuDevice.set_seed(seed),
317 Self::Cuda(c) => c.set_seed(seed),
318 Self::Metal(m) => m.set_seed(seed),
319 #[cfg(feature = "rocm")]
320 Self::Rocm(r) => r.set_seed(seed),
321 #[cfg(feature = "vulkan")]
322 Self::Vulkan(r) => r.set_seed(seed),
323 }
324 }
325
326 pub fn get_current_seed(&self) -> Result<u64> {
327 match self {
328 Self::Cpu => CpuDevice.get_current_seed(),
329 Self::Cuda(c) => c.get_current_seed(),
330 Self::Metal(m) => m.get_current_seed(),
331 #[cfg(feature = "rocm")]
332 Self::Rocm(r) => r.get_current_seed(),
333 #[cfg(feature = "vulkan")]
334 Self::Vulkan(r) => r.get_current_seed(),
335 }
336 }
337
338 pub fn same_device(&self, rhs: &Self) -> bool {
339 match (self, rhs) {
340 (Self::Cpu, Self::Cpu) => true,
341 (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
342 (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
343 #[cfg(feature = "rocm")]
344 (Self::Rocm(lhs), Self::Rocm(rhs)) => lhs.same_device(rhs),
345 #[cfg(feature = "vulkan")]
346 (Self::Vulkan(lhs), Self::Vulkan(rhs)) => lhs.same_device(rhs),
347 _ => false,
348 }
349 }
350
351 pub fn location(&self) -> DeviceLocation {
352 match self {
353 Self::Cpu => DeviceLocation::Cpu,
354 Self::Cuda(device) => device.location(),
355 Device::Metal(device) => device.location(),
356 #[cfg(feature = "rocm")]
357 Self::Rocm(device) => device.location(),
358 #[cfg(feature = "vulkan")]
359 Self::Vulkan(device) => device.location(),
360 }
361 }
362
363 pub fn is_cpu(&self) -> bool {
364 matches!(self, Self::Cpu)
365 }
366
367 pub fn is_cuda(&self) -> bool {
368 matches!(self, Self::Cuda(_))
369 }
370
371 pub fn is_metal(&self) -> bool {
372 matches!(self, Self::Metal(_))
373 }
374
375 pub fn is_rocm(&self) -> bool {
376 #[cfg(feature = "rocm")]
377 {
378 matches!(self, Self::Rocm(_))
379 }
380 #[cfg(not(feature = "rocm"))]
381 {
382 false
383 }
384 }
385
386 pub fn is_vulkan(&self) -> bool {
387 #[cfg(feature = "vulkan")]
388 {
389 matches!(self, Self::Vulkan(_))
390 }
391 #[cfg(not(feature = "vulkan"))]
392 {
393 false
394 }
395 }
396
397 pub fn supports_bf16(&self) -> bool {
398 match self {
399 Self::Cuda(_) | Self::Metal(_) => true,
400 Self::Cpu => false,
401 #[cfg(feature = "rocm")]
402 Self::Rocm(_) => true,
403 #[cfg(feature = "vulkan")]
406 Self::Vulkan(_) => false,
407 }
408 }
409
410 pub fn bf16_default_to_f32(&self) -> DType {
412 if self.supports_bf16() {
413 DType::BF16
414 } else {
415 DType::F32
416 }
417 }
418
419 pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
420 if crate::utils::cuda_is_available() {
421 Self::new_cuda(ordinal)
422 } else {
423 Ok(Self::Cpu)
424 }
425 }
426
427 pub fn metal_if_available(ordinal: usize) -> Result<Self> {
428 if crate::utils::metal_is_available() {
429 Self::new_metal(ordinal)
430 } else {
431 Ok(Self::Cpu)
432 }
433 }
434
435 pub(crate) fn rand_uniform_f64(
436 &self,
437 lo: f64,
438 up: f64,
439 shape: &Shape,
440 dtype: DType,
441 ) -> Result<Storage> {
442 match self {
443 Device::Cpu => {
444 let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
445 Ok(Storage::Cpu(storage))
446 }
447 Device::Cuda(device) => {
448 if dtype == DType::F16 || dtype == DType::BF16 {
450 let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
451 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
452 } else {
453 let storage = device.rand_uniform(shape, dtype, lo, up)?;
454 Ok(Storage::Cuda(storage))
455 }
456 }
457 Device::Metal(device) => {
458 let storage = device.rand_uniform(shape, dtype, lo, up)?;
459 Ok(Storage::Metal(storage))
460 }
461 #[cfg(feature = "rocm")]
462 Device::Rocm(device) => {
463 let storage = device.rand_uniform(shape, dtype, lo, up)?;
464 Ok(Storage::Rocm(storage))
465 }
466 #[cfg(feature = "vulkan")]
467 Device::Vulkan(device) => {
468 let storage = device.rand_uniform(shape, dtype, lo, up)?;
469 Ok(Storage::Vulkan(storage))
470 }
471 }
472 }
473
474 pub(crate) fn rand_uniform<T: crate::FloatDType>(
475 &self,
476 lo: T,
477 up: T,
478 shape: &Shape,
479 ) -> Result<Storage> {
480 self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
481 }
482
483 pub(crate) fn rand_normal_f64(
484 &self,
485 mean: f64,
486 std: f64,
487 shape: &Shape,
488 dtype: DType,
489 ) -> Result<Storage> {
490 match self {
491 Device::Cpu => {
492 let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
493 Ok(Storage::Cpu(storage))
494 }
495 Device::Cuda(device) => {
496 if dtype == DType::F16 || dtype == DType::BF16 {
498 let storage = device.rand_normal(shape, DType::F32, mean, std)?;
499 Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
500 } else {
501 let storage = device.rand_normal(shape, dtype, mean, std)?;
502 Ok(Storage::Cuda(storage))
503 }
504 }
505 Device::Metal(device) => {
506 let storage = device.rand_normal(shape, dtype, mean, std)?;
507 Ok(Storage::Metal(storage))
508 }
509 #[cfg(feature = "rocm")]
510 Device::Rocm(device) => {
511 let storage = device.rand_normal(shape, dtype, mean, std)?;
512 Ok(Storage::Rocm(storage))
513 }
514 #[cfg(feature = "vulkan")]
515 Device::Vulkan(device) => {
516 let storage = device.rand_normal(shape, dtype, mean, std)?;
517 Ok(Storage::Vulkan(storage))
518 }
519 }
520 }
521
522 pub(crate) fn rand_normal<T: crate::FloatDType>(
523 &self,
524 mean: T,
525 std: T,
526 shape: &Shape,
527 ) -> Result<Storage> {
528 self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
529 }
530
531 pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
532 match self {
533 Device::Cpu => {
534 let storage = CpuDevice.zeros_impl(shape, dtype)?;
535 Ok(Storage::Cpu(storage))
536 }
537 Device::Cuda(device) => {
538 let storage = device.zeros_impl(shape, dtype)?;
539 Ok(Storage::Cuda(storage))
540 }
541 Device::Metal(device) => {
542 let storage = device.zeros_impl(shape, dtype)?;
543 Ok(Storage::Metal(storage))
544 }
545 #[cfg(feature = "rocm")]
546 Device::Rocm(device) => {
547 let storage = device.zeros_impl(shape, dtype)?;
548 Ok(Storage::Rocm(storage))
549 }
550 #[cfg(feature = "vulkan")]
551 Device::Vulkan(device) => {
552 let storage = device.zeros_impl(shape, dtype)?;
553 Ok(Storage::Vulkan(storage))
554 }
555 }
556 }
557
558 pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
559 match self {
560 Device::Cpu => {
561 let storage = CpuDevice.alloc_uninit(shape, dtype)?;
562 Ok(Storage::Cpu(storage))
563 }
564 Device::Cuda(device) => {
565 let storage = device.alloc_uninit(shape, dtype)?;
566 Ok(Storage::Cuda(storage))
567 }
568 Device::Metal(device) => {
569 let storage = device.alloc_uninit(shape, dtype)?;
570 Ok(Storage::Metal(storage))
571 }
572 #[cfg(feature = "rocm")]
573 Device::Rocm(device) => {
574 let storage = device.alloc_uninit(shape, dtype)?;
575 Ok(Storage::Rocm(storage))
576 }
577 #[cfg(feature = "vulkan")]
578 Device::Vulkan(device) => {
579 let storage = device.alloc_uninit(shape, dtype)?;
580 Ok(Storage::Vulkan(storage))
581 }
582 }
583 }
584
585 pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
586 match self {
587 Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
588 Device::Cuda(device) => {
589 let storage = device.storage_from_slice(data)?;
590 Ok(Storage::Cuda(storage))
591 }
592 Device::Metal(device) => {
593 let storage = device.storage_from_slice(data)?;
594 Ok(Storage::Metal(storage))
595 }
596 #[cfg(feature = "rocm")]
597 Device::Rocm(device) => {
598 let storage = device.storage_from_slice(data)?;
599 Ok(Storage::Rocm(storage))
600 }
601 #[cfg(feature = "vulkan")]
602 Device::Vulkan(device) => {
603 let storage = device.storage_from_slice(data)?;
604 Ok(Storage::Vulkan(storage))
605 }
606 }
607 }
608
609 pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
610 match self {
611 Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
612 Device::Cuda(device) => {
613 let storage = array.to_cpu_storage();
614 let storage = device.storage_from_cpu_storage_owned(storage)?;
615 Ok(Storage::Cuda(storage))
616 }
617 Device::Metal(device) => {
618 let storage = array.to_cpu_storage();
619 let storage = device.storage_from_cpu_storage_owned(storage)?;
620 Ok(Storage::Metal(storage))
621 }
622 #[cfg(feature = "rocm")]
623 Device::Rocm(device) => {
624 let storage = array.to_cpu_storage();
625 let storage = device.storage_from_cpu_storage_owned(storage)?;
626 Ok(Storage::Rocm(storage))
627 }
628 #[cfg(feature = "vulkan")]
629 Device::Vulkan(device) => {
630 let storage = array.to_cpu_storage();
631 let storage = device.storage_from_cpu_storage_owned(storage)?;
632 Ok(Storage::Vulkan(storage))
633 }
634 }
635 }
636
637 pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
638 match self {
639 Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
640 Device::Cuda(device) => {
641 let storage = S::to_cpu_storage_owned(data);
642 let storage = device.storage_from_cpu_storage_owned(storage)?;
643 Ok(Storage::Cuda(storage))
644 }
645 Device::Metal(device) => {
646 let storage = S::to_cpu_storage_owned(data);
647 let storage = device.storage_from_cpu_storage_owned(storage)?;
648 Ok(Storage::Metal(storage))
649 }
650 #[cfg(feature = "rocm")]
651 Device::Rocm(device) => {
652 let storage = S::to_cpu_storage_owned(data);
653 let storage = device.storage_from_cpu_storage_owned(storage)?;
654 Ok(Storage::Rocm(storage))
655 }
656 #[cfg(feature = "vulkan")]
657 Device::Vulkan(device) => {
658 let storage = S::to_cpu_storage_owned(data);
659 let storage = device.storage_from_cpu_storage_owned(storage)?;
660 Ok(Storage::Vulkan(storage))
661 }
662 }
663 }
664
665 pub fn synchronize(&self) -> Result<()> {
666 match self {
667 Self::Cpu => Ok(()),
668 Self::Cuda(d) => d.synchronize(),
669 Self::Metal(d) => d.synchronize(),
670 #[cfg(feature = "rocm")]
671 Self::Rocm(d) => d.synchronize(),
672 #[cfg(feature = "vulkan")]
673 Self::Vulkan(d) => d.synchronize(),
674 }
675 }
676}