1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct RayonProfile {
10 pub wall_time_us: f64,
11 pub single_thread_time_us: f64,
12 pub parallel_speedup: f64,
13 pub num_threads: usize,
14 pub parallel_efficiency: f64,
15 pub heijunka_score: f64,
17 pub thread_spawn_overhead_us: f64,
18 pub work_steal_count: u64,
19}
20
21impl RayonProfile {
22 pub fn compute_heijunka_score(per_thread_times: &[f64]) -> f64 {
26 if per_thread_times.is_empty() || per_thread_times.len() == 1 {
27 return 0.0;
28 }
29 let mean = per_thread_times.iter().sum::<f64>() / per_thread_times.len() as f64;
30 if mean == 0.0 {
31 return 0.0;
32 }
33 let variance = per_thread_times
34 .iter()
35 .map(|t| (t - mean).powi(2))
36 .sum::<f64>()
37 / per_thread_times.len() as f64;
38 let cv = variance.sqrt() / mean;
39 cv.min(1.0) }
41
42 pub fn compute_speedup(single_thread_us: f64, parallel_us: f64) -> f64 {
44 if parallel_us > 0.0 {
45 single_thread_us / parallel_us
46 } else {
47 0.0
48 }
49 }
50
51 pub fn compute_efficiency(speedup: f64, num_threads: usize) -> f64 {
53 if num_threads > 0 {
54 speedup / num_threads as f64
55 } else {
56 0.0
57 }
58 }
59}
60
61const TIMING_RUNS: usize = 3;
63
64pub fn profile_parallel(function: &str, size: u32, threads: Option<&str>) -> Result<()> {
68 let thread_count = match threads {
69 Some("auto") | None => num_cpus::get(),
70 Some(n) => n
71 .parse()
72 .map_err(|_| anyhow::anyhow!("Invalid thread count: {n}"))?,
73 };
74
75 println!("\n=== CGP Parallel Profile: {function} (size={size}, threads={thread_count}) ===\n");
76 println!(" Backend: Rayon thread pool");
77 println!(" Function: {function}");
78 println!(" Size: {size}");
79 println!(" Threads: {thread_count}");
80
81 let binary = find_parallel_binary();
83 match binary {
84 Some(bin) => {
85 println!(" Binary: {bin}");
86 println!(" Timing: min of {TIMING_RUNS} runs");
87
88 let single_time = time_binary_min_of_n(&bin, 1, TIMING_RUNS);
90 let parallel_time = time_binary_min_of_n(&bin, thread_count, TIMING_RUNS);
92
93 match (single_time, parallel_time) {
94 (Some(st), Some(pt)) => {
95 let speedup = RayonProfile::compute_speedup(st, pt);
96 let efficiency = RayonProfile::compute_efficiency(speedup, thread_count);
97
98 println!("\n Results:");
99 println!(" Single-thread: {st:.0} us");
100 println!(" {thread_count}-thread: {pt:.0} us");
101 println!(" Parallel speedup: {speedup:.2}x");
102 println!(" Efficiency: {:.1}%", efficiency * 100.0);
103
104 let overhead_estimate = 40.0; let overhead_pct = if pt > 0.0 {
107 overhead_estimate / pt * 100.0
108 } else {
109 0.0
110 };
111 println!(
112 " Thread overhead: ~{overhead_estimate:.0} us ({overhead_pct:.1}% of total)"
113 );
114
115 if pt < 500.0 {
117 println!(
118 "\n \x1b[33m[WARN]\x1b[0m Workload <500us — thread overhead dominates"
119 );
120 println!(" Consider: sequential execution or batching");
121 }
122 }
123 _ => {
124 println!("\n Could not time binary — showing configuration only.");
125 }
126 }
127 }
128 None => {
129 println!(" No benchmark binary found.");
130 println!(" Build with: cargo build --release --examples");
131 println!(" Estimated metrics with synthetic data:");
132
133 println!("\n Theoretical Analysis:");
135 println!(
136 " Amdahl's law: if 95% parallelizable, max speedup = {:.1}x",
137 amdahl(0.95, thread_count)
138 );
139 println!(
140 " If 90% parallelizable, max speedup = {:.1}x",
141 amdahl(0.90, thread_count)
142 );
143 println!(
144 " If 80% parallelizable, max speedup = {:.1}x",
145 amdahl(0.80, thread_count)
146 );
147 }
148 }
149
150 println!();
151 Ok(())
152}
153
154pub fn profile_scaling(
157 size: u32,
158 max_threads: Option<usize>,
159 runs: usize,
160 json: bool,
161) -> anyhow::Result<()> {
162 let max_t = max_threads.unwrap_or_else(num_cpus::get);
163 let binary = resolve_bench_binary()?;
164 let bin = binary.as_str();
165
166 let thread_counts = build_thread_counts(max_t);
167
168 let (baseline_ms, baseline_gflops) = parse_gemm_time(bin, size, 1, runs)
169 .ok_or_else(|| anyhow::anyhow!(
170 "Failed to parse GEMM {size}x{size} from benchmark output (1 thread). \
171 Ensure the binary outputs 'Matrix Multiplication ({size}x{size}x{size})... X.XX ms (Y.YY GFLOPS)'"
172 ))?;
173
174 if !json {
175 println!("\n=== CGP Parallel Scaling: GEMM {size}x{size}, min-of-{runs} runs ===\n");
176 println!(
177 " {:>8} | {:>10} | {:>10} | {:>8} | {:>6}",
178 "Threads", "Time (ms)", "GFLOPS", "Scaling", "Notes"
179 );
180 println!(" {}", "-".repeat(60));
181 }
182
183 let results = run_scaling_sweep(bin, size, &thread_counts, runs, baseline_ms, json);
184
185 if !json {
186 print_scaling_peak(&results, baseline_gflops);
187 }
188
189 if json {
190 println!("{}", serde_json::to_string_pretty(&results)?);
191 } else {
192 println!();
193 }
194
195 Ok(())
196}
197
198fn resolve_bench_binary() -> anyhow::Result<String> {
200 find_parallel_binary().ok_or_else(|| anyhow::anyhow!(
201 "No benchmark binary found. Build with: cargo build --release --example benchmark_matrix_suite --features parallel"
202 ))
203}
204
205fn build_thread_counts(max_t: usize) -> Vec<usize> {
207 let mut thread_counts: Vec<usize> = vec![1, 2, 4, 8, 12, 16, 24, 32, 48, 64];
208 thread_counts.retain(|&t| t <= max_t);
209 if !thread_counts.contains(&max_t) {
210 thread_counts.push(max_t);
211 }
212 thread_counts
213}
214
215fn run_scaling_sweep(
217 bin: &str,
218 size: u32,
219 thread_counts: &[usize],
220 runs: usize,
221 baseline_ms: f64,
222 json: bool,
223) -> Vec<ScalingPoint> {
224 let mut results: Vec<ScalingPoint> = Vec::new();
225 for &t in thread_counts {
226 match parse_gemm_time(bin, size, t, runs) {
227 Some((time_ms, gflops)) => {
228 let scaling = baseline_ms / time_ms;
229 if !json {
230 let notes = scaling_notes(t, scaling);
231 println!(
232 " {:>8} | {:>9.2} | {:>10.1} | {:>7.1}x | {notes}",
233 t, time_ms, gflops, scaling
234 );
235 }
236 results.push(ScalingPoint {
237 threads: t,
238 time_us: time_ms * 1000.0,
239 gflops,
240 scaling,
241 });
242 }
243 None => {
244 if !json {
245 println!(" {:>8} | {:>10} | {:>10} | {:>8} |", t, "FAILED", "-", "-");
246 }
247 }
248 }
249 }
250 results
251}
252
253fn scaling_notes(threads: usize, scaling: f64) -> String {
255 if threads == 1 {
256 "baseline".to_string()
257 } else if scaling >= (threads as f64) * 0.9 {
258 "near-linear".to_string()
259 } else {
260 String::new()
261 }
262}
263
264fn print_scaling_peak(results: &[ScalingPoint], baseline_gflops: f64) {
266 let Some(peak) = results.iter().max_by(|a, b| {
267 a.gflops
268 .partial_cmp(&b.gflops)
269 .unwrap_or(std::cmp::Ordering::Equal)
270 }) else {
271 return;
272 };
273 println!(
274 "\n Peak: {:.1} GFLOPS at {}T ({:.1}x scaling)",
275 peak.gflops, peak.threads, peak.scaling
276 );
277 let theoretical_peak = baseline_gflops * peak.threads as f64;
278 let efficiency = peak.gflops / theoretical_peak * 100.0;
279 println!(" Efficiency: {efficiency:.1}% vs linear scaling");
280}
281
282#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
284pub struct ScalingPoint {
285 pub threads: usize,
286 pub time_us: f64,
287 pub gflops: f64,
288 pub scaling: f64,
289}
290
291fn amdahl(parallel_fraction: f64, threads: usize) -> f64 {
293 1.0 / ((1.0 - parallel_fraction) + parallel_fraction / threads as f64)
294}
295
296fn time_binary_min_of_n(binary: &str, threads: usize, runs: usize) -> Option<f64> {
299 let mut best: Option<f64> = None;
300 for _ in 0..runs {
301 let start = std::time::Instant::now();
302 let output = std::process::Command::new(binary)
303 .env("RAYON_NUM_THREADS", threads.to_string())
304 .output()
305 .ok()?;
306 if !output.status.success() {
307 return None;
308 }
309 let elapsed = start.elapsed().as_secs_f64() * 1e6;
310 best = Some(best.map_or(elapsed, |b: f64| b.min(elapsed)));
311 }
312 best
313}
314
315fn parse_gemm_time(binary: &str, size: u32, threads: usize, runs: usize) -> Option<(f64, f64)> {
319 let pattern = format!("{size}x{size}x{size}");
320 let mut best: Option<(f64, f64)> = None;
321
322 for _ in 0..runs {
323 let stdout = run_bench_once(binary, threads)?;
324 if let Some((ms, gf)) = extract_best_matching(&stdout, &pattern) {
325 if best.is_none_or(|(b_ms, _)| ms < b_ms) {
326 best = Some((ms, gf));
327 }
328 }
329 }
330
331 best
332}
333
334fn run_bench_once(binary: &str, threads: usize) -> Option<String> {
336 let output = std::process::Command::new(binary)
337 .env("RAYON_NUM_THREADS", threads.to_string())
338 .output()
339 .ok()?;
340 if !output.status.success() {
341 return None;
342 }
343 Some(String::from_utf8_lossy(&output.stdout).into_owned())
344}
345
346fn extract_best_matching(stdout: &str, pattern: &str) -> Option<(f64, f64)> {
348 let mut best: Option<(f64, f64)> = None;
349 for line in stdout.lines() {
350 let Some((ms, gf)) = parse_gemm_line(line, pattern) else {
351 continue;
352 };
353 if best.is_none_or(|(b_ms, _)| ms < b_ms) {
354 best = Some((ms, gf));
355 }
356 }
357 best
358}
359
360fn parse_gemm_line(line: &str, pattern: &str) -> Option<(f64, f64)> {
362 if !line.contains("Matrix Multiplication") || !line.contains(pattern) {
363 return None;
364 }
365 let ms = extract_between(line, "...", " ms")?
366 .trim()
367 .parse::<f64>()
368 .ok()?;
369 let gf = extract_between(line, "(", " GFLOPS)")?
370 .trim()
371 .parse::<f64>()
372 .ok()?;
373 Some((ms, gf))
374}
375
376fn extract_between<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> {
379 let end_idx = s.find(end)?;
380 let prefix = &s[..end_idx];
381 let start_idx = prefix.rfind(start)? + start.len();
382 Some(&s[start_idx..end_idx])
383}
384
385fn find_parallel_binary() -> Option<String> {
387 let target_dir = std::env::var("CARGO_TARGET_DIR").unwrap_or_default();
388 let mut candidates: Vec<String> = Vec::new();
389 if !target_dir.is_empty() {
390 candidates.push(format!(
391 "{target_dir}/release/examples/benchmark_matrix_suite"
392 ));
393 }
394 candidates.extend_from_slice(&[
395 "/mnt/nvme-raid0/targets/trueno/release/examples/benchmark_matrix_suite".to_string(),
396 "./target/release/examples/benchmark_matrix_suite".to_string(),
397 ]);
398 for path in &candidates {
399 if std::path::Path::new(path).exists() {
400 return Some(path.clone());
401 }
402 }
403 None
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
412 fn test_heijunka_perfect_balance() {
413 let times = vec![10.0, 10.0, 10.0, 10.0];
414 let score = RayonProfile::compute_heijunka_score(×);
415 assert!((score - 0.0).abs() < 1e-10);
416 }
417
418 #[test]
420 fn test_heijunka_severe_imbalance() {
421 let times = vec![100.0, 1.0, 1.0, 1.0];
422 let score = RayonProfile::compute_heijunka_score(×);
423 assert!(
424 score > 0.5,
425 "Heijunka score {score} should be > 0.5 for severe imbalance"
426 );
427 }
428
429 #[test]
431 fn test_heijunka_90pct_imbalance() {
432 let times = vec![
433 900.0,
434 100.0 / 7.0,
435 100.0 / 7.0,
436 100.0 / 7.0,
437 100.0 / 7.0,
438 100.0 / 7.0,
439 100.0 / 7.0,
440 100.0 / 7.0,
441 ];
442 let score = RayonProfile::compute_heijunka_score(×);
443 assert!(
444 score > 0.5,
445 "Score {score} for 90% imbalance should be > 0.5"
446 );
447 }
448
449 #[test]
450 fn test_heijunka_empty() {
451 assert_eq!(RayonProfile::compute_heijunka_score(&[]), 0.0);
452 assert_eq!(RayonProfile::compute_heijunka_score(&[42.0]), 0.0);
453 }
454
455 #[test]
456 fn test_compute_speedup() {
457 assert!((RayonProfile::compute_speedup(1000.0, 250.0) - 4.0).abs() < 0.01);
458 assert!((RayonProfile::compute_speedup(1000.0, 0.0)).abs() < 0.01);
459 }
460
461 #[test]
462 fn test_compute_efficiency() {
463 assert!((RayonProfile::compute_efficiency(4.0, 8) - 0.5).abs() < 0.01);
464 assert!((RayonProfile::compute_efficiency(8.0, 8) - 1.0).abs() < 0.01);
465 }
466
467 #[test]
468 fn test_amdahl() {
469 assert!((amdahl(1.0, 4) - 4.0).abs() < 0.01);
471 assert!((amdahl(0.0, 4) - 1.0).abs() < 0.01);
473 assert!((amdahl(0.5, 2) - 1.333).abs() < 0.01);
475 }
476
477 #[test]
478 fn test_profile_parallel_auto_threads() {
479 let result = profile_parallel("gemm_heijunka", 4096, Some("auto"));
480 assert!(result.is_ok());
481 }
482
483 #[test]
484 fn test_extract_between() {
485 let line = " Matrix Multiplication (1024x1024x1024)... 6.04 ms (355.35 GFLOPS)";
486 assert_eq!(extract_between(line, "...", " ms"), Some(" 6.04"));
487 assert_eq!(extract_between(line, "(", " GFLOPS)"), Some("355.35"));
489 assert_eq!(extract_between(line, "missing", " end"), None);
490 }
491
492 #[test]
493 fn test_scaling_point_serialization() {
494 let point = ScalingPoint {
495 threads: 8,
496 time_us: 5000.0,
497 gflops: 420.0,
498 scaling: 5.1,
499 };
500 let json = serde_json::to_string(&point).unwrap();
501 assert!(json.contains("\"threads\":8"));
502 assert!(json.contains("\"gflops\":420.0"));
503 let decoded: ScalingPoint = serde_json::from_str(&json).unwrap();
504 assert_eq!(decoded.threads, 8);
505 assert!((decoded.scaling - 5.1).abs() < 0.01);
506 }
507}