Skip to main content

rescript_openapi/codegen/
schema.rs

1// SPDX-License-Identifier: PMPL-1.0-or-later
2// SPDX-FileCopyrightText: 2026 Jonathan D.A. Jewell
3
4//! rescript-schema validator generation with topological sorting
5
6use crate::ir::{ApiSpec, TypeDef, Field, RsType};
7use super::Config;
8use anyhow::Result;
9use heck::ToLowerCamelCase;
10use std::collections::{HashMap, HashSet, VecDeque};
11
12pub fn generate(spec: &ApiSpec, config: &Config) -> Result<String> {
13    let mut output = String::new();
14
15    // Header
16    output.push_str("// SPDX-License-Identifier: PMPL-1.0-or-later\n");
17    output.push_str("// Generated by rescript-openapi - DO NOT EDIT\n");
18    output.push_str(&format!("// Source: {} v{}\n\n", spec.title, spec.version));
19
20    // Import types
21    output.push_str(&format!("open {}Types\n\n", config.module_prefix));
22
23    // Module alias for rescript-schema
24    output.push_str("module S = RescriptSchema.S\n\n");
25
26    // Topologically sort types by dependencies
27    let sorted_types = topological_sort(&spec.types);
28
29    // Generate schema for each type in dependency order
30    for type_def in sorted_types {
31        output.push_str(&generate_schema(type_def, config));
32        output.push('\n');
33    }
34
35    Ok(output)
36}
37
38pub fn generate_schema_only(type_def: &TypeDef, config: &Config) -> String {
39    let mut output = String::new();
40    output.push_str(&generate_schema(type_def, config));
41    output
42}
43
44/// Extract type dependencies from a TypeDef
45pub fn get_dependencies(type_def: &TypeDef) -> HashSet<String> {
46    let mut deps = HashSet::new();
47
48    match type_def {
49        TypeDef::Record { fields, .. } => {
50            for field in fields {
51                collect_type_deps(&field.ty, &mut deps);
52            }
53        }
54        TypeDef::Variant { cases, .. } => {
55            for case in cases {
56                if let Some(ty) = &case.payload {
57                    collect_type_deps(ty, &mut deps);
58                }
59            }
60        }
61        TypeDef::Alias { target, .. } => {
62            collect_type_deps(target, &mut deps);
63        }
64    }
65
66    deps
67}
68
69/// Recursively collect Named type dependencies
70fn collect_type_deps(ty: &RsType, deps: &mut HashSet<String>) {
71    match ty {
72        RsType::Named(name) => {
73            deps.insert(name.to_lower_camel_case());
74        }
75        RsType::Option(inner) | RsType::Array(inner) | RsType::Dict(inner) => {
76            collect_type_deps(inner, deps);
77        }
78        RsType::Tuple(types) => {
79            for t in types {
80                collect_type_deps(t, deps);
81            }
82        }
83        _ => {}
84    }
85}
86
87/// Topologically sort types so dependencies come before dependents
88pub fn topological_sort(types: &[TypeDef]) -> Vec<&TypeDef> {
89    // Build name -> TypeDef map
90    let type_map: HashMap<String, &TypeDef> = types
91        .iter()
92        .map(|t| {
93            let name = match t {
94                TypeDef::Record { name, .. } => name.to_lower_camel_case(),
95                TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
96                TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
97            };
98            (name, t)
99        })
100        .collect();
101
102    // Build dependency graph (adjacency list: type -> types it depends on)
103    let mut deps_map: HashMap<String, HashSet<String>> = HashMap::new();
104    let mut all_names: Vec<String> = Vec::new();
105
106    for type_def in types {
107        let name = match type_def {
108            TypeDef::Record { name, .. } => name.to_lower_camel_case(),
109            TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
110            TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
111        };
112        all_names.push(name.clone());
113
114        let deps = get_dependencies(type_def);
115        // Only keep deps that are actually in our type set
116        let filtered_deps: HashSet<String> = deps
117            .into_iter()
118            .filter(|d| type_map.contains_key(d))
119            .collect();
120        deps_map.insert(name, filtered_deps);
121    }
122
123    // Kahn's algorithm for topological sort
124    // Calculate in-degree (how many types depend on this type)
125    let mut in_degree: HashMap<String, usize> = HashMap::new();
126    for name in &all_names {
127        in_degree.insert(name.clone(), 0);
128    }
129
130    // Build reverse graph: for each type, which types depend on it
131    for (name, deps) in &deps_map {
132        // Each dependency means 'name' has an incoming edge
133        // (dep must come before name in the sorted order)
134        *in_degree.get_mut(name).unwrap() += deps.len();
135    }
136
137    // Start with types that have no dependencies (sorted for deterministic order)
138    let mut zero_degree: Vec<String> = in_degree
139        .iter()
140        .filter(|(_, &degree)| degree == 0)
141        .map(|(name, _)| name.clone())
142        .collect();
143    zero_degree.sort();
144    let mut queue: VecDeque<String> = zero_degree.into_iter().collect();
145
146    let mut sorted: Vec<&TypeDef> = Vec::new();
147
148    while let Some(name) = queue.pop_front() {
149        if let Some(&type_def) = type_map.get(&name) {
150            sorted.push(type_def);
151        }
152
153        // For each type that depends on this one, decrease its in-degree
154        // Collect newly ready types and sort them for deterministic order
155        let mut newly_ready: Vec<String> = Vec::new();
156        for (other_name, other_deps) in &deps_map {
157            if other_deps.contains(&name) {
158                let degree = in_degree.get_mut(other_name).unwrap();
159                *degree -= 1;
160                if *degree == 0 {
161                    newly_ready.push(other_name.clone());
162                }
163            }
164        }
165        newly_ready.sort();
166        for ready_name in newly_ready {
167            queue.push_back(ready_name);
168        }
169    }
170
171    // If we didn't get all types, there's a cycle - just append remaining
172    if sorted.len() < types.len() {
173        for type_def in types {
174            let name = match type_def {
175                TypeDef::Record { name, .. } => name.to_lower_camel_case(),
176                TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
177                TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
178            };
179            if !sorted.iter().any(|t| {
180                let n = match t {
181                    TypeDef::Record { name, .. } => name.to_lower_camel_case(),
182                    TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
183                    TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
184                };
185                n == name
186            }) {
187                sorted.push(type_def);
188            }
189        }
190    }
191
192    sorted
193}
194
195/// Sort types into Strongly Connected Components (SCCs) in topological order
196/// Returns a list of components, where each component is a list of mutually recursive types.
197pub fn topological_sort_scc(types: &[TypeDef]) -> Vec<Vec<&TypeDef>> {
198    // Build name -> TypeDef map and index map
199    let mut type_map: HashMap<String, &TypeDef> = HashMap::new();
200    let mut name_to_index: HashMap<String, usize> = HashMap::new();
201    let mut index_to_name: Vec<String> = Vec::new();
202
203    for (i, t) in types.iter().enumerate() {
204        let name = match t {
205            TypeDef::Record { name, .. } => name.to_lower_camel_case(),
206            TypeDef::Variant { name, .. } => name.to_lower_camel_case(),
207            TypeDef::Alias { name, .. } => name.to_lower_camel_case(),
208        };
209        type_map.insert(name.clone(), t);
210        name_to_index.insert(name.clone(), i);
211        index_to_name.push(name);
212    }
213
214    // Build adjacency list
215    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); types.len()];
216    for (i, type_def) in types.iter().enumerate() {
217        let deps = get_dependencies(type_def);
218        let mut sorted_deps: Vec<String> = deps.into_iter().collect();
219        sorted_deps.sort();
220        
221        for dep_name in sorted_deps {
222            if let Some(&dep_idx) = name_to_index.get(&dep_name) {
223                // Dependency: i depends on dep_idx
224                // For Tarjan's on dependency graph, edge i -> dep_idx
225                adj[i].push(dep_idx);
226            }
227        }
228    }
229
230    // Tarjan's Algorithm
231    let n = types.len();
232    let mut visited = vec![false; n];
233    let mut stack = Vec::new();
234    let mut on_stack = vec![false; n];
235    let mut ids = vec![-1; n];
236    let mut low = vec![-1; n];
237    let mut id_counter = 0;
238    let mut sccs: Vec<Vec<usize>> = Vec::new();
239
240    for i in 0..n {
241        if !visited[i] {
242            tarjan_dfs(
243                i,
244                &adj,
245                &mut visited,
246                &mut stack,
247                &mut on_stack,
248                &mut ids,
249                &mut low,
250                &mut id_counter,
251                &mut sccs,
252            );
253        }
254    }
255
256    // Tarjan's returns SCCs in reverse topological order (leaves first)
257    // We want dependencies first, so we use the order as is (reverse topo).
258    // Wait, Tarjan returns reverse topological order of the condensation graph?
259    // "The SCCs are found in reverse topological order."
260    // If A depends on B, and they are in different SCCs, B's SCC will be found first.
261    // This is exactly what we want: print dependencies (B) before dependents (A).
262
263    let mut result = Vec::new();
264    for scc_indices in sccs {
265        let mut scc_types = Vec::new();
266        for &idx in &scc_indices {
267            scc_types.push(&types[idx]);
268        }
269        result.push(scc_types);
270    }
271
272    result
273}
274
275#[allow(clippy::too_many_arguments)]
276fn tarjan_dfs(
277    at: usize,
278    adj: &Vec<Vec<usize>>,
279    visited: &mut Vec<bool>,
280    stack: &mut Vec<usize>,
281    on_stack: &mut Vec<bool>,
282    ids: &mut Vec<i32>,
283    low: &mut Vec<i32>,
284    id_counter: &mut i32,
285    sccs: &mut Vec<Vec<usize>>,
286) {
287    visited[at] = true;
288    stack.push(at);
289    on_stack[at] = true;
290    ids[at] = *id_counter;
291    low[at] = *id_counter;
292    *id_counter += 1;
293
294    for &to in &adj[at] {
295        if !visited[to] {
296            tarjan_dfs(
297                to, adj, visited, stack, on_stack, ids, low, id_counter, sccs,
298            );
299            low[at] = std::cmp::min(low[at], low[to]);
300        } else if on_stack[to] {
301            low[at] = std::cmp::min(low[at], ids[to]);
302        }
303    }
304
305    if ids[at] == low[at] {
306        let mut component = Vec::new();
307        loop {
308            let node = stack.pop().unwrap();
309            on_stack[node] = false;
310            component.push(node);
311            if node == at {
312                break;
313            }
314        }
315        sccs.push(component);
316    }
317}
318
319fn generate_schema(type_def: &TypeDef, config: &Config) -> String {
320    let mut output = String::new();
321
322    match type_def {
323        TypeDef::Record { name, doc, fields } => {
324            let schema_name = format!("{}Schema", name.to_lower_camel_case());
325
326            if let Some(doc) = doc {
327                output.push_str(&format!("/** Schema for {} */\n", doc));
328            }
329
330            let type_name = name.to_lower_camel_case();
331            output.push_str(&format!("let {}: S.t<{}> = S.object(s => ({{\n", schema_name, type_name));
332
333            for field in fields {
334                output.push_str(&generate_field_schema(field));
335            }
336
337            output.push_str(&format!("}}: {}))\n", type_name));
338        }
339
340        TypeDef::Variant { name, doc, cases } => {
341            let schema_name = format!("{}Schema", name.to_lower_camel_case());
342            let type_name = name.to_lower_camel_case();
343
344            if let Some(doc) = doc {
345                output.push_str(&format!("/** Schema for {} */\n", doc));
346            }
347
348            // String enum variant
349            if cases.iter().all(|c| c.payload.is_none()) {
350                if config.variant_mode == super::VariantMode::Standard {
351                    // Standard variant - use S.enum
352                    output.push_str(&format!("let {}: S.t<{}> = S.enum([\n", schema_name, type_name));
353                    for case in cases {
354                        output.push_str(&format!("  {},\n", case.name));
355                    }
356                    output.push_str("])\n");
357                } else {
358                    // Polymorphic variant - use S.union with literals
359                    output.push_str(&format!("let {}: S.t<{}> = S.union([\n", schema_name, type_name));
360
361                    for case in cases {
362                        output.push_str(&format!(
363                            "  S.literal(#{}),\n",
364                            case.name
365                        ));
366                    }
367
368                    output.push_str("])\n");
369                }
370            } else {
371                // oneOf/anyOf variant - wrap each referenced type's schema
372                output.push_str(&format!("let {}: S.t<{}> = S.union([\n", schema_name, type_name));
373
374                for case in cases {
375                    match &case.payload {
376                        Some(ty) => {
377                            // Wrap the inner schema to transform to variant constructor
378                            output.push_str(&format!(
379                                "  {}->S.transform(s => {{\n    parser: v => {}(v),\n    serializer: v => switch v {{ | {}(x) => x | _ => S.fail(\"Expected {}\") }}\n  }}),\n",
380                                ty.to_schema(),
381                                case.name,
382                                case.name,
383                                case.name
384                            ));
385                        }
386                        None => {
387                            // No payload - literal variant
388                            output.push_str(&format!(
389                                "  S.literal(\"{}\")->S.transform(s => {{\n    parser: _ => {},\n    serializer: _ => \"{}\"\n  }}),\n",
390                                case.original_name,
391                                case.name,
392                                case.original_name
393                            ));
394                        }
395                    }
396                }
397
398                output.push_str("])\n");
399            }
400        }
401
402        TypeDef::Alias { name, doc, target } => {
403            let schema_name = format!("{}Schema", name.to_lower_camel_case());
404
405            if let Some(doc) = doc {
406                output.push_str(&format!("/** Schema for {} */\n", doc));
407            }
408
409            output.push_str(&format!("let {} = {}\n", schema_name, target.to_schema()));
410        }
411    }
412
413    output
414}
415
416fn generate_field_schema(field: &Field) -> String {
417    let method = if field.optional { "fieldOr" } else { "field" };
418    let default = if field.optional {
419        ", None"
420    } else {
421        ""
422    };
423
424    let schema = field.ty.to_schema();
425
426    if field.name != field.original_name {
427        format!(
428            "  {}: s.{}(\"{}\", {}{}),\n",
429            field.name,
430            method,
431            field.original_name,
432            schema,
433            default
434        )
435    } else {
436        format!(
437            "  {}: s.{}(\"{}\", {}{}),\n",
438            field.name,
439            method,
440            field.name,
441            schema,
442            default
443        )
444    }
445}