1use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PerformanceContract {
12 pub kind: String,
13 pub name: String,
14 pub version: String,
15 #[serde(default)]
16 pub kernel: String,
17 #[serde(default)]
18 pub hardware: HardwareSpec,
19 #[serde(default)]
20 pub bounds: Vec<PerformanceBound>,
21 #[serde(default)]
22 pub metrics: std::collections::HashMap<String, MetricBound>,
23 #[serde(default)]
24 pub falsification: Vec<FalsificationCheck>,
25 #[serde(flatten, default)]
27 pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct HardwareSpec {
32 pub gpu: Option<String>,
33 pub cpu: Option<String>,
34 pub compute_capability: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PerformanceBound {
39 #[serde(default, deserialize_with = "deserialize_size")]
40 pub size: Vec<u32>,
41 #[serde(default)]
42 pub max_time_us: Option<f64>,
43 #[serde(default)]
44 pub min_tflops: Option<f64>,
45 #[serde(default)]
46 pub max_regression_pct: Option<f64>,
47 #[serde(default)]
48 pub min_bandwidth_gbps: Option<f64>,
49 #[serde(flatten, default)]
51 pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
52}
53
54fn deserialize_size<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
56where
57 D: serde::Deserializer<'de>,
58{
59 use serde::de;
60
61 struct SizeVisitor;
62 impl<'de> de::Visitor<'de> for SizeVisitor {
63 type Value = Vec<u32>;
64 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65 f.write_str("an integer or sequence of integers")
66 }
67 fn visit_u64<E: de::Error>(self, v: u64) -> Result<Vec<u32>, E> {
68 Ok(vec![v as u32])
69 }
70 fn visit_i64<E: de::Error>(self, v: i64) -> Result<Vec<u32>, E> {
71 Ok(vec![v as u32])
72 }
73 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Vec<u32>, A::Error> {
74 let mut v = Vec::new();
75 while let Some(elem) = seq.next_element::<u32>()? {
76 v.push(elem);
77 }
78 Ok(v)
79 }
80 fn visit_none<E: de::Error>(self) -> Result<Vec<u32>, E> {
81 Ok(Vec::new())
82 }
83 fn visit_unit<E: de::Error>(self) -> Result<Vec<u32>, E> {
84 Ok(Vec::new())
85 }
86 }
87 deserializer.deserialize_any(SizeVisitor)
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MetricBound {
92 pub min: Option<f64>,
93 pub max: Option<f64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct FalsificationCheck {
98 pub name: String,
99 #[serde(default)]
100 pub description: String,
101 #[serde(default)]
102 pub check: String,
103 #[serde(flatten, default)]
104 pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
105}
106
107#[derive(Debug)]
109pub struct ContractVerification {
110 pub contract_name: String,
111 pub passed: Vec<String>,
112 pub failed: Vec<String>,
113 pub skipped: Vec<String>,
114}
115
116impl ContractVerification {
117 pub fn is_pass(&self) -> bool {
118 self.failed.is_empty()
119 }
120}
121
122pub fn load_contract(path: &Path) -> Result<PerformanceContract> {
124 let content = std::fs::read_to_string(path)
125 .with_context(|| format!("Failed to read contract: {}", path.display()))?;
126 let contract: PerformanceContract = serde_yaml_ng::from_str(&content)
127 .with_context(|| format!("Failed to parse contract: {}", path.display()))?;
128 Ok(contract)
129}
130
131pub fn load_contracts_dir(dir: &Path) -> Result<Vec<PerformanceContract>> {
133 let mut contracts = Vec::new();
134 if dir.is_dir() {
135 for entry in std::fs::read_dir(dir)? {
136 let entry = entry?;
137 let path = entry.path();
138 if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
139 match load_contract(&path) {
140 Ok(c) => contracts.push(c),
141 Err(e) => eprintln!("Warning: skipping {}: {e}", path.display()),
142 }
143 }
144 }
145 }
146 Ok(contracts)
147}
148
149pub fn verify_contract(contract: &PerformanceContract) -> ContractVerification {
153 let mut result = ContractVerification {
154 contract_name: contract.name.clone(),
155 passed: Vec::new(),
156 failed: Vec::new(),
157 skipped: Vec::new(),
158 };
159
160 if contract.kind.is_empty() {
162 result
163 .failed
164 .push("Contract missing 'kind' field".to_string());
165 } else {
166 result.passed.push(format!("kind: {}", contract.kind));
167 }
168
169 if contract.kernel.is_empty() {
170 result
171 .skipped
172 .push("No kernel field — domain-specific contract".to_string());
173 } else {
174 result.passed.push(format!("kernel: {}", contract.kernel));
175 }
176
177 for (i, bound) in contract.bounds.iter().enumerate() {
179 if bound.size.is_empty() {
180 result
182 .passed
183 .push(format!("Bound {i}: structural (no size)"));
184 continue;
185 }
186
187 let size = bound.size[0];
188 let profile_path = format!("/tmp/cgp-{}-{size}.json", contract.kernel);
189 let profile = std::path::Path::new(&profile_path)
190 .exists()
191 .then(|| crate::metrics::export::load_json(std::path::Path::new(&profile_path)).ok())
192 .flatten();
193
194 match profile {
195 Some(p) => {
196 if let Some(max_time) = bound.max_time_us {
198 if p.timing.wall_clock_time_us <= max_time {
199 result.passed.push(format!(
200 "Bound {i}: time {:.1}us <= {max_time:.1}us",
201 p.timing.wall_clock_time_us
202 ));
203 } else {
204 result.failed.push(format!(
205 "Bound {i}: time {:.1}us > {max_time:.1}us EXCEEDED",
206 p.timing.wall_clock_time_us
207 ));
208 }
209 }
210 if let Some(min_tflops) = bound.min_tflops {
212 if p.throughput.tflops >= min_tflops {
213 result.passed.push(format!(
214 "Bound {i}: {:.1} TFLOP/s >= {min_tflops:.1}",
215 p.throughput.tflops
216 ));
217 } else {
218 result.failed.push(format!(
219 "Bound {i}: {:.1} TFLOP/s < {min_tflops:.1} BELOW MINIMUM",
220 p.throughput.tflops
221 ));
222 }
223 }
224 if let Some(min_bw) = bound.min_bandwidth_gbps {
226 if p.throughput.bandwidth_gbps >= min_bw {
227 result.passed.push(format!(
228 "Bound {i}: {:.1} GB/s >= {min_bw:.1}",
229 p.throughput.bandwidth_gbps
230 ));
231 } else {
232 result.failed.push(format!(
233 "Bound {i}: {:.1} GB/s < {min_bw:.1} BELOW MINIMUM",
234 p.throughput.bandwidth_gbps
235 ));
236 }
237 }
238 }
239 None => {
240 result
242 .passed
243 .push(format!("Bound {i}: size {:?}", bound.size));
244 if bound.max_time_us.is_none()
245 && bound.min_tflops.is_none()
246 && bound.min_bandwidth_gbps.is_none()
247 {
248 result
249 .skipped
250 .push(format!("Bound {i}: no criteria specified"));
251 }
252 }
253 }
254 }
255
256 for check in &contract.falsification {
258 if check.name.is_empty() || check.check.is_empty() {
259 result.failed.push(format!(
260 "Falsification '{}': missing name or check",
261 check.name
262 ));
263 continue;
264 }
265
266 let size = contract
268 .bounds
269 .first()
270 .and_then(|b| b.size.first())
271 .copied()
272 .unwrap_or(512);
273 let profile_path = format!("/tmp/cgp-{}-{size}.json", contract.kernel);
274 let profile = std::path::Path::new(&profile_path)
275 .exists()
276 .then(|| crate::metrics::export::load_json(std::path::Path::new(&profile_path)).ok())
277 .flatten();
278
279 match profile {
280 Some(p) => {
281 let pass = evaluate_check(&check.check, &p);
282 if pass {
283 result.passed.push(format!("FALSIFY {}: PASS", check.name));
284 } else {
285 result.failed.push(format!(
286 "FALSIFY {}: FAIL ({})",
287 check.name, check.description
288 ));
289 }
290 }
291 None => {
292 result.skipped.push(format!(
293 "FALSIFY {}: {} (no profile at {profile_path})",
294 check.name, check.description
295 ));
296 }
297 }
298 }
299
300 result
301}
302
303fn evaluate_check(expr: &str, profile: &crate::metrics::catalog::FullProfile) -> bool {
306 let parts: Vec<&str> = expr.split_whitespace().collect();
307 if parts.len() != 3 {
308 return false;
309 }
310 let field = parts[0];
311 let op = parts[1];
312 let threshold: f64 = match parts[2].parse() {
313 Ok(v) => v,
314 Err(_) => return false,
315 };
316
317 let value = match field {
318 "tflops" => profile.throughput.tflops,
319 "wall_clock_time_us" => profile.timing.wall_clock_time_us,
320 "bandwidth_gbps" => profile.throughput.bandwidth_gbps,
321 "arithmetic_intensity" => profile.throughput.arithmetic_intensity,
322 "warp_execution_efficiency" => profile
323 .gpu_compute
324 .as_ref()
325 .map_or(0.0, |g| g.warp_execution_efficiency_pct),
326 "achieved_occupancy" => profile
327 .gpu_compute
328 .as_ref()
329 .map_or(0.0, |g| g.achieved_occupancy_pct),
330 "global_load_efficiency" => profile
331 .gpu_memory
332 .as_ref()
333 .map_or(0.0, |g| g.global_load_efficiency_pct),
334 _ => return false,
335 };
336
337 match op {
338 ">" => value > threshold,
339 "<" => value < threshold,
340 ">=" => value >= threshold,
341 "<=" => value <= threshold,
342 "==" => (value - threshold).abs() < 0.001,
343 _ => false,
344 }
345}
346
347pub fn run_verify(
349 contracts_dir: Option<&str>,
350 contract_file: Option<&str>,
351 self_verify: bool,
352 fail_on_regression: bool,
353) -> Result<()> {
354 let contracts = if let Some(dir) = contracts_dir {
355 load_contracts_dir(Path::new(dir))?
356 } else if let Some(file) = contract_file {
357 vec![load_contract(Path::new(file))?]
358 } else if self_verify {
359 let dir = Path::new("contracts/cgp");
360 if dir.exists() {
361 load_contracts_dir(dir)?
362 } else {
363 println!("No contracts found at contracts/cgp/");
364 return Ok(());
365 }
366 } else {
367 anyhow::bail!("Specify --contracts-dir, --contract, or --self");
368 };
369
370 println!("\n=== cgp Contract Verification ===\n");
371 let mut total_pass = 0;
372 let mut total_fail = 0;
373 let mut total_skip = 0;
374
375 for c in &contracts {
376 let result = verify_contract(c);
377 let status = if result.is_pass() {
378 "\x1b[32mPASS\x1b[0m"
379 } else {
380 "\x1b[31mFAIL\x1b[0m"
381 };
382 println!(
383 " {} {} ({} pass, {} fail, {} skip)",
384 status,
385 c.name,
386 result.passed.len(),
387 result.failed.len(),
388 result.skipped.len()
389 );
390 total_pass += result.passed.len();
391 total_fail += result.failed.len();
392 total_skip += result.skipped.len();
393 }
394
395 println!("\n Total: {total_pass} pass, {total_fail} fail, {total_skip} skip");
396 if total_fail > 0 && fail_on_regression {
397 anyhow::bail!("{total_fail} contract verification(s) failed");
398 }
399 println!();
400 Ok(())
401}
402
403pub fn run_generate(kernel: &str, size: u32, tolerance: f64) -> Result<()> {
405 let profile_path = format!("/tmp/cgp-{kernel}-{size}.json");
407 let profile = if std::path::Path::new(&profile_path).exists() {
408 Some(crate::metrics::export::load_json(std::path::Path::new(
409 &profile_path,
410 ))?)
411 } else {
412 None
413 };
414
415 let (time_us, tflops) = match &profile {
416 Some(p) => (p.timing.wall_clock_time_us, p.throughput.tflops),
417 None => {
418 let flops = 2.0 * (size as f64).powi(3);
420 let est_time = 23.2 * (size as f64 / 512.0).powi(3); let est_tflops = flops / (est_time * 1e-6) / 1e12;
422 (est_time, est_tflops)
423 }
424 };
425
426 let max_time = time_us * (1.0 + tolerance / 100.0);
427 let min_tflops = tflops * (1.0 - tolerance / 100.0);
428
429 let gpu_name = std::process::Command::new("nvidia-smi")
431 .args(["--query-gpu=name", "--format=csv,noheader"])
432 .output()
433 .ok()
434 .filter(|o| o.status.success())
435 .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
436 .unwrap_or_else(|| "Unknown GPU".to_string());
437
438 let contract_yaml = format!(
439 r#"# Generated by cgp contract generate
440# Kernel: {kernel} at size {size}x{size}x{size}
441# Tolerance: {tolerance}%
442kind: PerformanceContract
443name: {kernel}-{size}
444version: "1.0.0"
445kernel: {kernel}
446hardware:
447 gpu: "{gpu_name}"
448 compute_capability: "8.9"
449
450bounds:
451 - size: [{size}, {size}, {size}]
452 max_time_us: {max_time:.1}
453 min_tflops: {min_tflops:.1}
454 max_regression_pct: {tolerance}
455
456metrics:
457 warp_execution_efficiency:
458 min: 95.0
459 achieved_occupancy:
460 min: 25.0
461
462falsification:
463 - name: FALSIFY-{kernel_upper}-001
464 description: "{kernel} must achieve >{min_tflops:.1} TFLOP/s at {size}x{size}"
465 check: "tflops > {min_tflops:.1}"
466 - name: FALSIFY-{kernel_upper}-002
467 description: "{kernel} must complete in <{max_time:.1}us at {size}x{size}"
468 check: "wall_clock_time_us < {max_time:.1}"
469"#,
470 kernel = kernel,
471 size = size,
472 tolerance = tolerance,
473 gpu_name = gpu_name,
474 max_time = max_time,
475 min_tflops = min_tflops,
476 kernel_upper = kernel.to_uppercase().replace('-', "_"),
477 );
478
479 let contracts_dir = std::path::Path::new("contracts/cgp");
481 std::fs::create_dir_all(contracts_dir)?;
482 let contract_path = contracts_dir.join(format!("{kernel}-{size}-v1.yaml"));
483 std::fs::write(&contract_path, &contract_yaml)?;
484
485 println!("Generated contract: {}", contract_path.display());
486 println!();
487 print!("{contract_yaml}");
488
489 Ok(())
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 fn sample_contract() -> PerformanceContract {
497 PerformanceContract {
498 kind: "PerformanceContract".to_string(),
499 name: "test-gemm-contract".to_string(),
500 version: "1.0.0".to_string(),
501 kernel: "gemm_cta_wmma_fp16".to_string(),
502 hardware: HardwareSpec {
503 gpu: Some("NVIDIA GeForce RTX 4090".to_string()),
504 cpu: None,
505 compute_capability: Some("8.9".to_string()),
506 },
507 bounds: vec![PerformanceBound {
508 size: vec![512, 512, 512],
509 max_time_us: Some(30.0),
510 min_tflops: Some(9.0),
511 max_regression_pct: Some(10.0),
512 min_bandwidth_gbps: None,
513 extra: Default::default(),
514 }],
515 metrics: {
516 let mut m = std::collections::HashMap::new();
517 m.insert(
518 "warp_execution_efficiency".to_string(),
519 MetricBound {
520 min: Some(95.0),
521 max: None,
522 },
523 );
524 m
525 },
526 falsification: vec![FalsificationCheck {
527 name: "FALSIFY-TEST-001".to_string(),
528 description: "CTA WMMA must achieve >9 TFLOP/s".to_string(),
529 check: "tflops > 9.0".to_string(),
530 extra: Default::default(),
531 }],
532 extra: Default::default(),
533 }
534 }
535
536 #[test]
537 fn test_verify_valid_contract() {
538 let contract = sample_contract();
539 let result = verify_contract(&contract);
540 assert!(result.is_pass());
541 assert!(!result.passed.is_empty());
542 }
543
544 #[test]
545 fn test_verify_missing_kernel_is_skipped() {
546 let mut contract = sample_contract();
547 contract.kernel = String::new();
548 let result = verify_contract(&contract);
549 assert!(result.is_pass());
551 assert!(!result.skipped.is_empty());
552 }
553
554 #[test]
555 fn test_contract_yaml_roundtrip() {
556 let contract = sample_contract();
557 let yaml = serde_yaml_ng::to_string(&contract).unwrap();
558 let parsed: PerformanceContract = serde_yaml_ng::from_str(&yaml).unwrap();
559 assert_eq!(parsed.name, contract.name);
560 assert_eq!(parsed.kernel, contract.kernel);
561 assert_eq!(parsed.bounds.len(), 1);
562 assert_eq!(parsed.bounds[0].size, vec![512, 512, 512]);
563 }
564
565 #[test]
566 fn test_contract_falsification_checks() {
567 let contract = sample_contract();
568 let result = verify_contract(&contract);
569 assert!(result.is_pass());
571 assert!(!result.skipped.is_empty());
572 }
573}