Skip to main content

rustify_ml/
utils.rs

1use std::fs;
2use std::path::PathBuf;
3
4use anyhow::{Context, Result};
5use tempfile::TempDir;
6
7#[derive(Debug, Clone)]
8pub enum InputSource {
9    File {
10        path: PathBuf,
11        code: String,
12    },
13    Snippet(String),
14    Git {
15        repo: String,
16        path: PathBuf,
17        code: String,
18    },
19}
20
21#[derive(Debug, Clone)]
22pub struct Hotspot {
23    pub func: String,
24    pub line: u32,
25    pub percent: f32,
26}
27
28#[derive(Debug, Clone, Default)]
29pub struct ProfileSummary {
30    pub hotspots: Vec<Hotspot>,
31}
32
33#[derive(Debug, Clone)]
34pub struct TargetSpec {
35    pub func: String,
36    pub line: u32,
37    pub percent: f32,
38    pub reason: String,
39}
40
41#[derive(Debug, Clone)]
42pub struct GenerationResult {
43    pub crate_dir: PathBuf,
44    pub generated_functions: Vec<String>,
45    pub fallback_functions: usize,
46}
47
48/// Materialize the input source into a concrete file path for profiling/build steps.
49/// Returns the path and a TempDir to keep the file alive for the caller's scope.
50pub fn materialize_input(source: &InputSource) -> Result<(PathBuf, TempDir)> {
51    let tmpdir = tempfile::tempdir().context("failed to create temp dir for input")?;
52    let filename = match source {
53        InputSource::File { path, .. } => path
54            .file_name()
55            .map(PathBuf::from)
56            .unwrap_or_else(|| PathBuf::from("input.py")),
57        InputSource::Snippet(_) => PathBuf::from("input.py"),
58        InputSource::Git { path, .. } => path
59            .file_name()
60            .map(PathBuf::from)
61            .unwrap_or_else(|| PathBuf::from("input.py")),
62    };
63    let path = tmpdir.path().join(filename);
64
65    match source {
66        InputSource::File { path: src, .. } => {
67            fs::copy(src, &path)
68                .with_context(|| format!("failed to copy input from {}", src.display()))?;
69        }
70        InputSource::Snippet(code) => {
71            fs::write(&path, code).context("failed to write snippet to temp file")?;
72        }
73        InputSource::Git { code, .. } => {
74            fs::write(&path, code).context("failed to write git file to temp file")?;
75        }
76    }
77
78    Ok((path, tmpdir))
79}
80
81/// One row in the post-generation summary table printed to stdout.
82#[derive(Debug, Clone)]
83pub struct AccelerateRow {
84    pub func: String,
85    pub line: u32,
86    pub pct_time: f32,
87    pub translation: &'static str, // "Full" | "Partial"
88    pub status: String,            // "Success" | "Fallback: <reason>"
89}
90
91/// Print a simple ASCII summary table to stdout.
92pub fn print_summary(rows: &[AccelerateRow], crate_dir: &std::path::Path) {
93    let total = rows.len();
94    let fallbacks = rows.iter().filter(|r| r.translation == "Partial").count();
95    println!();
96    println!(
97        "Accelerated {}/{} targets ({} fallback{})",
98        total - fallbacks,
99        total,
100        fallbacks,
101        if fallbacks == 1 { "" } else { "s" }
102    );
103    println!();
104    println!(
105        "{:<22} | {:>4} | {:>6} | {:<11} | Status",
106        "Func", "Line", "% Time", "Translation"
107    );
108    println!("{}", "-".repeat(22 + 3 + 4 + 3 + 6 + 3 + 11 + 3 + 20));
109    for row in rows {
110        println!(
111            "{:<22} | {:>4} | {:>5.1}% | {:<11} | {}",
112            row.func, row.line, row.pct_time, row.translation, row.status
113        );
114    }
115    println!();
116    println!("Generated: {}", crate_dir.display());
117    println!(
118        "Install:   cd {} && maturin develop --release",
119        crate_dir.display()
120    );
121    println!();
122}
123
124/// Print a hotspot table to stdout (used by --list-targets).
125pub fn print_hotspot_table(hotspots: &[Hotspot]) {
126    println!();
127    if hotspots.is_empty() {
128        println!("No hotspots found above threshold.");
129        println!();
130        return;
131    }
132    println!("Hotspots (ranked by CPU time):");
133    println!();
134    println!("{:<30} | {:>4} | {:>7}", "Function", "Line", "% Time");
135    println!("{}", "-".repeat(30 + 3 + 4 + 3 + 7));
136    for h in hotspots {
137        println!("{:<30} | {:>4} | {:>6.2}%", h.func, h.line, h.percent);
138    }
139    println!();
140    println!("Run with --threshold <N> to filter; use --function <name> to target one directly.");
141    println!();
142}
143
144pub fn extract_code(source: &InputSource) -> Result<String> {
145    match source {
146        InputSource::File { code, .. } => Ok(code.clone()),
147        InputSource::Snippet(code) => Ok(code.clone()),
148        InputSource::Git { code, .. } => Ok(code.clone()),
149    }
150}