1pub use trueno_quant::{
23 Q4_K_BLOCK_BYTES, Q4_K_BLOCK_SIZE, Q5_K_BLOCK_BYTES, Q5_K_BLOCK_SIZE, Q6_K_BLOCK_BYTES,
24 Q6_K_BLOCK_SIZE,
25};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TensorLayout {
41 RowMajor,
44}
45
46pub const STACK_LAYOUT: TensorLayout = TensorLayout::RowMajor;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct QuantFormat {
59 pub name: &'static str,
61 pub block_size: usize,
63 pub block_bytes: usize,
65 pub ggml_type_id: u32,
67}
68
69pub const Q4_K: QuantFormat = QuantFormat {
71 name: "Q4_K",
72 block_size: Q4_K_BLOCK_SIZE,
73 block_bytes: Q4_K_BLOCK_BYTES,
74 ggml_type_id: 12,
75};
76
77pub const Q5_K: QuantFormat = QuantFormat {
79 name: "Q5_K",
80 block_size: Q5_K_BLOCK_SIZE,
81 block_bytes: Q5_K_BLOCK_BYTES,
82 ggml_type_id: 13,
83};
84
85pub const Q6_K: QuantFormat = QuantFormat {
87 name: "Q6_K",
88 block_size: Q6_K_BLOCK_SIZE,
89 block_bytes: Q6_K_BLOCK_BYTES,
90 ggml_type_id: 14,
91};
92
93pub const Q8_0: QuantFormat =
95 QuantFormat { name: "Q8_0", block_size: 32, block_bytes: 34, ggml_type_id: 8 };
96
97pub const Q5_0: QuantFormat =
99 QuantFormat { name: "Q5_0", block_size: 32, block_bytes: 22, ggml_type_id: 6 };
100
101pub const Q4_0: QuantFormat =
103 QuantFormat { name: "Q4_0", block_size: 32, block_bytes: 18, ggml_type_id: 2 };
104
105pub const Q4_1: QuantFormat =
107 QuantFormat { name: "Q4_1", block_size: 32, block_bytes: 20, ggml_type_id: 3 };
108
109pub const ALL_FORMATS: &[QuantFormat] = &[Q4_0, Q4_1, Q5_0, Q8_0, Q4_K, Q5_K, Q6_K];
111
112#[must_use]
114pub fn format_by_ggml_type(type_id: u32) -> Option<&'static QuantFormat> {
115 ALL_FORMATS.iter().find(|f| f.ggml_type_id == type_id)
116}
117
118#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct WeightBufferError {
125 pub weight_name: String,
127 pub reason: String,
129}
130
131impl std::fmt::Display for WeightBufferError {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 write!(f, "Kernel contract violation for '{}': {}", self.weight_name, self.reason)
134 }
135}
136
137impl std::error::Error for WeightBufferError {}
138
139impl QuantFormat {
140 #[must_use]
151 pub const fn expected_bytes(&self, rows: usize, cols: usize) -> usize {
152 let blocks_per_row = (cols + self.block_size - 1) / self.block_size;
153 rows * blocks_per_row * self.block_bytes
154 }
155
156 pub fn validate_buffer(
163 &self,
164 weight_name: &str,
165 actual_bytes: usize,
166 rows: usize,
167 cols: usize,
168 ) -> Result<(), WeightBufferError> {
169 let expected = self.expected_bytes(rows, cols);
170 if actual_bytes != expected {
171 return Err(WeightBufferError {
172 weight_name: weight_name.to_string(),
173 reason: format!(
174 "{} buffer size mismatch: got {} bytes, expected {} bytes \
175 for [{}, {}] ({} blocks/row * {} bytes/block * {} rows)",
176 self.name,
177 actual_bytes,
178 expected,
179 rows,
180 cols,
181 (cols + self.block_size - 1) / self.block_size,
182 self.block_bytes,
183 rows,
184 ),
185 });
186 }
187 Ok(())
188 }
189}
190
191pub fn validate_weight_buffer(
210 weight_name: &str,
211 ggml_type: u32,
212 actual_bytes: usize,
213 rows: usize,
214 cols: usize,
215) -> Result<(), WeightBufferError> {
216 let format = format_by_ggml_type(ggml_type).ok_or_else(|| WeightBufferError {
217 weight_name: weight_name.to_string(),
218 reason: format!("Unknown GGML quantization type ID: {ggml_type}"),
219 })?;
220 format.validate_buffer(weight_name, actual_bytes, rows, cols)
221}
222
223pub fn validate_f32_buffer(
232 weight_name: &str,
233 actual_elements: usize,
234 rows: usize,
235 cols: usize,
236) -> Result<(), WeightBufferError> {
237 let expected = rows * cols;
238 if actual_elements != expected {
239 return Err(WeightBufferError {
240 weight_name: weight_name.to_string(),
241 reason: format!(
242 "F32 element count mismatch: got {actual_elements}, expected {expected} \
243 for [{rows}, {cols}]"
244 ),
245 });
246 }
247 Ok(())
248}
249
250pub fn validate_gemv_shapes(
266 weight_name: &str,
267 weight_rows: usize,
268 weight_cols: usize,
269 input_len: usize,
270 output_len: usize,
271) -> Result<(), WeightBufferError> {
272 if weight_cols != input_len {
273 return Err(WeightBufferError {
274 weight_name: weight_name.to_string(),
275 reason: format!(
276 "GEMV input dimension mismatch: weight has {weight_cols} cols \
277 but input has {input_len} elements"
278 ),
279 });
280 }
281 if weight_rows != output_len {
282 return Err(WeightBufferError {
283 weight_name: weight_name.to_string(),
284 reason: format!(
285 "GEMV output dimension mismatch: weight has {weight_rows} rows \
286 but output has {output_len} elements"
287 ),
288 });
289 }
290 Ok(())
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_q4k_expected_bytes() {
299 assert_eq!(Q4_K.expected_bytes(4096, 4096), 9_437_184);
302 }
303
304 #[test]
305 fn test_q6k_expected_bytes() {
306 assert_eq!(Q6_K.expected_bytes(4096, 4096), 4096 * 16 * 210);
308 }
309
310 #[test]
311 fn test_q8_0_expected_bytes() {
312 assert_eq!(Q8_0.expected_bytes(4096, 4096), 4096 * 128 * 34);
315 }
316
317 #[test]
318 fn test_validate_buffer_ok() {
319 let bytes = Q4_K.expected_bytes(4096, 4096);
320 assert!(Q4_K.validate_buffer("test.weight", bytes, 4096, 4096).is_ok());
321 }
322
323 #[test]
324 fn test_validate_buffer_wrong_size() {
325 let err = Q4_K.validate_buffer("test.weight", 1000, 4096, 4096).unwrap_err();
326 assert!(err.reason.contains("buffer size mismatch"));
327 }
328
329 #[test]
330 fn test_validate_weight_buffer_unknown_type() {
331 let err = validate_weight_buffer("test.weight", 99, 1000, 4096, 4096).unwrap_err();
332 assert!(err.reason.contains("Unknown GGML"));
333 }
334
335 #[test]
336 fn test_validate_f32_buffer_ok() {
337 assert!(validate_f32_buffer("test.weight", 4096 * 4096, 4096, 4096).is_ok());
338 }
339
340 #[test]
341 fn test_validate_f32_buffer_mismatch() {
342 let err = validate_f32_buffer("test.weight", 100, 4096, 4096).unwrap_err();
343 assert!(err.reason.contains("element count mismatch"));
344 }
345
346 #[test]
347 fn test_validate_gemv_shapes_ok() {
348 assert!(validate_gemv_shapes("test", 4096, 4096, 4096, 4096).is_ok());
349 }
350
351 #[test]
352 fn test_validate_gemv_shapes_input_mismatch() {
353 let err = validate_gemv_shapes("test", 4096, 4096, 2048, 4096).unwrap_err();
354 assert!(err.reason.contains("input dimension mismatch"));
355 }
356
357 #[test]
358 fn test_validate_gemv_shapes_output_mismatch() {
359 let err = validate_gemv_shapes("test", 4096, 4096, 4096, 2048).unwrap_err();
360 assert!(err.reason.contains("output dimension mismatch"));
361 }
362
363 #[test]
364 fn test_format_lookup_all_types() {
365 assert_eq!(format_by_ggml_type(2).unwrap().name, "Q4_0");
366 assert_eq!(format_by_ggml_type(3).unwrap().name, "Q4_1");
367 assert_eq!(format_by_ggml_type(6).unwrap().name, "Q5_0");
368 assert_eq!(format_by_ggml_type(8).unwrap().name, "Q8_0");
369 assert_eq!(format_by_ggml_type(12).unwrap().name, "Q4_K");
370 assert_eq!(format_by_ggml_type(13).unwrap().name, "Q5_K");
371 assert_eq!(format_by_ggml_type(14).unwrap().name, "Q6_K");
372 assert!(format_by_ggml_type(99).is_none());
373 }
374
375 #[test]
376 fn test_stack_layout_is_row_major() {
377 assert_eq!(STACK_LAYOUT, TensorLayout::RowMajor);
378 }
379
380 #[test]
381 fn test_non_aligned_cols() {
382 assert_eq!(Q4_K.expected_bytes(10, 100), 10 * 144);
385 assert_eq!(Q4_K.expected_bytes(10, 300), 10 * 2 * 144);
387 }
388}