Skip to main content

openapi_to_rust/
registry_generator.rs

1//! Operation registry generation for OpenAPI specifications.
2//!
3//! Generates a static registry of operation metadata from analyzed OpenAPI operations.
4//! The registry contains everything needed to:
5//! - Build CLI subcommands dynamically (names, params, help text)
6//! - Route and validate operations in a proxy (URL templates, param types, methods)
7//!
8//! The generated code is pure data — no HTTP client, no clap derives, no runtime dependencies
9//! beyond serde. Consumers (CLI shims, proxy routers) interpret the registry generically.
10
11use crate::analysis::SchemaAnalysis;
12use crate::generator::CodeGenerator;
13use proc_macro2::TokenStream;
14use quote::quote;
15
16impl CodeGenerator {
17    /// Generate the registry.rs file content
18    pub fn generate_registry(&self, analysis: &SchemaAnalysis) -> crate::Result<String> {
19        let registry_types = Self::generate_registry_types();
20        let operation_defs = self.generate_operation_defs(analysis);
21
22        let tokens = quote! {
23            //! Auto-generated operation registry. Do not edit.
24
25            #registry_types
26            #operation_defs
27        };
28
29        let file = syn::parse2(tokens).map_err(|e| {
30            crate::GeneratorError::CodeGenError(format!("Failed to parse registry tokens: {}", e))
31        })?;
32        Ok(prettyplease::unparse(&file))
33    }
34
35    /// Generate the registry data types
36    fn generate_registry_types() -> TokenStream {
37        quote! {
38            /// HTTP method for an operation
39            #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
40            pub enum HttpMethod {
41                Get,
42                Post,
43                Put,
44                Patch,
45                Delete,
46            }
47
48            impl HttpMethod {
49                pub fn as_str(&self) -> &'static str {
50                    match self {
51                        Self::Get => "GET",
52                        Self::Post => "POST",
53                        Self::Put => "PUT",
54                        Self::Patch => "PATCH",
55                        Self::Delete => "DELETE",
56                    }
57                }
58            }
59
60            impl std::fmt::Display for HttpMethod {
61                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62                    f.write_str(self.as_str())
63                }
64            }
65
66            /// Where a parameter appears in the request
67            #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
68            pub enum ParamLocation {
69                Path,
70                Query,
71                Header,
72            }
73
74            /// Primitive type of a parameter (for validation and CLI parsing)
75            #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
76            pub enum ParamType {
77                String,
78                Integer,
79                Number,
80                Boolean,
81            }
82
83            /// Content type for request bodies
84            #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
85            pub enum BodyContentType {
86                Json,
87                FormUrlEncoded,
88                Multipart,
89                OctetStream,
90                TextPlain,
91            }
92
93            /// Definition of an operation parameter.
94            ///
95            /// Only `Serialize` is derived: this struct holds `&'static`
96            /// references to data baked into the binary, which cannot be
97            /// reconstructed by `Deserialize`.
98            #[derive(Debug, Clone, serde::Serialize)]
99            pub struct ParamDef {
100                pub name: &'static str,
101                pub location: ParamLocation,
102                pub required: bool,
103                pub param_type: ParamType,
104                pub description: Option<&'static str>,
105            }
106
107            /// Definition of a request body.
108            ///
109            /// `Serialize`-only for the same reason as [`ParamDef`].
110            #[derive(Debug, Clone, serde::Serialize)]
111            pub struct BodyDef {
112                pub content_type: BodyContentType,
113                /// Name of the schema type (for JSON/form bodies)
114                pub schema_name: Option<&'static str>,
115            }
116
117            /// A single operation in the registry.
118            ///
119            /// `Serialize`-only because of the `&'static` fields.
120            #[derive(Debug, Clone, serde::Serialize)]
121            pub struct OperationDef {
122                /// Unique operation identifier (e.g. "repos/get", "issues/create-comment")
123                pub id: &'static str,
124                /// HTTP method
125                pub method: HttpMethod,
126                /// URL path template with {param} placeholders
127                pub path: &'static str,
128                /// Short summary for CLI help
129                pub summary: Option<&'static str>,
130                /// Longer description
131                pub description: Option<&'static str>,
132                /// Parameters (path, query, header)
133                pub params: &'static [ParamDef],
134                /// Request body definition
135                pub body: Option<BodyDef>,
136                /// Response schema name for the success (2xx) case
137                pub response_schema: Option<&'static str>,
138            }
139
140            /// Look up an operation by ID
141            pub fn find_operation(id: &str) -> Option<&'static OperationDef> {
142                OPERATIONS.iter().find(|op| op.id == id)
143            }
144
145            /// List all operation IDs
146            pub fn operation_ids() -> impl Iterator<Item = &'static str> {
147                OPERATIONS.iter().map(|op| op.id)
148            }
149        }
150    }
151
152    /// Generate the static OPERATIONS slice from analyzed operations
153    fn generate_operation_defs(&self, analysis: &SchemaAnalysis) -> TokenStream {
154        let mut param_statics: Vec<TokenStream> = Vec::new();
155        let mut op_entries: Vec<TokenStream> = Vec::new();
156
157        // Sort for deterministic output
158        let mut sorted_ops: Vec<_> = analysis.operations.values().collect();
159        sorted_ops.sort_by_key(|op| &op.operation_id);
160
161        for op in sorted_ops {
162            let id = &op.operation_id;
163            let method = match op.method.as_str() {
164                "GET" => quote! { HttpMethod::Get },
165                "POST" => quote! { HttpMethod::Post },
166                "PUT" => quote! { HttpMethod::Put },
167                "PATCH" => quote! { HttpMethod::Patch },
168                "DELETE" => quote! { HttpMethod::Delete },
169                _ => quote! { HttpMethod::Get },
170            };
171            let path = &op.path;
172
173            let summary = match &op.summary {
174                Some(s) => quote! { Some(#s) },
175                None => quote! { None },
176            };
177            let description = match &op.description {
178                Some(d) => quote! { Some(#d) },
179                None => quote! { None },
180            };
181
182            // Generate params
183            let param_defs: Vec<TokenStream> = op
184                .parameters
185                .iter()
186                .map(|p| {
187                    let name = &p.name;
188                    let location = match p.location.as_str() {
189                        "path" => quote! { ParamLocation::Path },
190                        "query" => quote! { ParamLocation::Query },
191                        "header" => quote! { ParamLocation::Header },
192                        _ => quote! { ParamLocation::Query },
193                    };
194                    let required = p.required;
195                    let param_type = match p.rust_type.as_str() {
196                        "i64" | "i32" => quote! { ParamType::Integer },
197                        "f64" => quote! { ParamType::Number },
198                        "bool" => quote! { ParamType::Boolean },
199                        _ => quote! { ParamType::String },
200                    };
201                    let desc = match &p.description {
202                        Some(d) => quote! { Some(#d) },
203                        None => quote! { None },
204                    };
205                    quote! {
206                        ParamDef {
207                            name: #name,
208                            location: #location,
209                            required: #required,
210                            param_type: #param_type,
211                            description: #desc,
212                        }
213                    }
214                })
215                .collect();
216
217            // Sanitize operation ID to a valid Rust identifier for the static name
218            let sanitized_id: String = op
219                .operation_id
220                .chars()
221                .map(|c| {
222                    if c.is_ascii_alphanumeric() {
223                        c.to_ascii_uppercase()
224                    } else {
225                        '_'
226                    }
227                })
228                .collect();
229            let params_static_name = syn::Ident::new(
230                &format!("PARAMS_{sanitized_id}"),
231                proc_macro2::Span::call_site(),
232            );
233            let param_count = param_defs.len();
234
235            // Emit the param array as a separate static
236            param_statics.push(quote! {
237                static #params_static_name: [ParamDef; #param_count] = [#(#param_defs),*];
238            });
239
240            // Generate body def
241            let body = match &op.request_body {
242                Some(rb) => {
243                    use crate::analysis::RequestBodyContent;
244                    let (content_type, schema_name) = match rb {
245                        RequestBodyContent::Json { schema_name } => (
246                            quote! { BodyContentType::Json },
247                            quote! { Some(#schema_name) },
248                        ),
249                        RequestBodyContent::FormUrlEncoded { schema_name } => (
250                            quote! { BodyContentType::FormUrlEncoded },
251                            quote! { Some(#schema_name) },
252                        ),
253                        RequestBodyContent::Multipart => {
254                            (quote! { BodyContentType::Multipart }, quote! { None })
255                        }
256                        RequestBodyContent::OctetStream => {
257                            (quote! { BodyContentType::OctetStream }, quote! { None })
258                        }
259                        RequestBodyContent::TextPlain => {
260                            (quote! { BodyContentType::TextPlain }, quote! { None })
261                        }
262                    };
263                    quote! {
264                        Some(BodyDef {
265                            content_type: #content_type,
266                            schema_name: #schema_name,
267                        })
268                    }
269                }
270                None => quote! { None },
271            };
272
273            // Response schema (2xx)
274            let response_schema = op
275                .response_schemas
276                .get("200")
277                .or_else(|| op.response_schemas.get("201"))
278                .or_else(|| {
279                    op.response_schemas
280                        .iter()
281                        .find(|(code, _)| code.starts_with('2'))
282                        .map(|(_, v)| v)
283                });
284            let response_schema_token = match response_schema {
285                Some(s) => quote! { Some(#s) },
286                None => quote! { None },
287            };
288
289            op_entries.push(quote! {
290                OperationDef {
291                    id: #id,
292                    method: #method,
293                    path: #path,
294                    summary: #summary,
295                    description: #description,
296                    params: &#params_static_name,
297                    body: #body,
298                    response_schema: #response_schema_token,
299                }
300            });
301        }
302
303        let op_count = op_entries.len();
304        quote! {
305            #(#param_statics)*
306
307            pub static OPERATIONS: [OperationDef; #op_count] = [#(#op_entries),*];
308        }
309    }
310}