Skip to main content

pixiv3_rs_proc/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::quote;
6use syn::fold::Fold;
7use syn::parse::{Parse, ParseStream};
8use syn::punctuated::Punctuated;
9use syn::{
10    Attribute, Expr, Ident, LitStr, Token, Type, TypeReference, braced, parenthesized, token,
11};
12
13/// One param or data field: `name @ "key": Type = default => transmute` (all optional after :)
14struct ParamSpec {
15    name: Ident,
16    key_override: Option<LitStr>,
17    ty: Type,
18    default: Option<Expr>,
19    transmute: Option<Expr>,
20}
21
22impl Parse for ParamSpec {
23    fn parse(input: ParseStream) -> syn::Result<Self> {
24        let name: Ident = input.parse()?;
25        let key_override = if input.peek(Token![@]) {
26            input.parse::<Token![@]>()?;
27            let lit: LitStr = input.parse()?;
28            Some(lit)
29        } else {
30            None
31        };
32        input.parse::<Token![:]>()?;
33        let ty: Type = input.parse()?;
34        let default = if input.peek(Token![=]) && !input.peek2(Token![>]) {
35            input.parse::<Token![=]>()?;
36            Some(input.parse()?)
37        } else {
38            None
39        };
40        let transmute = if input.peek(Token![=>]) {
41            input.parse::<Token![=>]>()?;
42            Some(input.parse()?)
43        } else {
44            None
45        };
46        Ok(ParamSpec {
47            name,
48            key_override,
49            ty,
50            default,
51            transmute,
52        })
53    }
54}
55
56/// Optional "params [ ... ]" or "data [ ... ]" section
57struct ParamSection {
58    kind: Ident,
59    entries: Vec<ParamSpec>,
60}
61
62impl Parse for ParamSection {
63    fn parse(input: ParseStream) -> syn::Result<Self> {
64        let kind: Ident = input.parse()?;
65        let content;
66        syn::bracketed!(content in input);
67        let entries = content.parse_terminated(ParamSpec::parse, Token![,])?;
68        Ok(ParamSection {
69            kind,
70            entries: entries.into_iter().collect(),
71        })
72    }
73}
74
75/// Whether the endpoint returns a paged result, and generates an iterator
76/// function for the items.
77struct Paged {
78    #[cfg_attr(not(feature = "stream"), expect(dead_code))]
79    field: Ident,
80    #[cfg_attr(not(feature = "stream"), expect(dead_code))]
81    item_type: Type,
82    #[cfg_attr(not(feature = "stream"), expect(dead_code))]
83    next_url: Option<Ident>,
84}
85
86impl Parse for Paged {
87    fn parse(input: ParseStream) -> syn::Result<Self> {
88        let paged = input.parse::<Ident>()?;
89        if paged != "paged" {
90            return Err(syn::Error::new(paged.span(), "expected paged"));
91        }
92
93        let next_url = if input.peek(Token![@]) {
94            input.parse::<Token![@]>()?;
95            let next_url = input.parse::<Ident>()?;
96            Some(next_url)
97        } else {
98            None
99        };
100
101        let field: Ident = input.parse()?;
102        input.parse::<Token![:]>()?;
103        let item_type: Type = input.parse()?;
104        Ok(Paged {
105            field,
106            item_type,
107            next_url,
108        })
109    }
110}
111
112/// Full endpoint: attrs? name -> ReturnType (paged @next_url? field: ItemType)? { METHOD "url", params? data? }
113struct ApiEndpoint {
114    attrs: Vec<Attribute>,
115    name: Ident,
116    return_type: Type,
117    #[cfg_attr(not(feature = "stream"), expect(dead_code))]
118    paged: Option<Paged>,
119    method: Ident,
120    url: LitStr,
121    sections: Vec<ParamSection>,
122}
123
124impl Parse for ApiEndpoint {
125    fn parse(input: ParseStream) -> syn::Result<Self> {
126        let attrs = input.call(Attribute::parse_outer)?;
127        let name: Ident = input.parse()?;
128        input.parse::<Token![->]>()?;
129        let return_type: Type = input.parse()?;
130
131        let paged = if input.peek(token::Paren) {
132            let content;
133            parenthesized!(content in input);
134            let paged = content.parse::<Paged>()?;
135            Some(paged)
136        } else {
137            None
138        };
139
140        let content;
141        braced!(content in input);
142        let method: Ident = content.parse()?;
143        let url: LitStr = content.parse()?;
144        content.parse::<Token![,]>()?;
145
146        let mut sections = Vec::new();
147        while !content.is_empty() {
148            let lookahead = content.lookahead1();
149            if lookahead.peek(Ident) {
150                let section = content.parse::<ParamSection>()?;
151                sections.push(section);
152                if !content.is_empty() {
153                    content.parse::<Token![,]>()?;
154                }
155            } else {
156                return Err(lookahead.error());
157            }
158        }
159
160        Ok(ApiEndpoint {
161            attrs,
162            name,
163            return_type,
164            method,
165            url,
166            sections,
167            paged,
168        })
169    }
170}
171
172impl ApiEndpoint {
173    pub fn find_section(&self, kind: impl AsRef<str>) -> Option<&ParamSection> {
174        let kind = kind.as_ref();
175        self.sections.iter().find(|s| s.kind == kind)
176    }
177}
178
179/// Fold that adds explicit lifetimes ('a1, 'a2, ...) to all references in a type,
180/// and records them in the lifetimes vec.
181struct ExplicitLifetimeFolder {
182    counter: u32,
183    lifetimes: Vec<syn::Lifetime>,
184}
185
186impl ExplicitLifetimeFolder {
187    fn new() -> Self {
188        Self {
189            counter: 0,
190            lifetimes: Vec::new(),
191        }
192    }
193}
194
195impl Fold for ExplicitLifetimeFolder {
196    fn fold_type_reference(&mut self, mut ty_ref: TypeReference) -> TypeReference {
197        if let Some(lifetime) = &ty_ref.lifetime {
198            self.lifetimes.push(lifetime.clone());
199        } else {
200            self.counter += 1;
201            let lt = syn::Lifetime::new(&format!("'a{}", self.counter), Span::call_site());
202            self.lifetimes.push(lt.clone());
203            ty_ref.lifetime = Some(lt);
204            *ty_ref.elem = self.fold_type(*ty_ref.elem);
205        }
206
207        ty_ref
208    }
209}
210
211struct ApiEndpoints {
212    endpoints: Punctuated<ApiEndpoint, Token![;]>,
213}
214
215impl Parse for ApiEndpoints {
216    fn parse(input: ParseStream) -> syn::Result<Self> {
217        let endpoints = input.parse_terminated(ApiEndpoint::parse, Token![;])?;
218        Ok(ApiEndpoints { endpoints })
219    }
220}
221
222/// Generates async API methods on `AppPixivAPI` from endpoint definitions.
223///
224/// Syntax: one or more endpoints separated by `;`. Each endpoint:
225/// `/// doc? name -> ReturnType (paged @next_url? field: ItemType)? { GET|POST|DELETE "path", params [ ... ]? data [ ... ]? }`
226///
227/// - Params: `name: Type = default => transmute`; use `name @ "key": Type` to override query/form key.
228/// - Paged: `(paged illusts: IllustrationInfo)` generates a method returning a struct with `illusts` and `next_url`.
229///
230/// 根据端点定义在 `AppPixivAPI` 上生成异步 API 方法。语法:多个端点用 `;` 分隔;每条可含 doc、返回类型、可选 paged、方法、路径及 params/data。
231#[proc_macro]
232pub fn api_endpoints(input: TokenStream) -> TokenStream {
233    let endpoints = match syn::parse::<ApiEndpoints>(input) {
234        Ok(e) => e,
235        Err(e) => return e.to_compile_error().into(),
236    };
237
238    let mut expanded = TokenStream2::new();
239
240    for endpoint in endpoints.endpoints {
241        let attrs = &endpoint.attrs;
242        let name = &endpoint.name;
243        let return_type = &endpoint.return_type;
244        let method = &endpoint.method;
245        let url = &endpoint.url;
246
247        let mut fn_params = Vec::new();
248        let mut section_inits = Vec::new();
249        let mut section_bodies = Vec::new();
250        #[cfg(feature = "stream")]
251        let mut fn_args = Vec::new();
252        let mut folder = ExplicitLifetimeFolder::new();
253
254        for section in &endpoint.sections {
255            let kind = &section.kind;
256
257            for spec in &section.entries {
258                let name = &spec.name;
259                let ty = folder.fold_type(spec.ty.clone());
260
261                fn_params.push(quote! { #name: #ty, });
262
263                let key = if let Some(key) = &spec.key_override {
264                    quote! { #key }
265                } else {
266                    quote! { stringify!(#name) }
267                };
268
269                let mut body_for_this = TokenStream2::new();
270
271                if let Some(default) = &spec.default {
272                    body_for_this.extend(quote! {
273                        let #name = #name.unwrap_or_else(|| #default);
274                    });
275                }
276
277                if let Some(transmute) = &spec.transmute {
278                    body_for_this.extend(quote! {
279                        let #name = #transmute;
280                    });
281                }
282
283                body_for_this.extend(quote! {
284                    #kind.push(#key, #name);
285                });
286
287                section_bodies.push(quote! { { #body_for_this } });
288
289                #[cfg(feature = "stream")]
290                fn_args.push(quote! { #name, });
291            }
292
293            section_inits.push(quote! {
294                #[allow(unused_mut)]
295                let mut #kind: kv_pairs::KVPairs<'_> = kv_pairs::kv_pairs![];
296            });
297        }
298
299        let params = if endpoint.find_section("params").is_some() {
300            quote! { Some(params) }
301        } else {
302            quote! { None }
303        };
304        let data = if endpoint.find_section("data").is_some() {
305            quote! { Some(data) }
306        } else {
307            quote! { None }
308        };
309
310        let lifetimes = &folder.lifetimes;
311        let expanded_endpoint = quote! {
312            #(#attrs)*
313            #[allow(clippy::too_many_arguments)]
314            pub async fn #name<'a0 #(, #lifetimes)*>(
315                &'a0 self,
316                #(#fn_params)*
317                with_auth: bool,
318            ) -> Result<#return_type, crate::error::PixivError> {
319                let url = format!("{}{}", self.hosts, #url);
320                #(#section_inits)*
321                #(#section_bodies)*
322                crate::debug!("calling {} at {}", stringify!(#name), #url);
323                let r = self.do_api_request(crate::aapi::HttpMethod::#method, &url, None, #params, #data, with_auth).await?;
324                crate::models::parse_response_into::<#return_type>(r).await
325            }
326        };
327
328        expanded.extend(expanded_endpoint);
329
330        #[cfg(feature = "stream")]
331        if let Some(paged) = &endpoint.paged {
332            use quote::format_ident;
333
334            let iter_fn_name = format_ident!("{}_iter", name);
335            let item_field = &paged.field;
336            let item_type = &paged.item_type;
337            let next_url_field = paged
338                .next_url
339                .clone()
340                .unwrap_or_else(|| format_ident!("next_url"));
341            let iter_doc_comment = format!(
342                "Iterate over the results of {0}.\n\n{0}的迭代版本。",
343                stringify!(#name)
344            );
345
346            let iter_fn = quote! {
347                #[allow(clippy::too_many_arguments)]
348                #[doc = #iter_doc_comment]
349                pub fn #iter_fn_name<'a0 #(, #lifetimes)*>(
350                    &'a0 self,
351                    #(#fn_params)*
352                    with_auth: bool,
353                ) -> impl ::futures_core::stream::Stream<
354                    Item = Result<#item_type, crate::error::PixivError>
355                > + use<'a0 #(, #lifetimes)*> {
356                    crate::debug!("calling {} (iterable version of {})", stringify!(#name), stringify!(#iter_fn_name));
357
358                    async_stream::try_stream! {
359                        crate::debug!("{} first request to {}", stringify!(#iter_fn_name), #url);
360                        let mut result = self.#name(#(#fn_args)* with_auth).await?;
361                        let mut next_url = result.#next_url_field;
362
363                        loop {
364                            for item in result.#item_field {
365                                yield item;
366                            }
367
368                            match &next_url {
369                                Some(url) => {
370                                    crate::debug!("{} next request to {}", stringify!(#iter_fn_name), url);
371                                    result = self.visit_next_url::<#return_type>(url, with_auth).await?;
372                                    next_url = result.#next_url_field;
373                                }
374                                None => {
375                                    crate::debug!("{} reached end of results", stringify!(#iter_fn_name));
376                                    break;
377                                },
378                            }
379                        }
380                    }
381                }
382            };
383
384            expanded.extend(iter_fn);
385        }
386    }
387
388    TokenStream::from(expanded)
389}
390
391/// A no-op macro that does nothing. Used for placeholder or conditional compilation.
392#[proc_macro]
393pub fn no_op_macro(_: TokenStream) -> TokenStream {
394    TokenStream::new()
395}