use cudarc::driver::{CudaDevice, CudaSlice};
use std::sync::Arc;
use std::collections::HashMap;
use crate::backend::{BackendError, BackendResult};
use crate::tensor::{DType, Tensor};
pub struct GpuWeightStore {
device: Arc<CudaDevice>,
weights: HashMap<String, GpuWeight>,
total_bytes: usize,
}
pub struct GpuWeight {
pub data: CudaSlice<f32>,
pub shape: Vec<usize>,
pub numel: usize,
}
impl GpuWeightStore {
pub fn new(device: Arc<CudaDevice>) -> Self {
Self {
device,
weights: HashMap::new(),
total_bytes: 0,
}
}
pub fn upload(&mut self, name: &str, tensor: &Tensor) -> BackendResult<()> {
let numel = tensor.numel();
let shape = tensor.shape().to_vec();
let key = tensor.name().unwrap_or(name).to_string();
let f32_data: Vec<f32> = if tensor.dtype() == DType::F32 {
tensor.as_f32()?.to_vec()
} else {
let mut dequant = Tensor::zeros(vec![numel], DType::F32);
crate::backend::cpu::ops::dequantize(tensor, &mut dequant)?;
dequant.as_f32()?.to_vec()
};
let gpu_data = self.device.htod_sync_copy(&f32_data)
.map_err(|e| BackendError::AllocationFailed(format!("{}", e)))?;
self.total_bytes += numel * 4;
self.weights.insert(key, GpuWeight {
data: gpu_data,
shape,
numel,
});
Ok(())
}
pub fn get(&self, name: &str) -> Option<&GpuWeight> {
self.weights.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.weights.contains_key(name)
}
pub fn vram_usage(&self) -> usize {
self.total_bytes
}
pub fn device(&self) -> &Arc<CudaDevice> {
&self.device
}
pub fn len(&self) -> usize {
self.weights.len()
}
pub fn is_empty(&self) -> bool {
self.weights.is_empty()
}
}
pub fn upload_model_weights(
device: Arc<CudaDevice>,
layers: &[crate::model::TransformerLayer],
embedding: &Tensor,
output: &crate::model::layers::Linear,
norm: &crate::model::layers::RMSNorm,
) -> BackendResult<GpuWeightStore> {
let mut store = GpuWeightStore::new(device);
store.upload("token_embd.weight", embedding)?;
for (i, layer) in layers.iter().enumerate() {
if i % 4 == 0 {
eprintln!(" Layer {}/{}", i + 1, layers.len());
}
store.upload(&format!("blk.{}.attn_q.weight", i), &layer.attention.wq.weight)?;
store.upload(&format!("blk.{}.attn_k.weight", i), &layer.attention.wk.weight)?;
store.upload(&format!("blk.{}.attn_v.weight", i), &layer.attention.wv.weight)?;
store.upload(&format!("blk.{}.attn_output.weight", i), &layer.attention.wo.weight)?;
if let Some(ref bias) = layer.attention.wq.bias {
store.upload(&format!("blk.{}.attn_q.bias", i), bias)?;
}
if let Some(ref bias) = layer.attention.wk.bias {
store.upload(&format!("blk.{}.attn_k.bias", i), bias)?;
}
if let Some(ref bias) = layer.attention.wv.bias {
store.upload(&format!("blk.{}.attn_v.bias", i), bias)?;
}
store.upload(&format!("blk.{}.attn_norm.weight", i), &layer.attn_norm.weight)?;
store.upload(&format!("blk.{}.ffn_gate.weight", i), &layer.ffn.w_gate.weight)?;
store.upload(&format!("blk.{}.ffn_up.weight", i), &layer.ffn.w_up.weight)?;
store.upload(&format!("blk.{}.ffn_down.weight", i), &layer.ffn.w_down.weight)?;
store.upload(&format!("blk.{}.ffn_norm.weight", i), &layer.ffn_norm.weight)?;
}
store.upload("output_norm.weight", &norm.weight)?;
store.upload("output.weight", &output.weight)?;
if let Some(ref bias) = output.bias {
store.upload("output.bias", bias)?;
}
let vram_mb = store.vram_usage() as f64 / (1024.0 * 1024.0);
eprintln!("Upload complete: {} weights, {:.1} MB VRAM", store.len(), vram_mb);
Ok(store)
}