Skip to main content

stepflow_flow/
json_schema.rs

1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13//! Utility for generating standalone JSON Schema documents from schemars::JsonSchema types.
14//!
15//! This module provides functionality to generate standalone JSON Schema draft 2020-12
16//! documents suitable for code generation tools like datamodel-code-generator.
17
18use serde_json::Value;
19
20/// Controls how external type references are handled in the generated schema.
21#[derive(Debug, Clone)]
22pub enum Refs {
23    /// Omit external schemas - just reference them by name without definitions.
24    /// Produces compact schemas suitable for documentation.
25    Omit,
26    /// Include external schemas in `$defs` with local references (`#/$defs/TypeName`).
27    /// Produces self-contained schemas for validation.
28    Local,
29    /// Reference external schemas from an external URL.
30    /// References become `{base_url}#/$defs/TypeName`.
31    External(String),
32}
33
34/// Generate a standalone JSON Schema document from a type implementing JsonSchema.
35///
36/// This function generates a compact schema without `$defs` - any referenced types
37/// will appear as `$ref` without definitions. This is suitable for component schemas
38/// used for documentation purposes.
39///
40/// For a complete schema with all `$defs` included, use [`generate_json_schema_with_defs`].
41pub fn generate_json_schema<T: schemars::JsonSchema>() -> Value {
42    generate_json_schema_with_refs::<T>(Refs::Omit)
43}
44
45/// Generate a standalone JSON Schema document with all `$defs` included.
46///
47/// This produces a fully self-contained schema suitable for validation
48/// without external references.
49pub fn generate_json_schema_with_defs<T: schemars::JsonSchema>() -> Value {
50    generate_json_schema_with_refs::<T>(Refs::Local)
51}
52
53/// Generate a JSON Schema document with configurable reference handling.
54///
55/// # Arguments
56/// * `refs` - Controls how external type references are handled:
57///   - `Refs::Omit` - Omit `$defs`, just reference by name
58///   - `Refs::Local` - Include schemas in `$defs` with local references
59///   - `Refs::External(url)` - Reference schemas from an external URL
60pub fn generate_json_schema_with_refs<T: schemars::JsonSchema>(refs: Refs) -> Value {
61    generate_json_schema_custom::<T>(refs, |_| {})
62}
63
64/// Generate a JSON Schema document with configurable reference handling and
65/// additional types seeded into `$defs`.
66///
67/// The `seed` callback receives a `&mut SchemaGenerator` before the root schema
68/// is finalised.  Calling `generator.subschema_for::<ExtraType>()` inside the
69/// callback ensures the type (and all its transitive deps) appear in `$defs`
70/// even when they are not reachable from the root type `T`.
71pub fn generate_json_schema_custom<T: schemars::JsonSchema>(
72    refs: Refs,
73    seed: impl FnOnce(&mut schemars::SchemaGenerator),
74) -> Value {
75    let settings = schemars::generate::SchemaSettings::draft2020_12();
76    let mut generator = settings.into_generator();
77    seed(&mut generator);
78    let schema = generator.into_root_schema_for::<T>();
79    let mut json = serde_json::to_value(schema).expect("Failed to serialize schema");
80
81    match refs {
82        Refs::Omit => {
83            // Remove $defs entirely
84            if let Some(obj) = json.as_object_mut() {
85                obj.remove("$defs");
86            }
87        }
88        Refs::Local => {
89            finalize_discriminators(&mut json);
90        }
91        Refs::External(ref base_url) => {
92            finalize_discriminators(&mut json);
93            // Transform #/$defs/X references to {base_url}#/$defs/X
94            transform_refs_external(&mut json, base_url);
95        }
96    }
97
98    json
99}
100
101/// Post-process a generated schema to make discriminated unions work correctly
102/// with code generators like `datamodel-code-generator`.
103///
104/// This runs three steps in order:
105/// 1. **Extract inline `oneOf` variants** to the definitions section — variants
106///    are keyed by their `title` attribute, so code generators produce the
107///    expected type names.
108/// 2. **Build discriminator mappings** by resolving `$ref` → definitions to read
109///    tag `const` values and populate `discriminator.mapping`.
110/// 3. **Add `default` alongside `const`** for discriminator tag properties —
111///    `datamodel-code-generator` uses `default` (not `const`) to set tag values.
112///
113/// Schemas are resolved using `#/$defs/` references. For OpenAPI documents
114/// where schemas live under `#/components/schemas/`, use
115/// [`finalize_discriminators_with_prefix`].
116pub fn finalize_discriminators(root: &mut Value) {
117    finalize_discriminators_with_prefix(root, "#/$defs/");
118}
119
120/// Like [`finalize_discriminators`], but with a configurable `$ref` prefix.
121///
122/// The `ref_prefix` determines both where definitions are stored in the JSON
123/// tree and the `$ref` prefix used in references:
124/// - `"#/$defs/"` — JSON Schema (definitions at `root.$defs`)
125/// - `"#/components/schemas/"` — OpenAPI (definitions at `root.components.schemas`)
126pub fn finalize_discriminators_with_prefix(root: &mut Value, ref_prefix: &str) {
127    flatten_string_enum_oneofs(root);
128    convert_nullable_anyof_to_oneof(root);
129    extract_inline_oneof_to_defs(root, ref_prefix);
130    build_discriminator_mappings(root, ref_prefix);
131    add_defaults_to_discriminator_consts(root, ref_prefix);
132}
133
134/// Derive a JSON pointer path from a `$ref` prefix.
135///
136/// - `"#/$defs/"` → `"/$defs"`
137/// - `"#/components/schemas/"` → `"/components/schemas"`
138fn defs_pointer(ref_prefix: &str) -> &str {
139    ref_prefix
140        .strip_prefix('#')
141        .unwrap_or(ref_prefix)
142        .strip_suffix('/')
143        .unwrap_or(ref_prefix)
144}
145
146/// Navigate to (and create if needed) the definitions object at the path
147/// implied by `ref_prefix`.
148fn get_or_create_defs_mut<'a>(
149    root: &'a mut Value,
150    ref_prefix: &str,
151) -> &'a mut serde_json::Map<String, Value> {
152    let pointer = defs_pointer(ref_prefix);
153    let mut current = root;
154    for segment in pointer.split('/').filter(|s| !s.is_empty()) {
155        current = current
156            .as_object_mut()
157            .unwrap()
158            .entry(segment.to_string())
159            .or_insert_with(|| Value::Object(serde_json::Map::new()));
160    }
161    current.as_object_mut().unwrap()
162}
163
164/// Convert `oneOf` schemas of string-const variants into simple string enums.
165///
166/// schemars generates documented Rust enums as `oneOf` arrays with per-variant
167/// `const` + `description` entries.  This is valid JSON Schema but code generators
168/// (openapi-generator, datamodel-code-generator) produce broken or overly complex
169/// types because every variant resolves to `str`.
170///
171/// This rewrites such schemas into `{ "type": "string", "enum": ["a", "b", ...] }`
172/// which all code generators handle correctly.  Schemas that have a `discriminator`
173/// are left untouched — those are tagged unions, not simple enums.
174fn flatten_string_enum_oneofs(root: &mut Value) {
175    match root {
176        Value::Object(obj) => {
177            // Check if this object is a string-const oneOf (without a discriminator)
178            let should_flatten = !obj.contains_key("discriminator")
179                && obj
180                    .get("oneOf")
181                    .and_then(|v| v.as_array())
182                    .is_some_and(|arr| {
183                        !arr.is_empty()
184                            && arr.iter().all(|v| {
185                                v.get("type").and_then(|t| t.as_str()) == Some("string")
186                                    && v.get("const").is_some()
187                            })
188                    });
189
190            if should_flatten {
191                if let Some(Value::Array(one_of)) = obj.remove("oneOf") {
192                    let enum_values: Vec<Value> = one_of
193                        .iter()
194                        .filter_map(|v| v.get("const").cloned())
195                        .collect();
196
197                    // Append per-variant descriptions to the enum's description
198                    let case_docs: Vec<String> = one_of
199                        .iter()
200                        .filter_map(|v| {
201                            let name = v.get("const")?.as_str()?;
202                            let desc = v.get("description")?.as_str()?;
203                            Some(format!("* `{name}`: {desc}"))
204                        })
205                        .collect();
206
207                    if !case_docs.is_empty() {
208                        let existing = obj
209                            .get("description")
210                            .and_then(|d| d.as_str())
211                            .unwrap_or_default();
212                        let full = format!("{existing}\n\nCases:\n{}", case_docs.join("\n"));
213                        obj.insert("description".to_string(), Value::String(full));
214                    }
215
216                    obj.insert("type".to_string(), Value::String("string".to_string()));
217                    obj.insert("enum".to_string(), Value::Array(enum_values));
218                }
219            } else {
220                // Recurse into all values
221                for v in obj.values_mut() {
222                    flatten_string_enum_oneofs(v);
223                }
224            }
225        }
226        Value::Array(arr) => {
227            for v in arr.iter_mut() {
228                flatten_string_enum_oneofs(v);
229            }
230        }
231        _ => {}
232    }
233}
234
235/// Convert nullable `anyOf` patterns to `oneOf`.
236///
237/// schemars generates `Option<T>` as `anyOf: [T, {type: null}]`, but
238/// code generators like openapi-generator handle `oneOf` nullable patterns
239/// correctly (the existing `fix_any_type_from_dict` post-processing in the
240/// Python codegen handles `OneOf` references).  This matches the schema
241/// output that utoipa previously produced.
242fn convert_nullable_anyof_to_oneof(root: &mut Value) {
243    match root {
244        Value::Object(obj) => {
245            // Check for anyOf with exactly one null variant (nullable pattern)
246            let is_nullable_anyof =
247                obj.get("anyOf")
248                    .and_then(|v| v.as_array())
249                    .is_some_and(|arr| {
250                        arr.len() == 2
251                            && arr
252                                .iter()
253                                .any(|v| v.get("type").and_then(|t| t.as_str()) == Some("null"))
254                    });
255
256            if is_nullable_anyof && let Some(any_of) = obj.remove("anyOf") {
257                obj.insert("oneOf".to_string(), any_of);
258            }
259
260            for v in obj.values_mut() {
261                convert_nullable_anyof_to_oneof(v);
262            }
263        }
264        Value::Array(arr) => {
265            for v in arr.iter_mut() {
266                convert_nullable_anyof_to_oneof(v);
267            }
268        }
269        _ => {}
270    }
271}
272
273/// Extract inline oneOf variants to the definitions section in schemas with
274/// discriminators.
275///
276/// schemars inlines all variants in the `oneOf` array. Discriminator mappings
277/// require `$ref` paths, so this extracts inline variants to definitions (using
278/// their `title` as the key) and replaces them with `$ref` entries.
279fn extract_inline_oneof_to_defs(root: &mut Value, ref_prefix: &str) {
280    let mut extractions: Vec<(String, Value)> = Vec::new();
281    extract_inline_oneof_recursive(root, ref_prefix, &mut extractions);
282
283    if extractions.is_empty() {
284        return;
285    }
286
287    let defs = get_or_create_defs_mut(root, ref_prefix);
288
289    for (key, schema) in extractions {
290        if let Some(existing) = defs.get_mut(&key) {
291            // Collision: the variant's title matches an existing $defs key (the inner
292            // type). Merge the discriminator tag property into the existing entry so
293            // that code generators can read the tag const value.
294            merge_tag_properties(existing, &schema);
295        } else {
296            defs.insert(key, schema);
297        }
298    }
299}
300
301fn extract_inline_oneof_recursive(
302    value: &mut Value,
303    ref_prefix: &str,
304    extractions: &mut Vec<(String, Value)>,
305) {
306    match value {
307        Value::Object(obj) => {
308            if obj.contains_key("discriminator")
309                && let Some(Value::Array(one_of)) = obj.get_mut("oneOf")
310            {
311                for variant in one_of.iter_mut() {
312                    // Skip variants that are already pure $ref entries
313                    if variant
314                        .as_object()
315                        .is_some_and(|o| o.len() == 1 && o.contains_key("$ref"))
316                    {
317                        continue;
318                    }
319                    // Extract inline variants with titles to $defs
320                    if let Some(title) = variant
321                        .get("title")
322                        .and_then(|t| t.as_str())
323                        .map(|s| s.to_string())
324                    {
325                        extractions.push((title.clone(), variant.clone()));
326                        *variant = serde_json::json!({ "$ref": format!("{ref_prefix}{title}") });
327                    }
328                }
329            }
330
331            for v in obj.values_mut() {
332                extract_inline_oneof_recursive(v, ref_prefix, extractions);
333            }
334        }
335        Value::Array(arr) => {
336            for v in arr.iter_mut() {
337                extract_inline_oneof_recursive(v, ref_prefix, extractions);
338            }
339        }
340        _ => {}
341    }
342}
343
344/// Merge discriminator tag properties from an extracted variant into an existing `$defs` entry.
345///
346/// When a variant's title matches an existing `$defs` key (e.g., `StepflowPluginConfig`
347/// is both the inner type and the variant title), this adds the tag `const` property
348/// and updates `required` so code generators can resolve the discriminator tag value.
349fn merge_tag_properties(existing: &mut Value, variant: &Value) {
350    // Merge properties (adds tag property from variant)
351    if let Some(variant_props) = variant.get("properties").and_then(|p| p.as_object()) {
352        let def_props = existing
353            .as_object_mut()
354            .unwrap()
355            .entry("properties")
356            .or_insert_with(|| Value::Object(serde_json::Map::new()))
357            .as_object_mut()
358            .unwrap();
359        for (key, value) in variant_props {
360            def_props
361                .entry(key.clone())
362                .or_insert_with(|| value.clone());
363        }
364    }
365
366    // Merge required arrays
367    if let Some(variant_required) = variant.get("required").and_then(|r| r.as_array()) {
368        let def_required = existing
369            .as_object_mut()
370            .unwrap()
371            .entry("required")
372            .or_insert_with(|| Value::Array(Vec::new()))
373            .as_array_mut()
374            .unwrap();
375        for req in variant_required {
376            if !def_required.contains(req) {
377                def_required.push(req.clone());
378            }
379        }
380    }
381}
382
383/// Build discriminator mappings by resolving `$ref` → definition entries
384/// and reading tag `const` values.
385fn build_discriminator_mappings(root: &mut Value, ref_prefix: &str) {
386    let defs = root
387        .pointer(defs_pointer(ref_prefix))
388        .and_then(|v| v.as_object())
389        .cloned();
390
391    // Recursively process all schemas in the document
392    build_discriminator_mappings_recursive(root, ref_prefix, defs.as_ref());
393}
394
395fn build_discriminator_mappings_recursive(
396    value: &mut Value,
397    ref_prefix: &str,
398    defs: Option<&serde_json::Map<String, Value>>,
399) {
400    let Some(defs) = defs else { return };
401    match value {
402        Value::Object(obj) => {
403            // Check if this object has a discriminator that needs mapping completion
404            let needs_mapping = obj
405                .get("discriminator")
406                .is_some_and(|d| d.get("propertyName").is_some());
407
408            if needs_mapping
409                && let Some(property_name) = obj
410                    .get("discriminator")
411                    .and_then(|d| d.get("propertyName"))
412                    .and_then(|p| p.as_str())
413                    .map(|s| s.to_string())
414                && let Some(one_of) = obj.get("oneOf").and_then(|v| v.as_array())
415            {
416                let mut mapping = serde_json::Map::new();
417
418                for variant in one_of {
419                    // Resolve $ref to the definition entry
420                    if let Some(ref_path) = variant.get("$ref").and_then(|r| r.as_str())
421                        && let Some(def_key) = ref_path.strip_prefix(ref_prefix)
422                        && let Some(def_schema) = defs.get(def_key)
423                    {
424                        // Read the const value for the discriminator property
425                        if let Some(const_val) = def_schema
426                            .get("properties")
427                            .and_then(|p| p.get(&property_name))
428                            .and_then(|p| p.get("const"))
429                            .and_then(|c| c.as_str())
430                        {
431                            mapping
432                                .insert(const_val.to_string(), Value::String(ref_path.to_string()));
433                        }
434                    }
435                }
436
437                if !mapping.is_empty()
438                    && let Some(disc) = obj.get_mut("discriminator").and_then(|d| d.as_object_mut())
439                {
440                    // Replace the mapping entirely — the post-processing steps
441                    // (extract_inline_oneof_to_defs) may have changed $ref paths
442                    disc.insert("mapping".to_string(), Value::Object(mapping));
443                }
444            }
445
446            // Recurse into all values
447            for v in obj.values_mut() {
448                build_discriminator_mappings_recursive(v, ref_prefix, Some(defs));
449            }
450        }
451        Value::Array(arr) => {
452            for v in arr.iter_mut() {
453                build_discriminator_mappings_recursive(v, ref_prefix, Some(defs));
454            }
455        }
456        _ => {}
457    }
458}
459
460/// Add `default` alongside `const` for discriminator tag properties in definitions.
461///
462/// Code generators like `datamodel-code-generator` use `default` (not `const`)
463/// to determine tag values for generated tagged union types. This walks all
464/// definition entries referenced by discriminator mappings and adds `default`
465/// equal to `const` for the discriminator tag property.
466fn add_defaults_to_discriminator_consts(root: &mut Value, ref_prefix: &str) {
467    let Some(root_obj) = root.as_object() else {
468        return;
469    };
470
471    // Collect (def_key, property_name) pairs from all discriminator mappings
472    let mut targets: Vec<(String, String)> = Vec::new();
473    collect_discriminator_targets(root_obj, ref_prefix, &mut targets);
474
475    if targets.is_empty() {
476        return;
477    }
478
479    // Apply defaults to the collected targets
480    let pointer = defs_pointer(ref_prefix);
481    let Some(defs) = root.pointer_mut(pointer).and_then(|d| d.as_object_mut()) else {
482        return;
483    };
484
485    for (def_key, property_name) in targets {
486        if let Some(def_schema) = defs.get_mut(&def_key)
487            && let Some(prop) = def_schema
488                .get_mut("properties")
489                .and_then(|p| p.get_mut(&property_name))
490                .and_then(|p| p.as_object_mut())
491            && let Some(const_val) = prop.get("const").cloned()
492        {
493            prop.entry("default").or_insert(const_val);
494        }
495    }
496}
497
498fn collect_discriminator_targets(
499    value: &serde_json::Map<String, Value>,
500    ref_prefix: &str,
501    targets: &mut Vec<(String, String)>,
502) {
503    if let Some(disc) = value.get("discriminator").and_then(|d| d.as_object())
504        && let Some(property_name) = disc.get("propertyName").and_then(|p| p.as_str())
505        && let Some(mapping) = disc.get("mapping").and_then(|m| m.as_object())
506    {
507        for ref_path in mapping.values() {
508            if let Some(ref_str) = ref_path.as_str()
509                && let Some(def_key) = ref_str.strip_prefix(ref_prefix)
510            {
511                targets.push((def_key.to_string(), property_name.to_string()));
512            }
513        }
514    }
515
516    // Recurse into nested objects
517    for v in value.values() {
518        if let Some(obj) = v.as_object() {
519            collect_discriminator_targets(obj, ref_prefix, targets);
520        } else if let Some(arr) = v.as_array() {
521            for item in arr {
522                if let Some(obj) = item.as_object() {
523                    collect_discriminator_targets(obj, ref_prefix, targets);
524                }
525            }
526        }
527    }
528}
529
530/// Recursively transform `#/$defs/X` references to `{base_url}#/$defs/X`.
531fn transform_refs_external(value: &mut Value, base_url: &str) {
532    match value {
533        Value::Object(map) => {
534            if let Some(Value::String(ref_str)) = map.get_mut("$ref")
535                && let Some(name) = ref_str.strip_prefix("#/$defs/")
536            {
537                *ref_str = format!("{base_url}#/$defs/{name}");
538            }
539            for v in map.values_mut() {
540                transform_refs_external(v, base_url);
541            }
542        }
543        Value::Array(arr) => {
544            for v in arr.iter_mut() {
545                transform_refs_external(v, base_url);
546            }
547        }
548        _ => {}
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    #[test]
557    fn test_generate_json_schema_has_required_fields() {
558        use crate::schema::SchemaRef;
559
560        let schema = generate_json_schema::<SchemaRef>();
561
562        // Should have $schema declaration
563        assert_eq!(
564            schema.get("$schema"),
565            Some(&Value::String(
566                "https://json-schema.org/draft/2020-12/schema".to_string()
567            ))
568        );
569
570        // Should have title
571        assert!(schema.get("title").is_some());
572    }
573
574    #[test]
575    fn test_generate_json_schema_with_defs() {
576        use crate::workflow::Flow;
577
578        let schema = generate_json_schema_with_defs::<Flow>();
579
580        // Should have $schema
581        assert!(schema.get("$schema").is_some());
582        // Should have title
583        assert!(schema.get("title").is_some());
584        // Should have $defs
585        assert!(schema.get("$defs").is_some());
586    }
587
588    #[test]
589    fn test_transform_refs_external() {
590        let mut input = serde_json::json!({
591            "$ref": "#/$defs/MyType",
592            "nested": {
593                "$ref": "#/$defs/OtherType"
594            },
595            "array": [
596                { "$ref": "#/$defs/ArrayItem" }
597            ]
598        });
599
600        transform_refs_external(&mut input, "https://stepflow.org/schemas/v1/flow.json");
601
602        assert_eq!(
603            input,
604            serde_json::json!({
605                "$ref": "https://stepflow.org/schemas/v1/flow.json#/$defs/MyType",
606                "nested": {
607                    "$ref": "https://stepflow.org/schemas/v1/flow.json#/$defs/OtherType"
608                },
609                "array": [
610                    { "$ref": "https://stepflow.org/schemas/v1/flow.json#/$defs/ArrayItem" }
611                ]
612            })
613        );
614    }
615}