Skip to main content

photon_ring_derive/
lib.rs

1// Copyright 2026 Photon Ring Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4//! Derive macros for [`photon_ring::Pod`] and [`photon_ring::Message`].
5//!
6//! ## `Pod` derive
7//!
8//! ```ignore
9//! #[derive(photon_ring::Pod)]
10//! struct Quote {
11//!     price: f64,
12//!     volume: u32,
13//! }
14//! ```
15//!
16//! This generates `#[repr(C)]`, `#[derive(Clone, Copy)]`, and
17//! `unsafe impl photon_ring::Pod for Quote {}` — with a compile-time
18//! check that every field type implements `Pod`.
19//!
20//! ## `Message` derive
21//!
22//! ```ignore
23//! #[derive(photon_ring::Message)]
24//! struct Order {
25//!     price: f64,
26//!     qty: u32,
27//!     side: Side,        // any #[repr(u8)] enum
28//!     filled: bool,
29//!     tag: Option<u32>,
30//! }
31//! ```
32//!
33//! Generates a Pod-compatible wire struct (`OrderWire`) plus `From`
34//! conversions in both directions. See [`derive_message`] for details.
35
36use proc_macro::TokenStream;
37use proc_macro2::Span;
38use quote::{format_ident, quote};
39use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericArgument, PathArguments, Type};
40
41/// Derive `Pod` for a struct.
42///
43/// Requirements:
44/// - Must be a struct (not enum or union).
45/// - All fields must implement `Pod`.
46/// - The struct will be given `#[repr(C)]` semantics (the macro adds
47///   `Clone` and `Copy` derives and the `unsafe impl Pod`).
48///
49/// # Example
50///
51/// ```ignore
52/// #[derive(photon_ring::Pod)]
53/// struct Tick {
54///     price: f64,
55///     volume: u32,
56///     _pad: u32,
57/// }
58/// ```
59#[proc_macro_derive(Pod)]
60pub fn derive_pod(input: TokenStream) -> TokenStream {
61    let input = parse_macro_input!(input as DeriveInput);
62    let name = &input.ident;
63    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
64
65    // Only structs are supported
66    let fields = match &input.data {
67        Data::Struct(s) => match &s.fields {
68            Fields::Named(f) => f.named.iter().collect::<Vec<_>>(),
69            Fields::Unnamed(f) => f.unnamed.iter().collect::<Vec<_>>(),
70            Fields::Unit => vec![],
71        },
72        _ => {
73            return syn::Error::new_spanned(&input.ident, "Pod can only be derived for structs")
74                .to_compile_error()
75                .into();
76        }
77    };
78
79    // Generate compile-time assertions that every field is Pod
80    let field_assertions = fields.iter().map(|f| {
81        let ty = &f.ty;
82        quote! {
83            const _: () = {
84                fn _assert_pod<T: photon_ring::Pod>() {}
85                fn _check() { _assert_pod::<#ty>(); }
86            };
87        }
88    });
89
90    let expanded = quote! {
91        // Compile-time field checks
92        #(#field_assertions)*
93
94        // Safety: all fields verified to be Pod via compile-time assertions above.
95        // The derive macro only applies to structs, and Pod requires that every
96        // bit pattern is valid — which holds when all fields are Pod.
97        unsafe impl #impl_generics photon_ring::Pod for #name #ty_generics #where_clause {}
98    };
99
100    TokenStream::from(expanded)
101}
102
103// ---------------------------------------------------------------------------
104// Message derive
105// ---------------------------------------------------------------------------
106
107/// Classification of a field type for wire conversion.
108enum FieldKind {
109    /// Numeric or array — passes through unchanged.
110    Passthrough,
111    /// `bool` → `u8`.
112    Bool,
113    /// `usize` → `u64`.
114    Usize,
115    /// `isize` → `i64`.
116    Isize,
117    /// `Option<T>` where T is a numeric type → `u64`.
118    /// Stores the inner type for the back-conversion cast.
119    OptionNumeric(Type),
120    /// Any other type — assumed to be a `#[repr(u8)]` enum → `u8`.
121    Enum,
122}
123
124/// Returns true if `ty` is a known Pod-passthrough numeric or float type.
125fn is_numeric(ty: &Type) -> bool {
126    if let Type::Path(p) = ty {
127        if let Some(seg) = p.path.segments.last() {
128            let id = seg.ident.to_string();
129            return matches!(
130                id.as_str(),
131                "u8" | "u16"
132                    | "u32"
133                    | "u64"
134                    | "u128"
135                    | "i8"
136                    | "i16"
137                    | "i32"
138                    | "i64"
139                    | "i128"
140                    | "f32"
141                    | "f64"
142            );
143        }
144    }
145    false
146}
147
148/// Classify a field's type into a [`FieldKind`].
149fn classify(ty: &Type) -> FieldKind {
150    match ty {
151        // Arrays `[T; N]` — passthrough (must be Pod).
152        Type::Array(_) => FieldKind::Passthrough,
153
154        Type::Path(p) => {
155            let seg = match p.path.segments.last() {
156                Some(s) => s,
157                None => return FieldKind::Enum,
158            };
159            let id = seg.ident.to_string();
160
161            match id.as_str() {
162                // Numerics — passthrough
163                "u8" | "u16" | "u32" | "u64" | "u128" | "i8" | "i16" | "i32" | "i64" | "i128"
164                | "f32" | "f64" => FieldKind::Passthrough,
165
166                "bool" => FieldKind::Bool,
167                "usize" => FieldKind::Usize,
168                "isize" => FieldKind::Isize,
169
170                "Option" => {
171                    // Extract inner type from Option<T>
172                    if let PathArguments::AngleBracketed(args) = &seg.arguments {
173                        if let Some(GenericArgument::Type(inner)) = args.args.first() {
174                            if is_numeric(inner) {
175                                return FieldKind::OptionNumeric(inner.clone());
176                            }
177                        }
178                    }
179                    // Non-numeric Option — not supported, will be caught by
180                    // the compile-time size assertion below.
181                    FieldKind::Enum
182                }
183
184                // Anything else — assume enum
185                _ => FieldKind::Enum,
186            }
187        }
188
189        _ => FieldKind::Enum,
190    }
191}
192
193/// Derive a Pod-compatible wire struct with `From` conversions.
194///
195/// Given a struct with fields that may include `bool`, `Option<numeric>`,
196/// `usize`/`isize`, and `#[repr(u8)]` enums, generates:
197///
198/// 1. **`{Name}Wire`** — a `#[repr(C)] Clone + Copy` struct with all fields
199///    converted to Pod-safe types, plus `unsafe impl Pod`.
200/// 2. **`From<Name> for {Name}Wire`** — converts the domain struct to wire.
201/// 3. **`From<{Name}Wire> for Name`** — converts the wire struct back.
202///
203/// # Field type mappings
204///
205/// | Source type | Wire type | To wire | From wire |
206/// |---|---|---|---|
207/// | `f32`, `f64`, `u8`..`u128`, `i8`..`i128` | same | passthrough | passthrough |
208/// | `usize` | `u64` | `as u64` | `as usize` |
209/// | `isize` | `i64` | `as i64` | `as isize` |
210/// | `bool` | `u8` | `if v { 1 } else { 0 }` | `v != 0` |
211/// | `Option<T>` (T numeric) | `u64` | `Some(x) => x as u64, None => 0` | `0 => None, v => Some(v as T)` |
212/// | `[T; N]` (T: Pod) | same | passthrough | passthrough |
213/// | Any other type | `u8` | `v as u8` | `transmute(v)` |
214///
215/// # Enum safety
216///
217/// Enum fields are converted via `transmute` from `u8`. The enum **must**
218/// have `#[repr(u8)]` — the macro emits a compile-time `size_of` check to
219/// enforce this.
220///
221/// # Example
222///
223/// ```ignore
224/// #[repr(u8)]
225/// #[derive(Clone, Copy)]
226/// enum Side { Buy = 0, Sell = 1 }
227///
228/// #[derive(photon_ring::Message)]
229/// struct Order {
230///     price: f64,
231///     qty: u32,
232///     side: Side,
233///     filled: bool,
234///     tag: Option<u32>,
235/// }
236/// // Generates: OrderWire, From<Order> for OrderWire, From<OrderWire> for Order
237/// ```
238#[proc_macro_derive(Message)]
239pub fn derive_message(input: TokenStream) -> TokenStream {
240    let input = parse_macro_input!(input as DeriveInput);
241    let name = &input.ident;
242    let wire_name = format_ident!("{}Wire", name);
243
244    // Only named structs are supported
245    let fields = match &input.data {
246        Data::Struct(s) => match &s.fields {
247            Fields::Named(f) => f.named.iter().collect::<Vec<_>>(),
248            _ => {
249                return syn::Error::new_spanned(
250                    &input.ident,
251                    "Message can only be derived for structs with named fields",
252                )
253                .to_compile_error()
254                .into();
255            }
256        },
257        _ => {
258            return syn::Error::new_spanned(
259                &input.ident,
260                "Message can only be derived for structs",
261            )
262            .to_compile_error()
263            .into();
264        }
265    };
266
267    let mut wire_fields = Vec::new();
268    let mut to_wire = Vec::new();
269    let mut from_wire = Vec::new();
270    let mut assertions = Vec::new();
271
272    for field in &fields {
273        let fname = field.ident.as_ref().unwrap();
274        let fty = &field.ty;
275        let kind = classify(fty);
276
277        match kind {
278            FieldKind::Passthrough => {
279                wire_fields.push(quote! { pub #fname: #fty });
280                to_wire.push(quote! { #fname: src.#fname });
281                from_wire.push(quote! { #fname: src.#fname });
282            }
283            FieldKind::Bool => {
284                wire_fields.push(quote! { pub #fname: u8 });
285                to_wire.push(quote! { #fname: if src.#fname { 1 } else { 0 } });
286                from_wire.push(quote! { #fname: src.#fname != 0 });
287            }
288            FieldKind::Usize => {
289                wire_fields.push(quote! { pub #fname: u64 });
290                to_wire.push(quote! { #fname: src.#fname as u64 });
291                from_wire.push(quote! { #fname: src.#fname as usize });
292            }
293            FieldKind::Isize => {
294                wire_fields.push(quote! { pub #fname: i64 });
295                to_wire.push(quote! { #fname: src.#fname as i64 });
296                from_wire.push(quote! { #fname: src.#fname as isize });
297            }
298            FieldKind::OptionNumeric(inner) => {
299                wire_fields.push(quote! { pub #fname: u64 });
300                to_wire.push(quote! {
301                    #fname: match src.#fname {
302                        Some(v) => v as u64,
303                        None => 0,
304                    }
305                });
306                from_wire.push(quote! {
307                    #fname: if src.#fname == 0 {
308                        None
309                    } else {
310                        Some(src.#fname as #inner)
311                    }
312                });
313            }
314            FieldKind::Enum => {
315                wire_fields.push(quote! { pub #fname: u8 });
316                to_wire.push(quote! { #fname: src.#fname as u8 });
317                from_wire.push(quote! {
318                    #fname: unsafe { core::mem::transmute::<u8, #fty>(src.#fname) }
319                });
320                // Compile-time assertion: enum must be 1 byte (#[repr(u8)])
321                let msg = format!(
322                    "Message derive: field `{}` has type `{}` which is not 1 byte. \
323                     Enum fields must have #[repr(u8)].",
324                    fname,
325                    quote! { #fty },
326                );
327                let msg_lit = syn::LitStr::new(&msg, Span::call_site());
328                assertions.push(quote! {
329                    const _: () = {
330                        assert!(
331                            core::mem::size_of::<#fty>() == 1,
332                            #msg_lit,
333                        );
334                    };
335                });
336            }
337        }
338    }
339
340    let expanded = quote! {
341        // Compile-time assertions for enum fields
342        #(#assertions)*
343
344        /// Auto-generated Pod-compatible wire struct for [`#name`].
345        #[repr(C)]
346        #[derive(Clone, Copy)]
347        pub struct #wire_name {
348            #(#wire_fields),*
349        }
350
351        // Safety: all fields of the wire struct are plain numeric types
352        // (u8, u32, u64, f32, f64, etc.) where every bit pattern is valid.
353        unsafe impl photon_ring::Pod for #wire_name {}
354
355        impl From<#name> for #wire_name {
356            #[inline]
357            fn from(src: #name) -> Self {
358                #wire_name {
359                    #(#to_wire),*
360                }
361            }
362        }
363
364        impl From<#wire_name> for #name {
365            #[inline]
366            fn from(src: #wire_name) -> Self {
367                #name {
368                    #(#from_wire),*
369                }
370            }
371        }
372    };
373
374    TokenStream::from(expanded)
375}