1use crate::backend::dtype::Dtype;
26use half::f16;
27
28pub enum CpuBuf {
34 F32(Vec<f32>),
35 F16(Vec<f16>),
36 U32(Vec<u32>),
37 I32(Vec<i32>),
38 I8(Vec<i8>),
39}
40
41impl CpuBuf {
42 pub fn alloc(dtype: Dtype, n: usize) -> Self {
43 match dtype {
44 Dtype::F32 => CpuBuf::F32(vec![0.0; n]),
45 Dtype::F16 => CpuBuf::F16(vec![f16::ZERO; n]),
46 Dtype::U32 => CpuBuf::U32(vec![0u32; n]),
47 Dtype::I32 => CpuBuf::I32(vec![0i32; n]),
48 Dtype::I8 => CpuBuf::I8(vec![0i8; n]),
49 }
50 }
51
52 pub fn dtype(&self) -> Dtype {
53 match self {
54 CpuBuf::F32(_) => Dtype::F32,
55 CpuBuf::F16(_) => Dtype::F16,
56 CpuBuf::U32(_) => Dtype::U32,
57 CpuBuf::I32(_) => Dtype::I32,
58 CpuBuf::I8(_) => Dtype::I8,
59 }
60 }
61
62 pub fn len(&self) -> usize {
63 match self {
64 CpuBuf::F32(v) => v.len(),
65 CpuBuf::F16(v) => v.len(),
66 CpuBuf::U32(v) => v.len(),
67 CpuBuf::I32(v) => v.len(),
68 CpuBuf::I8(v) => v.len(),
69 }
70 }
71
72 pub fn is_empty(&self) -> bool {
73 self.len() == 0
74 }
75
76 pub fn as_f32(&self) -> &[f32] {
80 match self {
81 CpuBuf::F32(v) => v,
82 _ => panic!("CpuBuf::as_f32 on dtype {}", self.dtype().name()),
83 }
84 }
85 pub fn as_f32_mut(&mut self) -> &mut [f32] {
86 match self {
87 CpuBuf::F32(v) => v,
88 _ => panic!("CpuBuf::as_f32_mut on dtype {}", self.dtype().name()),
89 }
90 }
91 pub fn as_f16(&self) -> &[f16] {
92 match self {
93 CpuBuf::F16(v) => v,
94 _ => panic!("CpuBuf::as_f16 on dtype {}", self.dtype().name()),
95 }
96 }
97 pub fn as_u32(&self) -> &[u32] {
98 match self {
99 CpuBuf::U32(v) => v,
100 _ => panic!("CpuBuf::as_u32 on dtype {}", self.dtype().name()),
101 }
102 }
103 pub fn as_u32_mut(&mut self) -> &mut [u32] {
104 match self {
105 CpuBuf::U32(v) => v,
106 _ => panic!("CpuBuf::as_u32_mut on dtype {}", self.dtype().name()),
107 }
108 }
109 pub fn as_i32(&self) -> &[i32] {
110 match self {
111 CpuBuf::I32(v) => v,
112 _ => panic!("CpuBuf::as_i32 on dtype {}", self.dtype().name()),
113 }
114 }
115 pub fn as_i32_mut(&mut self) -> &mut [i32] {
116 match self {
117 CpuBuf::I32(v) => v,
118 _ => panic!("CpuBuf::as_i32_mut on dtype {}", self.dtype().name()),
119 }
120 }
121 pub fn as_i8(&self) -> &[i8] {
122 match self {
123 CpuBuf::I8(v) => v,
124 _ => panic!("CpuBuf::as_i8 on dtype {}", self.dtype().name()),
125 }
126 }
127
128 pub fn from_f32(data: Vec<f32>) -> Self {
129 CpuBuf::F32(data)
130 }
131 pub fn from_u32(data: Vec<u32>) -> Self {
132 CpuBuf::U32(data)
133 }
134 pub fn from_i32(data: Vec<i32>) -> Self {
135 CpuBuf::I32(data)
136 }
137}
138
139#[cfg(feature = "cuda")]
147pub enum CudaBuf {
148 F32(cudarc::driver::CudaSlice<f32>),
149 F16(cudarc::driver::CudaSlice<f16>),
150 U32(cudarc::driver::CudaSlice<u32>),
151 I32(cudarc::driver::CudaSlice<i32>),
152 I8(cudarc::driver::CudaSlice<i8>),
153}
154
155#[cfg(feature = "cuda")]
156impl CudaBuf {
157 pub fn dtype(&self) -> Dtype {
158 match self {
159 CudaBuf::F32(_) => Dtype::F32,
160 CudaBuf::F16(_) => Dtype::F16,
161 CudaBuf::U32(_) => Dtype::U32,
162 CudaBuf::I32(_) => Dtype::I32,
163 CudaBuf::I8(_) => Dtype::I8,
164 }
165 }
166
167 pub fn len(&self) -> usize {
168 match self {
169 CudaBuf::F32(s) => s.len(),
170 CudaBuf::F16(s) => s.len(),
171 CudaBuf::U32(s) => s.len(),
172 CudaBuf::I32(s) => s.len(),
173 CudaBuf::I8(s) => s.len(),
174 }
175 }
176
177 pub fn is_empty(&self) -> bool {
178 self.len() == 0
179 }
180
181 pub fn as_f16(&self) -> &cudarc::driver::CudaSlice<f16> {
182 match self {
183 CudaBuf::F16(s) => s,
184 _ => panic!("CudaBuf::as_f16 on dtype {}", self.dtype().name()),
185 }
186 }
187 pub fn as_f16_mut(&mut self) -> &mut cudarc::driver::CudaSlice<f16> {
188 match self {
189 CudaBuf::F16(s) => s,
190 _ => panic!("CudaBuf::as_f16_mut on dtype {}", self.dtype().name()),
191 }
192 }
193 pub fn as_f32(&self) -> &cudarc::driver::CudaSlice<f32> {
194 match self {
195 CudaBuf::F32(s) => s,
196 _ => panic!("CudaBuf::as_f32 on dtype {}", self.dtype().name()),
197 }
198 }
199 pub fn as_u32(&self) -> &cudarc::driver::CudaSlice<u32> {
200 match self {
201 CudaBuf::U32(s) => s,
202 _ => panic!("CudaBuf::as_u32 on dtype {}", self.dtype().name()),
203 }
204 }
205 pub fn as_u32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<u32> {
206 match self {
207 CudaBuf::U32(s) => s,
208 _ => panic!("CudaBuf::as_u32_mut on dtype {}", self.dtype().name()),
209 }
210 }
211 pub fn as_i32(&self) -> &cudarc::driver::CudaSlice<i32> {
212 match self {
213 CudaBuf::I32(s) => s,
214 _ => panic!("CudaBuf::as_i32 on dtype {}", self.dtype().name()),
215 }
216 }
217 pub fn as_i8(&self) -> &cudarc::driver::CudaSlice<i8> {
218 match self {
219 CudaBuf::I8(s) => s,
220 _ => panic!("CudaBuf::as_i8 on dtype {}", self.dtype().name()),
221 }
222 }
223 pub fn as_i8_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i8> {
224 match self {
225 CudaBuf::I8(s) => s,
226 _ => panic!("CudaBuf::as_i8_mut on dtype {}", self.dtype().name()),
227 }
228 }
229 pub fn as_f32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<f32> {
230 match self {
231 CudaBuf::F32(s) => s,
232 _ => panic!("CudaBuf::as_f32_mut on dtype {}", self.dtype().name()),
233 }
234 }
235 pub fn as_i32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i32> {
236 match self {
237 CudaBuf::I32(s) => s,
238 _ => panic!("CudaBuf::as_i32_mut on dtype {}", self.dtype().name()),
239 }
240 }
241
242 pub fn from_f16(s: cudarc::driver::CudaSlice<f16>) -> Self {
244 CudaBuf::F16(s)
245 }
246 pub fn from_f32(s: cudarc::driver::CudaSlice<f32>) -> Self {
247 CudaBuf::F32(s)
248 }
249 pub fn from_u32(s: cudarc::driver::CudaSlice<u32>) -> Self {
250 CudaBuf::U32(s)
251 }
252 pub fn from_i32(s: cudarc::driver::CudaSlice<i32>) -> Self {
253 CudaBuf::I32(s)
254 }
255 pub fn from_i8(s: cudarc::driver::CudaSlice<i8>) -> Self {
256 CudaBuf::I8(s)
257 }
258
259 pub unsafe fn transmute<T>(&self, len: usize) -> Option<cudarc::driver::CudaView<'_, T>> {
267 match self {
268 CudaBuf::F16(s) => unsafe { s.transmute(len) },
269 CudaBuf::F32(s) => unsafe { s.transmute(len) },
270 CudaBuf::U32(s) => unsafe { s.transmute(len) },
271 CudaBuf::I32(s) => unsafe { s.transmute(len) },
272 CudaBuf::I8(s) => unsafe { s.transmute(len) },
273 }
274 }
275}
276
277#[cfg(feature = "cuda")]
282unsafe impl<'a, 'b: 'a> cudarc::driver::PushKernelArg<&'b CudaBuf>
283 for cudarc::driver::LaunchArgs<'a>
284{
285 fn arg(&mut self, arg: &'b CudaBuf) -> &mut Self {
286 match arg {
287 CudaBuf::F16(s) => self.arg(s),
288 CudaBuf::F32(s) => self.arg(s),
289 CudaBuf::U32(s) => self.arg(s),
290 CudaBuf::I32(s) => self.arg(s),
291 CudaBuf::I8(s) => self.arg(s),
292 }
293 }
294}
295
296#[cfg(feature = "cuda")]
297unsafe impl<'a, 'b: 'a> cudarc::driver::PushKernelArg<&'b mut CudaBuf>
298 for cudarc::driver::LaunchArgs<'a>
299{
300 fn arg(&mut self, arg: &'b mut CudaBuf) -> &mut Self {
301 match arg {
302 CudaBuf::F16(s) => self.arg(s),
303 CudaBuf::F32(s) => self.arg(s),
304 CudaBuf::U32(s) => self.arg(s),
305 CudaBuf::I32(s) => self.arg(s),
306 CudaBuf::I8(s) => self.arg(s),
307 }
308 }
309}