use anyhow::Result;
use cgp::{analysis, doctor, profilers};
use clap::{Parser, Subcommand};
#[derive(Parser)]
#[command(name = "cgp", version, about, long_about = None)]
struct Cli {
#[arg(long, global = true)]
json: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Profile {
#[command(subcommand)]
target: ProfileTarget,
},
Bench {
#[arg(long)]
bench: String,
#[arg(long)]
counters: Option<String>,
#[arg(long)]
check_regression: bool,
#[arg(long, default_value = "5")]
threshold: f64,
#[arg(long)]
roofline: bool,
},
Roofline {
#[arg(long)]
target: String,
#[arg(long)]
kernels: Option<String>,
#[arg(long)]
export: Option<String>,
#[arg(long)]
empirical: bool,
},
Diff {
#[arg(long)]
baseline: Option<String>,
#[arg(long)]
current: Option<String>,
#[arg(long)]
before: Option<String>,
#[arg(long)]
after: Option<String>,
},
Contract {
#[command(subcommand)]
action: ContractAction,
},
Trace {
binary: String,
#[arg(long)]
duration: Option<String>,
},
Explain {
target: String,
#[arg(long)]
kernel: Option<String>,
},
Tui,
Baseline {
#[arg(long)]
save: Option<String>,
#[arg(long)]
load: Option<String>,
},
Doctor,
Compete {
workload: String,
#[arg(long)]
ours: String,
#[arg(long)]
theirs: Vec<String>,
#[arg(long)]
label: Option<String>,
},
}
#[derive(Subcommand)]
enum ProfileTarget {
Kernel {
#[arg(long)]
name: String,
#[arg(long)]
size: u32,
#[arg(long)]
roofline: bool,
#[arg(long)]
metrics: Option<String>,
},
Cublas {
#[arg(long)]
op: String,
#[arg(long)]
size: u32,
},
Wgpu {
#[arg(long)]
shader: String,
#[arg(long)]
dispatch: Option<String>,
#[arg(long)]
target: Option<String>,
},
Metal {
#[arg(long)]
shader: String,
#[arg(long)]
dispatch: Option<u32>,
},
Simd {
#[arg(long)]
function: String,
#[arg(long)]
size: u32,
#[arg(long)]
arch: String,
},
Wasm {
#[arg(long)]
function: String,
#[arg(long)]
size: u32,
},
Quant {
#[arg(long, required_unless_present = "all")]
kernel: Option<String>,
#[arg(long, required_unless_present = "all")]
size: Option<String>,
#[arg(long)]
all: bool,
},
Scalar {
#[arg(long)]
function: String,
#[arg(long)]
size: u32,
},
Parallel {
#[arg(long)]
function: String,
#[arg(long)]
size: u32,
#[arg(long)]
threads: Option<String>,
},
Compare {
#[arg(long)]
kernel: String,
#[arg(long)]
size: u32,
#[arg(long)]
backends: String,
},
Scaling {
#[arg(long)]
size: u32,
#[arg(long)]
max_threads: Option<usize>,
#[arg(long, default_value = "3")]
runs: usize,
},
Binary {
path: String,
#[arg(long)]
kernel_filter: Option<String>,
#[arg(long)]
trace: bool,
#[arg(long)]
duration: Option<String>,
},
Python {
#[arg(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String>,
},
Library {
#[arg(long)]
so: String,
#[arg(long)]
symbol: String,
#[arg(long)]
args: Option<String>,
},
}
#[derive(Subcommand)]
enum ContractAction {
Verify {
#[arg(long)]
contracts_dir: Option<String>,
#[arg(long)]
contract: Option<String>,
#[arg(long)]
fail_on_regression: bool,
#[arg(long, name = "self")]
self_verify: bool,
},
Generate {
#[arg(long)]
kernel: String,
#[arg(long)]
size: u32,
#[arg(long, default_value = "10")]
tolerance: f64,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
let json = cli.json;
match cli.command {
Commands::Doctor => doctor::run_doctor(json),
Commands::Profile { target } => dispatch_profile(target, json),
Commands::Roofline {
target,
kernels,
export,
empirical,
} => analysis::roofline::run_roofline(
&target,
kernels.as_deref(),
export.as_deref(),
empirical,
json,
),
Commands::Bench {
bench,
counters,
check_regression,
threshold,
roofline,
} => analysis::bench::run_bench(
&bench,
counters.as_deref(),
check_regression,
threshold,
roofline,
),
Commands::Diff {
baseline,
current,
before,
after,
} => analysis::diff::run_diff(
baseline.as_deref(),
current.as_deref(),
before.as_deref(),
after.as_deref(),
json,
),
Commands::Contract { action } => dispatch_contract(action),
Commands::Trace { binary, duration } => {
profilers::cuda::run_trace(&binary, duration.as_deref())
}
Commands::Explain { target, kernel } => {
analysis::explain::run_explain(&target, kernel.as_deref())
}
Commands::Tui => {
println!("cgp tui: interactive mode (requires presentar)");
println!(" (Not yet implemented — use stdout commands for now)");
Ok(())
}
Commands::Baseline { save, load } => {
analysis::baseline::run_baseline(save.as_deref(), load.as_deref())
}
Commands::Compete {
workload,
ours,
theirs,
label,
} => analysis::compete::run_compete(&workload, &ours, &theirs, label.as_deref(), json),
}
}
fn dispatch_profile(target: ProfileTarget, json: bool) -> Result<()> {
match target {
ProfileTarget::Kernel {
name,
size,
roofline,
metrics,
} => profilers::cuda::profile_kernel(&name, size, roofline, metrics.as_deref()),
ProfileTarget::Cublas { op, size } => profilers::cuda::profile_cublas(&op, size),
ProfileTarget::Wgpu {
shader,
dispatch,
target,
} => {
profilers::wgpu_profiler::profile_wgpu(&shader, dispatch.as_deref(), target.as_deref())
}
ProfileTarget::Metal { shader, dispatch } => {
#[cfg(target_os = "macos")]
{
println!("cgp profile metal: shader={shader} dispatch={dispatch:?}");
Ok(())
}
#[cfg(not(target_os = "macos"))]
{
let _ = (&shader, dispatch);
anyhow::bail!("Metal backend requires macOS -- use --backend wgpu for Vulkan")
}
}
ProfileTarget::Simd {
function,
size,
arch,
} => profilers::simd::profile_simd(&function, size, &arch),
ProfileTarget::Wasm { function, size } => profilers::wasm::profile_wasm(&function, size),
ProfileTarget::Quant { kernel, size, all } => {
if all {
profilers::quant::profile_quant_all()
} else {
profilers::quant::profile_quant(
kernel.as_deref().unwrap_or("q4k_gemv"),
size.as_deref().unwrap_or("4096x1x4096"),
)
}
}
ProfileTarget::Scalar { function, size } => {
profilers::scalar::profile_scalar(&function, size)
}
ProfileTarget::Parallel {
function,
size,
threads,
} => profilers::rayon_parallel::profile_parallel(&function, size, threads.as_deref()),
ProfileTarget::Compare {
kernel,
size,
backends,
} => analysis::compare::run_compare(&kernel, size, &backends, json),
ProfileTarget::Scaling {
size,
max_threads,
runs,
} => profilers::rayon_parallel::profile_scaling(size, max_threads, runs, json),
ProfileTarget::Binary {
path,
kernel_filter,
trace,
duration,
} => profilers::cuda::profile_binary(
&path,
kernel_filter.as_deref(),
trace,
duration.as_deref(),
),
ProfileTarget::Python { args } => profilers::cuda::profile_python(&args),
ProfileTarget::Library { so, symbol, args } => {
println!("cgp profile library: {so}::{symbol} args={args:?}");
Ok(())
}
}
}
fn dispatch_contract(action: ContractAction) -> Result<()> {
match action {
ContractAction::Verify {
contracts_dir,
contract,
fail_on_regression,
self_verify,
} => analysis::contracts::run_verify(
contracts_dir.as_deref(),
contract.as_deref(),
self_verify,
fail_on_regression,
),
ContractAction::Generate {
kernel,
size,
tolerance,
} => analysis::contracts::run_generate(&kernel, size, tolerance),
}
}