hpt_allocator/
backend.rs

1//! a module to define the backend of the tensor
2
3#![allow(unused)]
4
5use std::sync::Arc;
6
7use crate::clone_storage;
8
9/// Cpu backend
10///
11/// this backend stores the pointer of the data memory
12pub struct Cpu {
13    pub(crate) ptr: u64,
14    pub(crate) device_id: usize,
15}
16
17#[cfg(feature = "cuda")]
18/// Cuda backend
19pub struct Cuda {
20    pub(crate) ptr: u64,
21    /// device
22    pub device: Arc<cudarc::driver::CudaDevice>,
23    /// compute capability
24    pub cap: usize,
25}
26
27/// backend of tensor
28///
29/// this backend stores the pointer of the data memory
30///
31/// this backend is used when we `free` or `clone` the tensor
32#[derive(Clone)]
33pub struct Backend<B> {
34    /// the backend of the tensor
35    pub inner: B,
36}
37
38impl<B: BackendTy> std::fmt::Debug for Backend<B> {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match B::ID {
41            0 => f.debug_struct("cpu").finish(),
42            1 => f.debug_struct("cuda").finish(),
43            _ => f.debug_struct("unknown").finish(),
44        }
45    }
46}
47
48impl Clone for Cpu {
49    fn clone(&self) -> Self {
50        if let Ok(mut storage) = crate::CPU_STORAGE.lock() {
51            clone_storage(self.ptr as *mut u8, self.device_id, &mut storage);
52        } else {
53            panic!("failed to lock CPU_STORAGE");
54        }
55        Cpu {
56            ptr: self.ptr,
57            device_id: self.device_id,
58        }
59    }
60}
61
62impl Backend<Cpu> {
63    /// create a new Cpu backend
64    pub fn new(address: u64, device_id: usize) -> Self {
65        Backend {
66            inner: Cpu {
67                ptr: address,
68                device_id,
69            },
70        }
71    }
72}
73
74#[cfg(feature = "cuda")]
75impl Clone for Cuda {
76    fn clone(&self) -> Self {
77        if let Ok(mut storage) = crate::CUDA_STORAGE.lock() {
78            clone_storage(self.ptr as *mut u8, self.device.ordinal(), &mut storage);
79        } else {
80            panic!("failed to lock CPU_STORAGE");
81        }
82        Cuda {
83            ptr: self.ptr,
84            device: self.device.clone(),
85            cap: self.cap,
86        }
87    }
88}
89
90#[cfg(feature = "cuda")]
91impl Backend<Cuda> {
92    /// create a new Cuda backend
93    pub fn new(address: u64, device: Arc<cudarc::driver::CudaDevice>) -> Self {
94        let cap_major = device.attribute(
95            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
96        ).expect("failed to get compute capability major when creating cuda backend");
97        let cap_minor = device.attribute(
98            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
99        ).expect("failed to get compute capability minor when creating cuda backend");
100        Backend {
101            inner: Cuda {
102                ptr: address,
103                device,
104                cap: (cap_major * 10 + cap_minor) as usize,
105            },
106        }
107    }
108}
109
110/// trait for buffer
111///
112/// this trait is used to get the pointer of the data memory
113pub trait Buffer {
114    /// get the pointer of the data memory
115    fn get_ptr(&self) -> u64;
116}
117
118impl Buffer for Cpu {
119    fn get_ptr(&self) -> u64 {
120        self.ptr
121    }
122}
123
124#[cfg(feature = "cuda")]
125impl Buffer for Cuda {
126    fn get_ptr(&self) -> u64 {
127        self.ptr
128    }
129}
130
131/// backend id trait
132///
133/// this trait is used to get the id of the backend
134///
135/// 0: Cpu
136///
137/// 1: Cuda
138///
139/// 2: Wgpu
140pub trait BackendTy {
141    /// beackend id
142    const ID: u8;
143}
144
145impl BackendTy for Cpu {
146    const ID: u8 = 0;
147}
148
149#[cfg(feature = "cuda")]
150impl BackendTy for Cuda {
151    const ID: u8 = 1;
152}