finance_query_derive/
lib.rs

1//! Derive macros for finance-query library.
2//!
3//! Provides `#[derive(ToDataFrame)]` for automatic DataFrame conversion.
4
5use proc_macro::TokenStream;
6use proc_macro2::TokenStream as TokenStream2;
7use quote::quote;
8use syn::{
9    Data, DeriveInput, Fields, GenericArgument, PathArguments, Type, TypePath, parse_macro_input,
10};
11
12/// Derive macro for automatic DataFrame conversion.
13///
14/// Generates a `to_dataframe(&self) -> PolarsResult<DataFrame>` method
15/// that converts all struct fields to DataFrame columns.
16///
17/// # Supported Types
18///
19/// - `String` → String column
20/// - `Option<String>` → nullable String column
21/// - `Option<FormattedValue<f64>>` → extracts `.raw` as `Option<f64>`
22/// - `Option<FormattedValue<i64>>` → extracts `.raw` as `Option<i64>`
23/// - `i32`, `i64`, `f64`, `bool` → direct columns
24/// - `Option<T>` for primitives → nullable columns
25/// - Nested structs/Vec → skipped (complex types not suitable for flat DataFrame)
26///
27/// # Example
28///
29/// ```ignore
30/// #[derive(ToDataFrame)]
31/// pub struct Quote {
32///     pub symbol: String,
33///     pub price: Option<FormattedValue<f64>>,
34/// }
35///
36/// // Generates:
37/// impl Quote {
38///     pub fn to_dataframe(&self) -> PolarsResult<DataFrame> {
39///         df![
40///             "symbol" => [self.symbol.as_str()],
41///             "price" => [self.price.as_ref().and_then(|v| v.raw)],
42///         ]
43///     }
44/// }
45/// ```
46#[proc_macro_derive(ToDataFrame)]
47pub fn derive_to_dataframe(input: TokenStream) -> TokenStream {
48    let input = parse_macro_input!(input as DeriveInput);
49    let name = &input.ident;
50
51    let fields = match &input.data {
52        Data::Struct(data) => match &data.fields {
53            Fields::Named(fields) => &fields.named,
54            _ => {
55                return syn::Error::new_spanned(
56                    &input,
57                    "ToDataFrame only supports structs with named fields",
58                )
59                .to_compile_error()
60                .into();
61            }
62        },
63        _ => {
64            return syn::Error::new_spanned(&input, "ToDataFrame only supports structs")
65                .to_compile_error()
66                .into();
67        }
68    };
69
70    let mut column_names: Vec<String> = Vec::new();
71    let mut column_values: Vec<TokenStream2> = Vec::new();
72
73    for field in fields.iter() {
74        let field_name = field.ident.as_ref().unwrap();
75        let field_name_str = to_snake_case(&field_name.to_string());
76        let field_type = &field.ty;
77
78        if let Some(value_expr) = generate_column_value(field_name, field_type) {
79            column_names.push(field_name_str);
80            column_values.push(value_expr);
81        }
82        // Skip fields that return None (complex nested types)
83    }
84
85    // Generate vec column value expressions (for vec_to_dataframe)
86    let mut vec_column_values: Vec<TokenStream2> = Vec::new();
87    for field in fields.iter() {
88        let field_name = field.ident.as_ref().unwrap();
89        let field_type = &field.ty;
90
91        if let Some(value_expr) = generate_vec_column_value(field_name, field_type) {
92            vec_column_values.push(value_expr);
93        }
94    }
95
96    let expanded = quote! {
97        #[cfg(feature = "dataframe")]
98        impl #name {
99            /// Converts this struct to a single-row polars DataFrame.
100            ///
101            /// All scalar fields are included as columns. Nested objects
102            /// and complex types are excluded.
103            ///
104            /// This method is auto-generated by the `ToDataFrame` derive macro.
105            pub fn to_dataframe(&self) -> ::polars::prelude::PolarsResult<::polars::prelude::DataFrame> {
106                use ::polars::prelude::*;
107                df![
108                    #( #column_names => #column_values ),*
109                ]
110            }
111
112            /// Converts a slice of structs to a multi-row polars DataFrame.
113            ///
114            /// All scalar fields are included as columns. Nested objects
115            /// and complex types are excluded.
116            ///
117            /// This method is auto-generated by the `ToDataFrame` derive macro.
118            pub fn vec_to_dataframe(items: &[Self]) -> ::polars::prelude::PolarsResult<::polars::prelude::DataFrame> {
119                use ::polars::prelude::*;
120                df![
121                    #( #column_names => #vec_column_values ),*
122                ]
123            }
124        }
125    };
126
127    TokenStream::from(expanded)
128}
129
130/// Converts a field name to snake_case for DataFrame column names.
131fn to_snake_case(s: &str) -> String {
132    s.to_string()
133}
134
135/// Generates the value expression for a DataFrame column based on field type.
136///
137/// Returns `None` for complex types that should be skipped.
138fn generate_column_value(field_name: &syn::Ident, field_type: &Type) -> Option<TokenStream2> {
139    match field_type {
140        // Direct String
141        Type::Path(type_path) if is_string(type_path) => {
142            Some(quote! { [self.#field_name.as_str()] })
143        }
144
145        // Direct FormattedValue<T> - extract .raw
146        Type::Path(type_path) if is_formatted_value(type_path) => {
147            Some(quote! { [self.#field_name.raw] })
148        }
149
150        // Option<T>
151        Type::Path(type_path) if is_option(type_path) => {
152            let inner_type = get_option_inner_type(type_path)?;
153            generate_option_value(field_name, inner_type)
154        }
155
156        // Direct primitives: i32, i64, f64, bool
157        Type::Path(type_path) if is_primitive(type_path) => Some(quote! { [self.#field_name] }),
158
159        // Skip all other types (Vec, nested structs, etc.)
160        _ => None,
161    }
162}
163
164/// Generates the value expression for a DataFrame column when iterating over a Vec.
165///
166/// Returns `None` for complex types that should be skipped.
167fn generate_vec_column_value(field_name: &syn::Ident, field_type: &Type) -> Option<TokenStream2> {
168    match field_type {
169        // Direct String
170        Type::Path(type_path) if is_string(type_path) => {
171            Some(quote! { items.iter().map(|item| item.#field_name.as_str()).collect::<Vec<_>>() })
172        }
173
174        // Direct FormattedValue<T> - extract .raw
175        Type::Path(type_path) if is_formatted_value(type_path) => {
176            Some(quote! { items.iter().map(|item| item.#field_name.raw).collect::<Vec<_>>() })
177        }
178
179        // Option<T>
180        Type::Path(type_path) if is_option(type_path) => {
181            let inner_type = get_option_inner_type(type_path)?;
182            generate_vec_option_value(field_name, inner_type)
183        }
184
185        // Direct primitives: i32, i64, f64, bool
186        Type::Path(type_path) if is_primitive(type_path) => {
187            Some(quote! { items.iter().map(|item| item.#field_name).collect::<Vec<_>>() })
188        }
189
190        // Skip all other types (Vec, nested structs, etc.)
191        _ => None,
192    }
193}
194
195/// Generates value expression for Option<T> fields when iterating over a Vec.
196fn generate_vec_option_value(field_name: &syn::Ident, inner_type: &Type) -> Option<TokenStream2> {
197    match inner_type {
198        // Option<String>
199        Type::Path(type_path) if is_string(type_path) => Some(
200            quote! { items.iter().map(|item| item.#field_name.as_deref()).collect::<Vec<_>>() },
201        ),
202
203        // Option<FormattedValue<T>> - extract .raw
204        Type::Path(type_path) if is_formatted_value(type_path) => Some(
205            quote! { items.iter().map(|item| item.#field_name.as_ref().and_then(|v| v.raw)).collect::<Vec<_>>() },
206        ),
207
208        // Option<primitive>
209        Type::Path(type_path) if is_primitive(type_path) => {
210            Some(quote! { items.iter().map(|item| item.#field_name).collect::<Vec<_>>() })
211        }
212
213        // Skip complex Option<T> types
214        _ => None,
215    }
216}
217
218/// Generates value expression for Option<T> fields.
219fn generate_option_value(field_name: &syn::Ident, inner_type: &Type) -> Option<TokenStream2> {
220    match inner_type {
221        // Option<String>
222        Type::Path(type_path) if is_string(type_path) => {
223            Some(quote! { [self.#field_name.as_deref()] })
224        }
225
226        // Option<FormattedValue<T>> - extract .raw
227        Type::Path(type_path) if is_formatted_value(type_path) => {
228            Some(quote! { [self.#field_name.as_ref().and_then(|v| v.raw)] })
229        }
230
231        // Option<primitive>
232        Type::Path(type_path) if is_primitive(type_path) => Some(quote! { [self.#field_name] }),
233
234        // Skip complex Option<T> types
235        _ => None,
236    }
237}
238
239/// Checks if a type path is `String`.
240fn is_string(type_path: &TypePath) -> bool {
241    type_path
242        .path
243        .segments
244        .last()
245        .map(|seg| seg.ident == "String")
246        .unwrap_or(false)
247}
248
249/// Checks if a type path is `Option<T>`.
250fn is_option(type_path: &TypePath) -> bool {
251    type_path
252        .path
253        .segments
254        .last()
255        .map(|seg| seg.ident == "Option")
256        .unwrap_or(false)
257}
258
259/// Checks if a type path is `FormattedValue<T>`.
260fn is_formatted_value(type_path: &TypePath) -> bool {
261    type_path
262        .path
263        .segments
264        .last()
265        .map(|seg| seg.ident == "FormattedValue")
266        .unwrap_or(false)
267}
268
269/// Checks if a type path is a primitive type (i32, i64, f64, bool).
270fn is_primitive(type_path: &TypePath) -> bool {
271    type_path
272        .path
273        .segments
274        .last()
275        .map(|seg| {
276            let name = seg.ident.to_string();
277            matches!(
278                name.as_str(),
279                "i32" | "i64" | "f64" | "bool" | "u32" | "u64"
280            )
281        })
282        .unwrap_or(false)
283}
284
285/// Extracts the inner type from Option<T>.
286fn get_option_inner_type(type_path: &TypePath) -> Option<&Type> {
287    let segment = type_path.path.segments.last()?;
288    if segment.ident != "Option" {
289        return None;
290    }
291
292    match &segment.arguments {
293        PathArguments::AngleBracketed(args) => args.args.first().and_then(|arg| {
294            if let GenericArgument::Type(ty) = arg {
295                Some(ty)
296            } else {
297                None
298            }
299        }),
300        _ => None,
301    }
302}