1use crate::{CodeGenerator, GeneratedFile, GeneratorConfig, SchemaAnalyzer};
2use serde_json::Value;
3use std::env;
4use std::fs;
5use std::path::PathBuf;
6use std::process::Command;
7use tempfile::TempDir;
8
9pub struct GenerationTestResult {
11 pub temp_dir: TempDir,
13 pub project_dir: PathBuf,
15 pub generated_src_dir: PathBuf,
17 pub files: Vec<GeneratedFile>,
19 pub compilation_output: Option<std::process::Output>,
21 pub cargo_toml_path: PathBuf,
23}
24
25impl GenerationTestResult {
26 pub fn read_file(&self, filename: &str) -> std::io::Result<String> {
28 let path = self.generated_src_dir.join(filename);
29 fs::read_to_string(path)
30 }
31
32 pub fn compiled_successfully(&self) -> bool {
34 self.compilation_output
35 .as_ref()
36 .map(|o| o.status.success())
37 .unwrap_or(false)
38 }
39
40 pub fn compilation_errors(&self) -> String {
42 self.compilation_output
43 .as_ref()
44 .map(|o| String::from_utf8_lossy(&o.stderr).to_string())
45 .unwrap_or_default()
46 }
47
48 pub fn test_output(&self) -> Option<String> {
50 self.compilation_output
51 .as_ref()
52 .map(|o| String::from_utf8_lossy(&o.stdout).to_string())
53 }
54}
55
56pub struct GenerationTest {
58 pub name: String,
60 pub spec: Value,
62 pub config_overrides: Option<GeneratorConfigOverrides>,
64 pub test_scenarios: Vec<TestScenario>,
66 pub run_tests: bool,
68 pub compile: bool,
70 pub extra_dependencies: Vec<(String, String)>,
72}
73
74#[derive(Clone)]
76pub struct TestScenario {
77 pub name: String,
79 pub target_type: String,
81 pub behavior: TestBehavior,
83}
84
85#[derive(Clone)]
87pub enum TestBehavior {
88 RoundTrip {
90 json: Value,
92 assertions: Vec<(String, Value)>,
94 },
95 Construction {
97 fields: Vec<(String, FieldValue)>,
100 assertions: Vec<ConstructionAssertion>,
102 },
103 CompileOnly,
105 EnumMatch {
107 json: Value,
109 expected_variant: String,
111 variant_assertions: Vec<(String, Value)>,
113 },
114}
115
116#[derive(Clone)]
118pub enum FieldValue {
119 String(String),
121 Integer(i64),
123 Float(f64),
125 Boolean(bool),
127 Null,
129 EnumVariant(String),
131 Array(Vec<FieldValue>),
133 Struct(String, Vec<(String, FieldValue)>),
135}
136
137#[derive(Clone)]
139pub enum ConstructionAssertion {
140 JsonContains(String),
142 FieldEquals(String, Value),
144 CanSerialize,
146}
147
148pub struct GeneratorConfigOverrides {
150 pub module_name: Option<String>,
151 pub enable_sse_client: Option<bool>,
152 pub enable_async_client: Option<bool>,
153}
154
155impl GenerationTest {
156 pub fn new(name: impl Into<String>, spec: Value) -> Self {
158 Self {
159 name: name.into(),
160 spec,
161 ..Default::default()
162 }
163 }
164
165 pub fn with_compilation(mut self) -> Self {
167 self.compile = true;
168 self
169 }
170
171 pub fn with_env_compilation(mut self) -> Self {
173 self.compile = should_compile_tests();
174 self
175 }
176
177 pub fn with_scenario(mut self, scenario: TestScenario) -> Self {
179 self.test_scenarios.push(scenario);
180 self
181 }
182
183 pub fn run_tests(mut self, run: bool) -> Self {
185 self.run_tests = run;
186 self
187 }
188}
189
190impl Default for GenerationTest {
191 fn default() -> Self {
192 Self {
193 name: "test".to_string(),
194 spec: Value::Null,
195 config_overrides: None,
196 test_scenarios: vec![],
197 run_tests: true,
198 extra_dependencies: vec![],
199 compile: false, }
201 }
202}
203
204pub fn should_compile_tests() -> bool {
206 env::var("OPENAPI_GEN_COMPILE_TESTS")
207 .map(|v| v == "1" || v.to_lowercase() == "true")
208 .unwrap_or(false)
209}
210
211pub fn test_generation(name: &str, spec: Value) -> Result<String, Box<dyn std::error::Error>> {
214 let mut analyzer = SchemaAnalyzer::new(spec)?;
216 let mut analysis = analyzer.analyze()?;
217
218 let config = GeneratorConfig {
220 module_name: name.to_string(),
221 ..Default::default()
222 };
223 let generator = CodeGenerator::new(config);
224 let generated_code = generator.generate(&mut analysis)?;
225
226 insta::assert_snapshot!(name, &generated_code);
230
231 Ok(generated_code)
232}
233
234pub fn run_generation_test(
236 test: GenerationTest,
237) -> Result<GenerationTestResult, Box<dyn std::error::Error>> {
238 let temp_dir = TempDir::new()?;
240 let project_dir = temp_dir.path().join(&test.name);
241 fs::create_dir(&project_dir)?;
242
243 let src_dir = project_dir.join("src");
245 fs::create_dir(&src_dir)?;
246 let generated_src_dir = src_dir.join("generated");
247 fs::create_dir(&generated_src_dir)?;
248
249 let mut analyzer = SchemaAnalyzer::new(test.spec.clone())?;
251 let analysis = analyzer.analyze()?;
252
253 let config = GeneratorConfig {
255 spec_path: PathBuf::from("test.json"), output_dir: generated_src_dir.clone(),
257 module_name: test
258 .config_overrides
259 .as_ref()
260 .and_then(|o| o.module_name.clone())
261 .unwrap_or_else(|| "generated".to_string()),
262 enable_sse_client: test
263 .config_overrides
264 .as_ref()
265 .and_then(|o| o.enable_sse_client)
266 .unwrap_or(false),
267 enable_async_client: test
268 .config_overrides
269 .as_ref()
270 .and_then(|o| o.enable_async_client)
271 .unwrap_or(false),
272 enable_specta: false,
273 type_mappings: {
274 let mut mappings = std::collections::BTreeMap::new();
275 mappings.insert("integer".to_string(), "i64".to_string());
276 mappings.insert("number".to_string(), "f64".to_string());
277 mappings.insert("string".to_string(), "String".to_string());
278 mappings.insert("boolean".to_string(), "bool".to_string());
279 mappings
280 },
281 streaming_config: None,
282 nullable_field_overrides: std::collections::BTreeMap::new(),
283 schema_extensions: vec![],
284 http_client_config: None,
285 retry_config: None,
286 tracing_enabled: true,
287 auth_config: None,
288 enable_registry: false,
289 registry_only: false,
290 };
291
292 let generator = CodeGenerator::new(config);
294 let mut analysis_mut = analysis;
295 let types_content = generator.generate(&mut analysis_mut)?;
296
297 let files = vec![GeneratedFile {
299 path: "types.rs".into(),
300 content: types_content.clone(),
301 }];
302
303 for file in &files {
305 let dest_path = generated_src_dir.join(&file.path);
306 if let Some(parent) = dest_path.parent() {
307 fs::create_dir_all(parent)?;
308 }
309 fs::write(dest_path, &file.content)?;
310 }
311
312 let mut dependencies = vec![
314 ("serde", r#"{ version = "1.0", features = ["derive"] }"#),
315 ("serde_json", r#""1.0""#),
316 ("async-trait", r#""0.1""#),
317 (
318 "reqwest",
319 r#"{ version = "0.12", features = ["json", "stream"] }"#,
320 ),
321 ("futures-util", r#""0.3""#),
322 ("tokio", r#"{ version = "1.0", features = ["full"] }"#),
323 ("tracing", r#""0.1""#),
324 ];
325
326 for (name, version) in &test.extra_dependencies {
328 dependencies.push((name.as_str(), version.as_str()));
329 }
330
331 let deps_str = dependencies
332 .iter()
333 .map(|(name, ver)| format!("{name} = {ver}"))
334 .collect::<Vec<_>>()
335 .join("\n");
336
337 let cargo_toml = format!(
338 r#"[package]
339name = "{}"
340version = "0.1.0"
341edition = "2021"
342
343[dependencies]
344{}
345"#,
346 test.name.replace('-', "_"),
347 deps_str
348 );
349
350 let cargo_toml_path = project_dir.join("Cargo.toml");
351 fs::write(&cargo_toml_path, cargo_toml)?;
352
353 fs::write(generated_src_dir.join("mod.rs"), "pub mod types;\n")?;
355
356 let lib_rs = create_lib_rs(&test, &files);
358 fs::write(src_dir.join("lib.rs"), lib_rs)?;
359
360 let compilation_output = if !test.compile {
362 None
363 } else if test.run_tests {
364 Some(
366 Command::new("cargo")
367 .arg("test")
368 .arg("--")
369 .arg("--nocapture")
370 .current_dir(&project_dir)
371 .env("RUST_BACKTRACE", "1")
372 .output()?,
373 )
374 } else {
375 Some(
377 Command::new("cargo")
378 .arg("check")
379 .current_dir(&project_dir)
380 .env("RUST_BACKTRACE", "1")
381 .output()?,
382 )
383 };
384
385 if let Some(ref output) = compilation_output {
387 if !output.status.success() {
388 if let Ok(types_content) = fs::read_to_string(generated_src_dir.join("types.rs")) {
389 eprintln!("Generated types.rs:\n{types_content}");
390 }
391 }
392 }
393
394 Ok(GenerationTestResult {
395 temp_dir,
396 project_dir,
397 generated_src_dir,
398 files,
399 compilation_output,
400 cargo_toml_path,
401 })
402}
403
404fn create_lib_rs(test: &GenerationTest, _generated_files: &[GeneratedFile]) -> String {
405 let mut lib_content = String::from("pub mod generated;\n\n");
406
407 lib_content.push_str("#[cfg(test)]\nmod tests {\n");
409 lib_content.push_str(" use super::generated::types::*;\n");
410 lib_content.push_str(" use serde_json;\n\n");
411
412 lib_content.push_str(" #[test]\n");
414 lib_content.push_str(" fn test_compilation() {\n");
415 lib_content.push_str(&format!(
416 " // Generated types compile for test: {}\n",
417 test.name
418 ));
419 lib_content.push_str(" }\n\n");
420
421 for scenario in &test.test_scenarios {
423 lib_content.push_str(&generate_test_from_scenario(scenario));
424 lib_content.push('\n');
425 }
426
427 lib_content.push_str("}\n");
428 lib_content
429}
430
431fn generate_test_from_scenario(scenario: &TestScenario) -> String {
433 let mut test_code = String::new();
434
435 let target_type = to_rust_type_name(&scenario.target_type);
437
438 test_code.push_str(&format!(" #[test]\n fn {}() {{\n", scenario.name));
439
440 match &scenario.behavior {
441 TestBehavior::RoundTrip { json, assertions } => {
442 test_code.push_str(&format!(" let json_str = r#\"{json}\"#;\n"));
444 test_code.push_str(&format!(
445 " let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
446 ));
447
448 for (field, expected) in assertions {
450 test_code.push_str(&format!(
451 " assert_eq!(parsed.{field}, serde_json::json!({expected}));\n"
452 ));
453 }
454
455 test_code
457 .push_str(" let serialized = serde_json::to_string(&parsed).unwrap();\n");
458 test_code.push_str(&format!(
459 " let round_trip: {target_type} = serde_json::from_str(&serialized).unwrap();\n"
460 ));
461 test_code.push_str(" assert_eq!(serde_json::to_value(parsed).unwrap(), serde_json::to_value(round_trip).unwrap());\n");
462 }
463
464 TestBehavior::Construction { fields, assertions } => {
465 test_code.push_str(&format!(" let instance = {target_type} {{\n"));
467
468 for (field_name, field_value) in fields {
469 test_code.push_str(&format!(
470 " {}: {},\n",
471 field_name,
472 field_value_to_code(field_value)
473 ));
474 }
475
476 test_code.push_str(" };\n\n");
477
478 for assertion in assertions {
480 match assertion {
481 ConstructionAssertion::JsonContains(text) => {
482 test_code.push_str(
483 " let json = serde_json::to_string(&instance).unwrap();\n",
484 );
485 test_code
486 .push_str(&format!(" assert!(json.contains(\"{text}\"));\n"));
487 }
488 ConstructionAssertion::FieldEquals(field, value) => {
489 test_code.push_str(&format!(
490 " assert_eq!(instance.{field}, serde_json::json!({value}));\n"
491 ));
492 }
493 ConstructionAssertion::CanSerialize => {
494 test_code.push_str(
495 " let _json = serde_json::to_string(&instance).unwrap();\n",
496 );
497 }
498 }
499 }
500 }
501
502 TestBehavior::CompileOnly => {
503 test_code.push_str(&format!(
504 " // Type {target_type} exists and compiles\n"
505 ));
506 test_code.push_str(&format!(" let _: Option<{target_type}> = None;\n"));
507 }
508
509 TestBehavior::EnumMatch {
510 json,
511 expected_variant,
512 variant_assertions,
513 } => {
514 test_code.push_str(&format!(" let json_str = r#\"{json}\"#;\n"));
515 test_code.push_str(&format!(
516 " let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
517 ));
518 test_code.push_str(" match parsed {\n");
519
520 if variant_assertions.is_empty() {
522 test_code.push_str(&format!(
524 " {target_type}::{expected_variant} {{ .. }} => {{\n"
525 ));
526 test_code.push_str(" // Variant matched successfully\n");
527 } else {
528 let field_bindings: Vec<String> = variant_assertions
530 .iter()
531 .map(|(field, _)| field.clone())
532 .collect();
533
534 if field_bindings.is_empty() {
535 test_code.push_str(&format!(
536 " {target_type}::{expected_variant} {{ .. }} => {{\n"
537 ));
538 } else {
539 let bindings = field_bindings.join(", ");
541 test_code.push_str(&format!(
542 " {target_type}::{expected_variant} {{ {bindings}, .. }} => {{\n"
543 ));
544
545 for (field, expected) in variant_assertions {
547 test_code.push_str(&format!(
548 " assert_eq!({field}, serde_json::json!({expected}));\n"
549 ));
550 }
551 }
552 }
553
554 test_code.push_str(" }\n");
555 test_code.push_str(&format!(
556 " _ => panic!(\"Expected {expected_variant} variant\"),\n"
557 ));
558 test_code.push_str(" }\n");
559 }
560 }
561
562 test_code.push_str(" }\n");
563 test_code
564}
565
566fn field_value_to_code(value: &FieldValue) -> String {
568 match value {
569 FieldValue::String(s) => format!("\"{s}\".to_string()"),
570 FieldValue::Integer(i) => i.to_string(),
571 FieldValue::Float(f) => f.to_string(),
572 FieldValue::Boolean(b) => b.to_string(),
573 FieldValue::Null => "None".to_string(),
574 FieldValue::EnumVariant(variant) => variant.clone(),
575 FieldValue::Array(values) => {
576 let items: Vec<String> = values.iter().map(field_value_to_code).collect();
577 format!("vec![{}]", items.join(", "))
578 }
579 FieldValue::Struct(type_name, fields) => {
580 let mut code = format!("{type_name} {{\n");
581 for (field_name, field_value) in fields {
582 code.push_str(&format!(
583 " {}: {},\n",
584 field_name,
585 field_value_to_code(field_value)
586 ));
587 }
588 code.push_str(" }");
589 code
590 }
591 }
592}
593
594pub fn minimal_spec(schemas: Value) -> Value {
596 serde_json::json!({
597 "openapi": "3.0.0",
598 "info": {
599 "title": "Test API",
600 "version": "1.0.0"
601 },
602 "components": {
603 "schemas": schemas
604 }
605 })
606}
607
608pub fn assert_compilation_success(result: &GenerationTestResult) {
610 if let Some(output) = &result.compilation_output {
611 if !output.status.success() {
612 let stderr = String::from_utf8_lossy(&output.stderr);
613 let stdout = String::from_utf8_lossy(&output.stdout);
614 panic!("Compilation failed!\nSTDERR:\n{stderr}\nSTDOUT:\n{stdout}");
615 }
616 } else {
617 panic!("No compilation was run");
618 }
619}
620
621pub fn find_type_definition(result: &GenerationTestResult, type_name: &str) -> Option<String> {
623 for file in &result.files {
624 for line in file.content.lines() {
625 if (line.contains(&format!("struct {type_name}"))
626 || line.contains(&format!("enum {type_name}"))
627 || line.contains(&format!("type {type_name} =")))
628 && !line.trim().starts_with("//")
629 {
630 return Some(line.to_string());
631 }
632 }
633 }
634 None
635}
636
637pub fn assert_no_underscores_in_type_name(type_definition: &str) {
639 if let Some(type_name) = extract_type_name(type_definition) {
640 assert!(
641 !type_name.contains('_'),
642 "Type name '{type_name}' should not contain underscores"
643 );
644 }
645}
646
647fn extract_type_name(type_def: &str) -> Option<String> {
648 let parts: Vec<&str> = type_def.split_whitespace().collect();
649 if parts.len() >= 2 {
650 let name = parts[1].split('<').next()?.split('=').next()?.trim();
651 Some(name.to_string())
652 } else {
653 None
654 }
655}
656
657fn to_rust_type_name(s: &str) -> String {
659 let mut result = String::new();
660 let mut next_upper = true;
661
662 for c in s.chars() {
663 match c {
664 'a'..='z' => {
665 if next_upper {
666 result.push(c.to_ascii_uppercase());
667 next_upper = false;
668 } else {
669 result.push(c);
670 }
671 }
672 'A'..='Z' => {
673 result.push(c);
674 next_upper = false;
675 }
676 '0'..='9' => {
677 result.push(c);
678 next_upper = false;
679 }
680 '_' | '-' | '.' | ' ' => {
681 next_upper = true;
683 }
684 _ => {
685 next_upper = true;
687 }
688 }
689 }
690
691 result
692}