Skip to main content

olai_codegen/
openapi_enrich.rs

1use std::collections::HashSet;
2use std::fs;
3use std::path::Path;
4
5use glob::glob;
6use protobuf::Message;
7use serde_json::Map as JsonMap;
8use serde_yaml::Value as YamlValue;
9
10use crate::Result;
11use crate::parsing::{CodeGenMetadata, parse_file_descriptor_set};
12
13/// Run Pass 0 (gnostic ref fix), Pass 1 (validation enrichment), and optionally Pass 2 (path/body dedup).
14pub fn run(
15    spec: &Path,
16    jsonschema_dir: &Path,
17    camel_case: bool,
18    descriptors: Option<&Path>,
19) -> Result<()> {
20    let spec_str = fs::read_to_string(spec)
21        .map_err(|e| crate::Error::Build(format!("Failed to read {}: {}", spec.display(), e)))?;
22    let mut doc: YamlValue = serde_yaml::from_str(&spec_str).map_err(|e| {
23        crate::Error::Build(format!("Failed to parse YAML {}: {}", spec.display(), e))
24    })?;
25
26    fix_gnostic_refs(&mut doc);
27    enrich_from_jsonschema(&mut doc, jsonschema_dir, camel_case)?;
28
29    if let Some(desc_path) = descriptors {
30        let bytes = fs::read(desc_path).map_err(|e| {
31            crate::Error::Build(format!(
32                "Failed to read descriptors {}: {}",
33                desc_path.display(),
34                e
35            ))
36        })?;
37        let fds = protobuf::descriptor::FileDescriptorSet::parse_from_bytes(&bytes)
38            .map_err(|e| crate::Error::Build(format!("Failed to parse descriptors: {}", e)))?;
39        let metadata = parse_file_descriptor_set(&fds)?;
40        dedup_path_params(&mut doc, &metadata);
41    }
42
43    let out = serde_yaml::to_string(&doc)
44        .map_err(|e| crate::Error::Build(format!("Failed to serialize YAML: {}", e)))?;
45    fs::write(spec, out)
46        .map_err(|e| crate::Error::Build(format!("Failed to write {}: {}", spec.display(), e)))?;
47
48    Ok(())
49}
50
51// ── Pass 0 ────────────────────────────────────────────────────────────────────
52
53/// Rewrite gnostic-generated `#/$defs/a.b.v1.TypeName.schema.strict.json` refs to
54/// valid OpenAPI `#/components/schemas/TypeName` refs throughout the document.
55fn fix_gnostic_refs(doc: &mut YamlValue) {
56    match doc {
57        YamlValue::Mapping(map) => {
58            if let Some(YamlValue::String(s)) = map.get_mut("$ref") {
59                if let Some(fixed) = rewrite_gnostic_ref(s) {
60                    *s = fixed;
61                }
62            }
63            for (_, v) in map.iter_mut() {
64                fix_gnostic_refs(v);
65            }
66        }
67        YamlValue::Sequence(seq) => {
68            for item in seq.iter_mut() {
69                fix_gnostic_refs(item);
70            }
71        }
72        _ => {}
73    }
74}
75
76fn rewrite_gnostic_ref(ref_str: &str) -> Option<String> {
77    let stem = ref_str
78        .strip_prefix("#/$defs/")?
79        .strip_suffix(".schema.strict.json")?;
80    let start = stem.find(|c: char| c.is_uppercase())?;
81    let type_name = &stem[start..];
82    Some(format!("#/components/schemas/{type_name}"))
83}
84
85// ── Pass 1 ────────────────────────────────────────────────────────────────────
86
87fn enrich_from_jsonschema(
88    spec: &mut YamlValue,
89    jsonschema_dir: &Path,
90    camel_case: bool,
91) -> Result<()> {
92    let pattern = jsonschema_dir
93        .join("*.schema.strict.bundle.json")
94        .to_string_lossy()
95        .into_owned();
96
97    let mut files: Vec<std::path::PathBuf> = glob(&pattern)
98        .map_err(|e| crate::Error::Build(format!("Glob pattern error: {}", e)))?
99        .filter_map(|r: Result<std::path::PathBuf, _>| r.ok())
100        .collect();
101    files.sort();
102
103    if files.is_empty() {
104        eprintln!(
105            "enrich-openapi: no JSON Schema files found in {}",
106            jsonschema_dir.display()
107        );
108        return Ok(());
109    }
110
111    let mut updated = 0usize;
112    let mut added = 0usize;
113
114    for path in &files {
115        let filename = path
116            .file_name()
117            .and_then(|f: &std::ffi::OsStr| f.to_str())
118            .unwrap_or_default();
119        let type_name = match type_name_from_filename(filename) {
120            Some(n) => n,
121            None => {
122                eprintln!("enrich-openapi: skipping {filename} (cannot extract type name)");
123                continue;
124            }
125        };
126
127        let content = fs::read_to_string(path).map_err(|e| {
128            crate::Error::Build(format!("Failed to read {}: {}", path.display(), e))
129        })?;
130        let bundle: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
131            crate::Error::Build(format!("Failed to parse JSON {}: {}", path.display(), e))
132        })?;
133
134        let defs: JsonMap<String, serde_json::Value> = bundle
135            .get("$defs")
136            .and_then(|v: &serde_json::Value| v.as_object())
137            .cloned()
138            .unwrap_or_default();
139
140        let root_ref = bundle
141            .get("$ref")
142            .and_then(|v: &serde_json::Value| v.as_str())
143            .unwrap_or("");
144        let root_key = root_ref.strip_prefix("#/$defs/").unwrap_or("");
145        let root_schema: serde_json::Value = match defs.get(root_key) {
146            Some(s) => s.clone(),
147            None => {
148                eprintln!(
149                    "enrich-openapi: could not resolve root $ref '{root_ref}' for {type_name}, skipping"
150                );
151                continue;
152            }
153        };
154
155        // Navigate to components.schemas.<TypeName>
156        let schemas = spec
157            .get_mut("components")
158            .and_then(|c| c.get_mut("schemas"));
159        let schemas = match schemas {
160            Some(s) => s,
161            None => {
162                eprintln!("enrich-openapi: openapi.yaml has no components.schemas, skipping");
163                break;
164            }
165        };
166
167        let exists = schemas
168            .as_mapping()
169            .map(|m| m.contains_key(type_name.as_str()))
170            .unwrap_or(false);
171
172        if !exists {
173            let ty = root_schema
174                .get("type")
175                .and_then(|v: &serde_json::Value| v.as_str())
176                .unwrap_or("object");
177            if let Some(map) = schemas.as_mapping_mut() {
178                let mut entry = serde_yaml::Mapping::new();
179                entry.insert(
180                    YamlValue::String("type".into()),
181                    YamlValue::String(ty.to_string()),
182                );
183                map.insert(
184                    YamlValue::String(type_name.clone()),
185                    YamlValue::Mapping(entry),
186                );
187            }
188            added += 1;
189        } else {
190            updated += 1;
191        }
192
193        let oa_schema = schemas
194            .as_mapping_mut()
195            .and_then(|m| m.get_mut(type_name.as_str()));
196        if let Some(oa) = oa_schema {
197            enrich_schema(oa, &root_schema, &defs, camel_case);
198        }
199    }
200
201    println!(
202        "enrich-openapi: enriched {} schemas ({} updated, {} added)",
203        updated + added,
204        updated,
205        added
206    );
207    Ok(())
208}
209
210/// Extract type name from a JSON Schema bundle filename.
211/// e.g. "dda.coordinator.v1.CreateWaveRequest.schema.strict.bundle.json" → "CreateWaveRequest"
212fn type_name_from_filename(filename: &str) -> Option<String> {
213    let stem = filename.strip_suffix(".schema.strict.bundle.json")?;
214    let start = stem.find(|c: char| c.is_uppercase())?;
215    Some(stem[start..].to_string())
216}
217
218/// Convert snake_case to camelCase.
219fn snake_to_camel(s: &str) -> String {
220    let mut result = String::with_capacity(s.len());
221    let mut capitalise_next = false;
222    for ch in s.chars() {
223        if ch == '_' {
224            capitalise_next = true;
225        } else if capitalise_next {
226            result.push(ch.to_ascii_uppercase());
227            capitalise_next = false;
228        } else {
229            result.push(ch);
230        }
231    }
232    result
233}
234
235/// Resolve a `#/$defs/<key>` reference within a bundle's $defs map.
236fn resolve_ref<'a>(
237    ref_str: &str,
238    defs: &'a JsonMap<String, serde_json::Value>,
239) -> Option<&'a serde_json::Value> {
240    let key = ref_str.strip_prefix("#/$defs/")?;
241    defs.get(key)
242}
243
244const VALIDATION_FIELDS: &[&str] = &[
245    "minLength",
246    "maxLength",
247    "pattern",
248    "minimum",
249    "maximum",
250    "exclusiveMinimum",
251    "exclusiveMaximum",
252    "minItems",
253    "maxItems",
254    "enum",
255    "additionalProperties",
256    "required",
257    "description",
258    "title",
259];
260
261/// Merge validation fields from a JSON Schema node into an OpenAPI YAML schema node.
262fn merge_validation(source: &serde_json::Value, target: &mut YamlValue) {
263    for &key in VALIDATION_FIELDS {
264        let val = match source.get(key) {
265            Some(v) => v,
266            None => continue,
267        };
268
269        if key == "exclusiveMinimum" {
270            if let Some(n) = val.as_f64() {
271                // JSON Schema 2020-12 numeric form → OpenAPI 3.0 boolean form
272                yaml_set(
273                    target,
274                    "minimum",
275                    YamlValue::Number(serde_yaml::Number::from(n)),
276                );
277                yaml_set(target, "exclusiveMinimum", YamlValue::Bool(true));
278            }
279            continue;
280        }
281
282        if key == "exclusiveMaximum" {
283            if let Some(n) = val.as_f64() {
284                // JSON Schema 2020-12 numeric form → OpenAPI 3.0 boolean form
285                yaml_set(
286                    target,
287                    "maximum",
288                    YamlValue::Number(serde_yaml::Number::from(n)),
289                );
290                yaml_set(target, "exclusiveMaximum", YamlValue::Bool(true));
291            }
292            continue;
293        }
294
295        yaml_set(target, key, json_to_yaml(val));
296    }
297}
298
299/// Recursively enrich an OpenAPI schema YAML node with validation from a JSON Schema node.
300fn enrich_schema(
301    openapi: &mut YamlValue,
302    json_schema: &serde_json::Value,
303    defs: &JsonMap<String, serde_json::Value>,
304    camel_case: bool,
305) {
306    merge_validation(json_schema, openapi);
307
308    // Recurse into properties
309    if let Some(js_props) = json_schema.get("properties").and_then(|v| v.as_object()) {
310        for (snake_key, js_prop) in js_props {
311            let lookup_key = if camel_case {
312                snake_to_camel(snake_key)
313            } else {
314                snake_key.clone()
315            };
316
317            let resolved: std::borrow::Cow<serde_json::Value> =
318                if let Some(ref_str) = js_prop.get("$ref").and_then(|v| v.as_str()) {
319                    match resolve_ref(ref_str, defs) {
320                        Some(r) => std::borrow::Cow::Borrowed(r),
321                        None => std::borrow::Cow::Borrowed(js_prop),
322                    }
323                } else {
324                    std::borrow::Cow::Borrowed(js_prop)
325                };
326
327            if let Some(oa_prop) = openapi
328                .get_mut("properties")
329                .and_then(|p| p.get_mut(lookup_key.as_str()))
330            {
331                enrich_schema(oa_prop, &resolved, defs, camel_case);
332            }
333        }
334    }
335
336    // Recurse into items
337    if let Some(js_items) = json_schema.get("items") {
338        let resolved: std::borrow::Cow<serde_json::Value> =
339            if let Some(ref_str) = js_items.get("$ref").and_then(|v| v.as_str()) {
340                match resolve_ref(ref_str, defs) {
341                    Some(r) => std::borrow::Cow::Borrowed(r),
342                    None => std::borrow::Cow::Borrowed(js_items),
343                }
344            } else {
345                std::borrow::Cow::Borrowed(js_items)
346            };
347
348        if let Some(oa_items) = openapi.get_mut("items") {
349            enrich_schema(oa_items, &resolved, defs, camel_case);
350        }
351    }
352
353    // Recurse into combiners
354    for combiner in &["allOf", "oneOf", "anyOf"] {
355        if let Some(js_list) = json_schema.get(combiner).and_then(|v| v.as_array()) {
356            if let Some(oa_list) = openapi.get_mut(combiner).and_then(|v| v.as_sequence_mut()) {
357                for (i, js_entry) in js_list.iter().enumerate() {
358                    if i >= oa_list.len() {
359                        break;
360                    }
361                    let resolved: std::borrow::Cow<serde_json::Value> =
362                        if let Some(ref_str) = js_entry.get("$ref").and_then(|v| v.as_str()) {
363                            match resolve_ref(ref_str, defs) {
364                                Some(r) => std::borrow::Cow::Borrowed(r),
365                                None => std::borrow::Cow::Borrowed(js_entry),
366                            }
367                        } else {
368                            std::borrow::Cow::Borrowed(js_entry)
369                        };
370                    enrich_schema(&mut oa_list[i], &resolved, defs, camel_case);
371                }
372            }
373        }
374    }
375}
376
377// ── Pass 2 ────────────────────────────────────────────────────────────────────
378
379fn dedup_path_params(spec: &mut YamlValue, metadata: &CodeGenMetadata) {
380    let mut removed_total = 0usize;
381
382    for service in metadata.services.values() {
383        for method in &service.methods {
384            let path_params: HashSet<String> =
385                method.http_pattern.parameters.iter().cloned().collect();
386            if path_params.is_empty() {
387                continue;
388            }
389
390            let input_type = method
391                .input_type
392                .rfind('.')
393                .map(|i| &method.input_type[i + 1..])
394                .unwrap_or(&method.input_type);
395
396            let schema = spec
397                .get_mut("components")
398                .and_then(|c| c.get_mut("schemas"))
399                .and_then(|s| s.get_mut(input_type));
400
401            let schema = match schema {
402                Some(s) => s,
403                None => continue,
404            };
405
406            // Remove from properties
407            if let Some(props) = schema
408                .get_mut("properties")
409                .and_then(|p| p.as_mapping_mut())
410            {
411                let before = props.len();
412                props.retain(|k, _| {
413                    let key = k.as_str().unwrap_or("");
414                    !path_params.contains(key)
415                });
416                removed_total += before - props.len();
417            }
418
419            // Remove from required array
420            if let Some(required) = schema.get_mut("required").and_then(|r| r.as_sequence_mut()) {
421                required.retain(|v| {
422                    let key = v.as_str().unwrap_or("");
423                    !path_params.contains(key)
424                });
425            }
426        }
427    }
428
429    if removed_total > 0 {
430        println!(
431            "enrich-openapi: dedup removed {removed_total} path-bound field(s) from request body schemas"
432        );
433    }
434}
435
436// ── YAML helpers ──────────────────────────────────────────────────────────────
437
438fn yaml_set(target: &mut YamlValue, key: &str, value: YamlValue) {
439    if let YamlValue::Mapping(m) = target {
440        m.insert(YamlValue::String(key.to_string()), value);
441    }
442}
443
444fn json_to_yaml(v: &serde_json::Value) -> YamlValue {
445    match v {
446        serde_json::Value::Null => YamlValue::Null,
447        serde_json::Value::Bool(b) => YamlValue::Bool(*b),
448        serde_json::Value::Number(n) => {
449            if let Some(i) = n.as_i64() {
450                YamlValue::Number(serde_yaml::Number::from(i))
451            } else if let Some(f) = n.as_f64() {
452                YamlValue::Number(serde_yaml::Number::from(f))
453            } else {
454                YamlValue::String(n.to_string())
455            }
456        }
457        serde_json::Value::String(s) => YamlValue::String(s.clone()),
458        serde_json::Value::Array(arr) => {
459            YamlValue::Sequence(arr.iter().map(json_to_yaml).collect())
460        }
461        serde_json::Value::Object(obj) => {
462            let mut mapping = serde_yaml::Mapping::new();
463            for (k, val) in obj {
464                mapping.insert(YamlValue::String(k.clone()), json_to_yaml(val));
465            }
466            YamlValue::Mapping(mapping)
467        }
468    }
469}