1use crate::{CodeGenerator, GeneratedFile, GeneratorConfig, SchemaAnalyzer};
2use serde_json::Value;
3use std::env;
4use std::fs;
5use std::path::PathBuf;
6use std::process::Command;
7use tempfile::TempDir;
8
9pub struct GenerationTestResult {
11 pub temp_dir: TempDir,
13 pub project_dir: PathBuf,
15 pub generated_src_dir: PathBuf,
17 pub files: Vec<GeneratedFile>,
19 pub compilation_output: Option<std::process::Output>,
21 pub cargo_toml_path: PathBuf,
23}
24
25impl GenerationTestResult {
26 pub fn read_file(&self, filename: &str) -> std::io::Result<String> {
28 let path = self.generated_src_dir.join(filename);
29 fs::read_to_string(path)
30 }
31
32 pub fn compiled_successfully(&self) -> bool {
34 self.compilation_output
35 .as_ref()
36 .map(|o| o.status.success())
37 .unwrap_or(false)
38 }
39
40 pub fn compilation_errors(&self) -> String {
42 self.compilation_output
43 .as_ref()
44 .map(|o| String::from_utf8_lossy(&o.stderr).to_string())
45 .unwrap_or_default()
46 }
47
48 pub fn test_output(&self) -> Option<String> {
50 self.compilation_output
51 .as_ref()
52 .map(|o| String::from_utf8_lossy(&o.stdout).to_string())
53 }
54}
55
56pub struct GenerationTest {
58 pub name: String,
60 pub spec: Value,
62 pub config_overrides: Option<GeneratorConfigOverrides>,
64 pub test_scenarios: Vec<TestScenario>,
66 pub run_tests: bool,
68 pub compile: bool,
70 pub extra_dependencies: Vec<(String, String)>,
72}
73
74#[derive(Clone)]
76pub struct TestScenario {
77 pub name: String,
79 pub target_type: String,
81 pub behavior: TestBehavior,
83}
84
85#[derive(Clone)]
87pub enum TestBehavior {
88 RoundTrip {
90 json: Value,
92 assertions: Vec<(String, Value)>,
94 },
95 Construction {
97 fields: Vec<(String, FieldValue)>,
100 assertions: Vec<ConstructionAssertion>,
102 },
103 CompileOnly,
105 EnumMatch {
107 json: Value,
109 expected_variant: String,
111 variant_assertions: Vec<(String, Value)>,
113 },
114}
115
116#[derive(Clone)]
118pub enum FieldValue {
119 String(String),
121 Integer(i64),
123 Float(f64),
125 Boolean(bool),
127 Null,
129 EnumVariant(String),
131 Array(Vec<FieldValue>),
133 Struct(String, Vec<(String, FieldValue)>),
135}
136
137#[derive(Clone)]
139pub enum ConstructionAssertion {
140 JsonContains(String),
142 FieldEquals(String, Value),
144 CanSerialize,
146}
147
148pub struct GeneratorConfigOverrides {
150 pub module_name: Option<String>,
151 pub enable_sse_client: Option<bool>,
152 pub enable_async_client: Option<bool>,
153}
154
155impl GenerationTest {
156 pub fn new(name: impl Into<String>, spec: Value) -> Self {
158 Self {
159 name: name.into(),
160 spec,
161 ..Default::default()
162 }
163 }
164
165 pub fn with_compilation(mut self) -> Self {
167 self.compile = true;
168 self
169 }
170
171 pub fn with_env_compilation(mut self) -> Self {
173 self.compile = should_compile_tests();
174 self
175 }
176
177 pub fn with_scenario(mut self, scenario: TestScenario) -> Self {
179 self.test_scenarios.push(scenario);
180 self
181 }
182
183 pub fn run_tests(mut self, run: bool) -> Self {
185 self.run_tests = run;
186 self
187 }
188}
189
190impl Default for GenerationTest {
191 fn default() -> Self {
192 Self {
193 name: "test".to_string(),
194 spec: Value::Null,
195 config_overrides: None,
196 test_scenarios: vec![],
197 run_tests: true,
198 extra_dependencies: vec![],
199 compile: false, }
201 }
202}
203
204pub fn should_compile_tests() -> bool {
206 env::var("OPENAPI_GEN_COMPILE_TESTS")
207 .map(|v| v == "1" || v.to_lowercase() == "true")
208 .unwrap_or(false)
209}
210
211pub fn test_generation(name: &str, spec: Value) -> Result<String, Box<dyn std::error::Error>> {
214 let mut analyzer = SchemaAnalyzer::new(spec)?;
216 let mut analysis = analyzer.analyze()?;
217
218 let config = GeneratorConfig {
220 module_name: name.to_string(),
221 ..Default::default()
222 };
223 let generator = CodeGenerator::new(config);
224 let generated_code = generator.generate(&mut analysis)?;
225
226 insta::assert_snapshot!(name, &generated_code);
230
231 Ok(generated_code)
232}
233
234pub fn run_generation_test(
236 test: GenerationTest,
237) -> Result<GenerationTestResult, Box<dyn std::error::Error>> {
238 let temp_dir = TempDir::new()?;
240 let project_dir = temp_dir.path().join(&test.name);
241 fs::create_dir(&project_dir)?;
242
243 let src_dir = project_dir.join("src");
245 fs::create_dir(&src_dir)?;
246 let generated_src_dir = src_dir.join("generated");
247 fs::create_dir(&generated_src_dir)?;
248
249 let mut analyzer = SchemaAnalyzer::new(test.spec.clone())?;
251 let analysis = analyzer.analyze()?;
252
253 let config = GeneratorConfig {
255 spec_path: PathBuf::from("test.json"), output_dir: generated_src_dir.clone(),
257 module_name: test
258 .config_overrides
259 .as_ref()
260 .and_then(|o| o.module_name.clone())
261 .unwrap_or_else(|| "generated".to_string()),
262 enable_sse_client: test
263 .config_overrides
264 .as_ref()
265 .and_then(|o| o.enable_sse_client)
266 .unwrap_or(false),
267 enable_async_client: test
268 .config_overrides
269 .as_ref()
270 .and_then(|o| o.enable_async_client)
271 .unwrap_or(false),
272 enable_specta: false,
273 type_mappings: {
274 let mut mappings = std::collections::BTreeMap::new();
275 mappings.insert("integer".to_string(), "i64".to_string());
276 mappings.insert("number".to_string(), "f64".to_string());
277 mappings.insert("string".to_string(), "String".to_string());
278 mappings.insert("boolean".to_string(), "bool".to_string());
279 mappings
280 },
281 streaming_config: None,
282 nullable_field_overrides: std::collections::BTreeMap::new(),
283 schema_extensions: vec![],
284 http_client_config: None,
285 retry_config: None,
286 tracing_enabled: true,
287 auth_config: None,
288 };
289
290 let generator = CodeGenerator::new(config);
292 let mut analysis_mut = analysis;
293 let types_content = generator.generate(&mut analysis_mut)?;
294
295 let files = vec![GeneratedFile {
297 path: "types.rs".into(),
298 content: types_content.clone(),
299 }];
300
301 for file in &files {
303 let dest_path = generated_src_dir.join(&file.path);
304 if let Some(parent) = dest_path.parent() {
305 fs::create_dir_all(parent)?;
306 }
307 fs::write(dest_path, &file.content)?;
308 }
309
310 let mut dependencies = vec![
312 ("serde", r#"{ version = "1.0", features = ["derive"] }"#),
313 ("serde_json", r#""1.0""#),
314 ("async-trait", r#""0.1""#),
315 (
316 "reqwest",
317 r#"{ version = "0.12", features = ["json", "stream"] }"#,
318 ),
319 ("futures-util", r#""0.3""#),
320 ("tokio", r#"{ version = "1.0", features = ["full"] }"#),
321 ("tracing", r#""0.1""#),
322 ];
323
324 for (name, version) in &test.extra_dependencies {
326 dependencies.push((name.as_str(), version.as_str()));
327 }
328
329 let deps_str = dependencies
330 .iter()
331 .map(|(name, ver)| format!("{name} = {ver}"))
332 .collect::<Vec<_>>()
333 .join("\n");
334
335 let cargo_toml = format!(
336 r#"[package]
337name = "{}"
338version = "0.1.0"
339edition = "2021"
340
341[dependencies]
342{}
343"#,
344 test.name.replace('-', "_"),
345 deps_str
346 );
347
348 let cargo_toml_path = project_dir.join("Cargo.toml");
349 fs::write(&cargo_toml_path, cargo_toml)?;
350
351 fs::write(generated_src_dir.join("mod.rs"), "pub mod types;\n")?;
353
354 let lib_rs = create_lib_rs(&test, &files);
356 fs::write(src_dir.join("lib.rs"), lib_rs)?;
357
358 let compilation_output = if !test.compile {
360 None
361 } else if test.run_tests {
362 Some(
364 Command::new("cargo")
365 .arg("test")
366 .arg("--")
367 .arg("--nocapture")
368 .current_dir(&project_dir)
369 .env("RUST_BACKTRACE", "1")
370 .output()?,
371 )
372 } else {
373 Some(
375 Command::new("cargo")
376 .arg("check")
377 .current_dir(&project_dir)
378 .env("RUST_BACKTRACE", "1")
379 .output()?,
380 )
381 };
382
383 if let Some(ref output) = compilation_output {
385 if !output.status.success() {
386 if let Ok(types_content) = fs::read_to_string(generated_src_dir.join("types.rs")) {
387 eprintln!("Generated types.rs:\n{types_content}");
388 }
389 }
390 }
391
392 Ok(GenerationTestResult {
393 temp_dir,
394 project_dir,
395 generated_src_dir,
396 files,
397 compilation_output,
398 cargo_toml_path,
399 })
400}
401
402fn create_lib_rs(test: &GenerationTest, _generated_files: &[GeneratedFile]) -> String {
403 let mut lib_content = String::from("pub mod generated;\n\n");
404
405 lib_content.push_str("#[cfg(test)]\nmod tests {\n");
407 lib_content.push_str(" use super::generated::types::*;\n");
408 lib_content.push_str(" use serde_json;\n\n");
409
410 lib_content.push_str(" #[test]\n");
412 lib_content.push_str(" fn test_compilation() {\n");
413 lib_content.push_str(&format!(
414 " // Generated types compile for test: {}\n",
415 test.name
416 ));
417 lib_content.push_str(" }\n\n");
418
419 for scenario in &test.test_scenarios {
421 lib_content.push_str(&generate_test_from_scenario(scenario));
422 lib_content.push('\n');
423 }
424
425 lib_content.push_str("}\n");
426 lib_content
427}
428
429fn generate_test_from_scenario(scenario: &TestScenario) -> String {
431 let mut test_code = String::new();
432
433 let target_type = to_rust_type_name(&scenario.target_type);
435
436 test_code.push_str(&format!(" #[test]\n fn {}() {{\n", scenario.name));
437
438 match &scenario.behavior {
439 TestBehavior::RoundTrip { json, assertions } => {
440 test_code.push_str(&format!(" let json_str = r#\"{json}\"#;\n"));
442 test_code.push_str(&format!(
443 " let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
444 ));
445
446 for (field, expected) in assertions {
448 test_code.push_str(&format!(
449 " assert_eq!(parsed.{field}, serde_json::json!({expected}));\n"
450 ));
451 }
452
453 test_code
455 .push_str(" let serialized = serde_json::to_string(&parsed).unwrap();\n");
456 test_code.push_str(&format!(
457 " let round_trip: {target_type} = serde_json::from_str(&serialized).unwrap();\n"
458 ));
459 test_code.push_str(" assert_eq!(serde_json::to_value(parsed).unwrap(), serde_json::to_value(round_trip).unwrap());\n");
460 }
461
462 TestBehavior::Construction { fields, assertions } => {
463 test_code.push_str(&format!(" let instance = {target_type} {{\n"));
465
466 for (field_name, field_value) in fields {
467 test_code.push_str(&format!(
468 " {}: {},\n",
469 field_name,
470 field_value_to_code(field_value)
471 ));
472 }
473
474 test_code.push_str(" };\n\n");
475
476 for assertion in assertions {
478 match assertion {
479 ConstructionAssertion::JsonContains(text) => {
480 test_code.push_str(
481 " let json = serde_json::to_string(&instance).unwrap();\n",
482 );
483 test_code
484 .push_str(&format!(" assert!(json.contains(\"{text}\"));\n"));
485 }
486 ConstructionAssertion::FieldEquals(field, value) => {
487 test_code.push_str(&format!(
488 " assert_eq!(instance.{field}, serde_json::json!({value}));\n"
489 ));
490 }
491 ConstructionAssertion::CanSerialize => {
492 test_code.push_str(
493 " let _json = serde_json::to_string(&instance).unwrap();\n",
494 );
495 }
496 }
497 }
498 }
499
500 TestBehavior::CompileOnly => {
501 test_code.push_str(&format!(
502 " // Type {target_type} exists and compiles\n"
503 ));
504 test_code.push_str(&format!(" let _: Option<{target_type}> = None;\n"));
505 }
506
507 TestBehavior::EnumMatch {
508 json,
509 expected_variant,
510 variant_assertions,
511 } => {
512 test_code.push_str(&format!(" let json_str = r#\"{json}\"#;\n"));
513 test_code.push_str(&format!(
514 " let parsed: {target_type} = serde_json::from_str(json_str).unwrap();\n"
515 ));
516 test_code.push_str(" match parsed {\n");
517
518 if variant_assertions.is_empty() {
520 test_code.push_str(&format!(
522 " {target_type}::{expected_variant} {{ .. }} => {{\n"
523 ));
524 test_code.push_str(" // Variant matched successfully\n");
525 } else {
526 let field_bindings: Vec<String> = variant_assertions
528 .iter()
529 .map(|(field, _)| field.clone())
530 .collect();
531
532 if field_bindings.is_empty() {
533 test_code.push_str(&format!(
534 " {target_type}::{expected_variant} {{ .. }} => {{\n"
535 ));
536 } else {
537 let bindings = field_bindings.join(", ");
539 test_code.push_str(&format!(
540 " {target_type}::{expected_variant} {{ {bindings}, .. }} => {{\n"
541 ));
542
543 for (field, expected) in variant_assertions {
545 test_code.push_str(&format!(
546 " assert_eq!({field}, serde_json::json!({expected}));\n"
547 ));
548 }
549 }
550 }
551
552 test_code.push_str(" }\n");
553 test_code.push_str(&format!(
554 " _ => panic!(\"Expected {expected_variant} variant\"),\n"
555 ));
556 test_code.push_str(" }\n");
557 }
558 }
559
560 test_code.push_str(" }\n");
561 test_code
562}
563
564fn field_value_to_code(value: &FieldValue) -> String {
566 match value {
567 FieldValue::String(s) => format!("\"{s}\".to_string()"),
568 FieldValue::Integer(i) => i.to_string(),
569 FieldValue::Float(f) => f.to_string(),
570 FieldValue::Boolean(b) => b.to_string(),
571 FieldValue::Null => "None".to_string(),
572 FieldValue::EnumVariant(variant) => variant.clone(),
573 FieldValue::Array(values) => {
574 let items: Vec<String> = values.iter().map(field_value_to_code).collect();
575 format!("vec![{}]", items.join(", "))
576 }
577 FieldValue::Struct(type_name, fields) => {
578 let mut code = format!("{type_name} {{\n");
579 for (field_name, field_value) in fields {
580 code.push_str(&format!(
581 " {}: {},\n",
582 field_name,
583 field_value_to_code(field_value)
584 ));
585 }
586 code.push_str(" }");
587 code
588 }
589 }
590}
591
592pub fn minimal_spec(schemas: Value) -> Value {
594 serde_json::json!({
595 "openapi": "3.0.0",
596 "info": {
597 "title": "Test API",
598 "version": "1.0.0"
599 },
600 "components": {
601 "schemas": schemas
602 }
603 })
604}
605
606pub fn assert_compilation_success(result: &GenerationTestResult) {
608 if let Some(output) = &result.compilation_output {
609 if !output.status.success() {
610 let stderr = String::from_utf8_lossy(&output.stderr);
611 let stdout = String::from_utf8_lossy(&output.stdout);
612 panic!("Compilation failed!\nSTDERR:\n{stderr}\nSTDOUT:\n{stdout}");
613 }
614 } else {
615 panic!("No compilation was run");
616 }
617}
618
619pub fn find_type_definition(result: &GenerationTestResult, type_name: &str) -> Option<String> {
621 for file in &result.files {
622 for line in file.content.lines() {
623 if (line.contains(&format!("struct {type_name}"))
624 || line.contains(&format!("enum {type_name}"))
625 || line.contains(&format!("type {type_name} =")))
626 && !line.trim().starts_with("//")
627 {
628 return Some(line.to_string());
629 }
630 }
631 }
632 None
633}
634
635pub fn assert_no_underscores_in_type_name(type_definition: &str) {
637 if let Some(type_name) = extract_type_name(type_definition) {
638 assert!(
639 !type_name.contains('_'),
640 "Type name '{type_name}' should not contain underscores"
641 );
642 }
643}
644
645fn extract_type_name(type_def: &str) -> Option<String> {
646 let parts: Vec<&str> = type_def.split_whitespace().collect();
647 if parts.len() >= 2 {
648 let name = parts[1].split('<').next()?.split('=').next()?.trim();
649 Some(name.to_string())
650 } else {
651 None
652 }
653}
654
655fn to_rust_type_name(s: &str) -> String {
657 let mut result = String::new();
658 let mut next_upper = true;
659
660 for c in s.chars() {
661 match c {
662 'a'..='z' => {
663 if next_upper {
664 result.push(c.to_ascii_uppercase());
665 next_upper = false;
666 } else {
667 result.push(c);
668 }
669 }
670 'A'..='Z' => {
671 result.push(c);
672 next_upper = false;
673 }
674 '0'..='9' => {
675 result.push(c);
676 next_upper = false;
677 }
678 '_' | '-' | '.' | ' ' => {
679 next_upper = true;
681 }
682 _ => {
683 next_upper = true;
685 }
686 }
687 }
688
689 result
690}