constensor_core/
device.rs

1#[cfg(feature = "cuda")]
2use crate::cuda_backend::CudaDevice;
3use crate::{
4    cpu_storage::CpuDevice,
5    storage::{BackendDevice, Storage},
6    CompiledGraph, DType, GraphNode, Result, Shape,
7};
8
9/// Marker trait for devices
10pub trait Dev: Clone {
11    fn resolve() -> Result<Device>;
12}
13
14#[derive(Clone)]
15pub struct Cpu;
16
17impl Dev for Cpu {
18    fn resolve() -> Result<Device> {
19        Ok(Device::Cpu)
20    }
21}
22
23#[cfg(feature = "cuda")]
24#[derive(Clone)]
25pub struct Cuda<const ORD: usize>;
26
27#[cfg(feature = "cuda")]
28macro_rules! cuda_device {
29    ($ord:expr) => {
30        impl Dev for Cuda<$ord> {
31            fn resolve() -> Result<Device> {
32                Ok(Device::Cuda(CudaDevice::new($ord)?))
33            }
34        }
35    };
36}
37
38// NOTE: Support up to 10 ordinals
39#[cfg(feature = "cuda")]
40cuda_device!(0);
41#[cfg(feature = "cuda")]
42cuda_device!(1);
43#[cfg(feature = "cuda")]
44cuda_device!(2);
45#[cfg(feature = "cuda")]
46cuda_device!(3);
47#[cfg(feature = "cuda")]
48cuda_device!(4);
49#[cfg(feature = "cuda")]
50cuda_device!(5);
51#[cfg(feature = "cuda")]
52cuda_device!(6);
53#[cfg(feature = "cuda")]
54cuda_device!(7);
55#[cfg(feature = "cuda")]
56cuda_device!(8);
57#[cfg(feature = "cuda")]
58cuda_device!(9);
59
60#[cfg(feature = "cuda")]
61pub type BestDevice<const ORD: usize> = Cuda<ORD>;
62#[cfg(not(feature = "cuda"))]
63pub type BestDevice<const ORD: usize> = Cpu;
64
65/// A concrete device.
66#[derive(Clone)]
67pub enum Device {
68    #[cfg(feature = "cuda")]
69    Cuda(CudaDevice),
70    Cpu,
71}
72
73impl Device {
74    pub fn run_graph<S: Shape, T: DType, D: Dev>(
75        &self,
76        graph: &CompiledGraph<S, T, D>,
77    ) -> Result<Storage<T>> {
78        match self {
79            #[cfg(feature = "cuda")]
80            Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.run_graph::<S, T, D>(graph)?)),
81            Self::Cpu => Ok(Storage::Cpu(CpuDevice.run_graph::<S, T, D>(graph)?)),
82        }
83    }
84
85    pub fn compile<S: Shape, T: DType, D: Dev>(
86        &self,
87        graph: Vec<GraphNode<T>>,
88    ) -> Result<CompiledGraph<S, T, D>> {
89        match self {
90            #[cfg(feature = "cuda")]
91            Self::Cuda(cuda) => cuda.compile::<S, T, D>(graph),
92            Self::Cpu => CpuDevice.compile::<S, T, D>(graph),
93        }
94    }
95}