constensor_core/
device.rs1#[cfg(feature = "cuda")]
2use crate::cuda_backend::CudaDevice;
3use crate::{
4 cpu_storage::CpuDevice,
5 storage::{BackendDevice, Storage},
6 CompiledGraph, DType, GraphNode, Result, Shape,
7};
8
9pub trait Dev: Clone {
11 fn resolve() -> Result<Device>;
12}
13
14#[derive(Clone)]
15pub struct Cpu;
16
17impl Dev for Cpu {
18 fn resolve() -> Result<Device> {
19 Ok(Device::Cpu)
20 }
21}
22
23#[cfg(feature = "cuda")]
24#[derive(Clone)]
25pub struct Cuda<const ORD: usize>;
26
27#[cfg(feature = "cuda")]
28macro_rules! cuda_device {
29 ($ord:expr) => {
30 impl Dev for Cuda<$ord> {
31 fn resolve() -> Result<Device> {
32 Ok(Device::Cuda(CudaDevice::new($ord)?))
33 }
34 }
35 };
36}
37
38#[cfg(feature = "cuda")]
40cuda_device!(0);
41#[cfg(feature = "cuda")]
42cuda_device!(1);
43#[cfg(feature = "cuda")]
44cuda_device!(2);
45#[cfg(feature = "cuda")]
46cuda_device!(3);
47#[cfg(feature = "cuda")]
48cuda_device!(4);
49#[cfg(feature = "cuda")]
50cuda_device!(5);
51#[cfg(feature = "cuda")]
52cuda_device!(6);
53#[cfg(feature = "cuda")]
54cuda_device!(7);
55#[cfg(feature = "cuda")]
56cuda_device!(8);
57#[cfg(feature = "cuda")]
58cuda_device!(9);
59
60#[cfg(feature = "cuda")]
61pub type BestDevice<const ORD: usize> = Cuda<ORD>;
62#[cfg(not(feature = "cuda"))]
63pub type BestDevice<const ORD: usize> = Cpu;
64
65#[derive(Clone)]
67pub enum Device {
68 #[cfg(feature = "cuda")]
69 Cuda(CudaDevice),
70 Cpu,
71}
72
73impl Device {
74 pub fn run_graph<S: Shape, T: DType, D: Dev>(
75 &self,
76 graph: &CompiledGraph<S, T, D>,
77 ) -> Result<Storage<T>> {
78 match self {
79 #[cfg(feature = "cuda")]
80 Self::Cuda(cuda) => Ok(Storage::Cuda(cuda.run_graph::<S, T, D>(graph)?)),
81 Self::Cpu => Ok(Storage::Cpu(CpuDevice.run_graph::<S, T, D>(graph)?)),
82 }
83 }
84
85 pub fn compile<S: Shape, T: DType, D: Dev>(
86 &self,
87 graph: Vec<GraphNode<T>>,
88 ) -> Result<CompiledGraph<S, T, D>> {
89 match self {
90 #[cfg(feature = "cuda")]
91 Self::Cuda(cuda) => cuda.compile::<S, T, D>(graph),
92 Self::Cpu => CpuDevice.compile::<S, T, D>(graph),
93 }
94 }
95}