use std::sync::Arc;
use candle_core::{DType, Storage, Tensor};
use cudarc::driver::{CudaSlice, CudaStream};
pub struct GpuWeight {
pub slice: CudaSlice<half::f16>,
pub len: usize,
}
impl GpuWeight {
pub fn from_tensor(
tensor: &Tensor,
target_stream: &Arc<CudaStream>,
) -> candle_core::Result<Self> {
if tensor.dtype() != DType::F16 {
candle_core::bail!("GpuWeight: expected F16, got {:?}", tensor.dtype());
}
let tensor = tensor.contiguous()?;
let len = tensor.elem_count();
let (storage, layout) = tensor.storage_and_layout();
let cuda_storage = match &*storage {
Storage::Cuda(cs) => cs,
_ => candle_core::bail!("GpuWeight: tensor must be on CUDA"),
};
let src = cuda_storage.as_cuda_slice::<half::f16>()?;
let offset = layout.start_offset();
if offset != 0 {
tracing::warn!(
"GpuWeight: tensor has non-zero start_offset={}, len={}, storage_len={}",
offset,
len,
src.len()
);
}
let src_view = src.slice(offset..offset + len);
let owned = target_stream
.clone_dtod(&src_view)
.map_err(|e| candle_core::Error::Msg(format!("weight clone_dtod: {e}")))?;
drop(storage);
Ok(Self { slice: owned, len })
}
}
pub struct GpuQuantWeight {
pub qweight: CudaSlice<i32>,
pub scales: CudaSlice<half::f16>,
pub qzeros: Option<CudaSlice<i32>>,
pub k: usize,
pub n: usize,
pub group_size: usize,
pub symmetric: bool,
}
pub enum LinearWeight {
Fp16(GpuWeight),
Int4(GpuQuantWeight),
Marlin(crate::marlin::MarlinWeight),
}
impl LinearWeight {
pub fn as_fp16(&self) -> &CudaSlice<half::f16> {
match self {
LinearWeight::Fp16(w) => &w.slice,
LinearWeight::Int4(_) | LinearWeight::Marlin(_) => {
panic!("Cannot get fp16 slice from quantized weight")
}
}
}
pub fn is_quantized(&self) -> bool {
matches!(self, LinearWeight::Int4(_) | LinearWeight::Marlin(_))
}
}
pub struct LayerWeights {
pub input_ln_w: GpuWeight,
pub qkv_w: LinearWeight,
pub q_norm_w: Option<GpuWeight>,
pub k_norm_w: Option<GpuWeight>,
pub o_w: LinearWeight,
pub post_ln_w: GpuWeight,
pub gate_up_w: LinearWeight,
pub down_w: LinearWeight,
}
pub struct TransformerGpuWeights {
pub embed_table: GpuWeight,
pub layers: Vec<LayerWeights>,
pub final_norm_w: GpuWeight,
pub lm_head_w: LinearWeight,
pub rope_cos: GpuWeight,
pub rope_sin: GpuWeight,
}
pub type Qwen3Weights = TransformerGpuWeights;