use anyhow::{Context, Result};
use clap::{Parser, Subcommand};
use notify::{Config as NotifyConfig, RecommendedWatcher, RecursiveMode, Watcher};
use std::path::PathBuf;
use std::sync::mpsc::channel;
use std::time::Duration;
use rescript_openapi::{codegen, ir, parser};
#[derive(Parser)]
#[command(name = "rescript-openapi")]
#[command(about = "Generate type-safe ReScript clients from OpenAPI specifications")]
#[command(version)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Generate {
#[arg(short, long)]
input: PathBuf,
#[arg(short, long, default_value = "src/api")]
output: PathBuf,
#[arg(short, long, default_value = "Api")]
module: String,
#[arg(long, default_value = "true")]
with_schema: bool,
#[arg(long, default_value = "true")]
with_client: bool,
#[arg(short, long)]
watch: bool,
#[arg(long)]
dry_run: bool,
#[arg(long)]
unified: bool,
#[arg(long, default_value = "full")]
client_mode: String,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long, default_value = "polymorphic")]
variant_mode: String,
},
Validate {
#[arg(short, long)]
input: PathBuf,
},
Info {
#[arg(short, long)]
input: PathBuf,
},
}
struct GeneratedCode {
filename: String,
content: String,
}
fn generate_code(
input_path: &PathBuf,
config: &codegen::Config,
) -> Result<Vec<GeneratedCode>> {
let spec = parser::parse_spec(input_path)
.with_context(|| format!("Failed to parse OpenAPI spec: {:?}", input_path))?;
let api_spec = ir::lower(&spec)
.context("Failed to lower OpenAPI spec to IR")?;
let mut generated_files = Vec::new();
if config.unified_module {
let mut output = String::new();
output.push_str("// SPDX-License-Identifier: PMPL-1.0-or-later\n");
output.push_str("// Generated by rescript-openapi - DO NOT EDIT\n");
output.push_str(&format!("// Source: {} v{}\n\n", api_spec.title, api_spec.version));
output.push_str("open RescriptCore\n");
output.push_str("module S = RescriptSchema.S\n\n");
let sccs = codegen::schema::topological_sort_scc(&api_spec.types);
for scc in sccs {
output.push_str(&codegen::types::generate_scc(&scc, config));
for type_def in scc {
output.push_str(&codegen::schema::generate_schema_only(type_def, config));
output.push('\n');
}
output.push('\n');
}
generated_files.push(GeneratedCode {
filename: format!("{}.res", config.module_prefix),
content: output,
});
} else {
let types_code = codegen::types::generate(&api_spec, config)
.context("Failed to generate types")?;
generated_files.push(GeneratedCode {
filename: format!("{}Types.res", config.module_prefix),
content: types_code,
});
if config.generate_schema {
let schema_code = codegen::schema::generate(&api_spec, config)
.context("Failed to generate schema")?;
generated_files.push(GeneratedCode {
filename: format!("{}Schema.res", config.module_prefix),
content: schema_code,
});
}
}
if config.generate_client {
let client_code = codegen::client::generate(&api_spec, config)
.context("Failed to generate client")?;
generated_files.push(GeneratedCode {
filename: format!("{}Client.res", config.module_prefix),
content: client_code,
});
}
Ok(generated_files)
}
fn write_generated_code(config: &codegen::Config, generated_files: &[GeneratedCode]) -> Result<()> {
std::fs::create_dir_all(&config.output_dir)
.with_context(|| format!("Failed to create output directory: {:?}", config.output_dir))?;
for generated_file in generated_files {
let file_path = config.output_dir.join(&generated_file.filename);
std::fs::write(&file_path, &generated_file.content)
.with_context(|| format!("Failed to write file: {:?}", file_path))?;
}
Ok(())
}
fn print_generated_code(generated_files: &[GeneratedCode]) {
for (index, generated_file) in generated_files.iter().enumerate() {
if index > 0 {
println!("\n{}", "=".repeat(80));
}
println!("// FILE: {}", generated_file.filename);
println!("{}", "=".repeat(80));
println!("{}", generated_file.content);
}
}
fn run_generate(
input_path: &PathBuf,
config: &codegen::Config,
dry_run_mode: bool,
) -> Result<()> {
let generated_files = generate_code(input_path, config)?;
if dry_run_mode {
print_generated_code(&generated_files);
} else {
write_generated_code(config, &generated_files)?;
println!(
"Generated ReScript code in {:?}",
config.output_dir
);
}
Ok(())
}
fn watch_and_regenerate(
input_path: &PathBuf,
config: &codegen::Config,
dry_run_mode: bool,
) -> Result<()> {
println!("Watching {:?} for changes...", input_path);
if let Err(error) = run_generate(input_path, config, dry_run_mode) {
eprintln!("Error during initial generation: {}", error);
}
let (sender, receiver) = channel();
let notify_config = NotifyConfig::default()
.with_poll_interval(Duration::from_secs(1));
let mut watcher: RecommendedWatcher = Watcher::new(sender, notify_config)
.context("Failed to create file watcher")?;
let watch_path = input_path
.parent()
.unwrap_or(input_path.as_path());
watcher
.watch(watch_path, RecursiveMode::NonRecursive)
.with_context(|| format!("Failed to watch path: {:?}", watch_path))?;
println!("Press Ctrl+C to stop watching.\n");
loop {
match receiver.recv() {
Ok(event_result) => {
match event_result {
Ok(event) => {
let is_our_file = event.paths.iter().any(|path| {
path.file_name() == input_path.file_name()
});
if is_our_file {
use notify::EventKind;
match event.kind {
EventKind::Modify(_) | EventKind::Create(_) => {
println!("\nFile changed, regenerating...");
match run_generate(input_path, config, dry_run_mode) {
Ok(()) => {
if !dry_run_mode {
println!("Regeneration complete.");
}
}
Err(error) => {
eprintln!("Error during regeneration: {}", error);
}
}
}
_ => {}
}
}
}
Err(error) => {
eprintln!("Watch error: {}", error);
}
}
}
Err(error) => {
eprintln!("Channel receive error: {}", error);
break;
}
}
}
Ok(())
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command {
Commands::Generate {
input,
output,
module,
with_schema,
with_client,
watch,
dry_run,
unified,
client_mode,
config,
variant_mode,
} => {
let config_path = config.unwrap_or_else(|| PathBuf::from("rescript-openapi.toml"));
let project_config = if config_path.exists() {
Some(codegen::ProjectConfig::load(&config_path)?)
} else {
None
};
let final_input = input; let final_output = project_config.as_ref().and_then(|c| c.output.clone()).unwrap_or(output);
let final_module = project_config.as_ref().and_then(|c| c.module.clone()).unwrap_or(module);
let final_with_schema = project_config.as_ref().and_then(|c| c.with_schema).unwrap_or(with_schema);
let final_with_client = project_config.as_ref().and_then(|c| c.with_client).unwrap_or(with_client);
let final_unified = project_config.as_ref().and_then(|c| c.unified).unwrap_or(unified);
let final_client_mode = match project_config.as_ref().and_then(|c| c.client_mode) {
Some(mode) => mode,
None => match client_mode.as_str() {
"full" => codegen::ClientMode::Full,
"functor-only" => codegen::ClientMode::FunctorOnly,
"none" => codegen::ClientMode::None,
_ => anyhow::bail!("Invalid client-mode: {}", client_mode),
},
};
let final_variant_mode = match project_config.as_ref().and_then(|c| c.variant_mode) {
Some(mode) => mode,
None => match variant_mode.as_str() {
"polymorphic" => codegen::VariantMode::Polymorphic,
"standard" => codegen::VariantMode::Standard,
_ => anyhow::bail!("Invalid variant-mode: {}", variant_mode),
},
};
let config = codegen::Config {
output_dir: final_output,
module_prefix: final_module,
generate_schema: final_with_schema,
generate_client: final_with_client,
unified_module: final_unified,
client_mode: final_client_mode,
variant_mode: final_variant_mode,
};
if watch {
watch_and_regenerate(&final_input, &config, dry_run)?;
} else {
run_generate(&final_input, &config, dry_run)?;
}
}
Commands::Validate { input } => {
let spec = parser::parse_spec(&input)?;
let diagnostics = parser::validate(&spec);
if diagnostics.is_empty() {
println!("OpenAPI spec is valid");
} else {
for diagnostic in &diagnostics {
eprintln!("{}", diagnostic);
}
std::process::exit(1);
}
}
Commands::Info { input } => {
let spec = parser::parse_spec(&input)?;
println!("Title: {}", spec.info.title);
println!("Version: {}", spec.info.version);
if let Some(description) = &spec.info.description {
println!("Description: {}", description);
}
println!("Paths: {}", spec.paths.paths.len());
let schema_count = spec
.components
.as_ref()
.map(|components| components.schemas.len())
.unwrap_or(0);
println!("Schemas: {}", schema_count);
}
}
Ok(())
}