pub use trueno_quant::{
Q4_K_BLOCK_BYTES, Q4_K_BLOCK_SIZE, Q5_K_BLOCK_BYTES, Q5_K_BLOCK_SIZE, Q6_K_BLOCK_BYTES,
Q6_K_BLOCK_SIZE,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TensorLayout {
RowMajor,
}
pub const STACK_LAYOUT: TensorLayout = TensorLayout::RowMajor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct QuantFormat {
pub name: &'static str,
pub block_size: usize,
pub block_bytes: usize,
pub ggml_type_id: u32,
}
pub const Q4_K: QuantFormat = QuantFormat {
name: "Q4_K",
block_size: Q4_K_BLOCK_SIZE,
block_bytes: Q4_K_BLOCK_BYTES,
ggml_type_id: 12,
};
pub const Q5_K: QuantFormat = QuantFormat {
name: "Q5_K",
block_size: Q5_K_BLOCK_SIZE,
block_bytes: Q5_K_BLOCK_BYTES,
ggml_type_id: 13,
};
pub const Q6_K: QuantFormat = QuantFormat {
name: "Q6_K",
block_size: Q6_K_BLOCK_SIZE,
block_bytes: Q6_K_BLOCK_BYTES,
ggml_type_id: 14,
};
pub const Q8_0: QuantFormat =
QuantFormat { name: "Q8_0", block_size: 32, block_bytes: 34, ggml_type_id: 8 };
pub const Q5_0: QuantFormat =
QuantFormat { name: "Q5_0", block_size: 32, block_bytes: 22, ggml_type_id: 6 };
pub const Q4_0: QuantFormat =
QuantFormat { name: "Q4_0", block_size: 32, block_bytes: 18, ggml_type_id: 2 };
pub const Q4_1: QuantFormat =
QuantFormat { name: "Q4_1", block_size: 32, block_bytes: 20, ggml_type_id: 3 };
pub const ALL_FORMATS: &[QuantFormat] = &[Q4_0, Q4_1, Q5_0, Q8_0, Q4_K, Q5_K, Q6_K];
#[must_use]
pub fn format_by_ggml_type(type_id: u32) -> Option<&'static QuantFormat> {
ALL_FORMATS.iter().find(|f| f.ggml_type_id == type_id)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WeightBufferError {
pub weight_name: String,
pub reason: String,
}
impl std::fmt::Display for WeightBufferError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Kernel contract violation for '{}': {}", self.weight_name, self.reason)
}
}
impl std::error::Error for WeightBufferError {}
impl QuantFormat {
#[must_use]
pub const fn expected_bytes(&self, rows: usize, cols: usize) -> usize {
let blocks_per_row = (cols + self.block_size - 1) / self.block_size;
rows * blocks_per_row * self.block_bytes
}
pub fn validate_buffer(
&self,
weight_name: &str,
actual_bytes: usize,
rows: usize,
cols: usize,
) -> Result<(), WeightBufferError> {
let expected = self.expected_bytes(rows, cols);
if actual_bytes != expected {
return Err(WeightBufferError {
weight_name: weight_name.to_string(),
reason: format!(
"{} buffer size mismatch: got {} bytes, expected {} bytes \
for [{}, {}] ({} blocks/row * {} bytes/block * {} rows)",
self.name,
actual_bytes,
expected,
rows,
cols,
(cols + self.block_size - 1) / self.block_size,
self.block_bytes,
rows,
),
});
}
Ok(())
}
}
pub fn validate_weight_buffer(
weight_name: &str,
ggml_type: u32,
actual_bytes: usize,
rows: usize,
cols: usize,
) -> Result<(), WeightBufferError> {
let format = format_by_ggml_type(ggml_type).ok_or_else(|| WeightBufferError {
weight_name: weight_name.to_string(),
reason: format!("Unknown GGML quantization type ID: {ggml_type}"),
})?;
format.validate_buffer(weight_name, actual_bytes, rows, cols)
}
pub fn validate_f32_buffer(
weight_name: &str,
actual_elements: usize,
rows: usize,
cols: usize,
) -> Result<(), WeightBufferError> {
let expected = rows * cols;
if actual_elements != expected {
return Err(WeightBufferError {
weight_name: weight_name.to_string(),
reason: format!(
"F32 element count mismatch: got {actual_elements}, expected {expected} \
for [{rows}, {cols}]"
),
});
}
Ok(())
}
pub fn validate_gemv_shapes(
weight_name: &str,
weight_rows: usize,
weight_cols: usize,
input_len: usize,
output_len: usize,
) -> Result<(), WeightBufferError> {
if weight_cols != input_len {
return Err(WeightBufferError {
weight_name: weight_name.to_string(),
reason: format!(
"GEMV input dimension mismatch: weight has {weight_cols} cols \
but input has {input_len} elements"
),
});
}
if weight_rows != output_len {
return Err(WeightBufferError {
weight_name: weight_name.to_string(),
reason: format!(
"GEMV output dimension mismatch: weight has {weight_rows} rows \
but output has {output_len} elements"
),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_q4k_expected_bytes() {
assert_eq!(Q4_K.expected_bytes(4096, 4096), 9_437_184);
}
#[test]
fn test_q6k_expected_bytes() {
assert_eq!(Q6_K.expected_bytes(4096, 4096), 4096 * 16 * 210);
}
#[test]
fn test_q8_0_expected_bytes() {
assert_eq!(Q8_0.expected_bytes(4096, 4096), 4096 * 128 * 34);
}
#[test]
fn test_validate_buffer_ok() {
let bytes = Q4_K.expected_bytes(4096, 4096);
assert!(Q4_K.validate_buffer("test.weight", bytes, 4096, 4096).is_ok());
}
#[test]
fn test_validate_buffer_wrong_size() {
let err = Q4_K.validate_buffer("test.weight", 1000, 4096, 4096).unwrap_err();
assert!(err.reason.contains("buffer size mismatch"));
}
#[test]
fn test_validate_weight_buffer_unknown_type() {
let err = validate_weight_buffer("test.weight", 99, 1000, 4096, 4096).unwrap_err();
assert!(err.reason.contains("Unknown GGML"));
}
#[test]
fn test_validate_f32_buffer_ok() {
assert!(validate_f32_buffer("test.weight", 4096 * 4096, 4096, 4096).is_ok());
}
#[test]
fn test_validate_f32_buffer_mismatch() {
let err = validate_f32_buffer("test.weight", 100, 4096, 4096).unwrap_err();
assert!(err.reason.contains("element count mismatch"));
}
#[test]
fn test_validate_gemv_shapes_ok() {
assert!(validate_gemv_shapes("test", 4096, 4096, 4096, 4096).is_ok());
}
#[test]
fn test_validate_gemv_shapes_input_mismatch() {
let err = validate_gemv_shapes("test", 4096, 4096, 2048, 4096).unwrap_err();
assert!(err.reason.contains("input dimension mismatch"));
}
#[test]
fn test_validate_gemv_shapes_output_mismatch() {
let err = validate_gemv_shapes("test", 4096, 4096, 4096, 2048).unwrap_err();
assert!(err.reason.contains("output dimension mismatch"));
}
#[test]
fn test_format_lookup_all_types() {
assert_eq!(format_by_ggml_type(2).unwrap().name, "Q4_0");
assert_eq!(format_by_ggml_type(3).unwrap().name, "Q4_1");
assert_eq!(format_by_ggml_type(6).unwrap().name, "Q5_0");
assert_eq!(format_by_ggml_type(8).unwrap().name, "Q8_0");
assert_eq!(format_by_ggml_type(12).unwrap().name, "Q4_K");
assert_eq!(format_by_ggml_type(13).unwrap().name, "Q5_K");
assert_eq!(format_by_ggml_type(14).unwrap().name, "Q6_K");
assert!(format_by_ggml_type(99).is_none());
}
#[test]
fn test_stack_layout_is_row_major() {
assert_eq!(STACK_LAYOUT, TensorLayout::RowMajor);
}
#[test]
fn test_non_aligned_cols() {
assert_eq!(Q4_K.expected_bytes(10, 100), 10 * 144);
assert_eq!(Q4_K.expected_bytes(10, 300), 10 * 2 * 144);
}
}