use crate::{CodeGenerator, GeneratedFile, GeneratorConfig, SchemaAnalyzer};
use serde_json::Value;
use std::env;
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use tempfile::TempDir;
pub struct GenerationTestResult {
pub temp_dir: TempDir,
pub project_dir: PathBuf,
pub generated_src_dir: PathBuf,
pub files: Vec<GeneratedFile>,
pub compilation_output: Option<std::process::Output>,
pub cargo_toml_path: PathBuf,
}
impl GenerationTestResult {
pub fn read_file(&self, filename: &str) -> std::io::Result<String> {
let path = self.generated_src_dir.join(filename);
fs::read_to_string(path)
}
pub fn compiled_successfully(&self) -> bool {
self.compilation_output
.as_ref()
.map(|o| o.status.success())
.unwrap_or(false)
}
pub fn compilation_errors(&self) -> String {
self.compilation_output
.as_ref()
.map(|o| String::from_utf8_lossy(&o.stderr).to_string())
.unwrap_or_default()
}
pub fn test_output(&self) -> Option<String> {
self.compilation_output
.as_ref()
.map(|o| String::from_utf8_lossy(&o.stdout).to_string())
}
}
pub struct GenerationTest {
pub name: String,
pub spec: Value,
pub config_overrides: Option<GeneratorConfigOverrides>,
pub test_scenarios: Vec<TestScenario>,
pub run_tests: bool,
pub compile: bool,
pub extra_dependencies: Vec<(String, String)>,
}
#[derive(Clone)]
pub struct TestScenario {
pub name: String,
pub target_type: String,
pub behavior: TestBehavior,
}
#[derive(Clone)]
pub enum TestBehavior {
RoundTrip {
json: Value,
assertions: Vec<(String, Value)>,
},
Construction {
fields: Vec<(String, FieldValue)>,
assertions: Vec<ConstructionAssertion>,
},
CompileOnly,
EnumMatch {
json: Value,
expected_variant: String,
variant_assertions: Vec<(String, Value)>,
},
}
#[derive(Clone)]
pub enum FieldValue {
String(String),
Integer(i64),
Float(f64),
Boolean(bool),
Null,
EnumVariant(String),
Array(Vec<FieldValue>),
Struct(String, Vec<(String, FieldValue)>),
}
#[derive(Clone)]
pub enum ConstructionAssertion {
JsonContains(String),
FieldEquals(String, Value),
CanSerialize,
}
pub struct GeneratorConfigOverrides {
pub module_name: Option<String>,
pub enable_sse_client: Option<bool>,
pub enable_async_client: Option<bool>,
}
impl GenerationTest {
pub fn new(name: impl Into<String>, spec: Value) -> Self {
Self {
name: name.into(),
spec,
..Default::default()
}
}
pub fn with_compilation(mut self) -> Self {
self.compile = true;
self
}
pub fn with_env_compilation(mut self) -> Self {
self.compile = should_compile_tests();
self
}
pub fn with_scenario(mut self, scenario: TestScenario) -> Self {
self.test_scenarios.push(scenario);
self
}
pub fn run_tests(mut self, run: bool) -> Self {
self.run_tests = run;
self
}
}
impl Default for GenerationTest {
fn default() -> Self {
Self {
name: "test".to_string(),
spec: Value::Null,
config_overrides: None,
test_scenarios: vec![],
run_tests: true,
extra_dependencies: vec![],
compile: false, }
}
}
pub fn should_compile_tests() -> bool {
env::var("OPENAPI_GEN_COMPILE_TESTS")
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
pub fn test_generation(name: &str, spec: Value) -> Result<String, Box<dyn std::error::Error>> {
let mut analyzer = SchemaAnalyzer::new(spec)?;
let mut analysis = analyzer.analyze()?;
let config = GeneratorConfig {
module_name: name.to_string(),
..Default::default()
};
let generator = CodeGenerator::new(config);
let generated_code = generator.generate(&mut analysis)?;
insta::assert_snapshot!(name, &generated_code);
Ok(generated_code)
}
pub fn run_generation_test(
test: GenerationTest,
) -> Result<GenerationTestResult, Box<dyn std::error::Error>> {
let temp_dir = TempDir::new()?;
let project_dir = temp_dir.path().join(&test.name);
fs::create_dir(&project_dir)?;
let src_dir = project_dir.join("src");
fs::create_dir(&src_dir)?;
let generated_src_dir = src_dir.join("generated");
fs::create_dir(&generated_src_dir)?;
let mut analyzer = SchemaAnalyzer::new(test.spec.clone())?;
let analysis = analyzer.analyze()?;
let config = GeneratorConfig {
spec_path: PathBuf::from("test.json"), output_dir: generated_src_dir.clone(),
module_name: test
.config_overrides
.as_ref()
.and_then(|o| o.module_name.clone())
.unwrap_or_else(|| "generated".to_string()),
enable_sse_client: test
.config_overrides
.as_ref()
.and_then(|o| o.enable_sse_client)
.unwrap_or(false),
enable_async_client: test
.config_overrides
.as_ref()
.and_then(|o| o.enable_async_client)
.unwrap_or(false),
enable_specta: false,
type_mappings: {
let mut mappings = std::collections::BTreeMap::new();
mappings.insert("integer".to_string(), "i64".to_string());
mappings.insert("number".to_string(), "f64".to_string());
mappings.insert("string".to_string(), "String".to_string());
mappings.insert("boolean".to_string(), "bool".to_string());
mappings
},
streaming_config: None,
nullable_field_overrides: std::collections::BTreeMap::new(),
schema_extensions: vec![],
http_client_config: None,
retry_config: None,
tracing_enabled: true,
auth_config: None,
enable_registry: false,
registry_only: false,
};
let generator = CodeGenerator::new(config);
let mut analysis_mut = analysis;
let types_content = generator.generate(&mut analysis_mut)?;
let files = vec![GeneratedFile {
path: "types.rs".into(),
content: types_content.clone(),
}];
for file in &files {
let dest_path = generated_src_dir.join(&file.path);
if let Some(parent) = dest_path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(dest_path, &file.content)?;
}
let mut dependencies = vec![
("serde", r#"{ version = "1.0", features = ["derive"] }"#),
("serde_json", r#""1.0""#),
("async-trait", r#""0.1""#),
(
"reqwest",
r#"{ version = "0.12", features = ["json", "stream"] }"#,
),
("futures-util", r#""0.3""#),
("tokio", r#"{ version = "1.0", features = ["full"] }"#),
("tracing", r#""0.1""#),
];
for (name, version) in &test.extra_dependencies {
dependencies.push((name.as_str(), version.as_str()));
}
let deps_str = dependencies
.iter()
.map(|(name, ver)| format!("{name} = {ver}"))
.collect::<Vec<_>>()
.join("\n");
let cargo_toml = format!(
r#"[package]
name = "{}"
version = "0.1.0"
edition = "2021"
[dependencies]
{}
"#,
test.name.replace('-', "_"),
deps_str
);
let cargo_toml_path = project_dir.join("Cargo.toml");
fs::write(&cargo_toml_path, cargo_toml)?;
fs::write(generated_src_dir.join("mod.rs"), "pub mod types;\n")?;
let lib_rs = create_lib_rs(&test, &files);
fs::write(src_dir.join("lib.rs"), lib_rs)?;
let compilation_output = if !test.compile {
None
} else if test.run_tests {
Some(
Command::new("cargo")
.arg("test")
.arg("--")
.arg("--nocapture")
.current_dir(&project_dir)
.env("RUST_BACKTRACE", "1")
.output()?,
)
} else {
Some(
Command::new("cargo")
.arg("check")
.current_dir(&project_dir)
.env("RUST_BACKTRACE", "1")
.output()?,
)
};
if let Some(ref output) = compilation_output {
if !output.status.success() {
if let Ok(types_content) = fs::read_to_string(generated_src_dir.join("types.rs")) {
eprintln!("Generated types.rs:\n{types_content}");
}
}
}
Ok(GenerationTestResult {
temp_dir,
project_dir,
generated_src_dir,
files,
compilation_output,
cargo_toml_path,
})
}
fn create_lib_rs(test: &GenerationTest, _generated_files: &[GeneratedFile]) -> String {
let mut lib_content = String::from("pub mod generated;\n\n");
lib_content.push_str("#[cfg(test)]\nmod tests {\n");
lib_content.push_str(" use super::generated::types::*;\n");
lib_content.push_str(" use serde_json;\n\n");
lib_content.push_str(" #[test]\n");
lib_content.push_str(" fn test_compilation() {\n");
lib_content.push_str(&format!(
" // Generated types compile for test: {}\n",
test.name
));
lib_content.push_str(" }\n\n");
for scenario in &test.test_scenarios {
lib_content.push_str(&generate_test_from_scenario(scenario));
lib_content.push('\n');
}
lib_content.push_str("}\n");
lib_content
}
fn generate_test_from_scenario(scenario: &TestScenario) -> String {
let mut test_code = String::new();
let target_type = to_rust_type_name(&scenario.target_type);
test_code.push_str(&format!(" #[test]\n fn {}() {{\n", scenario.name));
match &scenario.behavior {
TestBehavior::RoundTrip { json, assertions } => {
test_code.push_str(&format!(" let json_str = r#\"{json}\"#;\n"));
test_code.push_str(&format!(
" let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
));
for (field, expected) in assertions {
test_code.push_str(&format!(
" assert_eq!(parsed.{field}, serde_json::json!({expected}));\n"
));
}
test_code
.push_str(" let serialized = serde_json::to_string(&parsed).unwrap();\n");
test_code.push_str(&format!(
" let round_trip: {target_type} = serde_json::from_str(&serialized).unwrap();\n"
));
test_code.push_str(" assert_eq!(serde_json::to_value(parsed).unwrap(), serde_json::to_value(round_trip).unwrap());\n");
}
TestBehavior::Construction { fields, assertions } => {
test_code.push_str(&format!(" let instance = {target_type} {{\n"));
for (field_name, field_value) in fields {
test_code.push_str(&format!(
" {}: {},\n",
field_name,
field_value_to_code(field_value)
));
}
test_code.push_str(" };\n\n");
for assertion in assertions {
match assertion {
ConstructionAssertion::JsonContains(text) => {
test_code.push_str(
" let json = serde_json::to_string(&instance).unwrap();\n",
);
test_code
.push_str(&format!(" assert!(json.contains(\"{text}\"));\n"));
}
ConstructionAssertion::FieldEquals(field, value) => {
test_code.push_str(&format!(
" assert_eq!(instance.{field}, serde_json::json!({value}));\n"
));
}
ConstructionAssertion::CanSerialize => {
test_code.push_str(
" let _json = serde_json::to_string(&instance).unwrap();\n",
);
}
}
}
}
TestBehavior::CompileOnly => {
test_code.push_str(&format!(
" // Type {target_type} exists and compiles\n"
));
test_code.push_str(&format!(" let _: Option<{target_type}> = None;\n"));
}
TestBehavior::EnumMatch {
json,
expected_variant,
variant_assertions,
} => {
test_code.push_str(&format!(" let json_str = r#\"{json}\"#;\n"));
test_code.push_str(&format!(
" let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
));
test_code.push_str(" match parsed {\n");
if variant_assertions.is_empty() {
test_code.push_str(&format!(
" {target_type}::{expected_variant} {{ .. }} => {{\n"
));
test_code.push_str(" // Variant matched successfully\n");
} else {
let field_bindings: Vec<String> = variant_assertions
.iter()
.map(|(field, _)| field.clone())
.collect();
if field_bindings.is_empty() {
test_code.push_str(&format!(
" {target_type}::{expected_variant} {{ .. }} => {{\n"
));
} else {
let bindings = field_bindings.join(", ");
test_code.push_str(&format!(
" {target_type}::{expected_variant} {{ {bindings}, .. }} => {{\n"
));
for (field, expected) in variant_assertions {
test_code.push_str(&format!(
" assert_eq!({field}, serde_json::json!({expected}));\n"
));
}
}
}
test_code.push_str(" }\n");
test_code.push_str(&format!(
" _ => panic!(\"Expected {expected_variant} variant\"),\n"
));
test_code.push_str(" }\n");
}
}
test_code.push_str(" }\n");
test_code
}
fn field_value_to_code(value: &FieldValue) -> String {
match value {
FieldValue::String(s) => format!("\"{s}\".to_string()"),
FieldValue::Integer(i) => i.to_string(),
FieldValue::Float(f) => f.to_string(),
FieldValue::Boolean(b) => b.to_string(),
FieldValue::Null => "None".to_string(),
FieldValue::EnumVariant(variant) => variant.clone(),
FieldValue::Array(values) => {
let items: Vec<String> = values.iter().map(field_value_to_code).collect();
format!("vec![{}]", items.join(", "))
}
FieldValue::Struct(type_name, fields) => {
let mut code = format!("{type_name} {{\n");
for (field_name, field_value) in fields {
code.push_str(&format!(
" {}: {},\n",
field_name,
field_value_to_code(field_value)
));
}
code.push_str(" }");
code
}
}
}
pub fn minimal_spec(schemas: Value) -> Value {
serde_json::json!({
"openapi": "3.0.0",
"info": {
"title": "Test API",
"version": "1.0.0"
},
"components": {
"schemas": schemas
}
})
}
pub fn assert_compilation_success(result: &GenerationTestResult) {
if let Some(output) = &result.compilation_output {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
let stdout = String::from_utf8_lossy(&output.stdout);
panic!("Compilation failed!\nSTDERR:\n{stderr}\nSTDOUT:\n{stdout}");
}
} else {
panic!("No compilation was run");
}
}
pub fn find_type_definition(result: &GenerationTestResult, type_name: &str) -> Option<String> {
for file in &result.files {
for line in file.content.lines() {
if (line.contains(&format!("struct {type_name}"))
|| line.contains(&format!("enum {type_name}"))
|| line.contains(&format!("type {type_name} =")))
&& !line.trim().starts_with("//")
{
return Some(line.to_string());
}
}
}
None
}
pub fn assert_no_underscores_in_type_name(type_definition: &str) {
if let Some(type_name) = extract_type_name(type_definition) {
assert!(
!type_name.contains('_'),
"Type name '{type_name}' should not contain underscores"
);
}
}
fn extract_type_name(type_def: &str) -> Option<String> {
let parts: Vec<&str> = type_def.split_whitespace().collect();
if parts.len() >= 2 {
let name = parts[1].split('<').next()?.split('=').next()?.trim();
Some(name.to_string())
} else {
None
}
}
fn to_rust_type_name(s: &str) -> String {
let mut result = String::new();
let mut next_upper = true;
for c in s.chars() {
match c {
'a'..='z' => {
if next_upper {
result.push(c.to_ascii_uppercase());
next_upper = false;
} else {
result.push(c);
}
}
'A'..='Z' => {
result.push(c);
next_upper = false;
}
'0'..='9' => {
result.push(c);
next_upper = false;
}
'_' | '-' | '.' | ' ' => {
next_upper = true;
}
_ => {
next_upper = true;
}
}
}
result
}