use anyhow::Result;
use std::path::Path;
pub fn detect_wgpu_backend() -> &'static str {
#[cfg(target_os = "linux")]
{
"Vulkan"
}
#[cfg(target_os = "macos")]
{
"Metal"
}
#[cfg(target_os = "windows")]
{
"DX12"
}
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
{
"Unknown"
}
}
pub fn validate_shader(shader_path: &str) -> Result<ShaderInfo> {
let path = Path::new(shader_path);
if !path.exists() {
anyhow::bail!("Shader file not found: {shader_path}");
}
let content = std::fs::read_to_string(path)?;
let lines = content.lines().count();
let workgroup_size = content.lines().find_map(|line| {
if line.contains("@workgroup_size") {
let start = line.find('(')? + 1;
let end = line.find(')')?;
Some(line[start..end].to_string())
} else {
None
}
});
let has_compute = content.contains("@compute");
Ok(ShaderInfo {
path: shader_path.to_string(),
lines,
workgroup_size,
has_compute,
})
}
#[derive(Debug)]
pub struct ShaderInfo {
pub path: String,
pub lines: usize,
pub workgroup_size: Option<String>,
pub has_compute: bool,
}
pub fn parse_dispatch(dispatch: &str) -> Result<[u32; 3]> {
let parts: Vec<&str> = dispatch.split(',').collect();
if parts.len() != 3 {
anyhow::bail!("Dispatch must be X,Y,Z format (got: {dispatch})");
}
Ok([
parts[0].trim().parse()?,
parts[1].trim().parse()?,
parts[2].trim().parse()?,
])
}
pub fn profile_wgpu(shader: &str, dispatch: Option<&str>, target: Option<&str>) -> Result<()> {
let target_str = target.unwrap_or("native");
let backend = detect_wgpu_backend();
println!("\n=== CGP wgpu Profile: {shader} (target={target_str}) ===\n");
println!(" Shader: {shader}");
println!(" Backend: wgpu ({backend})");
if let Some(d) = dispatch {
let dims = parse_dispatch(d)?;
println!(" Dispatch: {}x{}x{}", dims[0], dims[1], dims[2]);
let total_invocations = dims[0] as u64 * dims[1] as u64 * dims[2] as u64;
println!(" Total workgroups: {total_invocations}");
}
println!(" Method: TIMESTAMP_QUERY for GPU-side timing (~1ns resolution)");
if Path::new(shader).exists() {
match validate_shader(shader) {
Ok(info) => {
println!(" Shader lines: {}", info.lines);
if let Some(ws) = &info.workgroup_size {
println!(" Workgroup size: {ws}");
}
if !info.has_compute {
println!(" \x1b[33m[WARN]\x1b[0m No @compute entry point found in shader");
}
}
Err(e) => println!(" \x1b[33m[WARN]\x1b[0m Shader validation: {e}"),
}
}
if target_str == "web" {
let has_chrome = which::which("google-chrome").is_ok()
|| which::which("chromium").is_ok()
|| which::which("chromium-browser").is_ok();
if !has_chrome {
println!(" No browser found -- falling back to wgpu native (Vulkan/Metal)");
} else {
println!(" Browser: headless Chrome (Chrome DevTools Protocol)");
}
}
println!();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wgpu_profile_runs() {
let result = profile_wgpu("test.wgsl", Some("256,256,1"), Some("native"));
assert!(result.is_ok());
}
#[test]
fn test_detect_backend() {
let backend = detect_wgpu_backend();
assert!(!backend.is_empty());
#[cfg(target_os = "linux")]
assert_eq!(backend, "Vulkan");
}
#[test]
fn test_parse_dispatch() {
let dims = parse_dispatch("256,256,1").unwrap();
assert_eq!(dims, [256, 256, 1]);
}
#[test]
fn test_parse_dispatch_bad() {
assert!(parse_dispatch("256,256").is_err());
}
#[test]
fn test_validate_shader_missing() {
assert!(validate_shader("/tmp/nonexistent_shader.wgsl").is_err());
}
#[test]
fn test_web_target_graceful() {
let result = profile_wgpu("test.wgsl", None, Some("web"));
assert!(result.is_ok());
}
}