Skip to main content

karpal_verify/
gpu.rs

1// Copyright (C) 2026 Industrial Algebra
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(not(feature = "std"))]
5use alloc::{string::String, vec, vec::Vec};
6#[cfg(feature = "std")]
7use std::{string::String, vec, vec::Vec};
8
9use karpal_proof::Property;
10
11use crate::{Declaration, Obligation, ObligationBundle, Origin, Sort, Term, VerificationTier};
12
13pub struct IsBufferAlignedTo16;
14impl Property for IsBufferAlignedTo16 {
15    const NAME: &'static str = "IsBufferAlignedTo16";
16}
17
18pub struct IsWorkgroupSizeDivisible;
19impl Property for IsWorkgroupSizeDivisible {
20    const NAME: &'static str = "IsWorkgroupSizeDivisible";
21}
22
23pub struct IsDispatchWithinLimits;
24impl Property for IsDispatchWithinLimits {
25    const NAME: &'static str = "IsDispatchWithinLimits";
26}
27
28pub struct IsMSLKernelDeterministic;
29impl Property for IsMSLKernelDeterministic {
30    const NAME: &'static str = "IsMSLKernelDeterministic";
31}
32
33/// A linear GPU kernel produces the numerically correct output under the
34/// DeepReinforce exact-match protocol.
35///
36/// The protocol restricts kernel inputs to binary `{0, 1}` values so that
37/// floating-point associativity holds exactly within the FP16 integer range
38/// `[0, 2048]`. A FP32 CPU reference is then compared against the GPU output
39/// with bit-exact equality at every position where the reference value is at
40/// or below the threshold.
41///
42/// This is a **runtime** verification property: the obligation records that
43/// a numerical check has been specified, but the actual check runs in the
44/// consumer crate (e.g., Borsalino's `verify_numerical()`).
45///
46/// # Limitations
47///
48/// Only applies to linear kernels. Non-linear operations (`log`, `exp`,
49/// `tanh`) produce irrational outputs that cannot be checked with exact match.
50///
51/// Reference: <https://deep-reinforce.com/correctness_check.html>
52pub struct IsNumericallyCorrect;
53impl Property for IsNumericallyCorrect {
54    const NAME: &'static str = "IsNumericallyCorrect";
55}
56
57/// Builder for GPU compute verification obligations.
58#[derive(Debug, Clone)]
59pub struct GpuObligationBundle {
60    bundle: ObligationBundle,
61}
62
63impl GpuObligationBundle {
64    pub fn metal_kernel(name: impl Into<String>, origin: Origin) -> Self {
65        Self {
66            bundle: ObligationBundle::new(name, origin),
67        }
68    }
69
70    pub fn with_buffer_alignment(mut self, buffer: impl Into<String>, alignment: i64) -> Self {
71        let buffer = buffer.into();
72        let property = if alignment == 16 {
73            IsBufferAlignedTo16::NAME
74        } else {
75            "IsBufferAligned"
76        };
77        self.bundle.push(Obligation {
78            name: format_obligation_name(&buffer, "buffer_alignment"),
79            property,
80            declarations: vec![Declaration::new(buffer.clone(), Sort::named("MTLBuffer"))],
81            assumptions: Vec::new(),
82            conclusion: Term::app("aligned_to", [Term::var(buffer), Term::int(alignment)]),
83            origin: self.bundle.origin.clone(),
84            tier: VerificationTier::External,
85        });
86        self
87    }
88
89    pub fn with_workgroup_divisibility(mut self, symbol: impl Into<String>, divisor: i64) -> Self {
90        let symbol = symbol.into();
91        self.bundle.push(Obligation {
92            name: format_obligation_name(&symbol, "workgroup_divisibility"),
93            property: IsWorkgroupSizeDivisible::NAME,
94            declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
95            assumptions: Vec::new(),
96            conclusion: Term::app("divisible_by", [Term::var(symbol), Term::int(divisor)]),
97            origin: self.bundle.origin.clone(),
98            tier: VerificationTier::External,
99        });
100        self
101    }
102
103    pub fn with_dispatch_limit(mut self, symbol: impl Into<String>, limit: i64) -> Self {
104        let symbol = symbol.into();
105        self.bundle.push(Obligation {
106            name: format_obligation_name(&symbol, "dispatch_limit"),
107            property: IsDispatchWithinLimits::NAME,
108            declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
109            assumptions: Vec::new(),
110            conclusion: Term::app(
111                "within_dispatch_limit",
112                [Term::var(symbol), Term::int(limit)],
113            ),
114            origin: self.bundle.origin.clone(),
115            tier: VerificationTier::External,
116        });
117        self
118    }
119
120    pub fn with_kernel_determinism(mut self, kernel: impl Into<String>) -> Self {
121        let kernel = kernel.into();
122        self.bundle.push(Obligation {
123            name: format_obligation_name(&kernel, "kernel_determinism"),
124            property: IsMSLKernelDeterministic::NAME,
125            declarations: vec![Declaration::new(kernel.clone(), Sort::named("MSLKernel"))],
126            assumptions: Vec::new(),
127            conclusion: Term::app("deterministic_kernel", [Term::var(kernel)]),
128            origin: self.bundle.origin.clone(),
129            tier: VerificationTier::External,
130        });
131        self
132    }
133
134    /// Declare that this kernel has a numerical correctness check specified.
135    ///
136    /// Records that the kernel will be verified against a FP32 CPU reference
137    /// using the DeepReinforce exact-match protocol. The actual runtime check
138    /// runs in the consumer crate (e.g., Borsalino's `verify_numerical()`).
139    ///
140    /// # Parameters
141    ///
142    /// - `kernel`: The kernel symbol being verified.
143    /// - `threshold`: The exact-match ceiling (2048 for FP16, the largest
144    ///   integer exactly representable in FP16). Positions where the reference
145    ///   output exceeds this value are ignored.
146    /// - `trials`: Number of random binary-input trials to run.
147    ///
148    /// # Tier
149    ///
150    /// Uses [`VerificationTier::Emergent`] — runtime discovery, not
151    /// compile-time phantom or external theorem prover.
152    ///
153    /// # Reference
154    ///
155    /// See [Towards a Reliable Kernel Correctness Check in Matrix
156    /// Multiplication](https://deep-reinforce.com/correctness_check.html).
157    pub fn with_numerical_correctness(
158        mut self,
159        kernel: impl Into<String>,
160        threshold: i64,
161        trials: u32,
162    ) -> Self {
163        let kernel = kernel.into();
164        self.bundle.push(Obligation {
165            name: format_obligation_name(&kernel, "numerical_correctness"),
166            property: IsNumericallyCorrect::NAME,
167            declarations: vec![
168                Declaration::new("threshold", Sort::Int),
169                Declaration::new("trials", Sort::Int),
170            ],
171            assumptions: vec![
172                Term::app("le", [Term::int(0), Term::var("threshold")]),
173                Term::app("le", [Term::int(1), Term::var("trials")]),
174            ],
175            conclusion: Term::app(
176                "numerically_correct_under_exact_match",
177                [
178                    Term::var(kernel.clone()),
179                    Term::int(threshold),
180                    Term::int(i64::from(trials)),
181                ],
182            ),
183            origin: self.bundle.origin.clone(),
184            tier: VerificationTier::Emergent,
185        });
186        self
187    }
188
189    pub fn into_bundle(self) -> ObligationBundle {
190        self.bundle
191    }
192}
193
194fn format_obligation_name(symbol: &str, suffix: &str) -> String {
195    let mut out = String::new();
196    for ch in symbol.chars() {
197        out.push(if ch.is_ascii_alphanumeric() || ch == '_' {
198            ch
199        } else {
200            '_'
201        });
202    }
203    out.push('_');
204    out.push_str(suffix);
205    out
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::{
212        Origin, VerificationTier, export_kani_bundle, export_lean_bundle, export_smt_bundle,
213    };
214
215    #[test]
216    fn gpu_bundle_contains_expected_obligations() {
217        let bundle = GpuObligationBundle::metal_kernel(
218            "borsalino_kernel",
219            Origin::new("borsalino", "kernels::reduce"),
220        )
221        .with_buffer_alignment("input", 16)
222        .with_workgroup_divisibility("threads_per_group", 32)
223        .with_dispatch_limit("grid_size", 65_535)
224        .with_kernel_determinism("reduce_kernel")
225        .into_bundle();
226
227        assert!(
228            bundle
229                .obligations()
230                .iter()
231                .any(|obligation| obligation.property == "IsBufferAlignedTo16")
232        );
233        assert!(
234            bundle
235                .obligations()
236                .iter()
237                .any(|obligation| obligation.property == "IsWorkgroupSizeDivisible")
238        );
239        assert!(
240            bundle
241                .obligations()
242                .iter()
243                .any(|obligation| obligation.property == "IsDispatchWithinLimits")
244        );
245        assert!(
246            bundle
247                .obligations()
248                .iter()
249                .any(|obligation| obligation.property == "IsMSLKernelDeterministic")
250        );
251    }
252
253    #[test]
254    fn gpu_bundle_exports_through_all_backends() {
255        let bundle = GpuObligationBundle::metal_kernel(
256            "borsalino_kernel",
257            Origin::new("borsalino", "kernels::reduce"),
258        )
259        .with_buffer_alignment("input", 16)
260        .with_workgroup_divisibility("threads_per_group", 32)
261        .with_dispatch_limit("grid_size", 65_535)
262        .with_kernel_determinism("reduce_kernel")
263        .into_bundle();
264
265        let smt = export_smt_bundle(&bundle);
266        let lean = export_lean_bundle("GpuVerify", &bundle);
267        let kani = export_kani_bundle(&bundle);
268
269        assert_eq!(smt.len(), 4);
270        assert!(lean.contains("deterministic_kernel"));
271        assert_eq!(kani.len(), 4);
272        assert!(kani[0].source.contains("kani::assert"));
273    }
274
275    // ── Numerical correctness (v0.6.0) ───────────────────────────
276    //
277    // Based on the DeepReinforce exact-match protocol:
278    // https://deep-reinforce.com/correctness_check.html
279    //
280    // Linear kernels are checked by restricting inputs to binary {0, 1}
281    // values and requiring bit-exact equality against an FP32 CPU
282    // reference at positions where the reference value <= 2048 (the FP16
283    // exact-integer ceiling).
284
285    #[test]
286    fn numerical_correctness_property_has_name() {
287        assert_eq!(IsNumericallyCorrect::NAME, "IsNumericallyCorrect");
288    }
289
290    #[test]
291    fn with_numerical_correctness_adds_obligation() {
292        let bundle = GpuObligationBundle::metal_kernel(
293            "borsalino_matmul",
294            Origin::new("borsalino", "kernels::matmul"),
295        )
296        .with_numerical_correctness("matmul_kernel", 2048, 16)
297        .into_bundle();
298
299        let obligations = bundle.obligations();
300        assert_eq!(obligations.len(), 1);
301        assert_eq!(obligations[0].property, IsNumericallyCorrect::NAME);
302        assert_eq!(obligations[0].tier, VerificationTier::Emergent);
303    }
304
305    #[test]
306    fn numerical_correctness_exports_through_all_backends() {
307        let bundle = GpuObligationBundle::metal_kernel(
308            "borsalino_matmul",
309            Origin::new("borsalino", "kernels::matmul"),
310        )
311        .with_buffer_alignment("input_a", 16)
312        .with_buffer_alignment("input_b", 16)
313        .with_numerical_correctness("matmul_kernel", 2048, 16)
314        .into_bundle();
315
316        let smt = export_smt_bundle(&bundle);
317        let lean = export_lean_bundle("GpuVerify", &bundle);
318        let kani = export_kani_bundle(&bundle);
319
320        assert_eq!(smt.len(), 3);
321        assert!(
322            lean.contains("numerically_correct_under_exact_match"),
323            "Lean export should contain the numerical correctness predicate"
324        );
325        assert_eq!(kani.len(), 3);
326        assert!(
327            kani.iter()
328                .any(|h| h.source.contains("numerically_correct_under_exact_match")),
329            "Kani export should contain the numerical correctness predicate"
330        );
331    }
332
333    #[test]
334    fn numerical_correctness_records_threshold_and_trials() {
335        let bundle = GpuObligationBundle::metal_kernel(
336            "borsalino_gp",
337            Origin::new("borsalino", "kernels::geometric_product"),
338        )
339        .with_numerical_correctness("gp_kernel", 2048, 32)
340        .into_bundle();
341
342        let obligation = &bundle.obligations()[0];
343
344        // The conclusion should reference the threshold and trials.
345        let conclusion_str = format!("{:?}", obligation.conclusion);
346        assert!(
347            conclusion_str.contains("2048"),
348            "Conclusion should contain the threshold value: {conclusion_str}"
349        );
350        assert!(
351            conclusion_str.contains("32"),
352            "Conclusion should contain the trials value: {conclusion_str}"
353        );
354    }
355}