Skip to main content

openapi_to_rust/
cli.rs

1use crate::{CodeGenerator, GeneratorConfig, SchemaAnalyzer, streaming::StreamingConfig};
2use clap::{Arg, Command};
3use std::fs;
4use std::process;
5
6/// Configuration for the CLI helper
7pub struct CliConfig {
8    /// Name of the API for display purposes
9    pub api_name: &'static str,
10    /// Default module name
11    pub default_module_name: &'static str,
12    /// Streaming configuration for this API
13    pub streaming_config: Option<StreamingConfig>,
14    /// Enable Specta type derives for frontend integration
15    pub enable_specta: bool,
16}
17
18/// Run the complete generation CLI with the provided configuration
19pub async fn run_generation_cli(cli_config: CliConfig) {
20    let matches = Command::new("api-gen")
21        .version("0.1.0")
22        .about("Generate API types and streaming client")
23        .arg(
24            Arg::new("input")
25                .help("Input OpenAPI spec (file path or URL)")
26                .required(true)
27                .index(1),
28        )
29        .arg(
30            Arg::new("output-dir")
31                .long("output-dir")
32                .value_name("DIR")
33                .help("Output directory for generated files (default: src/generated)")
34                .default_value("src/generated"),
35        )
36        .arg(
37            Arg::new("module-name")
38                .short('m')
39                .long("module-name")
40                .value_name("NAME")
41                .help("Generated module name")
42                .default_value(cli_config.default_module_name),
43        )
44        .arg(
45            Arg::new("verbose")
46                .short('v')
47                .long("verbose")
48                .help("Enable verbose output")
49                .action(clap::ArgAction::SetTrue),
50        )
51        .arg(
52            Arg::new("dry-run")
53                .long("dry-run")
54                .help("Print generated code to stdout instead of writing to file")
55                .action(clap::ArgAction::SetTrue),
56        )
57        .get_matches();
58
59    let Some(input) = matches.get_one::<String>("input") else {
60        eprintln!("Error: missing required argument 'input'");
61        process::exit(1);
62    };
63    let Some(output_dir) = matches.get_one::<String>("output-dir") else {
64        eprintln!("Error: missing required argument 'output-dir'");
65        process::exit(1);
66    };
67    let Some(module_name) = matches.get_one::<String>("module-name") else {
68        eprintln!("Error: missing required argument 'module-name'");
69        process::exit(1);
70    };
71    let verbose = matches.get_flag("verbose");
72    let dry_run = matches.get_flag("dry-run");
73
74    if verbose {
75        println!("🚀 {} API Generator", cli_config.api_name);
76        println!("Input: {input}");
77        if !dry_run {
78            println!("Output: {output_dir}");
79        }
80        println!("Module: {module_name}");
81        if cli_config.streaming_config.is_some() {
82            println!("🌊 Streaming: enabled");
83        }
84        println!();
85    }
86
87    // Load OpenAPI spec
88    let spec_content = match load_spec(input, verbose).await {
89        Ok(content) => content,
90        Err(e) => {
91            eprintln!("❌ Error loading spec: {e}");
92            process::exit(1);
93        }
94    };
95
96    if verbose {
97        println!("📄 Loaded OpenAPI spec ({} bytes)", spec_content.len());
98    }
99
100    // Parse the spec
101    let spec_value: serde_json::Value = match parse_spec(&spec_content, input) {
102        Ok(value) => value,
103        Err(e) => {
104            eprintln!("❌ Error parsing spec: {e}");
105            process::exit(1);
106        }
107    };
108
109    if verbose {
110        if let Some(info) = spec_value.get("info") {
111            if let Some(title) = info.get("title").and_then(|t| t.as_str()) {
112                println!("📋 Title: {title}");
113            }
114            if let Some(version) = info.get("version").and_then(|v| v.as_str()) {
115                println!("🏷️  Version: {version}");
116            }
117        }
118        println!();
119    }
120
121    // Analyze schemas
122    if verbose {
123        println!("🔍 Analyzing schemas...");
124    }
125
126    let mut analyzer = match SchemaAnalyzer::new(spec_value) {
127        Ok(analyzer) => analyzer,
128        Err(e) => {
129            eprintln!("❌ Error creating analyzer: {e}");
130            process::exit(1);
131        }
132    };
133
134    let mut analysis = match analyzer.analyze() {
135        Ok(analysis) => analysis,
136        Err(e) => {
137            eprintln!("❌ Error analyzing schemas: {e}");
138            process::exit(1);
139        }
140    };
141
142    if verbose {
143        println!("📈 Found {} schemas", analysis.schemas.len());
144        println!("📈 Found {} operations", analysis.operations.len());
145        if let Some(ref config) = cli_config.streaming_config {
146            println!("🌊 Found {} streaming endpoints", config.endpoints.len());
147        }
148        println!();
149    }
150
151    // Generate code
152    if verbose {
153        let stream_status = if cli_config.streaming_config.is_some() {
154            "with streaming support"
155        } else {
156            ""
157        };
158        println!(
159            "⚙️  Generating {} API code {}...",
160            cli_config.api_name, stream_status
161        );
162    }
163
164    let config = GeneratorConfig {
165        module_name: module_name.clone(),
166        output_dir: output_dir.into(),
167        streaming_config: cli_config.streaming_config,
168        enable_specta: cli_config.enable_specta,
169        ..Default::default()
170    };
171
172    let generator = CodeGenerator::new(config);
173    let generation_result = match generator.generate_all(&mut analysis) {
174        Ok(result) => result,
175        Err(e) => {
176            eprintln!("❌ Error generating code: {e}");
177            process::exit(1);
178        }
179    };
180
181    if dry_run {
182        println!("=== Generated Files ===");
183        for file in &generation_result.files {
184            println!("\n--- {} ---", file.path.display());
185            println!("{}", file.content);
186        }
187        println!("\n--- {} ---", generation_result.mod_file.path.display());
188        println!("{}", generation_result.mod_file.content);
189    } else {
190        // Write all files to disk
191        if let Err(e) = generator.write_files(&generation_result) {
192            eprintln!("❌ Error writing files: {e}");
193            process::exit(1);
194        }
195
196        if verbose {
197            println!(
198                "✅ Generated {} files written to: {}",
199                generation_result.files.len() + 1,
200                generator.config().output_dir.display()
201            );
202            for file in &generation_result.files {
203                println!("   - {}", file.path.display());
204            }
205            println!("   - {}", generation_result.mod_file.path.display());
206        } else {
207            println!(
208                "✅ Generated {} files written to: {}",
209                generation_result.files.len() + 1,
210                generator.config().output_dir.display()
211            );
212        }
213    }
214}
215
216async fn load_spec(input: &str, verbose: bool) -> Result<String, Box<dyn std::error::Error>> {
217    if input.starts_with("http://") || input.starts_with("https://") {
218        // Load from URL
219        if verbose {
220            println!("🌐 Fetching from URL...");
221        }
222
223        let response = reqwest::get(input).await?;
224        if !response.status().is_success() {
225            return Err(format!("HTTP error: {}", response.status()).into());
226        }
227
228        let content = response.text().await?;
229        Ok(content)
230    } else {
231        // Load from file
232        if verbose {
233            println!("📁 Reading from file...");
234        }
235
236        let content = fs::read_to_string(input)?;
237        Ok(content)
238    }
239}
240
241fn parse_spec(content: &str, input: &str) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
242    // Determine format from extension or content
243    let is_yaml = input.ends_with(".yaml")
244        || input.ends_with(".yml")
245        || content.trim_start().starts_with("openapi:")
246        || content.trim_start().starts_with("swagger:");
247
248    if is_yaml {
249        let value: serde_json::Value = serde_yaml::from_str(content)?;
250        Ok(value)
251    } else {
252        let value: serde_json::Value = serde_json::from_str(content)?;
253        Ok(value)
254    }
255}