Skip to main content

karpal_verify/
gpu.rs

1#[cfg(not(feature = "std"))]
2use alloc::{string::String, vec, vec::Vec};
3#[cfg(feature = "std")]
4use std::{string::String, vec, vec::Vec};
5
6use karpal_proof::Property;
7
8use crate::{Declaration, Obligation, ObligationBundle, Origin, Sort, Term, VerificationTier};
9
10pub struct IsBufferAlignedTo16;
11impl Property for IsBufferAlignedTo16 {
12    const NAME: &'static str = "IsBufferAlignedTo16";
13}
14
15pub struct IsWorkgroupSizeDivisible;
16impl Property for IsWorkgroupSizeDivisible {
17    const NAME: &'static str = "IsWorkgroupSizeDivisible";
18}
19
20pub struct IsDispatchWithinLimits;
21impl Property for IsDispatchWithinLimits {
22    const NAME: &'static str = "IsDispatchWithinLimits";
23}
24
25pub struct IsMSLKernelDeterministic;
26impl Property for IsMSLKernelDeterministic {
27    const NAME: &'static str = "IsMSLKernelDeterministic";
28}
29
30/// Builder for GPU compute verification obligations.
31#[derive(Debug, Clone)]
32pub struct GpuObligationBundle {
33    bundle: ObligationBundle,
34}
35
36impl GpuObligationBundle {
37    pub fn metal_kernel(name: impl Into<String>, origin: Origin) -> Self {
38        Self {
39            bundle: ObligationBundle::new(name, origin),
40        }
41    }
42
43    pub fn with_buffer_alignment(mut self, buffer: impl Into<String>, alignment: i64) -> Self {
44        let buffer = buffer.into();
45        let property = if alignment == 16 {
46            IsBufferAlignedTo16::NAME
47        } else {
48            "IsBufferAligned"
49        };
50        self.bundle.push(Obligation {
51            name: format_obligation_name(&buffer, "buffer_alignment"),
52            property,
53            declarations: vec![Declaration::new(buffer.clone(), Sort::named("MTLBuffer"))],
54            assumptions: Vec::new(),
55            conclusion: Term::app("aligned_to", [Term::var(buffer), Term::int(alignment)]),
56            origin: self.bundle.origin.clone(),
57            tier: VerificationTier::External,
58        });
59        self
60    }
61
62    pub fn with_workgroup_divisibility(mut self, symbol: impl Into<String>, divisor: i64) -> Self {
63        let symbol = symbol.into();
64        self.bundle.push(Obligation {
65            name: format_obligation_name(&symbol, "workgroup_divisibility"),
66            property: IsWorkgroupSizeDivisible::NAME,
67            declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
68            assumptions: Vec::new(),
69            conclusion: Term::app("divisible_by", [Term::var(symbol), Term::int(divisor)]),
70            origin: self.bundle.origin.clone(),
71            tier: VerificationTier::External,
72        });
73        self
74    }
75
76    pub fn with_dispatch_limit(mut self, symbol: impl Into<String>, limit: i64) -> Self {
77        let symbol = symbol.into();
78        self.bundle.push(Obligation {
79            name: format_obligation_name(&symbol, "dispatch_limit"),
80            property: IsDispatchWithinLimits::NAME,
81            declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
82            assumptions: Vec::new(),
83            conclusion: Term::app(
84                "within_dispatch_limit",
85                [Term::var(symbol), Term::int(limit)],
86            ),
87            origin: self.bundle.origin.clone(),
88            tier: VerificationTier::External,
89        });
90        self
91    }
92
93    pub fn with_kernel_determinism(mut self, kernel: impl Into<String>) -> Self {
94        let kernel = kernel.into();
95        self.bundle.push(Obligation {
96            name: format_obligation_name(&kernel, "kernel_determinism"),
97            property: IsMSLKernelDeterministic::NAME,
98            declarations: vec![Declaration::new(kernel.clone(), Sort::named("MSLKernel"))],
99            assumptions: Vec::new(),
100            conclusion: Term::app("deterministic_kernel", [Term::var(kernel)]),
101            origin: self.bundle.origin.clone(),
102            tier: VerificationTier::External,
103        });
104        self
105    }
106
107    pub fn into_bundle(self) -> ObligationBundle {
108        self.bundle
109    }
110}
111
112fn format_obligation_name(symbol: &str, suffix: &str) -> String {
113    let mut out = String::new();
114    for ch in symbol.chars() {
115        out.push(if ch.is_ascii_alphanumeric() || ch == '_' {
116            ch
117        } else {
118            '_'
119        });
120    }
121    out.push('_');
122    out.push_str(suffix);
123    out
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::{Origin, export_kani_bundle, export_lean_bundle, export_smt_bundle};
130
131    #[test]
132    fn gpu_bundle_contains_expected_obligations() {
133        let bundle = GpuObligationBundle::metal_kernel(
134            "borsalino_kernel",
135            Origin::new("borsalino", "kernels::reduce"),
136        )
137        .with_buffer_alignment("input", 16)
138        .with_workgroup_divisibility("threads_per_group", 32)
139        .with_dispatch_limit("grid_size", 65_535)
140        .with_kernel_determinism("reduce_kernel")
141        .into_bundle();
142
143        assert!(
144            bundle
145                .obligations()
146                .iter()
147                .any(|obligation| obligation.property == "IsBufferAlignedTo16")
148        );
149        assert!(
150            bundle
151                .obligations()
152                .iter()
153                .any(|obligation| obligation.property == "IsWorkgroupSizeDivisible")
154        );
155        assert!(
156            bundle
157                .obligations()
158                .iter()
159                .any(|obligation| obligation.property == "IsDispatchWithinLimits")
160        );
161        assert!(
162            bundle
163                .obligations()
164                .iter()
165                .any(|obligation| obligation.property == "IsMSLKernelDeterministic")
166        );
167    }
168
169    #[test]
170    fn gpu_bundle_exports_through_all_backends() {
171        let bundle = GpuObligationBundle::metal_kernel(
172            "borsalino_kernel",
173            Origin::new("borsalino", "kernels::reduce"),
174        )
175        .with_buffer_alignment("input", 16)
176        .with_workgroup_divisibility("threads_per_group", 32)
177        .with_dispatch_limit("grid_size", 65_535)
178        .with_kernel_determinism("reduce_kernel")
179        .into_bundle();
180
181        let smt = export_smt_bundle(&bundle);
182        let lean = export_lean_bundle("GpuVerify", &bundle);
183        let kani = export_kani_bundle(&bundle);
184
185        assert_eq!(smt.len(), 4);
186        assert!(lean.contains("deterministic_kernel"));
187        assert_eq!(kani.len(), 4);
188        assert!(kani[0].source.contains("kani::assert"));
189    }
190}