1#[cfg(not(feature = "std"))]
5use alloc::{string::String, vec, vec::Vec};
6#[cfg(feature = "std")]
7use std::{string::String, vec, vec::Vec};
8
9use karpal_proof::Property;
10
11use crate::{Declaration, Obligation, ObligationBundle, Origin, Sort, Term, VerificationTier};
12
13pub struct IsBufferAlignedTo16;
14impl Property for IsBufferAlignedTo16 {
15 const NAME: &'static str = "IsBufferAlignedTo16";
16}
17
18pub struct IsWorkgroupSizeDivisible;
19impl Property for IsWorkgroupSizeDivisible {
20 const NAME: &'static str = "IsWorkgroupSizeDivisible";
21}
22
23pub struct IsDispatchWithinLimits;
24impl Property for IsDispatchWithinLimits {
25 const NAME: &'static str = "IsDispatchWithinLimits";
26}
27
28pub struct IsMSLKernelDeterministic;
29impl Property for IsMSLKernelDeterministic {
30 const NAME: &'static str = "IsMSLKernelDeterministic";
31}
32
33pub struct IsNumericallyCorrect;
53impl Property for IsNumericallyCorrect {
54 const NAME: &'static str = "IsNumericallyCorrect";
55}
56
57#[derive(Debug, Clone)]
59pub struct GpuObligationBundle {
60 bundle: ObligationBundle,
61}
62
63impl GpuObligationBundle {
64 pub fn metal_kernel(name: impl Into<String>, origin: Origin) -> Self {
65 Self {
66 bundle: ObligationBundle::new(name, origin),
67 }
68 }
69
70 pub fn with_buffer_alignment(mut self, buffer: impl Into<String>, alignment: i64) -> Self {
71 let buffer = buffer.into();
72 let property = if alignment == 16 {
73 IsBufferAlignedTo16::NAME
74 } else {
75 "IsBufferAligned"
76 };
77 self.bundle.push(Obligation {
78 name: format_obligation_name(&buffer, "buffer_alignment"),
79 property,
80 declarations: vec![Declaration::new(buffer.clone(), Sort::named("MTLBuffer"))],
81 assumptions: Vec::new(),
82 conclusion: Term::app("aligned_to", [Term::var(buffer), Term::int(alignment)]),
83 origin: self.bundle.origin.clone(),
84 tier: VerificationTier::External,
85 });
86 self
87 }
88
89 pub fn with_workgroup_divisibility(mut self, symbol: impl Into<String>, divisor: i64) -> Self {
90 let symbol = symbol.into();
91 self.bundle.push(Obligation {
92 name: format_obligation_name(&symbol, "workgroup_divisibility"),
93 property: IsWorkgroupSizeDivisible::NAME,
94 declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
95 assumptions: Vec::new(),
96 conclusion: Term::app("divisible_by", [Term::var(symbol), Term::int(divisor)]),
97 origin: self.bundle.origin.clone(),
98 tier: VerificationTier::External,
99 });
100 self
101 }
102
103 pub fn with_dispatch_limit(mut self, symbol: impl Into<String>, limit: i64) -> Self {
104 let symbol = symbol.into();
105 self.bundle.push(Obligation {
106 name: format_obligation_name(&symbol, "dispatch_limit"),
107 property: IsDispatchWithinLimits::NAME,
108 declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
109 assumptions: Vec::new(),
110 conclusion: Term::app(
111 "within_dispatch_limit",
112 [Term::var(symbol), Term::int(limit)],
113 ),
114 origin: self.bundle.origin.clone(),
115 tier: VerificationTier::External,
116 });
117 self
118 }
119
120 pub fn with_kernel_determinism(mut self, kernel: impl Into<String>) -> Self {
121 let kernel = kernel.into();
122 self.bundle.push(Obligation {
123 name: format_obligation_name(&kernel, "kernel_determinism"),
124 property: IsMSLKernelDeterministic::NAME,
125 declarations: vec![Declaration::new(kernel.clone(), Sort::named("MSLKernel"))],
126 assumptions: Vec::new(),
127 conclusion: Term::app("deterministic_kernel", [Term::var(kernel)]),
128 origin: self.bundle.origin.clone(),
129 tier: VerificationTier::External,
130 });
131 self
132 }
133
134 pub fn with_numerical_correctness(
158 mut self,
159 kernel: impl Into<String>,
160 threshold: i64,
161 trials: u32,
162 ) -> Self {
163 let kernel = kernel.into();
164 self.bundle.push(Obligation {
165 name: format_obligation_name(&kernel, "numerical_correctness"),
166 property: IsNumericallyCorrect::NAME,
167 declarations: vec![
168 Declaration::new("threshold", Sort::Int),
169 Declaration::new("trials", Sort::Int),
170 ],
171 assumptions: vec![
172 Term::app("le", [Term::int(0), Term::var("threshold")]),
173 Term::app("le", [Term::int(1), Term::var("trials")]),
174 ],
175 conclusion: Term::app(
176 "numerically_correct_under_exact_match",
177 [
178 Term::var(kernel.clone()),
179 Term::int(threshold),
180 Term::int(i64::from(trials)),
181 ],
182 ),
183 origin: self.bundle.origin.clone(),
184 tier: VerificationTier::Emergent,
185 });
186 self
187 }
188
189 pub fn into_bundle(self) -> ObligationBundle {
190 self.bundle
191 }
192}
193
194fn format_obligation_name(symbol: &str, suffix: &str) -> String {
195 let mut out = String::new();
196 for ch in symbol.chars() {
197 out.push(if ch.is_ascii_alphanumeric() || ch == '_' {
198 ch
199 } else {
200 '_'
201 });
202 }
203 out.push('_');
204 out.push_str(suffix);
205 out
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::{
212 Origin, VerificationTier, export_kani_bundle, export_lean_bundle, export_smt_bundle,
213 };
214
215 #[test]
216 fn gpu_bundle_contains_expected_obligations() {
217 let bundle = GpuObligationBundle::metal_kernel(
218 "borsalino_kernel",
219 Origin::new("borsalino", "kernels::reduce"),
220 )
221 .with_buffer_alignment("input", 16)
222 .with_workgroup_divisibility("threads_per_group", 32)
223 .with_dispatch_limit("grid_size", 65_535)
224 .with_kernel_determinism("reduce_kernel")
225 .into_bundle();
226
227 assert!(
228 bundle
229 .obligations()
230 .iter()
231 .any(|obligation| obligation.property == "IsBufferAlignedTo16")
232 );
233 assert!(
234 bundle
235 .obligations()
236 .iter()
237 .any(|obligation| obligation.property == "IsWorkgroupSizeDivisible")
238 );
239 assert!(
240 bundle
241 .obligations()
242 .iter()
243 .any(|obligation| obligation.property == "IsDispatchWithinLimits")
244 );
245 assert!(
246 bundle
247 .obligations()
248 .iter()
249 .any(|obligation| obligation.property == "IsMSLKernelDeterministic")
250 );
251 }
252
253 #[test]
254 fn gpu_bundle_exports_through_all_backends() {
255 let bundle = GpuObligationBundle::metal_kernel(
256 "borsalino_kernel",
257 Origin::new("borsalino", "kernels::reduce"),
258 )
259 .with_buffer_alignment("input", 16)
260 .with_workgroup_divisibility("threads_per_group", 32)
261 .with_dispatch_limit("grid_size", 65_535)
262 .with_kernel_determinism("reduce_kernel")
263 .into_bundle();
264
265 let smt = export_smt_bundle(&bundle);
266 let lean = export_lean_bundle("GpuVerify", &bundle);
267 let kani = export_kani_bundle(&bundle);
268
269 assert_eq!(smt.len(), 4);
270 assert!(lean.contains("deterministic_kernel"));
271 assert_eq!(kani.len(), 4);
272 assert!(kani[0].source.contains("kani::assert"));
273 }
274
275 #[test]
286 fn numerical_correctness_property_has_name() {
287 assert_eq!(IsNumericallyCorrect::NAME, "IsNumericallyCorrect");
288 }
289
290 #[test]
291 fn with_numerical_correctness_adds_obligation() {
292 let bundle = GpuObligationBundle::metal_kernel(
293 "borsalino_matmul",
294 Origin::new("borsalino", "kernels::matmul"),
295 )
296 .with_numerical_correctness("matmul_kernel", 2048, 16)
297 .into_bundle();
298
299 let obligations = bundle.obligations();
300 assert_eq!(obligations.len(), 1);
301 assert_eq!(obligations[0].property, IsNumericallyCorrect::NAME);
302 assert_eq!(obligations[0].tier, VerificationTier::Emergent);
303 }
304
305 #[test]
306 fn numerical_correctness_exports_through_all_backends() {
307 let bundle = GpuObligationBundle::metal_kernel(
308 "borsalino_matmul",
309 Origin::new("borsalino", "kernels::matmul"),
310 )
311 .with_buffer_alignment("input_a", 16)
312 .with_buffer_alignment("input_b", 16)
313 .with_numerical_correctness("matmul_kernel", 2048, 16)
314 .into_bundle();
315
316 let smt = export_smt_bundle(&bundle);
317 let lean = export_lean_bundle("GpuVerify", &bundle);
318 let kani = export_kani_bundle(&bundle);
319
320 assert_eq!(smt.len(), 3);
321 assert!(
322 lean.contains("numerically_correct_under_exact_match"),
323 "Lean export should contain the numerical correctness predicate"
324 );
325 assert_eq!(kani.len(), 3);
326 assert!(
327 kani.iter()
328 .any(|h| h.source.contains("numerically_correct_under_exact_match")),
329 "Kani export should contain the numerical correctness predicate"
330 );
331 }
332
333 #[test]
334 fn numerical_correctness_records_threshold_and_trials() {
335 let bundle = GpuObligationBundle::metal_kernel(
336 "borsalino_gp",
337 Origin::new("borsalino", "kernels::geometric_product"),
338 )
339 .with_numerical_correctness("gp_kernel", 2048, 32)
340 .into_bundle();
341
342 let obligation = &bundle.obligations()[0];
343
344 let conclusion_str = format!("{:?}", obligation.conclusion);
346 assert!(
347 conclusion_str.contains("2048"),
348 "Conclusion should contain the threshold value: {conclusion_str}"
349 );
350 assert!(
351 conclusion_str.contains("32"),
352 "Conclusion should contain the trials value: {conclusion_str}"
353 );
354 }
355}