use crate::ir::{ApiSpec, TypeDef, Field, RsType};
use super::Config;
use anyhow::Result;
use heck::ToLowerCamelCase;
use std::collections::{HashMap, HashSet, VecDeque};
pub fn generate(spec: &ApiSpec, config: &Config) -> Result<String> {
let mut output = String::new();
output.push_str("// SPDX-License-Identifier: PMPL-1.0-or-later\n");
output.push_str("// Generated by rescript-openapi - DO NOT EDIT\n");
output.push_str(&format!("// Source: {} v{}\n\n", spec.title, spec.version));
output.push_str(&format!("open {}Types\n\n", config.module_prefix));
output.push_str("module S = RescriptSchema.S\n\n");
let sorted_types = topological_sort(&spec.types);
for type_def in sorted_types {
output.push_str(&generate_schema(type_def, config));
output.push('\n');
}
Ok(output)
}
pub fn generate_schema_only(type_def: &TypeDef, config: &Config) -> String {
let mut output = String::new();
output.push_str(&generate_schema(type_def, config));
output
}
pub fn get_dependencies(type_def: &TypeDef) -> HashSet<String> {
let mut deps = HashSet::new();
match type_def {
TypeDef::Record { fields, .. } => {
for field in fields {
collect_type_deps(&field.ty, &mut deps);
}
}
TypeDef::Variant { cases, .. } => {
for case in cases {
if let Some(ty) = &case.payload {
collect_type_deps(ty, &mut deps);
}
}
}
TypeDef::Alias { target, .. } => {
collect_type_deps(target, &mut deps);
}
}
deps
}
fn collect_type_deps(ty: &RsType, deps: &mut HashSet<String>) {
match ty {
RsType::Named(name) => {
deps.insert(name.to_lower_camel_case());
}
RsType::Option(inner) | RsType::Array(inner) | RsType::Dict(inner) => {
collect_type_deps(inner, deps);
}
RsType::Tuple(types) => {
for t in types {
collect_type_deps(t, deps);
}
}
_ => {}
}
}
pub fn topological_sort(types: &[TypeDef]) -> Vec<&TypeDef> {
let type_map: HashMap<String, &TypeDef> = types
.iter()
.map(|t| {
let name = match t {
TypeDef::Record { name, .. } => name.to_lower_camel_case(),
TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
};
(name, t)
})
.collect();
let mut deps_map: HashMap<String, HashSet<String>> = HashMap::new();
let mut all_names: Vec<String> = Vec::new();
for type_def in types {
let name = match type_def {
TypeDef::Record { name, .. } => name.to_lower_camel_case(),
TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
};
all_names.push(name.clone());
let deps = get_dependencies(type_def);
let filtered_deps: HashSet<String> = deps
.into_iter()
.filter(|d| type_map.contains_key(d))
.collect();
deps_map.insert(name, filtered_deps);
}
let mut in_degree: HashMap<String, usize> = HashMap::new();
for name in &all_names {
in_degree.insert(name.clone(), 0);
}
for (name, deps) in &deps_map {
*in_degree.get_mut(name).unwrap() += deps.len();
}
let mut zero_degree: Vec<String> = in_degree
.iter()
.filter(|(_, °ree)| degree == 0)
.map(|(name, _)| name.clone())
.collect();
zero_degree.sort();
let mut queue: VecDeque<String> = zero_degree.into_iter().collect();
let mut sorted: Vec<&TypeDef> = Vec::new();
while let Some(name) = queue.pop_front() {
if let Some(&type_def) = type_map.get(&name) {
sorted.push(type_def);
}
let mut newly_ready: Vec<String> = Vec::new();
for (other_name, other_deps) in &deps_map {
if other_deps.contains(&name) {
let degree = in_degree.get_mut(other_name).unwrap();
*degree -= 1;
if *degree == 0 {
newly_ready.push(other_name.clone());
}
}
}
newly_ready.sort();
for ready_name in newly_ready {
queue.push_back(ready_name);
}
}
if sorted.len() < types.len() {
for type_def in types {
let name = match type_def {
TypeDef::Record { name, .. } => name.to_lower_camel_case(),
TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
};
if !sorted.iter().any(|t| {
let n = match t {
TypeDef::Record { name, .. } => name.to_lower_camel_case(),
TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
};
n == name
}) {
sorted.push(type_def);
}
}
}
sorted
}
pub fn topological_sort_scc(types: &[TypeDef]) -> Vec<Vec<&TypeDef>> {
let mut type_map: HashMap<String, &TypeDef> = HashMap::new();
let mut name_to_index: HashMap<String, usize> = HashMap::new();
let mut index_to_name: Vec<String> = Vec::new();
for (i, t) in types.iter().enumerate() {
let name = match t {
TypeDef::Record { name, .. } => name.to_lower_camel_case(),
TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
};
type_map.insert(name.clone(), t);
name_to_index.insert(name.clone(), i);
index_to_name.push(name);
}
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); types.len()];
for (i, type_def) in types.iter().enumerate() {
let deps = get_dependencies(type_def);
let mut sorted_deps: Vec<String> = deps.into_iter().collect();
sorted_deps.sort();
for dep_name in sorted_deps {
if let Some(&dep_idx) = name_to_index.get(&dep_name) {
adj[i].push(dep_idx);
}
}
}
let n = types.len();
let mut visited = vec![false; n];
let mut stack = Vec::new();
let mut on_stack = vec![false; n];
let mut ids = vec![-1; n];
let mut low = vec![-1; n];
let mut id_counter = 0;
let mut sccs: Vec<Vec<usize>> = Vec::new();
for i in 0..n {
if !visited[i] {
tarjan_dfs(
i,
&adj,
&mut visited,
&mut stack,
&mut on_stack,
&mut ids,
&mut low,
&mut id_counter,
&mut sccs,
);
}
}
let mut result = Vec::new();
for scc_indices in sccs {
let mut scc_types = Vec::new();
for &idx in &scc_indices {
scc_types.push(&types[idx]);
}
result.push(scc_types);
}
result
}
#[allow(clippy::too_many_arguments)]
fn tarjan_dfs(
at: usize,
adj: &Vec<Vec<usize>>,
visited: &mut Vec<bool>,
stack: &mut Vec<usize>,
on_stack: &mut Vec<bool>,
ids: &mut Vec<i32>,
low: &mut Vec<i32>,
id_counter: &mut i32,
sccs: &mut Vec<Vec<usize>>,
) {
visited[at] = true;
stack.push(at);
on_stack[at] = true;
ids[at] = *id_counter;
low[at] = *id_counter;
*id_counter += 1;
for &to in &adj[at] {
if !visited[to] {
tarjan_dfs(
to, adj, visited, stack, on_stack, ids, low, id_counter, sccs,
);
low[at] = std::cmp::min(low[at], low[to]);
} else if on_stack[to] {
low[at] = std::cmp::min(low[at], ids[to]);
}
}
if ids[at] == low[at] {
let mut component = Vec::new();
loop {
let node = stack.pop().unwrap();
on_stack[node] = false;
component.push(node);
if node == at {
break;
}
}
sccs.push(component);
}
}
fn generate_schema(type_def: &TypeDef, config: &Config) -> String {
let mut output = String::new();
match type_def {
TypeDef::Record { name, doc, fields } => {
let schema_name = format!("{}Schema", name.to_lower_camel_case());
if let Some(doc) = doc {
output.push_str(&format!("/** Schema for {} */\n", doc));
}
let type_name = name.to_lower_camel_case();
output.push_str(&format!("let {}: S.t<{}> = S.object(s => ({{\n", schema_name, type_name));
for field in fields {
output.push_str(&generate_field_schema(field));
}
output.push_str(&format!("}}: {}))\n", type_name));
}
TypeDef::Variant { name, doc, cases } => {
let schema_name = format!("{}Schema", name.to_lower_camel_case());
let type_name = name.to_lower_camel_case();
if let Some(doc) = doc {
output.push_str(&format!("/** Schema for {} */\n", doc));
}
if cases.iter().all(|c| c.payload.is_none()) {
if config.variant_mode == super::VariantMode::Standard {
output.push_str(&format!("let {}: S.t<{}> = S.enum([\n", schema_name, type_name));
for case in cases {
output.push_str(&format!(" {},\n", case.name));
}
output.push_str("])\n");
} else {
output.push_str(&format!("let {}: S.t<{}> = S.union([\n", schema_name, type_name));
for case in cases {
output.push_str(&format!(
" S.literal(#{}),\n",
case.name
));
}
output.push_str("])\n");
}
} else {
output.push_str(&format!("let {}: S.t<{}> = S.union([\n", schema_name, type_name));
for case in cases {
match &case.payload {
Some(ty) => {
output.push_str(&format!(
" {}->S.transform(s => {{\n parser: v => {}(v),\n serializer: v => switch v {{ | {}(x) => x | _ => S.fail(\"Expected {}\") }}\n }}),\n",
ty.to_schema(),
case.name,
case.name,
case.name
));
}
None => {
output.push_str(&format!(
" S.literal(\"{}\")->S.transform(s => {{\n parser: _ => {},\n serializer: _ => \"{}\"\n }}),\n",
case.original_name,
case.name,
case.original_name
));
}
}
}
output.push_str("])\n");
}
}
TypeDef::Alias { name, doc, target } => {
let schema_name = format!("{}Schema", name.to_lower_camel_case());
if let Some(doc) = doc {
output.push_str(&format!("/** Schema for {} */\n", doc));
}
output.push_str(&format!("let {} = {}\n", schema_name, target.to_schema()));
}
}
output
}
fn generate_field_schema(field: &Field) -> String {
let method = if field.optional { "fieldOr" } else { "field" };
let default = if field.optional {
", None"
} else {
""
};
let schema = field.ty.to_schema();
if field.name != field.original_name {
format!(
" {}: s.{}(\"{}\", {}{}),\n",
field.name,
method,
field.original_name,
schema,
default
)
} else {
format!(
" {}: s.{}(\"{}\", {}{}),\n",
field.name,
method,
field.name,
schema,
default
)
}
}