kpl_derive/
lib.rs

1use darling::FromMeta;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{parse_macro_input, Data, DeriveInput, Fields};
5
6/// Procedural macro to generate API client code for stock API endpoints
7///
8/// Example usage:
9/// ```rust
10/// #[derive(ApiEndpoint)]
11/// #[endpoint(name = "历史每日涨跌统计")] // method and path are optional with defaults
12/// struct HisZhangFuDetail {
13///     // Fields can use serde rename attributes for proper parameter naming
14///     #[serde(rename = "VerSion")]
15///     ver_sion: String,
16/// }
17/// ```
18#[derive(Debug, FromMeta)]
19struct EndpointOpts {
20    name: String,
21    #[darling(default = "EndpointOpts::default_method")]
22    method: String,
23    #[darling(default = "EndpointOpts::default_path")]
24    path: String,
25    #[darling(default = "EndpointOpts::default_host")]
26    host: String,
27    #[darling(default)]
28    resp: Option<syn::Path>,
29}
30
31impl EndpointOpts {
32    fn default_method() -> String {
33        "GET".to_string()
34    }
35
36    fn default_path() -> String {
37        "/w1/api/index.php".to_string()
38    }
39
40    fn default_host() -> String {
41        "apphis.longhuvip.com".to_string()
42    }
43}
44
45#[proc_macro_derive(ApiEndpoint, attributes(endpoint, serde))]
46pub fn api_endpoint_derive(input: TokenStream) -> TokenStream {
47    let input = parse_macro_input!(input as DeriveInput);
48    let name = &input.ident;
49
50    // Parse endpoint options using darling
51    let endpoint_opts = match input
52        .attrs
53        .iter()
54        .find(|attr| attr.path().is_ident("endpoint"))
55        .map(|attr| EndpointOpts::from_meta(&attr.meta))
56        .transpose()
57    {
58        Ok(Some(opts)) => opts,
59        Ok(None) => EndpointOpts {
60            name: String::new(),
61            method: EndpointOpts::default_method(),
62            path: EndpointOpts::default_path(),
63            host: EndpointOpts::default_host(),
64            resp: None,
65        },
66        Err(e) => return TokenStream::from(e.write_errors()),
67    };
68
69    // Determine the actual return type to use in the generated code
70    let actual_response_type = match &endpoint_opts.resp {
71        Some(ty) => quote! { #ty },
72        None => quote! { serde_json::Value },
73    };
74
75    // Extract fields for query parameters
76    let fields = match &input.data {
77        Data::Struct(data) => match &data.fields {
78            Fields::Named(fields) => &fields.named,
79            _ => panic!("ApiEndpoint only supports structs with named fields"),
80        },
81        _ => panic!("ApiEndpoint only supports structs"),
82    };
83
84    // Generate field accessors for query parameters
85    let query_params = fields.iter().map(|field| {
86        let field_name = field.ident.as_ref().unwrap();
87
88        // Check for serde rename attribute
89        let mut rename_value = None;
90        for attr in &field.attrs {
91            if attr.path().is_ident("serde") {
92                let _ = attr.parse_nested_meta(|meta| {
93                    if let Some(ident) = meta.path.get_ident() {
94                        if ident == "rename" {
95                            rename_value = Some(meta.value()?.parse::<syn::LitStr>()?.value());
96                        }
97                    }
98                    Ok(())
99                });
100            }
101        }
102
103        // Use rename value if present, otherwise use field name
104        let param_name = rename_value.unwrap_or_else(|| field_name.to_string());
105        quote! {
106            (#param_name, self.#field_name.to_string())
107        }
108    });
109
110    let endpoint_name = &endpoint_opts.name;
111    let method = &endpoint_opts.method;
112    let path = &endpoint_opts.path;
113    let host = &endpoint_opts.host;
114
115    // Generate implementation with conditional deserialization logic based on response type
116    let output = if endpoint_opts.resp.is_some() {
117        quote! {
118            impl #name {
119                pub async fn execute(&self) -> Result<#actual_response_type, crate::error::ApiError> {
120                    self.execute_with_host(None).await
121                }
122
123                pub async fn execute_with_host(&self, host: Option<String>) -> Result<#actual_response_type, crate::error::ApiError> {
124                    use reqwest::Client;
125                    use serde_json::Value;
126
127                    let method = match reqwest::Method::from_bytes(#method.as_bytes()) {
128                        Ok(m) => m,
129                        Err(e) => return Err(crate::error::ApiError::InvalidMethod(e.to_string())),
130                    };
131
132                    let query_params = vec![
133                        #(#query_params),*
134                    ];
135
136                    let host = host.unwrap_or_else(|| #host.to_string());
137
138                    let response = Client::builder()
139                        .danger_accept_invalid_certs(true)
140                        .build()
141                        .unwrap()
142                        .request(method, format!("https://{}{}", host, #path))
143                        .header("User-Agent", "lhb/5.18.5 (com.kaipanla.www; build:2; iOS 18.3.0) Alamofire/4.9.1")
144                        .header("Content-Type", "application/x-www-form-urlencoded; application/x-www-form-urlencoded; charset=utf-8")
145                        .header("Accept-Language", "zh-Hans-CN;q=1.0")
146                        .query(&query_params)
147                        .send()
148                        .await
149                        .map_err(|e| crate::error::ApiError::from(e))?
150                        .error_for_status()
151                        .map_err(|e| crate::error::ApiError::RequestFailed(e.status().map_or(500, |s| s.as_u16())))?;
152
153                    let json_value: Value = response.json().await.map_err(|e| crate::error::ApiError::from(e))?;
154                    Ok(<#actual_response_type>::from(json_value))
155                }
156            }
157
158            impl std::fmt::Display for #name {
159                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160                    write!(f, "{}", #endpoint_name)
161                }
162            }
163        }
164    } else {
165        quote! {
166            impl #name {
167                pub async fn execute(&self) -> Result<#actual_response_type, crate::error::ApiError> {
168                    self.execute_with_host(None).await
169                }
170
171                pub async fn execute_with_host(&self, host: Option<String>) -> Result<#actual_response_type, crate::error::ApiError> {
172                    use reqwest::Client;
173
174                    let method = match reqwest::Method::from_bytes(#method.as_bytes()) {
175                        Ok(m) => m,
176                        Err(e) => return Err(crate::error::ApiError::InvalidMethod(e.to_string())),
177                    };
178
179                    let query_params = vec![
180                        #(#query_params),*
181                    ];
182
183                    let host = host.unwrap_or_else(|| #host.to_string());
184
185                    let response = Client::builder()
186                        .danger_accept_invalid_certs(true)
187                        .build()
188                        .unwrap()
189                        .request(method, format!("https://{}{}", host, #path))
190                        .header("User-Agent", "lhb/5.18.5 (com.kaipanla.www; build:2; iOS 18.3.0) Alamofire/4.9.1")
191                        .header("Content-Type", "application/x-www-form-urlencoded; application/x-www-form-urlencoded; charset=utf-8")
192                        .header("Accept-Language", "zh-Hans-CN;q=1.0")
193                        .query(&query_params)
194                        .send()
195                        .await
196                        .map_err(crate::error::ApiError::from)?
197                        .error_for_status()
198                        .map_err(|e| crate::error::ApiError::RequestFailed(e.status().map_or(500, |s| s.as_u16())))?;
199
200                    response.json::<#actual_response_type>()
201                        .await
202                        .map_err(ApiError::from)
203                }
204            }
205
206            impl std::fmt::Display for #name {
207                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208                    write!(f, "{}", #endpoint_name)
209                }
210            }
211        }
212    };
213
214    output.into()
215}