use cudarc::driver::{CudaDevice, CudaSlice};
use std::sync::Arc;
use std::collections::HashMap;
use crate::backend::{BackendError, BackendResult};
use crate::tensor::{DType, Tensor};
pub struct GpuWeight {
pub data: CudaSlice<f32>,
pub shape: Vec<usize>,
pub numel: usize,
}
impl GpuWeight {
pub fn from_tensor(device: &Arc<CudaDevice>, tensor: &Tensor) -> BackendResult<Self> {
let shape = tensor.shape().to_vec();
let numel: usize = shape.iter().product();
let f32_data: Vec<f32> = if tensor.dtype() == DType::F32 {
tensor.as_f32()?.to_vec()
} else {
let mut dequant = Tensor::zeros(shape.clone(), DType::F32);
crate::backend::cpu::ops::dequantize(tensor, &mut dequant)?;
dequant.as_f32()?.to_vec()
};
let data = device.htod_sync_copy(&f32_data)
.map_err(|e| BackendError::AllocationFailed(format!("GPU upload failed: {}", e)))?;
Ok(Self { data, shape, numel })
}
}
pub struct GpuLinear {
pub weight: GpuWeight,
pub bias: Option<CudaSlice<f32>>,
pub in_features: usize,
pub out_features: usize,
}
impl GpuLinear {
pub fn from_linear(
device: &Arc<CudaDevice>,
weight: &Tensor,
bias: Option<&Tensor>
) -> BackendResult<Self> {
let gpu_weight = GpuWeight::from_tensor(device, weight)?;
let in_features = gpu_weight.shape[0];
let out_features = gpu_weight.shape[1];
let gpu_bias = if let Some(b) = bias {
let bias_data = b.as_f32()?;
Some(device.htod_sync_copy(bias_data)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?)
} else {
None
};
Ok(Self {
weight: gpu_weight,
bias: gpu_bias,
in_features,
out_features,
})
}
}
pub struct GpuRMSNorm {
pub weight: CudaSlice<f32>,
pub eps: f32,
pub hidden_size: usize,
}
impl GpuRMSNorm {
pub fn from_rms_norm(device: &Arc<CudaDevice>, weight: &Tensor, eps: f32) -> BackendResult<Self> {
let weight_data = weight.as_f32()?;
let hidden_size = weight_data.len();
let gpu_weight = device.htod_sync_copy(weight_data)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
Ok(Self {
weight: gpu_weight,
eps,
hidden_size,
})
}
}
pub struct GpuAttention {
pub wq: GpuLinear,
pub wk: GpuLinear,
pub wv: GpuLinear,
pub wo: GpuLinear,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub use_neox_rope: bool,
}
pub struct GpuFFN {
pub w1: GpuLinear, pub w2: GpuLinear, pub w3: GpuLinear, }
pub struct GpuTransformerLayer {
pub attention_norm: GpuRMSNorm,
pub attention: GpuAttention,
pub ffn_norm: GpuRMSNorm,
pub ffn: GpuFFN,
}
pub struct GpuKVCache {
pub k_cache: CudaSlice<f32>,
pub v_cache: CudaSlice<f32>,
pub pos: usize,
pub max_seq_len: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
}
impl GpuKVCache {
pub fn new(device: &Arc<CudaDevice>, num_kv_heads: usize, max_seq_len: usize, head_dim: usize) -> BackendResult<Self> {
let cache_size = num_kv_heads * max_seq_len * head_dim;
let k_cache = device.alloc_zeros::<f32>(cache_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
let v_cache = device.alloc_zeros::<f32>(cache_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
Ok(Self {
k_cache,
v_cache,
pos: 0,
max_seq_len,
num_kv_heads,
head_dim,
})
}
pub fn reset(&mut self) {
self.pos = 0;
}
}
pub struct GpuModel {
pub device: Arc<CudaDevice>,
pub token_embedding: GpuWeight,
pub layers: Vec<GpuTransformerLayer>,
pub norm: GpuRMSNorm,
pub output: GpuLinear,
pub kv_caches: Vec<GpuKVCache>,
pub hidden_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub vocab_size: usize,
pub num_layers: usize,
pub freq_base: f32,
pub freq_scale: f32,
pub scratch: GpuScratchBuffers,
}
pub struct GpuScratchBuffers {
pub hidden: CudaSlice<f32>,
pub residual: CudaSlice<f32>,
pub attn_out: CudaSlice<f32>,
pub ffn_gate: CudaSlice<f32>,
pub ffn_up: CudaSlice<f32>,
pub ffn_out: CudaSlice<f32>,
pub q: CudaSlice<f32>, pub k: CudaSlice<f32>, pub v: CudaSlice<f32>, pub attn_scores: CudaSlice<f32>,
pub logits: CudaSlice<f32>,
}
impl GpuScratchBuffers {
pub fn new(
device: &Arc<CudaDevice>,
hidden_size: usize,
intermediate_size: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
max_seq_len: usize,
vocab_size: usize,
) -> BackendResult<Self> {
Ok(Self {
hidden: device.alloc_zeros::<f32>(hidden_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
residual: device.alloc_zeros::<f32>(hidden_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
attn_out: device.alloc_zeros::<f32>(hidden_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
ffn_gate: device.alloc_zeros::<f32>(intermediate_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
ffn_up: device.alloc_zeros::<f32>(intermediate_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
ffn_out: device.alloc_zeros::<f32>(hidden_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
q: device.alloc_zeros::<f32>(num_heads * head_dim)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
k: device.alloc_zeros::<f32>(num_kv_heads * head_dim)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
v: device.alloc_zeros::<f32>(num_kv_heads * head_dim)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
attn_scores: device.alloc_zeros::<f32>(num_heads * max_seq_len)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
logits: device.alloc_zeros::<f32>(vocab_size)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?,
})
}
}