Skip to main content

openapi_to_rust/
test_helpers.rs

1use crate::{CodeGenerator, GeneratedFile, GeneratorConfig, SchemaAnalyzer};
2use serde_json::Value;
3use std::env;
4use std::fs;
5use std::path::PathBuf;
6use std::process::Command;
7use tempfile::TempDir;
8
9/// Result of a generation test
10pub struct GenerationTestResult {
11    /// Temporary directory containing all generated files
12    pub temp_dir: TempDir,
13    /// Path to the test project directory
14    pub project_dir: PathBuf,
15    /// Path to the generated source directory
16    pub generated_src_dir: PathBuf,
17    /// Generated files
18    pub files: Vec<GeneratedFile>,
19    /// Compilation output (if compilation was run)
20    pub compilation_output: Option<std::process::Output>,
21    /// Path to Cargo.toml
22    pub cargo_toml_path: PathBuf,
23}
24
25impl GenerationTestResult {
26    /// Read a generated file by name
27    pub fn read_file(&self, filename: &str) -> std::io::Result<String> {
28        let path = self.generated_src_dir.join(filename);
29        fs::read_to_string(path)
30    }
31
32    /// Check if compilation succeeded
33    pub fn compiled_successfully(&self) -> bool {
34        self.compilation_output
35            .as_ref()
36            .map(|o| o.status.success())
37            .unwrap_or(false)
38    }
39
40    /// Get compilation errors
41    pub fn compilation_errors(&self) -> String {
42        self.compilation_output
43            .as_ref()
44            .map(|o| String::from_utf8_lossy(&o.stderr).to_string())
45            .unwrap_or_default()
46    }
47
48    /// Get test output
49    pub fn test_output(&self) -> Option<String> {
50        self.compilation_output
51            .as_ref()
52            .map(|o| String::from_utf8_lossy(&o.stdout).to_string())
53    }
54}
55
56/// Configuration for a generation test
57pub struct GenerationTest {
58    /// Name of the test (used for project name)
59    pub name: String,
60    /// OpenAPI specification
61    pub spec: Value,
62    /// Generator configuration (optional overrides)
63    pub config_overrides: Option<GeneratorConfigOverrides>,
64    /// Test scenarios to generate
65    pub test_scenarios: Vec<TestScenario>,
66    /// Whether to run tests after compilation (only if compilation is enabled)
67    pub run_tests: bool,
68    /// Whether to compile the generated code (false by default, must opt-in)
69    pub compile: bool,
70    /// Additional dependencies for Cargo.toml
71    pub extra_dependencies: Vec<(String, String)>,
72}
73
74/// A test scenario to generate
75#[derive(Clone)]
76pub struct TestScenario {
77    /// Name of the test function
78    pub name: String,
79    /// Type to test
80    pub target_type: String,
81    /// Test behavior
82    pub behavior: TestBehavior,
83}
84
85/// Different test behaviors we can generate
86#[derive(Clone)]
87pub enum TestBehavior {
88    /// Test serialization/deserialization round trip
89    RoundTrip {
90        /// JSON value to test with
91        json: Value,
92        /// Expected field values to assert (field_name -> expected_value)
93        assertions: Vec<(String, Value)>,
94    },
95    /// Test creating an instance with specific values
96    Construction {
97        /// Field values to set (field_name -> value_expr)
98        /// Value expressions are simple literals or enum variants
99        fields: Vec<(String, FieldValue)>,
100        /// Assertions to make on the created instance
101        assertions: Vec<ConstructionAssertion>,
102    },
103    /// Test that a type exists and compiles
104    CompileOnly,
105    /// Test enum variant matching
106    EnumMatch {
107        /// JSON to deserialize
108        json: Value,
109        /// Expected variant name
110        expected_variant: String,
111        /// Fields to check in the variant (field_name -> expected_value)
112        variant_assertions: Vec<(String, Value)>,
113    },
114}
115
116/// Value to use when constructing a field
117#[derive(Clone)]
118pub enum FieldValue {
119    /// String literal
120    String(String),
121    /// Integer literal
122    Integer(i64),
123    /// Float literal
124    Float(f64),
125    /// Boolean literal
126    Boolean(bool),
127    /// Null value
128    Null,
129    /// Enum variant (e.g., "MyEnum::Variant")
130    EnumVariant(String),
131    /// Array of values
132    Array(Vec<FieldValue>),
133    /// Struct construction (type_name, fields)
134    Struct(String, Vec<(String, FieldValue)>),
135}
136
137/// Assertion to make on a constructed value
138#[derive(Clone)]
139pub enum ConstructionAssertion {
140    /// Assert serialized JSON contains a string
141    JsonContains(String),
142    /// Assert a field equals a value
143    FieldEquals(String, Value),
144    /// Assert successful serialization
145    CanSerialize,
146}
147
148/// Overrides for generator configuration
149pub struct GeneratorConfigOverrides {
150    pub module_name: Option<String>,
151    pub enable_sse_client: Option<bool>,
152    pub enable_async_client: Option<bool>,
153}
154
155impl GenerationTest {
156    /// Create a new test with the given name and spec
157    pub fn new(name: impl Into<String>, spec: Value) -> Self {
158        Self {
159            name: name.into(),
160            spec,
161            ..Default::default()
162        }
163    }
164
165    /// Enable compilation for this test
166    pub fn with_compilation(mut self) -> Self {
167        self.compile = true;
168        self
169    }
170
171    /// Enable compilation only if OPENAPI_GEN_COMPILE_TESTS is set
172    pub fn with_env_compilation(mut self) -> Self {
173        self.compile = should_compile_tests();
174        self
175    }
176
177    /// Add a test scenario
178    pub fn with_scenario(mut self, scenario: TestScenario) -> Self {
179        self.test_scenarios.push(scenario);
180        self
181    }
182
183    /// Set whether to run tests (only applies if compilation is enabled)
184    pub fn run_tests(mut self, run: bool) -> Self {
185        self.run_tests = run;
186        self
187    }
188}
189
190impl Default for GenerationTest {
191    fn default() -> Self {
192        Self {
193            name: "test".to_string(),
194            spec: Value::Null,
195            config_overrides: None,
196            test_scenarios: vec![],
197            run_tests: true,
198            extra_dependencies: vec![],
199            compile: false, // Opt-in compilation
200        }
201    }
202}
203
204/// Check if we should run full compilation tests (via environment variable)
205pub fn should_compile_tests() -> bool {
206    env::var("OPENAPI_GEN_COMPILE_TESTS")
207        .map(|v| v == "1" || v.to_lowercase() == "true")
208        .unwrap_or(false)
209}
210
211/// Fast test function that only validates syntax (no compilation)
212/// When called from tests, automatically creates/verifies snapshots
213pub fn test_generation(name: &str, spec: Value) -> Result<String, Box<dyn std::error::Error>> {
214    // Analyze the spec
215    let mut analyzer = SchemaAnalyzer::new(spec)?;
216    let mut analysis = analyzer.analyze()?;
217
218    // Generate code
219    let config = GeneratorConfig {
220        module_name: name.to_string(),
221        ..Default::default()
222    };
223    let generator = CodeGenerator::new(config);
224    let generated_code = generator.generate(&mut analysis)?;
225
226    // The code is already validated by syn in the generator
227
228    // Automatically assert snapshot when insta is available
229    insta::assert_snapshot!(name, &generated_code);
230
231    Ok(generated_code)
232}
233
234/// Run a complete generation test
235pub fn run_generation_test(
236    test: GenerationTest,
237) -> Result<GenerationTestResult, Box<dyn std::error::Error>> {
238    // Create temporary directory
239    let temp_dir = TempDir::new()?;
240    let project_dir = temp_dir.path().join(&test.name);
241    fs::create_dir(&project_dir)?;
242
243    // Set up paths
244    let src_dir = project_dir.join("src");
245    fs::create_dir(&src_dir)?;
246    let generated_src_dir = src_dir.join("generated");
247    fs::create_dir(&generated_src_dir)?;
248
249    // Analyze the spec
250    let mut analyzer = SchemaAnalyzer::new(test.spec.clone())?;
251    let analysis = analyzer.analyze()?;
252
253    // Configure generator
254    let config = GeneratorConfig {
255        spec_path: PathBuf::from("test.json"), // Not used when we pass analysis directly
256        output_dir: generated_src_dir.clone(),
257        module_name: test
258            .config_overrides
259            .as_ref()
260            .and_then(|o| o.module_name.clone())
261            .unwrap_or_else(|| "generated".to_string()),
262        enable_sse_client: test
263            .config_overrides
264            .as_ref()
265            .and_then(|o| o.enable_sse_client)
266            .unwrap_or(false),
267        enable_async_client: test
268            .config_overrides
269            .as_ref()
270            .and_then(|o| o.enable_async_client)
271            .unwrap_or(false),
272        enable_specta: false,
273        type_mappings: {
274            let mut mappings = std::collections::BTreeMap::new();
275            mappings.insert("integer".to_string(), "i64".to_string());
276            mappings.insert("number".to_string(), "f64".to_string());
277            mappings.insert("string".to_string(), "String".to_string());
278            mappings.insert("boolean".to_string(), "bool".to_string());
279            mappings
280        },
281        streaming_config: None,
282        nullable_field_overrides: std::collections::BTreeMap::new(),
283        schema_extensions: vec![],
284        http_client_config: None,
285        retry_config: None,
286        tracing_enabled: true,
287        auth_config: None,
288    };
289
290    // Generate code
291    let generator = CodeGenerator::new(config);
292    let mut analysis_mut = analysis;
293    let types_content = generator.generate(&mut analysis_mut)?;
294
295    // Create files list with just the types file for now
296    let files = vec![GeneratedFile {
297        path: "types.rs".into(),
298        content: types_content.clone(),
299    }];
300
301    // Write generated files
302    for file in &files {
303        let dest_path = generated_src_dir.join(&file.path);
304        if let Some(parent) = dest_path.parent() {
305            fs::create_dir_all(parent)?;
306        }
307        fs::write(dest_path, &file.content)?;
308    }
309
310    // Create Cargo.toml
311    let mut dependencies = vec![
312        ("serde", r#"{ version = "1.0", features = ["derive"] }"#),
313        ("serde_json", r#""1.0""#),
314        ("async-trait", r#""0.1""#),
315        (
316            "reqwest",
317            r#"{ version = "0.12", features = ["json", "stream"] }"#,
318        ),
319        ("futures-util", r#""0.3""#),
320        ("tokio", r#"{ version = "1.0", features = ["full"] }"#),
321        ("tracing", r#""0.1""#),
322    ];
323
324    // Add extra dependencies
325    for (name, version) in &test.extra_dependencies {
326        dependencies.push((name.as_str(), version.as_str()));
327    }
328
329    let deps_str = dependencies
330        .iter()
331        .map(|(name, ver)| format!("{name} = {ver}"))
332        .collect::<Vec<_>>()
333        .join("\n");
334
335    let cargo_toml = format!(
336        r#"[package]
337name = "{}"
338version = "0.1.0"
339edition = "2021"
340
341[dependencies]
342{}
343"#,
344        test.name.replace('-', "_"),
345        deps_str
346    );
347
348    let cargo_toml_path = project_dir.join("Cargo.toml");
349    fs::write(&cargo_toml_path, cargo_toml)?;
350
351    // Create generated/mod.rs
352    fs::write(generated_src_dir.join("mod.rs"), "pub mod types;\n")?;
353
354    // Create lib.rs with tests
355    let lib_rs = create_lib_rs(&test, &files);
356    fs::write(src_dir.join("lib.rs"), lib_rs)?;
357
358    // Only compile if explicitly requested
359    let compilation_output = if !test.compile {
360        None
361    } else if test.run_tests {
362        // Run tests if requested
363        Some(
364            Command::new("cargo")
365                .arg("test")
366                .arg("--")
367                .arg("--nocapture")
368                .current_dir(&project_dir)
369                .env("RUST_BACKTRACE", "1")
370                .output()?,
371        )
372    } else {
373        // Just check compilation
374        Some(
375            Command::new("cargo")
376                .arg("check")
377                .current_dir(&project_dir)
378                .env("RUST_BACKTRACE", "1")
379                .output()?,
380        )
381    };
382
383    // Debug: print generated types for failing tests
384    if let Some(ref output) = compilation_output {
385        if !output.status.success() {
386            if let Ok(types_content) = fs::read_to_string(generated_src_dir.join("types.rs")) {
387                eprintln!("Generated types.rs:\n{types_content}");
388            }
389        }
390    }
391
392    Ok(GenerationTestResult {
393        temp_dir,
394        project_dir,
395        generated_src_dir,
396        files,
397        compilation_output,
398        cargo_toml_path,
399    })
400}
401
402fn create_lib_rs(test: &GenerationTest, _generated_files: &[GeneratedFile]) -> String {
403    let mut lib_content = String::from("pub mod generated;\n\n");
404
405    // Add test module
406    lib_content.push_str("#[cfg(test)]\nmod tests {\n");
407    lib_content.push_str("    use super::generated::types::*;\n");
408    lib_content.push_str("    use serde_json;\n\n");
409
410    // Always add basic compilation test
411    lib_content.push_str("    #[test]\n");
412    lib_content.push_str("    fn test_compilation() {\n");
413    lib_content.push_str(&format!(
414        "        // Generated types compile for test: {}\n",
415        test.name
416    ));
417    lib_content.push_str("    }\n\n");
418
419    // Generate tests from scenarios
420    for scenario in &test.test_scenarios {
421        lib_content.push_str(&generate_test_from_scenario(scenario));
422        lib_content.push('\n');
423    }
424
425    lib_content.push_str("}\n");
426    lib_content
427}
428
429/// Generate test code from a test scenario
430fn generate_test_from_scenario(scenario: &TestScenario) -> String {
431    let mut test_code = String::new();
432
433    // Convert type names to valid Rust type names (remove underscores)
434    let target_type = to_rust_type_name(&scenario.target_type);
435
436    test_code.push_str(&format!("    #[test]\n    fn {}() {{\n", scenario.name));
437
438    match &scenario.behavior {
439        TestBehavior::RoundTrip { json, assertions } => {
440            // Generate round-trip test
441            test_code.push_str(&format!("        let json_str = r#\"{json}\"#;\n"));
442            test_code.push_str(&format!(
443                "        let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
444            ));
445
446            // Add assertions
447            for (field, expected) in assertions {
448                test_code.push_str(&format!(
449                    "        assert_eq!(parsed.{field}, serde_json::json!({expected}));\n"
450                ));
451            }
452
453            // Test serialization round-trip
454            test_code
455                .push_str("        let serialized = serde_json::to_string(&parsed).unwrap();\n");
456            test_code.push_str(&format!(
457                "        let round_trip: {target_type} = serde_json::from_str(&serialized).unwrap();\n"
458            ));
459            test_code.push_str("        assert_eq!(serde_json::to_value(parsed).unwrap(), serde_json::to_value(round_trip).unwrap());\n");
460        }
461
462        TestBehavior::Construction { fields, assertions } => {
463            // Generate construction test
464            test_code.push_str(&format!("        let instance = {target_type} {{\n"));
465
466            for (field_name, field_value) in fields {
467                test_code.push_str(&format!(
468                    "            {}: {},\n",
469                    field_name,
470                    field_value_to_code(field_value)
471                ));
472            }
473
474            test_code.push_str("        };\n\n");
475
476            // Add assertions
477            for assertion in assertions {
478                match assertion {
479                    ConstructionAssertion::JsonContains(text) => {
480                        test_code.push_str(
481                            "        let json = serde_json::to_string(&instance).unwrap();\n",
482                        );
483                        test_code
484                            .push_str(&format!("        assert!(json.contains(\"{text}\"));\n"));
485                    }
486                    ConstructionAssertion::FieldEquals(field, value) => {
487                        test_code.push_str(&format!(
488                            "        assert_eq!(instance.{field}, serde_json::json!({value}));\n"
489                        ));
490                    }
491                    ConstructionAssertion::CanSerialize => {
492                        test_code.push_str(
493                            "        let _json = serde_json::to_string(&instance).unwrap();\n",
494                        );
495                    }
496                }
497            }
498        }
499
500        TestBehavior::CompileOnly => {
501            test_code.push_str(&format!(
502                "        // Type {target_type} exists and compiles\n"
503            ));
504            test_code.push_str(&format!("        let _: Option<{target_type}> = None;\n"));
505        }
506
507        TestBehavior::EnumMatch {
508            json,
509            expected_variant,
510            variant_assertions,
511        } => {
512            test_code.push_str(&format!("        let json_str = r#\"{json}\"#;\n"));
513            test_code.push_str(&format!(
514                "        let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
515            ));
516            test_code.push_str("        match parsed {\n");
517
518            // Generate match pattern for struct variants with named fields
519            if variant_assertions.is_empty() {
520                // No assertions needed, just check the variant
521                test_code.push_str(&format!(
522                    "            {target_type}::{expected_variant} {{ .. }} => {{\n"
523                ));
524                test_code.push_str("                // Variant matched successfully\n");
525            } else {
526                // Has assertions - generate struct variant pattern with field bindings
527                let field_bindings: Vec<String> = variant_assertions
528                    .iter()
529                    .map(|(field, _)| field.clone())
530                    .collect();
531
532                if field_bindings.is_empty() {
533                    test_code.push_str(&format!(
534                        "            {target_type}::{expected_variant} {{ .. }} => {{\n"
535                    ));
536                } else {
537                    // Always add .. to allow partial matching
538                    let bindings = field_bindings.join(", ");
539                    test_code.push_str(&format!(
540                        "            {target_type}::{expected_variant} {{ {bindings}, .. }} => {{\n"
541                    ));
542
543                    // Add assertions on the fields
544                    for (field, expected) in variant_assertions {
545                        test_code.push_str(&format!(
546                            "                assert_eq!({field}, serde_json::json!({expected}));\n"
547                        ));
548                    }
549                }
550            }
551
552            test_code.push_str("            }\n");
553            test_code.push_str(&format!(
554                "            _ => panic!(\"Expected {expected_variant} variant\"),\n"
555            ));
556            test_code.push_str("        }\n");
557        }
558    }
559
560    test_code.push_str("    }\n");
561    test_code
562}
563
564/// Convert a FieldValue to Rust code
565fn field_value_to_code(value: &FieldValue) -> String {
566    match value {
567        FieldValue::String(s) => format!("\"{s}\".to_string()"),
568        FieldValue::Integer(i) => i.to_string(),
569        FieldValue::Float(f) => f.to_string(),
570        FieldValue::Boolean(b) => b.to_string(),
571        FieldValue::Null => "None".to_string(),
572        FieldValue::EnumVariant(variant) => variant.clone(),
573        FieldValue::Array(values) => {
574            let items: Vec<String> = values.iter().map(field_value_to_code).collect();
575            format!("vec![{}]", items.join(", "))
576        }
577        FieldValue::Struct(type_name, fields) => {
578            let mut code = format!("{type_name} {{\n");
579            for (field_name, field_value) in fields {
580                code.push_str(&format!(
581                    "                {}: {},\n",
582                    field_name,
583                    field_value_to_code(field_value)
584                ));
585            }
586            code.push_str("            }");
587            code
588        }
589    }
590}
591
592/// Helper to create a minimal OpenAPI spec for testing
593pub fn minimal_spec(schemas: Value) -> Value {
594    serde_json::json!({
595        "openapi": "3.0.0",
596        "info": {
597            "title": "Test API",
598            "version": "1.0.0"
599        },
600        "components": {
601            "schemas": schemas
602        }
603    })
604}
605
606/// Assert that the test compiled successfully
607pub fn assert_compilation_success(result: &GenerationTestResult) {
608    if let Some(output) = &result.compilation_output {
609        if !output.status.success() {
610            let stderr = String::from_utf8_lossy(&output.stderr);
611            let stdout = String::from_utf8_lossy(&output.stdout);
612            panic!("Compilation failed!\nSTDERR:\n{stderr}\nSTDOUT:\n{stdout}");
613        }
614    } else {
615        panic!("No compilation was run");
616    }
617}
618
619/// Find a type definition in the generated files
620pub fn find_type_definition(result: &GenerationTestResult, type_name: &str) -> Option<String> {
621    for file in &result.files {
622        for line in file.content.lines() {
623            if (line.contains(&format!("struct {type_name}"))
624                || line.contains(&format!("enum {type_name}"))
625                || line.contains(&format!("type {type_name} =")))
626                && !line.trim().starts_with("//")
627            {
628                return Some(line.to_string());
629            }
630        }
631    }
632    None
633}
634
635/// Assert that a type name doesn't contain underscores
636pub fn assert_no_underscores_in_type_name(type_definition: &str) {
637    if let Some(type_name) = extract_type_name(type_definition) {
638        assert!(
639            !type_name.contains('_'),
640            "Type name '{type_name}' should not contain underscores"
641        );
642    }
643}
644
645fn extract_type_name(type_def: &str) -> Option<String> {
646    let parts: Vec<&str> = type_def.split_whitespace().collect();
647    if parts.len() >= 2 {
648        let name = parts[1].split('<').next()?.split('=').next()?.trim();
649        Some(name.to_string())
650    } else {
651        None
652    }
653}
654
655/// Convert a type name to valid Rust type name (PascalCase without underscores)
656fn to_rust_type_name(s: &str) -> String {
657    let mut result = String::new();
658    let mut next_upper = true;
659
660    for c in s.chars() {
661        match c {
662            'a'..='z' => {
663                if next_upper {
664                    result.push(c.to_ascii_uppercase());
665                    next_upper = false;
666                } else {
667                    result.push(c);
668                }
669            }
670            'A'..='Z' => {
671                result.push(c);
672                next_upper = false;
673            }
674            '0'..='9' => {
675                result.push(c);
676                next_upper = false;
677            }
678            '_' | '-' | '.' | ' ' => {
679                // Skip underscore/separator and make next char uppercase
680                next_upper = true;
681            }
682            _ => {
683                // Other special characters - treat as word boundary
684                next_upper = true;
685            }
686        }
687    }
688
689    result
690}