use std::collections::HashSet;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use crate::qs::{format_ident, quote};
use heck::{ToPascalCase, ToSnakeCase};
use proc_macro2::{Ident, TokenStream};
use regex::Regex;
use std::io::Write;
#[cfg(test)]
use toml;
#[macro_export]
macro_rules! assert_tokens_snapshot {
($output:expr) => {{
let mut settings = insta::Settings::new();
settings.set_prepend_module_to_snapshot(false);
settings.set_omit_expression(true);
settings.bind(|| {
let formatted_output = $crate::pretty_print(&$output);
insta::assert_snapshot!(formatted_output);
});
}};
}
#[macro_export]
macro_rules! assert_rust_compilation {
($output:expr) => {{
let formatted_output = $crate::pretty_print(&$output);
let full_name = stdext::function_name!();
let test_name = full_name.split("::").last().unwrap_or(full_name).replace("::", "_");
if let Err(e) = $crate::test_helper::try_compilation_test_with_name(&formatted_output, &test_name) {
panic!("Generated code failed to compile: {e}\n\n");
}
}};
}
pub fn try_compilation_test(generated_code: &str) -> Result<(), String> {
try_compilation_test_with_name(generated_code, "wgsl_bindgen_compile_test")
}
pub fn try_compilation_test_with_name(
generated_code: &str,
test_name: &str,
) -> Result<(), String> {
use std::fs;
use std::path::PathBuf;
let output_dir = PathBuf::from("tests").join("output");
let temp_dir = output_dir.join("compile_test_workspace").join(test_name);
fs::create_dir_all(&temp_dir)
.map_err(|e| format!("Failed to create temp directory: {e}"))?;
let compile_test =
create_single_file_compile_test(&temp_dir, test_name, generated_code)?;
let result = match compile_test.test_compilation() {
Ok(true) => {
Ok(())
}
Ok(false) => Err(
"Generated code failed to compile (see previous output for details)".to_string(),
),
Err(e) => Err(format!("Compilation test setup failed: {e}")),
};
result
}
fn create_single_file_compile_test(
workspace_dir: &std::path::Path,
project_name: &str,
generated_code: &str,
) -> Result<SingleFileCompileTest, String> {
use std::fs;
let src_dir = workspace_dir.join("src");
fs::create_dir_all(&src_dir).map_err(|e| format!("Failed to create src dir: {e}"))?;
let dependencies = detect_required_dependencies_from_content(generated_code);
let cargo_toml = generate_single_file_cargo_toml(project_name, &dependencies);
fs::write(workspace_dir.join("Cargo.toml"), cargo_toml)
.map_err(|e| format!("Failed to write Cargo.toml: {e}"))?;
fs::write(src_dir.join("lib.rs"), generated_code)
.map_err(|e| format!("Failed to write lib.rs: {e}"))?;
Ok(SingleFileCompileTest {
workspace_dir: workspace_dir.to_path_buf(),
})
}
struct SingleFileCompileTest {
workspace_dir: std::path::PathBuf,
}
impl SingleFileCompileTest {
pub fn test_compilation(&self) -> Result<bool, Box<dyn std::error::Error>> {
use std::process::Command;
let mut workspace_root =
std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
loop {
let cargo_toml = workspace_root.join("Cargo.toml");
if cargo_toml.exists() {
if let Ok(contents) = std::fs::read_to_string(&cargo_toml) {
if contents.contains("[workspace]") && !contents.contains("[workspace]\n\n[") {
break;
}
}
}
if !workspace_root.pop() {
workspace_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
break;
}
}
let target_dir = workspace_root.join("target");
let output = Command::new("cargo")
.arg("check")
.arg("--all-features")
.arg("--target-dir")
.arg(target_dir.to_str().unwrap())
.arg("--color=always") .current_dir(&self.workspace_dir)
.env("TERM", "xterm-256color") .output()?;
if output.status.success() {
println!("✓ Generated file compiles successfully");
Ok(true)
} else {
eprintln!("✗ Compilation failed:");
eprintln!("stdout: {}", String::from_utf8_lossy(&output.stdout));
eprintln!("stderr: {}", String::from_utf8_lossy(&output.stderr));
Ok(false)
}
}
}
fn detect_required_dependencies_from_content(
content: &str,
) -> std::collections::HashSet<String> {
let mut deps = std::collections::HashSet::new();
if content.contains("wgpu::") || content.contains("wgpu_types") {
deps.insert("wgpu".to_string());
}
if content.contains("glam::") {
deps.insert("glam".to_string());
}
if content.contains("bytemuck::") {
deps.insert("bytemuck".to_string());
}
if content.contains("encase::") {
deps.insert("encase".to_string());
}
if content.contains("naga_oil::") {
deps.insert("naga_oil".to_string());
}
deps.insert("wgpu".to_string());
deps.insert("glam".to_string());
deps.insert("bytemuck".to_string());
deps.insert("encase".to_string());
deps
}
fn generate_single_file_cargo_toml(
project_name: &str,
dependencies: &std::collections::HashSet<String>,
) -> String {
let mut cargo_toml = format!(
r#"[package]
name = "{project_name}"
version = "0.1.0"
edition = "2021"
# Empty workspace to avoid conflicts with parent workspace
[workspace]
[dependencies]
"#
);
#[cfg(test)]
let workspace_deps = read_workspace_dependencies().unwrap_or_default();
#[cfg(not(test))]
let workspace_deps: std::collections::HashMap<String, String> =
std::collections::HashMap::new();
for dep in dependencies {
match dep.as_str() {
"wgpu" => {
let wgpu_version = workspace_deps
.get("wgpu")
.map(|s| s.as_str())
.unwrap_or("29.0");
let naga_version = workspace_deps
.get("naga")
.map(|s| s.as_str())
.unwrap_or("29.0");
cargo_toml.push_str(&format!("wgpu = {{ version = \"{wgpu_version}\", features = [\"wgsl\", \"naga-ir\"] }}\nnaga = {{ version = \"{naga_version}\", features = [\"wgsl-out\"] }}\n"));
}
"glam" => {
let version = workspace_deps
.get("glam")
.map(|s| s.as_str())
.unwrap_or("0.30");
cargo_toml.push_str(&format!("glam = \"{version}\"\n"));
}
"bytemuck" => {
let version = workspace_deps
.get("bytemuck")
.map(|s| s.as_str())
.unwrap_or("1.13");
cargo_toml.push_str(&format!(
"bytemuck = {{ version = \"{version}\", features = [\"derive\"] }}\n"
));
}
"encase" => {
let version = workspace_deps
.get("encase")
.map(|s| s.as_str())
.unwrap_or("0.11");
cargo_toml.push_str(&format!("encase = \"{version}\"\n"));
}
"naga_oil" => {
let version = workspace_deps
.get("naga_oil")
.map(|s| s.as_str())
.unwrap_or("0.22");
cargo_toml.push_str(&format!("naga_oil = \"{version}\"\n"));
}
_ => {}
}
}
cargo_toml
}
#[cfg(test)]
fn read_workspace_dependencies(
) -> Result<std::collections::HashMap<String, String>, Box<dyn std::error::Error>> {
use std::collections::HashMap;
let workspace_root = find_workspace_root()?;
let cargo_toml_path = workspace_root.join("Cargo.toml");
let content = std::fs::read_to_string(cargo_toml_path)?;
let parsed: toml::Value = content.parse()?;
let mut deps = HashMap::new();
if let Some(workspace) = parsed.get("workspace") {
if let Some(dependencies) = workspace.get("dependencies") {
if let Some(deps_table) = dependencies.as_table() {
for (name, value) in deps_table {
let version = match value {
toml::Value::String(version_str) => version_str.clone(),
toml::Value::Table(table) => {
if let Some(version_value) = table.get("version") {
if let Some(version_str) = version_value.as_str() {
version_str.to_string()
} else {
continue;
}
} else {
continue;
}
}
_ => continue,
};
deps.insert(name.clone(), version);
}
}
}
}
Ok(deps)
}
#[cfg(test)]
fn find_workspace_root() -> Result<PathBuf, Box<dyn std::error::Error>> {
let mut current = std::env::current_dir()?;
loop {
let cargo_toml = current.join("Cargo.toml");
if cargo_toml.exists() {
let content = std::fs::read_to_string(&cargo_toml)?;
if content.contains("[workspace]") {
return Ok(current);
}
}
match current.parent() {
Some(parent) => current = parent.to_path_buf(),
None => return Err("Could not find workspace root".into()),
}
}
}