use crate::tensor::CudaTensor;
use crate::DType;
use tl_backend::{self, BackendResult, DType as BackendDType, GpuTensor};
fn to_backend_dtype(dtype: DType) -> BackendDType {
match dtype {
DType::F32 => BackendDType::F32,
DType::I64 => BackendDType::I64,
DType::I32 => BackendDType::I32,
DType::F16 => BackendDType::F16,
DType::U8 => BackendDType::U8,
}
}
fn from_backend_dtype(dtype: BackendDType) -> DType {
match dtype {
BackendDType::F32 => DType::F32,
BackendDType::I64 => DType::I64,
BackendDType::I32 => DType::I32,
BackendDType::F16 => DType::F16,
BackendDType::U8 => DType::U8,
_ => DType::F32,
}
}
impl GpuTensor for CudaTensor {
fn shape(&self) -> &[usize] {
&self.shape
}
fn dtype(&self) -> BackendDType {
to_backend_dtype(self.dtype)
}
fn to_vec_f32(&self) -> Vec<f32> {
self.to_vec::<f32>()
}
fn to_vec_i64(&self) -> Vec<i64> {
self.to_vec::<i64>()
}
fn from_slice_f32(data: &[f32], shape: &[usize]) -> BackendResult<Self> {
Ok(CudaTensor::from_slice(data, shape, DType::F32))
}
fn from_slice_i64(data: &[i64], shape: &[usize]) -> BackendResult<Self> {
Ok(CudaTensor::from_slice(data, shape, DType::I64))
}
fn zeros(shape: &[usize], dtype: BackendDType) -> BackendResult<Self> {
Ok(CudaTensor::zeros(shape, from_backend_dtype(dtype)))
}
fn ones(shape: &[usize], dtype: BackendDType) -> BackendResult<Self> {
Ok(CudaTensor::ones(shape, from_backend_dtype(dtype)))
}
fn randn(shape: &[usize], dtype: BackendDType) -> BackendResult<Self> {
Ok(CudaTensor::randn(shape, from_backend_dtype(dtype)))
}
fn arange(start: i64, end: i64, dtype: BackendDType) -> BackendResult<Self> {
let local_dtype = from_backend_dtype(dtype);
let count = (end - start) as usize;
match local_dtype {
DType::F32 => {
let data: Vec<f32> = (start..end).map(|i| i as f32).collect();
Ok(CudaTensor::from_slice(&data, &[count], local_dtype))
}
DType::I64 => {
let data: Vec<i64> = (start..end).collect();
Ok(CudaTensor::from_slice(&data, &[count], local_dtype))
}
_ => {
let data: Vec<f32> = (start..end).map(|i| i as f32).collect();
Ok(CudaTensor::from_slice(&data, &[count], DType::F32))
}
}
}
fn clone_data(&self) -> BackendResult<Self> {
CudaTensor::clone_data(self)
}
}
impl CudaTensor {
pub fn to_dtype(&self, target: DType) -> BackendResult<Self> {
if self.dtype == target {
return self.clone_data();
}
match (self.dtype, target) {
(DType::F32, DType::I64) => {
let data = self.to_vec::<f32>();
let i64_data: Vec<i64> = data.iter().map(|&x| x as i64).collect();
Ok(CudaTensor::from_slice(&i64_data, &self.shape, DType::I64))
}
(DType::I64, DType::F32) => {
let data = self.to_vec::<i64>();
let f32_data: Vec<f32> = data.iter().map(|&x| x as f32).collect();
Ok(CudaTensor::from_slice(&f32_data, &self.shape, DType::F32))
}
(DType::U8, DType::F32) => {
let data = self.to_vec::<u8>();
let f32_data: Vec<f32> = data.iter().map(|&x| x as f32).collect();
Ok(CudaTensor::from_slice(&f32_data, &self.shape, DType::F32))
}
(DType::U8, DType::I64) => {
let data = self.to_vec::<u8>();
let i64_data: Vec<i64> = data.iter().map(|&x| x as i64).collect();
Ok(CudaTensor::from_slice(&i64_data, &self.shape, DType::I64))
}
(DType::F32, DType::U8) => {
let data = self.to_vec::<f32>();
let u8_data: Vec<u8> = data.iter().map(|&x| x.clamp(0.0, 255.0) as u8).collect();
Ok(CudaTensor::from_slice(&u8_data, &self.shape, DType::U8))
}
_ => {
let f32_data = self.to_vec::<f32>();
Ok(CudaTensor::from_slice(&f32_data, &self.shape, target))
}
}
}
pub fn dequantize_q4_k(&self, target_shape: &[usize]) -> BackendResult<Self> {
use tl_backend::BackendError;
let num_elements: usize = target_shape.iter().product();
if num_elements % 256 != 0 {
return Err(BackendError::ArgumentError(format!(
"Q4_K num_elements must be divisible by 256, got {}",
num_elements
)));
}
let num_blocks = num_elements / 256;
let raw = self.to_vec::<u8>();
let expected_bytes = num_blocks * 144;
if raw.len() < expected_bytes {
return Err(BackendError::ArgumentError(format!(
"Q4_K: expected {} bytes, got {}",
expected_bytes,
raw.len()
)));
}
let mut output = vec![0.0f32; num_elements];
for block_idx in 0..num_blocks {
let base = block_idx * 144;
let d = f16_to_f32(raw[base], raw[base + 1]);
let dmin = f16_to_f32(raw[base + 2], raw[base + 3]);
let sc = &raw[base + 4..base + 16];
let qs = &raw[base + 16..base + 144];
let out_base = block_idx * 256;
for j in 0..8 {
let (scale, min) = decode_q4k_scale_min(sc, j);
let s = d * scale as f32;
let m = dmin * min as f32;
let sub_base = out_base + j * 32;
let q_offset = j * 16; for k in 0..16 {
let byte = qs[q_offset + k];
let lo = (byte & 0x0F) as f32;
let hi = ((byte >> 4) & 0x0F) as f32;
output[sub_base + k] = lo * s - m;
output[sub_base + k + 16] = hi * s - m;
}
}
}
Ok(CudaTensor::from_slice(&output, target_shape, DType::F32))
}
}
fn f16_to_f32(lo: u8, hi: u8) -> f32 {
let bits = (hi as u16) << 8 | lo as u16;
let sign = ((bits >> 15) & 1) as u32;
let exp = ((bits >> 10) & 0x1F) as u32;
let frac = (bits & 0x3FF) as u32;
if exp == 0 {
if frac == 0 {
f32::from_bits(sign << 31)
} else {
let mut val = frac as f32 / 1024.0;
val *= 2.0f32.powi(-14);
if sign == 1 {
-val
} else {
val
}
}
} else if exp == 31 {
if frac == 0 {
if sign == 1 {
f32::NEG_INFINITY
} else {
f32::INFINITY
}
} else {
f32::NAN
}
} else {
let f32_exp = exp as i32 - 15 + 127;
let f32_bits = (sign << 31) | ((f32_exp as u32) << 23) | (frac << 13);
f32::from_bits(f32_bits)
}
}
fn decode_q4k_scale_min(sc: &[u8], j: usize) -> (u8, u8) {
if j < 4 {
let scale = (sc[j] & 0x3F) as u8;
let min = (sc[j + 4] & 0x3F) as u8;
(scale, min)
} else {
let k = j - 4;
let scale_lo = (sc[k] >> 6) as u8;
let scale_hi = ((sc[j + 4] & 0x0F) as u8) << 2;
let scale = scale_lo | scale_hi;
let min_lo = (sc[k + 4] >> 6) as u8;
let min_hi = ((sc[j + 4] >> 4) & 0x0F) << 2;
let min = min_lo | (min_hi as u8);
(scale, min)
}
}