1#![deny(missing_docs)]
6
7use std::{
8 collections::HashMap,
9 fs::File,
10 path::{Path, PathBuf},
11};
12
13use openapiv3::OpenAPI;
14use proc_macro::TokenStream;
15use progenitor_impl::{
16 CrateVers, GenerationSettings, Generator, InterfaceStyle, TagStyle, TypePatch, UnknownPolicy,
17};
18use quote::{quote, ToTokens};
19use schemars::schema::SchemaObject;
20use serde::Deserialize;
21use serde_tokenstream::{OrderedMap, ParseWrapper};
22use syn::LitStr;
23use token_utils::TypeAndImpls;
24
25mod token_utils;
26
27#[proc_macro]
123pub fn generate_api(item: TokenStream) -> TokenStream {
124 match do_generate_api(item) {
125 Err(err) => err.to_compile_error().into(),
126 Ok(out) => out,
127 }
128}
129
130#[derive(Deserialize)]
131struct MacroSettings {
132 spec: ParseWrapper<LitStr>,
133 #[serde(default)]
134 interface: InterfaceStyle,
135 #[serde(default)]
136 tags: TagStyle,
137
138 inner_type: Option<ParseWrapper<syn::Type>>,
139 pre_hook: Option<ParseWrapper<ClosureOrPath>>,
140 pre_hook_async: Option<ParseWrapper<ClosureOrPath>>,
141 post_hook: Option<ParseWrapper<ClosureOrPath>>,
142 post_hook_async: Option<ParseWrapper<ClosureOrPath>>,
143
144 map_type: Option<ParseWrapper<syn::Type>>,
145
146 #[serde(default)]
147 derives: Vec<ParseWrapper<syn::Path>>,
148
149 #[serde(default)]
150 unknown_crates: UnknownPolicy,
151 #[serde(default)]
152 crates: HashMap<CrateName, MacroCrateSpec>,
153
154 #[serde(default)]
155 patch: HashMap<ParseWrapper<syn::Ident>, MacroPatch>,
156 #[serde(default)]
157 replace: HashMap<ParseWrapper<syn::Ident>, ParseWrapper<TypeAndImpls>>,
158 #[serde(default)]
159 convert: OrderedMap<SchemaObject, ParseWrapper<TypeAndImpls>>,
160 timeout: Option<u64>,
161}
162
163#[derive(Deserialize)]
164struct MacroPatch {
165 #[serde(default)]
166 rename: Option<String>,
167 #[serde(default)]
168 derives: Vec<ParseWrapper<syn::Path>>,
169}
170
171impl From<MacroPatch> for TypePatch {
172 fn from(a: MacroPatch) -> Self {
173 let mut s = Self::default();
174 a.rename.iter().for_each(|rename| {
175 s.with_rename(rename);
176 });
177 a.derives.iter().for_each(|derive| {
178 s.with_derive(derive.to_token_stream().to_string());
179 });
180 s
181 }
182}
183
184#[derive(Debug)]
185struct ClosureOrPath(proc_macro2::TokenStream);
186
187impl syn::parse::Parse for ClosureOrPath {
188 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
189 let lookahead = input.lookahead1();
190
191 if lookahead.peek(syn::token::Paren) {
192 let group: proc_macro2::Group = input.parse()?;
193 return syn::parse2::<Self>(group.stream());
194 }
195
196 if let Ok(closure) = input.parse::<syn::ExprClosure>() {
197 return Ok(Self(closure.to_token_stream()));
198 }
199
200 input
201 .parse::<syn::Path>()
202 .map(|path| Self(path.to_token_stream()))
203 }
204}
205
206struct MacroCrateSpec {
207 original: Option<String>,
208 version: CrateVers,
209}
210
211impl<'de> Deserialize<'de> for MacroCrateSpec {
212 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
213 where
214 D: serde::Deserializer<'de>,
215 {
216 let ss = String::deserialize(deserializer)?;
217
218 let (original, vers_str) = if let Some(ii) = ss.find('@') {
219 let original_str = &ss[..ii];
220 let rest = &ss[ii + 1..];
221 if !is_crate(original_str) {
222 return Err(<D::Error as serde::de::Error>::invalid_value(
223 serde::de::Unexpected::Str(&ss),
224 &"valid crate name",
225 ));
226 }
227
228 (Some(original_str.to_string()), rest)
229 } else {
230 (None, ss.as_ref())
231 };
232
233 let Some(version) = CrateVers::parse(vers_str) else {
234 return Err(<D::Error as serde::de::Error>::invalid_value(
235 serde::de::Unexpected::Str(&ss),
236 &"valid version",
237 ));
238 };
239
240 Ok(Self { original, version })
241 }
242}
243
244#[derive(Hash, PartialEq, Eq)]
245struct CrateName(String);
246impl<'de> Deserialize<'de> for CrateName {
247 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
248 where
249 D: serde::Deserializer<'de>,
250 {
251 let ss = String::deserialize(deserializer)?;
252
253 if is_crate(&ss) {
254 Ok(Self(ss))
255 } else {
256 Err(<D::Error as serde::de::Error>::invalid_value(
257 serde::de::Unexpected::Str(&ss),
258 &"valid crate name",
259 ))
260 }
261 }
262}
263
264fn is_crate(s: &str) -> bool {
265 !s.contains(|cc: char| !cc.is_alphanumeric() && cc != '_' && cc != '-')
266}
267
268fn open_file(path: PathBuf, span: proc_macro2::Span) -> Result<File, syn::Error> {
269 File::open(path.clone()).map_err(|e| {
270 let path_str = path.to_string_lossy();
271 syn::Error::new(span, format!("couldn't read file {}: {}", path_str, e))
272 })
273}
274
275fn do_generate_api(item: TokenStream) -> Result<TokenStream, syn::Error> {
276 let (spec, settings) = if let Ok(spec) = syn::parse::<LitStr>(item.clone()) {
277 (spec, GenerationSettings::default())
278 } else {
279 let MacroSettings {
280 spec,
281 interface,
282 tags,
283 inner_type,
284 pre_hook,
285 pre_hook_async,
286 post_hook,
287 post_hook_async,
288 map_type,
289 unknown_crates,
290 crates,
291 derives,
292 patch,
293 replace,
294 convert,
295 timeout,
296 } = serde_tokenstream::from_tokenstream(&item.into())?;
297
298 let mut settings = GenerationSettings::default();
299 settings.with_interface(interface);
300 settings.with_tag(tags);
301 inner_type.map(|inner_type| settings.with_inner_type(inner_type.to_token_stream()));
302 pre_hook.map(|pre_hook| settings.with_pre_hook(pre_hook.into_inner().0));
303 pre_hook_async
304 .map(|pre_hook_async| settings.with_pre_hook_async(pre_hook_async.into_inner().0));
305 post_hook.map(|post_hook| settings.with_post_hook(post_hook.into_inner().0));
306 post_hook_async
307 .map(|post_hook_async| settings.with_post_hook_async(post_hook_async.into_inner().0));
308 map_type.map(|map_type| settings.with_map_type(map_type.to_token_stream()));
309
310 settings.with_unknown_crates(unknown_crates);
311 crates.into_iter().for_each(
312 |(CrateName(crate_name), MacroCrateSpec { original, version })| {
313 if let Some(original_crate) = original {
314 settings.with_crate(original_crate, version, Some(&crate_name));
315 } else {
316 settings.with_crate(crate_name, version, None);
317 }
318 },
319 );
320
321 derives.into_iter().for_each(|derive| {
322 settings.with_derive(derive.to_token_stream());
323 });
324 patch.into_iter().for_each(|(type_name, patch)| {
325 settings.with_patch(type_name.to_token_stream().to_string(), &patch.into());
326 });
327 replace.into_iter().for_each(|(type_name, type_and_impls)| {
328 let type_name = type_name.to_token_stream();
329 let (replace_name, impls) = type_and_impls.into_inner().into_name_and_impls();
330 settings.with_replacement(type_name, replace_name, impls);
331 });
332 convert.into_iter().for_each(|(schema, type_and_impls)| {
333 let (type_name, impls) = type_and_impls.into_inner().into_name_and_impls();
334 settings.with_conversion(schema, type_name, impls);
335 });
336 if let Some(timeout) = timeout {
337 settings.with_timeout(timeout);
338 }
339 (spec.into_inner(), settings)
340 };
341
342 let dir = std::env::var("CARGO_MANIFEST_DIR").map_or_else(
343 |_| std::env::current_dir().unwrap(),
344 |s| Path::new(&s).to_path_buf(),
345 );
346
347 let path = dir.join(spec.value());
348 let path_str = path.to_string_lossy();
349
350 let mut f = open_file(path.clone(), spec.span())?;
351 let oapi: OpenAPI = match serde_json::from_reader(f) {
352 Ok(json_value) => json_value,
353 _ => {
354 f = open_file(path.clone(), spec.span())?;
355 serde_yaml::from_reader(f).map_err(|e| {
356 syn::Error::new(spec.span(), format!("failed to parse {}: {}", path_str, e))
357 })?
358 }
359 };
360
361 let mut builder = Generator::new(&settings);
362
363 let code = builder.generate_tokens(&oapi).map_err(|e| {
364 syn::Error::new(
365 spec.span(),
366 format!("generation error for {}: {}", spec.value(), e),
367 )
368 })?;
369
370 let output = quote! {
371 use progenitor::progenitor_client;
374
375 #code
376
377 const _: &str = include_str!(#path_str);
379 };
380
381 Ok(output.into())
382}