1#![allow(unused)]
4
5use std::sync::Arc;
6
7use crate::clone_storage;
8
9pub struct Cpu {
13 pub(crate) ptr: u64,
14 pub(crate) device_id: usize,
15}
16
17#[cfg(feature = "cuda")]
18pub struct Cuda {
20 pub(crate) ptr: u64,
21 pub device: Arc<cudarc::driver::CudaDevice>,
23 pub cap: usize,
25}
26
27#[derive(Clone)]
33pub struct Backend<B> {
34 pub inner: B,
36 pub should_drop: bool,
38}
39
40impl<B: BackendTy> Backend<B> {
41 pub fn should_drop(&self) -> bool {
43 self.should_drop
44 }
45
46 pub fn forget(&mut self) {
48 self.should_drop = false;
49 }
50}
51
52impl<B: BackendTy> std::fmt::Debug for Backend<B> {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match B::ID {
55 0 => f.debug_struct("cpu").finish(),
56 1 => f.debug_struct("cuda").finish(),
57 _ => f.debug_struct("unknown").finish(),
58 }
59 }
60}
61
62impl Clone for Cpu {
63 fn clone(&self) -> Self {
64 if let Ok(mut storage) = crate::CPU_STORAGE.lock() {
65 clone_storage(self.ptr as *mut u8, self.device_id, &mut storage);
66 } else {
67 panic!("failed to lock CPU_STORAGE");
68 }
69 Cpu {
70 ptr: self.ptr,
71 device_id: self.device_id,
72 }
73 }
74}
75
76impl Backend<Cpu> {
77 pub fn new(address: u64, device_id: usize, should_drop: bool) -> Self {
79 Backend {
80 inner: Cpu {
81 ptr: address,
82 device_id,
83 },
84 should_drop,
85 }
86 }
87}
88
89#[cfg(feature = "cuda")]
90impl Clone for Cuda {
91 fn clone(&self) -> Self {
92 if let Ok(mut storage) = crate::CUDA_STORAGE.lock() {
93 clone_storage(self.ptr as *mut u8, self.device.ordinal(), &mut storage);
94 } else {
95 panic!("failed to lock CPU_STORAGE");
96 }
97 Cuda {
98 ptr: self.ptr,
99 device: self.device.clone(),
100 cap: self.cap,
101 }
102 }
103}
104
105#[cfg(feature = "cuda")]
106impl Backend<Cuda> {
107 pub fn new(address: u64, device: Arc<cudarc::driver::CudaDevice>, should_drop: bool) -> Self {
109 let cap_major = device.attribute(
110 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
111 ).expect("failed to get compute capability major when creating cuda backend");
112 let cap_minor = device.attribute(
113 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
114 ).expect("failed to get compute capability minor when creating cuda backend");
115 Backend {
116 inner: Cuda {
117 ptr: address,
118 device,
119 cap: (cap_major * 10 + cap_minor) as usize,
120 },
121 should_drop,
122 }
123 }
124}
125
126pub trait Buffer {
130 fn get_ptr(&self) -> u64;
132}
133
134impl Buffer for Cpu {
135 fn get_ptr(&self) -> u64 {
136 self.ptr
137 }
138}
139
140#[cfg(feature = "cuda")]
141impl Buffer for Cuda {
142 fn get_ptr(&self) -> u64 {
143 self.ptr
144 }
145}
146
147pub trait BackendTy {
157 const ID: u8;
159}
160
161impl BackendTy for Cpu {
162 const ID: u8 = 0;
163}
164
165#[cfg(feature = "cuda")]
166impl BackendTy for Cuda {
167 const ID: u8 = 1;
168}