1#![allow(unused)]
4
5use std::sync::Arc;
6
7use crate::clone_storage;
8
9pub struct Cpu {
13 pub(crate) ptr: u64,
14 pub(crate) device_id: usize,
15}
16
17#[cfg(feature = "cuda")]
18pub struct Cuda {
20 pub(crate) ptr: u64,
21 pub device: Arc<cudarc::driver::CudaDevice>,
23 pub cap: usize,
25}
26
27#[derive(Clone)]
33pub struct Backend<B> {
34 pub inner: B,
36}
37
38impl<B: BackendTy> std::fmt::Debug for Backend<B> {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 match B::ID {
41 0 => f.debug_struct("cpu").finish(),
42 1 => f.debug_struct("cuda").finish(),
43 _ => f.debug_struct("unknown").finish(),
44 }
45 }
46}
47
48impl Clone for Cpu {
49 fn clone(&self) -> Self {
50 if let Ok(mut storage) = crate::CPU_STORAGE.lock() {
51 clone_storage(self.ptr as *mut u8, self.device_id, &mut storage);
52 } else {
53 panic!("failed to lock CPU_STORAGE");
54 }
55 Cpu {
56 ptr: self.ptr,
57 device_id: self.device_id,
58 }
59 }
60}
61
62impl Backend<Cpu> {
63 pub fn new(address: u64, device_id: usize) -> Self {
65 Backend {
66 inner: Cpu {
67 ptr: address,
68 device_id,
69 },
70 }
71 }
72}
73
74#[cfg(feature = "cuda")]
75impl Clone for Cuda {
76 fn clone(&self) -> Self {
77 if let Ok(mut storage) = crate::CUDA_STORAGE.lock() {
78 clone_storage(self.ptr as *mut u8, self.device.ordinal(), &mut storage);
79 } else {
80 panic!("failed to lock CPU_STORAGE");
81 }
82 Cuda {
83 ptr: self.ptr,
84 device: self.device.clone(),
85 cap: self.cap,
86 }
87 }
88}
89
90#[cfg(feature = "cuda")]
91impl Backend<Cuda> {
92 pub fn new(address: u64, device: Arc<cudarc::driver::CudaDevice>) -> Self {
94 let cap_major = device.attribute(
95 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
96 ).expect("failed to get compute capability major when creating cuda backend");
97 let cap_minor = device.attribute(
98 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
99 ).expect("failed to get compute capability minor when creating cuda backend");
100 Backend {
101 inner: Cuda {
102 ptr: address,
103 device,
104 cap: (cap_major * 10 + cap_minor) as usize,
105 },
106 }
107 }
108}
109
110pub trait Buffer {
114 fn get_ptr(&self) -> u64;
116}
117
118impl Buffer for Cpu {
119 fn get_ptr(&self) -> u64 {
120 self.ptr
121 }
122}
123
124#[cfg(feature = "cuda")]
125impl Buffer for Cuda {
126 fn get_ptr(&self) -> u64 {
127 self.ptr
128 }
129}
130
131pub trait BackendTy {
141 const ID: u8;
143}
144
145impl BackendTy for Cpu {
146 const ID: u8 = 0;
147}
148
149#[cfg(feature = "cuda")]
150impl BackendTy for Cuda {
151 const ID: u8 = 1;
152}