use std::collections::HashSet;
use std::path::Path;
use std::{env, fs};
const TEST_PREFIX: &str = "test_";
const INTEGRATION_TESTS_HEADER: &str = r#"// Auto-generated integration tests
//
// This module is automatically generated by the build script from test functions
// named test_*. Do not edit manually.
"#;
const INTEGRATION_TESTS_IMPORTS: &str = r#"use anyhow::{anyhow, Result};
use miden_client_integration_tests::tests::config::ClientConfig;
use miden_client::rpc::Endpoint;
use url::Url;"#;
const TOKIO_TEST_WRAPPER: &str = r#"/// Auto-generated tokio test wrapper for {ORIGINAL_FUNCTION_NAME}
#[tokio::test]
async fn {TEST_FUNCTION_NAME}() -> Result<()> {{
// Use default test configuration
let endpoint_url = std::env::var("TEST_MIDEN_RPC_ENDPOINT")
.unwrap_or_else(|_| Endpoint::localhost().to_string());
let url = Url::parse(&endpoint_url).map_err(|_| anyhow!("Invalid RPC endpoint URL"))?;
let host = url
.host_str()
.ok_or_else(|| anyhow!("RPC endpoint URL is missing a host"))?
.to_string();
let port = url.port().ok_or_else(|| anyhow!("RPC endpoint URL is missing a port"))?;
let endpoint = Endpoint::new(url.scheme().to_string(), host, Some(port));
let timeout = std::env::var("TEST_TIMEOUT")
.unwrap_or_else(|_| "10000".to_string())
.parse::<u64>()
.map_err(|_| anyhow!("Invalid timeout value"))?;
let client_config = ClientConfig::new(endpoint, timeout);
{ORIGINAL_FUNCTION_NAME}(client_config).await
}}"#;
const TEST_REGISTRY_HEADER: &str = r#"// Auto-generated test cases module
//
// This module is automatically generated by the build script from test functions
// named test_*. Do not edit manually.
"#;
const TEST_REGISTRY_IMPORTS: &str = r#"use super::{TestCase, TestCategory};"#;
const TEST_REGISTRY_FUNCTION: &str = r#"/// Returns all available test cases organized by category.
pub fn get_all_tests() -> Vec<TestCase> {{
vec![
{TEST_CASES}
]
}}"#;
fn main() {
println!("cargo:rerun-if-changed=src/tests/");
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:info=Running build script to generate integration tests");
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
let test_cases = collect_test_cases();
println!("cargo:info=Found {} test cases", test_cases.len());
let integration_path = out_path.join("integration_tests.rs");
let integration_code = generate_integration_tests(&test_cases);
fs::write(&integration_path, integration_code).unwrap();
println!("cargo:info=Generated tokio test wrappers in {}", integration_path.display());
let generated_path = out_path.join("test_registry.rs");
let generated_code = generate_test_case_vector(&test_cases);
fs::write(&generated_path, generated_code).unwrap();
println!("cargo:info=Generated test case vector in {}", generated_path.display());
}
#[derive(Debug)]
struct TestCaseInfo {
name: String,
category: String,
function_name: String,
}
fn collect_test_cases() -> Vec<TestCaseInfo> {
let mut test_cases = Vec::new();
let src_dir = Path::new("src/tests");
if src_dir.exists() && src_dir.is_dir() {
collect_test_cases_recursive(src_dir, &mut test_cases);
}
test_cases
}
fn collect_test_cases_recursive(current_dir: &Path, test_cases: &mut Vec<TestCaseInfo>) {
for entry in fs::read_dir(current_dir).unwrap() {
let entry = entry.unwrap();
let path = entry.path();
if path.is_dir() {
collect_test_cases_recursive(&path, test_cases);
} else if path.extension().and_then(|s| s.to_str()) == Some("rs") {
let mut file_test_cases = collect_test_cases_from_file(&path);
test_cases.append(&mut file_test_cases);
}
}
}
fn collect_test_cases_from_file(file_path: &Path) -> Vec<TestCaseInfo> {
let mut test_cases = Vec::new();
let category = match extract_category_from_path(file_path) {
Some(cat) => cat,
None => return test_cases, };
let content = match fs::read_to_string(file_path) {
Ok(content) => content,
Err(_) => return test_cases,
};
for line in content.lines() {
let trimmed = line.trim();
if let Some(function_name) = parse_test_function_name(trimmed) {
test_cases.push(TestCaseInfo {
name: function_name.clone(),
category: category.clone(),
function_name,
});
}
}
test_cases
}
fn extract_category_from_path(file_path: &Path) -> Option<String> {
let file_stem = file_path.file_stem()?.to_str()?;
if file_stem == "mod" {
return None;
}
if !file_path.to_str()?.contains("src/tests/") {
return None;
}
Some(file_stem.to_string())
}
fn parse_test_function_name(line: &str) -> Option<String> {
let s = line.trim();
if s.is_empty() || s.starts_with("//") || s.starts_with("/*") {
return None;
}
let tokens: Vec<&str> = s.split_whitespace().collect();
let fn_pos = if tokens[0] == "pub" && tokens[1] == "async" && tokens[2] == "fn" {
2 } else if tokens[0] == "pub" && tokens[1] == "fn" {
1 } else {
return None;
};
let name_token = tokens.get(fn_pos + 1)?;
let ident: String = name_token
.chars()
.take_while(|c| c.is_ascii_alphanumeric() || *c == '_')
.collect();
let after_prefix = ident.strip_prefix(TEST_PREFIX)?;
if after_prefix.is_empty() {
return None;
}
Some(format!("{TEST_PREFIX}{after_prefix}"))
}
fn generate_integration_tests(test_cases: &[TestCaseInfo]) -> String {
let mut result = String::new();
result.push_str(INTEGRATION_TESTS_HEADER);
result.push_str("\n\n");
result.push_str(INTEGRATION_TESTS_IMPORTS);
result.push('\n');
let mut modules = HashSet::new();
for test_case in test_cases {
let module_name = &test_case.category;
modules.insert(module_name);
}
for module in modules {
result.push_str(&format!("use miden_client_integration_tests::tests::{module}::*;\n"));
}
result.push('\n');
for test_case in test_cases {
let test_fn_name = test_case.function_name.strip_prefix(TEST_PREFIX).unwrap().to_string();
let test_wrapper = TOKIO_TEST_WRAPPER
.replace("{ORIGINAL_FUNCTION_NAME}", &test_case.function_name)
.replace("{TEST_FUNCTION_NAME}", &test_fn_name);
result.push_str(&test_wrapper);
result.push_str("\n\n");
}
result
}
fn generate_test_case_vector(test_cases: &[TestCaseInfo]) -> String {
let mut result = String::new();
result.push_str(TEST_REGISTRY_HEADER);
result.push_str("\n\n");
result.push_str(TEST_REGISTRY_IMPORTS);
result.push('\n');
let mut modules = HashSet::new();
for test_case in test_cases {
let module_name = &test_case.category;
modules.insert(module_name);
}
for module in modules {
result.push_str(&format!("use crate::tests::{module}::*;\n"));
}
let mut test_cases_str = String::new();
for test_case in test_cases {
let category_variant =
format!("TestCategory::{}", snake_case_to_pascal_case(&test_case.category));
test_cases_str.push_str(&format!(
" TestCase::new(\"{}\", {}, {}),\n",
test_case.name, category_variant, test_case.function_name
));
}
result.push('\n');
let function_code = TEST_REGISTRY_FUNCTION.replace("{TEST_CASES}", &test_cases_str);
result.push_str(&function_code);
result
}
fn snake_case_to_pascal_case(snake_str: &str) -> String {
snake_str
.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
})
.collect()
}