1use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PerformanceContract {
12 pub kind: String,
13 pub name: String,
14 pub version: String,
15 #[serde(default)]
16 pub kernel: String,
17 #[serde(default)]
18 pub hardware: HardwareSpec,
19 #[serde(default)]
20 pub bounds: Vec<PerformanceBound>,
21 #[serde(default)]
22 pub metrics: std::collections::HashMap<String, MetricBound>,
23 #[serde(default)]
24 pub falsification: Vec<FalsificationCheck>,
25 #[serde(flatten, default)]
27 pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct HardwareSpec {
32 pub gpu: Option<String>,
33 pub cpu: Option<String>,
34 pub compute_capability: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PerformanceBound {
39 #[serde(default, deserialize_with = "deserialize_size")]
40 pub size: Vec<u32>,
41 #[serde(default)]
42 pub max_time_us: Option<f64>,
43 #[serde(default)]
44 pub min_tflops: Option<f64>,
45 #[serde(default)]
46 pub max_regression_pct: Option<f64>,
47 #[serde(default)]
48 pub min_bandwidth_gbps: Option<f64>,
49 #[serde(flatten, default)]
51 pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
52}
53
54fn deserialize_size<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
56where
57 D: serde::Deserializer<'de>,
58{
59 use serde::de;
60
61 struct SizeVisitor;
62 impl<'de> de::Visitor<'de> for SizeVisitor {
63 type Value = Vec<u32>;
64 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65 f.write_str("an integer or sequence of integers")
66 }
67 fn visit_u64<E: de::Error>(self, v: u64) -> Result<Vec<u32>, E> {
68 Ok(vec![v as u32])
69 }
70 fn visit_i64<E: de::Error>(self, v: i64) -> Result<Vec<u32>, E> {
71 Ok(vec![v as u32])
72 }
73 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Vec<u32>, A::Error> {
74 let mut v = Vec::new();
75 while let Some(elem) = seq.next_element::<u32>()? {
76 v.push(elem);
77 }
78 Ok(v)
79 }
80 fn visit_none<E: de::Error>(self) -> Result<Vec<u32>, E> {
81 Ok(Vec::new())
82 }
83 fn visit_unit<E: de::Error>(self) -> Result<Vec<u32>, E> {
84 Ok(Vec::new())
85 }
86 }
87 deserializer.deserialize_any(SizeVisitor)
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MetricBound {
92 pub min: Option<f64>,
93 pub max: Option<f64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct FalsificationCheck {
98 pub name: String,
99 #[serde(default)]
100 pub description: String,
101 #[serde(default)]
102 pub check: String,
103 #[serde(flatten, default)]
104 pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
105}
106
107#[derive(Debug)]
109pub struct ContractVerification {
110 pub contract_name: String,
111 pub passed: Vec<String>,
112 pub failed: Vec<String>,
113 pub skipped: Vec<String>,
114}
115
116impl ContractVerification {
117 pub fn is_pass(&self) -> bool {
118 self.failed.is_empty()
119 }
120}
121
122pub fn load_contract(path: &Path) -> Result<PerformanceContract> {
124 let content = std::fs::read_to_string(path)
125 .with_context(|| format!("Failed to read contract: {}", path.display()))?;
126 let contract: PerformanceContract = serde_yaml_ng::from_str(&content)
127 .with_context(|| format!("Failed to parse contract: {}", path.display()))?;
128 Ok(contract)
129}
130
131pub fn load_contracts_dir(dir: &Path) -> Result<Vec<PerformanceContract>> {
133 let mut contracts = Vec::new();
134 if dir.is_dir() {
135 for entry in std::fs::read_dir(dir)? {
136 let entry = entry?;
137 let path = entry.path();
138 if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
139 match load_contract(&path) {
140 Ok(c) => contracts.push(c),
141 Err(e) => eprintln!("Warning: skipping {}: {e}", path.display()),
142 }
143 }
144 }
145 }
146 Ok(contracts)
147}
148
149pub fn verify_contract(contract: &PerformanceContract) -> ContractVerification {
153 let mut result = ContractVerification {
154 contract_name: contract.name.clone(),
155 passed: Vec::new(),
156 failed: Vec::new(),
157 skipped: Vec::new(),
158 };
159 validate_contract_metadata(contract, &mut result);
160 for (i, bound) in contract.bounds.iter().enumerate() {
161 verify_single_bound(contract, bound, i, &mut result);
162 }
163 for check in &contract.falsification {
164 verify_single_falsification(contract, check, &mut result);
165 }
166 result
167}
168
169fn validate_contract_metadata(contract: &PerformanceContract, result: &mut ContractVerification) {
171 if contract.kind.is_empty() {
172 result
173 .failed
174 .push("Contract missing 'kind' field".to_string());
175 } else {
176 result.passed.push(format!("kind: {}", contract.kind));
177 }
178 if contract.kernel.is_empty() {
179 result
180 .skipped
181 .push("No kernel field — domain-specific contract".to_string());
182 } else {
183 result.passed.push(format!("kernel: {}", contract.kernel));
184 }
185}
186
187fn verify_single_bound(
189 contract: &PerformanceContract,
190 bound: &PerformanceBound,
191 i: usize,
192 result: &mut ContractVerification,
193) {
194 if bound.size.is_empty() {
195 result
196 .passed
197 .push(format!("Bound {i}: structural (no size)"));
198 return;
199 }
200 let size = bound.size[0];
201 match load_kernel_profile(&contract.kernel, size) {
202 Some(p) => check_bound_thresholds(bound, i, &p, result),
203 None => check_bound_structural(bound, i, result),
204 }
205}
206
207fn load_kernel_profile(kernel: &str, size: u32) -> Option<crate::metrics::catalog::FullProfile> {
209 let profile_path = format!("/tmp/cgp-{kernel}-{size}.json");
210 let path = std::path::Path::new(&profile_path);
211 if !path.exists() {
212 return None;
213 }
214 crate::metrics::export::load_json(path).ok()
215}
216
217fn check_bound_thresholds(
219 bound: &PerformanceBound,
220 i: usize,
221 p: &crate::metrics::catalog::FullProfile,
222 result: &mut ContractVerification,
223) {
224 check_max_time(bound, i, p, result);
225 check_min_tflops(bound, i, p, result);
226 check_min_bandwidth(bound, i, p, result);
227}
228
229fn check_max_time(
231 bound: &PerformanceBound,
232 i: usize,
233 p: &crate::metrics::catalog::FullProfile,
234 result: &mut ContractVerification,
235) {
236 let Some(max_time) = bound.max_time_us else {
237 return;
238 };
239 let actual = p.timing.wall_clock_time_us;
240 if actual <= max_time {
241 result
242 .passed
243 .push(format!("Bound {i}: time {actual:.1}us <= {max_time:.1}us"));
244 } else {
245 result.failed.push(format!(
246 "Bound {i}: time {actual:.1}us > {max_time:.1}us EXCEEDED"
247 ));
248 }
249}
250
251fn check_min_tflops(
253 bound: &PerformanceBound,
254 i: usize,
255 p: &crate::metrics::catalog::FullProfile,
256 result: &mut ContractVerification,
257) {
258 let Some(min_tflops) = bound.min_tflops else {
259 return;
260 };
261 let actual = p.throughput.tflops;
262 if actual >= min_tflops {
263 result
264 .passed
265 .push(format!("Bound {i}: {actual:.1} TFLOP/s >= {min_tflops:.1}"));
266 } else {
267 result.failed.push(format!(
268 "Bound {i}: {actual:.1} TFLOP/s < {min_tflops:.1} BELOW MINIMUM"
269 ));
270 }
271}
272
273fn check_min_bandwidth(
275 bound: &PerformanceBound,
276 i: usize,
277 p: &crate::metrics::catalog::FullProfile,
278 result: &mut ContractVerification,
279) {
280 let Some(min_bw) = bound.min_bandwidth_gbps else {
281 return;
282 };
283 let actual = p.throughput.bandwidth_gbps;
284 if actual >= min_bw {
285 result
286 .passed
287 .push(format!("Bound {i}: {actual:.1} GB/s >= {min_bw:.1}"));
288 } else {
289 result.failed.push(format!(
290 "Bound {i}: {actual:.1} GB/s < {min_bw:.1} BELOW MINIMUM"
291 ));
292 }
293}
294
295fn check_bound_structural(bound: &PerformanceBound, i: usize, result: &mut ContractVerification) {
297 result
298 .passed
299 .push(format!("Bound {i}: size {:?}", bound.size));
300 if bound.max_time_us.is_none()
301 && bound.min_tflops.is_none()
302 && bound.min_bandwidth_gbps.is_none()
303 {
304 result
305 .skipped
306 .push(format!("Bound {i}: no criteria specified"));
307 }
308}
309
310fn verify_single_falsification(
312 contract: &PerformanceContract,
313 check: &FalsificationCheck,
314 result: &mut ContractVerification,
315) {
316 if check.name.is_empty() || check.check.is_empty() {
317 result.failed.push(format!(
318 "Falsification '{}': missing name or check",
319 check.name
320 ));
321 return;
322 }
323 let size = contract
324 .bounds
325 .first()
326 .and_then(|b| b.size.first())
327 .copied()
328 .unwrap_or(512);
329 let profile_path = format!("/tmp/cgp-{}-{size}.json", contract.kernel);
330 match load_kernel_profile(&contract.kernel, size) {
331 Some(p) => {
332 if evaluate_check(&check.check, &p) {
333 result.passed.push(format!("FALSIFY {}: PASS", check.name));
334 } else {
335 result.failed.push(format!(
336 "FALSIFY {}: FAIL ({})",
337 check.name, check.description
338 ));
339 }
340 }
341 None => {
342 result.skipped.push(format!(
343 "FALSIFY {}: {} (no profile at {profile_path})",
344 check.name, check.description
345 ));
346 }
347 }
348}
349
350fn evaluate_check(expr: &str, profile: &crate::metrics::catalog::FullProfile) -> bool {
353 let parts: Vec<&str> = expr.split_whitespace().collect();
354 if parts.len() != 3 {
355 return false;
356 }
357 let field = parts[0];
358 let op = parts[1];
359 let threshold: f64 = match parts[2].parse() {
360 Ok(v) => v,
361 Err(_) => return false,
362 };
363
364 let value = match field {
365 "tflops" => profile.throughput.tflops,
366 "wall_clock_time_us" => profile.timing.wall_clock_time_us,
367 "bandwidth_gbps" => profile.throughput.bandwidth_gbps,
368 "arithmetic_intensity" => profile.throughput.arithmetic_intensity,
369 "warp_execution_efficiency" => profile
370 .gpu_compute
371 .as_ref()
372 .map_or(0.0, |g| g.warp_execution_efficiency_pct),
373 "achieved_occupancy" => profile
374 .gpu_compute
375 .as_ref()
376 .map_or(0.0, |g| g.achieved_occupancy_pct),
377 "global_load_efficiency" => profile
378 .gpu_memory
379 .as_ref()
380 .map_or(0.0, |g| g.global_load_efficiency_pct),
381 _ => return false,
382 };
383
384 match op {
385 ">" => value > threshold,
386 "<" => value < threshold,
387 ">=" => value >= threshold,
388 "<=" => value <= threshold,
389 "==" => (value - threshold).abs() < 0.001,
390 _ => false,
391 }
392}
393
394pub fn run_verify(
396 contracts_dir: Option<&str>,
397 contract_file: Option<&str>,
398 self_verify: bool,
399 fail_on_regression: bool,
400) -> Result<()> {
401 let Some(contracts) = resolve_contracts_input(contracts_dir, contract_file, self_verify)?
402 else {
403 return Ok(());
404 };
405
406 println!("\n=== cgp Contract Verification ===\n");
407 let totals = run_verify_all(&contracts);
408 println!(
409 "\n Total: {} pass, {} fail, {} skip",
410 totals.pass, totals.fail, totals.skip
411 );
412 if totals.fail > 0 && fail_on_regression {
413 anyhow::bail!("{} contract verification(s) failed", totals.fail);
414 }
415 println!();
416 Ok(())
417}
418
419fn resolve_contracts_input(
422 contracts_dir: Option<&str>,
423 contract_file: Option<&str>,
424 self_verify: bool,
425) -> Result<Option<Vec<PerformanceContract>>> {
426 if let Some(dir) = contracts_dir {
427 return Ok(Some(load_contracts_dir(Path::new(dir))?));
428 }
429 if let Some(file) = contract_file {
430 return Ok(Some(vec![load_contract(Path::new(file))?]));
431 }
432 if self_verify {
433 let dir = Path::new("contracts/cgp");
434 if !dir.exists() {
435 println!("No contracts found at contracts/cgp/");
436 return Ok(None);
437 }
438 return Ok(Some(load_contracts_dir(dir)?));
439 }
440 anyhow::bail!("Specify --contracts-dir, --contract, or --self");
441}
442
443#[derive(Default)]
444struct VerifyTotals {
445 pass: usize,
446 fail: usize,
447 skip: usize,
448}
449
450fn run_verify_all(contracts: &[PerformanceContract]) -> VerifyTotals {
452 let mut totals = VerifyTotals::default();
453 for c in contracts {
454 let result = verify_contract(c);
455 print_contract_status(c, &result);
456 totals.pass += result.passed.len();
457 totals.fail += result.failed.len();
458 totals.skip += result.skipped.len();
459 }
460 totals
461}
462
463fn print_contract_status(c: &PerformanceContract, result: &ContractVerification) {
464 let status = if result.is_pass() {
465 "\x1b[32mPASS\x1b[0m"
466 } else {
467 "\x1b[31mFAIL\x1b[0m"
468 };
469 println!(
470 " {} {} ({} pass, {} fail, {} skip)",
471 status,
472 c.name,
473 result.passed.len(),
474 result.failed.len(),
475 result.skipped.len()
476 );
477}
478
479pub fn run_generate(kernel: &str, size: u32, tolerance: f64) -> Result<()> {
481 let profile_path = format!("/tmp/cgp-{kernel}-{size}.json");
483 let profile = if std::path::Path::new(&profile_path).exists() {
484 Some(crate::metrics::export::load_json(std::path::Path::new(
485 &profile_path,
486 ))?)
487 } else {
488 None
489 };
490
491 let (time_us, tflops) = match &profile {
492 Some(p) => (p.timing.wall_clock_time_us, p.throughput.tflops),
493 None => {
494 let flops = 2.0 * (size as f64).powi(3);
496 let est_time = 23.2 * (size as f64 / 512.0).powi(3); let est_tflops = flops / (est_time * 1e-6) / 1e12;
498 (est_time, est_tflops)
499 }
500 };
501
502 let max_time = time_us * (1.0 + tolerance / 100.0);
503 let min_tflops = tflops * (1.0 - tolerance / 100.0);
504
505 let gpu_name = std::process::Command::new("nvidia-smi")
507 .args(["--query-gpu=name", "--format=csv,noheader"])
508 .output()
509 .ok()
510 .filter(|o| o.status.success())
511 .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
512 .unwrap_or_else(|| "Unknown GPU".to_string());
513
514 let contract_yaml = format!(
515 r#"# Generated by cgp contract generate
516# Kernel: {kernel} at size {size}x{size}x{size}
517# Tolerance: {tolerance}%
518kind: PerformanceContract
519name: {kernel}-{size}
520version: "1.0.0"
521kernel: {kernel}
522hardware:
523 gpu: "{gpu_name}"
524 compute_capability: "8.9"
525
526bounds:
527 - size: [{size}, {size}, {size}]
528 max_time_us: {max_time:.1}
529 min_tflops: {min_tflops:.1}
530 max_regression_pct: {tolerance}
531
532metrics:
533 warp_execution_efficiency:
534 min: 95.0
535 achieved_occupancy:
536 min: 25.0
537
538falsification:
539 - name: FALSIFY-{kernel_upper}-001
540 description: "{kernel} must achieve >{min_tflops:.1} TFLOP/s at {size}x{size}"
541 check: "tflops > {min_tflops:.1}"
542 - name: FALSIFY-{kernel_upper}-002
543 description: "{kernel} must complete in <{max_time:.1}us at {size}x{size}"
544 check: "wall_clock_time_us < {max_time:.1}"
545"#,
546 kernel = kernel,
547 size = size,
548 tolerance = tolerance,
549 gpu_name = gpu_name,
550 max_time = max_time,
551 min_tflops = min_tflops,
552 kernel_upper = kernel.to_uppercase().replace('-', "_"),
553 );
554
555 let contracts_dir = std::path::Path::new("contracts/cgp");
557 std::fs::create_dir_all(contracts_dir)?;
558 let contract_path = contracts_dir.join(format!("{kernel}-{size}-v1.yaml"));
559 std::fs::write(&contract_path, &contract_yaml)?;
560
561 println!("Generated contract: {}", contract_path.display());
562 println!();
563 print!("{contract_yaml}");
564
565 Ok(())
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 fn sample_contract() -> PerformanceContract {
573 PerformanceContract {
574 kind: "PerformanceContract".to_string(),
575 name: "test-gemm-contract".to_string(),
576 version: "1.0.0".to_string(),
577 kernel: "gemm_cta_wmma_fp16".to_string(),
578 hardware: HardwareSpec {
579 gpu: Some("NVIDIA GeForce RTX 4090".to_string()),
580 cpu: None,
581 compute_capability: Some("8.9".to_string()),
582 },
583 bounds: vec![PerformanceBound {
584 size: vec![512, 512, 512],
585 max_time_us: Some(30.0),
586 min_tflops: Some(9.0),
587 max_regression_pct: Some(10.0),
588 min_bandwidth_gbps: None,
589 extra: Default::default(),
590 }],
591 metrics: {
592 let mut m = std::collections::HashMap::new();
593 m.insert(
594 "warp_execution_efficiency".to_string(),
595 MetricBound {
596 min: Some(95.0),
597 max: None,
598 },
599 );
600 m
601 },
602 falsification: vec![FalsificationCheck {
603 name: "FALSIFY-TEST-001".to_string(),
604 description: "CTA WMMA must achieve >9 TFLOP/s".to_string(),
605 check: "tflops > 9.0".to_string(),
606 extra: Default::default(),
607 }],
608 extra: Default::default(),
609 }
610 }
611
612 #[test]
613 fn test_verify_valid_contract() {
614 let contract = sample_contract();
615 let result = verify_contract(&contract);
616 assert!(result.is_pass());
617 assert!(!result.passed.is_empty());
618 }
619
620 #[test]
621 fn test_verify_missing_kernel_is_skipped() {
622 let mut contract = sample_contract();
623 contract.kernel = String::new();
624 let result = verify_contract(&contract);
625 assert!(result.is_pass());
627 assert!(!result.skipped.is_empty());
628 }
629
630 #[test]
631 fn test_contract_yaml_roundtrip() {
632 let contract = sample_contract();
633 let yaml = serde_yaml_ng::to_string(&contract).unwrap();
634 let parsed: PerformanceContract = serde_yaml_ng::from_str(&yaml).unwrap();
635 assert_eq!(parsed.name, contract.name);
636 assert_eq!(parsed.kernel, contract.kernel);
637 assert_eq!(parsed.bounds.len(), 1);
638 assert_eq!(parsed.bounds[0].size, vec![512, 512, 512]);
639 }
640
641 #[test]
642 fn test_contract_falsification_checks() {
643 let contract = sample_contract();
644 let result = verify_contract(&contract);
645 assert!(result.is_pass());
647 assert!(!result.skipped.is_empty());
648 }
649}