1use proc_macro::TokenStream;
39use quote::quote;
40use syn::{
41 parenthesized, parse::Parse, parse::ParseStream, parse_macro_input, parse_quote,
42 punctuated::Punctuated, Data, DeriveInput, Field, Fields, Generics, Ident, Path, Token, Type,
43};
44
45#[proc_macro_derive(RequestFields, attributes(openpit, request_fields))]
46pub fn derive_request_fields(input: TokenStream) -> TokenStream {
47 let input = parse_macro_input!(input as DeriveInput);
48
49 match derive_request_fields_impl(input) {
50 Ok(tokens) => tokens.into(),
51 Err(err) => err.to_compile_error().into(),
52 }
53}
54
55fn derive_request_fields_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
56 let name = input.ident;
57 let generics = input.generics;
58
59 let data = match input.data {
60 Data::Struct(data) => data,
61 _ => {
62 return Err(syn::Error::new_spanned(
63 name,
64 "RequestFields can only be derived for structs",
65 ));
66 }
67 };
68
69 let fields = match data.fields {
70 Fields::Named(fields) => fields.named,
71 _ => {
72 return Err(syn::Error::new_spanned(
73 name,
74 "RequestFields requires named fields",
75 ));
76 }
77 };
78
79 let mut generated = Vec::new();
80 let mut seen_traits = std::collections::BTreeSet::new();
81 let mut explicit_inner: Option<&Field> = None;
82
83 for field in &fields {
84 let Some(field_ident) = &field.ident else {
85 continue;
86 };
87
88 reject_legacy_request_fields(field)?;
89
90 let parsed = parse_openpit_items(field)?;
91 if !parsed.inner {
92 for capability in parsed.capabilities {
93 register_trait_once(&mut seen_traits, &capability, field)?;
94 generated.push(impl_direct_trait(
95 &name,
96 &generics,
97 field_ident,
98 &capability,
99 ));
100 }
101 continue;
102 }
103
104 if explicit_inner.is_some() {
105 return Err(syn::Error::new_spanned(
106 field,
107 "only one #[openpit(inner)] field is allowed",
108 ));
109 }
110 explicit_inner = Some(field);
111
112 for capability in parsed.capabilities {
113 register_trait_once(&mut seen_traits, &capability, field)?;
114 generated.push(impl_passthrough_trait(
115 &name,
116 &generics,
117 field_ident,
118 &field.ty,
119 &capability,
120 ));
121 }
122 }
123
124 Ok(quote! {
125 #(#generated)*
126 })
127}
128
129fn register_trait_once(
130 seen_traits: &mut std::collections::BTreeSet<String>,
131 capability: &CapabilitySpec,
132 span: &impl quote::ToTokens,
133) -> syn::Result<()> {
134 let key = quote!(#capability).to_string();
135 if !seen_traits.insert(key.clone()) {
136 return Err(syn::Error::new_spanned(
137 span,
138 format!("duplicate trait mapping for {key}"),
139 ));
140 }
141 Ok(())
142}
143
144fn reject_legacy_request_fields(field: &Field) -> syn::Result<()> {
145 for attr in &field.attrs {
146 if attr.path().is_ident("request_fields") {
147 return Err(syn::Error::new_spanned(
148 attr,
149 "legacy #[request_fields(...)] is not supported; use #[openpit(...)]",
150 ));
151 }
152 }
153 Ok(())
154}
155
156fn parse_openpit_items(field: &Field) -> syn::Result<FieldOpenpitItems> {
157 let mut result = FieldOpenpitItems {
158 inner: false,
159 capabilities: Vec::new(),
160 };
161
162 for attr in &field.attrs {
163 if !attr.path().is_ident("openpit") {
164 continue;
165 }
166
167 let items =
168 attr.parse_args_with(Punctuated::<OpenpitAttrItem, Token![,]>::parse_terminated)?;
169 if items.is_empty() {
170 return Err(syn::Error::new_spanned(
171 attr,
172 "empty #[openpit(...)] is not allowed",
173 ));
174 }
175
176 for item in items {
177 match item {
178 OpenpitAttrItem::Inner(span) => {
179 if result.inner {
180 return Err(syn::Error::new_spanned(
181 span,
182 "duplicate `inner` marker in #[openpit(...)]",
183 ));
184 }
185 result.inner = true;
186 }
187 OpenpitAttrItem::Capability(spec) => result.capabilities.push(*spec),
188 }
189 }
190 }
191
192 Ok(result)
193}
194
195struct FieldOpenpitItems {
196 inner: bool,
197 capabilities: Vec<CapabilitySpec>,
198}
199
200enum OpenpitAttrItem {
201 Inner(Ident),
202 Capability(Box<CapabilitySpec>),
203}
204
205impl Parse for OpenpitAttrItem {
206 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
207 let path = input.parse::<Path>()?;
208 if path.is_ident("inner") {
209 if !input.is_empty() && !input.peek(Token![,]) {
210 return Err(input.error("`inner` must not have arguments"));
211 }
212 let ident = path
213 .get_ident()
214 .expect("inner path must have one segment")
215 .clone();
216 return Ok(OpenpitAttrItem::Inner(ident));
217 }
218
219 if !input.peek(syn::token::Paren) {
220 return Err(syn::Error::new_spanned(
221 path,
222 "invalid #[openpit(...)] item; expected `Trait(method -> ReturnType)` or `Trait(-> ReturnType)`",
223 ));
224 }
225
226 let content;
227 parenthesized!(content in input);
228
229 let method_ident = if content.peek(Token![->]) {
230 content.parse::<Token![->]>()?;
231 infer_method_from_trait_path(&path)?
232 } else {
233 let method = content.parse::<Ident>()?;
234 content.parse::<Token![->]>()?;
235 method
236 };
237 let return_ty = content.parse::<Type>()?;
238
239 if !content.is_empty() {
240 return Err(content.error("unexpected tokens in trait signature"));
241 }
242
243 Ok(OpenpitAttrItem::Capability(Box::new(CapabilitySpec {
244 trait_path: path,
245 method_ident,
246 return_ty,
247 })))
248 }
249}
250
251#[derive(Clone)]
252struct CapabilitySpec {
253 trait_path: Path,
254 method_ident: Ident,
255 return_ty: Type,
256}
257
258impl quote::ToTokens for CapabilitySpec {
259 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
260 let trait_path = &self.trait_path;
261 trait_path.to_tokens(tokens);
262 }
263}
264
265fn infer_method_from_trait_path(path: &Path) -> syn::Result<Ident> {
266 let Some(segment) = path.segments.last() else {
267 return Err(syn::Error::new_spanned(
268 path,
269 "trait path must have at least one segment",
270 ));
271 };
272
273 let trait_name = segment.ident.to_string();
274 let Some(stripped) = trait_name.strip_prefix("Has") else {
275 return Err(syn::Error::new_spanned(
276 &segment.ident,
277 "method inference requires a `Has*` trait name",
278 ));
279 };
280 if stripped.is_empty() {
281 return Err(syn::Error::new_spanned(
282 &segment.ident,
283 "trait name `Has` does not contain a method stem",
284 ));
285 }
286
287 let mut snake = String::new();
288 for (idx, ch) in stripped.chars().enumerate() {
289 if ch.is_uppercase() {
290 if idx > 0 {
291 snake.push('_');
292 }
293 for lower in ch.to_lowercase() {
294 snake.push(lower);
295 }
296 } else {
297 snake.push(ch);
298 }
299 }
300
301 Ok(Ident::new(&snake, segment.ident.span()))
302}
303
304fn impl_direct_trait(
305 name: &Ident,
306 generics: &Generics,
307 field_ident: &Ident,
308 capability: &CapabilitySpec,
309) -> proc_macro2::TokenStream {
310 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
311 let trait_path = &capability.trait_path;
312 let method_ident = &capability.method_ident;
313 let return_ty = &capability.return_ty;
314
315 quote! {
316 impl #impl_generics #trait_path for #name #ty_generics #where_clause {
317 fn #method_ident(&self) -> #return_ty {
318 self.#field_ident.#method_ident()
319 }
320 }
321 }
322}
323
324fn impl_passthrough_trait(
325 name: &Ident,
326 generics: &Generics,
327 inner_field_ident: &Ident,
328 inner_ty: &Type,
329 capability: &CapabilitySpec,
330) -> proc_macro2::TokenStream {
331 let trait_path = &capability.trait_path;
332 let method_ident = &capability.method_ident;
333 let return_ty = &capability.return_ty;
334
335 let mut impl_generics = generics.clone();
336 impl_generics
337 .make_where_clause()
338 .predicates
339 .push(parse_quote!(#inner_ty: #trait_path));
340 let (impl_generics, ty_generics, where_clause) = impl_generics.split_for_impl();
341
342 quote! {
343 impl #impl_generics #trait_path for #name #ty_generics #where_clause {
344 fn #method_ident(&self) -> #return_ty {
345 <#inner_ty as #trait_path>::#method_ident(&self.#inner_field_ident)
346 }
347 }
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use quote::quote;
354 use syn::punctuated::Punctuated;
355 use syn::{parse_quote, parse_str, Data, DeriveInput, Field, Fields, Path};
356
357 use super::{
358 derive_request_fields_impl, infer_method_from_trait_path, parse_openpit_items,
359 register_trait_once, CapabilitySpec, OpenpitAttrItem,
360 };
361
362 fn clear_first_named_field_ident(input: &mut DeriveInput) -> bool {
363 match &mut input.data {
364 Data::Struct(data) => match &mut data.fields {
365 Fields::Named(fields) => {
366 fields.named[0].ident = None;
367 true
368 }
369 _ => false,
370 },
371 _ => false,
372 }
373 }
374
375 #[test]
376 fn infer_method_from_has_trait_converts_to_snake_case() {
377 let path: Path = parse_quote!(crate::HasOrderPrice);
378 let method = infer_method_from_trait_path(&path).expect("inference must succeed");
379 assert_eq!(method.to_string(), "order_price");
380 }
381
382 #[test]
383 fn infer_method_from_trait_rejects_non_has_prefix() {
384 let path: Path = parse_quote!(crate::TraitWithoutPrefix);
385 let err = infer_method_from_trait_path(&path).expect_err("must reject trait without Has");
386 assert_eq!(
387 err.to_string(),
388 "method inference requires a `Has*` trait name"
389 );
390 }
391
392 #[test]
393 fn infer_method_from_has_rejects_empty_stem() {
394 let path: Path = parse_quote!(Has);
395 let err = infer_method_from_trait_path(&path).expect_err("empty method stem must reject");
396 assert_eq!(
397 err.to_string(),
398 "trait name `Has` does not contain a method stem"
399 );
400 }
401
402 #[test]
403 fn infer_method_from_empty_path_rejects() {
404 let path = Path {
405 leading_colon: None,
406 segments: Punctuated::new(),
407 };
408 let err = infer_method_from_trait_path(&path).expect_err("empty path must reject");
409 assert_eq!(err.to_string(), "trait path must have at least one segment");
410 }
411
412 #[test]
413 fn parse_openpit_items_rejects_empty_attribute() {
414 let field: Field = parse_quote!(
415 #[openpit()]
416 operation: Operation
417 );
418 let err = parse_openpit_items(&field)
419 .err()
420 .expect("empty attribute must reject");
421 assert_eq!(err.to_string(), "empty #[openpit(...)] is not allowed");
422 }
423
424 #[test]
425 fn parse_openpit_items_rejects_duplicate_inner_marker() {
426 let field: Field = parse_quote!(
427 #[openpit(inner, inner)]
428 operation: Operation
429 );
430 let err = parse_openpit_items(&field)
431 .err()
432 .expect("duplicate inner must reject");
433 assert_eq!(
434 err.to_string(),
435 "duplicate `inner` marker in #[openpit(...)]"
436 );
437 }
438
439 #[test]
440 fn parse_openpit_items_parses_inner_and_capabilities() {
441 let field: Field = parse_quote!(
442 #[openpit(inner, crate::HasPnl(-> Result<Pnl, RequestFieldAccessError>))]
443 operation: Operation
444 );
445 let parsed = parse_openpit_items(&field).expect("must parse valid attribute");
446 assert!(parsed.inner);
447 assert_eq!(parsed.capabilities.len(), 1);
448 let capability = &parsed.capabilities[0];
449 let trait_path = &capability.trait_path;
450 assert_eq!(quote!(#trait_path).to_string(), "crate :: HasPnl");
451 assert_eq!(capability.method_ident.to_string(), "pnl");
452 }
453
454 #[test]
455 fn parse_openpit_items_ignores_non_openpit_attributes() {
456 let field: Field = parse_quote!(
457 #[serde(default)]
458 operation: Operation
459 );
460 let parsed = parse_openpit_items(&field).expect("must ignore non-openpit attributes");
461 assert!(!parsed.inner);
462 assert!(parsed.capabilities.is_empty());
463 }
464
465 #[test]
466 fn register_trait_once_rejects_duplicates() {
467 let mut seen = std::collections::BTreeSet::new();
468 let capability = CapabilitySpec {
469 trait_path: parse_quote!(crate::HasInstrument),
470 method_ident: parse_quote!(instrument),
471 return_ty: parse_quote!(Result<&Instrument, RequestFieldAccessError>),
472 };
473 register_trait_once(&mut seen, &capability, &capability)
474 .expect("first mapping must register");
475 let err = register_trait_once(&mut seen, &capability, &capability)
476 .expect_err("duplicate mapping must reject");
477 assert_eq!(
478 err.to_string(),
479 "duplicate trait mapping for crate :: HasInstrument"
480 );
481 }
482
483 #[test]
484 fn derive_skips_field_without_ident_when_ast_is_malformed() {
485 let mut input: DeriveInput = parse_quote!(
486 struct Wrapper {
487 operation: Operation,
488 }
489 );
490 assert!(clear_first_named_field_ident(&mut input));
491
492 let generated =
493 derive_request_fields_impl(input).expect("malformed field without ident is skipped");
494 assert!(generated.is_empty());
495 }
496
497 #[test]
498 fn clear_first_named_field_ident_returns_false_for_non_struct() {
499 let mut input: DeriveInput = parse_quote!(
500 enum Wrapper {
501 A,
502 }
503 );
504 assert!(!clear_first_named_field_ident(&mut input));
505 }
506
507 #[test]
508 fn clear_first_named_field_ident_returns_false_for_unnamed_struct() {
509 let mut input: DeriveInput = parse_quote!(
510 struct Wrapper(u64);
511 );
512 assert!(!clear_first_named_field_ident(&mut input));
513 }
514
515 #[test]
516 fn parse_openpit_attr_item_parses_inferred_method_signature() {
517 let item: OpenpitAttrItem = parse_str("HasPnl(-> Result<Pnl, RequestFieldAccessError>)")
518 .expect("must parse inferred signature");
519 assert_eq!(capability_method_name(item).as_deref(), Some("pnl"));
520 }
521
522 #[test]
523 fn parse_openpit_attr_item_parses_explicit_method_signature() {
524 let item: OpenpitAttrItem =
525 parse_str("HasInstrument(instrument -> Result<&Instrument, RequestFieldAccessError>)")
526 .expect("must parse explicit signature");
527 assert_eq!(capability_method_name(item).as_deref(), Some("instrument"));
528 }
529
530 #[test]
531 fn parse_openpit_attr_item_parses_inner_marker() {
532 let item: OpenpitAttrItem = parse_str("inner").expect("must parse inner marker");
533 assert_eq!(capability_method_name(item), None);
534 }
535
536 #[test]
537 fn derive_request_fields_impl_generates_passthrough_for_inner_capability() {
538 let input: DeriveInput = parse_quote!(
539 struct Wrapper<T> {
540 #[openpit(inner, HasPnl(-> Result<Pnl, RequestFieldAccessError>))]
541 inner: T,
542 }
543 );
544
545 let generated = derive_request_fields_impl(input).expect("derive generation must succeed");
546 let generated_src = generated.to_string();
547 assert!(generated_src.contains("impl < T > HasPnl for Wrapper < T > where T : HasPnl"));
548 assert!(generated_src.contains("< T as HasPnl > :: pnl"));
549 assert!(generated_src.contains("& self . inner"));
550 }
551
552 fn capability_method_name(item: OpenpitAttrItem) -> Option<String> {
553 match item {
554 OpenpitAttrItem::Capability(spec) => Some(spec.method_ident.to_string()),
555 OpenpitAttrItem::Inner(_) => None,
556 }
557 }
558}