1#![allow(dead_code)]
2
3use darling::FromMeta;
4use proc_macro::TokenStream;
5use quote::quote;
6use syn::{parse_macro_input, Data, DeriveInput, Fields, LitInt, Type};
7
8#[derive(Debug, FromMeta)]
10struct EndpointOpts {
11 api: String,
13 desc: String,
15 #[darling(default)]
17 resp: Option<syn::Path>,
18}
19
20#[derive(Debug, FromMeta)]
22struct ResponseOpts {
23 api: String,
25}
26
27#[proc_macro_derive(TsEndpoint, attributes(endpoint, fields))]
38pub fn ts_endpoint_derive(input: TokenStream) -> TokenStream {
39 let input = parse_macro_input!(input as DeriveInput);
40 let name = &input.ident;
41 let requester_name = syn::Ident::new(&format!("{}Requester", name), name.span());
42
43 let endpoint_opts = match input
45 .attrs
46 .iter()
47 .find(|attr| attr.path().is_ident("endpoint"))
48 .map(|attr| EndpointOpts::from_meta(&attr.meta))
49 .transpose()
50 {
51 Ok(Some(opts)) => opts,
52 Ok(None) => {
53 return syn::Error::new_spanned(
54 input.ident.clone(),
55 "Missing #[endpoint(...)] attribute",
56 )
57 .to_compile_error()
58 .into()
59 }
60 Err(e) => return TokenStream::from(e.write_errors()),
61 };
62
63 let fields = match &input.data {
65 Data::Struct(data) => match &data.fields {
66 Fields::Named(fields) => &fields.named,
67 _ => {
68 return syn::Error::new_spanned(
69 input.ident.clone(),
70 "TsEndpoint only supports structs with named fields",
71 )
72 .to_compile_error()
73 .into()
74 }
75 },
76 _ => {
77 return syn::Error::new_spanned(input.ident.clone(), "TsEndpoint only supports structs")
78 .to_compile_error()
79 .into()
80 }
81 };
82
83 let param_fields = fields.iter().map(|field| {
85 let field_name = field.ident.as_ref().unwrap();
86 let field_name_str = field_name.to_string();
87
88 let mut rename_value = None;
90 for attr in &field.attrs {
91 if attr.path().is_ident("serde") {
92 let _ = attr.parse_nested_meta(|meta| {
93 if meta.path.is_ident("rename") {
94 rename_value = Some(meta.value()?.parse::<syn::LitStr>()?.value());
95 }
96 Ok(())
97 });
98 }
99 }
100
101 let param_name = rename_value.unwrap_or_else(|| field_name_str.clone());
103
104 quote! {
105 params.insert(#param_name.to_string(), serde_json::to_value(&self.#field_name)?);
106 }
107 });
108
109 let api_name = &endpoint_opts.api;
111 let api_desc = &endpoint_opts.desc;
112
113 let resp_type = endpoint_opts.resp.as_ref().map(|path| quote! { #path });
115
116 let ts_requester_impl = if let Some(resp_type) = resp_type.clone() {
118 quote! {
119 pub struct #requester_name {
121 request: #name,
122 fields: Option<Vec<&'static str>>,
123 }
124
125 impl #requester_name {
126 pub fn new(request: #name, fields: Option<Vec<&'static str>>) -> Self {
127 Self { request, fields }
128 }
129
130 pub fn with_fields(mut self, fields: Vec<&'static str>) -> Self {
131 self.fields = Some(fields);
132 self
133 }
134
135 pub async fn execute(self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
136 self.request.__execute_request(self.fields).await
137 }
138
139 pub async fn execute_typed(self) -> Result<Vec<#resp_type>, Box<dyn std::error::Error>> {
140 let fields_to_use = if self.fields.is_none() {
142 let field_names = <#resp_type>::get_field_names();
144 Some(field_names)
145 } else {
146 self.fields
147 };
148
149 let json = self.request.__execute_request(fields_to_use).await?;
151 let res = <#resp_type>::from_json(&json);
152 res
153 }
154
155 pub async fn execute_as_dicts(self) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>, Box<dyn std::error::Error>> {
156 use serde_json::Value;
157 use std::collections::HashMap;
158
159 let json = self.request.__execute_request(self.fields).await?;
161
162 let data = json.get("data")
164 .ok_or("Missing 'data' field in response")?;
165
166 let fields = data.get("fields")
167 .ok_or("Missing 'fields' field in data")?
168 .as_array()
169 .ok_or("'fields' is not an array")?;
170
171 let items = data.get("items")
172 .ok_or("Missing 'items' field in data")?
173 .as_array()
174 .ok_or("'items' is not an array")?;
175
176 let mut result = Vec::with_capacity(items.len());
178
179 for item_value in items {
180 let item = item_value.as_array()
181 .ok_or("Item is not an array")?;
182
183 let mut map = HashMap::new();
184
185 for (i, field) in fields.iter().enumerate() {
187 if i < item.len() {
188 let field_name = field.as_str()
189 .ok_or("Field name is not a string")?
190 .to_string();
191
192 map.insert(field_name, item[i].clone());
193 }
194 }
195
196 result.push(map);
197 }
198
199 Ok(result)
200 }
201 }
202 }
203 } else {
204 quote! {
205 pub struct #requester_name {
207 request: #name,
208 fields: Option<Vec<&'static str>>,
209 }
210
211 impl #requester_name {
212 pub fn new(request: #name, fields: Option<Vec<&'static str>>) -> Self {
213 Self { request, fields }
214 }
215
216 pub fn with_fields(mut self, fields: Vec<&'static str>) -> Self {
217 self.fields = Some(fields);
218 self
219 }
220
221 pub async fn execute(self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
222 self.request.__execute_request(self.fields).await
223 }
224
225 pub async fn execute_as_dicts(self) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>, Box<dyn std::error::Error>> {
226 use serde_json::Value;
227 use std::collections::HashMap;
228
229 let json = self.request.__execute_request(self.fields).await?;
231
232 let data = json.get("data")
234 .ok_or("Missing 'data' field in response")?;
235
236 let fields = data.get("fields")
237 .ok_or("Missing 'fields' field in data")?
238 .as_array()
239 .ok_or("'fields' is not an array")?;
240
241 let items = data.get("items")
242 .ok_or("Missing 'items' field in data")?
243 .as_array()
244 .ok_or("'items' is not an array")?;
245
246 let mut result = Vec::with_capacity(items.len());
248
249 for item_value in items {
250 let item = item_value.as_array()
251 .ok_or("Item is not an array")?;
252
253 let mut map = HashMap::new();
254
255 for (i, field) in fields.iter().enumerate() {
257 if i < item.len() {
258 let field_name = field.as_str()
259 .ok_or("Field name is not a string")?
260 .to_string();
261
262 map.insert(field_name, item[i].clone());
263 }
264 }
265
266 result.push(map);
267 }
268
269 Ok(result)
270 }
271 }
272 }
273 };
274
275 let impl_struct = quote! {
277 impl #name {
278 pub fn api_name(&self) -> &'static str {
280 #api_name
281 }
282
283 pub fn description(&self) -> &'static str {
285 #api_desc
286 }
287
288 pub fn with_fields(self, fields: Vec<&'static str>) -> #requester_name {
290 #requester_name::new(self, Some(fields))
291 }
292
293 pub async fn execute(self) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
295 self.__execute_request(None).await
296 }
297
298 pub async fn execute_typed(self) -> Result<Vec<#resp_type>, Box<dyn std::error::Error>> {
300 let requester = #requester_name::new(self, None);
302 requester.execute_typed().await
303 }
304
305 #[doc(hidden)]
307 pub(crate) async fn __execute_request(&self, fields: Option<Vec<&str>>) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
308 use serde_json::{json, Map, Value};
309 use reqwest::Client;
310 use dotenvy::dotenv;
311 use std::env;
312
313 dotenv().ok();
315
316 let token = env::var("TUSHARE_TOKEN")
318 .map_err(|_| "TUSHARE_TOKEN environment variable not set")?;
319
320 let mut params = Map::new();
322 #(#param_fields)*
323
324 let mut request_body = Map::new();
326 request_body.insert("api_name".to_string(), Value::String(#api_name.to_string()));
327 request_body.insert("token".to_string(), Value::String(token));
328 request_body.insert("params".to_string(), Value::Object(params));
329
330 if let Some(field_list) = fields {
332 request_body.insert("fields".to_string(),
333 Value::String(field_list.join(",")));
334 }
335
336 let client = Client::new();
338 let response = client
339 .post("http://api.tushare.pro/")
340 .header("Content-Type", "application/json")
341 .body(serde_json::to_string(&Value::Object(request_body))?)
342 .send()
343 .await?;
344
345 if !response.status().is_success() {
346 return Err(format!("Request failed with status: {}", response.status()).into());
347 }
348
349 let json = response.json::<Value>().await?;
350 Ok(json)
351 }
352 }
353 };
354
355 let output = quote! {
357 #impl_struct
358 #ts_requester_impl
359 };
360
361 output.into()
362}
363
364#[proc_macro_derive(TsResponse, attributes(response, ts_field))]
382pub fn ts_response_derive(input: TokenStream) -> TokenStream {
383 let input = parse_macro_input!(input as DeriveInput);
384 let name = &input.ident;
385
386 let response_opts = match input
388 .attrs
389 .iter()
390 .find(|attr| attr.path().is_ident("response"))
391 .map(|attr| ResponseOpts::from_meta(&attr.meta))
392 .transpose()
393 {
394 Ok(Some(opts)) => opts,
395 Ok(None) => {
396 return syn::Error::new_spanned(
397 input.ident.clone(),
398 "Missing #[response(...)] attribute",
399 )
400 .to_compile_error()
401 .into()
402 }
403 Err(e) => return TokenStream::from(e.write_errors()),
404 };
405
406 let fields = match &input.data {
408 Data::Struct(data) => match &data.fields {
409 Fields::Named(fields) => &fields.named,
410 _ => {
411 return syn::Error::new_spanned(
412 input.ident.clone(),
413 "TsResponse only supports structs with named fields",
414 )
415 .to_compile_error()
416 .into()
417 }
418 },
419 _ => {
420 return syn::Error::new_spanned(input.ident.clone(), "TsResponse only supports structs")
421 .to_compile_error()
422 .into()
423 }
424 };
425
426 let field_parsers = fields.iter().map(|field| {
428 let field_name = field.ident.as_ref().unwrap();
429 let field_type = &field.ty;
430
431 let mut field_index = None;
433 let mut has_serde_default = false;
435
436 for attr in &field.attrs {
437 if attr.path().is_ident("ts_field") {
438 match attr.meta.require_list() {
439 Ok(nested) => {
440 let lit: LitInt = match syn::parse2(nested.tokens.clone()) {
442 Ok(lit) => lit,
443 Err(e) => return e.to_compile_error(),
444 };
445 field_index = Some(lit.base10_parse::<usize>().unwrap());
446 }
447 Err(e) => return e.to_compile_error(),
448 }
449 } else if attr.path().is_ident("serde") {
450 let _ = attr.parse_nested_meta(|meta| {
452 if meta.path.is_ident("default") {
453 has_serde_default = true;
454 }
455 Ok(())
458 });
459 }
461 }
462
463 let index = match field_index {
464 Some(idx) => idx,
465 None => {
466 return syn::Error::new_spanned(field_name, "Missing #[ts_field(index)] attribute")
467 .to_compile_error()
468 }
469 };
470
471 let from_value = if field_type_is_option(field_type) {
472 quote! {
474 let #field_name = if item.len() > #index {
475 let val = &item[#index];
476 if val.is_null() {
477 None
478 } else {
479 Some(serde_json::from_value(val.clone())?)
480 }
481 } else {
482 None };
484 }
485 } else if has_serde_default {
486 quote! {
488 let #field_name:#field_type = if item.len() > #index {
489 let val = &item[#index];
490 if val.is_null() {
491 Default::default() } else {
493 serde_json::from_value(val.clone()).unwrap_or_default()
495 }
496 } else {
497 Default::default() };
499 }
500 } else {
501 quote! {
503 let #field_name = if item.len() > #index {
504 let val = &item[#index];
505 if val.is_null() {
507 return Err(format!("Field '{}' at index {} is null, but type is not Option and #[serde(default)] is not specified", stringify!(#field_name), #index).into());
508 }
509 serde_json::from_value(val.clone())?
510 } else {
511 return Err(format!("Field index {} out of bounds for required field '{}'", #index, stringify!(#field_name)).into());
512 };
513 }
514 };
515
516 quote! { #from_value }
517 });
518
519 let field_names: Vec<_> = fields
521 .iter()
522 .map(|field| field.ident.as_ref().unwrap().clone())
523 .collect();
524
525 let struct_field_tokens = {
527 let field_idents = &field_names;
528 quote! {
529 #(#field_idents),*
530 }
531 };
532
533 let api_name = &response_opts.api;
535
536 let output = quote! {
538 impl #name {
539 pub fn from_json(json: &serde_json::Value) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
541 use serde_json::Value;
542
543 let data = json.get("data")
545 .ok_or_else(|| "Missing 'data' field in response")?;
546
547 let items = data.get("items")
548 .ok_or_else(|| "Missing 'items' field in data")?
549 .as_array()
550 .ok_or_else(|| "'items' is not an array")?;
551
552 let mut result = Vec::with_capacity(items.len());
553
554 for item_value in items {
555 let item = item_value.as_array()
556 .ok_or_else(|| "Item is not an array")?;
557
558 #(#field_parsers)*
559
560 result.push(Self {
561 #struct_field_tokens
562 });
563 }
564
565 Ok(result)
566 }
567
568 pub fn api_name() -> &'static str {
570 #api_name
571 }
572
573 pub fn get_field_names() -> Vec<&'static str> {
575 vec![
576 #(stringify!(#field_names)),*
577 ]
578 }
579 }
580
581 impl From<serde_json::Value> for #name {
583 fn from(value: serde_json::Value) -> Self {
584 panic!("Direct conversion from Value to {} is not supported, use from_json instead", stringify!(#name));
587 }
588 }
589 };
590
591 output.into()
592}
593
594fn field_type_is_option(ty: &Type) -> bool {
596 if let Type::Path(type_path) = ty {
597 if let Some(segment) = type_path.path.segments.first() {
598 return segment.ident == "Option";
599 }
600 }
601 false
602}