Skip to main content

spikard_cli/codegen/
base.rs

1//! Base trait for `OpenAPI` code generators
2//!
3//! Provides a language-neutral abstraction for code generation from `OpenAPI` specs,
4//! eliminating duplication across Python, TypeScript, Ruby, and PHP generators.
5
6use crate::codegen::SchemaRegistry;
7use anyhow::Result;
8use openapiv3::{OpenAPI, Operation, ReferenceOr, Schema};
9
10/// Base trait for `OpenAPI` code generators to eliminate duplication across languages.
11///
12/// Implementors should override language-specific methods while leveraging shared
13/// default implementations for common patterns.
14pub trait OpenApiGenerator {
15    /// Get the `OpenAPI` specification
16    fn spec(&self) -> &OpenAPI;
17
18    /// Get the schema registry for reference resolution
19    fn registry(&self) -> &SchemaRegistry;
20
21    /// Generate the file header (imports, module declaration, etc.)
22    fn generate_header(&self) -> String;
23
24    /// Generate data models/DTOs from `OpenAPI` components
25    fn generate_models(&self) -> Result<String>;
26
27    /// Generate route handlers from `OpenAPI` paths
28    fn generate_routes(&self) -> Result<String>;
29
30    /// Generate file footer (bootstrap, exports, etc.)
31    fn generate_footer(&self) -> String {
32        String::new()
33    }
34
35    /// Orchestrate the full code generation pipeline
36    fn generate(&self) -> Result<String> {
37        let mut output = String::new();
38
39        output.push_str(&self.generate_header());
40        output.push_str(&self.generate_models()?);
41        output.push_str(&self.generate_routes()?);
42
43        let footer = self.generate_footer();
44        if !footer.is_empty() {
45            output.push_str(&footer);
46        }
47
48        Ok(output)
49    }
50
51    /// Iterate over all paths in the spec and apply a function to each operation
52    fn iter_paths<F>(&self, mut f: F) -> Result<()>
53    where
54        F: FnMut(&str, &str, &Operation) -> Result<()>,
55    {
56        for (path, path_item_ref) in &self.spec().paths.paths {
57            let path_item = match path_item_ref {
58                ReferenceOr::Item(item) => item,
59                ReferenceOr::Reference { .. } => continue,
60            };
61
62            if let Some(op) = &path_item.get {
63                f(path, "get", op)?;
64            }
65            if let Some(op) = &path_item.post {
66                f(path, "post", op)?;
67            }
68            if let Some(op) = &path_item.put {
69                f(path, "put", op)?;
70            }
71            if let Some(op) = &path_item.delete {
72                f(path, "delete", op)?;
73            }
74            if let Some(op) = &path_item.patch {
75                f(path, "patch", op)?;
76            }
77        }
78        Ok(())
79    }
80
81    /// Iterate over all component schemas and apply a function to each
82    fn iter_schemas<F>(&self, mut f: F) -> Result<()>
83    where
84        F: FnMut(&str, &Schema) -> Result<()>,
85    {
86        if let Some(components) = &self.spec().components {
87            for (name, schema_ref) in &components.schemas {
88                match schema_ref {
89                    ReferenceOr::Item(schema) => {
90                        f(name, schema)?;
91                    }
92                    ReferenceOr::Reference { .. } => continue,
93                }
94            }
95        }
96        Ok(())
97    }
98
99    /// Extract request body type from operation (looks for application/json)
100    fn extract_request_body_type(&self, operation: &Operation) -> Option<String> {
101        operation.request_body.as_ref().and_then(|body_ref| match body_ref {
102            ReferenceOr::Item(request_body) => request_body.content.get("application/json").and_then(|media_type| {
103                media_type
104                    .schema
105                    .as_ref()
106                    .map(|schema_ref| self.extract_type_from_schema_ref(schema_ref))
107            }),
108            ReferenceOr::Reference { reference } => {
109                let ref_name = reference.split('/').next_back().unwrap();
110                Some(self.format_type_name(ref_name))
111            }
112        })
113    }
114
115    /// Extract response type from operation (looks for 200/201 responses)
116    fn extract_response_type(&self, operation: &Operation) -> String {
117        use openapiv3::StatusCode;
118
119        let response = operation
120            .responses
121            .responses
122            .get(&StatusCode::Code(200))
123            .or_else(|| operation.responses.responses.get(&StatusCode::Code(201)))
124            .or_else(|| operation.responses.responses.get(&StatusCode::Range(2)));
125
126        if let Some(response_ref) = response {
127            match response_ref {
128                ReferenceOr::Item(response) => {
129                    if let Some(content) = response.content.get("application/json")
130                        && let Some(schema_ref) = &content.schema
131                    {
132                        return self.extract_type_from_schema_ref(schema_ref);
133                    }
134                }
135                ReferenceOr::Reference { reference } => {
136                    let ref_name = reference.split('/').next_back().unwrap();
137                    return self.format_type_name(ref_name);
138                }
139            }
140        }
141
142        self.default_response_type()
143    }
144
145    /// Extract type name from a schema reference or inline schema
146    fn extract_type_from_schema_ref(&self, schema_ref: &ReferenceOr<Schema>) -> String {
147        match schema_ref {
148            ReferenceOr::Reference { reference } => {
149                let ref_name = reference.split('/').next_back().unwrap();
150                self.format_type_name(ref_name)
151            }
152            ReferenceOr::Item(_schema) => self.default_response_type(),
153        }
154    }
155
156    /// Format a type name according to language conventions (`PascalCase` by default)
157    fn format_type_name(&self, name: &str) -> String {
158        heck::ToPascalCase::to_pascal_case(name)
159    }
160
161    /// Return the language's default response type (e.g., "dict[str, Any]", "Record<string, unknown>")
162    fn default_response_type(&self) -> String {
163        "unknown".to_string()
164    }
165
166    /// Generate operation ID (function/method name) from operation and path
167    fn generate_operation_id(&self, path: &str, method: &str, operation: &Operation) -> String {
168        operation
169            .operation_id
170            .as_ref()
171            .map(|id| heck::ToSnakeCase::to_snake_case(id.as_str()))
172            .unwrap_or_else(|| {
173                format!(
174                    "{}_{}",
175                    method,
176                    path.replace('/', "_").replace(['{', '}'], "").trim_matches('_')
177                )
178            })
179    }
180}