Skip to main content

trueno/
contracts.rs

1//! GH-279: Kernel-Level Contracts for the Sovereign AI Stack
2//!
3//! Defines invariants that MUST hold for ANY data entering trueno compute kernels.
4//! Consumers (realizar, aprender) validate data against these contracts BEFORE
5//! calling any kernel. Violating a contract is a hard error, not a silent default.
6//!
7//! # Contract Hierarchy
8//!
9//! ```text
10//! aprender (import) ──► enforce_architecture_completeness()  [tensor names]
11//!                        │
12//! realizar (load)   ──► contract_gate::validate_model_load()  [architecture]
13//!                        │
14//! trueno (kernel)   ──► contracts::validate_weight_buffer()    [bytes & layout]
15//! ```
16//!
17//! This module is the bottom layer — raw buffer and layout validation.
18//! If these fail, the kernel WILL produce garbage or crash.
19
20// Re-export trueno-quant constants as the canonical source of truth for block sizes.
21// These constants are also used directly by trueno kernels (tiling, brick, etc.).
22pub use trueno_quant::{
23    Q4_K_BLOCK_BYTES, Q4_K_BLOCK_SIZE, Q5_K_BLOCK_BYTES, Q5_K_BLOCK_SIZE, Q6_K_BLOCK_BYTES,
24    Q6_K_BLOCK_SIZE,
25};
26
27// ============================================================================
28// Layout Contract (LAYOUT-001/002)
29// ============================================================================
30
31/// Tensor layout used by all trueno kernels.
32///
33/// The entire stack (APR format, realizar inference, trueno kernels) uses
34/// row-major layout EXCLUSIVELY. Column-major data from GGUF is transposed
35/// at the import boundary in aprender.
36///
37/// Kernel contract: `weight[row * cols + col]` is the access pattern.
38/// Passing column-major data produces GARBAGE — there is no runtime flag.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum TensorLayout {
41    /// Row-major: shape [rows, cols], stride [cols, 1]
42    /// This is the ONLY layout trueno kernels accept.
43    RowMajor,
44}
45
46/// The stack-wide tensor layout. All kernels assume this.
47pub const STACK_LAYOUT: TensorLayout = TensorLayout::RowMajor;
48
49// ============================================================================
50// Quantization Format Descriptors
51// ============================================================================
52
53/// Quantization format descriptor with block geometry.
54///
55/// Each format defines a fixed relationship between element count and byte size.
56/// Kernels use these to compute buffer sizes and validate inputs.
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub struct QuantFormat {
59    /// Human-readable name (e.g., "Q4_K")
60    pub name: &'static str,
61    /// Elements per quantization block
62    pub block_size: usize,
63    /// Bytes per quantization block
64    pub block_bytes: usize,
65    /// GGML type ID (for GGUF interop)
66    pub ggml_type_id: u32,
67}
68
69/// Q4_K super-block format: 256 elements, 144 bytes (4.5 bits/weight)
70pub const Q4_K: QuantFormat = QuantFormat {
71    name: "Q4_K",
72    block_size: Q4_K_BLOCK_SIZE,
73    block_bytes: Q4_K_BLOCK_BYTES,
74    ggml_type_id: 12,
75};
76
77/// Q5_K super-block format: 256 elements, 176 bytes (5.5 bits/weight)
78pub const Q5_K: QuantFormat = QuantFormat {
79    name: "Q5_K",
80    block_size: Q5_K_BLOCK_SIZE,
81    block_bytes: Q5_K_BLOCK_BYTES,
82    ggml_type_id: 13,
83};
84
85/// Q6_K super-block format: 256 elements, 210 bytes (6.5 bits/weight)
86pub const Q6_K: QuantFormat = QuantFormat {
87    name: "Q6_K",
88    block_size: Q6_K_BLOCK_SIZE,
89    block_bytes: Q6_K_BLOCK_BYTES,
90    ggml_type_id: 14,
91};
92
93/// Q8_0 block format: 32 elements, 34 bytes
94pub const Q8_0: QuantFormat =
95    QuantFormat { name: "Q8_0", block_size: 32, block_bytes: 34, ggml_type_id: 8 };
96
97/// Q5_0 block format: 32 elements, 22 bytes
98pub const Q5_0: QuantFormat =
99    QuantFormat { name: "Q5_0", block_size: 32, block_bytes: 22, ggml_type_id: 6 };
100
101/// Q4_0 block format: 32 elements, 18 bytes
102pub const Q4_0: QuantFormat =
103    QuantFormat { name: "Q4_0", block_size: 32, block_bytes: 18, ggml_type_id: 2 };
104
105/// Q4_1 block format: 32 elements, 20 bytes
106pub const Q4_1: QuantFormat =
107    QuantFormat { name: "Q4_1", block_size: 32, block_bytes: 20, ggml_type_id: 3 };
108
109/// All supported quantization formats, ordered by GGML type ID.
110pub const ALL_FORMATS: &[QuantFormat] = &[Q4_0, Q4_1, Q5_0, Q8_0, Q4_K, Q5_K, Q6_K];
111
112/// Lookup a quantization format by GGML type ID.
113#[must_use]
114pub fn format_by_ggml_type(type_id: u32) -> Option<&'static QuantFormat> {
115    ALL_FORMATS.iter().find(|f| f.ggml_type_id == type_id)
116}
117
118// ============================================================================
119// Weight Buffer Validation
120// ============================================================================
121
122/// Error returned when a weight buffer fails contract validation.
123#[derive(Debug, Clone, PartialEq, Eq)]
124pub struct WeightBufferError {
125    /// Which weight (e.g., "blk.0.attn_q.weight")
126    pub weight_name: String,
127    /// What went wrong
128    pub reason: String,
129}
130
131impl std::fmt::Display for WeightBufferError {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(f, "Kernel contract violation for '{}': {}", self.weight_name, self.reason)
134    }
135}
136
137impl std::error::Error for WeightBufferError {}
138
139impl QuantFormat {
140    /// Compute the expected byte size for a weight matrix [rows, cols].
141    ///
142    /// The matrix is stored as `rows` independent row vectors, each quantized
143    /// into ceil(cols / block_size) blocks of `block_bytes` bytes.
144    ///
145    /// # Row-Major Contract
146    ///
147    /// For GEMV `y = W·x` where W is [out_dim, in_dim]:
148    /// - `rows` = out_dim (number of output features)
149    /// - `cols` = in_dim (number of input features, quantized along this axis)
150    #[must_use]
151    pub const fn expected_bytes(&self, rows: usize, cols: usize) -> usize {
152        let blocks_per_row = (cols + self.block_size - 1) / self.block_size;
153        rows * blocks_per_row * self.block_bytes
154    }
155
156    /// Validate that a weight buffer has the correct size for [rows, cols].
157    ///
158    /// # Errors
159    ///
160    /// Returns `WeightBufferError` if `actual_bytes` does not match the expected
161    /// size for the given dimensions and quantization format.
162    pub fn validate_buffer(
163        &self,
164        weight_name: &str,
165        actual_bytes: usize,
166        rows: usize,
167        cols: usize,
168    ) -> Result<(), WeightBufferError> {
169        let expected = self.expected_bytes(rows, cols);
170        if actual_bytes != expected {
171            return Err(WeightBufferError {
172                weight_name: weight_name.to_string(),
173                reason: format!(
174                    "{} buffer size mismatch: got {} bytes, expected {} bytes \
175                     for [{}, {}] ({} blocks/row * {} bytes/block * {} rows)",
176                    self.name,
177                    actual_bytes,
178                    expected,
179                    rows,
180                    cols,
181                    (cols + self.block_size - 1) / self.block_size,
182                    self.block_bytes,
183                    rows,
184                ),
185            });
186        }
187        Ok(())
188    }
189}
190
191/// Validate a weight buffer against a known GGML type ID.
192///
193/// This is the primary entry point for realizar and aprender to validate
194/// quantized weight buffers before passing them to trueno kernels.
195///
196/// # Arguments
197///
198/// * `weight_name` - Human-readable name (for error messages)
199/// * `ggml_type` - GGML quantization type ID
200/// * `actual_bytes` - Actual buffer size in bytes
201/// * `rows` - Number of rows (out_dim for GEMV)
202/// * `cols` - Number of columns (in_dim for GEMV)
203///
204/// # Errors
205///
206/// Returns `WeightBufferError` if:
207/// - The GGML type ID is unknown
208/// - The buffer size doesn't match expected dimensions
209pub fn validate_weight_buffer(
210    weight_name: &str,
211    ggml_type: u32,
212    actual_bytes: usize,
213    rows: usize,
214    cols: usize,
215) -> Result<(), WeightBufferError> {
216    let format = format_by_ggml_type(ggml_type).ok_or_else(|| WeightBufferError {
217        weight_name: weight_name.to_string(),
218        reason: format!("Unknown GGML quantization type ID: {ggml_type}"),
219    })?;
220    format.validate_buffer(weight_name, actual_bytes, rows, cols)
221}
222
223/// Validate that an F32 weight buffer has correct element count.
224///
225/// For unquantized (F32) weights, the buffer must have exactly `rows * cols * 4` bytes
226/// (or `rows * cols` elements).
227///
228/// # Errors
229///
230/// Returns `WeightBufferError` if the element count doesn't match.
231pub fn validate_f32_buffer(
232    weight_name: &str,
233    actual_elements: usize,
234    rows: usize,
235    cols: usize,
236) -> Result<(), WeightBufferError> {
237    let expected = rows * cols;
238    if actual_elements != expected {
239        return Err(WeightBufferError {
240            weight_name: weight_name.to_string(),
241            reason: format!(
242                "F32 element count mismatch: got {actual_elements}, expected {expected} \
243                 for [{rows}, {cols}]"
244            ),
245        });
246    }
247    Ok(())
248}
249
250// ============================================================================
251// Matmul Shape Contract
252// ============================================================================
253
254/// Validate GEMV shape invariants for row-major layout.
255///
256/// For `y = W · x` with row-major W[out_dim, in_dim]:
257/// - `weight_rows` MUST equal `out_dim`
258/// - `weight_cols` MUST equal `in_dim`
259/// - `input_len` MUST equal `in_dim`
260/// - `output_len` MUST equal `out_dim`
261///
262/// # Errors
263///
264/// Returns `WeightBufferError` describing the shape mismatch.
265pub fn validate_gemv_shapes(
266    weight_name: &str,
267    weight_rows: usize,
268    weight_cols: usize,
269    input_len: usize,
270    output_len: usize,
271) -> Result<(), WeightBufferError> {
272    if weight_cols != input_len {
273        return Err(WeightBufferError {
274            weight_name: weight_name.to_string(),
275            reason: format!(
276                "GEMV input dimension mismatch: weight has {weight_cols} cols \
277                 but input has {input_len} elements"
278            ),
279        });
280    }
281    if weight_rows != output_len {
282        return Err(WeightBufferError {
283            weight_name: weight_name.to_string(),
284            reason: format!(
285                "GEMV output dimension mismatch: weight has {weight_rows} rows \
286                 but output has {output_len} elements"
287            ),
288        });
289    }
290    Ok(())
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_q4k_expected_bytes() {
299        // 256 elements per block, 144 bytes per block
300        // For a [4096, 4096] weight: 4096 * (4096/256) * 144 = 4096 * 16 * 144 = 9_437_184
301        assert_eq!(Q4_K.expected_bytes(4096, 4096), 9_437_184);
302    }
303
304    #[test]
305    fn test_q6k_expected_bytes() {
306        // 256 elements per block, 210 bytes per block
307        assert_eq!(Q6_K.expected_bytes(4096, 4096), 4096 * 16 * 210);
308    }
309
310    #[test]
311    fn test_q8_0_expected_bytes() {
312        // 32 elements per block, 34 bytes per block
313        // For [4096, 4096]: 4096 * (4096/32) * 34 = 4096 * 128 * 34
314        assert_eq!(Q8_0.expected_bytes(4096, 4096), 4096 * 128 * 34);
315    }
316
317    #[test]
318    fn test_validate_buffer_ok() {
319        let bytes = Q4_K.expected_bytes(4096, 4096);
320        assert!(Q4_K.validate_buffer("test.weight", bytes, 4096, 4096).is_ok());
321    }
322
323    #[test]
324    fn test_validate_buffer_wrong_size() {
325        let err = Q4_K.validate_buffer("test.weight", 1000, 4096, 4096).unwrap_err();
326        assert!(err.reason.contains("buffer size mismatch"));
327    }
328
329    #[test]
330    fn test_validate_weight_buffer_unknown_type() {
331        let err = validate_weight_buffer("test.weight", 99, 1000, 4096, 4096).unwrap_err();
332        assert!(err.reason.contains("Unknown GGML"));
333    }
334
335    #[test]
336    fn test_validate_f32_buffer_ok() {
337        assert!(validate_f32_buffer("test.weight", 4096 * 4096, 4096, 4096).is_ok());
338    }
339
340    #[test]
341    fn test_validate_f32_buffer_mismatch() {
342        let err = validate_f32_buffer("test.weight", 100, 4096, 4096).unwrap_err();
343        assert!(err.reason.contains("element count mismatch"));
344    }
345
346    #[test]
347    fn test_validate_gemv_shapes_ok() {
348        assert!(validate_gemv_shapes("test", 4096, 4096, 4096, 4096).is_ok());
349    }
350
351    #[test]
352    fn test_validate_gemv_shapes_input_mismatch() {
353        let err = validate_gemv_shapes("test", 4096, 4096, 2048, 4096).unwrap_err();
354        assert!(err.reason.contains("input dimension mismatch"));
355    }
356
357    #[test]
358    fn test_validate_gemv_shapes_output_mismatch() {
359        let err = validate_gemv_shapes("test", 4096, 4096, 4096, 2048).unwrap_err();
360        assert!(err.reason.contains("output dimension mismatch"));
361    }
362
363    #[test]
364    fn test_format_lookup_all_types() {
365        assert_eq!(format_by_ggml_type(2).unwrap().name, "Q4_0");
366        assert_eq!(format_by_ggml_type(3).unwrap().name, "Q4_1");
367        assert_eq!(format_by_ggml_type(6).unwrap().name, "Q5_0");
368        assert_eq!(format_by_ggml_type(8).unwrap().name, "Q8_0");
369        assert_eq!(format_by_ggml_type(12).unwrap().name, "Q4_K");
370        assert_eq!(format_by_ggml_type(13).unwrap().name, "Q5_K");
371        assert_eq!(format_by_ggml_type(14).unwrap().name, "Q6_K");
372        assert!(format_by_ggml_type(99).is_none());
373    }
374
375    #[test]
376    fn test_stack_layout_is_row_major() {
377        assert_eq!(STACK_LAYOUT, TensorLayout::RowMajor);
378    }
379
380    #[test]
381    fn test_non_aligned_cols() {
382        // 100 cols doesn't divide evenly into 256-element blocks
383        // ceil(100/256) = 1 block per row
384        assert_eq!(Q4_K.expected_bytes(10, 100), 10 * 144);
385        // 300 cols: ceil(300/256) = 2 blocks per row
386        assert_eq!(Q4_K.expected_bytes(10, 300), 10 * 2 * 144);
387    }
388}