candle_core/
device.rs

1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5/// A `DeviceLocation` represents a physical device whereas multiple `Device`
6/// can live on the same location (typically for cuda devices).
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9    Cpu,
10    Cuda { gpu_id: usize },
11    Metal { gpu_id: usize },
12}
13
14/// Cpu, Cuda, or Metal
15#[derive(Debug, Clone)]
16pub enum Device {
17    Cpu,
18    Cuda(crate::CudaDevice),
19    Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23    fn shape(&self) -> Result<Shape>;
24
25    fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29    fn shape(&self) -> Result<Shape> {
30        Ok(Shape::from(()))
31    }
32
33    fn to_cpu_storage(&self) -> CpuStorage {
34        S::to_cpu_storage(&[*self])
35    }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39    fn shape(&self) -> Result<Shape> {
40        Ok(Shape::from(self.len()))
41    }
42
43    fn to_cpu_storage(&self) -> CpuStorage {
44        S::to_cpu_storage(self.as_slice())
45    }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49    fn shape(&self) -> Result<Shape> {
50        Ok(Shape::from(self.len()))
51    }
52
53    fn to_cpu_storage(&self) -> CpuStorage {
54        S::to_cpu_storage(self)
55    }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59    fn shape(&self) -> Result<Shape> {
60        Ok(Shape::from((M, N)))
61    }
62
63    fn to_cpu_storage(&self) -> CpuStorage {
64        S::to_cpu_storage_owned(self.concat())
65    }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69    for &[[[S; N3]; N2]; N1]
70{
71    fn shape(&self) -> Result<Shape> {
72        Ok(Shape::from((N1, N2, N3)))
73    }
74
75    fn to_cpu_storage(&self) -> CpuStorage {
76        let mut vec = Vec::with_capacity(N1 * N2 * N3);
77        for i1 in 0..N1 {
78            for i2 in 0..N2 {
79                vec.extend(self[i1][i2])
80            }
81        }
82        S::to_cpu_storage_owned(vec)
83    }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87    for &[[[[S; N4]; N3]; N2]; N1]
88{
89    fn shape(&self) -> Result<Shape> {
90        Ok(Shape::from((N1, N2, N3, N4)))
91    }
92
93    fn to_cpu_storage(&self) -> CpuStorage {
94        let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95        for i1 in 0..N1 {
96            for i2 in 0..N2 {
97                for i3 in 0..N3 {
98                    vec.extend(self[i1][i2][i3])
99                }
100            }
101        }
102        S::to_cpu_storage_owned(vec)
103    }
104}
105
106impl<S: WithDType> NdArray for Vec<S> {
107    fn shape(&self) -> Result<Shape> {
108        Ok(Shape::from(self.len()))
109    }
110
111    fn to_cpu_storage(&self) -> CpuStorage {
112        S::to_cpu_storage(self.as_slice())
113    }
114}
115
116impl<S: WithDType> NdArray for Vec<&[S]> {
117    fn shape(&self) -> Result<Shape> {
118        if self.is_empty() {
119            crate::bail!("empty array")
120        }
121        let n = self.len();
122        let m = self[0].len();
123        for v in self.iter() {
124            if v.len() != m {
125                crate::bail!("two elements have different len {m} {}", v.len())
126            }
127        }
128        Ok(Shape::from((n, m)))
129    }
130
131    fn to_cpu_storage(&self) -> CpuStorage {
132        let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
133        S::to_cpu_storage_owned(data)
134    }
135}
136
137impl<S: WithDType> NdArray for Vec<Vec<S>> {
138    fn shape(&self) -> Result<Shape> {
139        if self.is_empty() {
140            crate::bail!("empty array")
141        }
142        let n = self.len();
143        let m = self[0].len();
144        for v in self.iter() {
145            if v.len() != m {
146                crate::bail!("two elements have different len {m} {}", v.len())
147            }
148        }
149        Ok(Shape::from((n, m)))
150    }
151
152    fn to_cpu_storage(&self) -> CpuStorage {
153        let len: usize = self.iter().map(|v| v.len()).sum();
154        let mut dst = Vec::with_capacity(len);
155        for v in self.iter() {
156            dst.extend(v.iter().copied());
157        }
158        S::to_cpu_storage_owned(dst)
159    }
160}
161
162impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
163    fn shape(&self) -> Result<Shape> {
164        if self.is_empty() {
165            crate::bail!("empty array")
166        }
167        let shape0 = self[0].shape()?;
168        let n = self.len();
169        for v in self.iter() {
170            let shape = v.shape()?;
171            if shape != shape0 {
172                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
173            }
174        }
175        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
176    }
177
178    fn to_cpu_storage(&self) -> CpuStorage {
179        if self.is_empty() {
180            return S::to_cpu_storage_owned(vec![]);
181        }
182        let len: usize = self
183            .iter()
184            .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
185            .sum();
186        let mut dst = Vec::with_capacity(len);
187        for v1 in self.iter() {
188            for v2 in v1.iter() {
189                dst.extend(v2.iter().copied());
190            }
191        }
192        S::to_cpu_storage_owned(dst)
193    }
194}
195
196impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
197    fn shape(&self) -> Result<Shape> {
198        if self.is_empty() {
199            crate::bail!("empty array")
200        }
201        let shape0 = self[0].shape()?;
202        let n = self.len();
203        for v in self.iter() {
204            let shape = v.shape()?;
205            if shape != shape0 {
206                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
207            }
208        }
209        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
210    }
211
212    fn to_cpu_storage(&self) -> CpuStorage {
213        let len: usize = self
214            .iter()
215            .map(|v| {
216                v.iter()
217                    .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
218                    .sum::<usize>()
219            })
220            .sum();
221        let mut dst = Vec::with_capacity(len);
222        for v1 in self.iter() {
223            for v2 in v1.iter() {
224                for v3 in v2.iter() {
225                    dst.extend(v3.iter().copied());
226                }
227            }
228        }
229        S::to_cpu_storage_owned(dst)
230    }
231}
232
233impl Device {
234    pub fn new_cuda(ordinal: usize) -> Result<Self> {
235        Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
236    }
237
238    pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
239        match self {
240            Self::Cuda(d) => Ok(d),
241            Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
242            Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
243        }
244    }
245
246    pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
247        match self {
248            Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
249            Self::Cpu => crate::bail!("expected a metal device, got cpu"),
250            Self::Metal(d) => Ok(d),
251        }
252    }
253
254    pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
255        Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
256    }
257
258    pub fn new_metal(ordinal: usize) -> Result<Self> {
259        Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
260    }
261
262    pub fn set_seed(&self, seed: u64) -> Result<()> {
263        match self {
264            Self::Cpu => CpuDevice.set_seed(seed),
265            Self::Cuda(c) => c.set_seed(seed),
266            Self::Metal(m) => m.set_seed(seed),
267        }
268    }
269
270    pub fn same_device(&self, rhs: &Self) -> bool {
271        match (self, rhs) {
272            (Self::Cpu, Self::Cpu) => true,
273            (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
274            (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
275            _ => false,
276        }
277    }
278
279    pub fn location(&self) -> DeviceLocation {
280        match self {
281            Self::Cpu => DeviceLocation::Cpu,
282            Self::Cuda(device) => device.location(),
283            Device::Metal(device) => device.location(),
284        }
285    }
286
287    pub fn is_cpu(&self) -> bool {
288        matches!(self, Self::Cpu)
289    }
290
291    pub fn is_cuda(&self) -> bool {
292        matches!(self, Self::Cuda(_))
293    }
294
295    pub fn is_metal(&self) -> bool {
296        matches!(self, Self::Metal(_))
297    }
298
299    pub fn supports_bf16(&self) -> bool {
300        match self {
301            Self::Cuda(_) | Self::Metal(_) => true,
302            Self::Cpu => false,
303        }
304    }
305
306    /// Return `BF16` for devices that support it, otherwise default to `F32`.
307    pub fn bf16_default_to_f32(&self) -> DType {
308        if self.supports_bf16() {
309            DType::BF16
310        } else {
311            DType::F32
312        }
313    }
314
315    pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
316        if crate::utils::cuda_is_available() {
317            Self::new_cuda(ordinal)
318        } else {
319            Ok(Self::Cpu)
320        }
321    }
322
323    pub(crate) fn rand_uniform_f64(
324        &self,
325        lo: f64,
326        up: f64,
327        shape: &Shape,
328        dtype: DType,
329    ) -> Result<Storage> {
330        match self {
331            Device::Cpu => {
332                let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
333                Ok(Storage::Cpu(storage))
334            }
335            Device::Cuda(device) => {
336                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
337                if dtype == DType::F16 || dtype == DType::BF16 {
338                    let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
339                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
340                } else {
341                    let storage = device.rand_uniform(shape, dtype, lo, up)?;
342                    Ok(Storage::Cuda(storage))
343                }
344            }
345            Device::Metal(device) => {
346                let storage = device.rand_uniform(shape, dtype, lo, up)?;
347                Ok(Storage::Metal(storage))
348            }
349        }
350    }
351
352    pub(crate) fn rand_uniform<T: crate::FloatDType>(
353        &self,
354        lo: T,
355        up: T,
356        shape: &Shape,
357    ) -> Result<Storage> {
358        self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
359    }
360
361    pub(crate) fn rand_normal_f64(
362        &self,
363        mean: f64,
364        std: f64,
365        shape: &Shape,
366        dtype: DType,
367    ) -> Result<Storage> {
368        match self {
369            Device::Cpu => {
370                let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
371                Ok(Storage::Cpu(storage))
372            }
373            Device::Cuda(device) => {
374                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
375                if dtype == DType::F16 || dtype == DType::BF16 {
376                    let storage = device.rand_normal(shape, DType::F32, mean, std)?;
377                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
378                } else {
379                    let storage = device.rand_normal(shape, dtype, mean, std)?;
380                    Ok(Storage::Cuda(storage))
381                }
382            }
383            Device::Metal(device) => {
384                let storage = device.rand_normal(shape, dtype, mean, std)?;
385                Ok(Storage::Metal(storage))
386            }
387        }
388    }
389
390    pub(crate) fn rand_normal<T: crate::FloatDType>(
391        &self,
392        mean: T,
393        std: T,
394        shape: &Shape,
395    ) -> Result<Storage> {
396        self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
397    }
398
399    pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
400        match self {
401            Device::Cpu => {
402                let storage = CpuDevice.zeros_impl(shape, dtype)?;
403                Ok(Storage::Cpu(storage))
404            }
405            Device::Cuda(device) => {
406                let storage = device.zeros_impl(shape, dtype)?;
407                Ok(Storage::Cuda(storage))
408            }
409            Device::Metal(device) => {
410                let storage = device.zeros_impl(shape, dtype)?;
411                Ok(Storage::Metal(storage))
412            }
413        }
414    }
415
416    pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
417        match self {
418            Device::Cpu => {
419                let storage = CpuDevice.alloc_uninit(shape, dtype)?;
420                Ok(Storage::Cpu(storage))
421            }
422            Device::Cuda(device) => {
423                let storage = device.alloc_uninit(shape, dtype)?;
424                Ok(Storage::Cuda(storage))
425            }
426            Device::Metal(device) => {
427                let storage = device.alloc_uninit(shape, dtype)?;
428                Ok(Storage::Metal(storage))
429            }
430        }
431    }
432
433    pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
434        match self {
435            Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
436            Device::Cuda(device) => {
437                let storage = device.storage_from_slice(data)?;
438                Ok(Storage::Cuda(storage))
439            }
440            Device::Metal(device) => {
441                let storage = device.storage_from_slice(data)?;
442                Ok(Storage::Metal(storage))
443            }
444        }
445    }
446
447    pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
448        match self {
449            Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
450            Device::Cuda(device) => {
451                let storage = array.to_cpu_storage();
452                let storage = device.storage_from_cpu_storage_owned(storage)?;
453                Ok(Storage::Cuda(storage))
454            }
455            Device::Metal(device) => {
456                let storage = array.to_cpu_storage();
457                let storage = device.storage_from_cpu_storage_owned(storage)?;
458                Ok(Storage::Metal(storage))
459            }
460        }
461    }
462
463    pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
464        match self {
465            Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
466            Device::Cuda(device) => {
467                let storage = S::to_cpu_storage_owned(data);
468                let storage = device.storage_from_cpu_storage_owned(storage)?;
469                Ok(Storage::Cuda(storage))
470            }
471            Device::Metal(device) => {
472                let storage = S::to_cpu_storage_owned(data);
473                let storage = device.storage_from_cpu_storage_owned(storage)?;
474                Ok(Storage::Metal(storage))
475            }
476        }
477    }
478
479    pub fn synchronize(&self) -> Result<()> {
480        match self {
481            Self::Cpu => Ok(()),
482            Self::Cuda(d) => d.synchronize(),
483            Self::Metal(d) => d.synchronize(),
484        }
485    }
486}