1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{Attribute, Data, DeriveInput, Fields, Lit, parse_macro_input};
5
6#[proc_macro_derive(ToStream, attributes(stream, field, variant))]
7pub fn derive_to_stream(input: TokenStream) -> TokenStream {
8 let input = parse_macro_input!(input as DeriveInput);
9 match derive_to_stream_inner(&input) {
10 Ok(tokens) => tokens.into(),
11 Err(err) => err.to_compile_error().into(),
12 }
13}
14
15#[proc_macro_derive(FromStream, attributes(stream, field, variant))]
16pub fn derive_from_stream(input: TokenStream) -> TokenStream {
17 let input = parse_macro_input!(input as DeriveInput);
18 match derive_from_stream_inner(&input) {
19 Ok(tokens) => tokens.into(),
20 Err(err) => err.to_compile_error().into(),
21 }
22}
23
24struct StreamAttrs {
25 bounds: Vec<syn::TypeParamBound>,
26}
27
28fn parse_stream_attrs(attrs: &[Attribute]) -> syn::Result<StreamAttrs> {
29 let mut bounds = Vec::new();
30
31 for attr in attrs {
32 if !attr.path().is_ident("stream") {
33 continue;
34 }
35
36 attr.parse_nested_meta(|meta| {
37 if meta.path.is_ident("bounds") {
38 let value = meta.value()?;
39 let lit: Lit = value.parse()?;
40 if let Lit::Str(s) = lit {
41 let parsed: syn::TypeParamBound = syn::parse_str(&s.value())?;
42 bounds.push(parsed);
43 Ok(())
44 } else {
45 Err(meta.error("Bounds must be a string literal"))
46 }
47 } else {
48 Err(meta.error("Unknown stream attribute"))
49 }
50 })?;
51 }
52
53 Ok(StreamAttrs { bounds })
54}
55
56struct FieldAttrs {
57 ignore: bool,
58 order: Option<usize>,
59}
60
61fn parse_field_attrs(attrs: &[Attribute]) -> syn::Result<FieldAttrs> {
62 let mut ignore = false;
63 let mut order = None;
64
65 for attr in attrs {
66 if !attr.path().is_ident("field") {
67 continue;
68 }
69
70 attr.parse_nested_meta(|meta| {
71 if meta.path.is_ident("ignore") {
72 ignore = true;
73 Ok(())
74 } else if meta.path.is_ident("order") {
75 let value = meta.value()?;
76 let lit: Lit = value.parse()?;
77 if let Lit::Int(i) = lit {
78 order = Some(i.base10_parse::<usize>()?);
79 Ok(())
80 } else {
81 Err(meta.error("Order must be an integer"))
82 }
83 } else {
84 Err(meta.error("Unknown field attribute"))
85 }
86 })?;
87 }
88
89 Ok(FieldAttrs { ignore, order })
90}
91
92struct VariantAttrs {
93 index: Option<usize>,
94}
95
96fn parse_variant_attrs(attrs: &[Attribute]) -> syn::Result<VariantAttrs> {
97 let mut index = None;
98
99 for attr in attrs {
100 if !attr.path().is_ident("variant") {
101 continue;
102 }
103
104 attr.parse_nested_meta(|meta| {
105 if meta.path.is_ident("index") {
106 let value = meta.value()?;
107 let lit: Lit = value.parse()?;
108 if let Lit::Int(i) = lit {
109 index = Some(i.base10_parse::<usize>()?);
110 Ok(())
111 } else {
112 Err(meta.error("Index must be an integer"))
113 }
114 } else {
115 Err(meta.error("Unknown variant attribute"))
116 }
117 })?;
118 }
119
120 Ok(VariantAttrs { index })
121}
122
123struct OrderedField {
124 index: usize,
125 sort_key: usize,
126 attrs: FieldAttrs,
127}
128
129fn compute_field_order(fields: &Fields) -> syn::Result<Vec<OrderedField>> {
130 let mut ordered: Vec<OrderedField> = Vec::new();
131
132 for (i, field) in fields.iter().enumerate() {
133 let attrs = parse_field_attrs(&field.attrs)?;
134 let sort_key = attrs.order.unwrap_or(i);
135 ordered.push(OrderedField {
136 index: i,
137 sort_key,
138 attrs,
139 });
140 }
141
142 ordered.sort_by_key(|f| f.sort_key);
143 Ok(ordered)
144}
145
146enum DeriveMode {
147 To,
148 From,
149}
150
151fn build_impl_generics(
152 input: &DeriveInput,
153 stream_attrs: &StreamAttrs,
154 mode: DeriveMode,
155) -> (TokenStream2, TokenStream2, TokenStream2) {
156 let name = &input.ident;
157
158 let mut impl_params = Vec::new();
159 let mut where_clauses = Vec::new();
160
161 for param in &input.generics.params {
162 impl_params.push(quote! { #param });
163 }
164
165 let bounds = &stream_attrs.bounds;
166 let s_bounds = if bounds.is_empty() {
167 quote! { __S }
168 } else {
169 quote! { __S: #(#bounds)+* }
170 };
171 impl_params.push(s_bounds);
172
173 for param in &input.generics.params {
174 if let syn::GenericParam::Type(t) = param {
175 let ident = &t.ident;
176 match mode {
177 DeriveMode::To => {
178 where_clauses.push(quote! { #ident: ::data_stream::ToStream<__S> });
179 }
180 DeriveMode::From => {
181 where_clauses.push(quote! { #ident: ::data_stream::FromStream<__S> });
182 }
183 }
184 }
185 }
186
187 if let Some(wc) = &input.generics.where_clause {
188 for pred in &wc.predicates {
189 where_clauses.push(quote! { #pred });
190 }
191 }
192
193 let impl_block = quote! { <#(#impl_params),*> };
194 let (_, ty_generics, _) = input.generics.split_for_impl();
195 let ty_block = quote! { #name #ty_generics };
196
197 let where_block = if where_clauses.is_empty() {
198 quote! {}
199 } else {
200 quote! { where #(#where_clauses),* }
201 };
202
203 (impl_block, ty_block, where_block)
204}
205
206fn resolve_enum_indices(data: &syn::DataEnum) -> syn::Result<Vec<usize>> {
207 let mut result = Vec::with_capacity(data.variants.len());
208 let mut seen = std::collections::HashSet::new();
209 let mut auto_index = 0;
210
211 for variant in &data.variants {
212 let vattrs = parse_variant_attrs(&variant.attrs)?;
213 let index = match vattrs.index {
214 Some(i) => i,
215 None => auto_index,
216 };
217
218 let index = index as usize;
219 if !seen.insert(index) {
220 return Err(syn::Error::new_spanned(
221 &variant.ident,
222 format!("Duplicate enum variant index: {index}"),
223 ));
224 }
225
226 result.push(index);
227 auto_index = index
228 .checked_add(1)
229 .ok_or_else(|| syn::Error::new_spanned(&variant.ident, "Enum index overflow"))?;
230 }
231
232 Ok(result)
233}
234
235fn derive_to_stream_inner(input: &DeriveInput) -> syn::Result<TokenStream2> {
236 let stream_attrs = parse_stream_attrs(&input.attrs)?;
237 let (impl_gen, ty_gen, where_block) = build_impl_generics(input, &stream_attrs, DeriveMode::To);
238
239 match &input.data {
240 Data::Struct(data) => {
241 derive_to_stream_struct(input, &data.fields, impl_gen, ty_gen, where_block)
242 }
243 Data::Enum(data) => derive_to_stream_enum(input, data, impl_gen, ty_gen, where_block),
244 Data::Union(_) => Err(syn::Error::new_spanned(input, "unions are not supported")),
245 }
246}
247
248fn derive_to_stream_struct(
249 _input: &DeriveInput,
250 fields: &Fields,
251 impl_gen: TokenStream2,
252 ty_gen: TokenStream2,
253 where_block: TokenStream2,
254) -> syn::Result<TokenStream2> {
255 let ordered = compute_field_order(fields)?;
256 let mut write_stmts = Vec::new();
257
258 for of in &ordered {
259 if of.attrs.ignore {
260 continue;
261 }
262
263 let field_access = match fields {
264 Fields::Named(named) => {
265 let ident = named.named[of.index].ident.as_ref().unwrap();
266 quote! { &self.#ident }
267 }
268 Fields::Unnamed(_) => {
269 let index = syn::Index::from(of.index);
270 quote! { &self.#index }
271 }
272 Fields::Unit => unreachable!(),
273 };
274
275 write_stmts.push(quote! {
276 ::data_stream::ToStream::<__S>::to_stream(#field_access, stream)?;
277 });
278 }
279
280 Ok(quote! {
281 impl #impl_gen ::data_stream::ToStream<__S> for #ty_gen #where_block {
282 fn to_stream<__W: ::std::io::Write>(&self, stream: &mut __W) -> ::std::io::Result<()> {
283 #(#write_stmts)*
284 Ok(())
285 }
286 }
287 })
288}
289
290fn derive_to_stream_enum(
291 _input: &DeriveInput,
292 data: &syn::DataEnum,
293 impl_gen: TokenStream2,
294 ty_gen: TokenStream2,
295 where_block: TokenStream2,
296) -> syn::Result<TokenStream2> {
297 let indices = resolve_enum_indices(data)?;
298 let mut match_arms = Vec::new();
299
300 for (variant, disc) in data.variants.iter().zip(indices.iter().copied()) {
301 let vident = &variant.ident;
302
303 let (pattern, field_writes) = match &variant.fields {
304 Fields::Unit => (quote! { Self::#vident }, quote! {}),
305 Fields::Unnamed(fields) => {
306 let bindings: Vec<_> = (0..fields.unnamed.len())
307 .map(|i| format_ident!("__f{}", i))
308 .collect();
309 let writes: Vec<_> = bindings
310 .iter()
311 .map(|b| {
312 quote! {
313 ::data_stream::ToStream::<__S>::to_stream(#b, stream)?;
314 }
315 })
316 .collect();
317 (
318 quote! { Self::#vident(#(#bindings),*) },
319 quote! { #(#writes)* },
320 )
321 }
322 Fields::Named(fields) => {
323 let field_idents: Vec<_> = fields
324 .named
325 .iter()
326 .map(|f| f.ident.as_ref().unwrap())
327 .collect();
328
329 let ordered = compute_field_order(&variant.fields)?;
330 let writes: Vec<_> = ordered
331 .iter()
332 .filter(|of| !of.attrs.ignore)
333 .map(|of| {
334 let ident = field_idents[of.index];
335 quote! {
336 ::data_stream::ToStream::<__S>::to_stream(#ident, stream)?;
337 }
338 })
339 .collect();
340
341 (
342 quote! { Self::#vident { #(#field_idents),* } },
343 quote! { #(#writes)* },
344 )
345 }
346 };
347
348 match_arms.push(quote! {
349 #pattern => {
350 ::data_stream::ToStream::<__S>::to_stream(&#disc, stream)?;
351 #field_writes
352 }
353 });
354 }
355
356 Ok(quote! {
357 impl #impl_gen ::data_stream::ToStream<__S> for #ty_gen #where_block {
358 fn to_stream<__W: ::std::io::Write>(&self, stream: &mut __W) -> ::std::io::Result<()> {
359 match self {
360 #(#match_arms)*
361 }
362 Ok(())
363 }
364 }
365 })
366}
367
368fn derive_from_stream_inner(input: &DeriveInput) -> syn::Result<TokenStream2> {
369 let stream_attrs = parse_stream_attrs(&input.attrs)?;
370 let (impl_gen, ty_gen, where_block) =
371 build_impl_generics(input, &stream_attrs, DeriveMode::From);
372
373 match &input.data {
374 Data::Struct(data) => {
375 derive_from_stream_struct(input, &data.fields, impl_gen, ty_gen, where_block)
376 }
377 Data::Enum(data) => derive_from_stream_enum(input, data, impl_gen, ty_gen, where_block),
378 Data::Union(_) => Err(syn::Error::new_spanned(input, "Unions are not supported")),
379 }
380}
381
382fn derive_from_stream_struct(
383 _input: &DeriveInput,
384 fields: &Fields,
385 impl_gen: TokenStream2,
386 ty_gen: TokenStream2,
387 where_block: TokenStream2,
388) -> syn::Result<TokenStream2> {
389 let ordered = compute_field_order(fields)?;
390 let construct = build_struct_construction(fields, &ordered, None);
391
392 Ok(quote! {
393 impl #impl_gen ::data_stream::FromStream<__S> for #ty_gen #where_block {
394 fn from_stream<__R: ::std::io::Read>(stream: &mut __R) -> ::std::io::Result<Self> {
395 #construct
396 }
397 }
398 })
399}
400
401fn build_struct_construction(
402 fields: &Fields,
403 ordered: &[OrderedField],
404 self_path: Option<&TokenStream2>,
405) -> TokenStream2 {
406 match fields {
407 Fields::Named(named) => {
408 let mut read_stmts = Vec::new();
409 let mut field_inits = Vec::new();
410
411 for of in ordered {
412 let field = &named.named[of.index];
413 let ident = field.ident.as_ref().unwrap();
414 let ty = &field.ty;
415
416 if of.attrs.ignore {
417 field_inits.push(quote! {
418 #ident: ::std::default::Default::default()
419 });
420 } else {
421 let temp = format_ident!("__field_{}", ident);
422 read_stmts.push(quote! {
423 let #temp: #ty = ::data_stream::FromStream::<__S>::from_stream(stream)?;
424 });
425 field_inits.push(quote! {
426 #ident: #temp
427 });
428 }
429 }
430
431 let prefix = self_path.map_or_else(|| quote! { Self }, |p| quote! { #p });
432
433 quote! {
434 #(#read_stmts)*
435 Ok(#prefix { #(#field_inits),* })
436 }
437 }
438 Fields::Unnamed(unnamed) => {
439 let mut read_stmts = Vec::new();
440 let mut read_temp_by_index: Vec<Option<syn::Ident>> = vec![None; unnamed.unnamed.len()];
441 let mut ignored_by_index = vec![false; unnamed.unnamed.len()];
442
443 for of in ordered {
444 let ty = &unnamed.unnamed[of.index].ty;
445 ignored_by_index[of.index] = of.attrs.ignore;
446
447 if !of.attrs.ignore {
448 let temp = format_ident!("__field_{}", of.index);
449 read_stmts.push(quote! {
450 let #temp: #ty = ::data_stream::FromStream::<__S>::from_stream(stream)?;
451 });
452 read_temp_by_index[of.index] = Some(temp);
453 }
454 }
455
456 let mut field_values = Vec::new();
457 for index in 0..unnamed.unnamed.len() {
458 if ignored_by_index[index] {
459 field_values.push(quote! {
460 ::std::default::Default::default()
461 });
462 } else {
463 let temp = read_temp_by_index[index].as_ref().unwrap();
464 field_values.push(quote! { #temp });
465 }
466 }
467
468 let prefix = self_path.map_or_else(|| quote! { Self }, |p| quote! { #p });
469
470 quote! {
471 #(#read_stmts)*
472 Ok(#prefix(#(#field_values),*))
473 }
474 }
475 Fields::Unit => {
476 let prefix = self_path.map_or_else(|| quote! { Self }, |p| quote! { #p });
477 quote! { Ok(#prefix) }
478 }
479 }
480}
481
482fn derive_from_stream_enum(
483 _input: &DeriveInput,
484 data: &syn::DataEnum,
485 impl_gen: TokenStream2,
486 ty_gen: TokenStream2,
487 where_block: TokenStream2,
488) -> syn::Result<TokenStream2> {
489 let indices = resolve_enum_indices(data)?;
490 let mut match_arms = Vec::new();
491
492 for (variant, disc) in data.variants.iter().zip(indices.iter().copied()) {
493 let vident = &variant.ident;
494 let self_path = quote! { Self::#vident };
495
496 let ordered = compute_field_order(&variant.fields)?;
497 let construct = build_struct_construction(&variant.fields, &ordered, Some(&self_path));
498
499 match_arms.push(quote! {
500 #disc => { #construct }
501 });
502 }
503
504 Ok(quote! {
505 impl #impl_gen ::data_stream::FromStream<__S> for #ty_gen #where_block {
506 fn from_stream<__R: ::std::io::Read>(stream: &mut __R) -> ::std::io::Result<Self> {
507 let __discriminant: usize = ::data_stream::FromStream::<__S>::from_stream(stream)?;
508 match __discriminant {
509 #(#match_arms)*
510 other => Err(::std::io::Error::new(
511 ::std::io::ErrorKind::InvalidData,
512 ::std::format!("Invalid enum discriminant: {}", other),
513 )),
514 }
515 }
516 }
517 })
518}