Skip to main content

cgp/profilers/
rayon_parallel.rs

1//! Rayon parallel profiling. Spec section 4.9.
2//! Measures parallel efficiency, work stealing, and load balance (Heijunka score).
3
4use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7/// Rayon parallel profile output.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct RayonProfile {
10    pub wall_time_us: f64,
11    pub single_thread_time_us: f64,
12    pub parallel_speedup: f64,
13    pub num_threads: usize,
14    pub parallel_efficiency: f64,
15    /// 0.0 = perfect balance, 1.0 = all work on 1 thread.
16    pub heijunka_score: f64,
17    pub thread_spawn_overhead_us: f64,
18    pub work_steal_count: u64,
19}
20
21impl RayonProfile {
22    /// Compute Heijunka (load balance) score from per-thread work times.
23    /// Score = coefficient of variation of per-thread times.
24    /// 0.0 = perfect balance, higher = more imbalanced.
25    pub fn compute_heijunka_score(per_thread_times: &[f64]) -> f64 {
26        if per_thread_times.is_empty() || per_thread_times.len() == 1 {
27            return 0.0;
28        }
29        let mean = per_thread_times.iter().sum::<f64>() / per_thread_times.len() as f64;
30        if mean == 0.0 {
31            return 0.0;
32        }
33        let variance = per_thread_times
34            .iter()
35            .map(|t| (t - mean).powi(2))
36            .sum::<f64>()
37            / per_thread_times.len() as f64;
38        let cv = variance.sqrt() / mean;
39        cv.min(1.0) // Cap at 1.0
40    }
41
42    /// Estimate parallel speedup from single-thread and multi-thread wall times.
43    pub fn compute_speedup(single_thread_us: f64, parallel_us: f64) -> f64 {
44        if parallel_us > 0.0 {
45            single_thread_us / parallel_us
46        } else {
47            0.0
48        }
49    }
50
51    /// Compute parallel efficiency: speedup / num_threads (1.0 = ideal).
52    pub fn compute_efficiency(speedup: f64, num_threads: usize) -> f64 {
53        if num_threads > 0 {
54            speedup / num_threads as f64
55        } else {
56            0.0
57        }
58    }
59}
60
61/// Number of runs to take min of for stable timing.
62const TIMING_RUNS: usize = 3;
63
64/// Profile a parallel function.
65/// Runs the benchmark binary with RAYON_NUM_THREADS=1 and RAYON_NUM_THREADS=N,
66/// using min-of-3 timing for stability. Computes parallel metrics.
67pub fn profile_parallel(function: &str, size: u32, threads: Option<&str>) -> Result<()> {
68    let thread_count = match threads {
69        Some("auto") | None => num_cpus::get(),
70        Some(n) => n
71            .parse()
72            .map_err(|_| anyhow::anyhow!("Invalid thread count: {n}"))?,
73    };
74
75    println!("\n=== CGP Parallel Profile: {function} (size={size}, threads={thread_count}) ===\n");
76    println!("  Backend: Rayon thread pool");
77    println!("  Function: {function}");
78    println!("  Size: {size}");
79    println!("  Threads: {thread_count}");
80
81    // Try to find a benchmark binary
82    let binary = find_parallel_binary();
83    match binary {
84        Some(bin) => {
85            println!("  Binary: {bin}");
86            println!("  Timing: min of {TIMING_RUNS} runs");
87
88            // Run single-threaded (min of N runs)
89            let single_time = time_binary_min_of_n(&bin, 1, TIMING_RUNS);
90            // Run multi-threaded (min of N runs)
91            let parallel_time = time_binary_min_of_n(&bin, thread_count, TIMING_RUNS);
92
93            match (single_time, parallel_time) {
94                (Some(st), Some(pt)) => {
95                    let speedup = RayonProfile::compute_speedup(st, pt);
96                    let efficiency = RayonProfile::compute_efficiency(speedup, thread_count);
97
98                    println!("\n  Results:");
99                    println!("    Single-thread:      {st:.0} us");
100                    println!("    {thread_count}-thread:     {pt:.0} us");
101                    println!("    Parallel speedup:   {speedup:.2}x");
102                    println!("    Efficiency:         {:.1}%", efficiency * 100.0);
103
104                    // Estimate spawn overhead (~40us per thread::scope call)
105                    let overhead_estimate = 40.0; // us, from memory feedback
106                    let overhead_pct = if pt > 0.0 {
107                        overhead_estimate / pt * 100.0
108                    } else {
109                        0.0
110                    };
111                    println!(
112                        "    Thread overhead:     ~{overhead_estimate:.0} us ({overhead_pct:.1}% of total)"
113                    );
114
115                    // Warning for small workloads
116                    if pt < 500.0 {
117                        println!(
118                            "\n  \x1b[33m[WARN]\x1b[0m Workload <500us — thread overhead dominates"
119                        );
120                        println!("    Consider: sequential execution or batching");
121                    }
122                }
123                _ => {
124                    println!("\n  Could not time binary — showing configuration only.");
125                }
126            }
127        }
128        None => {
129            println!("  No benchmark binary found.");
130            println!("  Build with: cargo build --release --examples");
131            println!("  Estimated metrics with synthetic data:");
132
133            // Show theoretical analysis
134            println!("\n  Theoretical Analysis:");
135            println!(
136                "    Amdahl's law: if 95% parallelizable, max speedup = {:.1}x",
137                amdahl(0.95, thread_count)
138            );
139            println!(
140                "    If 90% parallelizable, max speedup = {:.1}x",
141                amdahl(0.90, thread_count)
142            );
143            println!(
144                "    If 80% parallelizable, max speedup = {:.1}x",
145                amdahl(0.80, thread_count)
146            );
147        }
148    }
149
150    println!();
151    Ok(())
152}
153
154/// Parallel scaling sweep — measure throughput at each thread count.
155/// Produces the scaling table used in spec Appendix A.2.
156pub fn profile_scaling(
157    size: u32,
158    max_threads: Option<usize>,
159    runs: usize,
160    json: bool,
161) -> anyhow::Result<()> {
162    let max_t = max_threads.unwrap_or_else(num_cpus::get);
163    let binary = resolve_bench_binary()?;
164    let bin = binary.as_str();
165
166    let thread_counts = build_thread_counts(max_t);
167
168    let (baseline_ms, baseline_gflops) = parse_gemm_time(bin, size, 1, runs)
169        .ok_or_else(|| anyhow::anyhow!(
170            "Failed to parse GEMM {size}x{size} from benchmark output (1 thread). \
171             Ensure the binary outputs 'Matrix Multiplication ({size}x{size}x{size})... X.XX ms (Y.YY GFLOPS)'"
172        ))?;
173
174    if !json {
175        println!("\n=== CGP Parallel Scaling: GEMM {size}x{size}, min-of-{runs} runs ===\n");
176        println!(
177            "  {:>8} | {:>10} | {:>10} | {:>8} | {:>6}",
178            "Threads", "Time (ms)", "GFLOPS", "Scaling", "Notes"
179        );
180        println!("  {}", "-".repeat(60));
181    }
182
183    let results = run_scaling_sweep(bin, size, &thread_counts, runs, baseline_ms, json);
184
185    if !json {
186        print_scaling_peak(&results, baseline_gflops);
187    }
188
189    if json {
190        println!("{}", serde_json::to_string_pretty(&results)?);
191    } else {
192        println!();
193    }
194
195    Ok(())
196}
197
198/// Locate the benchmark binary or bail with a build hint.
199fn resolve_bench_binary() -> anyhow::Result<String> {
200    find_parallel_binary().ok_or_else(|| anyhow::anyhow!(
201        "No benchmark binary found. Build with: cargo build --release --example benchmark_matrix_suite --features parallel"
202    ))
203}
204
205/// Build the thread-count sweep, filtered to `max_t` and ensuring `max_t` is present.
206fn build_thread_counts(max_t: usize) -> Vec<usize> {
207    let mut thread_counts: Vec<usize> = vec![1, 2, 4, 8, 12, 16, 24, 32, 48, 64];
208    thread_counts.retain(|&t| t <= max_t);
209    if !thread_counts.contains(&max_t) {
210        thread_counts.push(max_t);
211    }
212    thread_counts
213}
214
215/// Run the scaling measurement loop, printing each row unless `json` is set.
216fn run_scaling_sweep(
217    bin: &str,
218    size: u32,
219    thread_counts: &[usize],
220    runs: usize,
221    baseline_ms: f64,
222    json: bool,
223) -> Vec<ScalingPoint> {
224    let mut results: Vec<ScalingPoint> = Vec::new();
225    for &t in thread_counts {
226        match parse_gemm_time(bin, size, t, runs) {
227            Some((time_ms, gflops)) => {
228                let scaling = baseline_ms / time_ms;
229                if !json {
230                    let notes = scaling_notes(t, scaling);
231                    println!(
232                        "  {:>8} | {:>9.2} | {:>10.1} | {:>7.1}x | {notes}",
233                        t, time_ms, gflops, scaling
234                    );
235                }
236                results.push(ScalingPoint {
237                    threads: t,
238                    time_us: time_ms * 1000.0,
239                    gflops,
240                    scaling,
241                });
242            }
243            None => {
244                if !json {
245                    println!("  {:>8} | {:>10} | {:>10} | {:>8} |", t, "FAILED", "-", "-");
246                }
247            }
248        }
249    }
250    results
251}
252
253/// Classify a row of the scaling table as baseline / near-linear / regular.
254fn scaling_notes(threads: usize, scaling: f64) -> String {
255    if threads == 1 {
256        "baseline".to_string()
257    } else if scaling >= (threads as f64) * 0.9 {
258        "near-linear".to_string()
259    } else {
260        String::new()
261    }
262}
263
264/// Print the peak GFLOPS row and its efficiency vs ideal linear scaling.
265fn print_scaling_peak(results: &[ScalingPoint], baseline_gflops: f64) {
266    let Some(peak) = results.iter().max_by(|a, b| {
267        a.gflops
268            .partial_cmp(&b.gflops)
269            .unwrap_or(std::cmp::Ordering::Equal)
270    }) else {
271        return;
272    };
273    println!(
274        "\n  Peak: {:.1} GFLOPS at {}T ({:.1}x scaling)",
275        peak.gflops, peak.threads, peak.scaling
276    );
277    let theoretical_peak = baseline_gflops * peak.threads as f64;
278    let efficiency = peak.gflops / theoretical_peak * 100.0;
279    println!("  Efficiency: {efficiency:.1}% vs linear scaling");
280}
281
282/// A single point in a parallel scaling curve.
283#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
284pub struct ScalingPoint {
285    pub threads: usize,
286    pub time_us: f64,
287    pub gflops: f64,
288    pub scaling: f64,
289}
290
291/// Amdahl's law: speedup = 1 / ((1 - p) + p/n)
292fn amdahl(parallel_fraction: f64, threads: usize) -> f64 {
293    1.0 / ((1.0 - parallel_fraction) + parallel_fraction / threads as f64)
294}
295
296/// Time a binary with a given number of threads, returning min of N runs
297/// for stable measurements (reduces noise from OS scheduling, cache state).
298fn time_binary_min_of_n(binary: &str, threads: usize, runs: usize) -> Option<f64> {
299    let mut best: Option<f64> = None;
300    for _ in 0..runs {
301        let start = std::time::Instant::now();
302        let output = std::process::Command::new(binary)
303            .env("RAYON_NUM_THREADS", threads.to_string())
304            .output()
305            .ok()?;
306        if !output.status.success() {
307            return None;
308        }
309        let elapsed = start.elapsed().as_secs_f64() * 1e6;
310        best = Some(best.map_or(elapsed, |b: f64| b.min(elapsed)));
311    }
312    best
313}
314
315/// Run benchmark binary and parse the GEMM time for a specific size.
316/// Parses lines like: "Matrix Multiplication (1024x1024x1024)...     6.04 ms  (355.35 GFLOPS)"
317/// Returns (time_ms, gflops) for the best of N runs.
318fn parse_gemm_time(binary: &str, size: u32, threads: usize, runs: usize) -> Option<(f64, f64)> {
319    let pattern = format!("{size}x{size}x{size}");
320    let mut best: Option<(f64, f64)> = None;
321
322    for _ in 0..runs {
323        let stdout = run_bench_once(binary, threads)?;
324        if let Some((ms, gf)) = extract_best_matching(&stdout, &pattern) {
325            if best.is_none_or(|(b_ms, _)| ms < b_ms) {
326                best = Some((ms, gf));
327            }
328        }
329    }
330
331    best
332}
333
334/// Run the benchmark binary once at `threads`, returning stdout on success.
335fn run_bench_once(binary: &str, threads: usize) -> Option<String> {
336    let output = std::process::Command::new(binary)
337        .env("RAYON_NUM_THREADS", threads.to_string())
338        .output()
339        .ok()?;
340    if !output.status.success() {
341        return None;
342    }
343    Some(String::from_utf8_lossy(&output.stdout).into_owned())
344}
345
346/// Scan stdout for the best `(time_ms, gflops)` pair matching the size pattern.
347fn extract_best_matching(stdout: &str, pattern: &str) -> Option<(f64, f64)> {
348    let mut best: Option<(f64, f64)> = None;
349    for line in stdout.lines() {
350        let Some((ms, gf)) = parse_gemm_line(line, pattern) else {
351            continue;
352        };
353        if best.is_none_or(|(b_ms, _)| ms < b_ms) {
354            best = Some((ms, gf));
355        }
356    }
357    best
358}
359
360/// Parse a single "Matrix Multiplication (NxNxN)... X.XX ms (Y.YY GFLOPS)" line.
361fn parse_gemm_line(line: &str, pattern: &str) -> Option<(f64, f64)> {
362    if !line.contains("Matrix Multiplication") || !line.contains(pattern) {
363        return None;
364    }
365    let ms = extract_between(line, "...", " ms")?
366        .trim()
367        .parse::<f64>()
368        .ok()?;
369    let gf = extract_between(line, "(", " GFLOPS)")?
370        .trim()
371        .parse::<f64>()
372        .ok()?;
373    Some((ms, gf))
374}
375
376/// Extract text between two markers in a string.
377/// Finds the last occurrence of `start` before `end`.
378fn extract_between<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> {
379    let end_idx = s.find(end)?;
380    let prefix = &s[..end_idx];
381    let start_idx = prefix.rfind(start)? + start.len();
382    Some(&s[start_idx..end_idx])
383}
384
385/// Find a parallel benchmark binary.
386fn find_parallel_binary() -> Option<String> {
387    let target_dir = std::env::var("CARGO_TARGET_DIR").unwrap_or_default();
388    let mut candidates: Vec<String> = Vec::new();
389    if !target_dir.is_empty() {
390        candidates.push(format!(
391            "{target_dir}/release/examples/benchmark_matrix_suite"
392        ));
393    }
394    candidates.extend_from_slice(&[
395        "/mnt/nvme-raid0/targets/trueno/release/examples/benchmark_matrix_suite".to_string(),
396        "./target/release/examples/benchmark_matrix_suite".to_string(),
397    ]);
398    for path in &candidates {
399        if std::path::Path::new(path).exists() {
400            return Some(path.clone());
401        }
402    }
403    None
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    /// Perfect balance: all threads same time.
411    #[test]
412    fn test_heijunka_perfect_balance() {
413        let times = vec![10.0, 10.0, 10.0, 10.0];
414        let score = RayonProfile::compute_heijunka_score(&times);
415        assert!((score - 0.0).abs() < 1e-10);
416    }
417
418    /// Severe imbalance: one thread does all the work.
419    #[test]
420    fn test_heijunka_severe_imbalance() {
421        let times = vec![100.0, 1.0, 1.0, 1.0];
422        let score = RayonProfile::compute_heijunka_score(&times);
423        assert!(
424            score > 0.5,
425            "Heijunka score {score} should be > 0.5 for severe imbalance"
426        );
427    }
428
429    /// FALSIFY-CGP-081: Intentionally imbalanced workload should have high score.
430    #[test]
431    fn test_heijunka_90pct_imbalance() {
432        let times = vec![
433            900.0,
434            100.0 / 7.0,
435            100.0 / 7.0,
436            100.0 / 7.0,
437            100.0 / 7.0,
438            100.0 / 7.0,
439            100.0 / 7.0,
440            100.0 / 7.0,
441        ];
442        let score = RayonProfile::compute_heijunka_score(&times);
443        assert!(
444            score > 0.5,
445            "Score {score} for 90% imbalance should be > 0.5"
446        );
447    }
448
449    #[test]
450    fn test_heijunka_empty() {
451        assert_eq!(RayonProfile::compute_heijunka_score(&[]), 0.0);
452        assert_eq!(RayonProfile::compute_heijunka_score(&[42.0]), 0.0);
453    }
454
455    #[test]
456    fn test_compute_speedup() {
457        assert!((RayonProfile::compute_speedup(1000.0, 250.0) - 4.0).abs() < 0.01);
458        assert!((RayonProfile::compute_speedup(1000.0, 0.0)).abs() < 0.01);
459    }
460
461    #[test]
462    fn test_compute_efficiency() {
463        assert!((RayonProfile::compute_efficiency(4.0, 8) - 0.5).abs() < 0.01);
464        assert!((RayonProfile::compute_efficiency(8.0, 8) - 1.0).abs() < 0.01);
465    }
466
467    #[test]
468    fn test_amdahl() {
469        // 100% parallelizable with 4 threads = 4x speedup
470        assert!((amdahl(1.0, 4) - 4.0).abs() < 0.01);
471        // 0% parallelizable = 1x speedup
472        assert!((amdahl(0.0, 4) - 1.0).abs() < 0.01);
473        // 50% parallelizable with 2 threads = 1.33x
474        assert!((amdahl(0.5, 2) - 1.333).abs() < 0.01);
475    }
476
477    #[test]
478    fn test_profile_parallel_auto_threads() {
479        let result = profile_parallel("gemm_heijunka", 4096, Some("auto"));
480        assert!(result.is_ok());
481    }
482
483    #[test]
484    fn test_extract_between() {
485        let line = "  Matrix Multiplication (1024x1024x1024)...     6.04 ms  (355.35 GFLOPS)";
486        assert_eq!(extract_between(line, "...", " ms"), Some("     6.04"));
487        // Uses rfind, so finds the LAST ( before " GFLOPS)"
488        assert_eq!(extract_between(line, "(", " GFLOPS)"), Some("355.35"));
489        assert_eq!(extract_between(line, "missing", " end"), None);
490    }
491
492    #[test]
493    fn test_scaling_point_serialization() {
494        let point = ScalingPoint {
495            threads: 8,
496            time_us: 5000.0,
497            gflops: 420.0,
498            scaling: 5.1,
499        };
500        let json = serde_json::to_string(&point).unwrap();
501        assert!(json.contains("\"threads\":8"));
502        assert!(json.contains("\"gflops\":420.0"));
503        let decoded: ScalingPoint = serde_json::from_str(&json).unwrap();
504        assert_eq!(decoded.threads, 8);
505        assert!((decoded.scaling - 5.1).abs() < 0.01);
506    }
507}