Skip to main content

rustify_ml/
profiler.rs

1use std::process::Command;
2
3use anyhow::{Context, Result, anyhow};
4use tracing::{info, warn};
5
6use crate::utils::{Hotspot, InputSource, ProfileSummary};
7
8/// Detect the Python executable name available on this system.
9/// Tries `python3` first (Linux/macOS convention), then falls back to `python`.
10/// Returns an error if neither is found.
11pub fn detect_python() -> Result<String> {
12    for candidate in &["python3", "python"] {
13        if let Ok(output) = Command::new(candidate).arg("--version").output()
14            && output.status.success()
15        {
16            return Ok(candidate.to_string());
17        }
18    }
19    Err(anyhow!(
20        "Python not found on PATH. Install Python 3.10+ and ensure it is on PATH."
21    ))
22}
23
24/// Check that the detected Python is >= 3.10. Warns (does not error) if older.
25fn check_python_version(python: &str) {
26    let check = Command::new(python)
27        .args([
28            "-c",
29            "import sys; ok = sys.version_info >= (3, 10); print('ok' if ok else f'old:{sys.version}')",
30        ])
31        .output();
32
33    match check {
34        Ok(out) if out.status.success() => {
35            let stdout = String::from_utf8_lossy(&out.stdout);
36            let result = stdout.trim();
37            if result == "ok" {
38                info!(python, "Python version check passed (>= 3.10)");
39            } else {
40                warn!(
41                    python,
42                    version = result.trim_start_matches("old:"),
43                    "Python < 3.10 detected; some profiling features may behave differently"
44                );
45            }
46        }
47        Ok(out) => {
48            let stderr = String::from_utf8_lossy(&out.stderr);
49            warn!(python, err = %stderr.trim(), "Python version check failed");
50        }
51        Err(e) => {
52            warn!(python, err = %e, "could not run Python version check");
53        }
54    }
55}
56
57/// Profile Python code with a configurable iteration count.
58/// Wraps `profile_input_core` with the given loop count.
59pub fn profile_input_with_iterations(
60    source: &InputSource,
61    threshold: f32,
62    iterations: u32,
63) -> Result<ProfileSummary> {
64    profile_input_core(source, threshold, iterations)
65}
66
67/// Profile Python code using the built-in cProfile via a Python subprocess.
68/// Uses a default of 100 iterations.
69pub fn profile_input(source: &InputSource, threshold: f32) -> Result<ProfileSummary> {
70    profile_input_core(source, threshold, 100)
71}
72
73/// Core profiling implementation.
74fn profile_input_core(
75    source: &InputSource,
76    threshold: f32,
77    iterations: u32,
78) -> Result<ProfileSummary> {
79    let python = detect_python()?;
80    check_python_version(&python);
81
82    // Write the Python script to a temp file for execution; keep dir alive for the duration of profiling
83    let (path, _tmpdir) = crate::utils::materialize_input(source)?;
84
85    let profiler = format!(
86        r#"
87import cProfile, pstats, runpy
88# Use a non-__main__ run_name to avoid executing script-side benchmarks guarded by
89# if __name__ == "__main__": blocks (prevents hangs during profiling).
90_iters = {iterations}
91prof = cProfile.Profile()
92prof.enable()
93for _ in range(_iters):
94    runpy.run_path(r"{path}", run_name="__rustify_profile__")
95prof.disable()
96stats = pstats.Stats(prof)
97total = sum(v[3] for v in stats.stats.values()) or 1e-9
98for (fname, line, func), stat in stats.stats.items():
99    ct = stat[3]
100    pct = (ct / total) * 100.0
101    print(f"{{pct:.2f}}% {{func}} {{fname}}:{{line}}")
102"#,
103        path = path.display()
104    );
105
106    let output = Command::new(&python)
107        .args(["-c", &profiler])
108        .output()
109        .with_context(|| {
110            format!(
111                "failed to run {} for profiling; ensure Python is installed",
112                python
113            )
114        })?;
115
116    if !output.status.success() {
117        let stderr = String::from_utf8_lossy(&output.stderr);
118        warn!("python profiling failed: {}", stderr.trim());
119        return Err(anyhow!("python profiling failed: {}", stderr.trim()));
120    }
121
122    let stdout = String::from_utf8_lossy(&output.stdout);
123    let hotspots = parse_hotspots(&stdout, threshold);
124    info!(
125        count = hotspots.len(),
126        threshold, "profiled hotspots collected"
127    );
128
129    Ok(ProfileSummary { hotspots })
130}
131
132fn parse_hotspots(stdout: &str, threshold: f32) -> Vec<Hotspot> {
133    let mut hotspots = Vec::new();
134    for line in stdout.lines() {
135        if let Some((percent_part, rest)) = line.split_once(' ')
136            && let Ok(percent) = percent_part.trim().trim_end_matches('%').parse::<f32>()
137        {
138            if rest.contains("<built-in") || rest.contains("<frozen") {
139                continue;
140            }
141            let mut parts = rest.rsplitn(2, ':');
142            if let (Some(line_part), Some(func_part)) = (parts.next(), parts.next())
143                && let Ok(line_no) = line_part.parse::<u32>()
144            {
145                hotspots.push(Hotspot {
146                    func: func_part.trim().to_string(),
147                    line: line_no,
148                    percent,
149                });
150            }
151        }
152    }
153
154    hotspots.retain(|h| h.percent >= threshold);
155    hotspots.sort_by(|a, b| b.percent.total_cmp(&a.percent));
156    hotspots
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[test]
164    fn test_parse_hotspots_filters_and_sorts() {
165        let stdout = "42.10% foo /tmp/code.py:10\n	not-a-match\n15.00% <built-in>:0\n20.00% bar /tmp/code.py:20";
166        let hs = parse_hotspots(stdout, 18.0);
167        assert_eq!(hs.len(), 2);
168        assert_eq!(hs[0].func, "foo");
169        assert_eq!(hs[0].line, 10);
170        assert_eq!(hs[1].func, "bar");
171        assert_eq!(hs[1].line, 20);
172    }
173}