Skip to main content

hanzo_ml/
device.rs

1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5/// A `DeviceLocation` represents a physical device whereas multiple `Device`
6/// can live on the same location (typically for cuda devices).
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9    Cpu,
10    Cuda {
11        gpu_id: usize,
12    },
13    Metal {
14        gpu_id: usize,
15    },
16    #[cfg(feature = "rocm")]
17    Rocm {
18        gpu_id: usize,
19    },
20    #[cfg(feature = "vulkan")]
21    Vulkan {
22        gpu_id: usize,
23    },
24}
25
26/// Cpu, Cuda, or Metal
27#[derive(Debug, Clone)]
28pub enum Device {
29    Cpu,
30    Cuda(crate::CudaDevice),
31    Metal(crate::MetalDevice),
32    #[cfg(feature = "rocm")]
33    Rocm(crate::RocmDevice),
34    #[cfg(feature = "vulkan")]
35    Vulkan(crate::VulkanDevice),
36}
37
38pub trait NdArray {
39    fn shape(&self) -> Result<Shape>;
40
41    fn to_cpu_storage(&self) -> CpuStorage;
42}
43
44impl<S: WithDType> NdArray for S {
45    fn shape(&self) -> Result<Shape> {
46        Ok(Shape::from(()))
47    }
48
49    fn to_cpu_storage(&self) -> CpuStorage {
50        S::to_cpu_storage(&[*self])
51    }
52}
53
54impl<S: WithDType, const N: usize> NdArray for &[S; N] {
55    fn shape(&self) -> Result<Shape> {
56        Ok(Shape::from(self.len()))
57    }
58
59    fn to_cpu_storage(&self) -> CpuStorage {
60        S::to_cpu_storage(self.as_slice())
61    }
62}
63
64impl<S: WithDType> NdArray for &[S] {
65    fn shape(&self) -> Result<Shape> {
66        Ok(Shape::from(self.len()))
67    }
68
69    fn to_cpu_storage(&self) -> CpuStorage {
70        S::to_cpu_storage(self)
71    }
72}
73
74impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
75    fn shape(&self) -> Result<Shape> {
76        Ok(Shape::from((M, N)))
77    }
78
79    fn to_cpu_storage(&self) -> CpuStorage {
80        S::to_cpu_storage_owned(self.concat())
81    }
82}
83
84impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
85    for &[[[S; N3]; N2]; N1]
86{
87    fn shape(&self) -> Result<Shape> {
88        Ok(Shape::from((N1, N2, N3)))
89    }
90
91    fn to_cpu_storage(&self) -> CpuStorage {
92        let mut vec = Vec::with_capacity(N1 * N2 * N3);
93        for i1 in 0..N1 {
94            for i2 in 0..N2 {
95                vec.extend(self[i1][i2])
96            }
97        }
98        S::to_cpu_storage_owned(vec)
99    }
100}
101
102impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
103    for &[[[[S; N4]; N3]; N2]; N1]
104{
105    fn shape(&self) -> Result<Shape> {
106        Ok(Shape::from((N1, N2, N3, N4)))
107    }
108
109    fn to_cpu_storage(&self) -> CpuStorage {
110        let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
111        for i1 in 0..N1 {
112            for i2 in 0..N2 {
113                for i3 in 0..N3 {
114                    vec.extend(self[i1][i2][i3])
115                }
116            }
117        }
118        S::to_cpu_storage_owned(vec)
119    }
120}
121
122impl<S: WithDType> NdArray for Vec<S> {
123    fn shape(&self) -> Result<Shape> {
124        Ok(Shape::from(self.len()))
125    }
126
127    fn to_cpu_storage(&self) -> CpuStorage {
128        S::to_cpu_storage(self.as_slice())
129    }
130}
131
132impl<S: WithDType> NdArray for Vec<&[S]> {
133    fn shape(&self) -> Result<Shape> {
134        if self.is_empty() {
135            crate::bail!("empty array")
136        }
137        let n = self.len();
138        let m = self[0].len();
139        for v in self.iter() {
140            if v.len() != m {
141                crate::bail!("two elements have different len {m} {}", v.len())
142            }
143        }
144        Ok(Shape::from((n, m)))
145    }
146
147    fn to_cpu_storage(&self) -> CpuStorage {
148        let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
149        S::to_cpu_storage_owned(data)
150    }
151}
152
153impl<S: WithDType> NdArray for Vec<Vec<S>> {
154    fn shape(&self) -> Result<Shape> {
155        if self.is_empty() {
156            crate::bail!("empty array")
157        }
158        let n = self.len();
159        let m = self[0].len();
160        for v in self.iter() {
161            if v.len() != m {
162                crate::bail!("two elements have different len {m} {}", v.len())
163            }
164        }
165        Ok(Shape::from((n, m)))
166    }
167
168    fn to_cpu_storage(&self) -> CpuStorage {
169        let len: usize = self.iter().map(|v| v.len()).sum();
170        let mut dst = Vec::with_capacity(len);
171        for v in self.iter() {
172            dst.extend(v.iter().copied());
173        }
174        S::to_cpu_storage_owned(dst)
175    }
176}
177
178impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
179    fn shape(&self) -> Result<Shape> {
180        if self.is_empty() {
181            crate::bail!("empty array")
182        }
183        let shape0 = self[0].shape()?;
184        let n = self.len();
185        for v in self.iter() {
186            let shape = v.shape()?;
187            if shape != shape0 {
188                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
189            }
190        }
191        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
192    }
193
194    fn to_cpu_storage(&self) -> CpuStorage {
195        if self.is_empty() {
196            return S::to_cpu_storage_owned(vec![]);
197        }
198        let len: usize = self
199            .iter()
200            .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
201            .sum();
202        let mut dst = Vec::with_capacity(len);
203        for v1 in self.iter() {
204            for v2 in v1.iter() {
205                dst.extend(v2.iter().copied());
206            }
207        }
208        S::to_cpu_storage_owned(dst)
209    }
210}
211
212impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
213    fn shape(&self) -> Result<Shape> {
214        if self.is_empty() {
215            crate::bail!("empty array")
216        }
217        let shape0 = self[0].shape()?;
218        let n = self.len();
219        for v in self.iter() {
220            let shape = v.shape()?;
221            if shape != shape0 {
222                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
223            }
224        }
225        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
226    }
227
228    fn to_cpu_storage(&self) -> CpuStorage {
229        let len: usize = self
230            .iter()
231            .map(|v| {
232                v.iter()
233                    .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
234                    .sum::<usize>()
235            })
236            .sum();
237        let mut dst = Vec::with_capacity(len);
238        for v1 in self.iter() {
239            for v2 in v1.iter() {
240                for v3 in v2.iter() {
241                    dst.extend(v3.iter().copied());
242                }
243            }
244        }
245        S::to_cpu_storage_owned(dst)
246    }
247}
248
249impl Device {
250    pub fn new_cuda(ordinal: usize) -> Result<Self> {
251        Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
252    }
253
254    #[cfg(feature = "rocm")]
255    pub fn new_rocm(ordinal: usize) -> Result<Self> {
256        Ok(Self::Rocm(crate::RocmDevice::new(ordinal)?))
257    }
258    #[cfg(feature = "vulkan")]
259    pub fn new_vulkan(ordinal: usize) -> Result<Self> {
260        Ok(Self::Vulkan(crate::VulkanDevice::new(ordinal)?))
261    }
262
263    pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
264        match self {
265            Self::Cuda(d) => Ok(d),
266            Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
267            Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
268            #[cfg(feature = "rocm")]
269            Self::Rocm(_) => crate::bail!("expected a cuda device, got rocm"),
270            #[cfg(feature = "vulkan")]
271            Self::Vulkan(_) => crate::bail!("expected a cuda device, got vulkan"),
272        }
273    }
274
275    pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
276        match self {
277            Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
278            Self::Cpu => crate::bail!("expected a metal device, got cpu"),
279            Self::Metal(d) => Ok(d),
280            #[cfg(feature = "rocm")]
281            Self::Rocm(_) => crate::bail!("expected a metal device, got rocm"),
282            #[cfg(feature = "vulkan")]
283            Self::Vulkan(_) => crate::bail!("expected a metal device, got vulkan"),
284        }
285    }
286
287    #[cfg(feature = "rocm")]
288    pub fn as_rocm_device(&self) -> Result<&crate::RocmDevice> {
289        match self {
290            Self::Cuda(_) => crate::bail!("expected a rocm device, got cuda"),
291            Self::Cpu => crate::bail!("expected a rocm device, got cpu"),
292            Self::Metal(_) => crate::bail!("expected a rocm device, got Metal"),
293            Self::Rocm(d) => Ok(d),
294        }
295    }
296    #[cfg(feature = "vulkan")]
297    pub fn as_vulkan_device(&self) -> Result<&crate::VulkanDevice> {
298        match self {
299            Self::Cuda(_) => crate::bail!("expected a vulkan device, got cuda"),
300            Self::Cpu => crate::bail!("expected a vulkan device, got cpu"),
301            Self::Metal(_) => crate::bail!("expected a vulkan device, got Metal"),
302            Self::Vulkan(d) => Ok(d),
303        }
304    }
305
306    pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
307        Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
308    }
309
310    pub fn new_metal(ordinal: usize) -> Result<Self> {
311        Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
312    }
313
314    pub fn set_seed(&self, seed: u64) -> Result<()> {
315        match self {
316            Self::Cpu => CpuDevice.set_seed(seed),
317            Self::Cuda(c) => c.set_seed(seed),
318            Self::Metal(m) => m.set_seed(seed),
319            #[cfg(feature = "rocm")]
320            Self::Rocm(r) => r.set_seed(seed),
321            #[cfg(feature = "vulkan")]
322            Self::Vulkan(r) => r.set_seed(seed),
323        }
324    }
325
326    pub fn get_current_seed(&self) -> Result<u64> {
327        match self {
328            Self::Cpu => CpuDevice.get_current_seed(),
329            Self::Cuda(c) => c.get_current_seed(),
330            Self::Metal(m) => m.get_current_seed(),
331            #[cfg(feature = "rocm")]
332            Self::Rocm(r) => r.get_current_seed(),
333            #[cfg(feature = "vulkan")]
334            Self::Vulkan(r) => r.get_current_seed(),
335        }
336    }
337
338    pub fn same_device(&self, rhs: &Self) -> bool {
339        match (self, rhs) {
340            (Self::Cpu, Self::Cpu) => true,
341            (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
342            (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
343            #[cfg(feature = "rocm")]
344            (Self::Rocm(lhs), Self::Rocm(rhs)) => lhs.same_device(rhs),
345            #[cfg(feature = "vulkan")]
346            (Self::Vulkan(lhs), Self::Vulkan(rhs)) => lhs.same_device(rhs),
347            _ => false,
348        }
349    }
350
351    pub fn location(&self) -> DeviceLocation {
352        match self {
353            Self::Cpu => DeviceLocation::Cpu,
354            Self::Cuda(device) => device.location(),
355            Device::Metal(device) => device.location(),
356            #[cfg(feature = "rocm")]
357            Self::Rocm(device) => device.location(),
358            #[cfg(feature = "vulkan")]
359            Self::Vulkan(device) => device.location(),
360        }
361    }
362
363    pub fn is_cpu(&self) -> bool {
364        matches!(self, Self::Cpu)
365    }
366
367    pub fn is_cuda(&self) -> bool {
368        matches!(self, Self::Cuda(_))
369    }
370
371    pub fn is_metal(&self) -> bool {
372        matches!(self, Self::Metal(_))
373    }
374
375    pub fn is_rocm(&self) -> bool {
376        #[cfg(feature = "rocm")]
377        {
378            matches!(self, Self::Rocm(_))
379        }
380        #[cfg(not(feature = "rocm"))]
381        {
382            false
383        }
384    }
385
386    pub fn is_vulkan(&self) -> bool {
387        #[cfg(feature = "vulkan")]
388        {
389            matches!(self, Self::Vulkan(_))
390        }
391        #[cfg(not(feature = "vulkan"))]
392        {
393            false
394        }
395    }
396
397    pub fn supports_bf16(&self) -> bool {
398        match self {
399            Self::Cuda(_) | Self::Metal(_) => true,
400            Self::Cpu => false,
401            #[cfg(feature = "rocm")]
402            Self::Rocm(_) => true,
403            // Dozen/D3D12 Vulkan path on the 8060S has no native bf16; the
404            // backend is f32/u32-only, so default away from bf16.
405            #[cfg(feature = "vulkan")]
406            Self::Vulkan(_) => false,
407        }
408    }
409
410    /// Return `BF16` for devices that support it, otherwise default to `F32`.
411    pub fn bf16_default_to_f32(&self) -> DType {
412        if self.supports_bf16() {
413            DType::BF16
414        } else {
415            DType::F32
416        }
417    }
418
419    pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
420        if crate::utils::cuda_is_available() {
421            Self::new_cuda(ordinal)
422        } else {
423            Ok(Self::Cpu)
424        }
425    }
426
427    pub fn metal_if_available(ordinal: usize) -> Result<Self> {
428        if crate::utils::metal_is_available() {
429            Self::new_metal(ordinal)
430        } else {
431            Ok(Self::Cpu)
432        }
433    }
434
435    pub(crate) fn rand_uniform_f64(
436        &self,
437        lo: f64,
438        up: f64,
439        shape: &Shape,
440        dtype: DType,
441    ) -> Result<Storage> {
442        match self {
443            Device::Cpu => {
444                let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
445                Ok(Storage::Cpu(storage))
446            }
447            Device::Cuda(device) => {
448                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
449                if dtype == DType::F16 || dtype == DType::BF16 {
450                    let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
451                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
452                } else {
453                    let storage = device.rand_uniform(shape, dtype, lo, up)?;
454                    Ok(Storage::Cuda(storage))
455                }
456            }
457            Device::Metal(device) => {
458                let storage = device.rand_uniform(shape, dtype, lo, up)?;
459                Ok(Storage::Metal(storage))
460            }
461            #[cfg(feature = "rocm")]
462            Device::Rocm(device) => {
463                let storage = device.rand_uniform(shape, dtype, lo, up)?;
464                Ok(Storage::Rocm(storage))
465            }
466            #[cfg(feature = "vulkan")]
467            Device::Vulkan(device) => {
468                let storage = device.rand_uniform(shape, dtype, lo, up)?;
469                Ok(Storage::Vulkan(storage))
470            }
471        }
472    }
473
474    pub(crate) fn rand_uniform<T: crate::FloatDType>(
475        &self,
476        lo: T,
477        up: T,
478        shape: &Shape,
479    ) -> Result<Storage> {
480        self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
481    }
482
483    pub(crate) fn rand_normal_f64(
484        &self,
485        mean: f64,
486        std: f64,
487        shape: &Shape,
488        dtype: DType,
489    ) -> Result<Storage> {
490        match self {
491            Device::Cpu => {
492                let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
493                Ok(Storage::Cpu(storage))
494            }
495            Device::Cuda(device) => {
496                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
497                if dtype == DType::F16 || dtype == DType::BF16 {
498                    let storage = device.rand_normal(shape, DType::F32, mean, std)?;
499                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
500                } else {
501                    let storage = device.rand_normal(shape, dtype, mean, std)?;
502                    Ok(Storage::Cuda(storage))
503                }
504            }
505            Device::Metal(device) => {
506                let storage = device.rand_normal(shape, dtype, mean, std)?;
507                Ok(Storage::Metal(storage))
508            }
509            #[cfg(feature = "rocm")]
510            Device::Rocm(device) => {
511                let storage = device.rand_normal(shape, dtype, mean, std)?;
512                Ok(Storage::Rocm(storage))
513            }
514            #[cfg(feature = "vulkan")]
515            Device::Vulkan(device) => {
516                let storage = device.rand_normal(shape, dtype, mean, std)?;
517                Ok(Storage::Vulkan(storage))
518            }
519        }
520    }
521
522    pub(crate) fn rand_normal<T: crate::FloatDType>(
523        &self,
524        mean: T,
525        std: T,
526        shape: &Shape,
527    ) -> Result<Storage> {
528        self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
529    }
530
531    pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
532        match self {
533            Device::Cpu => {
534                let storage = CpuDevice.zeros_impl(shape, dtype)?;
535                Ok(Storage::Cpu(storage))
536            }
537            Device::Cuda(device) => {
538                let storage = device.zeros_impl(shape, dtype)?;
539                Ok(Storage::Cuda(storage))
540            }
541            Device::Metal(device) => {
542                let storage = device.zeros_impl(shape, dtype)?;
543                Ok(Storage::Metal(storage))
544            }
545            #[cfg(feature = "rocm")]
546            Device::Rocm(device) => {
547                let storage = device.zeros_impl(shape, dtype)?;
548                Ok(Storage::Rocm(storage))
549            }
550            #[cfg(feature = "vulkan")]
551            Device::Vulkan(device) => {
552                let storage = device.zeros_impl(shape, dtype)?;
553                Ok(Storage::Vulkan(storage))
554            }
555        }
556    }
557
558    pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
559        match self {
560            Device::Cpu => {
561                let storage = CpuDevice.alloc_uninit(shape, dtype)?;
562                Ok(Storage::Cpu(storage))
563            }
564            Device::Cuda(device) => {
565                let storage = device.alloc_uninit(shape, dtype)?;
566                Ok(Storage::Cuda(storage))
567            }
568            Device::Metal(device) => {
569                let storage = device.alloc_uninit(shape, dtype)?;
570                Ok(Storage::Metal(storage))
571            }
572            #[cfg(feature = "rocm")]
573            Device::Rocm(device) => {
574                let storage = device.alloc_uninit(shape, dtype)?;
575                Ok(Storage::Rocm(storage))
576            }
577            #[cfg(feature = "vulkan")]
578            Device::Vulkan(device) => {
579                let storage = device.alloc_uninit(shape, dtype)?;
580                Ok(Storage::Vulkan(storage))
581            }
582        }
583    }
584
585    pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
586        match self {
587            Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
588            Device::Cuda(device) => {
589                let storage = device.storage_from_slice(data)?;
590                Ok(Storage::Cuda(storage))
591            }
592            Device::Metal(device) => {
593                let storage = device.storage_from_slice(data)?;
594                Ok(Storage::Metal(storage))
595            }
596            #[cfg(feature = "rocm")]
597            Device::Rocm(device) => {
598                let storage = device.storage_from_slice(data)?;
599                Ok(Storage::Rocm(storage))
600            }
601            #[cfg(feature = "vulkan")]
602            Device::Vulkan(device) => {
603                let storage = device.storage_from_slice(data)?;
604                Ok(Storage::Vulkan(storage))
605            }
606        }
607    }
608
609    pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
610        match self {
611            Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
612            Device::Cuda(device) => {
613                let storage = array.to_cpu_storage();
614                let storage = device.storage_from_cpu_storage_owned(storage)?;
615                Ok(Storage::Cuda(storage))
616            }
617            Device::Metal(device) => {
618                let storage = array.to_cpu_storage();
619                let storage = device.storage_from_cpu_storage_owned(storage)?;
620                Ok(Storage::Metal(storage))
621            }
622            #[cfg(feature = "rocm")]
623            Device::Rocm(device) => {
624                let storage = array.to_cpu_storage();
625                let storage = device.storage_from_cpu_storage_owned(storage)?;
626                Ok(Storage::Rocm(storage))
627            }
628            #[cfg(feature = "vulkan")]
629            Device::Vulkan(device) => {
630                let storage = array.to_cpu_storage();
631                let storage = device.storage_from_cpu_storage_owned(storage)?;
632                Ok(Storage::Vulkan(storage))
633            }
634        }
635    }
636
637    pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
638        match self {
639            Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
640            Device::Cuda(device) => {
641                let storage = S::to_cpu_storage_owned(data);
642                let storage = device.storage_from_cpu_storage_owned(storage)?;
643                Ok(Storage::Cuda(storage))
644            }
645            Device::Metal(device) => {
646                let storage = S::to_cpu_storage_owned(data);
647                let storage = device.storage_from_cpu_storage_owned(storage)?;
648                Ok(Storage::Metal(storage))
649            }
650            #[cfg(feature = "rocm")]
651            Device::Rocm(device) => {
652                let storage = S::to_cpu_storage_owned(data);
653                let storage = device.storage_from_cpu_storage_owned(storage)?;
654                Ok(Storage::Rocm(storage))
655            }
656            #[cfg(feature = "vulkan")]
657            Device::Vulkan(device) => {
658                let storage = S::to_cpu_storage_owned(data);
659                let storage = device.storage_from_cpu_storage_owned(storage)?;
660                Ok(Storage::Vulkan(storage))
661            }
662        }
663    }
664
665    pub fn synchronize(&self) -> Result<()> {
666        match self {
667            Self::Cpu => Ok(()),
668            Self::Cuda(d) => d.synchronize(),
669            Self::Metal(d) => d.synchronize(),
670            #[cfg(feature = "rocm")]
671            Self::Rocm(d) => d.synchronize(),
672            #[cfg(feature = "vulkan")]
673            Self::Vulkan(d) => d.synchronize(),
674        }
675    }
676}