Skip to main content

data_stream_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{Attribute, Data, DeriveInput, Fields, Lit, parse_macro_input};
5
6#[proc_macro_derive(ToStream, attributes(stream, field, variant))]
7pub fn derive_to_stream(input: TokenStream) -> TokenStream {
8    let input = parse_macro_input!(input as DeriveInput);
9    match derive_to_stream_inner(&input) {
10        Ok(tokens) => tokens.into(),
11        Err(err) => err.to_compile_error().into(),
12    }
13}
14
15#[proc_macro_derive(FromStream, attributes(stream, field, variant))]
16pub fn derive_from_stream(input: TokenStream) -> TokenStream {
17    let input = parse_macro_input!(input as DeriveInput);
18    match derive_from_stream_inner(&input) {
19        Ok(tokens) => tokens.into(),
20        Err(err) => err.to_compile_error().into(),
21    }
22}
23
24struct StreamAttrs {
25    bounds: Vec<syn::TypeParamBound>,
26}
27
28fn parse_stream_attrs(attrs: &[Attribute]) -> syn::Result<StreamAttrs> {
29    let mut bounds = Vec::new();
30
31    for attr in attrs {
32        if !attr.path().is_ident("stream") {
33            continue;
34        }
35
36        attr.parse_nested_meta(|meta| {
37            if meta.path.is_ident("bounds") {
38                let value = meta.value()?;
39                let lit: Lit = value.parse()?;
40                if let Lit::Str(s) = lit {
41                    let parsed: syn::TypeParamBound = syn::parse_str(&s.value())?;
42                    bounds.push(parsed);
43                    Ok(())
44                } else {
45                    Err(meta.error("Bounds must be a string literal"))
46                }
47            } else {
48                Err(meta.error("Unknown stream attribute"))
49            }
50        })?;
51    }
52
53    Ok(StreamAttrs { bounds })
54}
55
56struct FieldAttrs {
57    ignore: bool,
58    order: Option<usize>,
59}
60
61fn parse_field_attrs(attrs: &[Attribute]) -> syn::Result<FieldAttrs> {
62    let mut ignore = false;
63    let mut order = None;
64
65    for attr in attrs {
66        if !attr.path().is_ident("field") {
67            continue;
68        }
69
70        attr.parse_nested_meta(|meta| {
71            if meta.path.is_ident("ignore") {
72                ignore = true;
73                Ok(())
74            } else if meta.path.is_ident("order") {
75                let value = meta.value()?;
76                let lit: Lit = value.parse()?;
77                if let Lit::Int(i) = lit {
78                    order = Some(i.base10_parse::<usize>()?);
79                    Ok(())
80                } else {
81                    Err(meta.error("Order must be an integer"))
82                }
83            } else {
84                Err(meta.error("Unknown field attribute"))
85            }
86        })?;
87    }
88
89    Ok(FieldAttrs { ignore, order })
90}
91
92struct VariantAttrs {
93    index: Option<usize>,
94}
95
96fn parse_variant_attrs(attrs: &[Attribute]) -> syn::Result<VariantAttrs> {
97    let mut index = None;
98
99    for attr in attrs {
100        if !attr.path().is_ident("variant") {
101            continue;
102        }
103
104        attr.parse_nested_meta(|meta| {
105            if meta.path.is_ident("index") {
106                let value = meta.value()?;
107                let lit: Lit = value.parse()?;
108                if let Lit::Int(i) = lit {
109                    index = Some(i.base10_parse::<usize>()?);
110                    Ok(())
111                } else {
112                    Err(meta.error("Index must be an integer"))
113                }
114            } else {
115                Err(meta.error("Unknown variant attribute"))
116            }
117        })?;
118    }
119
120    Ok(VariantAttrs { index })
121}
122
123struct OrderedField {
124    index: usize,
125    sort_key: usize,
126    attrs: FieldAttrs,
127}
128
129fn compute_field_order(fields: &Fields) -> syn::Result<Vec<OrderedField>> {
130    let mut ordered: Vec<OrderedField> = Vec::new();
131
132    for (i, field) in fields.iter().enumerate() {
133        let attrs = parse_field_attrs(&field.attrs)?;
134        let sort_key = attrs.order.unwrap_or(i);
135        ordered.push(OrderedField {
136            index: i,
137            sort_key,
138            attrs,
139        });
140    }
141
142    ordered.sort_by_key(|f| f.sort_key);
143    Ok(ordered)
144}
145
146enum DeriveMode {
147    To,
148    From,
149}
150
151fn build_impl_generics(
152    input: &DeriveInput,
153    stream_attrs: &StreamAttrs,
154    mode: DeriveMode,
155) -> (TokenStream2, TokenStream2, TokenStream2) {
156    let name = &input.ident;
157
158    let mut impl_params = Vec::new();
159    let mut where_clauses = Vec::new();
160
161    for param in &input.generics.params {
162        impl_params.push(quote! { #param });
163    }
164
165    let bounds = &stream_attrs.bounds;
166    let s_bounds = if bounds.is_empty() {
167        quote! { __S }
168    } else {
169        quote! { __S: #(#bounds)+* }
170    };
171    impl_params.push(s_bounds);
172
173    for param in &input.generics.params {
174        if let syn::GenericParam::Type(t) = param {
175            let ident = &t.ident;
176            match mode {
177                DeriveMode::To => {
178                    where_clauses.push(quote! { #ident: ::data_stream::ToStream<__S> });
179                }
180                DeriveMode::From => {
181                    where_clauses.push(quote! { #ident: ::data_stream::FromStream<__S> });
182                }
183            }
184        }
185    }
186
187    if let Some(wc) = &input.generics.where_clause {
188        for pred in &wc.predicates {
189            where_clauses.push(quote! { #pred });
190        }
191    }
192
193    let impl_block = quote! { <#(#impl_params),*> };
194    let (_, ty_generics, _) = input.generics.split_for_impl();
195    let ty_block = quote! { #name #ty_generics };
196
197    let where_block = if where_clauses.is_empty() {
198        quote! {}
199    } else {
200        quote! { where #(#where_clauses),* }
201    };
202
203    (impl_block, ty_block, where_block)
204}
205
206fn resolve_enum_indices(data: &syn::DataEnum) -> syn::Result<Vec<usize>> {
207    let mut result = Vec::with_capacity(data.variants.len());
208    let mut seen = std::collections::HashSet::new();
209    let mut auto_index = 0;
210
211    for variant in &data.variants {
212        let vattrs = parse_variant_attrs(&variant.attrs)?;
213        let index = match vattrs.index {
214            Some(i) => i,
215            None => auto_index,
216        };
217
218        let index = index as usize;
219        if !seen.insert(index) {
220            return Err(syn::Error::new_spanned(
221                &variant.ident,
222                format!("Duplicate enum variant index: {index}"),
223            ));
224        }
225
226        result.push(index);
227        auto_index = index
228            .checked_add(1)
229            .ok_or_else(|| syn::Error::new_spanned(&variant.ident, "Enum index overflow"))?;
230    }
231
232    Ok(result)
233}
234
235fn derive_to_stream_inner(input: &DeriveInput) -> syn::Result<TokenStream2> {
236    let stream_attrs = parse_stream_attrs(&input.attrs)?;
237    let (impl_gen, ty_gen, where_block) = build_impl_generics(input, &stream_attrs, DeriveMode::To);
238
239    match &input.data {
240        Data::Struct(data) => {
241            derive_to_stream_struct(input, &data.fields, impl_gen, ty_gen, where_block)
242        }
243        Data::Enum(data) => derive_to_stream_enum(input, data, impl_gen, ty_gen, where_block),
244        Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
245    }
246}
247
248fn derive_to_stream_struct(
249    _input: &DeriveInput,
250    fields: &Fields,
251    impl_gen: TokenStream2,
252    ty_gen: TokenStream2,
253    where_block: TokenStream2,
254) -> syn::Result<TokenStream2> {
255    let ordered = compute_field_order(fields)?;
256    let mut write_stmts = Vec::new();
257
258    for of in &ordered {
259        if of.attrs.ignore {
260            continue;
261        }
262
263        let field_access = match fields {
264            Fields::Named(named) => {
265                let ident = named.named[of.index].ident.as_ref().unwrap();
266                quote! { &self.#ident }
267            }
268            Fields::Unnamed(_) => {
269                let index = syn::Index::from(of.index);
270                quote! { &self.#index }
271            }
272            Fields::Unit => unreachable!(),
273        };
274
275        write_stmts.push(quote! {
276            ::data_stream::ToStream::<__S>::to_stream(#field_access, stream)?;
277        });
278    }
279
280    Ok(quote! {
281        impl #impl_gen ::data_stream::ToStream<__S> for #ty_gen #where_block {
282            fn to_stream<__W: ::std::io::Write>(&self, stream: &mut __W) -> ::std::io::Result<()> {
283                #(#write_stmts)*
284                Ok(())
285            }
286        }
287    })
288}
289
290fn derive_to_stream_enum(
291    _input: &DeriveInput,
292    data: &syn::DataEnum,
293    impl_gen: TokenStream2,
294    ty_gen: TokenStream2,
295    where_block: TokenStream2,
296) -> syn::Result<TokenStream2> {
297    let indices = resolve_enum_indices(data)?;
298    let mut match_arms = Vec::new();
299
300    for (variant, disc) in data.variants.iter().zip(indices.iter().copied()) {
301        let vident = &variant.ident;
302
303        let (pattern, field_writes) = match &variant.fields {
304            Fields::Unit => (quote! { Self::#vident }, quote! {}),
305            Fields::Unnamed(fields) => {
306                let bindings: Vec<_> = (0..fields.unnamed.len())
307                    .map(|i| format_ident!("__f{}", i))
308                    .collect();
309                let writes: Vec<_> = bindings
310                    .iter()
311                    .map(|b| {
312                        quote! {
313                            ::data_stream::ToStream::<__S>::to_stream(#b, stream)?;
314                        }
315                    })
316                    .collect();
317                (
318                    quote! { Self::#vident(#(#bindings),*) },
319                    quote! { #(#writes)* },
320                )
321            }
322            Fields::Named(fields) => {
323                let field_idents: Vec<_> = fields
324                    .named
325                    .iter()
326                    .map(|f| f.ident.as_ref().unwrap())
327                    .collect();
328
329                let ordered = compute_field_order(&variant.fields)?;
330                let writes: Vec<_> = ordered
331                    .iter()
332                    .filter(|of| !of.attrs.ignore)
333                    .map(|of| {
334                        let ident = field_idents[of.index];
335                        quote! {
336                            ::data_stream::ToStream::<__S>::to_stream(#ident, stream)?;
337                        }
338                    })
339                    .collect();
340
341                (
342                    quote! { Self::#vident { #(#field_idents),* } },
343                    quote! { #(#writes)* },
344                )
345            }
346        };
347
348        match_arms.push(quote! {
349            #pattern => {
350                ::data_stream::ToStream::<__S>::to_stream(&#disc, stream)?;
351                #field_writes
352            }
353        });
354    }
355
356    Ok(quote! {
357        impl #impl_gen ::data_stream::ToStream<__S> for #ty_gen #where_block {
358            fn to_stream<__W: ::std::io::Write>(&self, stream: &mut __W) -> ::std::io::Result<()> {
359                match self {
360                    #(#match_arms)*
361                }
362                Ok(())
363            }
364        }
365    })
366}
367
368fn derive_from_stream_inner(input: &DeriveInput) -> syn::Result<TokenStream2> {
369    let stream_attrs = parse_stream_attrs(&input.attrs)?;
370    let (impl_gen, ty_gen, where_block) =
371        build_impl_generics(input, &stream_attrs, DeriveMode::From);
372
373    match &input.data {
374        Data::Struct(data) => {
375            derive_from_stream_struct(input, &data.fields, impl_gen, ty_gen, where_block)
376        }
377        Data::Enum(data) => derive_from_stream_enum(input, data, impl_gen, ty_gen, where_block),
378        Data::Union(_) => Err(syn::Error::new_spanned(input, "Unions are not supported")),
379    }
380}
381
382fn derive_from_stream_struct(
383    _input: &DeriveInput,
384    fields: &Fields,
385    impl_gen: TokenStream2,
386    ty_gen: TokenStream2,
387    where_block: TokenStream2,
388) -> syn::Result<TokenStream2> {
389    let ordered = compute_field_order(fields)?;
390    let construct = build_struct_construction(fields, &ordered, None);
391
392    Ok(quote! {
393        impl #impl_gen ::data_stream::FromStream<__S> for #ty_gen #where_block {
394            fn from_stream<__R: ::std::io::Read>(stream: &mut __R) -> ::std::io::Result<Self> {
395                #construct
396            }
397        }
398    })
399}
400
401fn build_struct_construction(
402    fields: &Fields,
403    ordered: &[OrderedField],
404    self_path: Option<&TokenStream2>,
405) -> TokenStream2 {
406    match fields {
407        Fields::Named(named) => {
408            let mut read_stmts = Vec::new();
409            let mut field_inits = Vec::new();
410
411            for of in ordered {
412                let field = &named.named[of.index];
413                let ident = field.ident.as_ref().unwrap();
414                let ty = &field.ty;
415
416                if of.attrs.ignore {
417                    field_inits.push(quote! {
418                        #ident: ::std::default::Default::default()
419                    });
420                } else {
421                    let temp = format_ident!("__field_{}", ident);
422                    read_stmts.push(quote! {
423                        let #temp: #ty = ::data_stream::FromStream::<__S>::from_stream(stream)?;
424                    });
425                    field_inits.push(quote! {
426                        #ident: #temp
427                    });
428                }
429            }
430
431            let prefix = self_path.map_or_else(|| quote! { Self }, |p| quote! { #p });
432
433            quote! {
434                #(#read_stmts)*
435                Ok(#prefix { #(#field_inits),* })
436            }
437        }
438        Fields::Unnamed(unnamed) => {
439            let mut read_stmts = Vec::new();
440            let mut read_temp_by_index: Vec<Option<syn::Ident>> = vec![None; unnamed.unnamed.len()];
441            let mut ignored_by_index = vec![false; unnamed.unnamed.len()];
442
443            for of in ordered {
444                let ty = &unnamed.unnamed[of.index].ty;
445                ignored_by_index[of.index] = of.attrs.ignore;
446
447                if !of.attrs.ignore {
448                    let temp = format_ident!("__field_{}", of.index);
449                    read_stmts.push(quote! {
450                        let #temp: #ty = ::data_stream::FromStream::<__S>::from_stream(stream)?;
451                    });
452                    read_temp_by_index[of.index] = Some(temp);
453                }
454            }
455
456            let mut field_values = Vec::new();
457            for index in 0..unnamed.unnamed.len() {
458                if ignored_by_index[index] {
459                    field_values.push(quote! {
460                        ::std::default::Default::default()
461                    });
462                } else {
463                    let temp = read_temp_by_index[index].as_ref().unwrap();
464                    field_values.push(quote! { #temp });
465                }
466            }
467
468            let prefix = self_path.map_or_else(|| quote! { Self }, |p| quote! { #p });
469
470            quote! {
471                #(#read_stmts)*
472                Ok(#prefix(#(#field_values),*))
473            }
474        }
475        Fields::Unit => {
476            let prefix = self_path.map_or_else(|| quote! { Self }, |p| quote! { #p });
477            quote! { Ok(#prefix) }
478        }
479    }
480}
481
482fn derive_from_stream_enum(
483    _input: &DeriveInput,
484    data: &syn::DataEnum,
485    impl_gen: TokenStream2,
486    ty_gen: TokenStream2,
487    where_block: TokenStream2,
488) -> syn::Result<TokenStream2> {
489    let indices = resolve_enum_indices(data)?;
490    let mut match_arms = Vec::new();
491
492    for (variant, disc) in data.variants.iter().zip(indices.iter().copied()) {
493        let vident = &variant.ident;
494        let self_path = quote! { Self::#vident };
495
496        let ordered = compute_field_order(&variant.fields)?;
497        let construct = build_struct_construction(&variant.fields, &ordered, Some(&self_path));
498
499        match_arms.push(quote! {
500            #disc => { #construct }
501        });
502    }
503
504    Ok(quote! {
505        impl #impl_gen ::data_stream::FromStream<__S> for #ty_gen #where_block {
506            fn from_stream<__R: ::std::io::Read>(stream: &mut __R) -> ::std::io::Result<Self> {
507                let __discriminant: usize = ::data_stream::FromStream::<__S>::from_stream(stream)?;
508                match __discriminant {
509                    #(#match_arms)*
510                    other => Err(::std::io::Error::new(
511                        ::std::io::ErrorKind::InvalidData,
512                        ::std::format!("Invalid enum discriminant: {}", other),
513                    )),
514                }
515            }
516        }
517    })
518}