openapi_to_rust/
patterns.rs1use serde_json::Value;
2
3#[derive(Debug, Clone, PartialEq)]
4pub enum AnyOfSemantics {
5 Nullable,
7 Union,
9 FlexibleUnion,
11}
12
13#[derive(Debug, Clone)]
14pub struct DiscriminatedUnionInfo {
15 pub field: String,
16 pub mappings: std::collections::HashMap<String, String>,
17}
18
19pub fn analyze_anyof_semantics(schema: &Value) -> Option<AnyOfSemantics> {
20 let variants = schema.get("anyOf")?.as_array()?;
21
22 if variants.len() == 2 {
24 let has_null = variants.iter().any(is_null_type);
25 let has_type = variants.iter().any(|v| !is_null_type(v));
26
27 if has_null && has_type {
28 return Some(AnyOfSemantics::Nullable);
29 }
30 }
31
32 if variants.iter().all(is_complex_type) {
34 return Some(AnyOfSemantics::Union);
35 }
36
37 Some(AnyOfSemantics::FlexibleUnion)
39}
40
41pub fn is_null_type(schema: &Value) -> bool {
42 schema.get("type").and_then(|t| t.as_str()) == Some("null")
43}
44
45pub fn is_complex_type(schema: &Value) -> bool {
46 schema.get("$ref").is_some()
48 || schema.get("type").and_then(|t| t.as_str()) == Some("object")
49 || schema.get("properties").is_some()
50}
51
52pub fn detect_discriminated_union(schema: &Value) -> Option<DiscriminatedUnionInfo> {
53 if let Some(discriminator) = schema.get("discriminator") {
55 return parse_explicit_discriminator(discriminator, schema);
56 }
57
58 if let Some(variants) = schema.get("oneOf").or_else(|| schema.get("anyOf")) {
60 return detect_implicit_discriminator(variants);
61 }
62
63 None
64}
65
66fn parse_explicit_discriminator(
67 discriminator: &Value,
68 schema: &Value,
69) -> Option<DiscriminatedUnionInfo> {
70 let field = discriminator.get("propertyName")?.as_str()?.to_string();
71
72 let mut mappings = std::collections::HashMap::new();
74
75 if let Some(mapping) = discriminator.get("mapping") {
76 if let Some(mapping_obj) = mapping.as_object() {
77 for (key, value) in mapping_obj {
78 if let Some(value_str) = value.as_str() {
79 mappings.insert(key.clone(), value_str.to_string());
80 }
81 }
82 }
83 } else {
84 if let Some(variants) = schema.get("oneOf").or_else(|| schema.get("anyOf")) {
86 if let Some(variants_array) = variants.as_array() {
87 mappings = extract_variant_discriminator_mappings(variants_array, &field)?;
88 }
89 }
90 }
91
92 Some(DiscriminatedUnionInfo { field, mappings })
93}
94
95fn extract_variant_discriminator_mappings(
96 variants: &[Value],
97 _discriminator_field: &str,
98) -> Option<std::collections::HashMap<String, String>> {
99 let mut mappings = std::collections::HashMap::new();
100
101 for variant in variants {
102 if let Some(ref_str) = variant.get("$ref").and_then(|r| r.as_str()) {
103 if let Some(schema_name) = extract_schema_name_from_ref(ref_str) {
104 mappings.insert(schema_name.clone(), format!("schema:{schema_name}"));
108 }
109 }
110 }
111
112 if mappings.is_empty() {
113 None
114 } else {
115 Some(mappings)
116 }
117}
118
119fn extract_schema_name_from_ref(ref_str: &str) -> Option<String> {
120 ref_str.split('/').next_back().map(|s| s.to_string())
121}
122
123fn detect_implicit_discriminator(variants: &Value) -> Option<DiscriminatedUnionInfo> {
124 let variant_refs = extract_variant_refs(variants)?;
125
126 let common_discriminator = find_common_discriminator_field(&variant_refs)?;
128
129 Some(DiscriminatedUnionInfo {
133 field: common_discriminator,
134 mappings: std::collections::HashMap::new(),
135 })
136}
137
138fn extract_variant_refs(variants: &Value) -> Option<Vec<String>> {
139 let variants_array = variants.as_array()?;
140
141 let refs: Vec<String> = variants_array
142 .iter()
143 .filter_map(|v| v.get("$ref").and_then(|r| r.as_str()))
144 .map(|s| s.to_string())
145 .collect();
146
147 if refs.len() == variants_array.len() && refs.len() > 1 {
148 Some(refs)
149 } else {
150 None
151 }
152}
153
154fn find_common_discriminator_field(_variant_refs: &[String]) -> Option<String> {
155 Some("type".to_string())
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use serde_json::json;
165
166 #[test]
167 fn test_nullable_pattern_detection() {
168 let nullable_schema = json!({
169 "anyOf": [
170 {"$ref": "#/components/schemas/Error"},
171 {"type": "null"}
172 ]
173 });
174
175 assert_eq!(
176 analyze_anyof_semantics(&nullable_schema),
177 Some(AnyOfSemantics::Nullable)
178 );
179 }
180
181 #[test]
182 fn test_union_pattern_detection() {
183 let union_schema = json!({
184 "anyOf": [
185 {"$ref": "#/components/schemas/TextContent"},
186 {"$ref": "#/components/schemas/ImageContent"}
187 ]
188 });
189
190 assert_eq!(
191 analyze_anyof_semantics(&union_schema),
192 Some(AnyOfSemantics::Union)
193 );
194 }
195
196 #[test]
197 fn test_discriminated_union_detection() {
198 let schema = json!({
199 "oneOf": [
200 {"$ref": "#/components/schemas/ResponseCreatedEvent"},
201 {"$ref": "#/components/schemas/ResponseTextDeltaEvent"}
202 ],
203 "discriminator": {
204 "propertyName": "type"
205 }
206 });
207
208 let result = detect_discriminated_union(&schema);
209 assert!(result.is_some());
210
211 let info = result.unwrap();
212 assert_eq!(info.field, "type");
213 }
214}