candle_core/
device.rs

1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5/// A `DeviceLocation` represents a physical device whereas multiple `Device`
6/// can live on the same location (typically for cuda devices).
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9    Cpu,
10    Cuda { gpu_id: usize },
11    Metal { gpu_id: usize },
12}
13
14/// Cpu, Cuda, or Metal
15#[derive(Debug, Clone)]
16pub enum Device {
17    Cpu,
18    Cuda(crate::CudaDevice),
19    Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23    fn shape(&self) -> Result<Shape>;
24
25    fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29    fn shape(&self) -> Result<Shape> {
30        Ok(Shape::from(()))
31    }
32
33    fn to_cpu_storage(&self) -> CpuStorage {
34        S::to_cpu_storage(&[*self])
35    }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39    fn shape(&self) -> Result<Shape> {
40        Ok(Shape::from(self.len()))
41    }
42
43    fn to_cpu_storage(&self) -> CpuStorage {
44        S::to_cpu_storage(self.as_slice())
45    }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49    fn shape(&self) -> Result<Shape> {
50        Ok(Shape::from(self.len()))
51    }
52
53    fn to_cpu_storage(&self) -> CpuStorage {
54        S::to_cpu_storage(self)
55    }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59    fn shape(&self) -> Result<Shape> {
60        Ok(Shape::from((M, N)))
61    }
62
63    fn to_cpu_storage(&self) -> CpuStorage {
64        S::to_cpu_storage_owned(self.concat())
65    }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69    for &[[[S; N3]; N2]; N1]
70{
71    fn shape(&self) -> Result<Shape> {
72        Ok(Shape::from((N1, N2, N3)))
73    }
74
75    fn to_cpu_storage(&self) -> CpuStorage {
76        let mut vec = Vec::with_capacity(N1 * N2 * N3);
77        for i1 in 0..N1 {
78            for i2 in 0..N2 {
79                vec.extend(self[i1][i2])
80            }
81        }
82        S::to_cpu_storage_owned(vec)
83    }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87    for &[[[[S; N4]; N3]; N2]; N1]
88{
89    fn shape(&self) -> Result<Shape> {
90        Ok(Shape::from((N1, N2, N3, N4)))
91    }
92
93    fn to_cpu_storage(&self) -> CpuStorage {
94        let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95        for i1 in 0..N1 {
96            for i2 in 0..N2 {
97                for i3 in 0..N3 {
98                    vec.extend(self[i1][i2][i3])
99                }
100            }
101        }
102        S::to_cpu_storage_owned(vec)
103    }
104}
105
106impl<S: NdArray> NdArray for Vec<S> {
107    fn shape(&self) -> Result<Shape> {
108        if self.is_empty() {
109            crate::bail!("empty array")
110        }
111        let shape0 = self[0].shape()?;
112        let n = self.len();
113        for v in self.iter() {
114            let shape = v.shape()?;
115            if shape != shape0 {
116                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
117            }
118        }
119        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
120    }
121
122    fn to_cpu_storage(&self) -> CpuStorage {
123        // This allocates intermediary memory and shouldn't be necessary.
124        let storages = self.iter().map(|v| v.to_cpu_storage()).collect::<Vec<_>>();
125        CpuStorage::concat(storages.as_slice()).unwrap()
126    }
127}
128
129impl Device {
130    pub fn new_cuda(ordinal: usize) -> Result<Self> {
131        Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
132    }
133
134    pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
135        match self {
136            Self::Cuda(d) => Ok(d),
137            Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
138            Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
139        }
140    }
141
142    pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
143        match self {
144            Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
145            Self::Cpu => crate::bail!("expected a metal device, got cpu"),
146            Self::Metal(d) => Ok(d),
147        }
148    }
149
150    pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
151        Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
152    }
153
154    pub fn new_metal(ordinal: usize) -> Result<Self> {
155        Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
156    }
157
158    pub fn set_seed(&self, seed: u64) -> Result<()> {
159        match self {
160            Self::Cpu => CpuDevice.set_seed(seed),
161            Self::Cuda(c) => c.set_seed(seed),
162            Self::Metal(m) => m.set_seed(seed),
163        }
164    }
165
166    pub fn same_device(&self, rhs: &Self) -> bool {
167        match (self, rhs) {
168            (Self::Cpu, Self::Cpu) => true,
169            (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
170            (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
171            _ => false,
172        }
173    }
174
175    pub fn location(&self) -> DeviceLocation {
176        match self {
177            Self::Cpu => DeviceLocation::Cpu,
178            Self::Cuda(device) => device.location(),
179            Device::Metal(device) => device.location(),
180        }
181    }
182
183    pub fn is_cpu(&self) -> bool {
184        matches!(self, Self::Cpu)
185    }
186
187    pub fn is_cuda(&self) -> bool {
188        matches!(self, Self::Cuda(_))
189    }
190
191    pub fn is_metal(&self) -> bool {
192        matches!(self, Self::Metal(_))
193    }
194
195    pub fn supports_bf16(&self) -> bool {
196        match self {
197            Self::Cuda(_) | Self::Metal(_) => true,
198            Self::Cpu => false,
199        }
200    }
201
202    /// Return `BF16` for devices that support it, otherwise default to `F32`.
203    pub fn bf16_default_to_f32(&self) -> DType {
204        if self.supports_bf16() {
205            DType::BF16
206        } else {
207            DType::F32
208        }
209    }
210
211    pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
212        if crate::utils::cuda_is_available() {
213            Self::new_cuda(ordinal)
214        } else {
215            Ok(Self::Cpu)
216        }
217    }
218
219    pub(crate) fn rand_uniform_f64(
220        &self,
221        lo: f64,
222        up: f64,
223        shape: &Shape,
224        dtype: DType,
225    ) -> Result<Storage> {
226        match self {
227            Device::Cpu => {
228                let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
229                Ok(Storage::Cpu(storage))
230            }
231            Device::Cuda(device) => {
232                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
233                if dtype == DType::F16 || dtype == DType::BF16 {
234                    let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
235                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
236                } else {
237                    let storage = device.rand_uniform(shape, dtype, lo, up)?;
238                    Ok(Storage::Cuda(storage))
239                }
240            }
241            Device::Metal(device) => {
242                let storage = device.rand_uniform(shape, dtype, lo, up)?;
243                Ok(Storage::Metal(storage))
244            }
245        }
246    }
247
248    pub(crate) fn rand_uniform<T: crate::FloatDType>(
249        &self,
250        lo: T,
251        up: T,
252        shape: &Shape,
253    ) -> Result<Storage> {
254        self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
255    }
256
257    pub(crate) fn rand_normal_f64(
258        &self,
259        mean: f64,
260        std: f64,
261        shape: &Shape,
262        dtype: DType,
263    ) -> Result<Storage> {
264        match self {
265            Device::Cpu => {
266                let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
267                Ok(Storage::Cpu(storage))
268            }
269            Device::Cuda(device) => {
270                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
271                if dtype == DType::F16 || dtype == DType::BF16 {
272                    let storage = device.rand_normal(shape, DType::F32, mean, std)?;
273                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
274                } else {
275                    let storage = device.rand_normal(shape, dtype, mean, std)?;
276                    Ok(Storage::Cuda(storage))
277                }
278            }
279            Device::Metal(device) => {
280                let storage = device.rand_normal(shape, dtype, mean, std)?;
281                Ok(Storage::Metal(storage))
282            }
283        }
284    }
285
286    pub(crate) fn rand_normal<T: crate::FloatDType>(
287        &self,
288        mean: T,
289        std: T,
290        shape: &Shape,
291    ) -> Result<Storage> {
292        self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
293    }
294
295    pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
296        match self {
297            Device::Cpu => {
298                let storage = CpuDevice.ones_impl(shape, dtype)?;
299                Ok(Storage::Cpu(storage))
300            }
301            Device::Cuda(device) => {
302                let storage = device.ones_impl(shape, dtype)?;
303                Ok(Storage::Cuda(storage))
304            }
305            Device::Metal(device) => {
306                let storage = device.ones_impl(shape, dtype)?;
307                Ok(Storage::Metal(storage))
308            }
309        }
310    }
311
312    pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
313        match self {
314            Device::Cpu => {
315                let storage = CpuDevice.zeros_impl(shape, dtype)?;
316                Ok(Storage::Cpu(storage))
317            }
318            Device::Cuda(device) => {
319                let storage = device.zeros_impl(shape, dtype)?;
320                Ok(Storage::Cuda(storage))
321            }
322            Device::Metal(device) => {
323                let storage = device.zeros_impl(shape, dtype)?;
324                Ok(Storage::Metal(storage))
325            }
326        }
327    }
328
329    pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
330        match self {
331            Device::Cpu => {
332                let storage = CpuDevice.alloc_uninit(shape, dtype)?;
333                Ok(Storage::Cpu(storage))
334            }
335            Device::Cuda(device) => {
336                let storage = device.alloc_uninit(shape, dtype)?;
337                Ok(Storage::Cuda(storage))
338            }
339            Device::Metal(device) => {
340                let storage = device.alloc_uninit(shape, dtype)?;
341                Ok(Storage::Metal(storage))
342            }
343        }
344    }
345
346    pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
347        match self {
348            Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
349            Device::Cuda(device) => {
350                let storage = device.storage_from_slice(data)?;
351                Ok(Storage::Cuda(storage))
352            }
353            Device::Metal(device) => {
354                let storage = device.storage_from_slice(data)?;
355                Ok(Storage::Metal(storage))
356            }
357        }
358    }
359
360    pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
361        match self {
362            Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
363            Device::Cuda(device) => {
364                let storage = array.to_cpu_storage();
365                let storage = device.storage_from_cpu_storage_owned(storage)?;
366                Ok(Storage::Cuda(storage))
367            }
368            Device::Metal(device) => {
369                let storage = array.to_cpu_storage();
370                let storage = device.storage_from_cpu_storage_owned(storage)?;
371                Ok(Storage::Metal(storage))
372            }
373        }
374    }
375
376    pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
377        match self {
378            Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
379            Device::Cuda(device) => {
380                let storage = S::to_cpu_storage_owned(data);
381                let storage = device.storage_from_cpu_storage_owned(storage)?;
382                Ok(Storage::Cuda(storage))
383            }
384            Device::Metal(device) => {
385                let storage = S::to_cpu_storage_owned(data);
386                let storage = device.storage_from_cpu_storage_owned(storage)?;
387                Ok(Storage::Metal(storage))
388            }
389        }
390    }
391
392    pub fn synchronize(&self) -> Result<()> {
393        match self {
394            Self::Cpu => Ok(()),
395            Self::Cuda(d) => d.synchronize(),
396            Self::Metal(d) => d.synchronize(),
397        }
398    }
399}