Skip to main content

kumo_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, LitStr, Type, parse_macro_input};
5
6/// Derive macro that generates an [`Extract`] implementation for a struct.
7///
8/// Each field must carry `#[extract(css = "selector")]` plus optional modifiers:
9/// - `attr = "name"` — read an HTML attribute instead of text content
10/// - `re = r"pattern"` — apply a regex and take the first capture / match
11/// - `text` — explicit text content (the default; can be omitted)
12/// - `llm_fallback = "hint"` — fall back to LLM when selector returns empty
13/// - `llm_fallback` — same, using field name as the extraction hint
14///
15/// `String` fields use `unwrap_or_default()` on missing matches.
16/// `Option<String>` fields stay as `Option` (no unwrap).
17///
18/// ```rust,ignore
19/// #[derive(Extract, Serialize)]
20/// struct Book {
21///     #[extract(css = "h3 a", attr = "title")]
22///     title: String,
23///     #[extract(css = ".price_color", llm_fallback = "the price in GBP")]
24///     price: String,
25/// }
26/// ```
27#[proc_macro_derive(Extract, attributes(extract))]
28pub fn derive_extract(input: TokenStream) -> TokenStream {
29    let input = parse_macro_input!(input as DeriveInput);
30    match impl_extract(&input) {
31        Ok(ts) => ts.into(),
32        Err(e) => e.to_compile_error().into(),
33    }
34}
35
36struct FieldInfo {
37    name: syn::Ident,
38    is_option: bool,
39    args: ExtractArgs,
40}
41
42fn impl_extract(input: &DeriveInput) -> syn::Result<TokenStream2> {
43    let name = &input.ident;
44    let Data::Struct(data) = &input.data else {
45        return Err(syn::Error::new_spanned(
46            input,
47            "#[derive(Extract)] only supports structs",
48        ));
49    };
50    let Fields::Named(fields) = &data.fields else {
51        return Err(syn::Error::new_spanned(
52            input,
53            "#[derive(Extract)] requires named fields",
54        ));
55    };
56
57    let field_infos: Vec<FieldInfo> = fields
58        .named
59        .iter()
60        .map(|field| {
61            Ok(FieldInfo {
62                name: field.ident.as_ref().unwrap().clone(),
63                is_option: is_option_type(&field.ty),
64                args: parse_extract_args(field)?,
65            })
66        })
67        .collect::<syn::Result<Vec<_>>>()?;
68
69    let has_llm_fallback = field_infos.iter().any(|f| f.args.llm_fallback.is_some());
70
71    // Generate per-field sync extraction (as Option<String> for everything).
72    let sync_extraction: Vec<TokenStream2> = field_infos
73        .iter()
74        .map(|fi| {
75            let field_name = &fi.name;
76            let css = &fi.args.css;
77            let base = quote! { element.css(#css).first() };
78            let valued = match (&fi.args.attr, &fi.args.re) {
79                (Some(attr), _) => quote! { #base.and_then(|e| e.attr(#attr)) },
80                (_, Some(re)) => quote! { #base.and_then(|e| e.re_first(#re)) },
81                _ => quote! { #base.map(|e| e.text()) },
82            };
83            let transform_expr = match fi.args.transform.as_ref().map(|t| t.value()) {
84                Some(ref t) if t == "trim" => {
85                    quote! { .map(|s: String| s.trim().to_string()) }
86                }
87                Some(ref t) if t == "lowercase" => {
88                    quote! { .map(|s: String| s.to_lowercase()) }
89                }
90                Some(ref t) if t == "uppercase" => {
91                    quote! { .map(|s: String| s.to_uppercase()) }
92                }
93                _ => quote! {},
94            };
95            let var = quote::format_ident!("__field_{}", field_name);
96            quote! { let mut #var: Option<String> = (#valued)#transform_expr; }
97        })
98        .collect();
99
100    // Generate LLM fallback block (only if any field has llm_fallback).
101    let llm_block = if has_llm_fallback {
102        // Build the schema properties entries for all llm_fallback fields.
103        let schema_entries: Vec<TokenStream2> = field_infos
104            .iter()
105            .filter_map(|fi| {
106                fi.args.llm_fallback.as_ref().map(|hint_opt| {
107                    let field_str = fi.name.to_string();
108                    let hint = hint_opt
109                        .as_ref()
110                        .map(|s| s.value())
111                        .unwrap_or_else(|| field_str.clone());
112                    quote! {
113                        props.insert(
114                            #field_str.to_string(),
115                            ::serde_json::json!({ "type": "string", "description": #hint }),
116                        );
117                    }
118                })
119            })
120            .collect();
121
122        // Generate the missing-check condition.
123        let missing_checks: Vec<TokenStream2> = field_infos
124            .iter()
125            .filter_map(|fi| {
126                if fi.args.llm_fallback.is_some() {
127                    let var = quote::format_ident!("__field_{}", fi.name);
128                    Some(quote! { #var.as_ref().map(|s| s.trim().is_empty()).unwrap_or(true) })
129                } else {
130                    None
131                }
132            })
133            .collect();
134
135        // Generate the fill-in assignments after the LLM call.
136        let fill_ins: Vec<TokenStream2> = field_infos
137            .iter()
138            .filter_map(|fi| {
139                if fi.args.llm_fallback.is_some() {
140                    let field_str = fi.name.to_string();
141                    let var = quote::format_ident!("__field_{}", fi.name);
142                    Some(quote! {
143                        if #var.as_ref().map(|s| s.trim().is_empty()).unwrap_or(true) {
144                            #var = __llm_json.get(#field_str)
145                                .and_then(|v| v.as_str())
146                                .filter(|s| !s.trim().is_empty())
147                                .map(|s| s.to_string());
148                        }
149                    })
150                } else {
151                    None
152                }
153            })
154            .collect();
155
156        quote! {
157            if #(#missing_checks)||* {
158                if let Some(__llm_client) = llm {
159                    let mut props = ::serde_json::Map::new();
160                    #(#schema_entries)*
161                    let __schema = ::serde_json::json!({
162                        "type": "object",
163                        "properties": props
164                    });
165                    let (__llm_json, _) = __llm_client
166                        .extract_json(&__schema, element.outer_html())
167                        .await?;
168                    #(#fill_ins)*
169                }
170            }
171        }
172    } else {
173        quote! {}
174    };
175
176    // Generate struct construction expressions.
177    let struct_fields: Vec<TokenStream2> = field_infos
178        .iter()
179        .map(|fi| {
180            let field_name = &fi.name;
181            let var = quote::format_ident!("__field_{}", field_name);
182            if fi.is_option {
183                quote! { #field_name: #var }
184            } else if let Some(default) = &fi.args.default_val {
185                quote! { #field_name: #var.unwrap_or_else(|| #default.to_string()) }
186            } else {
187                quote! { #field_name: #var.unwrap_or_default() }
188            }
189        })
190        .collect();
191
192    Ok(quote! {
193        #[::async_trait::async_trait]
194        impl ::kumo::extract::Extract for #name {
195            async fn extract_from(
196                element: &::kumo::extract::Element,
197                llm: ::std::option::Option<&dyn ::kumo::llm::client::LlmClient>,
198            ) -> ::std::result::Result<Self, ::kumo::error::KumoError> {
199                #(#sync_extraction)*
200                #llm_block
201                ::std::result::Result::Ok(#name {
202                    #(#struct_fields),*
203                })
204            }
205        }
206    })
207}
208
209struct ExtractArgs {
210    css: LitStr,
211    attr: Option<LitStr>,
212    re: Option<LitStr>,
213    /// `Some(Some(hint))` = `llm_fallback = "hint"`, `Some(None)` = bare `llm_fallback`.
214    llm_fallback: Option<Option<LitStr>>,
215    /// Fallback string for `String` fields when the selector returns empty.
216    default_val: Option<LitStr>,
217    /// Named transform: "trim", "lowercase", or "uppercase".
218    transform: Option<LitStr>,
219}
220
221fn parse_extract_args(field: &syn::Field) -> syn::Result<ExtractArgs> {
222    let attr = field
223        .attrs
224        .iter()
225        .find(|a| a.path().is_ident("extract"))
226        .ok_or_else(|| {
227            syn::Error::new_spanned(field, "field is missing #[extract(css = \"...\")]")
228        })?;
229
230    let mut css: Option<LitStr> = None;
231    let mut attr_val: Option<LitStr> = None;
232    let mut re_val: Option<LitStr> = None;
233    let mut llm_fallback: Option<Option<LitStr>> = None;
234    let mut default_val: Option<LitStr> = None;
235    let mut transform: Option<LitStr> = None;
236
237    attr.parse_nested_meta(|meta| {
238        if meta.path.is_ident("css") {
239            css = Some(meta.value()?.parse()?);
240        } else if meta.path.is_ident("attr") {
241            attr_val = Some(meta.value()?.parse()?);
242        } else if meta.path.is_ident("re") {
243            re_val = Some(meta.value()?.parse()?);
244        } else if meta.path.is_ident("text") {
245            // explicit text — no-op, it's the default
246        } else if meta.path.is_ident("llm_fallback") {
247            if meta.input.peek(syn::Token![=]) {
248                let hint: LitStr = meta.value()?.parse()?;
249                llm_fallback = Some(Some(hint));
250            } else {
251                llm_fallback = Some(None);
252            }
253        } else if meta.path.is_ident("default") {
254            default_val = Some(meta.value()?.parse()?);
255        } else if meta.path.is_ident("transform") {
256            let lit: LitStr = meta.value()?.parse()?;
257            let val = lit.value();
258            if !matches!(val.as_str(), "trim" | "lowercase" | "uppercase") {
259                return Err(syn::Error::new(
260                    lit.span(),
261                    format!("unknown transform `{val}` — valid values: trim, lowercase, uppercase"),
262                ));
263            }
264            transform = Some(lit);
265        } else {
266            let key = meta
267                .path
268                .get_ident()
269                .map(|i| i.to_string())
270                .unwrap_or_default();
271            return Err(meta.error(format!(
272                "unknown extract attribute `{key}` — valid keys: css, attr, re, text, llm_fallback, default, transform"
273            )));
274        }
275        Ok(())
276    })?;
277
278    let css =
279        css.ok_or_else(|| syn::Error::new_spanned(attr, "#[extract] requires css = \"selector\""))?;
280
281    Ok(ExtractArgs {
282        css,
283        attr: attr_val,
284        re: re_val,
285        llm_fallback,
286        default_val,
287        transform,
288    })
289}
290
291fn is_option_type(ty: &Type) -> bool {
292    if let Type::Path(tp) = ty
293        && let Some(seg) = tp.path.segments.last()
294    {
295        return seg.ident == "Option";
296    }
297    false
298}