1#[cfg(not(feature = "std"))]
2use alloc::{string::String, vec, vec::Vec};
3#[cfg(feature = "std")]
4use std::{string::String, vec, vec::Vec};
5
6use karpal_proof::Property;
7
8use crate::{Declaration, Obligation, ObligationBundle, Origin, Sort, Term, VerificationTier};
9
10pub struct IsBufferAlignedTo16;
11impl Property for IsBufferAlignedTo16 {
12 const NAME: &'static str = "IsBufferAlignedTo16";
13}
14
15pub struct IsWorkgroupSizeDivisible;
16impl Property for IsWorkgroupSizeDivisible {
17 const NAME: &'static str = "IsWorkgroupSizeDivisible";
18}
19
20pub struct IsDispatchWithinLimits;
21impl Property for IsDispatchWithinLimits {
22 const NAME: &'static str = "IsDispatchWithinLimits";
23}
24
25pub struct IsMSLKernelDeterministic;
26impl Property for IsMSLKernelDeterministic {
27 const NAME: &'static str = "IsMSLKernelDeterministic";
28}
29
30#[derive(Debug, Clone)]
32pub struct GpuObligationBundle {
33 bundle: ObligationBundle,
34}
35
36impl GpuObligationBundle {
37 pub fn metal_kernel(name: impl Into<String>, origin: Origin) -> Self {
38 Self {
39 bundle: ObligationBundle::new(name, origin),
40 }
41 }
42
43 pub fn with_buffer_alignment(mut self, buffer: impl Into<String>, alignment: i64) -> Self {
44 let buffer = buffer.into();
45 let property = if alignment == 16 {
46 IsBufferAlignedTo16::NAME
47 } else {
48 "IsBufferAligned"
49 };
50 self.bundle.push(Obligation {
51 name: format_obligation_name(&buffer, "buffer_alignment"),
52 property,
53 declarations: vec![Declaration::new(buffer.clone(), Sort::named("MTLBuffer"))],
54 assumptions: Vec::new(),
55 conclusion: Term::app("aligned_to", [Term::var(buffer), Term::int(alignment)]),
56 origin: self.bundle.origin.clone(),
57 tier: VerificationTier::External,
58 });
59 self
60 }
61
62 pub fn with_workgroup_divisibility(mut self, symbol: impl Into<String>, divisor: i64) -> Self {
63 let symbol = symbol.into();
64 self.bundle.push(Obligation {
65 name: format_obligation_name(&symbol, "workgroup_divisibility"),
66 property: IsWorkgroupSizeDivisible::NAME,
67 declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
68 assumptions: Vec::new(),
69 conclusion: Term::app("divisible_by", [Term::var(symbol), Term::int(divisor)]),
70 origin: self.bundle.origin.clone(),
71 tier: VerificationTier::External,
72 });
73 self
74 }
75
76 pub fn with_dispatch_limit(mut self, symbol: impl Into<String>, limit: i64) -> Self {
77 let symbol = symbol.into();
78 self.bundle.push(Obligation {
79 name: format_obligation_name(&symbol, "dispatch_limit"),
80 property: IsDispatchWithinLimits::NAME,
81 declarations: vec![Declaration::new(symbol.clone(), Sort::Int)],
82 assumptions: Vec::new(),
83 conclusion: Term::app(
84 "within_dispatch_limit",
85 [Term::var(symbol), Term::int(limit)],
86 ),
87 origin: self.bundle.origin.clone(),
88 tier: VerificationTier::External,
89 });
90 self
91 }
92
93 pub fn with_kernel_determinism(mut self, kernel: impl Into<String>) -> Self {
94 let kernel = kernel.into();
95 self.bundle.push(Obligation {
96 name: format_obligation_name(&kernel, "kernel_determinism"),
97 property: IsMSLKernelDeterministic::NAME,
98 declarations: vec![Declaration::new(kernel.clone(), Sort::named("MSLKernel"))],
99 assumptions: Vec::new(),
100 conclusion: Term::app("deterministic_kernel", [Term::var(kernel)]),
101 origin: self.bundle.origin.clone(),
102 tier: VerificationTier::External,
103 });
104 self
105 }
106
107 pub fn into_bundle(self) -> ObligationBundle {
108 self.bundle
109 }
110}
111
112fn format_obligation_name(symbol: &str, suffix: &str) -> String {
113 let mut out = String::new();
114 for ch in symbol.chars() {
115 out.push(if ch.is_ascii_alphanumeric() || ch == '_' {
116 ch
117 } else {
118 '_'
119 });
120 }
121 out.push('_');
122 out.push_str(suffix);
123 out
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use crate::{Origin, export_kani_bundle, export_lean_bundle, export_smt_bundle};
130
131 #[test]
132 fn gpu_bundle_contains_expected_obligations() {
133 let bundle = GpuObligationBundle::metal_kernel(
134 "borsalino_kernel",
135 Origin::new("borsalino", "kernels::reduce"),
136 )
137 .with_buffer_alignment("input", 16)
138 .with_workgroup_divisibility("threads_per_group", 32)
139 .with_dispatch_limit("grid_size", 65_535)
140 .with_kernel_determinism("reduce_kernel")
141 .into_bundle();
142
143 assert!(
144 bundle
145 .obligations()
146 .iter()
147 .any(|obligation| obligation.property == "IsBufferAlignedTo16")
148 );
149 assert!(
150 bundle
151 .obligations()
152 .iter()
153 .any(|obligation| obligation.property == "IsWorkgroupSizeDivisible")
154 );
155 assert!(
156 bundle
157 .obligations()
158 .iter()
159 .any(|obligation| obligation.property == "IsDispatchWithinLimits")
160 );
161 assert!(
162 bundle
163 .obligations()
164 .iter()
165 .any(|obligation| obligation.property == "IsMSLKernelDeterministic")
166 );
167 }
168
169 #[test]
170 fn gpu_bundle_exports_through_all_backends() {
171 let bundle = GpuObligationBundle::metal_kernel(
172 "borsalino_kernel",
173 Origin::new("borsalino", "kernels::reduce"),
174 )
175 .with_buffer_alignment("input", 16)
176 .with_workgroup_divisibility("threads_per_group", 32)
177 .with_dispatch_limit("grid_size", 65_535)
178 .with_kernel_determinism("reduce_kernel")
179 .into_bundle();
180
181 let smt = export_smt_bundle(&bundle);
182 let lean = export_lean_bundle("GpuVerify", &bundle);
183 let kani = export_kani_bundle(&bundle);
184
185 assert_eq!(smt.len(), 4);
186 assert!(lean.contains("deterministic_kernel"));
187 assert_eq!(kani.len(), 4);
188 assert!(kani[0].source.contains("kani::assert"));
189 }
190}