1#![allow(unused)]
4
5use std::sync::Arc;
6
7use crate::clone_storage;
8
9pub struct Cpu {
13 pub(crate) ptr: u64,
14 pub(crate) device_id: usize,
15}
16
17#[cfg(feature = "cuda")]
18pub struct Cuda {
20 pub(crate) ptr: u64,
21 pub device: Arc<cudarc::driver::CudaDevice>,
23 pub cap: usize,
25}
26
27#[derive(Clone)]
33pub struct Backend<B> {
34 pub _backend: B,
36}
37
38impl Clone for Cpu {
39 fn clone(&self) -> Self {
40 if let Ok(mut storage) = crate::CPU_STORAGE.lock() {
41 clone_storage(self.ptr as *mut u8, self.device_id, &mut storage);
42 } else {
43 panic!("failed to lock CPU_STORAGE");
44 }
45 Cpu {
46 ptr: self.ptr,
47 device_id: self.device_id,
48 }
49 }
50}
51
52impl Backend<Cpu> {
53 pub fn new(address: u64, device_id: usize) -> Self {
55 Backend {
56 _backend: Cpu {
57 ptr: address,
58 device_id,
59 },
60 }
61 }
62}
63
64#[cfg(feature = "cuda")]
65impl Clone for Cuda {
66 fn clone(&self) -> Self {
67 if let Ok(mut storage) = crate::CUDA_STORAGE.lock() {
68 clone_storage(self.ptr as *mut u8, self.device.ordinal(), &mut storage);
69 } else {
70 panic!("failed to lock CPU_STORAGE");
71 }
72 Cuda {
73 ptr: self.ptr,
74 device: self.device.clone(),
75 cap: self.cap,
76 }
77 }
78}
79
80#[cfg(feature = "cuda")]
81impl Backend<Cuda> {
82 pub fn new(address: u64, device: Arc<cudarc::driver::CudaDevice>) -> Self {
84 let cap_major = device.attribute(
85 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
86 ).expect("failed to get compute capability major when creating cuda backend");
87 let cap_minor = device.attribute(
88 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
89 ).expect("failed to get compute capability minor when creating cuda backend");
90 Backend {
91 _backend: Cuda {
92 ptr: address,
93 device,
94 cap: (cap_major * 10 + cap_minor) as usize,
95 },
96 }
97 }
98}
99
100pub trait Buffer {
104 fn get_ptr(&self) -> u64;
106}
107
108impl Buffer for Cpu {
109 fn get_ptr(&self) -> u64 {
110 self.ptr
111 }
112}
113
114#[cfg(feature = "cuda")]
115impl Buffer for Cuda {
116 fn get_ptr(&self) -> u64 {
117 self.ptr
118 }
119}
120
121pub trait BackendTy {
131 const ID: u8;
133}
134
135impl BackendTy for Cpu {
136 const ID: u8 = 0;
137}
138
139#[cfg(feature = "cuda")]
140impl BackendTy for Cuda {
141 const ID: u8 = 1;
142}