Skip to main content

candle_core/
device.rs

1use crate::backend::BackendDevice;
2use crate::cpu_backend::CpuDevice;
3use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
4
5/// A `DeviceLocation` represents a physical device whereas multiple `Device`
6/// can live on the same location (typically for cuda devices).
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DeviceLocation {
9    Cpu,
10    Cuda { gpu_id: usize },
11    Metal { gpu_id: usize },
12}
13
14/// Cpu, Cuda, or Metal
15#[derive(Debug, Clone)]
16pub enum Device {
17    Cpu,
18    Cuda(crate::CudaDevice),
19    Metal(crate::MetalDevice),
20}
21
22pub trait NdArray {
23    fn shape(&self) -> Result<Shape>;
24
25    fn to_cpu_storage(&self) -> CpuStorage;
26}
27
28impl<S: WithDType> NdArray for S {
29    fn shape(&self) -> Result<Shape> {
30        Ok(Shape::from(()))
31    }
32
33    fn to_cpu_storage(&self) -> CpuStorage {
34        S::to_cpu_storage(&[*self])
35    }
36}
37
38impl<S: WithDType, const N: usize> NdArray for &[S; N] {
39    fn shape(&self) -> Result<Shape> {
40        Ok(Shape::from(self.len()))
41    }
42
43    fn to_cpu_storage(&self) -> CpuStorage {
44        S::to_cpu_storage(self.as_slice())
45    }
46}
47
48impl<S: WithDType> NdArray for &[S] {
49    fn shape(&self) -> Result<Shape> {
50        Ok(Shape::from(self.len()))
51    }
52
53    fn to_cpu_storage(&self) -> CpuStorage {
54        S::to_cpu_storage(self)
55    }
56}
57
58impl<S: WithDType, const N: usize, const M: usize> NdArray for &[[S; N]; M] {
59    fn shape(&self) -> Result<Shape> {
60        Ok(Shape::from((M, N)))
61    }
62
63    fn to_cpu_storage(&self) -> CpuStorage {
64        S::to_cpu_storage_owned(self.concat())
65    }
66}
67
68impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize> NdArray
69    for &[[[S; N3]; N2]; N1]
70{
71    fn shape(&self) -> Result<Shape> {
72        Ok(Shape::from((N1, N2, N3)))
73    }
74
75    fn to_cpu_storage(&self) -> CpuStorage {
76        let mut vec = Vec::with_capacity(N1 * N2 * N3);
77        for i1 in 0..N1 {
78            for i2 in 0..N2 {
79                vec.extend(self[i1][i2])
80            }
81        }
82        S::to_cpu_storage_owned(vec)
83    }
84}
85
86impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4: usize> NdArray
87    for &[[[[S; N4]; N3]; N2]; N1]
88{
89    fn shape(&self) -> Result<Shape> {
90        Ok(Shape::from((N1, N2, N3, N4)))
91    }
92
93    fn to_cpu_storage(&self) -> CpuStorage {
94        let mut vec = Vec::with_capacity(N1 * N2 * N3 * N4);
95        for i1 in 0..N1 {
96            for i2 in 0..N2 {
97                for i3 in 0..N3 {
98                    vec.extend(self[i1][i2][i3])
99                }
100            }
101        }
102        S::to_cpu_storage_owned(vec)
103    }
104}
105
106impl<S: WithDType> NdArray for Vec<S> {
107    fn shape(&self) -> Result<Shape> {
108        Ok(Shape::from(self.len()))
109    }
110
111    fn to_cpu_storage(&self) -> CpuStorage {
112        S::to_cpu_storage(self.as_slice())
113    }
114}
115
116impl<S: WithDType> NdArray for Vec<&[S]> {
117    fn shape(&self) -> Result<Shape> {
118        if self.is_empty() {
119            crate::bail!("empty array")
120        }
121        let n = self.len();
122        let m = self[0].len();
123        for v in self.iter() {
124            if v.len() != m {
125                crate::bail!("two elements have different len {m} {}", v.len())
126            }
127        }
128        Ok(Shape::from((n, m)))
129    }
130
131    fn to_cpu_storage(&self) -> CpuStorage {
132        let data = self.iter().copied().flatten().copied().collect::<Vec<_>>();
133        S::to_cpu_storage_owned(data)
134    }
135}
136
137impl<S: WithDType> NdArray for Vec<Vec<S>> {
138    fn shape(&self) -> Result<Shape> {
139        if self.is_empty() {
140            crate::bail!("empty array")
141        }
142        let n = self.len();
143        let m = self[0].len();
144        for v in self.iter() {
145            if v.len() != m {
146                crate::bail!("two elements have different len {m} {}", v.len())
147            }
148        }
149        Ok(Shape::from((n, m)))
150    }
151
152    fn to_cpu_storage(&self) -> CpuStorage {
153        let len: usize = self.iter().map(|v| v.len()).sum();
154        let mut dst = Vec::with_capacity(len);
155        for v in self.iter() {
156            dst.extend(v.iter().copied());
157        }
158        S::to_cpu_storage_owned(dst)
159    }
160}
161
162impl<S: WithDType> NdArray for Vec<Vec<Vec<S>>> {
163    fn shape(&self) -> Result<Shape> {
164        if self.is_empty() {
165            crate::bail!("empty array")
166        }
167        let shape0 = self[0].shape()?;
168        let n = self.len();
169        for v in self.iter() {
170            let shape = v.shape()?;
171            if shape != shape0 {
172                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
173            }
174        }
175        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
176    }
177
178    fn to_cpu_storage(&self) -> CpuStorage {
179        if self.is_empty() {
180            return S::to_cpu_storage_owned(vec![]);
181        }
182        let len: usize = self
183            .iter()
184            .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
185            .sum();
186        let mut dst = Vec::with_capacity(len);
187        for v1 in self.iter() {
188            for v2 in v1.iter() {
189                dst.extend(v2.iter().copied());
190            }
191        }
192        S::to_cpu_storage_owned(dst)
193    }
194}
195
196impl<S: WithDType> NdArray for Vec<Vec<Vec<Vec<S>>>> {
197    fn shape(&self) -> Result<Shape> {
198        if self.is_empty() {
199            crate::bail!("empty array")
200        }
201        let shape0 = self[0].shape()?;
202        let n = self.len();
203        for v in self.iter() {
204            let shape = v.shape()?;
205            if shape != shape0 {
206                crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
207            }
208        }
209        Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
210    }
211
212    fn to_cpu_storage(&self) -> CpuStorage {
213        let len: usize = self
214            .iter()
215            .map(|v| {
216                v.iter()
217                    .map(|v| v.iter().map(|v| v.len()).sum::<usize>())
218                    .sum::<usize>()
219            })
220            .sum();
221        let mut dst = Vec::with_capacity(len);
222        for v1 in self.iter() {
223            for v2 in v1.iter() {
224                for v3 in v2.iter() {
225                    dst.extend(v3.iter().copied());
226                }
227            }
228        }
229        S::to_cpu_storage_owned(dst)
230    }
231}
232
233impl Device {
234    pub fn new_cuda(ordinal: usize) -> Result<Self> {
235        Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
236    }
237
238    pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
239        match self {
240            Self::Cuda(d) => Ok(d),
241            Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
242            Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
243        }
244    }
245
246    pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
247        match self {
248            Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
249            Self::Cpu => crate::bail!("expected a metal device, got cpu"),
250            Self::Metal(d) => Ok(d),
251        }
252    }
253
254    pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
255        Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
256    }
257
258    pub fn new_metal(ordinal: usize) -> Result<Self> {
259        Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
260    }
261
262    /// Run `f` with device specific context.
263    ///
264    /// On CPU this installs candle's private rayon thread pool for the
265    /// duration of `f`, keeping worker threads warm across the many short
266    /// parallel sections in a model forward pass. Currently noop for other backends.
267    pub fn with_context<F, R>(&self, f: F) -> R
268    where
269        F: FnOnce() -> R + Send,
270        R: Send,
271    {
272        match self {
273            Self::Cpu => crate::utils::with_threadpool(f),
274            _ => f(),
275        }
276    }
277
278    pub fn set_seed(&self, seed: u64) -> Result<()> {
279        match self {
280            Self::Cpu => CpuDevice.set_seed(seed),
281            Self::Cuda(c) => c.set_seed(seed),
282            Self::Metal(m) => m.set_seed(seed),
283        }
284    }
285
286    pub fn get_current_seed(&self) -> Result<u64> {
287        match self {
288            Self::Cpu => CpuDevice.get_current_seed(),
289            Self::Cuda(c) => c.get_current_seed(),
290            Self::Metal(m) => m.get_current_seed(),
291        }
292    }
293
294    pub fn same_device(&self, rhs: &Self) -> bool {
295        match (self, rhs) {
296            (Self::Cpu, Self::Cpu) => true,
297            (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs),
298            (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs),
299            _ => false,
300        }
301    }
302
303    pub fn location(&self) -> DeviceLocation {
304        match self {
305            Self::Cpu => DeviceLocation::Cpu,
306            Self::Cuda(device) => device.location(),
307            Device::Metal(device) => device.location(),
308        }
309    }
310
311    pub fn is_cpu(&self) -> bool {
312        matches!(self, Self::Cpu)
313    }
314
315    pub fn is_cuda(&self) -> bool {
316        matches!(self, Self::Cuda(_))
317    }
318
319    pub fn is_metal(&self) -> bool {
320        matches!(self, Self::Metal(_))
321    }
322
323    pub fn supports_bf16(&self) -> bool {
324        match self {
325            Self::Cuda(_) | Self::Metal(_) => true,
326            Self::Cpu => false,
327        }
328    }
329
330    /// Return `BF16` for devices that support it, otherwise default to `F32`.
331    pub fn bf16_default_to_f32(&self) -> DType {
332        if self.supports_bf16() {
333            DType::BF16
334        } else {
335            DType::F32
336        }
337    }
338
339    pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
340        if crate::utils::cuda_is_available() {
341            Self::new_cuda(ordinal)
342        } else {
343            Ok(Self::Cpu)
344        }
345    }
346
347    pub fn metal_if_available(ordinal: usize) -> Result<Self> {
348        if crate::utils::metal_is_available() {
349            Self::new_metal(ordinal)
350        } else {
351            Ok(Self::Cpu)
352        }
353    }
354
355    pub(crate) fn rand_uniform_f64(
356        &self,
357        lo: f64,
358        up: f64,
359        shape: &Shape,
360        dtype: DType,
361    ) -> Result<Storage> {
362        match self {
363            Device::Cpu => {
364                let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
365                Ok(Storage::Cpu(storage))
366            }
367            Device::Cuda(device) => {
368                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
369                if dtype == DType::F16 || dtype == DType::BF16 {
370                    let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
371                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
372                } else {
373                    let storage = device.rand_uniform(shape, dtype, lo, up)?;
374                    Ok(Storage::Cuda(storage))
375                }
376            }
377            Device::Metal(device) => {
378                let storage = device.rand_uniform(shape, dtype, lo, up)?;
379                Ok(Storage::Metal(storage))
380            }
381        }
382    }
383
384    pub(crate) fn rand_uniform<T: crate::FloatDType>(
385        &self,
386        lo: T,
387        up: T,
388        shape: &Shape,
389    ) -> Result<Storage> {
390        self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
391    }
392
393    pub(crate) fn rand_normal_f64(
394        &self,
395        mean: f64,
396        std: f64,
397        shape: &Shape,
398        dtype: DType,
399    ) -> Result<Storage> {
400        match self {
401            Device::Cpu => {
402                let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
403                Ok(Storage::Cpu(storage))
404            }
405            Device::Cuda(device) => {
406                // TODO: Remove the special case if we start supporting generating f16/bf16 directly.
407                if dtype == DType::F16 || dtype == DType::BF16 {
408                    let storage = device.rand_normal(shape, DType::F32, mean, std)?;
409                    Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
410                } else {
411                    let storage = device.rand_normal(shape, dtype, mean, std)?;
412                    Ok(Storage::Cuda(storage))
413                }
414            }
415            Device::Metal(device) => {
416                let storage = device.rand_normal(shape, dtype, mean, std)?;
417                Ok(Storage::Metal(storage))
418            }
419        }
420    }
421
422    pub(crate) fn rand_normal<T: crate::FloatDType>(
423        &self,
424        mean: T,
425        std: T,
426        shape: &Shape,
427    ) -> Result<Storage> {
428        self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
429    }
430
431    pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
432        match self {
433            Device::Cpu => {
434                let storage = CpuDevice.zeros_impl(shape, dtype)?;
435                Ok(Storage::Cpu(storage))
436            }
437            Device::Cuda(device) => {
438                let storage = device.zeros_impl(shape, dtype)?;
439                Ok(Storage::Cuda(storage))
440            }
441            Device::Metal(device) => {
442                let storage = device.zeros_impl(shape, dtype)?;
443                Ok(Storage::Metal(storage))
444            }
445        }
446    }
447
448    pub(crate) unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
449        match self {
450            Device::Cpu => {
451                let storage = CpuDevice.alloc_uninit(shape, dtype)?;
452                Ok(Storage::Cpu(storage))
453            }
454            Device::Cuda(device) => {
455                let storage = device.alloc_uninit(shape, dtype)?;
456                Ok(Storage::Cuda(storage))
457            }
458            Device::Metal(device) => {
459                let storage = device.alloc_uninit(shape, dtype)?;
460                Ok(Storage::Metal(storage))
461            }
462        }
463    }
464
465    pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
466        match self {
467            Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
468            Device::Cuda(device) => {
469                let storage = device.storage_from_slice(data)?;
470                Ok(Storage::Cuda(storage))
471            }
472            Device::Metal(device) => {
473                let storage = device.storage_from_slice(data)?;
474                Ok(Storage::Metal(storage))
475            }
476        }
477    }
478
479    pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
480        match self {
481            Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
482            Device::Cuda(device) => {
483                let storage = array.to_cpu_storage();
484                let storage = device.storage_from_cpu_storage_owned(storage)?;
485                Ok(Storage::Cuda(storage))
486            }
487            Device::Metal(device) => {
488                let storage = array.to_cpu_storage();
489                let storage = device.storage_from_cpu_storage_owned(storage)?;
490                Ok(Storage::Metal(storage))
491            }
492        }
493    }
494
495    pub(crate) fn storage_owned<S: WithDType>(&self, data: Vec<S>) -> Result<Storage> {
496        match self {
497            Device::Cpu => Ok(Storage::Cpu(S::to_cpu_storage_owned(data))),
498            Device::Cuda(device) => {
499                let storage = S::to_cpu_storage_owned(data);
500                let storage = device.storage_from_cpu_storage_owned(storage)?;
501                Ok(Storage::Cuda(storage))
502            }
503            Device::Metal(device) => {
504                let storage = S::to_cpu_storage_owned(data);
505                let storage = device.storage_from_cpu_storage_owned(storage)?;
506                Ok(Storage::Metal(storage))
507            }
508        }
509    }
510
511    pub fn synchronize(&self) -> Result<()> {
512        match self {
513            Self::Cpu => Ok(()),
514            Self::Cuda(d) => d.synchronize(),
515            Self::Metal(d) => d.synchronize(),
516        }
517    }
518}