Skip to main content

cgp/analysis/
contracts.rs

1//! Performance contract verification (CI/CD gate).
2//! Extends provable-contracts framework to performance bounds.
3//! See spec section 3.4 and 7.1.
4
5use anyhow::{Context, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9/// A performance contract loaded from YAML.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PerformanceContract {
12    pub kind: String,
13    pub name: String,
14    pub version: String,
15    #[serde(default)]
16    pub kernel: String,
17    #[serde(default)]
18    pub hardware: HardwareSpec,
19    #[serde(default)]
20    pub bounds: Vec<PerformanceBound>,
21    #[serde(default)]
22    pub metrics: std::collections::HashMap<String, MetricBound>,
23    #[serde(default)]
24    pub falsification: Vec<FalsificationCheck>,
25    /// Absorb any extra fields from domain-specific contract schemas
26    #[serde(flatten, default)]
27    pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
28}
29
30#[derive(Debug, Clone, Default, Serialize, Deserialize)]
31pub struct HardwareSpec {
32    pub gpu: Option<String>,
33    pub cpu: Option<String>,
34    pub compute_capability: Option<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct PerformanceBound {
39    #[serde(default, deserialize_with = "deserialize_size")]
40    pub size: Vec<u32>,
41    #[serde(default)]
42    pub max_time_us: Option<f64>,
43    #[serde(default)]
44    pub min_tflops: Option<f64>,
45    #[serde(default)]
46    pub max_regression_pct: Option<f64>,
47    #[serde(default)]
48    pub min_bandwidth_gbps: Option<f64>,
49    /// Absorb domain-specific bound fields (operation, competitor, etc.)
50    #[serde(flatten, default)]
51    pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
52}
53
54/// Accept both `size: 1024` (single int) and `size: [1024, 1024, 1024]` (sequence).
55fn deserialize_size<'de, D>(deserializer: D) -> Result<Vec<u32>, D::Error>
56where
57    D: serde::Deserializer<'de>,
58{
59    use serde::de;
60
61    struct SizeVisitor;
62    impl<'de> de::Visitor<'de> for SizeVisitor {
63        type Value = Vec<u32>;
64        fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
65            f.write_str("an integer or sequence of integers")
66        }
67        fn visit_u64<E: de::Error>(self, v: u64) -> Result<Vec<u32>, E> {
68            Ok(vec![v as u32])
69        }
70        fn visit_i64<E: de::Error>(self, v: i64) -> Result<Vec<u32>, E> {
71            Ok(vec![v as u32])
72        }
73        fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Vec<u32>, A::Error> {
74            let mut v = Vec::new();
75            while let Some(elem) = seq.next_element::<u32>()? {
76                v.push(elem);
77            }
78            Ok(v)
79        }
80        fn visit_none<E: de::Error>(self) -> Result<Vec<u32>, E> {
81            Ok(Vec::new())
82        }
83        fn visit_unit<E: de::Error>(self) -> Result<Vec<u32>, E> {
84            Ok(Vec::new())
85        }
86    }
87    deserializer.deserialize_any(SizeVisitor)
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct MetricBound {
92    pub min: Option<f64>,
93    pub max: Option<f64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct FalsificationCheck {
98    pub name: String,
99    #[serde(default)]
100    pub description: String,
101    #[serde(default)]
102    pub check: String,
103    #[serde(flatten, default)]
104    pub extra: std::collections::HashMap<String, serde_yaml_ng::Value>,
105}
106
107/// Result of verifying a single contract.
108#[derive(Debug)]
109pub struct ContractVerification {
110    pub contract_name: String,
111    pub passed: Vec<String>,
112    pub failed: Vec<String>,
113    pub skipped: Vec<String>,
114}
115
116impl ContractVerification {
117    pub fn is_pass(&self) -> bool {
118        self.failed.is_empty()
119    }
120}
121
122/// Load a performance contract from a YAML file.
123pub fn load_contract(path: &Path) -> Result<PerformanceContract> {
124    let content = std::fs::read_to_string(path)
125        .with_context(|| format!("Failed to read contract: {}", path.display()))?;
126    let contract: PerformanceContract = serde_yaml_ng::from_str(&content)
127        .with_context(|| format!("Failed to parse contract: {}", path.display()))?;
128    Ok(contract)
129}
130
131/// Load all contracts from a directory.
132pub fn load_contracts_dir(dir: &Path) -> Result<Vec<PerformanceContract>> {
133    let mut contracts = Vec::new();
134    if dir.is_dir() {
135        for entry in std::fs::read_dir(dir)? {
136            let entry = entry?;
137            let path = entry.path();
138            if path.extension().is_some_and(|e| e == "yaml" || e == "yml") {
139                match load_contract(&path) {
140                    Ok(c) => contracts.push(c),
141                    Err(e) => eprintln!("Warning: skipping {}: {e}", path.display()),
142                }
143            }
144        }
145    }
146    Ok(contracts)
147}
148
149/// Verify a contract against measured values.
150/// Loads saved profiles from /tmp/cgp-{kernel}-{size}.json if available,
151/// and checks performance bounds + falsification expressions.
152pub fn verify_contract(contract: &PerformanceContract) -> ContractVerification {
153    let mut result = ContractVerification {
154        contract_name: contract.name.clone(),
155        passed: Vec::new(),
156        failed: Vec::new(),
157        skipped: Vec::new(),
158    };
159
160    // Validate structure
161    if contract.kind.is_empty() {
162        result
163            .failed
164            .push("Contract missing 'kind' field".to_string());
165    } else {
166        result.passed.push(format!("kind: {}", contract.kind));
167    }
168
169    if contract.kernel.is_empty() {
170        result
171            .skipped
172            .push("No kernel field — domain-specific contract".to_string());
173    } else {
174        result.passed.push(format!("kernel: {}", contract.kernel));
175    }
176
177    // Check bounds against saved profiles
178    for (i, bound) in contract.bounds.iter().enumerate() {
179        if bound.size.is_empty() {
180            // Domain-specific bounds without size — structural pass
181            result
182                .passed
183                .push(format!("Bound {i}: structural (no size)"));
184            continue;
185        }
186
187        let size = bound.size[0];
188        let profile_path = format!("/tmp/cgp-{}-{size}.json", contract.kernel);
189        let profile = std::path::Path::new(&profile_path)
190            .exists()
191            .then(|| crate::metrics::export::load_json(std::path::Path::new(&profile_path)).ok())
192            .flatten();
193
194        match profile {
195            Some(p) => {
196                // Check max_time_us
197                if let Some(max_time) = bound.max_time_us {
198                    if p.timing.wall_clock_time_us <= max_time {
199                        result.passed.push(format!(
200                            "Bound {i}: time {:.1}us <= {max_time:.1}us",
201                            p.timing.wall_clock_time_us
202                        ));
203                    } else {
204                        result.failed.push(format!(
205                            "Bound {i}: time {:.1}us > {max_time:.1}us EXCEEDED",
206                            p.timing.wall_clock_time_us
207                        ));
208                    }
209                }
210                // Check min_tflops
211                if let Some(min_tflops) = bound.min_tflops {
212                    if p.throughput.tflops >= min_tflops {
213                        result.passed.push(format!(
214                            "Bound {i}: {:.1} TFLOP/s >= {min_tflops:.1}",
215                            p.throughput.tflops
216                        ));
217                    } else {
218                        result.failed.push(format!(
219                            "Bound {i}: {:.1} TFLOP/s < {min_tflops:.1} BELOW MINIMUM",
220                            p.throughput.tflops
221                        ));
222                    }
223                }
224                // Check min_bandwidth_gbps
225                if let Some(min_bw) = bound.min_bandwidth_gbps {
226                    if p.throughput.bandwidth_gbps >= min_bw {
227                        result.passed.push(format!(
228                            "Bound {i}: {:.1} GB/s >= {min_bw:.1}",
229                            p.throughput.bandwidth_gbps
230                        ));
231                    } else {
232                        result.failed.push(format!(
233                            "Bound {i}: {:.1} GB/s < {min_bw:.1} BELOW MINIMUM",
234                            p.throughput.bandwidth_gbps
235                        ));
236                    }
237                }
238            }
239            None => {
240                // No profile available — validate structure only
241                result
242                    .passed
243                    .push(format!("Bound {i}: size {:?}", bound.size));
244                if bound.max_time_us.is_none()
245                    && bound.min_tflops.is_none()
246                    && bound.min_bandwidth_gbps.is_none()
247                {
248                    result
249                        .skipped
250                        .push(format!("Bound {i}: no criteria specified"));
251                }
252            }
253        }
254    }
255
256    // Evaluate falsification checks against profile data
257    for check in &contract.falsification {
258        if check.name.is_empty() || check.check.is_empty() {
259            result.failed.push(format!(
260                "Falsification '{}': missing name or check",
261                check.name
262            ));
263            continue;
264        }
265
266        // Try to find a profile and evaluate the check expression
267        let size = contract
268            .bounds
269            .first()
270            .and_then(|b| b.size.first())
271            .copied()
272            .unwrap_or(512);
273        let profile_path = format!("/tmp/cgp-{}-{size}.json", contract.kernel);
274        let profile = std::path::Path::new(&profile_path)
275            .exists()
276            .then(|| crate::metrics::export::load_json(std::path::Path::new(&profile_path)).ok())
277            .flatten();
278
279        match profile {
280            Some(p) => {
281                let pass = evaluate_check(&check.check, &p);
282                if pass {
283                    result.passed.push(format!("FALSIFY {}: PASS", check.name));
284                } else {
285                    result.failed.push(format!(
286                        "FALSIFY {}: FAIL ({})",
287                        check.name, check.description
288                    ));
289                }
290            }
291            None => {
292                result.skipped.push(format!(
293                    "FALSIFY {}: {} (no profile at {profile_path})",
294                    check.name, check.description
295                ));
296            }
297        }
298    }
299
300    result
301}
302
303/// Evaluate a simple check expression against a profile.
304/// Supports: "field > value", "field < value", "field >= value", "field == value"
305fn evaluate_check(expr: &str, profile: &crate::metrics::catalog::FullProfile) -> bool {
306    let parts: Vec<&str> = expr.split_whitespace().collect();
307    if parts.len() != 3 {
308        return false;
309    }
310    let field = parts[0];
311    let op = parts[1];
312    let threshold: f64 = match parts[2].parse() {
313        Ok(v) => v,
314        Err(_) => return false,
315    };
316
317    let value = match field {
318        "tflops" => profile.throughput.tflops,
319        "wall_clock_time_us" => profile.timing.wall_clock_time_us,
320        "bandwidth_gbps" => profile.throughput.bandwidth_gbps,
321        "arithmetic_intensity" => profile.throughput.arithmetic_intensity,
322        "warp_execution_efficiency" => profile
323            .gpu_compute
324            .as_ref()
325            .map_or(0.0, |g| g.warp_execution_efficiency_pct),
326        "achieved_occupancy" => profile
327            .gpu_compute
328            .as_ref()
329            .map_or(0.0, |g| g.achieved_occupancy_pct),
330        "global_load_efficiency" => profile
331            .gpu_memory
332            .as_ref()
333            .map_or(0.0, |g| g.global_load_efficiency_pct),
334        _ => return false,
335    };
336
337    match op {
338        ">" => value > threshold,
339        "<" => value < threshold,
340        ">=" => value >= threshold,
341        "<=" => value <= threshold,
342        "==" => (value - threshold).abs() < 0.001,
343        _ => false,
344    }
345}
346
347/// Run contract verification for a directory of contracts.
348pub fn run_verify(
349    contracts_dir: Option<&str>,
350    contract_file: Option<&str>,
351    self_verify: bool,
352    fail_on_regression: bool,
353) -> Result<()> {
354    let contracts = if let Some(dir) = contracts_dir {
355        load_contracts_dir(Path::new(dir))?
356    } else if let Some(file) = contract_file {
357        vec![load_contract(Path::new(file))?]
358    } else if self_verify {
359        let dir = Path::new("contracts/cgp");
360        if dir.exists() {
361            load_contracts_dir(dir)?
362        } else {
363            println!("No contracts found at contracts/cgp/");
364            return Ok(());
365        }
366    } else {
367        anyhow::bail!("Specify --contracts-dir, --contract, or --self");
368    };
369
370    println!("\n=== cgp Contract Verification ===\n");
371    let mut total_pass = 0;
372    let mut total_fail = 0;
373    let mut total_skip = 0;
374
375    for c in &contracts {
376        let result = verify_contract(c);
377        let status = if result.is_pass() {
378            "\x1b[32mPASS\x1b[0m"
379        } else {
380            "\x1b[31mFAIL\x1b[0m"
381        };
382        println!(
383            "  {} {} ({} pass, {} fail, {} skip)",
384            status,
385            c.name,
386            result.passed.len(),
387            result.failed.len(),
388            result.skipped.len()
389        );
390        total_pass += result.passed.len();
391        total_fail += result.failed.len();
392        total_skip += result.skipped.len();
393    }
394
395    println!("\n  Total: {total_pass} pass, {total_fail} fail, {total_skip} skip");
396    if total_fail > 0 && fail_on_regression {
397        anyhow::bail!("{total_fail} contract verification(s) failed");
398    }
399    println!();
400    Ok(())
401}
402
403/// Generate a performance contract YAML from a profile or estimated values.
404pub fn run_generate(kernel: &str, size: u32, tolerance: f64) -> Result<()> {
405    // Try to load a saved profile for this kernel
406    let profile_path = format!("/tmp/cgp-{kernel}-{size}.json");
407    let profile = if std::path::Path::new(&profile_path).exists() {
408        Some(crate::metrics::export::load_json(std::path::Path::new(
409            &profile_path,
410        ))?)
411    } else {
412        None
413    };
414
415    let (time_us, tflops) = match &profile {
416        Some(p) => (p.timing.wall_clock_time_us, p.throughput.tflops),
417        None => {
418            // Estimate for GEMM
419            let flops = 2.0 * (size as f64).powi(3);
420            let est_time = 23.2 * (size as f64 / 512.0).powi(3); // Scale from 512 baseline
421            let est_tflops = flops / (est_time * 1e-6) / 1e12;
422            (est_time, est_tflops)
423        }
424    };
425
426    let max_time = time_us * (1.0 + tolerance / 100.0);
427    let min_tflops = tflops * (1.0 - tolerance / 100.0);
428
429    // Detect GPU
430    let gpu_name = std::process::Command::new("nvidia-smi")
431        .args(["--query-gpu=name", "--format=csv,noheader"])
432        .output()
433        .ok()
434        .filter(|o| o.status.success())
435        .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string())
436        .unwrap_or_else(|| "Unknown GPU".to_string());
437
438    let contract_yaml = format!(
439        r#"# Generated by cgp contract generate
440# Kernel: {kernel} at size {size}x{size}x{size}
441# Tolerance: {tolerance}%
442kind: PerformanceContract
443name: {kernel}-{size}
444version: "1.0.0"
445kernel: {kernel}
446hardware:
447  gpu: "{gpu_name}"
448  compute_capability: "8.9"
449
450bounds:
451  - size: [{size}, {size}, {size}]
452    max_time_us: {max_time:.1}
453    min_tflops: {min_tflops:.1}
454    max_regression_pct: {tolerance}
455
456metrics:
457  warp_execution_efficiency:
458    min: 95.0
459  achieved_occupancy:
460    min: 25.0
461
462falsification:
463  - name: FALSIFY-{kernel_upper}-001
464    description: "{kernel} must achieve >{min_tflops:.1} TFLOP/s at {size}x{size}"
465    check: "tflops > {min_tflops:.1}"
466  - name: FALSIFY-{kernel_upper}-002
467    description: "{kernel} must complete in <{max_time:.1}us at {size}x{size}"
468    check: "wall_clock_time_us < {max_time:.1}"
469"#,
470        kernel = kernel,
471        size = size,
472        tolerance = tolerance,
473        gpu_name = gpu_name,
474        max_time = max_time,
475        min_tflops = min_tflops,
476        kernel_upper = kernel.to_uppercase().replace('-', "_"),
477    );
478
479    // Write to contracts directory
480    let contracts_dir = std::path::Path::new("contracts/cgp");
481    std::fs::create_dir_all(contracts_dir)?;
482    let contract_path = contracts_dir.join(format!("{kernel}-{size}-v1.yaml"));
483    std::fs::write(&contract_path, &contract_yaml)?;
484
485    println!("Generated contract: {}", contract_path.display());
486    println!();
487    print!("{contract_yaml}");
488
489    Ok(())
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    fn sample_contract() -> PerformanceContract {
497        PerformanceContract {
498            kind: "PerformanceContract".to_string(),
499            name: "test-gemm-contract".to_string(),
500            version: "1.0.0".to_string(),
501            kernel: "gemm_cta_wmma_fp16".to_string(),
502            hardware: HardwareSpec {
503                gpu: Some("NVIDIA GeForce RTX 4090".to_string()),
504                cpu: None,
505                compute_capability: Some("8.9".to_string()),
506            },
507            bounds: vec![PerformanceBound {
508                size: vec![512, 512, 512],
509                max_time_us: Some(30.0),
510                min_tflops: Some(9.0),
511                max_regression_pct: Some(10.0),
512                min_bandwidth_gbps: None,
513                extra: Default::default(),
514            }],
515            metrics: {
516                let mut m = std::collections::HashMap::new();
517                m.insert(
518                    "warp_execution_efficiency".to_string(),
519                    MetricBound {
520                        min: Some(95.0),
521                        max: None,
522                    },
523                );
524                m
525            },
526            falsification: vec![FalsificationCheck {
527                name: "FALSIFY-TEST-001".to_string(),
528                description: "CTA WMMA must achieve >9 TFLOP/s".to_string(),
529                check: "tflops > 9.0".to_string(),
530                extra: Default::default(),
531            }],
532            extra: Default::default(),
533        }
534    }
535
536    #[test]
537    fn test_verify_valid_contract() {
538        let contract = sample_contract();
539        let result = verify_contract(&contract);
540        assert!(result.is_pass());
541        assert!(!result.passed.is_empty());
542    }
543
544    #[test]
545    fn test_verify_missing_kernel_is_skipped() {
546        let mut contract = sample_contract();
547        contract.kernel = String::new();
548        let result = verify_contract(&contract);
549        // Domain-specific contracts without kernel are allowed (skipped, not failed)
550        assert!(result.is_pass());
551        assert!(!result.skipped.is_empty());
552    }
553
554    #[test]
555    fn test_contract_yaml_roundtrip() {
556        let contract = sample_contract();
557        let yaml = serde_yaml_ng::to_string(&contract).unwrap();
558        let parsed: PerformanceContract = serde_yaml_ng::from_str(&yaml).unwrap();
559        assert_eq!(parsed.name, contract.name);
560        assert_eq!(parsed.kernel, contract.kernel);
561        assert_eq!(parsed.bounds.len(), 1);
562        assert_eq!(parsed.bounds[0].size, vec![512, 512, 512]);
563    }
564
565    #[test]
566    fn test_contract_falsification_checks() {
567        let contract = sample_contract();
568        let result = verify_contract(&contract);
569        // Falsification checks are skipped (need runtime data), not failed
570        assert!(result.is_pass());
571        assert!(!result.skipped.is_empty());
572    }
573}