use cubecl::prelude::*;
pub const GGUF_BLOCK_SIZE: usize = 32;
const Q4_0_BLOCK_BYTES: usize = 18; const Q4_1_BLOCK_BYTES: usize = 20; const Q5_0_BLOCK_BYTES: usize = 22; const Q5_1_BLOCK_BYTES: usize = 24; const Q8_0_BLOCK_BYTES: usize = 34; const Q8_1_BLOCK_BYTES: usize = 40;
#[allow(non_camel_case_types)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GgufBlockKind {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
}
impl GgufBlockKind {
#[must_use]
pub const fn block_elements(self) -> usize {
GGUF_BLOCK_SIZE
}
#[must_use]
pub const fn block_bytes(self) -> usize {
match self {
Self::Q4_0 => Q4_0_BLOCK_BYTES,
Self::Q4_1 => Q4_1_BLOCK_BYTES,
Self::Q5_0 => Q5_0_BLOCK_BYTES,
Self::Q5_1 => Q5_1_BLOCK_BYTES,
Self::Q8_0 => Q8_0_BLOCK_BYTES,
Self::Q8_1 => Q8_1_BLOCK_BYTES,
}
}
}
#[inline]
fn read_f16_to_f32(b0: u8, b1: u8) -> f32 {
half::f16::from_bits(u16::from_le_bytes([b0, b1])).to_f32()
}
#[inline]
fn pack_4_bytes_le(b0: u8, b1: u8, b2: u8, b3: u8) -> u32 {
u32::from_le_bytes([b0, b1, b2, b3])
}
pub fn split_q4_0_blocks(raw: &[u8], num_blocks: usize) -> (Vec<f32>, Vec<u32>) {
assert!(
raw.len() >= num_blocks * Q4_0_BLOCK_BYTES,
"split_q4_0_blocks: need {} bytes for {num_blocks} blocks, got {}",
num_blocks * Q4_0_BLOCK_BYTES,
raw.len()
);
let mut scales = Vec::with_capacity(num_blocks);
let mut nibbles = Vec::with_capacity(num_blocks * 4);
for b in 0..num_blocks {
let off = b * Q4_0_BLOCK_BYTES;
scales.push(read_f16_to_f32(raw[off], raw[off + 1]));
for u in 0..4 {
let base = off + 2 + u * 4;
nibbles.push(pack_4_bytes_le(
raw[base],
raw[base + 1],
raw[base + 2],
raw[base + 3],
));
}
}
(scales, nibbles)
}
pub fn split_q4_1_blocks(raw: &[u8], num_blocks: usize) -> (Vec<f32>, Vec<f32>, Vec<u32>) {
assert!(
raw.len() >= num_blocks * Q4_1_BLOCK_BYTES,
"split_q4_1_blocks: need {} bytes for {num_blocks} blocks, got {}",
num_blocks * Q4_1_BLOCK_BYTES,
raw.len()
);
let mut scales = Vec::with_capacity(num_blocks);
let mut mins = Vec::with_capacity(num_blocks);
let mut nibbles = Vec::with_capacity(num_blocks * 4);
for b in 0..num_blocks {
let off = b * Q4_1_BLOCK_BYTES;
scales.push(read_f16_to_f32(raw[off], raw[off + 1]));
mins.push(read_f16_to_f32(raw[off + 2], raw[off + 3]));
for u in 0..4 {
let base = off + 4 + u * 4;
nibbles.push(pack_4_bytes_le(
raw[base],
raw[base + 1],
raw[base + 2],
raw[base + 3],
));
}
}
(scales, mins, nibbles)
}
pub fn split_q8_0_blocks(raw: &[u8], num_blocks: usize) -> (Vec<f32>, Vec<u32>) {
assert!(
raw.len() >= num_blocks * Q8_0_BLOCK_BYTES,
"split_q8_0_blocks: need {} bytes for {num_blocks} blocks, got {}",
num_blocks * Q8_0_BLOCK_BYTES,
raw.len()
);
let mut scales = Vec::with_capacity(num_blocks);
let mut bytes = Vec::with_capacity(num_blocks * 8);
for b in 0..num_blocks {
let off = b * Q8_0_BLOCK_BYTES;
scales.push(read_f16_to_f32(raw[off], raw[off + 1]));
for u in 0..8 {
let base = off + 2 + u * 4;
bytes.push(pack_4_bytes_le(
raw[base],
raw[base + 1],
raw[base + 2],
raw[base + 3],
));
}
}
(scales, bytes)
}
pub fn split_q5_0_blocks(raw: &[u8], num_blocks: usize) -> (Vec<f32>, Vec<u32>, Vec<u32>) {
assert!(
raw.len() >= num_blocks * Q5_0_BLOCK_BYTES,
"split_q5_0_blocks: need {} bytes for {num_blocks} blocks, got {}",
num_blocks * Q5_0_BLOCK_BYTES,
raw.len()
);
let mut scales = Vec::with_capacity(num_blocks);
let mut qh = Vec::with_capacity(num_blocks);
let mut nibbles = Vec::with_capacity(num_blocks * 4);
for b in 0..num_blocks {
let off = b * Q5_0_BLOCK_BYTES;
scales.push(read_f16_to_f32(raw[off], raw[off + 1]));
qh.push(pack_4_bytes_le(
raw[off + 2],
raw[off + 3],
raw[off + 4],
raw[off + 5],
));
for u in 0..4 {
let base = off + 6 + u * 4;
nibbles.push(pack_4_bytes_le(
raw[base],
raw[base + 1],
raw[base + 2],
raw[base + 3],
));
}
}
(scales, qh, nibbles)
}
pub fn split_q5_1_blocks(
raw: &[u8],
num_blocks: usize,
) -> (Vec<f32>, Vec<f32>, Vec<u32>, Vec<u32>) {
assert!(
raw.len() >= num_blocks * Q5_1_BLOCK_BYTES,
"split_q5_1_blocks: need {} bytes for {num_blocks} blocks, got {}",
num_blocks * Q5_1_BLOCK_BYTES,
raw.len()
);
let mut scales = Vec::with_capacity(num_blocks);
let mut mins = Vec::with_capacity(num_blocks);
let mut qh = Vec::with_capacity(num_blocks);
let mut nibbles = Vec::with_capacity(num_blocks * 4);
for b in 0..num_blocks {
let off = b * Q5_1_BLOCK_BYTES;
scales.push(read_f16_to_f32(raw[off], raw[off + 1]));
mins.push(read_f16_to_f32(raw[off + 2], raw[off + 3]));
qh.push(pack_4_bytes_le(
raw[off + 4],
raw[off + 5],
raw[off + 6],
raw[off + 7],
));
for u in 0..4 {
let base = off + 8 + u * 4;
nibbles.push(pack_4_bytes_le(
raw[base],
raw[base + 1],
raw[base + 2],
raw[base + 3],
));
}
}
(scales, mins, qh, nibbles)
}
pub fn split_q8_1_blocks(raw: &[u8], num_blocks: usize) -> (Vec<f32>, Vec<f32>, Vec<u32>) {
assert!(
raw.len() >= num_blocks * Q8_1_BLOCK_BYTES,
"split_q8_1_blocks: need {} bytes for {num_blocks} blocks, got {}",
num_blocks * Q8_1_BLOCK_BYTES,
raw.len()
);
let mut scales = Vec::with_capacity(num_blocks);
let mut mins = Vec::with_capacity(num_blocks);
let mut bytes = Vec::with_capacity(num_blocks * 8);
for b in 0..num_blocks {
let off = b * Q8_1_BLOCK_BYTES;
scales.push(f32::from_le_bytes([
raw[off],
raw[off + 1],
raw[off + 2],
raw[off + 3],
]));
mins.push(f32::from_le_bytes([
raw[off + 4],
raw[off + 5],
raw[off + 6],
raw[off + 7],
]));
for u in 0..8 {
let base = off + 8 + u * 4;
bytes.push(pack_4_bytes_le(
raw[base],
raw[base + 1],
raw[base + 2],
raw[base + 3],
));
}
}
(scales, mins, bytes)
}
#[cube(launch_unchecked)]
pub fn kernel_dequantize_q4_0<F: Float>(
scales: &Array<F>,
nibbles: &Array<u32>,
out: &mut Array<F>,
) {
if ABSOLUTE_POS < out.len() {
let t = ABSOLUTE_POS;
let block_id = t / 32;
let elem = t % 32;
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let global_u32 = block_id * 4 + u32_idx_in_block;
let packed = nibbles[global_u32];
let byte = (packed >> (byte_in_u32 as u32 * 8u32)) & 0xFFu32;
let nibble = (byte >> (is_high as u32 * 4u32)) & 0xFu32;
let nibble_f = F::cast_from(nibble) - F::new(8.0);
out[t] = scales[block_id] * nibble_f;
}
}
#[cube(launch_unchecked)]
pub fn kernel_dequantize_q4_1<F: Float>(
scales: &Array<F>,
mins: &Array<F>,
nibbles: &Array<u32>,
out: &mut Array<F>,
) {
if ABSOLUTE_POS < out.len() {
let t = ABSOLUTE_POS;
let block_id = t / 32;
let elem = t % 32;
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let global_u32 = block_id * 4 + u32_idx_in_block;
let packed = nibbles[global_u32];
let byte = (packed >> (byte_in_u32 as u32 * 8u32)) & 0xFFu32;
let nibble = (byte >> (is_high as u32 * 4u32)) & 0xFu32;
let nibble_f = F::cast_from(nibble);
out[t] = scales[block_id] * nibble_f + mins[block_id];
}
}
#[cube(launch_unchecked)]
pub fn kernel_dequantize_q8_0<F: Float>(scales: &Array<F>, bytes: &Array<u32>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
let t = ABSOLUTE_POS;
let block_id = t / 32;
let elem = t % 32;
let u32_idx_in_block = elem / 4;
let byte_in_u32 = elem % 4;
let global_u32 = block_id * 8 + u32_idx_in_block;
let packed = bytes[global_u32];
let byte = (packed >> (byte_in_u32 as u32 * 8u32)) & 0xFFu32;
let sign_bit = (byte >> 7u32) & 1u32;
let byte_f = F::cast_from(byte);
let sign_f = F::cast_from(sign_bit);
let signed_f = byte_f - sign_f * F::new(256.0);
out[t] = scales[block_id] * signed_f;
}
}
#[cube(launch_unchecked)]
pub fn kernel_dequantize_q5_0<F: Float>(
scales: &Array<F>,
qh: &Array<u32>,
nibbles: &Array<u32>,
out: &mut Array<F>,
) {
if ABSOLUTE_POS < out.len() {
let t = ABSOLUTE_POS;
let block_id = t / 32;
let elem = t % 32;
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let global_u32 = block_id * 4 + u32_idx_in_block;
let packed = nibbles[global_u32];
let byte = (packed >> (byte_in_u32 as u32 * 8u32)) & 0xFFu32;
let low_nibble = (byte >> (is_high as u32 * 4u32)) & 0xFu32;
let qh_word = qh[block_id];
let bit_pos = byte_idx as u32 * 2u32 + is_high as u32;
let high_bit = (qh_word >> bit_pos) & 1u32;
let val5 = low_nibble | (high_bit << 4u32);
let val_f = F::cast_from(val5) - F::new(16.0);
out[t] = scales[block_id] * val_f;
}
}
#[cube(launch_unchecked)]
pub fn kernel_dequantize_q5_1<F: Float>(
scales: &Array<F>,
mins: &Array<F>,
qh: &Array<u32>,
nibbles: &Array<u32>,
out: &mut Array<F>,
) {
if ABSOLUTE_POS < out.len() {
let t = ABSOLUTE_POS;
let block_id = t / 32;
let elem = t % 32;
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let global_u32 = block_id * 4 + u32_idx_in_block;
let packed = nibbles[global_u32];
let byte = (packed >> (byte_in_u32 as u32 * 8u32)) & 0xFFu32;
let low_nibble = (byte >> (is_high as u32 * 4u32)) & 0xFu32;
let qh_word = qh[block_id];
let bit_pos = byte_idx as u32 * 2u32 + is_high as u32;
let high_bit = (qh_word >> bit_pos) & 1u32;
let val5 = low_nibble | (high_bit << 4u32);
let val_f = F::cast_from(val5);
out[t] = scales[block_id] * val_f + mins[block_id];
}
}
#[cube(launch_unchecked)]
pub fn kernel_dequantize_q8_1<F: Float>(
scales: &Array<F>,
mins: &Array<F>,
bytes: &Array<u32>,
out: &mut Array<F>,
) {
if ABSOLUTE_POS < out.len() {
let t = ABSOLUTE_POS;
let block_id = t / 32;
let elem = t % 32;
let u32_idx_in_block = elem / 4;
let byte_in_u32 = elem % 4;
let global_u32 = block_id * 8 + u32_idx_in_block;
let packed = bytes[global_u32];
let byte = (packed >> (byte_in_u32 as u32 * 8u32)) & 0xFFu32;
let sign_bit = (byte >> 7u32) & 1u32;
let byte_f = F::cast_from(byte);
let sign_f = F::cast_from(sign_bit);
let signed_f = byte_f - sign_f * F::new(256.0);
out[t] = scales[block_id] * signed_f + mins[block_id];
}
}
pub fn dequantize_q4_0_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
scales: &[f32],
nibbles: &[u32],
num_elements: usize,
) -> cubecl::server::Handle {
debug_assert_eq!(num_elements, scales.len() * 32);
debug_assert_eq!(nibbles.len(), scales.len() * 4);
let scales_handle = client.create_from_slice(f32::as_bytes(scales));
let nibbles_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(
nibbles.as_ptr() as *const u8,
std::mem::size_of_val(nibbles),
)
});
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let (count, dim) = crate::elementwise_launch_dims(num_elements as u32);
unsafe {
kernel_dequantize_q4_0::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(scales_handle, scales.len()),
ArrayArg::from_raw_parts(nibbles_handle, nibbles.len()),
ArrayArg::from_raw_parts(out_handle.clone(), num_elements),
);
}
out_handle
}
pub fn dequantize_q4_1_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
scales: &[f32],
mins: &[f32],
nibbles: &[u32],
num_elements: usize,
) -> cubecl::server::Handle {
debug_assert_eq!(num_elements, scales.len() * 32);
debug_assert_eq!(scales.len(), mins.len());
debug_assert_eq!(nibbles.len(), scales.len() * 4);
let scales_handle = client.create_from_slice(f32::as_bytes(scales));
let mins_handle = client.create_from_slice(f32::as_bytes(mins));
let nibbles_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(
nibbles.as_ptr() as *const u8,
std::mem::size_of_val(nibbles),
)
});
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let (count, dim) = crate::elementwise_launch_dims(num_elements as u32);
unsafe {
kernel_dequantize_q4_1::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(scales_handle, scales.len()),
ArrayArg::from_raw_parts(mins_handle, mins.len()),
ArrayArg::from_raw_parts(nibbles_handle, nibbles.len()),
ArrayArg::from_raw_parts(out_handle.clone(), num_elements),
);
}
out_handle
}
pub fn dequantize_q8_0_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
scales: &[f32],
bytes: &[u32],
num_elements: usize,
) -> cubecl::server::Handle {
debug_assert_eq!(num_elements, scales.len() * 32);
debug_assert_eq!(bytes.len(), scales.len() * 8);
let scales_handle = client.create_from_slice(f32::as_bytes(scales));
let bytes_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(bytes.as_ptr() as *const u8, std::mem::size_of_val(bytes))
});
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let (count, dim) = crate::elementwise_launch_dims(num_elements as u32);
unsafe {
kernel_dequantize_q8_0::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(scales_handle, scales.len()),
ArrayArg::from_raw_parts(bytes_handle, bytes.len()),
ArrayArg::from_raw_parts(out_handle.clone(), num_elements),
);
}
out_handle
}
pub fn dequantize_q5_0_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
scales: &[f32],
qh: &[u32],
nibbles: &[u32],
num_elements: usize,
) -> cubecl::server::Handle {
debug_assert_eq!(num_elements, scales.len() * 32);
debug_assert_eq!(qh.len(), scales.len());
debug_assert_eq!(nibbles.len(), scales.len() * 4);
let scales_handle = client.create_from_slice(f32::as_bytes(scales));
let qh_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(qh.as_ptr() as *const u8, std::mem::size_of_val(qh))
});
let nibbles_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(
nibbles.as_ptr() as *const u8,
std::mem::size_of_val(nibbles),
)
});
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let (count, dim) = crate::elementwise_launch_dims(num_elements as u32);
unsafe {
kernel_dequantize_q5_0::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(scales_handle, scales.len()),
ArrayArg::from_raw_parts(qh_handle, qh.len()),
ArrayArg::from_raw_parts(nibbles_handle, nibbles.len()),
ArrayArg::from_raw_parts(out_handle.clone(), num_elements),
);
}
out_handle
}
pub fn dequantize_q5_1_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
scales: &[f32],
mins: &[f32],
qh: &[u32],
nibbles: &[u32],
num_elements: usize,
) -> cubecl::server::Handle {
debug_assert_eq!(num_elements, scales.len() * 32);
debug_assert_eq!(mins.len(), scales.len());
debug_assert_eq!(qh.len(), scales.len());
debug_assert_eq!(nibbles.len(), scales.len() * 4);
let scales_handle = client.create_from_slice(f32::as_bytes(scales));
let mins_handle = client.create_from_slice(f32::as_bytes(mins));
let qh_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(qh.as_ptr() as *const u8, std::mem::size_of_val(qh))
});
let nibbles_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(
nibbles.as_ptr() as *const u8,
std::mem::size_of_val(nibbles),
)
});
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let (count, dim) = crate::elementwise_launch_dims(num_elements as u32);
unsafe {
kernel_dequantize_q5_1::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(scales_handle, scales.len()),
ArrayArg::from_raw_parts(mins_handle, mins.len()),
ArrayArg::from_raw_parts(qh_handle, qh.len()),
ArrayArg::from_raw_parts(nibbles_handle, nibbles.len()),
ArrayArg::from_raw_parts(out_handle.clone(), num_elements),
);
}
out_handle
}
pub fn dequantize_q8_1_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
scales: &[f32],
mins: &[f32],
bytes: &[u32],
num_elements: usize,
) -> cubecl::server::Handle {
debug_assert_eq!(num_elements, scales.len() * 32);
debug_assert_eq!(mins.len(), scales.len());
debug_assert_eq!(bytes.len(), scales.len() * 8);
let scales_handle = client.create_from_slice(f32::as_bytes(scales));
let mins_handle = client.create_from_slice(f32::as_bytes(mins));
let bytes_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(bytes.as_ptr() as *const u8, std::mem::size_of_val(bytes))
});
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let (count, dim) = crate::elementwise_launch_dims(num_elements as u32);
unsafe {
kernel_dequantize_q8_1::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(scales_handle, scales.len()),
ArrayArg::from_raw_parts(mins_handle, mins.len()),
ArrayArg::from_raw_parts(bytes_handle, bytes.len()),
ArrayArg::from_raw_parts(out_handle.clone(), num_elements),
);
}
out_handle
}
#[cfg(test)]
pub(crate) fn dequantize_q4_0_reference(scales: &[f32], nibbles: &[u32]) -> Vec<f32> {
let num_blocks = scales.len();
let mut out = Vec::with_capacity(num_blocks * 32);
for b in 0..num_blocks {
let scale = scales[b];
for elem in 0..32 {
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let packed = nibbles[b * 4 + u32_idx_in_block];
let byte = (packed >> (byte_in_u32 * 8)) & 0xFF;
let nibble = (byte >> (is_high * 4)) & 0xF;
out.push(scale * (nibble as f32 - 8.0));
}
}
out
}
#[cfg(test)]
pub(crate) fn dequantize_q4_1_reference(scales: &[f32], mins: &[f32], nibbles: &[u32]) -> Vec<f32> {
let num_blocks = scales.len();
let mut out = Vec::with_capacity(num_blocks * 32);
for b in 0..num_blocks {
let scale = scales[b];
let min_v = mins[b];
for elem in 0..32 {
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let packed = nibbles[b * 4 + u32_idx_in_block];
let byte = (packed >> (byte_in_u32 * 8)) & 0xFF;
let nibble = (byte >> (is_high * 4)) & 0xF;
out.push(scale * (nibble as f32) + min_v);
}
}
out
}
#[cfg(test)]
pub(crate) fn dequantize_q8_0_reference(scales: &[f32], bytes: &[u32]) -> Vec<f32> {
let num_blocks = scales.len();
let mut out = Vec::with_capacity(num_blocks * 32);
for b in 0..num_blocks {
let scale = scales[b];
for elem in 0..32 {
let u32_idx_in_block = elem / 4;
let byte_in_u32 = elem % 4;
let packed = bytes[b * 8 + u32_idx_in_block];
let byte_u = (packed >> (byte_in_u32 * 8)) & 0xFF;
let sign_bit = (byte_u >> 7) & 1;
let signed_u = byte_u.wrapping_add(0xFFFFFF00_u32.wrapping_mul(sign_bit));
let signed = signed_u as i32;
out.push(scale * signed as f32);
}
}
out
}
#[cfg(test)]
pub(crate) fn dequantize_q5_0_reference(scales: &[f32], qh: &[u32], nibbles: &[u32]) -> Vec<f32> {
let num_blocks = scales.len();
let mut out = Vec::with_capacity(num_blocks * 32);
for b in 0..num_blocks {
let scale = scales[b];
let qh_word = qh[b];
for elem in 0..32 {
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let packed = nibbles[b * 4 + u32_idx_in_block];
let byte = (packed >> (byte_in_u32 * 8)) & 0xFF;
let low_nibble = (byte >> (is_high * 4)) & 0xF;
let bit_pos = (byte_idx as u32) * 2 + (is_high as u32);
let high_bit = (qh_word >> bit_pos) & 1;
let val5 = low_nibble | (high_bit << 4);
out.push(scale * (val5 as f32 - 16.0));
}
}
out
}
#[cfg(test)]
pub(crate) fn dequantize_q5_1_reference(
scales: &[f32],
mins: &[f32],
qh: &[u32],
nibbles: &[u32],
) -> Vec<f32> {
let num_blocks = scales.len();
let mut out = Vec::with_capacity(num_blocks * 32);
for b in 0..num_blocks {
let scale = scales[b];
let min_v = mins[b];
let qh_word = qh[b];
for elem in 0..32 {
let byte_idx = elem / 2;
let is_high = elem % 2;
let u32_idx_in_block = byte_idx / 4;
let byte_in_u32 = byte_idx % 4;
let packed = nibbles[b * 4 + u32_idx_in_block];
let byte = (packed >> (byte_in_u32 * 8)) & 0xFF;
let low_nibble = (byte >> (is_high * 4)) & 0xF;
let bit_pos = (byte_idx as u32) * 2 + (is_high as u32);
let high_bit = (qh_word >> bit_pos) & 1;
let val5 = low_nibble | (high_bit << 4);
out.push(scale * (val5 as f32) + min_v);
}
}
out
}
#[cfg(test)]
pub(crate) fn dequantize_q8_1_reference(scales: &[f32], mins: &[f32], bytes: &[u32]) -> Vec<f32> {
let num_blocks = scales.len();
let mut out = Vec::with_capacity(num_blocks * 32);
for b in 0..num_blocks {
let scale = scales[b];
let min_v = mins[b];
for elem in 0..32 {
let u32_idx_in_block = elem / 4;
let byte_in_u32 = elem % 4;
let packed = bytes[b * 8 + u32_idx_in_block];
let byte_u = (packed >> (byte_in_u32 * 8)) & 0xFF;
let sign_bit = (byte_u >> 7) & 1;
let signed_u = byte_u.wrapping_add(0xFFFFFF00_u32.wrapping_mul(sign_bit));
let signed = signed_u as i32;
out.push(scale * signed as f32 + min_v);
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn build_q4_0_blocks(blocks: &[(f32, [u8; 32])]) -> (Vec<u8>, Vec<f32>) {
let mut raw = Vec::with_capacity(blocks.len() * Q4_0_BLOCK_BYTES);
let mut expected = Vec::with_capacity(blocks.len() * 32);
for &(scale, nibs) in blocks {
let s_f16 = half::f16::from_f32(scale);
let s_bits = s_f16.to_bits();
raw.extend_from_slice(&s_bits.to_le_bytes());
for chunk in nibs.chunks(2) {
let lo = chunk[0] & 0xF;
let hi = chunk[1] & 0xF;
raw.push((hi << 4) | lo);
}
for &n in &nibs {
expected.push(s_f16.to_f32() * (n as f32 - 8.0));
}
}
(raw, expected)
}
#[test]
fn split_q4_0_recovers_scales_and_nibbles() {
let blocks = vec![
(1.0, [0u8; 32]),
(
0.5,
[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, 12, 11, 10,
9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
],
),
(-2.5, [15u8; 32]),
];
let (raw, expected) = build_q4_0_blocks(&blocks);
let (scales, nibbles) = split_q4_0_blocks(&raw, blocks.len());
assert_eq!(scales.len(), blocks.len());
assert_eq!(nibbles.len(), blocks.len() * 4);
for (i, &(s, _)) in blocks.iter().enumerate() {
assert_relative_eq!(scales[i], half::f16::from_f32(s).to_f32(), epsilon = 1e-3);
}
let dequantized = dequantize_q4_0_reference(&scales, &nibbles);
assert_eq!(dequantized.len(), expected.len());
for (got, want) in dequantized.iter().zip(expected.iter()) {
assert_relative_eq!(*got, *want, epsilon = 1e-3);
}
}
#[test]
fn q4_0_reference_matches_serialize_dequant_arithmetic() {
let scale = 0.25_f32;
let mut nibs = [0u8; 32];
for (i, n) in nibs.iter_mut().enumerate() {
*n = (i % 16) as u8;
}
let (raw, _) = build_q4_0_blocks(&[(scale, nibs)]);
let (scales, nibbles) = split_q4_0_blocks(&raw, 1);
let dequant = dequantize_q4_0_reference(&scales, &nibbles);
let s = half::f16::from_f32(scale).to_f32();
for (i, &v) in dequant.iter().enumerate() {
let expected_n = (i % 16) as f32 - 8.0;
assert_relative_eq!(v, s * expected_n, epsilon = 1e-3);
}
}
fn build_q4_1_blocks(blocks: &[(f32, f32, [u8; 32])]) -> Vec<u8> {
let mut raw = Vec::with_capacity(blocks.len() * Q4_1_BLOCK_BYTES);
for &(scale, min_v, nibs) in blocks {
raw.extend_from_slice(&half::f16::from_f32(scale).to_bits().to_le_bytes());
raw.extend_from_slice(&half::f16::from_f32(min_v).to_bits().to_le_bytes());
for chunk in nibs.chunks(2) {
let lo = chunk[0] & 0xF;
let hi = chunk[1] & 0xF;
raw.push((hi << 4) | lo);
}
}
raw
}
#[test]
fn split_q4_1_recovers_scales_mins_nibbles() {
let blocks = vec![
(1.0, 0.0, [3u8; 32]),
(
0.5,
1.5,
[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, 12, 11, 10,
9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
],
),
];
let raw = build_q4_1_blocks(&blocks);
let (scales, mins, nibbles) = split_q4_1_blocks(&raw, blocks.len());
assert_eq!(scales.len(), blocks.len());
assert_eq!(mins.len(), blocks.len());
let dequant = dequantize_q4_1_reference(&scales, &mins, &nibbles);
assert_eq!(dequant.len(), blocks.len() * 32);
let s0 = half::f16::from_f32(blocks[0].0).to_f32();
let m0 = half::f16::from_f32(blocks[0].1).to_f32();
for &v in &dequant[..32] {
assert_relative_eq!(v, 3.0 * s0 + m0, epsilon = 1e-3);
}
}
fn build_q8_0_blocks(blocks: &[(f32, [i8; 32])]) -> Vec<u8> {
let mut raw = Vec::with_capacity(blocks.len() * Q8_0_BLOCK_BYTES);
for &(scale, vals) in blocks {
raw.extend_from_slice(&half::f16::from_f32(scale).to_bits().to_le_bytes());
for v in vals {
raw.push(v as u8);
}
}
raw
}
#[test]
fn split_q8_0_recovers_scales_and_signed_bytes() {
let mut vals = [0i8; 32];
for (i, v) in vals.iter_mut().enumerate() {
*v = ((i as i32) - 16) as i8; }
let blocks = vec![(2.0, vals)];
let raw = build_q8_0_blocks(&blocks);
let (scales, bytes) = split_q8_0_blocks(&raw, blocks.len());
assert_eq!(scales.len(), 1);
assert_eq!(bytes.len(), 8);
let dequant = dequantize_q8_0_reference(&scales, &bytes);
let s = half::f16::from_f32(2.0).to_f32();
for (i, &v) in dequant.iter().enumerate() {
let expected = s * (vals[i] as f32);
assert_relative_eq!(v, expected, epsilon = 1e-3);
}
}
#[test]
fn q8_0_sign_extension_handles_full_range() {
let vals: [i8; 32] = std::array::from_fn(|i| {
((i as i32) * 8 - 128) as i8
});
let raw = build_q8_0_blocks(&[(1.0, vals)]);
let (scales, bytes) = split_q8_0_blocks(&raw, 1);
let dequant = dequantize_q8_0_reference(&scales, &bytes);
let s = half::f16::from_f32(1.0).to_f32();
for (i, &v) in dequant.iter().enumerate() {
let expected = s * (vals[i] as f32);
assert_relative_eq!(v, expected, epsilon = 1e-3);
}
}
#[test]
fn block_kind_metadata_constants() {
assert_eq!(GgufBlockKind::Q4_0.block_elements(), 32);
assert_eq!(GgufBlockKind::Q4_0.block_bytes(), 18);
assert_eq!(GgufBlockKind::Q4_1.block_bytes(), 20);
assert_eq!(GgufBlockKind::Q8_0.block_bytes(), 34);
}
#[test]
fn split_q4_0_panics_on_short_input() {
let short = vec![0u8; 17]; let result = std::panic::catch_unwind(|| split_q4_0_blocks(&short, 1));
assert!(result.is_err());
}
#[test]
fn random_q4_0_blocks_round_trip_through_split_then_dequant() {
let mut state: u32 = 0x1234_5678;
let mut next = || {
state = state.wrapping_mul(1_103_515_245).wrapping_add(12345);
state
};
let num_blocks = 64;
let mut raw = Vec::with_capacity(num_blocks * Q4_0_BLOCK_BYTES);
let mut expected = Vec::with_capacity(num_blocks * 32);
for _ in 0..num_blocks {
let scale_f32 = (next() as f32 / u32::MAX as f32) * 4.0 - 2.0;
let scale = half::f16::from_f32(scale_f32);
raw.extend_from_slice(&scale.to_bits().to_le_bytes());
let s = scale.to_f32();
let mut nibs = [0u8; 32];
for n in nibs.iter_mut() {
*n = (next() & 0xF) as u8;
}
for chunk in nibs.chunks(2) {
let lo = chunk[0] & 0xF;
let hi = chunk[1] & 0xF;
raw.push((hi << 4) | lo);
}
for &n in &nibs {
expected.push(s * (n as f32 - 8.0));
}
}
let (scales, nibbles) = split_q4_0_blocks(&raw, num_blocks);
let got = dequantize_q4_0_reference(&scales, &nibbles);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-3);
}
}
fn build_q5_0_blocks(blocks: &[(f32, u32, [u8; 32])]) -> Vec<u8> {
let mut raw = Vec::with_capacity(blocks.len() * Q5_0_BLOCK_BYTES);
for &(scale, qh, vals) in blocks {
raw.extend_from_slice(&half::f16::from_f32(scale).to_bits().to_le_bytes());
raw.extend_from_slice(&qh.to_le_bytes());
for chunk in vals.chunks(2) {
let lo = chunk[0] & 0xF;
let hi = chunk[1] & 0xF;
raw.push((hi << 4) | lo);
}
}
raw
}
#[test]
fn split_q5_0_recovers_scales_qh_and_nibbles() {
let mut vals = [0u8; 32];
let mut qh: u32 = 0;
for (i, v) in vals.iter_mut().enumerate() {
let val5 = i as u8; *v = val5 & 0xF;
if val5 >= 16 {
qh |= 1 << i;
}
}
let raw = build_q5_0_blocks(&[(1.5, qh, vals)]);
let (scales, qh_out, nibs) = split_q5_0_blocks(&raw, 1);
assert_eq!(scales.len(), 1);
assert_eq!(qh_out.len(), 1);
assert_eq!(qh_out[0], qh);
assert_eq!(nibs.len(), 4);
let dequant = dequantize_q5_0_reference(&scales, &qh_out, &nibs);
let s = half::f16::from_f32(1.5).to_f32();
for (i, &v) in dequant.iter().enumerate() {
let val5 = i as f32;
let expected = s * (val5 - 16.0);
assert_relative_eq!(v, expected, epsilon = 1e-3);
}
}
fn build_q5_1_blocks(blocks: &[(f32, f32, u32, [u8; 32])]) -> Vec<u8> {
let mut raw = Vec::with_capacity(blocks.len() * Q5_1_BLOCK_BYTES);
for &(scale, min_v, qh, vals) in blocks {
raw.extend_from_slice(&half::f16::from_f32(scale).to_bits().to_le_bytes());
raw.extend_from_slice(&half::f16::from_f32(min_v).to_bits().to_le_bytes());
raw.extend_from_slice(&qh.to_le_bytes());
for chunk in vals.chunks(2) {
let lo = chunk[0] & 0xF;
let hi = chunk[1] & 0xF;
raw.push((hi << 4) | lo);
}
}
raw
}
#[test]
fn split_q5_1_recovers_scales_mins_qh_and_nibbles() {
let mut vals = [0u8; 32];
let mut qh: u32 = 0;
for (i, v) in vals.iter_mut().enumerate() {
let val5 = i as u8;
*v = val5 & 0xF;
if val5 >= 16 {
qh |= 1 << i;
}
}
let raw = build_q5_1_blocks(&[(0.5, 2.0, qh, vals)]);
let (scales, mins, qh_out, nibs) = split_q5_1_blocks(&raw, 1);
assert_eq!(scales.len(), 1);
assert_eq!(mins.len(), 1);
let dequant = dequantize_q5_1_reference(&scales, &mins, &qh_out, &nibs);
let s = half::f16::from_f32(0.5).to_f32();
let m = half::f16::from_f32(2.0).to_f32();
for (i, &v) in dequant.iter().enumerate() {
let expected = s * (i as f32) + m;
assert_relative_eq!(v, expected, epsilon = 1e-3);
}
}
fn build_q8_1_blocks(blocks: &[(f32, f32, [i8; 32])]) -> Vec<u8> {
let mut raw = Vec::with_capacity(blocks.len() * Q8_1_BLOCK_BYTES);
for &(scale, min_v, vals) in blocks {
raw.extend_from_slice(&scale.to_le_bytes());
raw.extend_from_slice(&min_v.to_le_bytes());
for v in vals {
raw.push(v as u8);
}
}
raw
}
#[test]
fn split_q8_1_recovers_scales_mins_and_signed_bytes() {
let vals: [i8; 32] = std::array::from_fn(|i| ((i as i32) - 16) as i8);
let raw = build_q8_1_blocks(&[(1.5, 3.0, vals)]);
let (scales, mins, bytes) = split_q8_1_blocks(&raw, 1);
assert_eq!(scales.len(), 1);
assert_eq!(mins.len(), 1);
assert_relative_eq!(scales[0], 1.5, epsilon = 1e-9);
assert_relative_eq!(mins[0], 3.0, epsilon = 1e-9);
assert_eq!(bytes.len(), 8);
let dequant = dequantize_q8_1_reference(&scales, &mins, &bytes);
for (i, &v) in dequant.iter().enumerate() {
let expected = 1.5 * (vals[i] as f32) + 3.0;
assert_relative_eq!(v, expected, epsilon = 1e-6);
}
}
#[test]
fn block_kind_metadata_constants_extended() {
assert_eq!(GgufBlockKind::Q5_0.block_bytes(), 22);
assert_eq!(GgufBlockKind::Q5_1.block_bytes(), 24);
assert_eq!(GgufBlockKind::Q8_1.block_bytes(), 40);
assert_eq!(GgufBlockKind::Q5_0.block_elements(), 32);
assert_eq!(GgufBlockKind::Q5_1.block_elements(), 32);
assert_eq!(GgufBlockKind::Q8_1.block_elements(), 32);
}
}
#[cube(launch_unchecked)]
pub fn kernel_apply_token_mask<F: Float>(logits: &Array<F>, mask: &Array<u32>, out: &mut Array<F>) {
if ABSOLUTE_POS < out.len() {
let i = ABSOLUTE_POS;
let allow = mask[i];
let v = logits[i];
if allow != 0u32 {
out[i] = v;
} else {
out[i] = F::min_value();
}
}
}
pub fn apply_token_mask_to_gpu<R: Runtime>(
client: &ComputeClient<R>,
logits: &[f32],
mask: &[u32],
) -> cubecl::server::Handle {
debug_assert_eq!(logits.len(), mask.len());
let n = logits.len();
let logits_handle = client.create_from_slice(f32::as_bytes(logits));
let mask_handle = client.create_from_slice(unsafe {
std::slice::from_raw_parts(mask.as_ptr() as *const u8, std::mem::size_of_val(mask))
});
let out_handle = client.empty(std::mem::size_of_val(logits));
let (count, dim) = crate::elementwise_launch_dims(n as u32);
unsafe {
kernel_apply_token_mask::launch_unchecked::<f32, R>(
client,
count,
dim,
ArrayArg::from_raw_parts(logits_handle, n),
ArrayArg::from_raw_parts(mask_handle, n),
ArrayArg::from_raw_parts(out_handle.clone(), n),
);
}
out_handle
}
#[cfg(all(test, feature = "cuda"))]
mod cuda_tests {
use super::*;
use approx::assert_relative_eq;
use cubecl_cuda::{CudaDevice, CudaRuntime};
fn cuda_client() -> ComputeClient<CudaRuntime> {
let device = CudaDevice { index: 0 };
CudaRuntime::client(&device)
}
fn read_f32(client: &ComputeClient<CudaRuntime>, handle: cubecl::server::Handle) -> Vec<f32> {
let bytes = client.read_one(handle).expect("CUDA read_one failed");
f32::from_bytes(&bytes).to_vec()
}
#[test]
fn q4_0_kernel_runs_on_gpu_and_matches_reference() {
let client = cuda_client();
let mut state: u32 = 0xCAFE_BABE;
let mut next = || {
state = state.wrapping_mul(1_103_515_245).wrapping_add(12345);
state
};
let num_blocks = 4;
let mut raw = Vec::with_capacity(num_blocks * Q4_0_BLOCK_BYTES);
for _ in 0..num_blocks {
let s_f32 = (next() as f32 / u32::MAX as f32) * 4.0 - 2.0;
let s = half::f16::from_f32(s_f32);
raw.extend_from_slice(&s.to_bits().to_le_bytes());
for _ in 0..16 {
let lo = (next() & 0xF) as u8;
let hi = (next() & 0xF) as u8;
raw.push((hi << 4) | lo);
}
}
let (scales, nibbles) = split_q4_0_blocks(&raw, num_blocks);
let expected = dequantize_q4_0_reference(&scales, &nibbles);
let handle = dequantize_q4_0_to_gpu(&client, &scales, &nibbles, num_blocks * 32);
let got = read_f32(&client, handle);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-5);
}
}
#[test]
fn q4_1_kernel_runs_on_gpu_and_matches_reference() {
let client = cuda_client();
let num_blocks = 3;
let mut state: u32 = 0x4242_4242;
let mut next = || {
state = state.wrapping_mul(48271).wrapping_add(7);
state
};
let mut raw = Vec::with_capacity(num_blocks * Q4_1_BLOCK_BYTES);
for _ in 0..num_blocks {
let s = half::f16::from_f32((next() as f32 / u32::MAX as f32) * 2.0);
let m = half::f16::from_f32((next() as f32 / u32::MAX as f32) * 4.0 - 2.0);
raw.extend_from_slice(&s.to_bits().to_le_bytes());
raw.extend_from_slice(&m.to_bits().to_le_bytes());
for _ in 0..16 {
let lo = (next() & 0xF) as u8;
let hi = (next() & 0xF) as u8;
raw.push((hi << 4) | lo);
}
}
let (scales, mins, nibbles) = split_q4_1_blocks(&raw, num_blocks);
let expected = dequantize_q4_1_reference(&scales, &mins, &nibbles);
let handle = dequantize_q4_1_to_gpu(&client, &scales, &mins, &nibbles, num_blocks * 32);
let got = read_f32(&client, handle);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-5);
}
}
#[test]
fn q8_0_kernel_runs_on_gpu_and_matches_reference() {
let client = cuda_client();
let num_blocks = 5;
let mut state: u32 = 0xDEAD_BEEF;
let mut next = || {
state = state.wrapping_mul(214013).wrapping_add(2531011);
state
};
let mut raw = Vec::with_capacity(num_blocks * Q8_0_BLOCK_BYTES);
for _ in 0..num_blocks {
let s = half::f16::from_f32((next() as f32 / u32::MAX as f32) * 0.5);
raw.extend_from_slice(&s.to_bits().to_le_bytes());
for _ in 0..32 {
raw.push((next() & 0xFF) as u8); }
}
let (scales, bytes) = split_q8_0_blocks(&raw, num_blocks);
let expected = dequantize_q8_0_reference(&scales, &bytes);
let handle = dequantize_q8_0_to_gpu(&client, &scales, &bytes, num_blocks * 32);
let got = read_f32(&client, handle);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-5);
}
}
#[test]
fn q5_0_kernel_runs_on_gpu_and_matches_reference() {
let client = cuda_client();
let num_blocks = 4;
let mut state: u32 = 0xFACE_FEED;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
state
};
let mut raw = Vec::with_capacity(num_blocks * Q5_0_BLOCK_BYTES);
for _ in 0..num_blocks {
let s = half::f16::from_f32((next() as f32 / u32::MAX as f32) * 2.0);
raw.extend_from_slice(&s.to_bits().to_le_bytes());
let qh = next();
raw.extend_from_slice(&qh.to_le_bytes());
for _ in 0..16 {
let lo = (next() & 0xF) as u8;
let hi = (next() & 0xF) as u8;
raw.push((hi << 4) | lo);
}
}
let (scales, qh, nibs) = split_q5_0_blocks(&raw, num_blocks);
let expected = dequantize_q5_0_reference(&scales, &qh, &nibs);
let handle = dequantize_q5_0_to_gpu(&client, &scales, &qh, &nibs, num_blocks * 32);
let got = read_f32(&client, handle);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-5);
}
}
#[test]
fn q5_1_kernel_runs_on_gpu_and_matches_reference() {
let client = cuda_client();
let num_blocks = 4;
let mut state: u32 = 0xBABE_CAFE;
let mut next = || {
state = state.wrapping_mul(22695477).wrapping_add(1);
state
};
let mut raw = Vec::with_capacity(num_blocks * Q5_1_BLOCK_BYTES);
for _ in 0..num_blocks {
let s = half::f16::from_f32((next() as f32 / u32::MAX as f32) * 2.0);
let m = half::f16::from_f32((next() as f32 / u32::MAX as f32) * 4.0 - 2.0);
raw.extend_from_slice(&s.to_bits().to_le_bytes());
raw.extend_from_slice(&m.to_bits().to_le_bytes());
let qh = next();
raw.extend_from_slice(&qh.to_le_bytes());
for _ in 0..16 {
let lo = (next() & 0xF) as u8;
let hi = (next() & 0xF) as u8;
raw.push((hi << 4) | lo);
}
}
let (scales, mins, qh, nibs) = split_q5_1_blocks(&raw, num_blocks);
let expected = dequantize_q5_1_reference(&scales, &mins, &qh, &nibs);
let handle = dequantize_q5_1_to_gpu(&client, &scales, &mins, &qh, &nibs, num_blocks * 32);
let got = read_f32(&client, handle);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-5);
}
}
#[test]
fn token_mask_kernel_runs_on_gpu_and_replaces_disallowed_with_min_value() {
let client = cuda_client();
let logits: Vec<f32> = (0..16).map(|i| (i as f32) - 7.5).collect();
let mask: Vec<u32> = (0..16).map(|i| u32::from(i % 2 == 0)).collect();
let handle = apply_token_mask_to_gpu(&client, &logits, &mask);
let got = read_f32(&client, handle);
assert_eq!(got.len(), logits.len());
for (i, &v) in got.iter().enumerate() {
if mask[i] != 0 {
assert_relative_eq!(v, logits[i], epsilon = 1e-6);
} else {
assert!(v <= -1.0e30, "expected sentinel at idx {i}, got {v}");
}
}
}
#[test]
fn q8_1_kernel_runs_on_gpu_and_matches_reference() {
let client = cuda_client();
let num_blocks = 3;
let mut state: u32 = 0x1357_9BDF;
let mut next = || {
state = state.wrapping_mul(1664525).wrapping_add(1013904223);
state
};
let mut raw = Vec::with_capacity(num_blocks * Q8_1_BLOCK_BYTES);
for _ in 0..num_blocks {
let s = (next() as f32 / u32::MAX as f32) * 0.5;
let m = (next() as f32 / u32::MAX as f32) * 4.0 - 2.0;
raw.extend_from_slice(&s.to_le_bytes());
raw.extend_from_slice(&m.to_le_bytes());
for _ in 0..32 {
raw.push((next() & 0xFF) as u8);
}
}
let (scales, mins, bytes) = split_q8_1_blocks(&raw, num_blocks);
let expected = dequantize_q8_1_reference(&scales, &mins, &bytes);
let handle = dequantize_q8_1_to_gpu(&client, &scales, &mins, &bytes, num_blocks * 32);
let got = read_f32(&client, handle);
assert_eq!(got.len(), expected.len());
for (a, b) in got.iter().zip(expected.iter()) {
assert_relative_eq!(*a, *b, epsilon = 1e-5);
}
}
}