1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, LitStr, Type, parse_macro_input};
5
6#[proc_macro_derive(Extract, attributes(extract))]
28pub fn derive_extract(input: TokenStream) -> TokenStream {
29 let input = parse_macro_input!(input as DeriveInput);
30 match impl_extract(&input) {
31 Ok(ts) => ts.into(),
32 Err(e) => e.to_compile_error().into(),
33 }
34}
35
36struct FieldInfo {
37 name: syn::Ident,
38 is_option: bool,
39 args: ExtractArgs,
40}
41
42fn impl_extract(input: &DeriveInput) -> syn::Result<TokenStream2> {
43 let name = &input.ident;
44 let Data::Struct(data) = &input.data else {
45 return Err(syn::Error::new_spanned(
46 input,
47 "#[derive(Extract)] only supports structs",
48 ));
49 };
50 let Fields::Named(fields) = &data.fields else {
51 return Err(syn::Error::new_spanned(
52 input,
53 "#[derive(Extract)] requires named fields",
54 ));
55 };
56
57 let field_infos: Vec<FieldInfo> = fields
58 .named
59 .iter()
60 .map(|field| {
61 Ok(FieldInfo {
62 name: field.ident.as_ref().unwrap().clone(),
63 is_option: is_option_type(&field.ty),
64 args: parse_extract_args(field)?,
65 })
66 })
67 .collect::<syn::Result<Vec<_>>>()?;
68
69 let has_llm_fallback = field_infos.iter().any(|f| f.args.llm_fallback.is_some());
70
71 let sync_extraction: Vec<TokenStream2> = field_infos
73 .iter()
74 .map(|fi| {
75 let field_name = &fi.name;
76 let css = &fi.args.css;
77 let base = quote! { element.css(#css).first() };
78 let valued = match (&fi.args.attr, &fi.args.re) {
79 (Some(attr), _) => quote! { #base.and_then(|e| e.attr(#attr)) },
80 (_, Some(re)) => quote! { #base.and_then(|e| e.re_first(#re)) },
81 _ => quote! { #base.map(|e| e.text()) },
82 };
83 let transform_expr = match fi.args.transform.as_ref().map(|t| t.value()) {
84 Some(ref t) if t == "trim" => {
85 quote! { .map(|s: String| s.trim().to_string()) }
86 }
87 Some(ref t) if t == "lowercase" => {
88 quote! { .map(|s: String| s.to_lowercase()) }
89 }
90 Some(ref t) if t == "uppercase" => {
91 quote! { .map(|s: String| s.to_uppercase()) }
92 }
93 _ => quote! {},
94 };
95 let var = quote::format_ident!("__field_{}", field_name);
96 quote! { let mut #var: Option<String> = (#valued)#transform_expr; }
97 })
98 .collect();
99
100 let llm_block = if has_llm_fallback {
102 let schema_entries: Vec<TokenStream2> = field_infos
104 .iter()
105 .filter_map(|fi| {
106 fi.args.llm_fallback.as_ref().map(|hint_opt| {
107 let field_str = fi.name.to_string();
108 let hint = hint_opt
109 .as_ref()
110 .map(|s| s.value())
111 .unwrap_or_else(|| field_str.clone());
112 quote! {
113 props.insert(
114 #field_str.to_string(),
115 ::serde_json::json!({ "type": "string", "description": #hint }),
116 );
117 }
118 })
119 })
120 .collect();
121
122 let missing_checks: Vec<TokenStream2> = field_infos
124 .iter()
125 .filter_map(|fi| {
126 if fi.args.llm_fallback.is_some() {
127 let var = quote::format_ident!("__field_{}", fi.name);
128 Some(quote! { #var.as_ref().map(|s| s.trim().is_empty()).unwrap_or(true) })
129 } else {
130 None
131 }
132 })
133 .collect();
134
135 let fill_ins: Vec<TokenStream2> = field_infos
137 .iter()
138 .filter_map(|fi| {
139 if fi.args.llm_fallback.is_some() {
140 let field_str = fi.name.to_string();
141 let var = quote::format_ident!("__field_{}", fi.name);
142 Some(quote! {
143 if #var.as_ref().map(|s| s.trim().is_empty()).unwrap_or(true) {
144 #var = __llm_json.get(#field_str)
145 .and_then(|v| v.as_str())
146 .filter(|s| !s.trim().is_empty())
147 .map(|s| s.to_string());
148 }
149 })
150 } else {
151 None
152 }
153 })
154 .collect();
155
156 quote! {
157 if #(#missing_checks)||* {
158 if let Some(__llm_client) = llm {
159 let mut props = ::serde_json::Map::new();
160 #(#schema_entries)*
161 let __schema = ::serde_json::json!({
162 "type": "object",
163 "properties": props
164 });
165 let (__llm_json, _) = __llm_client
166 .extract_json(&__schema, element.outer_html())
167 .await?;
168 #(#fill_ins)*
169 }
170 }
171 }
172 } else {
173 quote! {}
174 };
175
176 let struct_fields: Vec<TokenStream2> = field_infos
178 .iter()
179 .map(|fi| {
180 let field_name = &fi.name;
181 let var = quote::format_ident!("__field_{}", field_name);
182 if fi.is_option {
183 quote! { #field_name: #var }
184 } else if let Some(default) = &fi.args.default_val {
185 quote! { #field_name: #var.unwrap_or_else(|| #default.to_string()) }
186 } else {
187 quote! { #field_name: #var.unwrap_or_default() }
188 }
189 })
190 .collect();
191
192 Ok(quote! {
193 #[::async_trait::async_trait]
194 impl ::kumo::extract::Extract for #name {
195 async fn extract_from(
196 element: &::kumo::extract::Element,
197 llm: ::std::option::Option<&dyn ::kumo::llm::client::LlmClient>,
198 ) -> ::std::result::Result<Self, ::kumo::error::KumoError> {
199 #(#sync_extraction)*
200 #llm_block
201 ::std::result::Result::Ok(#name {
202 #(#struct_fields),*
203 })
204 }
205 }
206 })
207}
208
209struct ExtractArgs {
210 css: LitStr,
211 attr: Option<LitStr>,
212 re: Option<LitStr>,
213 llm_fallback: Option<Option<LitStr>>,
215 default_val: Option<LitStr>,
217 transform: Option<LitStr>,
219}
220
221fn parse_extract_args(field: &syn::Field) -> syn::Result<ExtractArgs> {
222 let attr = field
223 .attrs
224 .iter()
225 .find(|a| a.path().is_ident("extract"))
226 .ok_or_else(|| {
227 syn::Error::new_spanned(field, "field is missing #[extract(css = \"...\")]")
228 })?;
229
230 let mut css: Option<LitStr> = None;
231 let mut attr_val: Option<LitStr> = None;
232 let mut re_val: Option<LitStr> = None;
233 let mut llm_fallback: Option<Option<LitStr>> = None;
234 let mut default_val: Option<LitStr> = None;
235 let mut transform: Option<LitStr> = None;
236
237 attr.parse_nested_meta(|meta| {
238 if meta.path.is_ident("css") {
239 css = Some(meta.value()?.parse()?);
240 } else if meta.path.is_ident("attr") {
241 attr_val = Some(meta.value()?.parse()?);
242 } else if meta.path.is_ident("re") {
243 re_val = Some(meta.value()?.parse()?);
244 } else if meta.path.is_ident("text") {
245 } else if meta.path.is_ident("llm_fallback") {
247 if meta.input.peek(syn::Token![=]) {
248 let hint: LitStr = meta.value()?.parse()?;
249 llm_fallback = Some(Some(hint));
250 } else {
251 llm_fallback = Some(None);
252 }
253 } else if meta.path.is_ident("default") {
254 default_val = Some(meta.value()?.parse()?);
255 } else if meta.path.is_ident("transform") {
256 let lit: LitStr = meta.value()?.parse()?;
257 let val = lit.value();
258 if !matches!(val.as_str(), "trim" | "lowercase" | "uppercase") {
259 return Err(syn::Error::new(
260 lit.span(),
261 format!("unknown transform `{val}` — valid values: trim, lowercase, uppercase"),
262 ));
263 }
264 transform = Some(lit);
265 } else {
266 let key = meta
267 .path
268 .get_ident()
269 .map(|i| i.to_string())
270 .unwrap_or_default();
271 return Err(meta.error(format!(
272 "unknown extract attribute `{key}` — valid keys: css, attr, re, text, llm_fallback, default, transform"
273 )));
274 }
275 Ok(())
276 })?;
277
278 let css =
279 css.ok_or_else(|| syn::Error::new_spanned(attr, "#[extract] requires css = \"selector\""))?;
280
281 Ok(ExtractArgs {
282 css,
283 attr: attr_val,
284 re: re_val,
285 llm_fallback,
286 default_val,
287 transform,
288 })
289}
290
291fn is_option_type(ty: &Type) -> bool {
292 if let Type::Path(tp) = ty
293 && let Some(seg) = tp.path.segments.last()
294 {
295 return seg.ident == "Option";
296 }
297 false
298}