1use std::collections::HashMap;
4
5use heck::ToUpperCamelCase;
6use openapiv3::{
7 IntegerFormat, NumberFormat, ReferenceOr, Schema, SchemaKind, Type, VariantOrUnknownOrEmpty,
8};
9use proc_macro2::TokenStream;
10use quote::{format_ident, quote};
11use typify::{TypeSpace, TypeSpaceSettings};
12
13use crate::openapi::{Operation, OperationParam, ParamLocation, ParsedSpec, RequestBody};
14use crate::{Error, Result};
15
16pub struct TypeGenerator {
18 type_space: TypeSpace,
19 #[allow(dead_code)]
21 type_names: HashMap<String, String>,
22}
23
24impl TypeGenerator {
25 pub fn new(spec: &ParsedSpec) -> Result<Self> {
27 let settings = TypeSpaceSettings::default();
28 let mut type_space = TypeSpace::new(&settings);
29
30 if let Some(components) = &spec.components {
32 let schemas = components
33 .schemas
34 .iter()
35 .map(|(name, schema)| {
36 let schema = convert_to_schemars(schema)?;
37 Ok((name.clone(), schema))
38 })
39 .collect::<Result<Vec<_>>>()?;
40
41 type_space
42 .add_ref_types(schemas.into_iter())
43 .map_err(|e| Error::TypeGenError(e.to_string()))?;
44 }
45
46 Ok(Self {
47 type_space,
48 type_names: HashMap::new(),
49 })
50 }
51
52 pub fn generate_all_types(&self) -> TokenStream {
54 self.type_space.to_stream()
55 }
56
57 pub fn get_type_name(&self, reference: &str) -> Option<String> {
59 let name = reference.strip_prefix("#/components/schemas/")?;
61 Some(name.to_upper_camel_case())
62 }
63
64 pub fn type_for_schema(&self, schema: &ReferenceOr<Schema>, name_hint: &str) -> TokenStream {
66 match schema {
67 ReferenceOr::Reference { reference } => {
68 if let Some(type_name) = self.get_type_name(reference) {
69 let ident = format_ident!("{}", type_name);
70 quote! { #ident }
71 } else {
72 quote! { serde_json::Value }
73 }
74 }
75 ReferenceOr::Item(schema) => self.type_for_inline_schema(schema, name_hint),
76 }
77 }
78
79 pub fn type_for_boxed_schema(
81 &self,
82 schema: &ReferenceOr<Box<Schema>>,
83 name_hint: &str,
84 ) -> TokenStream {
85 match schema {
86 ReferenceOr::Reference { reference } => {
87 if let Some(type_name) = self.get_type_name(reference) {
88 let ident = format_ident!("{}", type_name);
89 quote! { #ident }
90 } else {
91 quote! { serde_json::Value }
92 }
93 }
94 ReferenceOr::Item(schema) => self.type_for_inline_schema(schema, name_hint),
95 }
96 }
97
98 fn type_for_inline_schema(&self, schema: &Schema, name_hint: &str) -> TokenStream {
100 match &schema.schema_kind {
101 SchemaKind::Type(Type::String(_)) => quote! { String },
102 SchemaKind::Type(Type::Integer(int_type)) => match &int_type.format {
103 VariantOrUnknownOrEmpty::Item(IntegerFormat::Int32) => quote! { i32 },
104 VariantOrUnknownOrEmpty::Item(IntegerFormat::Int64) => quote! { i64 },
105 _ => quote! { i64 },
106 },
107 SchemaKind::Type(Type::Number(num_type)) => match &num_type.format {
108 VariantOrUnknownOrEmpty::Item(NumberFormat::Float) => quote! { f32 },
109 VariantOrUnknownOrEmpty::Item(NumberFormat::Double) => quote! { f64 },
110 _ => quote! { f64 },
111 },
112 SchemaKind::Type(Type::Boolean(_)) => quote! { bool },
113 SchemaKind::Type(Type::Array(arr)) => {
114 if let Some(items) = &arr.items {
115 let inner = self.type_for_boxed_schema(items, &format!("{}Item", name_hint));
116 quote! { Vec<#inner> }
117 } else {
118 quote! { Vec<serde_json::Value> }
119 }
120 }
121 SchemaKind::Type(Type::Object(_)) => {
122 quote! { serde_json::Value }
125 }
126 _ => quote! { serde_json::Value },
127 }
128 }
129
130 pub fn path_param_type(&self, param: &OperationParam) -> TokenStream {
132 if let Some(schema) = ¶m.schema {
133 self.type_for_schema(schema, ¶m.name.to_upper_camel_case())
134 } else {
135 quote! { String }
136 }
137 }
138
139 pub fn query_param_type(&self, param: &OperationParam) -> TokenStream {
141 if let Some(schema) = ¶m.schema {
142 self.type_for_schema(schema, ¶m.name.to_upper_camel_case())
143 } else {
144 quote! { String }
145 }
146 }
147
148 pub fn request_body_type(&self, body: &RequestBody, op_name: &str) -> TokenStream {
150 if let Some(schema) = &body.schema {
151 self.type_for_schema(schema, &format!("{}Body", op_name.to_upper_camel_case()))
152 } else {
153 quote! { serde_json::Value }
154 }
155 }
156
157 #[allow(dead_code)]
159 pub fn response_type(
160 &self,
161 schema: &Option<ReferenceOr<Schema>>,
162 op_name: &str,
163 status: u16,
164 ) -> TokenStream {
165 if let Some(schema) = schema {
166 self.type_for_schema(
167 schema,
168 &format!("{}Response{}", op_name.to_upper_camel_case(), status),
169 )
170 } else {
171 quote! { () }
172 }
173 }
174
175 pub fn generate_query_struct(&self, op: &Operation) -> Option<(syn::Ident, TokenStream)> {
177 let query_params: Vec<_> = op
178 .parameters
179 .iter()
180 .filter(|p| p.location == ParamLocation::Query)
181 .collect();
182
183 if query_params.is_empty() {
184 return None;
185 }
186
187 let struct_name = format_ident!(
188 "{}Query",
189 op.operation_id
190 .as_deref()
191 .unwrap_or(&op.path)
192 .to_upper_camel_case()
193 );
194
195 let fields = query_params.iter().map(|param| {
196 let name = format_ident!("{}", heck::AsSnakeCase(¶m.name).to_string());
197 let ty = self.query_param_type(param);
198
199 if param.required {
200 quote! { pub #name: #ty }
201 } else {
202 quote! { pub #name: Option<#ty> }
203 }
204 });
205
206 let definition = quote! {
207 #[derive(Debug, Clone, serde::Deserialize)]
208 pub struct #struct_name {
209 #(#fields,)*
210 }
211 };
212
213 Some((struct_name, definition))
214 }
215
216 pub fn generate_path_type(&self, op: &Operation) -> TokenStream {
218 let path_params: Vec<_> = op
219 .parameters
220 .iter()
221 .filter(|p| p.location == ParamLocation::Path)
222 .collect();
223
224 if path_params.is_empty() {
225 return quote! { () };
226 }
227
228 if path_params.len() == 1 {
229 return self.path_param_type(path_params[0]);
230 }
231
232 let types = path_params.iter().map(|p| self.path_param_type(p));
233 quote! { (#(#types),*) }
234 }
235}
236
237fn convert_to_schemars(schema: &ReferenceOr<Schema>) -> Result<schemars::schema::Schema> {
239 let json = serde_json::to_value(schema)
241 .map_err(|e| Error::TypeGenError(format!("failed to serialize schema: {}", e)))?;
242
243 serde_json::from_value(json)
244 .map_err(|e| Error::TypeGenError(format!("failed to convert schema: {}", e)))
245}