Skip to main content

openapi_to_rust/
patterns.rs

1use serde_json::Value;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum AnyOfSemantics {
5    /// Nullable type pattern: [Type, null]
6    Nullable,
7    /// Union type pattern: multiple complex types
8    Union,
9    /// Flexible union: mixed refs and primitives
10    FlexibleUnion,
11}
12
13#[derive(Debug, Clone)]
14pub struct DiscriminatedUnionInfo {
15    pub field: String,
16    pub mappings: std::collections::HashMap<String, String>,
17}
18
19pub fn analyze_anyof_semantics(schema: &Value) -> Option<AnyOfSemantics> {
20    let variants = schema.get("anyOf")?.as_array()?;
21
22    // Pattern 1: [Type, null] = nullable type
23    if variants.len() == 2 {
24        let has_null = variants.iter().any(is_null_type);
25        let has_type = variants.iter().any(|v| !is_null_type(v));
26
27        if has_null && has_type {
28            return Some(AnyOfSemantics::Nullable);
29        }
30    }
31
32    // Pattern 2: Multiple complex types = union
33    if variants.iter().all(is_complex_type) {
34        return Some(AnyOfSemantics::Union);
35    }
36
37    // Pattern 3: Mixed refs and primitives = flexible union
38    Some(AnyOfSemantics::FlexibleUnion)
39}
40
41pub fn is_null_type(schema: &Value) -> bool {
42    schema.get("type").and_then(|t| t.as_str()) == Some("null")
43}
44
45pub fn is_complex_type(schema: &Value) -> bool {
46    // Consider refs and objects as complex types
47    schema.get("$ref").is_some()
48        || schema.get("type").and_then(|t| t.as_str()) == Some("object")
49        || schema.get("properties").is_some()
50}
51
52pub fn detect_discriminated_union(schema: &Value) -> Option<DiscriminatedUnionInfo> {
53    // Check for explicit discriminator field
54    if let Some(discriminator) = schema.get("discriminator") {
55        return parse_explicit_discriminator(discriminator, schema);
56    }
57
58    // Auto-detect from oneOf/anyOf patterns
59    if let Some(variants) = schema.get("oneOf").or_else(|| schema.get("anyOf")) {
60        return detect_implicit_discriminator(variants);
61    }
62
63    None
64}
65
66fn parse_explicit_discriminator(
67    discriminator: &Value,
68    schema: &Value,
69) -> Option<DiscriminatedUnionInfo> {
70    let field = discriminator.get("propertyName")?.as_str()?.to_string();
71
72    // Extract mappings from discriminator.mapping if present
73    let mut mappings = std::collections::HashMap::new();
74
75    if let Some(mapping) = discriminator.get("mapping") {
76        if let Some(mapping_obj) = mapping.as_object() {
77            for (key, value) in mapping_obj {
78                if let Some(value_str) = value.as_str() {
79                    mappings.insert(key.clone(), value_str.to_string());
80                }
81            }
82        }
83    } else {
84        // No explicit mapping - need to extract from variant schemas
85        if let Some(variants) = schema.get("oneOf").or_else(|| schema.get("anyOf")) {
86            if let Some(variants_array) = variants.as_array() {
87                mappings = extract_variant_discriminator_mappings(variants_array, &field)?;
88            }
89        }
90    }
91
92    Some(DiscriminatedUnionInfo { field, mappings })
93}
94
95fn extract_variant_discriminator_mappings(
96    variants: &[Value],
97    _discriminator_field: &str,
98) -> Option<std::collections::HashMap<String, String>> {
99    let mut mappings = std::collections::HashMap::new();
100
101    for variant in variants {
102        if let Some(ref_str) = variant.get("$ref").and_then(|r| r.as_str()) {
103            if let Some(schema_name) = extract_schema_name_from_ref(ref_str) {
104                // For now, we can't resolve the actual schema here since we don't have access to all schemas
105                // The actual discriminator value extraction will be handled in analysis.rs
106                // We'll create a placeholder mapping using the schema name
107                mappings.insert(schema_name.clone(), format!("schema:{schema_name}"));
108            }
109        }
110    }
111
112    if mappings.is_empty() {
113        None
114    } else {
115        Some(mappings)
116    }
117}
118
119fn extract_schema_name_from_ref(ref_str: &str) -> Option<String> {
120    ref_str.split('/').next_back().map(|s| s.to_string())
121}
122
123fn detect_implicit_discriminator(variants: &Value) -> Option<DiscriminatedUnionInfo> {
124    let variant_refs = extract_variant_refs(variants)?;
125
126    // Check if all variants have a common field that could be a discriminator
127    let common_discriminator = find_common_discriminator_field(&variant_refs)?;
128
129    // TODO: Verify each variant has a unique value for the discriminator
130    // This would require resolving the schema references
131
132    Some(DiscriminatedUnionInfo {
133        field: common_discriminator,
134        mappings: std::collections::HashMap::new(),
135    })
136}
137
138fn extract_variant_refs(variants: &Value) -> Option<Vec<String>> {
139    let variants_array = variants.as_array()?;
140
141    let refs: Vec<String> = variants_array
142        .iter()
143        .filter_map(|v| v.get("$ref").and_then(|r| r.as_str()))
144        .map(|s| s.to_string())
145        .collect();
146
147    if refs.len() == variants_array.len() && refs.len() > 1 {
148        Some(refs)
149    } else {
150        None
151    }
152}
153
154fn find_common_discriminator_field(_variant_refs: &[String]) -> Option<String> {
155    // For now, assume "type" is the common discriminator field
156    // In a full implementation, this would resolve each ref and check
157    // for common fields across all variants
158    Some("type".to_string())
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use serde_json::json;
165
166    #[test]
167    fn test_nullable_pattern_detection() {
168        let nullable_schema = json!({
169            "anyOf": [
170                {"$ref": "#/components/schemas/Error"},
171                {"type": "null"}
172            ]
173        });
174
175        assert_eq!(
176            analyze_anyof_semantics(&nullable_schema),
177            Some(AnyOfSemantics::Nullable)
178        );
179    }
180
181    #[test]
182    fn test_union_pattern_detection() {
183        let union_schema = json!({
184            "anyOf": [
185                {"$ref": "#/components/schemas/TextContent"},
186                {"$ref": "#/components/schemas/ImageContent"}
187            ]
188        });
189
190        assert_eq!(
191            analyze_anyof_semantics(&union_schema),
192            Some(AnyOfSemantics::Union)
193        );
194    }
195
196    #[test]
197    fn test_discriminated_union_detection() {
198        let schema = json!({
199            "oneOf": [
200                {"$ref": "#/components/schemas/ResponseCreatedEvent"},
201                {"$ref": "#/components/schemas/ResponseTextDeltaEvent"}
202            ],
203            "discriminator": {
204                "propertyName": "type"
205            }
206        });
207
208        let result = detect_discriminated_union(&schema);
209        assert!(result.is_some());
210
211        let info = result.unwrap();
212        assert_eq!(info.field, "type");
213    }
214}