hpt_allocator/
backend.rs

1//! a module to define the backend of the tensor
2
3#![allow(unused)]
4
5use std::sync::Arc;
6
7use crate::clone_storage;
8
9/// Cpu backend
10///
11/// this backend stores the pointer of the data memory
12pub struct Cpu {
13    pub(crate) ptr: u64,
14    pub(crate) device_id: usize,
15}
16
17#[cfg(feature = "cuda")]
18/// Cuda backend
19pub struct Cuda {
20    pub(crate) ptr: u64,
21    /// device
22    pub device: Arc<cudarc::driver::CudaDevice>,
23    /// compute capability
24    pub cap: usize,
25}
26
27/// backend of tensor
28///
29/// this backend stores the pointer of the data memory
30///
31/// this backend is used when we `free` or `clone` the tensor
32#[derive(Clone)]
33pub struct Backend<B> {
34    /// the backend of the tensor
35    pub inner: B,
36    /// should drop the data, the data comes from the user
37    pub should_drop: bool,
38}
39
40impl<B: BackendTy> Backend<B> {
41    /// get the should drop flag
42    pub fn should_drop(&self) -> bool {
43        self.should_drop
44    }
45
46    /// forget the backend
47    pub fn forget(&mut self) {
48        self.should_drop = false;
49    }
50}
51
52impl<B: BackendTy> std::fmt::Debug for Backend<B> {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match B::ID {
55            0 => f.debug_struct("cpu").finish(),
56            1 => f.debug_struct("cuda").finish(),
57            _ => f.debug_struct("unknown").finish(),
58        }
59    }
60}
61
62impl Clone for Cpu {
63    fn clone(&self) -> Self {
64        if let Ok(mut storage) = crate::CPU_STORAGE.lock() {
65            clone_storage(self.ptr as *mut u8, self.device_id, &mut storage);
66        } else {
67            panic!("failed to lock CPU_STORAGE");
68        }
69        Cpu {
70            ptr: self.ptr,
71            device_id: self.device_id,
72        }
73    }
74}
75
76impl Backend<Cpu> {
77    /// create a new Cpu backend
78    pub fn new(address: u64, device_id: usize, should_drop: bool) -> Self {
79        Backend {
80            inner: Cpu {
81                ptr: address,
82                device_id,
83            },
84            should_drop,
85        }
86    }
87}
88
89#[cfg(feature = "cuda")]
90impl Clone for Cuda {
91    fn clone(&self) -> Self {
92        if let Ok(mut storage) = crate::CUDA_STORAGE.lock() {
93            clone_storage(self.ptr as *mut u8, self.device.ordinal(), &mut storage);
94        } else {
95            panic!("failed to lock CPU_STORAGE");
96        }
97        Cuda {
98            ptr: self.ptr,
99            device: self.device.clone(),
100            cap: self.cap,
101        }
102    }
103}
104
105#[cfg(feature = "cuda")]
106impl Backend<Cuda> {
107    /// create a new Cuda backend
108    pub fn new(address: u64, device: Arc<cudarc::driver::CudaDevice>, should_drop: bool) -> Self {
109        let cap_major = device.attribute(
110            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
111        ).expect("failed to get compute capability major when creating cuda backend");
112        let cap_minor = device.attribute(
113            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
114        ).expect("failed to get compute capability minor when creating cuda backend");
115        Backend {
116            inner: Cuda {
117                ptr: address,
118                device,
119                cap: (cap_major * 10 + cap_minor) as usize,
120            },
121            should_drop,
122        }
123    }
124}
125
126/// trait for buffer
127///
128/// this trait is used to get the pointer of the data memory
129pub trait Buffer {
130    /// get the pointer of the data memory
131    fn get_ptr(&self) -> u64;
132}
133
134impl Buffer for Cpu {
135    fn get_ptr(&self) -> u64 {
136        self.ptr
137    }
138}
139
140#[cfg(feature = "cuda")]
141impl Buffer for Cuda {
142    fn get_ptr(&self) -> u64 {
143        self.ptr
144    }
145}
146
147/// backend id trait
148///
149/// this trait is used to get the id of the backend
150///
151/// 0: Cpu
152///
153/// 1: Cuda
154///
155/// 2: Wgpu
156pub trait BackendTy {
157    /// beackend id
158    const ID: u8;
159}
160
161impl BackendTy for Cpu {
162    const ID: u8 = 0;
163}
164
165#[cfg(feature = "cuda")]
166impl BackendTy for Cuda {
167    const ID: u8 = 1;
168}