Skip to main content

openapi_to_rust/
test_helpers.rs

1use crate::{CodeGenerator, GeneratedFile, GeneratorConfig, SchemaAnalyzer};
2use serde_json::Value;
3use std::env;
4use std::fs;
5use std::path::PathBuf;
6use std::process::Command;
7use tempfile::TempDir;
8
9/// Result of a generation test
10pub struct GenerationTestResult {
11    /// Temporary directory containing all generated files
12    pub temp_dir: TempDir,
13    /// Path to the test project directory
14    pub project_dir: PathBuf,
15    /// Path to the generated source directory
16    pub generated_src_dir: PathBuf,
17    /// Generated files
18    pub files: Vec<GeneratedFile>,
19    /// Compilation output (if compilation was run)
20    pub compilation_output: Option<std::process::Output>,
21    /// Path to Cargo.toml
22    pub cargo_toml_path: PathBuf,
23}
24
25impl GenerationTestResult {
26    /// Read a generated file by name
27    pub fn read_file(&self, filename: &str) -> std::io::Result<String> {
28        let path = self.generated_src_dir.join(filename);
29        fs::read_to_string(path)
30    }
31
32    /// Check if compilation succeeded
33    pub fn compiled_successfully(&self) -> bool {
34        self.compilation_output
35            .as_ref()
36            .map(|o| o.status.success())
37            .unwrap_or(false)
38    }
39
40    /// Get compilation errors
41    pub fn compilation_errors(&self) -> String {
42        self.compilation_output
43            .as_ref()
44            .map(|o| String::from_utf8_lossy(&o.stderr).to_string())
45            .unwrap_or_default()
46    }
47
48    /// Get test output
49    pub fn test_output(&self) -> Option<String> {
50        self.compilation_output
51            .as_ref()
52            .map(|o| String::from_utf8_lossy(&o.stdout).to_string())
53    }
54}
55
56/// Configuration for a generation test
57pub struct GenerationTest {
58    /// Name of the test (used for project name)
59    pub name: String,
60    /// OpenAPI specification
61    pub spec: Value,
62    /// Generator configuration (optional overrides)
63    pub config_overrides: Option<GeneratorConfigOverrides>,
64    /// Test scenarios to generate
65    pub test_scenarios: Vec<TestScenario>,
66    /// Whether to run tests after compilation (only if compilation is enabled)
67    pub run_tests: bool,
68    /// Whether to compile the generated code (false by default, must opt-in)
69    pub compile: bool,
70    /// Additional dependencies for Cargo.toml
71    pub extra_dependencies: Vec<(String, String)>,
72}
73
74/// A test scenario to generate
75#[derive(Clone)]
76pub struct TestScenario {
77    /// Name of the test function
78    pub name: String,
79    /// Type to test
80    pub target_type: String,
81    /// Test behavior
82    pub behavior: TestBehavior,
83}
84
85/// Different test behaviors we can generate
86#[derive(Clone)]
87pub enum TestBehavior {
88    /// Test serialization/deserialization round trip
89    RoundTrip {
90        /// JSON value to test with
91        json: Value,
92        /// Expected field values to assert (field_name -> expected_value)
93        assertions: Vec<(String, Value)>,
94    },
95    /// Test creating an instance with specific values
96    Construction {
97        /// Field values to set (field_name -> value_expr)
98        /// Value expressions are simple literals or enum variants
99        fields: Vec<(String, FieldValue)>,
100        /// Assertions to make on the created instance
101        assertions: Vec<ConstructionAssertion>,
102    },
103    /// Test that a type exists and compiles
104    CompileOnly,
105    /// Test enum variant matching
106    EnumMatch {
107        /// JSON to deserialize
108        json: Value,
109        /// Expected variant name
110        expected_variant: String,
111        /// Fields to check in the variant (field_name -> expected_value)
112        variant_assertions: Vec<(String, Value)>,
113    },
114}
115
116/// Value to use when constructing a field
117#[derive(Clone)]
118pub enum FieldValue {
119    /// String literal
120    String(String),
121    /// Integer literal
122    Integer(i64),
123    /// Float literal
124    Float(f64),
125    /// Boolean literal
126    Boolean(bool),
127    /// Null value
128    Null,
129    /// Enum variant (e.g., "MyEnum::Variant")
130    EnumVariant(String),
131    /// Array of values
132    Array(Vec<FieldValue>),
133    /// Struct construction (type_name, fields)
134    Struct(String, Vec<(String, FieldValue)>),
135}
136
137/// Assertion to make on a constructed value
138#[derive(Clone)]
139pub enum ConstructionAssertion {
140    /// Assert serialized JSON contains a string
141    JsonContains(String),
142    /// Assert a field equals a value
143    FieldEquals(String, Value),
144    /// Assert successful serialization
145    CanSerialize,
146}
147
148/// Overrides for generator configuration
149pub struct GeneratorConfigOverrides {
150    pub module_name: Option<String>,
151    pub enable_sse_client: Option<bool>,
152    pub enable_async_client: Option<bool>,
153}
154
155impl GenerationTest {
156    /// Create a new test with the given name and spec
157    pub fn new(name: impl Into<String>, spec: Value) -> Self {
158        Self {
159            name: name.into(),
160            spec,
161            ..Default::default()
162        }
163    }
164
165    /// Enable compilation for this test
166    pub fn with_compilation(mut self) -> Self {
167        self.compile = true;
168        self
169    }
170
171    /// Enable compilation only if OPENAPI_GEN_COMPILE_TESTS is set
172    pub fn with_env_compilation(mut self) -> Self {
173        self.compile = should_compile_tests();
174        self
175    }
176
177    /// Add a test scenario
178    pub fn with_scenario(mut self, scenario: TestScenario) -> Self {
179        self.test_scenarios.push(scenario);
180        self
181    }
182
183    /// Set whether to run tests (only applies if compilation is enabled)
184    pub fn run_tests(mut self, run: bool) -> Self {
185        self.run_tests = run;
186        self
187    }
188}
189
190impl Default for GenerationTest {
191    fn default() -> Self {
192        Self {
193            name: "test".to_string(),
194            spec: Value::Null,
195            config_overrides: None,
196            test_scenarios: vec![],
197            run_tests: true,
198            extra_dependencies: vec![],
199            compile: false, // Opt-in compilation
200        }
201    }
202}
203
204/// Check if we should run full compilation tests (via environment variable)
205pub fn should_compile_tests() -> bool {
206    env::var("OPENAPI_GEN_COMPILE_TESTS")
207        .map(|v| v == "1" || v.to_lowercase() == "true")
208        .unwrap_or(false)
209}
210
211/// Fast test function that only validates syntax (no compilation)
212/// When called from tests, automatically creates/verifies snapshots
213pub fn test_generation(name: &str, spec: Value) -> Result<String, Box<dyn std::error::Error>> {
214    // Analyze the spec
215    let mut analyzer = SchemaAnalyzer::new(spec)?;
216    let mut analysis = analyzer.analyze()?;
217
218    // Generate code
219    let config = GeneratorConfig {
220        module_name: name.to_string(),
221        ..Default::default()
222    };
223    let generator = CodeGenerator::new(config);
224    let generated_code = generator.generate(&mut analysis)?;
225
226    // The code is already validated by syn in the generator
227
228    // Automatically assert snapshot when insta is available
229    insta::assert_snapshot!(name, &generated_code);
230
231    Ok(generated_code)
232}
233
234/// Run a complete generation test
235pub fn run_generation_test(
236    test: GenerationTest,
237) -> Result<GenerationTestResult, Box<dyn std::error::Error>> {
238    // Create temporary directory
239    let temp_dir = TempDir::new()?;
240    let project_dir = temp_dir.path().join(&test.name);
241    fs::create_dir(&project_dir)?;
242
243    // Set up paths
244    let src_dir = project_dir.join("src");
245    fs::create_dir(&src_dir)?;
246    let generated_src_dir = src_dir.join("generated");
247    fs::create_dir(&generated_src_dir)?;
248
249    // Analyze the spec
250    let mut analyzer = SchemaAnalyzer::new(test.spec.clone())?;
251    let analysis = analyzer.analyze()?;
252
253    // Configure generator
254    let config = GeneratorConfig {
255        spec_path: PathBuf::from("test.json"), // Not used when we pass analysis directly
256        output_dir: generated_src_dir.clone(),
257        module_name: test
258            .config_overrides
259            .as_ref()
260            .and_then(|o| o.module_name.clone())
261            .unwrap_or_else(|| "generated".to_string()),
262        enable_sse_client: test
263            .config_overrides
264            .as_ref()
265            .and_then(|o| o.enable_sse_client)
266            .unwrap_or(false),
267        enable_async_client: test
268            .config_overrides
269            .as_ref()
270            .and_then(|o| o.enable_async_client)
271            .unwrap_or(false),
272        enable_specta: false,
273        type_mappings: {
274            let mut mappings = std::collections::BTreeMap::new();
275            mappings.insert("integer".to_string(), "i64".to_string());
276            mappings.insert("number".to_string(), "f64".to_string());
277            mappings.insert("string".to_string(), "String".to_string());
278            mappings.insert("boolean".to_string(), "bool".to_string());
279            mappings
280        },
281        streaming_config: None,
282        nullable_field_overrides: std::collections::BTreeMap::new(),
283        schema_extensions: vec![],
284        http_client_config: None,
285        retry_config: None,
286        tracing_enabled: true,
287        auth_config: None,
288        enable_registry: false,
289        registry_only: false,
290    };
291
292    // Generate code
293    let generator = CodeGenerator::new(config);
294    let mut analysis_mut = analysis;
295    let types_content = generator.generate(&mut analysis_mut)?;
296
297    // Create files list with just the types file for now
298    let files = vec![GeneratedFile {
299        path: "types.rs".into(),
300        content: types_content.clone(),
301    }];
302
303    // Write generated files
304    for file in &files {
305        let dest_path = generated_src_dir.join(&file.path);
306        if let Some(parent) = dest_path.parent() {
307            fs::create_dir_all(parent)?;
308        }
309        fs::write(dest_path, &file.content)?;
310    }
311
312    // Create Cargo.toml
313    let mut dependencies = vec![
314        ("serde", r#"{ version = "1.0", features = ["derive"] }"#),
315        ("serde_json", r#""1.0""#),
316        ("async-trait", r#""0.1""#),
317        (
318            "reqwest",
319            r#"{ version = "0.12", features = ["json", "stream"] }"#,
320        ),
321        ("futures-util", r#""0.3""#),
322        ("tokio", r#"{ version = "1.0", features = ["full"] }"#),
323        ("tracing", r#""0.1""#),
324    ];
325
326    // Add extra dependencies
327    for (name, version) in &test.extra_dependencies {
328        dependencies.push((name.as_str(), version.as_str()));
329    }
330
331    let deps_str = dependencies
332        .iter()
333        .map(|(name, ver)| format!("{name} = {ver}"))
334        .collect::<Vec<_>>()
335        .join("\n");
336
337    let cargo_toml = format!(
338        r#"[package]
339name = "{}"
340version = "0.1.0"
341edition = "2021"
342
343[dependencies]
344{}
345"#,
346        test.name.replace('-', "_"),
347        deps_str
348    );
349
350    let cargo_toml_path = project_dir.join("Cargo.toml");
351    fs::write(&cargo_toml_path, cargo_toml)?;
352
353    // Create generated/mod.rs
354    fs::write(generated_src_dir.join("mod.rs"), "pub mod types;\n")?;
355
356    // Create lib.rs with tests
357    let lib_rs = create_lib_rs(&test, &files);
358    fs::write(src_dir.join("lib.rs"), lib_rs)?;
359
360    // Only compile if explicitly requested
361    let compilation_output = if !test.compile {
362        None
363    } else if test.run_tests {
364        // Run tests if requested
365        Some(
366            Command::new("cargo")
367                .arg("test")
368                .arg("--")
369                .arg("--nocapture")
370                .current_dir(&project_dir)
371                .env("RUST_BACKTRACE", "1")
372                .output()?,
373        )
374    } else {
375        // Just check compilation
376        Some(
377            Command::new("cargo")
378                .arg("check")
379                .current_dir(&project_dir)
380                .env("RUST_BACKTRACE", "1")
381                .output()?,
382        )
383    };
384
385    // Debug: print generated types for failing tests
386    if let Some(ref output) = compilation_output {
387        if !output.status.success() {
388            if let Ok(types_content) = fs::read_to_string(generated_src_dir.join("types.rs")) {
389                eprintln!("Generated types.rs:\n{types_content}");
390            }
391        }
392    }
393
394    Ok(GenerationTestResult {
395        temp_dir,
396        project_dir,
397        generated_src_dir,
398        files,
399        compilation_output,
400        cargo_toml_path,
401    })
402}
403
404fn create_lib_rs(test: &GenerationTest, _generated_files: &[GeneratedFile]) -> String {
405    let mut lib_content = String::from("pub mod generated;\n\n");
406
407    // Add test module
408    lib_content.push_str("#[cfg(test)]\nmod tests {\n");
409    lib_content.push_str("    use super::generated::types::*;\n");
410    lib_content.push_str("    use serde_json;\n\n");
411
412    // Always add basic compilation test
413    lib_content.push_str("    #[test]\n");
414    lib_content.push_str("    fn test_compilation() {\n");
415    lib_content.push_str(&format!(
416        "        // Generated types compile for test: {}\n",
417        test.name
418    ));
419    lib_content.push_str("    }\n\n");
420
421    // Generate tests from scenarios
422    for scenario in &test.test_scenarios {
423        lib_content.push_str(&generate_test_from_scenario(scenario));
424        lib_content.push('\n');
425    }
426
427    lib_content.push_str("}\n");
428    lib_content
429}
430
431/// Generate test code from a test scenario
432fn generate_test_from_scenario(scenario: &TestScenario) -> String {
433    let mut test_code = String::new();
434
435    // Convert type names to valid Rust type names (remove underscores)
436    let target_type = to_rust_type_name(&scenario.target_type);
437
438    test_code.push_str(&format!("    #[test]\n    fn {}() {{\n", scenario.name));
439
440    match &scenario.behavior {
441        TestBehavior::RoundTrip { json, assertions } => {
442            // Generate round-trip test
443            test_code.push_str(&format!("        let json_str = r#\"{json}\"#;\n"));
444            test_code.push_str(&format!(
445                "        let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
446            ));
447
448            // Add assertions
449            for (field, expected) in assertions {
450                test_code.push_str(&format!(
451                    "        assert_eq!(parsed.{field}, serde_json::json!({expected}));\n"
452                ));
453            }
454
455            // Test serialization round-trip
456            test_code
457                .push_str("        let serialized = serde_json::to_string(&parsed).unwrap();\n");
458            test_code.push_str(&format!(
459                "        let round_trip: {target_type} = serde_json::from_str(&serialized).unwrap();\n"
460            ));
461            test_code.push_str("        assert_eq!(serde_json::to_value(parsed).unwrap(), serde_json::to_value(round_trip).unwrap());\n");
462        }
463
464        TestBehavior::Construction { fields, assertions } => {
465            // Generate construction test
466            test_code.push_str(&format!("        let instance = {target_type} {{\n"));
467
468            for (field_name, field_value) in fields {
469                test_code.push_str(&format!(
470                    "            {}: {},\n",
471                    field_name,
472                    field_value_to_code(field_value)
473                ));
474            }
475
476            test_code.push_str("        };\n\n");
477
478            // Add assertions
479            for assertion in assertions {
480                match assertion {
481                    ConstructionAssertion::JsonContains(text) => {
482                        test_code.push_str(
483                            "        let json = serde_json::to_string(&instance).unwrap();\n",
484                        );
485                        test_code
486                            .push_str(&format!("        assert!(json.contains(\"{text}\"));\n"));
487                    }
488                    ConstructionAssertion::FieldEquals(field, value) => {
489                        test_code.push_str(&format!(
490                            "        assert_eq!(instance.{field}, serde_json::json!({value}));\n"
491                        ));
492                    }
493                    ConstructionAssertion::CanSerialize => {
494                        test_code.push_str(
495                            "        let _json = serde_json::to_string(&instance).unwrap();\n",
496                        );
497                    }
498                }
499            }
500        }
501
502        TestBehavior::CompileOnly => {
503            test_code.push_str(&format!(
504                "        // Type {target_type} exists and compiles\n"
505            ));
506            test_code.push_str(&format!("        let _: Option<{target_type}> = None;\n"));
507        }
508
509        TestBehavior::EnumMatch {
510            json,
511            expected_variant,
512            variant_assertions,
513        } => {
514            test_code.push_str(&format!("        let json_str = r#\"{json}\"#;\n"));
515            test_code.push_str(&format!(
516                "        let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
517            ));
518            test_code.push_str("        match parsed {\n");
519
520            // Generate match pattern for struct variants with named fields
521            if variant_assertions.is_empty() {
522                // No assertions needed, just check the variant
523                test_code.push_str(&format!(
524                    "            {target_type}::{expected_variant} {{ .. }} => {{\n"
525                ));
526                test_code.push_str("                // Variant matched successfully\n");
527            } else {
528                // Has assertions - generate struct variant pattern with field bindings
529                let field_bindings: Vec<String> = variant_assertions
530                    .iter()
531                    .map(|(field, _)| field.clone())
532                    .collect();
533
534                if field_bindings.is_empty() {
535                    test_code.push_str(&format!(
536                        "            {target_type}::{expected_variant} {{ .. }} => {{\n"
537                    ));
538                } else {
539                    // Always add .. to allow partial matching
540                    let bindings = field_bindings.join(", ");
541                    test_code.push_str(&format!(
542                        "            {target_type}::{expected_variant} {{ {bindings}, .. }} => {{\n"
543                    ));
544
545                    // Add assertions on the fields
546                    for (field, expected) in variant_assertions {
547                        test_code.push_str(&format!(
548                            "                assert_eq!({field}, serde_json::json!({expected}));\n"
549                        ));
550                    }
551                }
552            }
553
554            test_code.push_str("            }\n");
555            test_code.push_str(&format!(
556                "            _ => panic!(\"Expected {expected_variant} variant\"),\n"
557            ));
558            test_code.push_str("        }\n");
559        }
560    }
561
562    test_code.push_str("    }\n");
563    test_code
564}
565
566/// Convert a FieldValue to Rust code
567fn field_value_to_code(value: &FieldValue) -> String {
568    match value {
569        FieldValue::String(s) => format!("\"{s}\".to_string()"),
570        FieldValue::Integer(i) => i.to_string(),
571        FieldValue::Float(f) => f.to_string(),
572        FieldValue::Boolean(b) => b.to_string(),
573        FieldValue::Null => "None".to_string(),
574        FieldValue::EnumVariant(variant) => variant.clone(),
575        FieldValue::Array(values) => {
576            let items: Vec<String> = values.iter().map(field_value_to_code).collect();
577            format!("vec![{}]", items.join(", "))
578        }
579        FieldValue::Struct(type_name, fields) => {
580            let mut code = format!("{type_name} {{\n");
581            for (field_name, field_value) in fields {
582                code.push_str(&format!(
583                    "                {}: {},\n",
584                    field_name,
585                    field_value_to_code(field_value)
586                ));
587            }
588            code.push_str("            }");
589            code
590        }
591    }
592}
593
594/// Helper to create a minimal OpenAPI spec for testing
595pub fn minimal_spec(schemas: Value) -> Value {
596    serde_json::json!({
597        "openapi": "3.0.0",
598        "info": {
599            "title": "Test API",
600            "version": "1.0.0"
601        },
602        "components": {
603            "schemas": schemas
604        }
605    })
606}
607
608/// Assert that the test compiled successfully
609pub fn assert_compilation_success(result: &GenerationTestResult) {
610    if let Some(output) = &result.compilation_output {
611        if !output.status.success() {
612            let stderr = String::from_utf8_lossy(&output.stderr);
613            let stdout = String::from_utf8_lossy(&output.stdout);
614            panic!("Compilation failed!\nSTDERR:\n{stderr}\nSTDOUT:\n{stdout}");
615        }
616    } else {
617        panic!("No compilation was run");
618    }
619}
620
621/// Find a type definition in the generated files
622pub fn find_type_definition(result: &GenerationTestResult, type_name: &str) -> Option<String> {
623    for file in &result.files {
624        for line in file.content.lines() {
625            if (line.contains(&format!("struct {type_name}"))
626                || line.contains(&format!("enum {type_name}"))
627                || line.contains(&format!("type {type_name} =")))
628                && !line.trim().starts_with("//")
629            {
630                return Some(line.to_string());
631            }
632        }
633    }
634    None
635}
636
637/// Assert that a type name doesn't contain underscores
638pub fn assert_no_underscores_in_type_name(type_definition: &str) {
639    if let Some(type_name) = extract_type_name(type_definition) {
640        assert!(
641            !type_name.contains('_'),
642            "Type name '{type_name}' should not contain underscores"
643        );
644    }
645}
646
647fn extract_type_name(type_def: &str) -> Option<String> {
648    let parts: Vec<&str> = type_def.split_whitespace().collect();
649    if parts.len() >= 2 {
650        let name = parts[1].split('<').next()?.split('=').next()?.trim();
651        Some(name.to_string())
652    } else {
653        None
654    }
655}
656
657/// Convert a type name to valid Rust type name (PascalCase without underscores)
658fn to_rust_type_name(s: &str) -> String {
659    let mut result = String::new();
660    let mut next_upper = true;
661
662    for c in s.chars() {
663        match c {
664            'a'..='z' => {
665                if next_upper {
666                    result.push(c.to_ascii_uppercase());
667                    next_upper = false;
668                } else {
669                    result.push(c);
670                }
671            }
672            'A'..='Z' => {
673                result.push(c);
674                next_upper = false;
675            }
676            '0'..='9' => {
677                result.push(c);
678                next_upper = false;
679            }
680            '_' | '-' | '.' | ' ' => {
681                // Skip underscore/separator and make next char uppercase
682                next_upper = true;
683            }
684            _ => {
685                // Other special characters - treat as word boundary
686                next_upper = true;
687            }
688        }
689    }
690
691    result
692}