Skip to main content

unistructgen_openapi_parser/
schema.rs

1//! Schema conversion from OpenAPI to IR
2
3use crate::error::{OpenApiError, Result};
4use crate::options::OpenApiParserOptions;
5use crate::types::{
6    extract_type_name_from_ref, openapi_type_to_ir, sanitize_field_name, to_pascal_case,
7};
8use crate::validation::extract_validation_constraints;
9use openapiv3::{
10    OpenAPI, ReferenceOr, Schema, SchemaKind, Type,
11};
12use std::collections::HashSet;
13use unistructgen_core::{
14    IREnum, IREnumVariant, IRField, IRStruct, IRType, IRTypeRef, PrimitiveKind,
15};
16
17/// Schema converter that maintains context and handles references
18pub struct SchemaConverter<'a> {
19    /// The OpenAPI specification
20    spec: &'a OpenAPI,
21    /// Parser options
22    options: &'a OpenApiParserOptions,
23    /// Generated type names to avoid duplicates
24    generated_types: HashSet<String>,
25    /// Current depth for recursion prevention
26    current_depth: usize,
27    /// Reference resolution stack for cycle detection
28    reference_stack: Vec<String>,
29    /// Inline enum types generated from properties
30    inline_enum_types: Vec<IRType>,
31}
32
33impl<'a> SchemaConverter<'a> {
34    /// Create a new schema converter
35    pub fn new(spec: &'a OpenAPI, options: &'a OpenApiParserOptions) -> Self {
36        Self {
37            spec,
38            options,
39            generated_types: HashSet::new(),
40            current_depth: 0,
41            reference_stack: Vec::new(),
42            inline_enum_types: Vec::new(),
43        }
44    }
45
46    /// Convert all schemas from components to IR types
47    pub fn convert_all_schemas(&mut self) -> Result<Vec<IRType>> {
48        let mut types = Vec::new();
49
50        if let Some(components) = &self.spec.components {
51            for (name, schema_ref) in &components.schemas {
52                let schema = match schema_ref {
53                    ReferenceOr::Item(schema) => schema,
54                    ReferenceOr::Reference { .. } => {
55                        // Skip references at the top level
56                        continue;
57                    }
58                };
59
60                let ir_type = self.convert_schema(name, schema)?;
61                if let Some(ty) = ir_type {
62                    types.push(ty);
63                }
64            }
65        }
66
67        // Add all inline enum types generated from properties
68        types.extend(self.inline_enum_types.drain(..));
69
70        Ok(types)
71    }
72
73    /// Convert a single schema to IR type
74    pub fn convert_schema(&mut self, name: &str, schema: &Schema) -> Result<Option<IRType>> {
75        // Check depth limit
76        if self.current_depth >= self.options.max_depth {
77            return Err(OpenApiError::invalid_spec(format!(
78                "Maximum schema depth ({}) exceeded for '{}'",
79                self.options.max_depth, name
80            )));
81        }
82
83        self.current_depth += 1;
84        let result = self.convert_schema_impl(name, schema);
85        self.current_depth -= 1;
86
87        result
88    }
89
90    fn convert_schema_impl(&mut self, name: &str, schema: &Schema) -> Result<Option<IRType>> {
91        match &schema.schema_kind {
92            SchemaKind::Type(Type::Object(obj_type)) => {
93                let struct_name = self.options.format_type_name(&to_pascal_case(name));
94
95                // Check if already generated
96                if self.generated_types.contains(&struct_name) {
97                    return Ok(None);
98                }
99                self.generated_types.insert(struct_name.clone());
100
101                let mut ir_struct = IRStruct::new(struct_name);
102
103                // Add documentation
104                if self.options.generate_docs {
105                    if let Some(desc) = &schema.schema_data.description {
106                        ir_struct.doc = Some(desc.clone());
107                    }
108                }
109
110                // Add derives
111                if self.options.derive_serde {
112                    ir_struct.add_derive("serde::Serialize".to_string());
113                    ir_struct.add_derive("serde::Deserialize".to_string());
114                }
115                if self.options.derive_default {
116                    ir_struct.add_derive("Default".to_string());
117                }
118                if self.options.generate_validation {
119                    ir_struct.add_derive("validator::Validate".to_string());
120                }
121
122                // Convert properties to fields
123                let required_fields: HashSet<_> =
124                    obj_type.required.iter().map(|s| s.as_str()).collect();
125
126                for (field_name, property_ref) in &obj_type.properties {
127                    let property = match property_ref {
128                        ReferenceOr::Item(schema) => schema,
129                        ReferenceOr::Reference { reference: _ } => {
130                            // We need to resolve this reference manually
131                            // For now, skip or use a simplified approach
132                            continue;
133                        }
134                    };
135                    let field =
136                        self.convert_property(field_name, property, &required_fields)?;
137                    ir_struct.add_field(field);
138                }
139
140                Ok(Some(IRType::Struct(ir_struct)))
141            }
142
143            SchemaKind::Type(Type::String(string_type)) if !string_type.enumeration.is_empty() => {
144                // Enum type
145                let enum_name = self.options.format_type_name(&to_pascal_case(name));
146
147                if self.generated_types.contains(&enum_name) {
148                    return Ok(None);
149                }
150                self.generated_types.insert(enum_name.clone());
151
152                let mut ir_enum = IREnum {
153                    name: enum_name,
154                    variants: Vec::new(),
155                    derives: vec![
156                        "Debug".to_string(),
157                        "Clone".to_string(),
158                        "PartialEq".to_string(),
159                        "Eq".to_string(),
160                        "Hash".to_string(),
161                    ],
162                    doc: schema.schema_data.description.clone(),
163                };
164
165                if self.options.derive_serde {
166                    ir_enum.derives.push("serde::Serialize".to_string());
167                    ir_enum.derives.push("serde::Deserialize".to_string());
168                }
169
170                for variant_value in &string_type.enumeration {
171                    if let Some(variant_str) = variant_value {
172                        let pascal_name = to_pascal_case(variant_str);
173                        let variant = IREnumVariant {
174                            name: pascal_name.clone(),
175                            source_value: if pascal_name != *variant_str {
176                                Some(variant_str.clone())
177                            } else {
178                                None
179                            },
180                            doc: None,
181                        };
182                        ir_enum.variants.push(variant);
183                    }
184                }
185
186                Ok(Some(IRType::Enum(ir_enum)))
187            }
188
189            SchemaKind::AllOf { all_of } => {
190                // Merge all schemas
191                self.convert_all_of(name, all_of)
192            }
193
194            SchemaKind::OneOf { one_of } => {
195                // Create enum with variants
196                self.convert_one_of(name, one_of)
197            }
198
199            SchemaKind::AnyOf { any_of } => {
200                // Similar to oneOf
201                self.convert_any_of(name, any_of)
202            }
203
204            _ => {
205                // Other types (primitives) don't generate standalone types
206                Ok(None)
207            }
208        }
209    }
210
211    /// Convert a property to an IR field
212    fn convert_property(
213        &mut self,
214        name: &str,
215        schema: &Schema,
216        required_fields: &HashSet<&str>,
217    ) -> Result<IRField> {
218        let field_name = sanitize_field_name(name);
219        let is_required = required_fields.contains(name);
220
221        // Determine field type
222        let mut field_type = match &schema.schema_kind {
223            SchemaKind::Type(Type::Object(_)) => {
224                // Nested object - generate a new type
225                let nested_name = to_pascal_case(name);
226                if let Some(IRType::Struct(nested_struct)) =
227                    self.convert_schema(&nested_name, schema)?
228                {
229                    IRTypeRef::Named(nested_struct.name)
230                } else {
231                    IRTypeRef::Primitive(PrimitiveKind::Json)
232                }
233            }
234            SchemaKind::Type(Type::String(string_type)) if !string_type.enumeration.is_empty() => {
235                // Inline enum - generate enum type
236                let enum_name = to_pascal_case(name);
237
238                // Check if not already generated
239                if !self.generated_types.contains(&enum_name) {
240                    self.generated_types.insert(enum_name.clone());
241
242                    let mut ir_enum = IREnum {
243                        name: enum_name.clone(),
244                        variants: Vec::new(),
245                        derives: vec![
246                            "Debug".to_string(),
247                            "Clone".to_string(),
248                            "PartialEq".to_string(),
249                            "Eq".to_string(),
250                            "Hash".to_string(),
251                        ],
252                        doc: schema.schema_data.description.clone(),
253                    };
254
255                    if self.options.derive_serde {
256                        ir_enum.derives.push("serde::Serialize".to_string());
257                        ir_enum.derives.push("serde::Deserialize".to_string());
258                    }
259
260                    for variant_value in &string_type.enumeration {
261                        if let Some(variant_str) = variant_value {
262                            let pascal_name = to_pascal_case(variant_str);
263                            let variant = IREnumVariant {
264                                name: pascal_name.clone(),
265                                source_value: if pascal_name != *variant_str {
266                                    Some(variant_str.clone())
267                                } else {
268                                    None
269                                },
270                                doc: None,
271                            };
272                            ir_enum.variants.push(variant);
273                        }
274                    }
275
276                    self.inline_enum_types.push(IRType::Enum(ir_enum));
277                }
278
279                IRTypeRef::Named(enum_name)
280            }
281            _ => openapi_type_to_ir(schema, Some(name))?,
282        };
283
284        // Make optional if not required or if option is set
285        if !is_required || self.options.make_fields_optional {
286            field_type = field_type.make_optional();
287        }
288
289        let mut field = IRField::new(field_name.clone(), field_type);
290
291        // Set source name for serde rename if different
292        if field_name != name {
293            field.source_name = Some(name.to_string());
294            field.attributes.push(format!("#[serde(rename = \"{}\")]", name));
295        }
296
297        // Add documentation
298        if self.options.generate_docs {
299            if let Some(desc) = &schema.schema_data.description {
300                field.doc = Some(desc.clone());
301            }
302        }
303
304        // Extract validation constraints
305        if self.options.generate_validation {
306            field.constraints = extract_validation_constraints(schema);
307        }
308
309        field.optional = !is_required;
310
311        Ok(field)
312    }
313
314    /// Resolve a schema reference to the actual schema
315    fn resolve_schema_ref(&self, schema_ref: &'a ReferenceOr<Schema>) -> Result<&'a Schema> {
316        match schema_ref {
317            ReferenceOr::Item(schema) => Ok(schema),
318            ReferenceOr::Reference { reference } => {
319                // Check for circular references
320                if self.reference_stack.contains(reference) {
321                    return Err(OpenApiError::circular_reference(reference.clone()));
322                }
323
324                // Extract schema name from reference
325                let schema_name = extract_type_name_from_ref(reference);
326
327                // Look up in components
328                let components = self.spec.components.as_ref().ok_or_else(|| {
329                    OpenApiError::reference_resolution(
330                        reference.clone(),
331                        "no components in spec".to_string(),
332                    )
333                })?;
334
335                let found_schema_ref = components.schemas.get(&schema_name).ok_or_else(|| {
336                    OpenApiError::reference_resolution(
337                        reference.clone(),
338                        format!("schema '{}' not found in components", schema_name),
339                    )
340                })?;
341
342                match found_schema_ref {
343                    ReferenceOr::Item(schema) => Ok(schema),
344                    ReferenceOr::Reference { .. } => Err(OpenApiError::reference_resolution(
345                        reference.clone(),
346                        "nested references not supported".to_string(),
347                    )),
348                }
349            }
350        }
351    }
352
353    /// Convert allOf schema composition
354    fn convert_all_of(
355        &mut self,
356        name: &str,
357        schemas: &[ReferenceOr<Schema>],
358    ) -> Result<Option<IRType>> {
359        let struct_name = self.options.format_type_name(&to_pascal_case(name));
360
361        if self.generated_types.contains(&struct_name) {
362            return Ok(None);
363        }
364        self.generated_types.insert(struct_name.clone());
365
366        let mut ir_struct = IRStruct::new(struct_name);
367
368        // Merge all schemas
369        // First, collect all field info to avoid borrow checker issues
370        let mut fields_to_process = Vec::new();
371
372        for schema_ref in schemas {
373            let schema = self.resolve_schema_ref(schema_ref)?;
374
375            if let SchemaKind::Type(Type::Object(obj_type)) = &schema.schema_kind {
376                let required: HashSet<_> = obj_type.required.iter().map(|s| s.as_str()).collect();
377
378                for (field_name, property_ref) in &obj_type.properties {
379                    let property = match property_ref {
380                        ReferenceOr::Item(schema) => schema,
381                        ReferenceOr::Reference { .. } => continue,
382                    };
383                    // Clone the necessary data to avoid holding the borrow
384                    fields_to_process.push((
385                        field_name.clone(),
386                        property.clone(),
387                        required.clone(),
388                    ));
389                }
390            }
391        }
392
393        // Now process the fields with mutable access to self
394        for (field_name, property, required) in fields_to_process {
395            let required_set: HashSet<&str> = required.iter().map(|s| s.as_ref()).collect();
396            let field = self.convert_property(&field_name, &property, &required_set)?;
397            ir_struct.add_field(field);
398        }
399
400        // Add standard derives
401        if self.options.derive_serde {
402            ir_struct.add_derive("serde::Serialize".to_string());
403            ir_struct.add_derive("serde::Deserialize".to_string());
404        }
405
406        Ok(Some(IRType::Struct(ir_struct)))
407    }
408
409    /// Convert oneOf schema composition (creates enum)
410    fn convert_one_of(
411        &mut self,
412        name: &str,
413        schemas: &[ReferenceOr<Schema>],
414    ) -> Result<Option<IRType>> {
415        let enum_name = self.options.format_type_name(&to_pascal_case(name));
416
417        if self.generated_types.contains(&enum_name) {
418            return Ok(None);
419        }
420        self.generated_types.insert(enum_name.clone());
421
422        let mut ir_enum = IREnum {
423            name: enum_name,
424            variants: Vec::new(),
425            derives: vec![
426                "Debug".to_string(),
427                "Clone".to_string(),
428                "PartialEq".to_string(),
429            ],
430            doc: None,
431        };
432
433        if self.options.derive_serde {
434            ir_enum.derives.push("serde::Serialize".to_string());
435            ir_enum.derives.push("serde::Deserialize".to_string());
436        }
437
438        // Create variant for each schema
439        for (idx, _schema_ref) in schemas.iter().enumerate() {
440            let variant_name = format!("Variant{}", idx + 1);
441            let variant = IREnumVariant {
442                name: variant_name,
443                source_value: None,
444                doc: None,
445            };
446            ir_enum.variants.push(variant);
447        }
448
449        Ok(Some(IRType::Enum(ir_enum)))
450    }
451
452    /// Convert anyOf schema composition
453    fn convert_any_of(
454        &mut self,
455        name: &str,
456        schemas: &[ReferenceOr<Schema>],
457    ) -> Result<Option<IRType>> {
458        // Similar to oneOf for now
459        self.convert_one_of(name, schemas)
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_sanitize_field_name() {
469        assert_eq!(sanitize_field_name("userName"), "user_name");
470        assert_eq!(sanitize_field_name("type"), "type_");
471        assert_eq!(sanitize_field_name("123field"), "_123field");
472    }
473}