rescript_openapi/codegen/
schema.rs1use crate::ir::{ApiSpec, TypeDef, Field, RsType};
7use super::Config;
8use anyhow::Result;
9use heck::ToLowerCamelCase;
10use std::collections::{HashMap, HashSet, VecDeque};
11
12pub fn generate(spec: &ApiSpec, config: &Config) -> Result<String> {
13 let mut output = String::new();
14
15 output.push_str("// SPDX-License-Identifier: PMPL-1.0-or-later\n");
17 output.push_str("// Generated by rescript-openapi - DO NOT EDIT\n");
18 output.push_str(&format!("// Source: {} v{}\n\n", spec.title, spec.version));
19
20 output.push_str(&format!("open {}Types\n\n", config.module_prefix));
22
23 output.push_str("module S = RescriptSchema.S\n\n");
25
26 let sorted_types = topological_sort(&spec.types);
28
29 for type_def in sorted_types {
31 output.push_str(&generate_schema(type_def, config));
32 output.push('\n');
33 }
34
35 Ok(output)
36}
37
38pub fn generate_schema_only(type_def: &TypeDef, config: &Config) -> String {
39 let mut output = String::new();
40 output.push_str(&generate_schema(type_def, config));
41 output
42}
43
44pub fn get_dependencies(type_def: &TypeDef) -> HashSet<String> {
46 let mut deps = HashSet::new();
47
48 match type_def {
49 TypeDef::Record { fields, .. } => {
50 for field in fields {
51 collect_type_deps(&field.ty, &mut deps);
52 }
53 }
54 TypeDef::Variant { cases, .. } => {
55 for case in cases {
56 if let Some(ty) = &case.payload {
57 collect_type_deps(ty, &mut deps);
58 }
59 }
60 }
61 TypeDef::Alias { target, .. } => {
62 collect_type_deps(target, &mut deps);
63 }
64 }
65
66 deps
67}
68
69fn collect_type_deps(ty: &RsType, deps: &mut HashSet<String>) {
71 match ty {
72 RsType::Named(name) => {
73 deps.insert(name.to_lower_camel_case());
74 }
75 RsType::Option(inner) | RsType::Array(inner) | RsType::Dict(inner) => {
76 collect_type_deps(inner, deps);
77 }
78 RsType::Tuple(types) => {
79 for t in types {
80 collect_type_deps(t, deps);
81 }
82 }
83 _ => {}
84 }
85}
86
87pub fn topological_sort(types: &[TypeDef]) -> Vec<&TypeDef> {
89 let type_map: HashMap<String, &TypeDef> = types
91 .iter()
92 .map(|t| {
93 let name = match t {
94 TypeDef::Record { name, .. } => name.to_lower_camel_case(),
95 TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
96 TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
97 };
98 (name, t)
99 })
100 .collect();
101
102 let mut deps_map: HashMap<String, HashSet<String>> = HashMap::new();
104 let mut all_names: Vec<String> = Vec::new();
105
106 for type_def in types {
107 let name = match type_def {
108 TypeDef::Record { name, .. } => name.to_lower_camel_case(),
109 TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
110 TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
111 };
112 all_names.push(name.clone());
113
114 let deps = get_dependencies(type_def);
115 let filtered_deps: HashSet<String> = deps
117 .into_iter()
118 .filter(|d| type_map.contains_key(d))
119 .collect();
120 deps_map.insert(name, filtered_deps);
121 }
122
123 let mut in_degree: HashMap<String, usize> = HashMap::new();
126 for name in &all_names {
127 in_degree.insert(name.clone(), 0);
128 }
129
130 for (name, deps) in &deps_map {
132 *in_degree.get_mut(name).unwrap() += deps.len();
135 }
136
137 let mut zero_degree: Vec<String> = in_degree
139 .iter()
140 .filter(|(_, °ree)| degree == 0)
141 .map(|(name, _)| name.clone())
142 .collect();
143 zero_degree.sort();
144 let mut queue: VecDeque<String> = zero_degree.into_iter().collect();
145
146 let mut sorted: Vec<&TypeDef> = Vec::new();
147
148 while let Some(name) = queue.pop_front() {
149 if let Some(&type_def) = type_map.get(&name) {
150 sorted.push(type_def);
151 }
152
153 let mut newly_ready: Vec<String> = Vec::new();
156 for (other_name, other_deps) in &deps_map {
157 if other_deps.contains(&name) {
158 let degree = in_degree.get_mut(other_name).unwrap();
159 *degree -= 1;
160 if *degree == 0 {
161 newly_ready.push(other_name.clone());
162 }
163 }
164 }
165 newly_ready.sort();
166 for ready_name in newly_ready {
167 queue.push_back(ready_name);
168 }
169 }
170
171 if sorted.len() < types.len() {
173 for type_def in types {
174 let name = match type_def {
175 TypeDef::Record { name, .. } => name.to_lower_camel_case(),
176 TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
177 TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
178 };
179 if !sorted.iter().any(|t| {
180 let n = match t {
181 TypeDef::Record { name, .. } => name.to_lower_camel_case(),
182 TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
183 TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
184 };
185 n == name
186 }) {
187 sorted.push(type_def);
188 }
189 }
190 }
191
192 sorted
193}
194
195pub fn topological_sort_scc(types: &[TypeDef]) -> Vec<Vec<&TypeDef>> {
198 let mut type_map: HashMap<String, &TypeDef> = HashMap::new();
200 let mut name_to_index: HashMap<String, usize> = HashMap::new();
201 let mut index_to_name: Vec<String> = Vec::new();
202
203 for (i, t) in types.iter().enumerate() {
204 let name = match t {
205 TypeDef::Record { name, .. } => name.to_lower_camel_case(),
206 TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
207 TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
208 };
209 type_map.insert(name.clone(), t);
210 name_to_index.insert(name.clone(), i);
211 index_to_name.push(name);
212 }
213
214 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); types.len()];
216 for (i, type_def) in types.iter().enumerate() {
217 let deps = get_dependencies(type_def);
218 let mut sorted_deps: Vec<String> = deps.into_iter().collect();
219 sorted_deps.sort();
220
221 for dep_name in sorted_deps {
222 if let Some(&dep_idx) = name_to_index.get(&dep_name) {
223 adj[i].push(dep_idx);
226 }
227 }
228 }
229
230 let n = types.len();
232 let mut visited = vec![false; n];
233 let mut stack = Vec::new();
234 let mut on_stack = vec![false; n];
235 let mut ids = vec![-1; n];
236 let mut low = vec![-1; n];
237 let mut id_counter = 0;
238 let mut sccs: Vec<Vec<usize>> = Vec::new();
239
240 for i in 0..n {
241 if !visited[i] {
242 tarjan_dfs(
243 i,
244 &adj,
245 &mut visited,
246 &mut stack,
247 &mut on_stack,
248 &mut ids,
249 &mut low,
250 &mut id_counter,
251 &mut sccs,
252 );
253 }
254 }
255
256 let mut result = Vec::new();
264 for scc_indices in sccs {
265 let mut scc_types = Vec::new();
266 for &idx in &scc_indices {
267 scc_types.push(&types[idx]);
268 }
269 result.push(scc_types);
270 }
271
272 result
273}
274
275#[allow(clippy::too_many_arguments)]
276fn tarjan_dfs(
277 at: usize,
278 adj: &Vec<Vec<usize>>,
279 visited: &mut Vec<bool>,
280 stack: &mut Vec<usize>,
281 on_stack: &mut Vec<bool>,
282 ids: &mut Vec<i32>,
283 low: &mut Vec<i32>,
284 id_counter: &mut i32,
285 sccs: &mut Vec<Vec<usize>>,
286) {
287 visited[at] = true;
288 stack.push(at);
289 on_stack[at] = true;
290 ids[at] = *id_counter;
291 low[at] = *id_counter;
292 *id_counter += 1;
293
294 for &to in &adj[at] {
295 if !visited[to] {
296 tarjan_dfs(
297 to, adj, visited, stack, on_stack, ids, low, id_counter, sccs,
298 );
299 low[at] = std::cmp::min(low[at], low[to]);
300 } else if on_stack[to] {
301 low[at] = std::cmp::min(low[at], ids[to]);
302 }
303 }
304
305 if ids[at] == low[at] {
306 let mut component = Vec::new();
307 loop {
308 let node = stack.pop().unwrap();
309 on_stack[node] = false;
310 component.push(node);
311 if node == at {
312 break;
313 }
314 }
315 sccs.push(component);
316 }
317}
318
319fn generate_schema(type_def: &TypeDef, config: &Config) -> String {
320 let mut output = String::new();
321
322 match type_def {
323 TypeDef::Record { name, doc, fields } => {
324 let schema_name = format!("{}Schema", name.to_lower_camel_case());
325
326 if let Some(doc) = doc {
327 output.push_str(&format!("/** Schema for {} */\n", doc));
328 }
329
330 let type_name = name.to_lower_camel_case();
331 output.push_str(&format!("let {}: S.t<{}> = S.object(s => ({{\n", schema_name, type_name));
332
333 for field in fields {
334 output.push_str(&generate_field_schema(field));
335 }
336
337 output.push_str(&format!("}}: {}))\n", type_name));
338 }
339
340 TypeDef::Variant { name, doc, cases } => {
341 let schema_name = format!("{}Schema", name.to_lower_camel_case());
342 let type_name = name.to_lower_camel_case();
343
344 if let Some(doc) = doc {
345 output.push_str(&format!("/** Schema for {} */\n", doc));
346 }
347
348 if cases.iter().all(|c| c.payload.is_none()) {
350 if config.variant_mode == super::VariantMode::Standard {
351 output.push_str(&format!("let {}: S.t<{}> = S.enum([\n", schema_name, type_name));
353 for case in cases {
354 output.push_str(&format!(" {},\n", case.name));
355 }
356 output.push_str("])\n");
357 } else {
358 output.push_str(&format!("let {}: S.t<{}> = S.union([\n", schema_name, type_name));
360
361 for case in cases {
362 output.push_str(&format!(
363 " S.literal(#{}),\n",
364 case.name
365 ));
366 }
367
368 output.push_str("])\n");
369 }
370 } else {
371 output.push_str(&format!("let {}: S.t<{}> = S.union([\n", schema_name, type_name));
373
374 for case in cases {
375 match &case.payload {
376 Some(ty) => {
377 output.push_str(&format!(
379 " {}->S.transform(s => {{\n parser: v => {}(v),\n serializer: v => switch v {{ | {}(x) => x | _ => S.fail(\"Expected {}\") }}\n }}),\n",
380 ty.to_schema(),
381 case.name,
382 case.name,
383 case.name
384 ));
385 }
386 None => {
387 output.push_str(&format!(
389 " S.literal(\"{}\")->S.transform(s => {{\n parser: _ => {},\n serializer: _ => \"{}\"\n }}),\n",
390 case.original_name,
391 case.name,
392 case.original_name
393 ));
394 }
395 }
396 }
397
398 output.push_str("])\n");
399 }
400 }
401
402 TypeDef::Alias { name, doc, target } => {
403 let schema_name = format!("{}Schema", name.to_lower_camel_case());
404
405 if let Some(doc) = doc {
406 output.push_str(&format!("/** Schema for {} */\n", doc));
407 }
408
409 output.push_str(&format!("let {} = {}\n", schema_name, target.to_schema()));
410 }
411 }
412
413 output
414}
415
416fn generate_field_schema(field: &Field) -> String {
417 let method = if field.optional { "fieldOr" } else { "field" };
418 let default = if field.optional {
419 ", None"
420 } else {
421 ""
422 };
423
424 let schema = field.ty.to_schema();
425
426 if field.name != field.original_name {
427 format!(
428 " {}: s.{}(\"{}\", {}{}),\n",
429 field.name,
430 method,
431 field.original_name,
432 schema,
433 default
434 )
435 } else {
436 format!(
437 " {}: s.{}(\"{}\", {}{}),\n",
438 field.name,
439 method,
440 field.name,
441 schema,
442 default
443 )
444 }
445}