1use proc_macro::TokenStream;
26use quote::quote;
27use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
28
29#[proc_macro_derive(TypedContext, attributes(typed_context))]
59pub fn derive_typed_context(input: TokenStream) -> TokenStream {
60 let input = parse_macro_input!(input as DeriveInput);
61
62 let name = &input.ident;
63 let name_str = name.to_string();
64
65 let fields = match &input.data {
67 Data::Struct(data) => match &data.fields {
68 Fields::Named(fields) => &fields.named,
69 _ => {
70 return syn::Error::new_spanned(
71 &input,
72 "TypedContext can only be derived for structs with named fields",
73 )
74 .to_compile_error()
75 .into();
76 }
77 },
78 _ => {
79 return syn::Error::new_spanned(&input, "TypedContext can only be derived for structs")
80 .to_compile_error()
81 .into();
82 }
83 };
84
85 let mut schema_fields = Vec::new();
87 let mut match_arms = Vec::new();
88
89 for field in fields.iter() {
90 let field_name = field.ident.as_ref().unwrap();
91 let field_name_str = field_name.to_string();
92 let field_type = &field.ty;
93
94 let field_type_expr = rust_type_to_field_type(field_type);
96
97 schema_fields.push(quote! {
99 ordo_core::context::FieldSchema::new(
100 #field_name_str,
101 #field_type_expr,
102 {
104 let uninit = ::std::mem::MaybeUninit::<#name>::uninit();
105 let base_ptr = uninit.as_ptr();
106 let field_ptr = unsafe { ::std::ptr::addr_of!((*base_ptr).#field_name) };
107 (field_ptr as usize) - (base_ptr as usize)
108 },
109 )
110 });
111
112 match_arms.push(quote! {
114 #field_name_str => ::std::option::Option::Some((
115 ::std::ptr::addr_of!(self.#field_name) as *const u8,
116 #field_type_expr,
117 ))
118 });
119 }
120
121 let expanded = quote! {
123 impl ordo_core::expr::jit::TypedContext for #name {
124 fn schema() -> &'static ordo_core::context::MessageSchema {
125 use ::std::sync::OnceLock;
126
127 static SCHEMA: OnceLock<ordo_core::context::MessageSchema> = OnceLock::new();
128 SCHEMA.get_or_init(|| {
129 ordo_core::context::MessageSchema::new(
130 #name_str,
131 vec![
132 #(#schema_fields,)*
133 ],
134 )
135 })
136 }
137
138 unsafe fn field_ptr(
139 &self,
140 field_name: &str,
141 ) -> ::std::option::Option<(*const u8, ordo_core::context::FieldType)> {
142 match field_name {
143 #(#match_arms,)*
144 _ => ::std::option::Option::None,
145 }
146 }
147 }
148 };
149
150 TokenStream::from(expanded)
151}
152
153fn rust_type_to_field_type(ty: &Type) -> proc_macro2::TokenStream {
155 let type_str = quote!(#ty).to_string().replace(' ', "");
156
157 match type_str.as_str() {
158 "bool" => quote!(ordo_core::context::FieldType::Bool),
159 "i32" => quote!(ordo_core::context::FieldType::Int32),
160 "i64" => quote!(ordo_core::context::FieldType::Int64),
161 "u32" => quote!(ordo_core::context::FieldType::UInt32),
162 "u64" => quote!(ordo_core::context::FieldType::UInt64),
163 "f32" => quote!(ordo_core::context::FieldType::Float32),
164 "f64" => quote!(ordo_core::context::FieldType::Float64),
165 "String" | "::std::string::String" | "std::string::String" => {
166 quote!(ordo_core::context::FieldType::String)
167 }
168 "Vec<u8>" | "::std::vec::Vec<u8>" => {
169 quote!(ordo_core::context::FieldType::Bytes)
170 }
171 _ => {
172 quote! {
175 ordo_core::context::FieldType::Message(
176 ::std::sync::Arc::new(<#ty as ordo_core::expr::jit::TypedContext>::schema().clone())
177 )
178 }
179 }
180 }
181}
182
183#[proc_macro_derive(ProstTypedContext, attributes(prost))]
188pub fn derive_prost_typed_context(input: TokenStream) -> TokenStream {
189 let input = parse_macro_input!(input as DeriveInput);
190
191 let name = &input.ident;
192 let name_str = name.to_string();
193
194 let fields = match &input.data {
196 Data::Struct(data) => match &data.fields {
197 Fields::Named(fields) => &fields.named,
198 _ => {
199 return syn::Error::new_spanned(
200 &input,
201 "ProstTypedContext can only be derived for structs with named fields",
202 )
203 .to_compile_error()
204 .into();
205 }
206 },
207 _ => {
208 return syn::Error::new_spanned(
209 &input,
210 "ProstTypedContext can only be derived for structs",
211 )
212 .to_compile_error()
213 .into();
214 }
215 };
216
217 let mut schema_fields = Vec::new();
219 let mut match_arms = Vec::new();
220
221 for field in fields.iter() {
222 let field_name = field.ident.as_ref().unwrap();
223 let field_name_str = field_name.to_string();
224 let field_type = &field.ty;
225
226 let proto_tag = extract_prost_tag(&field.attrs);
228
229 let field_type_expr = rust_type_to_field_type(field_type);
231
232 let schema_field = if let Some(tag) = proto_tag {
234 quote! {
235 ordo_core::context::FieldSchema::new(
236 #field_name_str,
237 #field_type_expr,
238 {
239 let uninit = ::std::mem::MaybeUninit::<#name>::uninit();
240 let base_ptr = uninit.as_ptr();
241 let field_ptr = unsafe { ::std::ptr::addr_of!((*base_ptr).#field_name) };
242 (field_ptr as usize) - (base_ptr as usize)
243 },
244 ).with_proto_tag(#tag)
245 }
246 } else {
247 quote! {
248 ordo_core::context::FieldSchema::new(
249 #field_name_str,
250 #field_type_expr,
251 {
252 let uninit = ::std::mem::MaybeUninit::<#name>::uninit();
253 let base_ptr = uninit.as_ptr();
254 let field_ptr = unsafe { ::std::ptr::addr_of!((*base_ptr).#field_name) };
255 (field_ptr as usize) - (base_ptr as usize)
256 },
257 )
258 }
259 };
260
261 schema_fields.push(schema_field);
262
263 match_arms.push(quote! {
265 #field_name_str => ::std::option::Option::Some((
266 ::std::ptr::addr_of!(self.#field_name) as *const u8,
267 #field_type_expr,
268 ))
269 });
270 }
271
272 let expanded = quote! {
274 impl ordo_core::expr::jit::TypedContext for #name {
275 fn schema() -> &'static ordo_core::context::MessageSchema {
276 use ::std::sync::OnceLock;
277
278 static SCHEMA: OnceLock<ordo_core::context::MessageSchema> = OnceLock::new();
279 SCHEMA.get_or_init(|| {
280 ordo_core::context::MessageSchema::new(
281 #name_str,
282 vec![
283 #(#schema_fields,)*
284 ],
285 )
286 })
287 }
288
289 unsafe fn field_ptr(
290 &self,
291 field_name: &str,
292 ) -> ::std::option::Option<(*const u8, ordo_core::context::FieldType)> {
293 match field_name {
294 #(#match_arms,)*
295 _ => ::std::option::Option::None,
296 }
297 }
298 }
299 };
300
301 TokenStream::from(expanded)
302}
303
304fn extract_prost_tag(attrs: &[syn::Attribute]) -> Option<u32> {
306 for attr in attrs {
307 if attr.path().is_ident("prost") {
308 if let Ok(syn::Meta::List(list)) = attr.parse_args::<syn::Meta>() {
310 for nested in list.tokens.clone().into_iter() {
311 let token_str = nested.to_string();
312 if token_str.starts_with("tag") {
313 if let Some(num_str) = token_str
315 .split('=')
316 .nth(1)
317 .map(|s| s.trim().trim_matches('"').trim())
318 {
319 if let Ok(tag) = num_str.parse::<u32>() {
320 return Some(tag);
321 }
322 }
323 }
324 }
325 }
326
327 let tokens = attr.meta.require_list().ok()?.tokens.to_string();
329 for part in tokens.split(',') {
330 let part = part.trim();
331 if part.starts_with("tag") {
332 if let Some(num_str) = part
333 .split('=')
334 .nth(1)
335 .map(|s| s.trim().trim_matches('"').trim())
336 {
337 if let Ok(tag) = num_str.parse::<u32>() {
338 return Some(tag);
339 }
340 }
341 }
342 }
343 }
344 }
345 None
346}