use tl_cuda::{CudaTensor, DType};
use serial_test::serial;
#[test]
#[serial]
fn test_dequantize_q4_k_shape() {
let num_blocks = 1;
let bytes_per_block = 144;
let input_bytes = vec![0u8; num_blocks * bytes_per_block];
let input = CudaTensor::from_slice(&input_bytes, &[input_bytes.len()], DType::U8);
assert_eq!(input.shape(), &[input_bytes.len()]);
let target_shape = [num_blocks * 256];
let result = input.dequantize_q4_k(&target_shape);
match result {
Ok(output) => {
assert_eq!(output.shape(), &[256]);
assert_eq!(output.dtype(), DType::F32);
assert_eq!(output.elem_count(), 256);
let data = output.to_vec::<f32>();
assert!(data.iter().all(|x| x.is_finite()), "All outputs should be finite");
},
Err(e) => {
eprintln!("dequantize_q4_k returned error (may be expected for zero input): {}", e);
}
}
}
#[test]
#[serial]
fn test_dequantize_q4_k_multiple_blocks() {
let num_blocks = 4;
let bytes_per_block = 144;
let input_bytes = vec![0u8; num_blocks * bytes_per_block];
let input = CudaTensor::from_slice(&input_bytes, &[input_bytes.len()], DType::U8);
let target_shape = [num_blocks * 256];
let result = input.dequantize_q4_k(&target_shape);
match result {
Ok(output) => {
assert_eq!(output.shape(), &[1024]);
assert_eq!(output.elem_count(), 1024);
},
Err(e) => {
eprintln!("dequantize_q4_k error for multiple blocks: {}", e);
}
}
}
#[test]
#[serial]
fn test_dequantize_q4_k_invalid_size() {
let result_err = CudaTensor::from_slice(&[0u8; 144], &[144], DType::U8)
.dequantize_q4_k(&[100]);
assert!(result_err.is_err(), "Should fail for non-256-aligned element count");
}
#[test]
#[serial]
fn test_dequantize_q4_k_multidim_target() {
let num_blocks = 2;
let bytes_per_block = 144;
let input_bytes = vec![0u8; num_blocks * bytes_per_block];
let input = CudaTensor::from_slice(&input_bytes, &[input_bytes.len()], DType::U8);
let target_shape = [2, 256];
let result = input.dequantize_q4_k(&target_shape);
match result {
Ok(output) => {
assert_eq!(output.shape(), &[2, 256]);
assert_eq!(output.elem_count(), 512);
},
Err(e) => {
eprintln!("dequantize_q4_k error for multidim: {}", e);
}
}
}
#[test]
#[serial]
fn test_dequantize_q4_k_nonzero_data() {
let num_blocks = 1;
let bytes_per_block = 144;
let mut input_bytes = vec![0u8; num_blocks * bytes_per_block];
input_bytes[0] = 0x00;
input_bytes[1] = 0x3C;
let input = CudaTensor::from_slice(&input_bytes, &[input_bytes.len()], DType::U8);
let target_shape = [num_blocks * 256];
let result = input.dequantize_q4_k(&target_shape);
match result {
Ok(output) => {
assert_eq!(output.shape(), &[256]);
let data = output.to_vec::<f32>();
assert!(data.iter().all(|x| x.is_finite()), "All outputs should be finite");
},
Err(e) => {
eprintln!("dequantize_q4_k error for nonzero data: {}", e);
}
}
}