Skip to main content

conduit_derive/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3//! Derive macros for conduit-core's `Encode` and `Decode` traits.
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{Data, DeriveInput, Fields, parse_macro_input};
8
9/// Derive the `Encode` trait for a struct with named fields.
10///
11/// Generates a `conduit_core::Encode` implementation that encodes each
12/// field in declaration order by delegating to the field type's own
13/// `Encode` impl.
14///
15/// # Example
16///
17/// ```rust,ignore
18/// use conduit_derive::Encode;
19///
20/// #[derive(Encode)]
21/// struct MarketTick {
22///     timestamp: i64,
23///     price: f64,
24///     volume: f64,
25///     side: u8,
26/// }
27/// ```
28#[proc_macro_derive(Encode)]
29pub fn derive_wire_encode(input: TokenStream) -> TokenStream {
30    let input = parse_macro_input!(input as DeriveInput);
31    match impl_wire_encode(&input) {
32        Ok(tokens) => tokens.into(),
33        Err(err) => err.to_compile_error().into(),
34    }
35}
36
37/// Derive the `Decode` trait for a struct with named fields.
38///
39/// Generates a `conduit_core::Decode` implementation that decodes each
40/// field in declaration order by delegating to the field type's own
41/// `Decode` impl, tracking the cumulative byte offset.
42///
43/// # Example
44///
45/// ```rust,ignore
46/// use conduit_derive::Decode;
47///
48/// #[derive(Decode)]
49/// struct MarketTick {
50///     timestamp: i64,
51///     price: f64,
52///     volume: f64,
53///     side: u8,
54/// }
55/// ```
56#[proc_macro_derive(Decode)]
57pub fn derive_wire_decode(input: TokenStream) -> TokenStream {
58    let input = parse_macro_input!(input as DeriveInput);
59    match impl_wire_decode(&input) {
60        Ok(tokens) => tokens.into(),
61        Err(err) => err.to_compile_error().into(),
62    }
63}
64
65/// Extract named fields from a `DeriveInput`, rejecting enums, unions, and
66/// tuple/unit structs with a compile error.
67fn named_fields(input: &DeriveInput) -> syn::Result<&syn::FieldsNamed> {
68    let name = &input.ident;
69    match &input.data {
70        Data::Struct(data) => match &data.fields {
71            Fields::Named(named) => Ok(named),
72            _ => Err(syn::Error::new_spanned(
73                name,
74                "Encode / Decode can only be derived for structs with named fields",
75            )),
76        },
77        Data::Enum(_) => Err(syn::Error::new_spanned(
78            name,
79            "Encode / Decode cannot be derived for enums",
80        )),
81        Data::Union(_) => Err(syn::Error::new_spanned(
82            name,
83            "Encode / Decode cannot be derived for unions",
84        )),
85    }
86}
87
88/// Reject generic structs with a compile error — wire encoding requires
89/// a fixed, concrete layout.
90fn reject_generics(input: &DeriveInput) -> syn::Result<()> {
91    if !input.generics.params.is_empty() {
92        return Err(syn::Error::new_spanned(
93            &input.generics,
94            "Encode / Decode cannot be derived for generic structs",
95        ));
96    }
97    Ok(())
98}
99
100/// Generate the `Encode` impl: encodes each named field in declaration
101/// order and sums their `encode_size()` for the total.
102fn impl_wire_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
103    reject_generics(input)?;
104    let name = &input.ident;
105    let fields = named_fields(input)?;
106
107    let encode_stmts: Vec<_> = fields
108        .named
109        .iter()
110        .map(|f| {
111            let ident = f.ident.as_ref().unwrap();
112            quote! {
113                conduit_core::Encode::encode(&self.#ident, buf);
114            }
115        })
116        .collect();
117
118    let size_terms: Vec<_> = fields
119        .named
120        .iter()
121        .map(|f| {
122            let ident = f.ident.as_ref().unwrap();
123            quote! {
124                conduit_core::Encode::encode_size(&self.#ident)
125            }
126        })
127        .collect();
128
129    // Handle the zero-field edge case: encode_size returns 0.
130    let size_expr = if size_terms.is_empty() {
131        quote! { 0 }
132    } else {
133        let first = &size_terms[0];
134        let rest = &size_terms[1..];
135        quote! { #first #(+ #rest)* }
136    };
137
138    Ok(quote! {
139        impl conduit_core::Encode for #name {
140            fn encode(&self, buf: &mut Vec<u8>) {
141                #(#encode_stmts)*
142            }
143
144            fn encode_size(&self) -> usize {
145                #size_expr
146            }
147        }
148    })
149}
150
151/// Generate the `Decode` impl: decodes each named field in declaration
152/// order, tracking the cumulative byte offset through the input slice.
153fn impl_wire_decode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
154    reject_generics(input)?;
155    let name = &input.ident;
156    let fields = named_fields(input)?;
157
158    let decode_stmts: Vec<_> = fields
159        .named
160        .iter()
161        .map(|f| {
162            let ident = f.ident.as_ref().unwrap();
163            quote! {
164                let (#ident, __n) = conduit_core::Decode::decode(&__data[__offset..])?;
165                __offset += __n;
166            }
167        })
168        .collect();
169
170    let field_names: Vec<_> = fields
171        .named
172        .iter()
173        .map(|f| f.ident.as_ref().unwrap())
174        .collect();
175
176    Ok(quote! {
177        impl conduit_core::Decode for #name {
178            fn decode(__data: &[u8]) -> Option<(Self, usize)> {
179                let mut __offset = 0usize;
180                #(#decode_stmts)*
181                Some((Self { #(#field_names),* }, __offset))
182            }
183        }
184    })
185}