use cudarc::driver::{CudaSlice, CudaStream};
use std::sync::Arc;
use crate::backend::{BackendError, BackendResult};
use crate::tensor::DType;
#[allow(dead_code)]
pub struct GpuTensor {
data: GpuBuffer,
shape: Vec<usize>,
dtype: DType,
}
#[allow(dead_code)]
pub enum GpuBuffer {
F32(CudaSlice<f32>),
F16(CudaSlice<u16>), U8(CudaSlice<u8>), }
#[allow(dead_code)]
impl GpuTensor {
pub fn alloc(stream: &Arc<CudaStream>, shape: Vec<usize>, dtype: DType) -> BackendResult<Self> {
let numel: usize = shape.iter().product();
let data = match dtype {
DType::F32 => {
let slice = stream
.alloc_zeros::<f32>(numel)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
GpuBuffer::F32(slice)
}
DType::F16 => {
let slice = stream
.alloc_zeros::<u16>(numel)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
GpuBuffer::F16(slice)
}
DType::Q4_0
| DType::Q4_1
| DType::Q5_0
| DType::Q5_1
| DType::Q8_0
| DType::Q8_1
| DType::Q2K
| DType::Q3K
| DType::Q4K
| DType::Q5K
| DType::Q6K
| DType::Q8K => {
let bytes = quantized_bytes(numel, dtype);
let slice = stream
.alloc_zeros::<u8>(bytes)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
GpuBuffer::U8(slice)
}
_ => return Err(BackendError::UnsupportedDType(dtype)),
};
Ok(Self { data, shape, dtype })
}
pub fn from_f32(
stream: &Arc<CudaStream>,
data: &[f32],
shape: Vec<usize>,
) -> BackendResult<Self> {
let slice = stream
.clone_htod(data)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
Ok(Self {
data: GpuBuffer::F32(slice),
shape,
dtype: DType::F32,
})
}
pub fn from_bytes(
stream: &Arc<CudaStream>,
data: &[u8],
shape: Vec<usize>,
dtype: DType,
) -> BackendResult<Self> {
let slice = stream
.clone_htod(data)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
Ok(Self {
data: GpuBuffer::U8(slice),
shape,
dtype,
})
}
pub fn to_f32(&self, stream: &Arc<CudaStream>) -> BackendResult<Vec<f32>> {
match &self.data {
GpuBuffer::F32(slice) => stream
.clone_dtoh(slice)
.map_err(|e| BackendError::OperationFailed(format!("{}", e))),
_ => Err(BackendError::DTypeMismatch {
expected: DType::F32,
got: self.dtype,
}),
}
}
pub fn as_f32_slice(&self) -> BackendResult<&CudaSlice<f32>> {
match &self.data {
GpuBuffer::F32(slice) => Ok(slice),
_ => Err(BackendError::DTypeMismatch {
expected: DType::F32,
got: self.dtype,
}),
}
}
pub fn as_f32_slice_mut(&mut self) -> BackendResult<&mut CudaSlice<f32>> {
match &mut self.data {
GpuBuffer::F32(slice) => Ok(slice),
_ => Err(BackendError::DTypeMismatch {
expected: DType::F32,
got: self.dtype,
}),
}
}
pub fn as_u8_slice(&self) -> BackendResult<&CudaSlice<u8>> {
match &self.data {
GpuBuffer::U8(slice) => Ok(slice),
_ => Err(BackendError::DTypeMismatch {
expected: DType::Q4K, got: self.dtype,
}),
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn dtype(&self) -> DType {
self.dtype
}
}
#[allow(dead_code)]
fn quantized_bytes(numel: usize, dtype: DType) -> usize {
match dtype {
DType::Q4_0 => (numel / 32) * 18, DType::Q4_1 => (numel / 32) * 20,
DType::Q5_0 => (numel / 32) * 22,
DType::Q5_1 => (numel / 32) * 24,
DType::Q8_0 => (numel / 32) * 34,
DType::Q8_1 => (numel / 32) * 36,
DType::Q2K => (numel / 256) * 84,
DType::Q3K => (numel / 256) * 110,
DType::Q4K => (numel / 256) * 144,
DType::Q5K => (numel / 256) * 176,
DType::Q6K => (numel / 256) * 210,
DType::Q8K => (numel / 256) * 292,
_ => numel * 4, }
}
#[allow(dead_code)]
pub struct GpuWeightCache {
stream: Arc<CudaStream>,
weights: std::collections::HashMap<String, GpuTensor>,
total_bytes: usize,
}
#[allow(dead_code)]
impl GpuWeightCache {
pub fn new(stream: Arc<CudaStream>) -> Self {
Self {
stream,
weights: std::collections::HashMap::new(),
total_bytes: 0,
}
}
pub fn upload_f32(
&mut self,
name: String,
data: &[f32],
shape: Vec<usize>,
) -> BackendResult<()> {
let gpu_tensor = GpuTensor::from_f32(&self.stream, data, shape)?;
self.total_bytes += data.len() * 4;
self.weights.insert(name, gpu_tensor);
Ok(())
}
pub fn upload_quantized(
&mut self,
name: String,
data: &[u8],
shape: Vec<usize>,
dtype: DType,
) -> BackendResult<()> {
let gpu_tensor = GpuTensor::from_bytes(&self.stream, data, shape, dtype)?;
self.total_bytes += data.len();
self.weights.insert(name, gpu_tensor);
Ok(())
}
pub fn get(&self, name: &str) -> Option<&GpuTensor> {
self.weights.get(name)
}
pub fn total_bytes(&self) -> usize {
self.total_bytes
}
}