Skip to main content

openpit_derive/
lib.rs

1// Copyright The Pit Project Owners. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// Please see https://github.com/openpitkit and the OWNERS file for details.
17//! Procedural macros for the `openpit` SDK.
18//!
19//! This crate provides derive macros that generate request-field capability implementations
20//! expected by `openpit` policies.
21//!
22//! # `RequestFields`
23//!
24//! Derive for wrapper structs with named fields.
25//!
26//! Field-level `#[openpit(...)]` items:
27//!
28//! - `inner`: marks the field used for passthrough delegation.
29//! - `TraitPath(method -> ReturnType)`: generate direct impl for the field.
30//! - `TraitPath(-> ReturnType)`: same as above, method inferred from `Has*` trait name.
31//!
32//! On a field marked with `inner`, trait items generate passthrough impls with
33//! `where InnerType: TraitPath`.
34//!
35//! Old syntax `#[request_fields(...)]` is rejected with a compile-time error that points to
36//! `#[openpit(...)]`.
37
38use proc_macro::TokenStream;
39use quote::quote;
40use syn::{
41    parenthesized, parse::Parse, parse::ParseStream, parse_macro_input, parse_quote,
42    punctuated::Punctuated, Data, DeriveInput, Field, Fields, Generics, Ident, Path, Token, Type,
43};
44
45#[proc_macro_derive(RequestFields, attributes(openpit, request_fields))]
46pub fn derive_request_fields(input: TokenStream) -> TokenStream {
47    let input = parse_macro_input!(input as DeriveInput);
48
49    match derive_request_fields_impl(input) {
50        Ok(tokens) => tokens.into(),
51        Err(err) => err.to_compile_error().into(),
52    }
53}
54
55fn derive_request_fields_impl(input: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
56    let name = input.ident;
57    let generics = input.generics;
58
59    let data = match input.data {
60        Data::Struct(data) => data,
61        _ => {
62            return Err(syn::Error::new_spanned(
63                name,
64                "RequestFields can only be derived for structs",
65            ));
66        }
67    };
68
69    let fields = match data.fields {
70        Fields::Named(fields) => fields.named,
71        _ => {
72            return Err(syn::Error::new_spanned(
73                name,
74                "RequestFields requires named fields",
75            ));
76        }
77    };
78
79    let mut generated = Vec::new();
80    let mut seen_traits = std::collections::BTreeSet::new();
81    let mut explicit_inner: Option<&Field> = None;
82
83    for field in &fields {
84        let Some(field_ident) = &field.ident else {
85            continue;
86        };
87
88        reject_legacy_request_fields(field)?;
89
90        let parsed = parse_openpit_items(field)?;
91        if !parsed.inner {
92            for capability in parsed.capabilities {
93                register_trait_once(&mut seen_traits, &capability, field)?;
94                generated.push(impl_direct_trait(
95                    &name,
96                    &generics,
97                    field_ident,
98                    &capability,
99                ));
100            }
101            continue;
102        }
103
104        if explicit_inner.is_some() {
105            return Err(syn::Error::new_spanned(
106                field,
107                "only one #[openpit(inner)] field is allowed",
108            ));
109        }
110        explicit_inner = Some(field);
111
112        for capability in parsed.capabilities {
113            register_trait_once(&mut seen_traits, &capability, field)?;
114            generated.push(impl_passthrough_trait(
115                &name,
116                &generics,
117                field_ident,
118                &field.ty,
119                &capability,
120            ));
121        }
122    }
123
124    Ok(quote! {
125        #(#generated)*
126    })
127}
128
129fn register_trait_once(
130    seen_traits: &mut std::collections::BTreeSet<String>,
131    capability: &CapabilitySpec,
132    span: &impl quote::ToTokens,
133) -> syn::Result<()> {
134    let key = quote!(#capability).to_string();
135    if !seen_traits.insert(key.clone()) {
136        return Err(syn::Error::new_spanned(
137            span,
138            format!("duplicate trait mapping for {key}"),
139        ));
140    }
141    Ok(())
142}
143
144fn reject_legacy_request_fields(field: &Field) -> syn::Result<()> {
145    for attr in &field.attrs {
146        if attr.path().is_ident("request_fields") {
147            return Err(syn::Error::new_spanned(
148                attr,
149                "legacy #[request_fields(...)] is not supported; use #[openpit(...)]",
150            ));
151        }
152    }
153    Ok(())
154}
155
156fn parse_openpit_items(field: &Field) -> syn::Result<FieldOpenpitItems> {
157    let mut result = FieldOpenpitItems {
158        inner: false,
159        capabilities: Vec::new(),
160    };
161
162    for attr in &field.attrs {
163        if !attr.path().is_ident("openpit") {
164            continue;
165        }
166
167        let items =
168            attr.parse_args_with(Punctuated::<OpenpitAttrItem, Token![,]>::parse_terminated)?;
169        if items.is_empty() {
170            return Err(syn::Error::new_spanned(
171                attr,
172                "empty #[openpit(...)] is not allowed",
173            ));
174        }
175
176        for item in items {
177            match item {
178                OpenpitAttrItem::Inner(span) => {
179                    if result.inner {
180                        return Err(syn::Error::new_spanned(
181                            span,
182                            "duplicate `inner` marker in #[openpit(...)]",
183                        ));
184                    }
185                    result.inner = true;
186                }
187                OpenpitAttrItem::Capability(spec) => result.capabilities.push(*spec),
188            }
189        }
190    }
191
192    Ok(result)
193}
194
195struct FieldOpenpitItems {
196    inner: bool,
197    capabilities: Vec<CapabilitySpec>,
198}
199
200enum OpenpitAttrItem {
201    Inner(Ident),
202    Capability(Box<CapabilitySpec>),
203}
204
205impl Parse for OpenpitAttrItem {
206    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
207        let path = input.parse::<Path>()?;
208        if path.is_ident("inner") {
209            if !input.is_empty() && !input.peek(Token![,]) {
210                return Err(input.error("`inner` must not have arguments"));
211            }
212            let ident = path
213                .get_ident()
214                .expect("inner path must have one segment")
215                .clone();
216            return Ok(OpenpitAttrItem::Inner(ident));
217        }
218
219        if !input.peek(syn::token::Paren) {
220            return Err(syn::Error::new_spanned(
221                path,
222                "invalid #[openpit(...)] item; expected `Trait(method -> ReturnType)` or `Trait(-> ReturnType)`",
223            ));
224        }
225
226        let content;
227        parenthesized!(content in input);
228
229        let method_ident = if content.peek(Token![->]) {
230            content.parse::<Token![->]>()?;
231            infer_method_from_trait_path(&path)?
232        } else {
233            let method = content.parse::<Ident>()?;
234            content.parse::<Token![->]>()?;
235            method
236        };
237        let return_ty = content.parse::<Type>()?;
238
239        if !content.is_empty() {
240            return Err(content.error("unexpected tokens in trait signature"));
241        }
242
243        Ok(OpenpitAttrItem::Capability(Box::new(CapabilitySpec {
244            trait_path: path,
245            method_ident,
246            return_ty,
247        })))
248    }
249}
250
251#[derive(Clone)]
252struct CapabilitySpec {
253    trait_path: Path,
254    method_ident: Ident,
255    return_ty: Type,
256}
257
258impl quote::ToTokens for CapabilitySpec {
259    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
260        let trait_path = &self.trait_path;
261        trait_path.to_tokens(tokens);
262    }
263}
264
265fn infer_method_from_trait_path(path: &Path) -> syn::Result<Ident> {
266    let Some(segment) = path.segments.last() else {
267        return Err(syn::Error::new_spanned(
268            path,
269            "trait path must have at least one segment",
270        ));
271    };
272
273    let trait_name = segment.ident.to_string();
274    let Some(stripped) = trait_name.strip_prefix("Has") else {
275        return Err(syn::Error::new_spanned(
276            &segment.ident,
277            "method inference requires a `Has*` trait name",
278        ));
279    };
280    if stripped.is_empty() {
281        return Err(syn::Error::new_spanned(
282            &segment.ident,
283            "trait name `Has` does not contain a method stem",
284        ));
285    }
286
287    let mut snake = String::new();
288    for (idx, ch) in stripped.chars().enumerate() {
289        if ch.is_uppercase() {
290            if idx > 0 {
291                snake.push('_');
292            }
293            for lower in ch.to_lowercase() {
294                snake.push(lower);
295            }
296        } else {
297            snake.push(ch);
298        }
299    }
300
301    Ok(Ident::new(&snake, segment.ident.span()))
302}
303
304fn impl_direct_trait(
305    name: &Ident,
306    generics: &Generics,
307    field_ident: &Ident,
308    capability: &CapabilitySpec,
309) -> proc_macro2::TokenStream {
310    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
311    let trait_path = &capability.trait_path;
312    let method_ident = &capability.method_ident;
313    let return_ty = &capability.return_ty;
314
315    quote! {
316        impl #impl_generics #trait_path for #name #ty_generics #where_clause {
317            fn #method_ident(&self) -> #return_ty {
318                self.#field_ident.#method_ident()
319            }
320        }
321    }
322}
323
324fn impl_passthrough_trait(
325    name: &Ident,
326    generics: &Generics,
327    inner_field_ident: &Ident,
328    inner_ty: &Type,
329    capability: &CapabilitySpec,
330) -> proc_macro2::TokenStream {
331    let trait_path = &capability.trait_path;
332    let method_ident = &capability.method_ident;
333    let return_ty = &capability.return_ty;
334
335    let mut impl_generics = generics.clone();
336    impl_generics
337        .make_where_clause()
338        .predicates
339        .push(parse_quote!(#inner_ty: #trait_path));
340    let (impl_generics, ty_generics, where_clause) = impl_generics.split_for_impl();
341
342    quote! {
343        impl #impl_generics #trait_path for #name #ty_generics #where_clause {
344            fn #method_ident(&self) -> #return_ty {
345                <#inner_ty as #trait_path>::#method_ident(&self.#inner_field_ident)
346            }
347        }
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use quote::quote;
354    use syn::punctuated::Punctuated;
355    use syn::{parse_quote, parse_str, Data, DeriveInput, Field, Fields, Path};
356
357    use super::{
358        derive_request_fields_impl, infer_method_from_trait_path, parse_openpit_items,
359        register_trait_once, CapabilitySpec, OpenpitAttrItem,
360    };
361
362    fn clear_first_named_field_ident(input: &mut DeriveInput) -> bool {
363        match &mut input.data {
364            Data::Struct(data) => match &mut data.fields {
365                Fields::Named(fields) => {
366                    fields.named[0].ident = None;
367                    true
368                }
369                _ => false,
370            },
371            _ => false,
372        }
373    }
374
375    #[test]
376    fn infer_method_from_has_trait_converts_to_snake_case() {
377        let path: Path = parse_quote!(crate::HasOrderPrice);
378        let method = infer_method_from_trait_path(&path).expect("inference must succeed");
379        assert_eq!(method.to_string(), "order_price");
380    }
381
382    #[test]
383    fn infer_method_from_trait_rejects_non_has_prefix() {
384        let path: Path = parse_quote!(crate::TraitWithoutPrefix);
385        let err = infer_method_from_trait_path(&path).expect_err("must reject trait without Has");
386        assert_eq!(
387            err.to_string(),
388            "method inference requires a `Has*` trait name"
389        );
390    }
391
392    #[test]
393    fn infer_method_from_has_rejects_empty_stem() {
394        let path: Path = parse_quote!(Has);
395        let err = infer_method_from_trait_path(&path).expect_err("empty method stem must reject");
396        assert_eq!(
397            err.to_string(),
398            "trait name `Has` does not contain a method stem"
399        );
400    }
401
402    #[test]
403    fn infer_method_from_empty_path_rejects() {
404        let path = Path {
405            leading_colon: None,
406            segments: Punctuated::new(),
407        };
408        let err = infer_method_from_trait_path(&path).expect_err("empty path must reject");
409        assert_eq!(err.to_string(), "trait path must have at least one segment");
410    }
411
412    #[test]
413    fn parse_openpit_items_rejects_empty_attribute() {
414        let field: Field = parse_quote!(
415            #[openpit()]
416            operation: Operation
417        );
418        let err = parse_openpit_items(&field)
419            .err()
420            .expect("empty attribute must reject");
421        assert_eq!(err.to_string(), "empty #[openpit(...)] is not allowed");
422    }
423
424    #[test]
425    fn parse_openpit_items_rejects_duplicate_inner_marker() {
426        let field: Field = parse_quote!(
427            #[openpit(inner, inner)]
428            operation: Operation
429        );
430        let err = parse_openpit_items(&field)
431            .err()
432            .expect("duplicate inner must reject");
433        assert_eq!(
434            err.to_string(),
435            "duplicate `inner` marker in #[openpit(...)]"
436        );
437    }
438
439    #[test]
440    fn parse_openpit_items_parses_inner_and_capabilities() {
441        let field: Field = parse_quote!(
442            #[openpit(inner, crate::HasPnl(-> Result<Pnl, RequestFieldAccessError>))]
443            operation: Operation
444        );
445        let parsed = parse_openpit_items(&field).expect("must parse valid attribute");
446        assert!(parsed.inner);
447        assert_eq!(parsed.capabilities.len(), 1);
448        let capability = &parsed.capabilities[0];
449        let trait_path = &capability.trait_path;
450        assert_eq!(quote!(#trait_path).to_string(), "crate :: HasPnl");
451        assert_eq!(capability.method_ident.to_string(), "pnl");
452    }
453
454    #[test]
455    fn parse_openpit_items_ignores_non_openpit_attributes() {
456        let field: Field = parse_quote!(
457            #[serde(default)]
458            operation: Operation
459        );
460        let parsed = parse_openpit_items(&field).expect("must ignore non-openpit attributes");
461        assert!(!parsed.inner);
462        assert!(parsed.capabilities.is_empty());
463    }
464
465    #[test]
466    fn register_trait_once_rejects_duplicates() {
467        let mut seen = std::collections::BTreeSet::new();
468        let capability = CapabilitySpec {
469            trait_path: parse_quote!(crate::HasInstrument),
470            method_ident: parse_quote!(instrument),
471            return_ty: parse_quote!(Result<&Instrument, RequestFieldAccessError>),
472        };
473        register_trait_once(&mut seen, &capability, &capability)
474            .expect("first mapping must register");
475        let err = register_trait_once(&mut seen, &capability, &capability)
476            .expect_err("duplicate mapping must reject");
477        assert_eq!(
478            err.to_string(),
479            "duplicate trait mapping for crate :: HasInstrument"
480        );
481    }
482
483    #[test]
484    fn derive_skips_field_without_ident_when_ast_is_malformed() {
485        let mut input: DeriveInput = parse_quote!(
486            struct Wrapper {
487                operation: Operation,
488            }
489        );
490        assert!(clear_first_named_field_ident(&mut input));
491
492        let generated =
493            derive_request_fields_impl(input).expect("malformed field without ident is skipped");
494        assert!(generated.is_empty());
495    }
496
497    #[test]
498    fn clear_first_named_field_ident_returns_false_for_non_struct() {
499        let mut input: DeriveInput = parse_quote!(
500            enum Wrapper {
501                A,
502            }
503        );
504        assert!(!clear_first_named_field_ident(&mut input));
505    }
506
507    #[test]
508    fn clear_first_named_field_ident_returns_false_for_unnamed_struct() {
509        let mut input: DeriveInput = parse_quote!(
510            struct Wrapper(u64);
511        );
512        assert!(!clear_first_named_field_ident(&mut input));
513    }
514
515    #[test]
516    fn parse_openpit_attr_item_parses_inferred_method_signature() {
517        let item: OpenpitAttrItem = parse_str("HasPnl(-> Result<Pnl, RequestFieldAccessError>)")
518            .expect("must parse inferred signature");
519        assert_eq!(capability_method_name(item).as_deref(), Some("pnl"));
520    }
521
522    #[test]
523    fn parse_openpit_attr_item_parses_explicit_method_signature() {
524        let item: OpenpitAttrItem =
525            parse_str("HasInstrument(instrument -> Result<&Instrument, RequestFieldAccessError>)")
526                .expect("must parse explicit signature");
527        assert_eq!(capability_method_name(item).as_deref(), Some("instrument"));
528    }
529
530    #[test]
531    fn parse_openpit_attr_item_parses_inner_marker() {
532        let item: OpenpitAttrItem = parse_str("inner").expect("must parse inner marker");
533        assert_eq!(capability_method_name(item), None);
534    }
535
536    #[test]
537    fn derive_request_fields_impl_generates_passthrough_for_inner_capability() {
538        let input: DeriveInput = parse_quote!(
539            struct Wrapper<T> {
540                #[openpit(inner, HasPnl(-> Result<Pnl, RequestFieldAccessError>))]
541                inner: T,
542            }
543        );
544
545        let generated = derive_request_fields_impl(input).expect("derive generation must succeed");
546        let generated_src = generated.to_string();
547        assert!(generated_src.contains("impl < T > HasPnl for Wrapper < T > where T : HasPnl"));
548        assert!(generated_src.contains("< T as HasPnl > :: pnl"));
549        assert!(generated_src.contains("& self . inner"));
550    }
551
552    fn capability_method_name(item: OpenpitAttrItem) -> Option<String> {
553        match item {
554            OpenpitAttrItem::Capability(spec) => Some(spec.method_ident.to_string()),
555            OpenpitAttrItem::Inner(_) => None,
556        }
557    }
558}