hpt_allocator/
backend.rs

1//! a module to define the backend of the tensor
2
3#![allow(unused)]
4
5use std::sync::Arc;
6
7use crate::clone_storage;
8
9/// Cpu backend
10///
11/// this backend stores the pointer of the data memory
12pub struct Cpu {
13    pub(crate) ptr: u64,
14    pub(crate) device_id: usize,
15}
16
17#[cfg(feature = "cuda")]
18/// Cuda backend
19pub struct Cuda {
20    pub(crate) ptr: u64,
21    /// device
22    pub device: Arc<cudarc::driver::CudaDevice>,
23    /// compute capability
24    pub cap: usize,
25}
26
27/// backend of tensor
28///
29/// this backend stores the pointer of the data memory
30///
31/// this backend is used when we `free` or `clone` the tensor
32#[derive(Clone)]
33pub struct Backend<B> {
34    /// the backend of the tensor
35    pub _backend: B,
36}
37
38impl Clone for Cpu {
39    fn clone(&self) -> Self {
40        if let Ok(mut storage) = crate::CPU_STORAGE.lock() {
41            clone_storage(self.ptr as *mut u8, self.device_id, &mut storage);
42        } else {
43            panic!("failed to lock CPU_STORAGE");
44        }
45        Cpu {
46            ptr: self.ptr,
47            device_id: self.device_id,
48        }
49    }
50}
51
52impl Backend<Cpu> {
53    /// create a new Cpu backend
54    pub fn new(address: u64, device_id: usize) -> Self {
55        Backend {
56            _backend: Cpu {
57                ptr: address,
58                device_id,
59            },
60        }
61    }
62}
63
64#[cfg(feature = "cuda")]
65impl Clone for Cuda {
66    fn clone(&self) -> Self {
67        if let Ok(mut storage) = crate::CUDA_STORAGE.lock() {
68            clone_storage(self.ptr as *mut u8, self.device.ordinal(), &mut storage);
69        } else {
70            panic!("failed to lock CPU_STORAGE");
71        }
72        Cuda {
73            ptr: self.ptr,
74            device: self.device.clone(),
75            cap: self.cap,
76        }
77    }
78}
79
80#[cfg(feature = "cuda")]
81impl Backend<Cuda> {
82    /// create a new Cuda backend
83    pub fn new(address: u64, device: Arc<cudarc::driver::CudaDevice>) -> Self {
84        let cap_major = device.attribute(
85            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
86        ).expect("failed to get compute capability major when creating cuda backend");
87        let cap_minor = device.attribute(
88            cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
89        ).expect("failed to get compute capability minor when creating cuda backend");
90        Backend {
91            _backend: Cuda {
92                ptr: address,
93                device,
94                cap: (cap_major * 10 + cap_minor) as usize,
95            },
96        }
97    }
98}
99
100/// trait for buffer
101///
102/// this trait is used to get the pointer of the data memory
103pub trait Buffer {
104    /// get the pointer of the data memory
105    fn get_ptr(&self) -> u64;
106}
107
108impl Buffer for Cpu {
109    fn get_ptr(&self) -> u64 {
110        self.ptr
111    }
112}
113
114#[cfg(feature = "cuda")]
115impl Buffer for Cuda {
116    fn get_ptr(&self) -> u64 {
117        self.ptr
118    }
119}
120
121/// backend id trait
122///
123/// this trait is used to get the id of the backend
124///
125/// 0: Cpu
126///
127/// 1: Cuda
128///
129/// 2: Wgpu
130pub trait BackendTy {
131    /// beackend id
132    const ID: u8;
133}
134
135impl BackendTy for Cpu {
136    const ID: u8 = 0;
137}
138
139#[cfg(feature = "cuda")]
140impl BackendTy for Cuda {
141    const ID: u8 = 1;
142}