Skip to main content

cgp/analysis/
contracts.rs

1//! Performance contract verification (CI/CD gate).
2//! Extends provable-contracts framework to performance bounds.
3//! See spec section 3.4 and 7.1.
4
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9/// A performance contract loaded from YAML.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PerformanceContract {
12    pub kind: String,
13    pub name: String,
14    pub version: String,
15    #[serde(default)]
16    pub kernel: String,
17    #[serde(default)]
18    pub hardware: HardwareSpec,
19    #[serde(default)]
20    pub bounds: Vec<PerformanceBound>,
21    #[serde(default)]
22    pub metrics: std::collections::HashMap<String, MetricBound>,
23    #[serde(default)]
24    pub falsification: Vec<FalsificationCheck>,
25    /// Absorb any extra fields from domain-specific contract schemas
26    #[serde(flatten, default)]
27    pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct HardwareSpec {
32    pub gpu: Option<String>,
33    pub cpu: Option<String>,
34    pub compute_capability: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PerformanceBound {
39    #[serde(default, deserialize_with = "deserialize_size")]
40    pub size: Vec<u32>,
41    #[serde(default)]
42    pub max_time_us: Option<f64>,
43    #[serde(default)]
44    pub min_tflops: Option<f64>,
45    #[serde(default)]
46    pub max_regression_pct: Option<f64>,
47    #[serde(default)]
48    pub min_bandwidth_gbps: Option<f64>,
49    /// Absorb domain-specific bound fields (operation, competitor, etc.)
50    #[serde(flatten, default)]
51    pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
52}
53
54/// Accept both `size: 1024` (single int) and `size: [1024, 1024, 1024]` (sequence).
55fn deserialize_size<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
56where
57    D: serde::Deserializer<'de>,
58{
59    use serde::de;
60
61    struct SizeVisitor;
62    impl<'de> de::Visitor<'de> for SizeVisitor {
63        type Value = Vec<u32>;
64        fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65            f.write_str("an integer or sequence of integers")
66        }
67        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Vec<u32>, E> {
68            Ok(vec![v as u32])
69        }
70        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Vec<u32>, E> {
71            Ok(vec![v as u32])
72        }
73        fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Vec<u32>, A::Error> {
74            let mut v = Vec::new();
75            while let Some(elem) = seq.next_element::<u32>()? {
76                v.push(elem);
77            }
78            Ok(v)
79        }
80        fn visit_none<E: de::Error>(self) -> Result<Vec<u32>, E> {
81            Ok(Vec::new())
82        }
83        fn visit_unit<E: de::Error>(self) -> Result<Vec<u32>, E> {
84            Ok(Vec::new())
85        }
86    }
87    deserializer.deserialize_any(SizeVisitor)
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MetricBound {
92    pub min: Option<f64>,
93    pub max: Option<f64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct FalsificationCheck {
98    pub name: String,
99    #[serde(default)]
100    pub description: String,
101    #[serde(default)]
102    pub check: String,
103    #[serde(flatten, default)]
104    pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
105}
106
107/// Result of verifying a single contract.
108#[derive(Debug)]
109pub struct ContractVerification {
110    pub contract_name: String,
111    pub passed: Vec<String>,
112    pub failed: Vec<String>,
113    pub skipped: Vec<String>,
114}
115
116impl ContractVerification {
117    pub fn is_pass(&self) -> bool {
118        self.failed.is_empty()
119    }
120}
121
122/// Load a performance contract from a YAML file.
123pub fn load_contract(path: &Path) -> Result<PerformanceContract> {
124    let content = std::fs::read_to_string(path)
125        .with_context(|| format!("Failed to read contract: {}", path.display()))?;
126    let contract: PerformanceContract = serde_yaml_ng::from_str(&content)
127        .with_context(|| format!("Failed to parse contract: {}", path.display()))?;
128    Ok(contract)
129}
130
131/// Load all contracts from a directory.
132pub fn load_contracts_dir(dir: &Path) -> Result<Vec<PerformanceContract>> {
133    let mut contracts = Vec::new();
134    if dir.is_dir() {
135        for entry in std::fs::read_dir(dir)? {
136            let entry = entry?;
137            let path = entry.path();
138            if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
139                match load_contract(&path) {
140                    Ok(c) => contracts.push(c),
141                    Err(e) => eprintln!("Warning: skipping {}: {e}", path.display()),
142                }
143            }
144        }
145    }
146    Ok(contracts)
147}
148
149/// Verify a contract against measured values.
150/// Loads saved profiles from /tmp/cgp-{kernel}-{size}.json if available,
151/// and checks performance bounds + falsification expressions.
152pub fn verify_contract(contract: &PerformanceContract) -> ContractVerification {
153    let mut result = ContractVerification {
154        contract_name: contract.name.clone(),
155        passed: Vec::new(),
156        failed: Vec::new(),
157        skipped: Vec::new(),
158    };
159    validate_contract_metadata(contract, &mut result);
160    for (i, bound) in contract.bounds.iter().enumerate() {
161        verify_single_bound(contract, bound, i, &mut result);
162    }
163    for check in &contract.falsification {
164        verify_single_falsification(contract, check, &mut result);
165    }
166    result
167}
168
169/// Check the contract has the required `kind` field and record `kernel` if present.
170fn validate_contract_metadata(contract: &PerformanceContract, result: &mut ContractVerification) {
171    if contract.kind.is_empty() {
172        result
173            .failed
174            .push("Contract missing 'kind' field".to_string());
175    } else {
176        result.passed.push(format!("kind: {}", contract.kind));
177    }
178    if contract.kernel.is_empty() {
179        result
180            .skipped
181            .push("No kernel field — domain-specific contract".to_string());
182    } else {
183        result.passed.push(format!("kernel: {}", contract.kernel));
184    }
185}
186
187/// Verify one `PerformanceBound`: structural pass if no size, else compare to profile data.
188fn verify_single_bound(
189    contract: &PerformanceContract,
190    bound: &PerformanceBound,
191    i: usize,
192    result: &mut ContractVerification,
193) {
194    if bound.size.is_empty() {
195        result
196            .passed
197            .push(format!("Bound {i}: structural (no size)"));
198        return;
199    }
200    let size = bound.size[0];
201    match load_kernel_profile(&contract.kernel, size) {
202        Some(p) => check_bound_thresholds(bound, i, &p, result),
203        None => check_bound_structural(bound, i, result),
204    }
205}
206
207/// Load `/tmp/cgp-{kernel}-{size}.json` profile if present and parseable.
208fn load_kernel_profile(kernel: &str, size: u32) -> Option<crate::metrics::catalog::FullProfile> {
209    let profile_path = format!("/tmp/cgp-{kernel}-{size}.json");
210    let path = std::path::Path::new(&profile_path);
211    if !path.exists() {
212        return None;
213    }
214    crate::metrics::export::load_json(path).ok()
215}
216
217/// Run each individual threshold check present on the bound.
218fn check_bound_thresholds(
219    bound: &PerformanceBound,
220    i: usize,
221    p: &crate::metrics::catalog::FullProfile,
222    result: &mut ContractVerification,
223) {
224    check_max_time(bound, i, p, result);
225    check_min_tflops(bound, i, p, result);
226    check_min_bandwidth(bound, i, p, result);
227}
228
229/// Compare `wall_clock_time_us` to `bound.max_time_us` (lower is better).
230fn check_max_time(
231    bound: &PerformanceBound,
232    i: usize,
233    p: &crate::metrics::catalog::FullProfile,
234    result: &mut ContractVerification,
235) {
236    let Some(max_time) = bound.max_time_us else {
237        return;
238    };
239    let actual = p.timing.wall_clock_time_us;
240    if actual <= max_time {
241        result
242            .passed
243            .push(format!("Bound {i}: time {actual:.1}us <= {max_time:.1}us"));
244    } else {
245        result.failed.push(format!(
246            "Bound {i}: time {actual:.1}us > {max_time:.1}us EXCEEDED"
247        ));
248    }
249}
250
251/// Compare measured TFLOP/s to `bound.min_tflops` (higher is better).
252fn check_min_tflops(
253    bound: &PerformanceBound,
254    i: usize,
255    p: &crate::metrics::catalog::FullProfile,
256    result: &mut ContractVerification,
257) {
258    let Some(min_tflops) = bound.min_tflops else {
259        return;
260    };
261    let actual = p.throughput.tflops;
262    if actual >= min_tflops {
263        result
264            .passed
265            .push(format!("Bound {i}: {actual:.1} TFLOP/s >= {min_tflops:.1}"));
266    } else {
267        result.failed.push(format!(
268            "Bound {i}: {actual:.1} TFLOP/s < {min_tflops:.1} BELOW MINIMUM"
269        ));
270    }
271}
272
273/// Compare measured GB/s to `bound.min_bandwidth_gbps` (higher is better).
274fn check_min_bandwidth(
275    bound: &PerformanceBound,
276    i: usize,
277    p: &crate::metrics::catalog::FullProfile,
278    result: &mut ContractVerification,
279) {
280    let Some(min_bw) = bound.min_bandwidth_gbps else {
281        return;
282    };
283    let actual = p.throughput.bandwidth_gbps;
284    if actual >= min_bw {
285        result
286            .passed
287            .push(format!("Bound {i}: {actual:.1} GB/s >= {min_bw:.1}"));
288    } else {
289        result.failed.push(format!(
290            "Bound {i}: {actual:.1} GB/s < {min_bw:.1} BELOW MINIMUM"
291        ));
292    }
293}
294
295/// No profile found: record the bound structurally and flag empty-criteria bounds.
296fn check_bound_structural(bound: &PerformanceBound, i: usize, result: &mut ContractVerification) {
297    result
298        .passed
299        .push(format!("Bound {i}: size {:?}", bound.size));
300    if bound.max_time_us.is_none()
301        && bound.min_tflops.is_none()
302        && bound.min_bandwidth_gbps.is_none()
303    {
304        result
305            .skipped
306            .push(format!("Bound {i}: no criteria specified"));
307    }
308}
309
310/// Verify one falsification clause: validate fields, then evaluate against a profile.
311fn verify_single_falsification(
312    contract: &PerformanceContract,
313    check: &FalsificationCheck,
314    result: &mut ContractVerification,
315) {
316    if check.name.is_empty() || check.check.is_empty() {
317        result.failed.push(format!(
318            "Falsification '{}': missing name or check",
319            check.name
320        ));
321        return;
322    }
323    let size = contract
324        .bounds
325        .first()
326        .and_then(|b| b.size.first())
327        .copied()
328        .unwrap_or(512);
329    let profile_path = format!("/tmp/cgp-{}-{size}.json", contract.kernel);
330    match load_kernel_profile(&contract.kernel, size) {
331        Some(p) => {
332            if evaluate_check(&check.check, &p) {
333                result.passed.push(format!("FALSIFY {}: PASS", check.name));
334            } else {
335                result.failed.push(format!(
336                    "FALSIFY {}: FAIL ({})",
337                    check.name, check.description
338                ));
339            }
340        }
341        None => {
342            result.skipped.push(format!(
343                "FALSIFY {}: {} (no profile at {profile_path})",
344                check.name, check.description
345            ));
346        }
347    }
348}
349
350/// Evaluate a simple check expression against a profile.
351/// Supports: "field > value", "field < value", "field >= value", "field == value"
352fn evaluate_check(expr: &str, profile: &crate::metrics::catalog::FullProfile) -> bool {
353    let parts: Vec<&str> = expr.split_whitespace().collect();
354    if parts.len() != 3 {
355        return false;
356    }
357    let field = parts[0];
358    let op = parts[1];
359    let threshold: f64 = match parts[2].parse() {
360        Ok(v) => v,
361        Err(_) => return false,
362    };
363
364    let value = match field {
365        "tflops" => profile.throughput.tflops,
366        "wall_clock_time_us" => profile.timing.wall_clock_time_us,
367        "bandwidth_gbps" => profile.throughput.bandwidth_gbps,
368        "arithmetic_intensity" => profile.throughput.arithmetic_intensity,
369        "warp_execution_efficiency" => profile
370            .gpu_compute
371            .as_ref()
372            .map_or(0.0, |g| g.warp_execution_efficiency_pct),
373        "achieved_occupancy" => profile
374            .gpu_compute
375            .as_ref()
376            .map_or(0.0, |g| g.achieved_occupancy_pct),
377        "global_load_efficiency" => profile
378            .gpu_memory
379            .as_ref()
380            .map_or(0.0, |g| g.global_load_efficiency_pct),
381        _ => return false,
382    };
383
384    match op {
385        ">" => value > threshold,
386        "<" => value < threshold,
387        ">=" => value >= threshold,
388        "<=" => value <= threshold,
389        "==" => (value - threshold).abs() < 0.001,
390        _ => false,
391    }
392}
393
394/// Run contract verification for a directory of contracts.
395pub fn run_verify(
396    contracts_dir: Option<&str>,
397    contract_file: Option<&str>,
398    self_verify: bool,
399    fail_on_regression: bool,
400) -> Result<()> {
401    let Some(contracts) = resolve_contracts_input(contracts_dir, contract_file, self_verify)?
402    else {
403        return Ok(());
404    };
405
406    println!("\n=== cgp Contract Verification ===\n");
407    let totals = run_verify_all(&contracts);
408    println!(
409        "\n  Total: {} pass, {} fail, {} skip",
410        totals.pass, totals.fail, totals.skip
411    );
412    if totals.fail > 0 && fail_on_regression {
413        anyhow::bail!("{} contract verification(s) failed", totals.fail);
414    }
415    println!();
416    Ok(())
417}
418
419/// Resolve the set of contracts to verify from CLI flags. Returns `None` when the
420/// self-verify directory is absent (early-exit signal for `run_verify`).
421fn resolve_contracts_input(
422    contracts_dir: Option<&str>,
423    contract_file: Option<&str>,
424    self_verify: bool,
425) -> Result<Option<Vec<PerformanceContract>>> {
426    if let Some(dir) = contracts_dir {
427        return Ok(Some(load_contracts_dir(Path::new(dir))?));
428    }
429    if let Some(file) = contract_file {
430        return Ok(Some(vec![load_contract(Path::new(file))?]));
431    }
432    if self_verify {
433        let dir = Path::new("contracts/cgp");
434        if !dir.exists() {
435            println!("No contracts found at contracts/cgp/");
436            return Ok(None);
437        }
438        return Ok(Some(load_contracts_dir(dir)?));
439    }
440    anyhow::bail!("Specify --contracts-dir, --contract, or --self");
441}
442
443#[derive(Default)]
444struct VerifyTotals {
445    pass: usize,
446    fail: usize,
447    skip: usize,
448}
449
450/// Verify every contract and print per-contract status; return aggregate counts.
451fn run_verify_all(contracts: &[PerformanceContract]) -> VerifyTotals {
452    let mut totals = VerifyTotals::default();
453    for c in contracts {
454        let result = verify_contract(c);
455        print_contract_status(c, &result);
456        totals.pass += result.passed.len();
457        totals.fail += result.failed.len();
458        totals.skip += result.skipped.len();
459    }
460    totals
461}
462
463fn print_contract_status(c: &PerformanceContract, result: &ContractVerification) {
464    let status = if result.is_pass() {
465        "\x1b[32mPASS\x1b[0m"
466    } else {
467        "\x1b[31mFAIL\x1b[0m"
468    };
469    println!(
470        "  {} {} ({} pass, {} fail, {} skip)",
471        status,
472        c.name,
473        result.passed.len(),
474        result.failed.len(),
475        result.skipped.len()
476    );
477}
478
479/// Generate a performance contract YAML from a profile or estimated values.
480pub fn run_generate(kernel: &str, size: u32, tolerance: f64) -> Result<()> {
481    // Try to load a saved profile for this kernel
482    let profile_path = format!("/tmp/cgp-{kernel}-{size}.json");
483    let profile = if std::path::Path::new(&profile_path).exists() {
484        Some(crate::metrics::export::load_json(std::path::Path::new(
485            &profile_path,
486        ))?)
487    } else {
488        None
489    };
490
491    let (time_us, tflops) = match &profile {
492        Some(p) => (p.timing.wall_clock_time_us, p.throughput.tflops),
493        None => {
494            // Estimate for GEMM
495            let flops = 2.0 * (size as f64).powi(3);
496            let est_time = 23.2 * (size as f64 / 512.0).powi(3); // Scale from 512 baseline
497            let est_tflops = flops / (est_time * 1e-6) / 1e12;
498            (est_time, est_tflops)
499        }
500    };
501
502    let max_time = time_us * (1.0 + tolerance / 100.0);
503    let min_tflops = tflops * (1.0 - tolerance / 100.0);
504
505    // Detect GPU
506    let gpu_name = std::process::Command::new("nvidia-smi")
507        .args(["--query-gpu=name", "--format=csv,noheader"])
508        .output()
509        .ok()
510        .filter(|o| o.status.success())
511        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
512        .unwrap_or_else(|| "Unknown GPU".to_string());
513
514    let contract_yaml = format!(
515        r#"# Generated by cgp contract generate
516# Kernel: {kernel} at size {size}x{size}x{size}
517# Tolerance: {tolerance}%
518kind: PerformanceContract
519name: {kernel}-{size}
520version: "1.0.0"
521kernel: {kernel}
522hardware:
523  gpu: "{gpu_name}"
524  compute_capability: "8.9"
525
526bounds:
527  - size: [{size}, {size}, {size}]
528    max_time_us: {max_time:.1}
529    min_tflops: {min_tflops:.1}
530    max_regression_pct: {tolerance}
531
532metrics:
533  warp_execution_efficiency:
534    min: 95.0
535  achieved_occupancy:
536    min: 25.0
537
538falsification:
539  - name: FALSIFY-{kernel_upper}-001
540    description: "{kernel} must achieve >{min_tflops:.1} TFLOP/s at {size}x{size}"
541    check: "tflops > {min_tflops:.1}"
542  - name: FALSIFY-{kernel_upper}-002
543    description: "{kernel} must complete in <{max_time:.1}us at {size}x{size}"
544    check: "wall_clock_time_us < {max_time:.1}"
545"#,
546        kernel = kernel,
547        size = size,
548        tolerance = tolerance,
549        gpu_name = gpu_name,
550        max_time = max_time,
551        min_tflops = min_tflops,
552        kernel_upper = kernel.to_uppercase().replace('-', "_"),
553    );
554
555    // Write to contracts directory
556    let contracts_dir = std::path::Path::new("contracts/cgp");
557    std::fs::create_dir_all(contracts_dir)?;
558    let contract_path = contracts_dir.join(format!("{kernel}-{size}-v1.yaml"));
559    std::fs::write(&contract_path, &contract_yaml)?;
560
561    println!("Generated contract: {}", contract_path.display());
562    println!();
563    print!("{contract_yaml}");
564
565    Ok(())
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    fn sample_contract() -> PerformanceContract {
573        PerformanceContract {
574            kind: "PerformanceContract".to_string(),
575            name: "test-gemm-contract".to_string(),
576            version: "1.0.0".to_string(),
577            kernel: "gemm_cta_wmma_fp16".to_string(),
578            hardware: HardwareSpec {
579                gpu: Some("NVIDIA GeForce RTX 4090".to_string()),
580                cpu: None,
581                compute_capability: Some("8.9".to_string()),
582            },
583            bounds: vec![PerformanceBound {
584                size: vec![512, 512, 512],
585                max_time_us: Some(30.0),
586                min_tflops: Some(9.0),
587                max_regression_pct: Some(10.0),
588                min_bandwidth_gbps: None,
589                extra: Default::default(),
590            }],
591            metrics: {
592                let mut m = std::collections::HashMap::new();
593                m.insert(
594                    "warp_execution_efficiency".to_string(),
595                    MetricBound {
596                        min: Some(95.0),
597                        max: None,
598                    },
599                );
600                m
601            },
602            falsification: vec![FalsificationCheck {
603                name: "FALSIFY-TEST-001".to_string(),
604                description: "CTA WMMA must achieve >9 TFLOP/s".to_string(),
605                check: "tflops > 9.0".to_string(),
606                extra: Default::default(),
607            }],
608            extra: Default::default(),
609        }
610    }
611
612    #[test]
613    fn test_verify_valid_contract() {
614        let contract = sample_contract();
615        let result = verify_contract(&contract);
616        assert!(result.is_pass());
617        assert!(!result.passed.is_empty());
618    }
619
620    #[test]
621    fn test_verify_missing_kernel_is_skipped() {
622        let mut contract = sample_contract();
623        contract.kernel = String::new();
624        let result = verify_contract(&contract);
625        // Domain-specific contracts without kernel are allowed (skipped, not failed)
626        assert!(result.is_pass());
627        assert!(!result.skipped.is_empty());
628    }
629
630    #[test]
631    fn test_contract_yaml_roundtrip() {
632        let contract = sample_contract();
633        let yaml = serde_yaml_ng::to_string(&contract).unwrap();
634        let parsed: PerformanceContract = serde_yaml_ng::from_str(&yaml).unwrap();
635        assert_eq!(parsed.name, contract.name);
636        assert_eq!(parsed.kernel, contract.kernel);
637        assert_eq!(parsed.bounds.len(), 1);
638        assert_eq!(parsed.bounds[0].size, vec![512, 512, 512]);
639    }
640
641    #[test]
642    fn test_contract_falsification_checks() {
643        let contract = sample_contract();
644        let result = verify_contract(&contract);
645        // Falsification checks are skipped (need runtime data), not failed
646        assert!(result.is_pass());
647        assert!(!result.skipped.is_empty());
648    }
649}