ts_derive/
lib.rs

1#![allow(dead_code)]
2
3use darling::FromMeta;
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, Data, DeriveInput, Fields, LitInt, Type};
7
8/// Options for the TsEndpoint derive macro
9#[derive(Debug, FromMeta)]
10struct EndpointOpts {
11    /// API name for the Tushare request
12    api: String,
13    /// Description of the API endpoint
14    desc: String,
15    /// Response type (optional)
16    #[darling(default)]
17    resp: Option<syn::Path>,
18}
19
20/// Options for the TsResponse derive macro
21#[derive(Debug, FromMeta)]
22struct ResponseOpts {
23    /// API name for the Tushare response
24    api: String,
25}
26
27/// Derive macro for Tushare API endpoints
28///
29/// Example usage:
30/// ```rust
31/// #[derive(TsEndpoint)]
32/// #[endpoint(api = "api_name", desc = "description", resp = MyResponseType)]
33/// struct MyRequest {
34///     // ... fields ...
35/// }
36/// ```
37#[proc_macro_derive(TsEndpoint, attributes(endpoint, fields))]
38pub fn ts_endpoint_derive(input: TokenStream) -> TokenStream {
39    let input = parse_macro_input!(input as DeriveInput);
40    let name = &input.ident;
41    let requester_name = syn::Ident::new(&format!("{}Requester", name), name.span());
42
43    // Parse endpoint options using darling
44    let endpoint_opts = match input
45        .attrs
46        .iter()
47        .find(|attr| attr.path().is_ident("endpoint"))
48        .map(|attr| EndpointOpts::from_meta(&attr.meta))
49        .transpose()
50    {
51        Ok(Some(opts)) => opts,
52        Ok(None) => {
53            return syn::Error::new_spanned(
54                input.ident.clone(),
55                "Missing #[endpoint(...)] attribute",
56            )
57            .to_compile_error()
58            .into()
59        }
60        Err(e) => return TokenStream::from(e.write_errors()),
61    };
62
63    // Extract fields for request parameters
64    let fields = match &input.data {
65        Data::Struct(data) => match &data.fields {
66            Fields::Named(fields) => &fields.named,
67            _ => {
68                return syn::Error::new_spanned(
69                    input.ident.clone(),
70                    "TsEndpoint only supports structs with named fields",
71                )
72                .to_compile_error()
73                .into()
74            }
75        },
76        _ => {
77            return syn::Error::new_spanned(input.ident.clone(), "TsEndpoint only supports structs")
78                .to_compile_error()
79                .into()
80        }
81    };
82
83    // Generate field serialization for the params object
84    let param_fields = fields.iter().map(|field| {
85        let field_name = field.ident.as_ref().unwrap();
86        let field_name_str = field_name.to_string();
87
88        // Check for serde rename attribute
89        let mut rename_value = None;
90        for attr in &field.attrs {
91            if attr.path().is_ident("serde") {
92                let _ = attr.parse_nested_meta(|meta| {
93                    if meta.path.is_ident("rename") {
94                        rename_value = Some(meta.value()?.parse::<syn::LitStr>()?.value());
95                    }
96                    Ok(())
97                });
98            }
99        }
100
101        // Use rename value if present, otherwise use field name
102        let param_name = rename_value.unwrap_or_else(|| field_name_str.clone());
103
104        quote! {
105            params.insert(#param_name.to_string(), serde_json::to_value(&self.#field_name)?);
106        }
107    });
108
109    // Get API name and description from the endpoint options
110    let api_name = &endpoint_opts.api;
111    let api_desc = &endpoint_opts.desc;
112
113    // Check if response type is specified
114    let resp_type = endpoint_opts.resp.as_ref().map(|path| quote! { #path });
115
116    // Generate the TsRequesterImpl struct implementation with a unique name
117    let ts_requester_impl = if let Some(resp_type) = resp_type.clone() {
118        quote! {
119            // 定义单独的TsRequester结构体和impl,这个结构体是在当前crate中的
120            pub struct #requester_name {
121                request: #name,
122                fields: Option<Vec<&'static str>>,
123            }
124
125            impl #requester_name {
126                pub fn new(request: #name, fields: Option<Vec<&'static str>>) -> Self {
127                    Self { request, fields }
128                }
129
130                pub fn with_fields(mut self, fields: Vec<&'static str>) -> Self {
131                    self.fields = Some(fields);
132                    self
133                }
134
135                pub async fn execute(self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
136                    self.request.__execute_request(self.fields).await
137                }
138
139                pub async fn execute_typed(self) -> Result<Vec<#resp_type>, Box<dyn std::error::Error>> {
140                    // If fields are not provided, extract field names from the response struct
141                    let fields_to_use = if self.fields.is_none() {
142                        // Get field names from the response struct by reflection
143                        let field_names = <#resp_type>::get_field_names();
144                        Some(field_names)
145                    } else {
146                        self.fields
147                    };
148
149                    // Execute with the fields (either provided or derived)
150                    let json = self.request.__execute_request(fields_to_use).await?;
151                    let res = <#resp_type>::from_json(&json);
152                    res
153                }
154
155                pub async fn execute_as_dicts(self) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>, Box<dyn std::error::Error>> {
156                    use serde_json::Value;
157                    use std::collections::HashMap;
158
159                    // 直接使用__execute_request而不是execute,以便保留字段信息
160                    let json = self.request.__execute_request(self.fields).await?;
161
162                    // Extract fields and items
163                    let data = json.get("data")
164                        .ok_or("Missing 'data' field in response")?;
165
166                    let fields = data.get("fields")
167                        .ok_or("Missing 'fields' field in data")?
168                        .as_array()
169                        .ok_or("'fields' is not an array")?;
170
171                    let items = data.get("items")
172                        .ok_or("Missing 'items' field in data")?
173                        .as_array()
174                        .ok_or("'items' is not an array")?;
175
176                    // Convert to Vec<HashMap<String, Value>>
177                    let mut result = Vec::with_capacity(items.len());
178
179                    for item_value in items {
180                        let item = item_value.as_array()
181                            .ok_or("Item is not an array")?;
182
183                        let mut map = HashMap::new();
184
185                        // Map fields to values
186                        for (i, field) in fields.iter().enumerate() {
187                            if i < item.len() {
188                                let field_name = field.as_str()
189                                    .ok_or("Field name is not a string")?
190                                    .to_string();
191
192                                map.insert(field_name, item[i].clone());
193                            }
194                        }
195
196                        result.push(map);
197                    }
198
199                    Ok(result)
200                }
201            }
202        }
203    } else {
204        quote! {
205            // 定义单独的TsRequester结构体和impl,这个结构体是在当前crate中的
206            pub struct #requester_name {
207                request: #name,
208                fields: Option<Vec<&'static str>>,
209            }
210
211            impl #requester_name {
212                pub fn new(request: #name, fields: Option<Vec<&'static str>>) -> Self {
213                    Self { request, fields }
214                }
215
216                pub fn with_fields(mut self, fields: Vec<&'static str>) -> Self {
217                    self.fields = Some(fields);
218                    self
219                }
220
221                pub async fn execute(self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
222                    self.request.__execute_request(self.fields).await
223                }
224
225                pub async fn execute_as_dicts(self) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>, Box<dyn std::error::Error>> {
226                    use serde_json::Value;
227                    use std::collections::HashMap;
228
229                    // 直接使用__execute_request而不是execute,以便保留字段信息
230                    let json = self.request.__execute_request(self.fields).await?;
231
232                    // Extract fields and items
233                    let data = json.get("data")
234                        .ok_or("Missing 'data' field in response")?;
235
236                    let fields = data.get("fields")
237                        .ok_or("Missing 'fields' field in data")?
238                        .as_array()
239                        .ok_or("'fields' is not an array")?;
240
241                    let items = data.get("items")
242                        .ok_or("Missing 'items' field in data")?
243                        .as_array()
244                        .ok_or("'items' is not an array")?;
245
246                    // Convert to Vec<HashMap<String, Value>>
247                    let mut result = Vec::with_capacity(items.len());
248
249                    for item_value in items {
250                        let item = item_value.as_array()
251                            .ok_or("Item is not an array")?;
252
253                        let mut map = HashMap::new();
254
255                        // Map fields to values
256                        for (i, field) in fields.iter().enumerate() {
257                            if i < item.len() {
258                                let field_name = field.as_str()
259                                    .ok_or("Field name is not a string")?
260                                    .to_string();
261
262                                map.insert(field_name, item[i].clone());
263                            }
264                        }
265
266                        result.push(map);
267                    }
268
269                    Ok(result)
270                }
271            }
272        }
273    };
274
275    // Generate impl for the struct
276    let impl_struct = quote! {
277        impl #name {
278            /// Get the API name
279            pub fn api_name(&self) -> &'static str {
280                #api_name
281            }
282
283            /// Get the API description
284            pub fn description(&self) -> &'static str {
285                #api_desc
286            }
287
288            /// Start chain with fields
289            pub fn with_fields(self, fields: Vec<&'static str>) -> #requester_name {
290                #requester_name::new(self, Some(fields))
291            }
292
293            /// Execute without fields
294            pub async fn execute(self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
295                self.__execute_request(None).await
296            }
297
298            /// Execute with typed response, automatically deriving fields from response struct
299            pub async fn execute_typed(self) -> Result<Vec<#resp_type>, Box<dyn std::error::Error>> {
300                // Create requester and call its execute_typed method
301                let requester = #requester_name::new(self, None);
302                requester.execute_typed().await
303            }
304
305            // Inner method used by TsRequester
306            #[doc(hidden)]
307            pub(crate) async fn __execute_request(&self, fields: Option<Vec<&str>>) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
308                use serde_json::{json, Map, Value};
309                use reqwest::Client;
310                use dotenvy::dotenv;
311                use std::env;
312
313                // Load environment variables
314                dotenv().ok();
315
316                // Get token from environment
317                let token = env::var("TUSHARE_TOKEN")
318                    .map_err(|_| "TUSHARE_TOKEN environment variable not set")?;
319
320                // Build params object
321                let mut params = Map::new();
322                #(#param_fields)*
323
324                // Create request body
325                let mut request_body = Map::new();
326                request_body.insert("api_name".to_string(), Value::String(#api_name.to_string()));
327                request_body.insert("token".to_string(), Value::String(token));
328                request_body.insert("params".to_string(), Value::Object(params));
329
330                // Add fields if provided
331                if let Some(field_list) = fields {
332                    request_body.insert("fields".to_string(),
333                        Value::String(field_list.join(",")));
334                }
335
336                // Send request
337                let client = Client::new();
338                let response = client
339                    .post("http://api.tushare.pro/")
340                    .header("Content-Type", "application/json")
341                    .body(serde_json::to_string(&Value::Object(request_body))?)
342                    .send()
343                    .await?;
344
345                if !response.status().is_success() {
346                    return Err(format!("Request failed with status: {}", response.status()).into());
347                }
348
349                let json = response.json::<Value>().await?;
350                Ok(json)
351            }
352        }
353    };
354
355    // Combine implementations
356    let output = quote! {
357        #impl_struct
358        #ts_requester_impl
359    };
360
361    output.into()
362}
363
364/// Derive macro for Tushare API response models
365///
366/// This macro will create structs that represent the response data from a Tushare API call.
367/// It automatically maps the fields to the data items in the response.
368///
369/// Example usage:
370/// ```rust
371/// #[derive(TsResponse)]
372/// #[response(api = "api_name")]
373/// struct MyResponseData {
374///     #[ts_field(0)]
375///     field_one: String,
376///     #[ts_field(1)]
377///     field_two: i64,
378///     // ...
379/// }
380/// ```
381#[proc_macro_derive(TsResponse, attributes(response, ts_field))]
382pub fn ts_response_derive(input: TokenStream) -> TokenStream {
383    let input = parse_macro_input!(input as DeriveInput);
384    let name = &input.ident;
385
386    // Parse response options
387    let response_opts = match input
388        .attrs
389        .iter()
390        .find(|attr| attr.path().is_ident("response"))
391        .map(|attr| ResponseOpts::from_meta(&attr.meta))
392        .transpose()
393    {
394        Ok(Some(opts)) => opts,
395        Ok(None) => {
396            return syn::Error::new_spanned(
397                input.ident.clone(),
398                "Missing #[response(...)] attribute",
399            )
400            .to_compile_error()
401            .into()
402        }
403        Err(e) => return TokenStream::from(e.write_errors()),
404    };
405
406    // Extract fields for response data
407    let fields = match &input.data {
408        Data::Struct(data) => match &data.fields {
409            Fields::Named(fields) => &fields.named,
410            _ => {
411                return syn::Error::new_spanned(
412                    input.ident.clone(),
413                    "TsResponse only supports structs with named fields",
414                )
415                .to_compile_error()
416                .into()
417            }
418        },
419        _ => {
420            return syn::Error::new_spanned(input.ident.clone(), "TsResponse only supports structs")
421                .to_compile_error()
422                .into()
423        }
424    };
425
426    // Generate field parsing for the response items
427    let field_parsers = fields.iter().map(|field| {
428        let field_name = field.ident.as_ref().unwrap();
429        let field_type = &field.ty;
430
431        // Extract index from ts_field attribute
432        let mut field_index = None;
433        // Check for #[serde(default)] attribute
434        let mut has_serde_default = false;
435
436        for attr in &field.attrs {
437            if attr.path().is_ident("ts_field") {
438                match attr.meta.require_list() {
439                    Ok(nested) => {
440                        // Parse the first token in the list as a literal integer
441                        let lit: LitInt = match syn::parse2(nested.tokens.clone()) {
442                            Ok(lit) => lit,
443                            Err(e) => return e.to_compile_error(),
444                        };
445                        field_index = Some(lit.base10_parse::<usize>().unwrap());
446                    }
447                    Err(e) => return e.to_compile_error(),
448                }
449            } else if attr.path().is_ident("serde") {
450                 // Use parse_nested_meta for a more robust check
451                 let _ = attr.parse_nested_meta(|meta| {
452                    if meta.path.is_ident("default") {
453                        has_serde_default = true;
454                    }
455                    // Ignore other serde attributes like rename, skip, etc.
456                    // Need to handle potential errors within meta if attributes are complex
457                    Ok(())
458                });
459                // Ignoring potential parse_nested_meta error for simplicity for now
460            }
461        }
462
463        let index = match field_index {
464            Some(idx) => idx,
465            None => {
466                return syn::Error::new_spanned(field_name, "Missing #[ts_field(index)] attribute")
467                    .to_compile_error()
468            }
469        };
470
471        let from_value = if field_type_is_option(field_type) {
472            // Logic for Option<T>
473            quote! {
474                let #field_name = if item.len() > #index {
475                    let val = &item[#index];
476                    if val.is_null() {
477                        None
478                    } else {
479                        Some(serde_json::from_value(val.clone())?)
480                    }
481                } else {
482                    None // Treat missing index as None for Option types
483                };
484            }
485        } else if has_serde_default {
486            // Logic for non-Option<T> with #[serde(default)]
487             quote! {
488                let #field_name:#field_type = if item.len() > #index {
489                    let val = &item[#index];
490                    if val.is_null() {
491                        Default::default() // Use default if null
492                    } else {
493                        // Using unwrap_or_default() on the Result is cleaner
494                        serde_json::from_value(val.clone()).unwrap_or_default()
495                    }
496                } else {
497                    Default::default() // Use default if index out of bounds
498                };
499            }
500        } else {
501            // Logic for non-Option<T> *without* #[serde(default)]
502            quote! {
503                let #field_name = if item.len() > #index {
504                    let val = &item[#index];
505                     // Error on null for non-optional, non-default fields
506                    if val.is_null() {
507                         return Err(format!("Field '{}' at index {} is null, but type is not Option and #[serde(default)] is not specified", stringify!(#field_name), #index).into());
508                    }
509                    serde_json::from_value(val.clone())?
510                } else {
511                    return Err(format!("Field index {} out of bounds for required field '{}'", #index, stringify!(#field_name)).into());
512                };
513            }
514        };
515
516        quote! { #from_value }
517    });
518
519    // 生成字段名称列表(用于构造和获取字段名)
520    let field_names: Vec<_> = fields
521        .iter()
522        .map(|field| field.ident.as_ref().unwrap().clone())
523        .collect();
524
525    // 生成用于构造结构体的字段列表
526    let struct_field_tokens = {
527        let field_idents = &field_names;
528        quote! {
529            #(#field_idents),*
530        }
531    };
532
533    // Get API name
534    let api_name = &response_opts.api;
535
536    // Generate implementation for parsing response
537    let output = quote! {
538        impl #name {
539            /// Parse a list of items from Tushare API response
540            pub fn from_json(json: &serde_json::Value) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
541                use serde_json::Value;
542
543                // Extract data from response
544                let data = json.get("data")
545                    .ok_or_else(|| "Missing 'data' field in response")?;
546
547                let items = data.get("items")
548                    .ok_or_else(|| "Missing 'items' field in data")?
549                    .as_array()
550                    .ok_or_else(|| "'items' is not an array")?;
551
552                let mut result = Vec::with_capacity(items.len());
553
554                for item_value in items {
555                    let item = item_value.as_array()
556                        .ok_or_else(|| "Item is not an array")?;
557
558                    #(#field_parsers)*
559
560                    result.push(Self {
561                        #struct_field_tokens
562                    });
563                }
564
565                Ok(result)
566            }
567
568            /// Get the API name for this response
569            pub fn api_name() -> &'static str {
570                #api_name
571            }
572
573            /// Get field names from the response struct
574            pub fn get_field_names() -> Vec<&'static str> {
575                vec![
576                    #(stringify!(#field_names)),*
577                ]
578            }
579        }
580
581        // Implement From<Value> to allow automatic conversion from JSON
582        impl From<serde_json::Value> for #name {
583            fn from(value: serde_json::Value) -> Self {
584                // This is just a placeholder implementation to satisfy the trait bound
585                // The actual conversion is handled by the from_json method
586                panic!("Direct conversion from Value to {} is not supported, use from_json instead", stringify!(#name));
587            }
588        }
589    };
590
591    output.into()
592}
593
594/// Check if a type is an Option<T>
595fn field_type_is_option(ty: &Type) -> bool {
596    if let Type::Path(type_path) = ty {
597        if let Some(segment) = type_path.path.segments.first() {
598            return segment.ident == "Option";
599        }
600    }
601    false
602}