1use anyhow::Result;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct RayonProfile {
10 pub wall_time_us: f64,
11 pub single_thread_time_us: f64,
12 pub parallel_speedup: f64,
13 pub num_threads: usize,
14 pub parallel_efficiency: f64,
15 pub heijunka_score: f64,
17 pub thread_spawn_overhead_us: f64,
18 pub work_steal_count: u64,
19}
20
21impl RayonProfile {
22 pub fn compute_heijunka_score(per_thread_times: &[f64]) -> f64 {
26 if per_thread_times.is_empty() || per_thread_times.len() == 1 {
27 return 0.0;
28 }
29 let mean = per_thread_times.iter().sum::<f64>() / per_thread_times.len() as f64;
30 if mean == 0.0 {
31 return 0.0;
32 }
33 let variance = per_thread_times
34 .iter()
35 .map(|t| (t - mean).powi(2))
36 .sum::<f64>()
37 / per_thread_times.len() as f64;
38 let cv = variance.sqrt() / mean;
39 cv.min(1.0) }
41
42 pub fn compute_speedup(single_thread_us: f64, parallel_us: f64) -> f64 {
44 if parallel_us > 0.0 {
45 single_thread_us / parallel_us
46 } else {
47 0.0
48 }
49 }
50
51 pub fn compute_efficiency(speedup: f64, num_threads: usize) -> f64 {
53 if num_threads > 0 {
54 speedup / num_threads as f64
55 } else {
56 0.0
57 }
58 }
59}
60
61const TIMING_RUNS: usize = 3;
63
64pub fn profile_parallel(function: &str, size: u32, threads: Option<&str>) -> Result<()> {
68 let thread_count = match threads {
69 Some("auto") | None => num_cpus::get(),
70 Some(n) => n
71 .parse()
72 .map_err(|_| anyhow::anyhow!("Invalid thread count: {n}"))?,
73 };
74
75 println!("\n=== CGP Parallel Profile: {function} (size={size}, threads={thread_count}) ===\n");
76 println!(" Backend: Rayon thread pool");
77 println!(" Function: {function}");
78 println!(" Size: {size}");
79 println!(" Threads: {thread_count}");
80
81 let binary = find_parallel_binary();
83 match binary {
84 Some(bin) => {
85 println!(" Binary: {bin}");
86 println!(" Timing: min of {TIMING_RUNS} runs");
87
88 let single_time = time_binary_min_of_n(&bin, 1, TIMING_RUNS);
90 let parallel_time = time_binary_min_of_n(&bin, thread_count, TIMING_RUNS);
92
93 match (single_time, parallel_time) {
94 (Some(st), Some(pt)) => {
95 let speedup = RayonProfile::compute_speedup(st, pt);
96 let efficiency = RayonProfile::compute_efficiency(speedup, thread_count);
97
98 println!("\n Results:");
99 println!(" Single-thread: {st:.0} us");
100 println!(" {thread_count}-thread: {pt:.0} us");
101 println!(" Parallel speedup: {speedup:.2}x");
102 println!(" Efficiency: {:.1}%", efficiency * 100.0);
103
104 let overhead_estimate = 40.0; let overhead_pct = if pt > 0.0 {
107 overhead_estimate / pt * 100.0
108 } else {
109 0.0
110 };
111 println!(
112 " Thread overhead: ~{overhead_estimate:.0} us ({overhead_pct:.1}% of total)"
113 );
114
115 if pt < 500.0 {
117 println!(
118 "\n \x1b[33m[WARN]\x1b[0m Workload <500us — thread overhead dominates"
119 );
120 println!(" Consider: sequential execution or batching");
121 }
122 }
123 _ => {
124 println!("\n Could not time binary — showing configuration only.");
125 }
126 }
127 }
128 None => {
129 println!(" No benchmark binary found.");
130 println!(" Build with: cargo build --release --examples");
131 println!(" Estimated metrics with synthetic data:");
132
133 println!("\n Theoretical Analysis:");
135 println!(
136 " Amdahl's law: if 95% parallelizable, max speedup = {:.1}x",
137 amdahl(0.95, thread_count)
138 );
139 println!(
140 " If 90% parallelizable, max speedup = {:.1}x",
141 amdahl(0.90, thread_count)
142 );
143 println!(
144 " If 80% parallelizable, max speedup = {:.1}x",
145 amdahl(0.80, thread_count)
146 );
147 }
148 }
149
150 println!();
151 Ok(())
152}
153
154pub fn profile_scaling(
157 size: u32,
158 max_threads: Option<usize>,
159 runs: usize,
160 json: bool,
161) -> anyhow::Result<()> {
162 let max_t = max_threads.unwrap_or_else(num_cpus::get);
163 let binary = find_parallel_binary();
164 let bin = match &binary {
165 Some(b) => b.as_str(),
166 None => {
167 anyhow::bail!(
168 "No benchmark binary found. Build with: cargo build --release --example benchmark_matrix_suite --features parallel"
169 )
170 }
171 };
172
173 let mut thread_counts: Vec<usize> = vec![1, 2, 4, 8, 12, 16, 24, 32, 48, 64];
175 thread_counts.retain(|&t| t <= max_t);
176 if !thread_counts.contains(&max_t) {
177 thread_counts.push(max_t);
178 }
179
180 let (baseline_ms, baseline_gflops) = parse_gemm_time(bin, size, 1, runs)
182 .ok_or_else(|| anyhow::anyhow!(
183 "Failed to parse GEMM {size}x{size} from benchmark output (1 thread). \
184 Ensure the binary outputs 'Matrix Multiplication ({size}x{size}x{size})... X.XX ms (Y.YY GFLOPS)'"
185 ))?;
186
187 let mut results: Vec<ScalingPoint> = Vec::new();
188
189 if !json {
190 println!("\n=== CGP Parallel Scaling: GEMM {size}x{size}, min-of-{runs} runs ===\n");
191 println!(
192 " {:>8} | {:>10} | {:>10} | {:>8} | {:>6}",
193 "Threads", "Time (ms)", "GFLOPS", "Scaling", "Notes"
194 );
195 println!(" {}", "-".repeat(60));
196 }
197
198 for &t in &thread_counts {
199 match parse_gemm_time(bin, size, t, runs) {
200 Some((time_ms, gflops)) => {
201 let scaling = baseline_ms / time_ms;
202
203 let notes = if t == 1 {
204 "baseline".to_string()
205 } else if scaling >= (t as f64) * 0.9 {
206 "near-linear".to_string()
207 } else {
208 String::new()
209 };
210
211 if !json {
212 println!(
213 " {:>8} | {:>9.2} | {:>10.1} | {:>7.1}x | {notes}",
214 t, time_ms, gflops, scaling
215 );
216 }
217 results.push(ScalingPoint {
218 threads: t,
219 time_us: time_ms * 1000.0,
220 gflops,
221 scaling,
222 });
223 }
224 None => {
225 if !json {
226 println!(" {:>8} | {:>10} | {:>10} | {:>8} |", t, "FAILED", "-", "-");
227 }
228 }
229 }
230 }
231
232 if let Some(peak) = results.iter().max_by(|a, b| {
234 a.gflops
235 .partial_cmp(&b.gflops)
236 .unwrap_or(std::cmp::Ordering::Equal)
237 }) {
238 if !json {
239 println!(
240 "\n Peak: {:.1} GFLOPS at {}T ({:.1}x scaling)",
241 peak.gflops, peak.threads, peak.scaling
242 );
243 let theoretical_peak = baseline_gflops * peak.threads as f64;
244 let efficiency = peak.gflops / theoretical_peak * 100.0;
245 println!(" Efficiency: {efficiency:.1}% vs linear scaling");
246 }
247 }
248
249 if json {
250 println!("{}", serde_json::to_string_pretty(&results)?);
251 } else {
252 println!();
253 }
254
255 Ok(())
256}
257
258#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
260pub struct ScalingPoint {
261 pub threads: usize,
262 pub time_us: f64,
263 pub gflops: f64,
264 pub scaling: f64,
265}
266
267fn amdahl(parallel_fraction: f64, threads: usize) -> f64 {
269 1.0 / ((1.0 - parallel_fraction) + parallel_fraction / threads as f64)
270}
271
272fn time_binary_min_of_n(binary: &str, threads: usize, runs: usize) -> Option<f64> {
275 let mut best: Option<f64> = None;
276 for _ in 0..runs {
277 let start = std::time::Instant::now();
278 let output = std::process::Command::new(binary)
279 .env("RAYON_NUM_THREADS", threads.to_string())
280 .output()
281 .ok()?;
282 if !output.status.success() {
283 return None;
284 }
285 let elapsed = start.elapsed().as_secs_f64() * 1e6;
286 best = Some(best.map_or(elapsed, |b: f64| b.min(elapsed)));
287 }
288 best
289}
290
291fn parse_gemm_time(binary: &str, size: u32, threads: usize, runs: usize) -> Option<(f64, f64)> {
295 let pattern = format!("{size}x{size}x{size}");
296 let mut best_time_ms: Option<f64> = None;
297 let mut best_gflops: Option<f64> = None;
298
299 for _ in 0..runs {
300 let output = std::process::Command::new(binary)
301 .env("RAYON_NUM_THREADS", threads.to_string())
302 .output()
303 .ok()?;
304 if !output.status.success() {
305 return None;
306 }
307 let stdout = String::from_utf8_lossy(&output.stdout);
308 for line in stdout.lines() {
309 if line.contains("Matrix Multiplication") && line.contains(&pattern) {
310 if let Some(ms_str) = extract_between(line, "...", " ms") {
312 if let Ok(ms) = ms_str.trim().parse::<f64>() {
313 if best_time_ms.is_none_or(|best| ms < best) {
314 best_time_ms = Some(ms);
315 if let Some(gf_str) = extract_between(line, "(", " GFLOPS)") {
317 if let Ok(gf) = gf_str.trim().parse::<f64>() {
318 best_gflops = Some(gf);
319 }
320 }
321 }
322 }
323 }
324 }
325 }
326 }
327
328 match (best_time_ms, best_gflops) {
329 (Some(ms), Some(gf)) => Some((ms, gf)),
330 _ => None,
331 }
332}
333
334fn extract_between<'a>(s: &'a str, start: &str, end: &str) -> Option<&'a str> {
337 let end_idx = s.find(end)?;
338 let prefix = &s[..end_idx];
339 let start_idx = prefix.rfind(start)? + start.len();
340 Some(&s[start_idx..end_idx])
341}
342
343fn find_parallel_binary() -> Option<String> {
345 let target_dir = std::env::var("CARGO_TARGET_DIR").unwrap_or_default();
346 let mut candidates: Vec<String> = Vec::new();
347 if !target_dir.is_empty() {
348 candidates.push(format!(
349 "{target_dir}/release/examples/benchmark_matrix_suite"
350 ));
351 }
352 candidates.extend_from_slice(&[
353 "/mnt/nvme-raid0/targets/trueno/release/examples/benchmark_matrix_suite".to_string(),
354 "./target/release/examples/benchmark_matrix_suite".to_string(),
355 ]);
356 for path in &candidates {
357 if std::path::Path::new(path).exists() {
358 return Some(path.clone());
359 }
360 }
361 None
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
370 fn test_heijunka_perfect_balance() {
371 let times = vec![10.0, 10.0, 10.0, 10.0];
372 let score = RayonProfile::compute_heijunka_score(×);
373 assert!((score - 0.0).abs() < 1e-10);
374 }
375
376 #[test]
378 fn test_heijunka_severe_imbalance() {
379 let times = vec![100.0, 1.0, 1.0, 1.0];
380 let score = RayonProfile::compute_heijunka_score(×);
381 assert!(
382 score > 0.5,
383 "Heijunka score {score} should be > 0.5 for severe imbalance"
384 );
385 }
386
387 #[test]
389 fn test_heijunka_90pct_imbalance() {
390 let times = vec![
391 900.0,
392 100.0 / 7.0,
393 100.0 / 7.0,
394 100.0 / 7.0,
395 100.0 / 7.0,
396 100.0 / 7.0,
397 100.0 / 7.0,
398 100.0 / 7.0,
399 ];
400 let score = RayonProfile::compute_heijunka_score(×);
401 assert!(
402 score > 0.5,
403 "Score {score} for 90% imbalance should be > 0.5"
404 );
405 }
406
407 #[test]
408 fn test_heijunka_empty() {
409 assert_eq!(RayonProfile::compute_heijunka_score(&[]), 0.0);
410 assert_eq!(RayonProfile::compute_heijunka_score(&[42.0]), 0.0);
411 }
412
413 #[test]
414 fn test_compute_speedup() {
415 assert!((RayonProfile::compute_speedup(1000.0, 250.0) - 4.0).abs() < 0.01);
416 assert!((RayonProfile::compute_speedup(1000.0, 0.0)).abs() < 0.01);
417 }
418
419 #[test]
420 fn test_compute_efficiency() {
421 assert!((RayonProfile::compute_efficiency(4.0, 8) - 0.5).abs() < 0.01);
422 assert!((RayonProfile::compute_efficiency(8.0, 8) - 1.0).abs() < 0.01);
423 }
424
425 #[test]
426 fn test_amdahl() {
427 assert!((amdahl(1.0, 4) - 4.0).abs() < 0.01);
429 assert!((amdahl(0.0, 4) - 1.0).abs() < 0.01);
431 assert!((amdahl(0.5, 2) - 1.333).abs() < 0.01);
433 }
434
435 #[test]
436 fn test_profile_parallel_auto_threads() {
437 let result = profile_parallel("gemm_heijunka", 4096, Some("auto"));
438 assert!(result.is_ok());
439 }
440
441 #[test]
442 fn test_extract_between() {
443 let line = " Matrix Multiplication (1024x1024x1024)... 6.04 ms (355.35 GFLOPS)";
444 assert_eq!(extract_between(line, "...", " ms"), Some(" 6.04"));
445 assert_eq!(extract_between(line, "(", " GFLOPS)"), Some("355.35"));
447 assert_eq!(extract_between(line, "missing", " end"), None);
448 }
449
450 #[test]
451 fn test_scaling_point_serialization() {
452 let point = ScalingPoint {
453 threads: 8,
454 time_us: 5000.0,
455 gflops: 420.0,
456 scaling: 5.1,
457 };
458 let json = serde_json::to_string(&point).unwrap();
459 assert!(json.contains("\"threads\":8"));
460 assert!(json.contains("\"gflops\":420.0"));
461 let decoded: ScalingPoint = serde_json::from_str(&json).unwrap();
462 assert_eq!(decoded.threads, 8);
463 assert!((decoded.scaling - 5.1).abs() < 0.01);
464 }
465}