Skip to main content

ordo_derive/
lib.rs

1//! Procedural macros for Ordo rule engine
2//!
3//! This crate provides derive macros for generating TypedContext implementations
4//! that enable zero-overhead field access in JIT-compiled code.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use ordo_derive::TypedContext;
10//!
11//! #[derive(TypedContext)]
12//! #[repr(C)]  // Recommended for stable layout
13//! pub struct LoanContext {
14//!     pub amount: f64,
15//!     pub credit_score: i32,
16//!     pub approved: bool,
17//! }
18//! ```
19//!
20//! The macro generates an implementation of `TypedContext` that provides:
21//! - A static `MessageSchema` describing the struct layout
22//! - Direct field pointer access via `field_ptr()`
23//! - Nested field path resolution (for nested structs)
24
25use proc_macro::TokenStream;
26use quote::quote;
27use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
28
29/// Derive macro for generating TypedContext implementations
30///
31/// This macro generates code that enables zero-overhead field access
32/// in JIT-compiled expressions.
33///
34/// # Requirements
35///
36/// - Struct must have named fields
37/// - For optimal performance, use `#[repr(C)]` to ensure stable layout
38/// - Supported field types: bool, i32, i64, u32, u64, f32, f64
39///
40/// # Generated Code
41///
42/// The macro generates:
43/// 1. A static `MessageSchema` with field offsets computed at compile time
44/// 2. `TypedContext::schema()` returning the schema
45/// 3. `TypedContext::field_ptr()` for direct field access
46///
47/// # Example
48///
49/// ```ignore
50/// #[derive(TypedContext)]
51/// #[repr(C)]
52/// pub struct MyContext {
53///     pub value: f64,      // offset: 0
54///     pub count: i32,      // offset: 8
55///     pub active: bool,    // offset: 12
56/// }
57/// ```
58#[proc_macro_derive(TypedContext, attributes(typed_context))]
59pub fn derive_typed_context(input: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(input as DeriveInput);
61
62    let name = &input.ident;
63    let name_str = name.to_string();
64
65    // Extract struct fields
66    let fields = match &input.data {
67        Data::Struct(data) => match &data.fields {
68            Fields::Named(fields) => &fields.named,
69            _ => {
70                return syn::Error::new_spanned(
71                    &input,
72                    "TypedContext can only be derived for structs with named fields",
73                )
74                .to_compile_error()
75                .into();
76            }
77        },
78        _ => {
79            return syn::Error::new_spanned(&input, "TypedContext can only be derived for structs")
80                .to_compile_error()
81                .into();
82        }
83    };
84
85    // Generate field schema entries and match arms
86    let mut schema_fields = Vec::new();
87    let mut match_arms = Vec::new();
88
89    for field in fields.iter() {
90        let field_name = field.ident.as_ref().unwrap();
91        let field_name_str = field_name.to_string();
92        let field_type = &field.ty;
93
94        // Map Rust type to FieldType
95        let field_type_expr = rust_type_to_field_type(field_type);
96
97        // Generate schema field entry
98        schema_fields.push(quote! {
99            ordo_core::context::FieldSchema::new(
100                #field_name_str,
101                #field_type_expr,
102                // Use memoffset-style offset calculation
103                {
104                    let uninit = ::std::mem::MaybeUninit::<#name>::uninit();
105                    let base_ptr = uninit.as_ptr();
106                    let field_ptr = unsafe { ::std::ptr::addr_of!((*base_ptr).#field_name) };
107                    (field_ptr as usize) - (base_ptr as usize)
108                },
109            )
110        });
111
112        // Generate match arm for field_ptr
113        match_arms.push(quote! {
114            #field_name_str => ::std::option::Option::Some((
115                ::std::ptr::addr_of!(self.#field_name) as *const u8,
116                #field_type_expr,
117            ))
118        });
119    }
120
121    // Generate the implementation
122    let expanded = quote! {
123        impl ordo_core::expr::jit::TypedContext for #name {
124            fn schema() -> &'static ordo_core::context::MessageSchema {
125                use ::std::sync::OnceLock;
126
127                static SCHEMA: OnceLock<ordo_core::context::MessageSchema> = OnceLock::new();
128                SCHEMA.get_or_init(|| {
129                    ordo_core::context::MessageSchema::new(
130                        #name_str,
131                        vec![
132                            #(#schema_fields,)*
133                        ],
134                    )
135                })
136            }
137
138            unsafe fn field_ptr(
139                &self,
140                field_name: &str,
141            ) -> ::std::option::Option<(*const u8, ordo_core::context::FieldType)> {
142                match field_name {
143                    #(#match_arms,)*
144                    _ => ::std::option::Option::None,
145                }
146            }
147        }
148    };
149
150    TokenStream::from(expanded)
151}
152
153/// Convert a Rust type to the corresponding FieldType expression
154fn rust_type_to_field_type(ty: &Type) -> proc_macro2::TokenStream {
155    let type_str = quote!(#ty).to_string().replace(' ', "");
156
157    match type_str.as_str() {
158        "bool" => quote!(ordo_core::context::FieldType::Bool),
159        "i32" => quote!(ordo_core::context::FieldType::Int32),
160        "i64" => quote!(ordo_core::context::FieldType::Int64),
161        "u32" => quote!(ordo_core::context::FieldType::UInt32),
162        "u64" => quote!(ordo_core::context::FieldType::UInt64),
163        "f32" => quote!(ordo_core::context::FieldType::Float32),
164        "f64" => quote!(ordo_core::context::FieldType::Float64),
165        "String" | "::std::string::String" | "std::string::String" => {
166            quote!(ordo_core::context::FieldType::String)
167        }
168        "Vec<u8>" | "::std::vec::Vec<u8>" => {
169            quote!(ordo_core::context::FieldType::Bytes)
170        }
171        _ => {
172            // For unknown types, try to treat as a nested message
173            // This requires the nested type to also implement TypedContext
174            quote! {
175                ordo_core::context::FieldType::Message(
176                    ::std::sync::Arc::new(<#ty as ordo_core::expr::jit::TypedContext>::schema().clone())
177                )
178            }
179        }
180    }
181}
182
183/// Derive macro for generating TypedContext for prost-generated types
184///
185/// This is similar to TypedContext but specifically handles prost attributes
186/// to extract proto tag numbers.
187#[proc_macro_derive(ProstTypedContext, attributes(prost))]
188pub fn derive_prost_typed_context(input: TokenStream) -> TokenStream {
189    let input = parse_macro_input!(input as DeriveInput);
190
191    let name = &input.ident;
192    let name_str = name.to_string();
193
194    // Extract struct fields
195    let fields = match &input.data {
196        Data::Struct(data) => match &data.fields {
197            Fields::Named(fields) => &fields.named,
198            _ => {
199                return syn::Error::new_spanned(
200                    &input,
201                    "ProstTypedContext can only be derived for structs with named fields",
202                )
203                .to_compile_error()
204                .into();
205            }
206        },
207        _ => {
208            return syn::Error::new_spanned(
209                &input,
210                "ProstTypedContext can only be derived for structs",
211            )
212            .to_compile_error()
213            .into();
214        }
215    };
216
217    // Generate field schema entries and match arms
218    let mut schema_fields = Vec::new();
219    let mut match_arms = Vec::new();
220
221    for field in fields.iter() {
222        let field_name = field.ident.as_ref().unwrap();
223        let field_name_str = field_name.to_string();
224        let field_type = &field.ty;
225
226        // Extract proto tag from #[prost(..., tag = "N")] if present
227        let proto_tag = extract_prost_tag(&field.attrs);
228
229        // Map Rust type to FieldType
230        let field_type_expr = rust_type_to_field_type(field_type);
231
232        // Generate schema field entry with proto tag
233        let schema_field = if let Some(tag) = proto_tag {
234            quote! {
235                ordo_core::context::FieldSchema::new(
236                    #field_name_str,
237                    #field_type_expr,
238                    {
239                        let uninit = ::std::mem::MaybeUninit::<#name>::uninit();
240                        let base_ptr = uninit.as_ptr();
241                        let field_ptr = unsafe { ::std::ptr::addr_of!((*base_ptr).#field_name) };
242                        (field_ptr as usize) - (base_ptr as usize)
243                    },
244                ).with_proto_tag(#tag)
245            }
246        } else {
247            quote! {
248                ordo_core::context::FieldSchema::new(
249                    #field_name_str,
250                    #field_type_expr,
251                    {
252                        let uninit = ::std::mem::MaybeUninit::<#name>::uninit();
253                        let base_ptr = uninit.as_ptr();
254                        let field_ptr = unsafe { ::std::ptr::addr_of!((*base_ptr).#field_name) };
255                        (field_ptr as usize) - (base_ptr as usize)
256                    },
257                )
258            }
259        };
260
261        schema_fields.push(schema_field);
262
263        // Generate match arm for field_ptr
264        match_arms.push(quote! {
265            #field_name_str => ::std::option::Option::Some((
266                ::std::ptr::addr_of!(self.#field_name) as *const u8,
267                #field_type_expr,
268            ))
269        });
270    }
271
272    // Generate the implementation
273    let expanded = quote! {
274        impl ordo_core::expr::jit::TypedContext for #name {
275            fn schema() -> &'static ordo_core::context::MessageSchema {
276                use ::std::sync::OnceLock;
277
278                static SCHEMA: OnceLock<ordo_core::context::MessageSchema> = OnceLock::new();
279                SCHEMA.get_or_init(|| {
280                    ordo_core::context::MessageSchema::new(
281                        #name_str,
282                        vec![
283                            #(#schema_fields,)*
284                        ],
285                    )
286                })
287            }
288
289            unsafe fn field_ptr(
290                &self,
291                field_name: &str,
292            ) -> ::std::option::Option<(*const u8, ordo_core::context::FieldType)> {
293                match field_name {
294                    #(#match_arms,)*
295                    _ => ::std::option::Option::None,
296                }
297            }
298        }
299    };
300
301    TokenStream::from(expanded)
302}
303
304/// Extract the proto tag number from prost attributes
305fn extract_prost_tag(attrs: &[syn::Attribute]) -> Option<u32> {
306    for attr in attrs {
307        if attr.path().is_ident("prost") {
308            // Parse the attribute content to find tag = "N"
309            if let Ok(syn::Meta::List(list)) = attr.parse_args::<syn::Meta>() {
310                for nested in list.tokens.clone().into_iter() {
311                    let token_str = nested.to_string();
312                    if token_str.starts_with("tag") {
313                        // Extract the number from tag = "N"
314                        if let Some(num_str) = token_str
315                            .split('=')
316                            .nth(1)
317                            .map(|s| s.trim().trim_matches('"').trim())
318                        {
319                            if let Ok(tag) = num_str.parse::<u32>() {
320                                return Some(tag);
321                            }
322                        }
323                    }
324                }
325            }
326
327            // Fallback: try to parse as a simple token stream and look for tag
328            let tokens = attr.meta.require_list().ok()?.tokens.to_string();
329            for part in tokens.split(',') {
330                let part = part.trim();
331                if part.starts_with("tag") {
332                    if let Some(num_str) = part
333                        .split('=')
334                        .nth(1)
335                        .map(|s| s.trim().trim_matches('"').trim())
336                    {
337                        if let Ok(tag) = num_str.parse::<u32>() {
338                            return Some(tag);
339                        }
340                    }
341                }
342            }
343        }
344    }
345    None
346}