Skip to main content

reliakit_derive/
lib.rs

1//! Derive macros for `reliakit` traits.
2//!
3//! This crate provides `#[derive(...)]` support for the trait pairs defined by
4//! other `reliakit-*` crates. It is written using only the standard
5//! [`proc_macro`] API and pulls in no third-party crates. To stay free of a
6//! full Rust-grammar parser, it reads only what the generated code needs — the
7//! type name and its field shape — and rejects constructs it does not yet
8//! handle with a clear compile error rather than guessing.
9//!
10//! # Supported types
11//!
12//! - structs with named fields
13//! - tuple structs
14//! - unit structs
15//! - enums with unit, tuple, and struct variants
16//!
17//! Unions, generic types, generic enums, enums with explicit discriminants or a
18//! `#[repr(...)]`, and empty enums are rejected with a compile error. The JSON
19//! derives currently cover structs only; enums are rejected for now. The CSV
20//! derives cover only structs with named fields, since CSV columns need names.
21//!
22//! # `reliakit-codec`
23//!
24//! [`CanonicalEncode`] and [`CanonicalDecode`] generate implementations of the
25//! same-named traits from `reliakit-codec`, encoding each field in declaration
26//! order. The derived code is exactly what a handwritten implementation would
27//! be — one `encode`/`decode` call per field, in order.
28//!
29//! ```
30//! # // The derives reference `::reliakit_codec`, which must be a dependency of
31//! # // the crate that uses them.
32//! use reliakit_codec::{decode_from_slice_exact, encode_to_vec};
33//! use reliakit_derive::{CanonicalDecode, CanonicalEncode};
34//!
35//! #[derive(Debug, PartialEq, CanonicalEncode, CanonicalDecode)]
36//! struct Point {
37//!     x: u16,
38//!     y: u16,
39//! }
40//!
41//! let encoded = encode_to_vec(&Point { x: 10, y: 20 }).unwrap();
42//! assert_eq!(encoded, [10, 0, 20, 0]);
43//! assert_eq!(decode_from_slice_exact::<Point>(&encoded).unwrap(), Point { x: 10, y: 20 });
44//! ```
45//!
46//! Enums are supported too. Each variant is tagged by its zero-based
47//! declaration index, encoded as a little-endian `u32`, followed by the
48//! variant's fields in declaration order:
49//!
50//! ```
51//! use reliakit_codec::{decode_from_slice_exact, encode_to_vec};
52//! use reliakit_derive::{CanonicalDecode, CanonicalEncode};
53//!
54//! #[derive(Debug, PartialEq, CanonicalEncode, CanonicalDecode)]
55//! enum Message {
56//!     Ping,
57//!     Pong,
58//! }
59//!
60//! assert_eq!(encode_to_vec(&Message::Ping).unwrap(), [0, 0, 0, 0]);
61//! assert_eq!(encode_to_vec(&Message::Pong).unwrap(), [1, 0, 0, 0]);
62//! assert_eq!(decode_from_slice_exact::<Message>(&[1, 0, 0, 0]).unwrap(), Message::Pong);
63//! ```
64//!
65//! # `reliakit-json`
66//!
67//! [`JsonEncode`] and [`JsonDecode`] generate implementations of the same-named
68//! `reliakit-json` traits. A struct with named fields becomes a JSON object in
69//! declaration order, a tuple struct becomes an array, and a unit struct
70//! becomes `null`. Decoding is strict; unknown object fields are ignored.
71//!
72//! ```
73//! use reliakit_derive::{JsonDecode, JsonEncode};
74//! use reliakit_json::{from_json_str, to_json_string};
75//!
76//! #[derive(Debug, PartialEq, JsonEncode, JsonDecode)]
77//! struct Point {
78//!     x: u16,
79//!     y: u16,
80//! }
81//!
82//! let json = to_json_string(&Point { x: 10, y: 20 });
83//! assert_eq!(json, r#"{"x":10,"y":20}"#);
84//! assert_eq!(from_json_str::<Point>(&json).unwrap(), Point { x: 10, y: 20 });
85//! ```
86//!
87//! # `reliakit-csv`
88//!
89//! [`CsvEncode`] and [`CsvDecode`] generate implementations of the same-named
90//! `reliakit-csv` traits. A struct with named fields becomes a CSV row, one
91//! column per field in declaration order, with the field names as the header.
92//! Because CSV columns need names, only structs with named fields are supported
93//! — tuple structs, unit structs, and enums are rejected. Decoding is strict:
94//! the row must have one field per struct field, and each must parse.
95//!
96//! ```
97//! use reliakit_csv::{from_csv_str, to_csv_string};
98//! use reliakit_derive::{CsvDecode, CsvEncode};
99//!
100//! #[derive(Debug, PartialEq, CsvEncode, CsvDecode)]
101//! struct Row {
102//!     id: u32,
103//!     name: String,
104//! }
105//!
106//! let rows = vec![Row { id: 1, name: "ada".into() }];
107//! let csv = to_csv_string(&rows);
108//! assert_eq!(csv, "id,name\r\n1,ada\r\n");
109//! assert_eq!(from_csv_str::<Row>(&csv).unwrap(), rows);
110//! ```
111
112#![forbid(unsafe_code)]
113#![warn(missing_docs)]
114
115use proc_macro::{Delimiter, Spacing, TokenStream, TokenTree};
116
117/// Derives `reliakit_codec::CanonicalEncode`, encoding each field in
118/// declaration order (for enums, the variant tag first).
119///
120/// See the [crate] documentation for supported types and limitations.
121#[proc_macro_derive(CanonicalEncode)]
122pub fn derive_canonical_encode(input: TokenStream) -> TokenStream {
123    match Parsed::from_input(input) {
124        Ok(parsed) => parsed.canonical_encode_impl(),
125        Err(message) => compile_error(&message),
126    }
127}
128
129/// Derives `reliakit_codec::CanonicalDecode`, decoding each field in
130/// declaration order (for enums, the variant tag first).
131///
132/// See the [crate] documentation for supported types and limitations.
133#[proc_macro_derive(CanonicalDecode)]
134pub fn derive_canonical_decode(input: TokenStream) -> TokenStream {
135    match Parsed::from_input(input) {
136        Ok(parsed) => parsed.canonical_decode_impl(),
137        Err(message) => compile_error(&message),
138    }
139}
140
141/// Derives `reliakit_json::JsonEncode`: a struct with named fields becomes a
142/// JSON object (in declaration order), a tuple struct becomes a JSON array, and
143/// a unit struct becomes `null`.
144///
145/// Enums are not supported yet. See the [crate] documentation.
146#[proc_macro_derive(JsonEncode)]
147pub fn derive_json_encode(input: TokenStream) -> TokenStream {
148    match Parsed::from_input(input).and_then(|parsed| parsed.json_encode_impl()) {
149        Ok(tokens) => tokens,
150        Err(message) => compile_error(&message),
151    }
152}
153
154/// Derives `reliakit_json::JsonDecode`, the inverse of [`macro@JsonEncode`].
155/// Decoding is strict: the JSON shape must match, and required object fields
156/// must be present; unknown object fields are ignored.
157///
158/// Enums are not supported yet. See the [crate] documentation.
159#[proc_macro_derive(JsonDecode)]
160pub fn derive_json_decode(input: TokenStream) -> TokenStream {
161    match Parsed::from_input(input).and_then(|parsed| parsed.json_decode_impl()) {
162        Ok(tokens) => tokens,
163        Err(message) => compile_error(&message),
164    }
165}
166
167/// Derives `reliakit_csv::CsvEncode`: a struct with named fields becomes a row,
168/// one column per field in declaration order, with the field names as the
169/// header.
170///
171/// Only structs with named fields are supported — CSV columns need names, so
172/// tuple structs, unit structs, and enums are rejected. See the [crate]
173/// documentation.
174#[proc_macro_derive(CsvEncode)]
175pub fn derive_csv_encode(input: TokenStream) -> TokenStream {
176    match Parsed::from_input(input).and_then(|parsed| parsed.csv_encode_impl()) {
177        Ok(tokens) => tokens,
178        Err(message) => compile_error(&message),
179    }
180}
181
182/// Derives `reliakit_csv::CsvDecode`, the inverse of [`macro@CsvEncode`].
183/// Decoding is strict: the row must have exactly one field per struct field,
184/// and each field must parse into its target type.
185///
186/// Only structs with named fields are supported. See the [crate] documentation.
187#[proc_macro_derive(CsvDecode)]
188pub fn derive_csv_decode(input: TokenStream) -> TokenStream {
189    match Parsed::from_input(input).and_then(|parsed| parsed.csv_decode_impl()) {
190        Ok(tokens) => tokens,
191        Err(message) => compile_error(&message),
192    }
193}
194
195/// Which item keyword the derive input started with.
196enum Kind {
197    Struct,
198    Enum,
199    Union,
200}
201
202/// The field shape of a struct body or a single enum variant, reduced to
203/// exactly what the generated code needs.
204enum Shape {
205    /// Named fields, in declaration order.
206    Named(Vec<String>),
207    /// Tuple fields, by count.
208    Tuple(usize),
209    /// No fields (unit struct or unit variant).
210    Unit,
211}
212
213/// One validated enum variant: its name and field shape.
214struct Variant {
215    name: String,
216    shape: Shape,
217}
218
219/// The validated body the derive will implement.
220enum Body {
221    /// A struct with the given field shape.
222    Struct(Shape),
223    /// An enum with the given variants, in declaration order.
224    Enum(Vec<Variant>),
225}
226
227/// A validated item ready for code generation.
228struct Parsed {
229    name: String,
230    body: Body,
231}
232
233/// One enum variant as read from tokens, before validation.
234struct RawVariant {
235    name: String,
236    /// The variant's field shape, or a message if its syntax is unsupported.
237    shape: Result<Shape, String>,
238    /// Whether the variant carried an explicit `= discriminant`.
239    has_discriminant: bool,
240}
241
242/// The item body as read from tokens, before validation.
243enum RawBody {
244    Struct(Shape),
245    Enum(Vec<RawVariant>),
246    Union,
247}
248
249/// The whole item as read from tokens, before any semantic validation. Kept
250/// free of `proc_macro` types so [`validate`] is pure and unit-testable.
251struct Raw {
252    name: String,
253    has_generics: bool,
254    saw_repr: bool,
255    body: RawBody,
256}
257
258impl Parsed {
259    /// Reads and validates a derive input.
260    fn from_input(input: TokenStream) -> Result<Self, String> {
261        validate(classify(input)?)
262    }
263
264    fn canonical_encode_impl(&self) -> TokenStream {
265        let statements = match &self.body {
266            Body::Struct(shape) => struct_encode_statements(shape),
267            Body::Enum(variants) => enum_encode_statements(variants),
268        };
269
270        format!(
271            "impl ::reliakit_codec::CanonicalEncode for {name} {{\n\
272             fn encode<__W: ::reliakit_codec::EncodeSink + ?Sized>(&self, __writer: &mut __W) \
273             -> ::core::result::Result<(), ::reliakit_codec::CodecError> {{\n\
274             {statements}\n\
275             ::core::result::Result::Ok(())\n\
276             }}\n\
277             }}",
278            name = self.name,
279        )
280        .parse()
281        .expect("reliakit-derive generated invalid CanonicalEncode tokens")
282    }
283
284    fn canonical_decode_impl(&self) -> TokenStream {
285        let value = match &self.body {
286            Body::Struct(shape) => struct_decode_value(shape),
287            Body::Enum(variants) => enum_decode_value(&self.name, variants),
288        };
289
290        format!(
291            "impl ::reliakit_codec::CanonicalDecode for {name} {{\n\
292             fn decode<__R: ::reliakit_codec::DecodeSource + ?Sized>(__reader: &mut __R) \
293             -> ::core::result::Result<Self, ::reliakit_codec::CodecError> {{\n\
294             {value}\n\
295             }}\n\
296             }}",
297            name = self.name,
298        )
299        .parse()
300        .expect("reliakit-derive generated invalid CanonicalDecode tokens")
301    }
302
303    fn json_encode_impl(&self) -> Result<TokenStream, String> {
304        let value = match &self.body {
305            Body::Struct(shape) => json_encode_value(shape),
306            Body::Enum(_) => {
307                return Err("reliakit-derive: JsonEncode does not support enums yet".into())
308            }
309        };
310
311        Ok(format!(
312            "impl ::reliakit_json::JsonEncode for {name} {{\n\
313             fn to_json_value(&self) -> ::reliakit_json::JsonValue {{\n\
314             {value}\n\
315             }}\n\
316             }}",
317            name = self.name,
318        )
319        .parse()
320        .expect("reliakit-derive generated invalid JsonEncode tokens"))
321    }
322
323    fn json_decode_impl(&self) -> Result<TokenStream, String> {
324        let body = match &self.body {
325            Body::Struct(shape) => json_decode_body(shape),
326            Body::Enum(_) => {
327                return Err("reliakit-derive: JsonDecode does not support enums yet".into())
328            }
329        };
330
331        Ok(format!(
332            "impl ::reliakit_json::JsonDecode for {name} {{\n\
333             fn from_json_value(__value: &::reliakit_json::JsonValue) \
334             -> ::core::result::Result<Self, ::reliakit_json::JsonDecodeError> {{\n\
335             {body}\n\
336             }}\n\
337             }}",
338            name = self.name,
339        )
340        .parse()
341        .expect("reliakit-derive generated invalid JsonDecode tokens"))
342    }
343
344    fn csv_encode_impl(&self) -> Result<TokenStream, String> {
345        let fields = csv_named_fields(&self.body, "CsvEncode")?;
346        let methods = csv_encode_methods(fields);
347        Ok(format!(
348            "impl ::reliakit_csv::CsvEncode for {name} {{\n{methods}\n}}",
349            name = self.name,
350        )
351        .parse()
352        .expect("reliakit-derive generated invalid CsvEncode tokens"))
353    }
354
355    fn csv_decode_impl(&self) -> Result<TokenStream, String> {
356        let fields = csv_named_fields(&self.body, "CsvDecode")?;
357        let method = csv_decode_method(fields);
358        Ok(format!(
359            "impl ::reliakit_csv::CsvDecode for {name} {{\n{method}\n}}",
360            name = self.name,
361        )
362        .parse()
363        .expect("reliakit-derive generated invalid CsvDecode tokens"))
364    }
365}
366
367/// The JSON object key for a field: a raw identifier's `r#` prefix is dropped.
368fn json_key(field: &str) -> &str {
369    field.strip_prefix("r#").unwrap_or(field)
370}
371
372/// The body of a struct's `JsonEncode::to_json_value`.
373fn json_encode_value(shape: &Shape) -> String {
374    match shape {
375        Shape::Named(fields) => {
376            let mut inserts = String::new();
377            for field in fields {
378                let key = json_key(field);
379                inserts.push_str(&format!(
380                    "__object.insert({key:?}.into(), \
381                     ::reliakit_json::JsonEncode::to_json_value(&self.{field}));",
382                ));
383            }
384            format!(
385                "let mut __object = ::reliakit_json::JsonObject::new();\n\
386                 {inserts}\n\
387                 ::reliakit_json::JsonValue::Object(__object)"
388            )
389        }
390        Shape::Tuple(count) => {
391            let mut items = String::new();
392            for index in 0..*count {
393                items.push_str(&format!(
394                    "::reliakit_json::JsonEncode::to_json_value(&self.{index}),"
395                ));
396            }
397            format!("::reliakit_json::JsonValue::array([{items}])")
398        }
399        Shape::Unit => "::reliakit_json::JsonValue::Null".to_string(),
400    }
401}
402
403/// The body of a struct's `JsonDecode::from_json_value`.
404fn json_decode_body(shape: &Shape) -> String {
405    match shape {
406        Shape::Named(fields) => {
407            let mut inner = String::new();
408            for field in fields {
409                let key = json_key(field);
410                let missing = format!("missing field `{key}`");
411                inner.push_str(&format!(
412                    "{field}: ::reliakit_json::JsonDecode::from_json_value(\
413                     __object.get({key:?}).ok_or_else(|| \
414                     ::reliakit_json::JsonDecodeError::missing_field({missing:?}))?)?,",
415                ));
416            }
417            format!(
418                "let __object = __value.as_object().ok_or_else(|| \
419                 ::reliakit_json::JsonDecodeError::unexpected_type(\"expected a JSON object\"))?;\n\
420                 ::core::result::Result::Ok(Self {{ {inner} }})"
421            )
422        }
423        Shape::Tuple(count) => {
424            let mut inner = String::new();
425            for index in 0..*count {
426                inner.push_str(&format!(
427                    "::reliakit_json::JsonDecode::from_json_value(&__array[{index}])?,"
428                ));
429            }
430            format!(
431                "let __array = __value.as_array().ok_or_else(|| \
432                 ::reliakit_json::JsonDecodeError::unexpected_type(\"expected a JSON array\"))?;\n\
433                 if __array.len() != {count} {{ return ::core::result::Result::Err(\
434                 ::reliakit_json::JsonDecodeError::unexpected_type(\
435                 \"JSON array has the wrong number of elements\")); }}\n\
436                 ::core::result::Result::Ok(Self({inner}))"
437            )
438        }
439        Shape::Unit => "if !__value.is_null() {\n\
440             return ::core::result::Result::Err(\
441             ::reliakit_json::JsonDecodeError::unexpected_type(\
442             \"expected JSON null for a unit struct\"));\n\
443             }\n\
444             ::core::result::Result::Ok(Self)"
445            .to_string(),
446    }
447}
448
449/// The CSV column name for a field: a raw identifier's `r#` prefix is dropped.
450fn csv_column(field: &str) -> &str {
451    field.strip_prefix("r#").unwrap_or(field)
452}
453
454/// Returns the named fields of a struct, or a reject message. CSV needs column
455/// names, so tuple structs, unit structs, and enums are rejected. Pure, so the
456/// reject decisions are unit-testable.
457fn csv_named_fields<'a>(body: &'a Body, trait_name: &str) -> Result<&'a [String], String> {
458    match body {
459        Body::Struct(Shape::Named(fields)) => Ok(fields),
460        Body::Struct(_) => Err(format!(
461            "reliakit-derive: {trait_name} requires a struct with named fields \
462             (CSV columns need names)"
463        )),
464        Body::Enum(_) => Err(format!(
465            "reliakit-derive: {trait_name} does not support enums"
466        )),
467    }
468}
469
470/// The `header` and `encode_fields` method bodies for a named struct.
471fn csv_encode_methods(fields: &[String]) -> String {
472    let mut header = String::new();
473    let mut pushes = String::new();
474    for field in fields {
475        let column = csv_column(field);
476        header.push_str(&format!("__header.push({column:?});"));
477        pushes.push_str(&format!(
478            "__out.push(::reliakit_csv::CsvField::encode_field(&self.{field}));"
479        ));
480    }
481    format!(
482        "fn header() -> ::reliakit_csv::__private::Vec<&'static str> {{\n\
483         let mut __header = ::reliakit_csv::__private::Vec::new();\n\
484         {header}\n\
485         __header\n\
486         }}\n\
487         fn encode_fields(&self, __out: &mut ::reliakit_csv::__private::Vec<\
488         ::reliakit_csv::__private::String>) {{\n\
489         {pushes}\n\
490         }}"
491    )
492}
493
494/// The `decode_fields` method body for a named struct.
495fn csv_decode_method(fields: &[String]) -> String {
496    let count = fields.len();
497    let mut inner = String::new();
498    for (index, field) in fields.iter().enumerate() {
499        inner.push_str(&format!(
500            "{field}: ::reliakit_csv::CsvField::decode_field(__fields[{index}])\
501             .map_err(|__e| __e.at_field({index}))?,"
502        ));
503    }
504    format!(
505        "fn decode_fields(__fields: &[&str]) \
506         -> ::core::result::Result<Self, ::reliakit_csv::CsvDecodeError> {{\n\
507         if __fields.len() != {count} {{ return ::core::result::Result::Err(\
508         ::reliakit_csv::CsvDecodeError::field_count()); }}\n\
509         ::core::result::Result::Ok(Self {{ {inner} }})\n\
510         }}"
511    )
512}
513
514/// Validates a [`Raw`] item, rejecting unsupported forms with a descriptive
515/// message. Pure — it touches no `proc_macro` types, so it is unit-testable.
516fn validate(raw: Raw) -> Result<Parsed, String> {
517    match raw.body {
518        RawBody::Union => Err("reliakit-derive does not support unions".into()),
519        RawBody::Struct(shape) => {
520            if raw.has_generics {
521                return Err("reliakit-derive does not support generic types yet".into());
522            }
523            Ok(Parsed {
524                name: raw.name,
525                body: Body::Struct(shape),
526            })
527        }
528        RawBody::Enum(raw_variants) => {
529            if raw.has_generics {
530                return Err("reliakit-derive does not support generic types yet".into());
531            }
532            if raw.saw_repr {
533                return Err("reliakit-derive does not support `#[repr(...)]` on enums; \
534                            variant tags are always the u32 declaration index"
535                    .into());
536            }
537            let mut variants = Vec::new();
538            for raw_variant in raw_variants {
539                if raw_variant.has_discriminant {
540                    return Err(format!(
541                        "reliakit-derive does not support explicit enum discriminants \
542                         (`{} = ...`); variant tags are the u32 declaration index",
543                        raw_variant.name
544                    ));
545                }
546                match raw_variant.shape {
547                    Ok(shape) => variants.push(Variant {
548                        name: raw_variant.name,
549                        shape,
550                    }),
551                    Err(message) => return Err(message),
552                }
553            }
554            if variants.is_empty() {
555                return Err("reliakit-derive cannot derive for an empty enum \
556                            (there is no variant to encode or decode)"
557                    .into());
558            }
559            Ok(Parsed {
560                name: raw.name,
561                body: Body::Enum(variants),
562            })
563        }
564    }
565}
566
567/// Reads a derive input into a [`Raw`] item. Touches `proc_macro` types; its
568/// happy paths are exercised by the integration and example tests.
569fn classify(input: TokenStream) -> Result<Raw, String> {
570    let tokens: Vec<TokenTree> = input.into_iter().collect();
571
572    // Find the item keyword, skipping outer attributes and visibility, noting a
573    // `#[repr(...)]` so enums can reject it (struct behavior is unchanged).
574    let mut idx = 0;
575    let mut saw_repr = false;
576    let kind = loop {
577        match tokens.get(idx) {
578            Some(TokenTree::Ident(ident)) => match ident.to_string().as_str() {
579                "struct" => break Kind::Struct,
580                "enum" => break Kind::Enum,
581                "union" => break Kind::Union,
582                _ => idx += 1,
583            },
584            Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Bracket => {
585                if attr_is_repr(group.stream()) {
586                    saw_repr = true;
587                }
588                idx += 1;
589            }
590            Some(_) => idx += 1,
591            None => return Err("reliakit-derive: expected a struct, enum, or union".into()),
592        }
593    };
594
595    idx += 1;
596    let name = match tokens.get(idx) {
597        Some(TokenTree::Ident(ident)) => ident.to_string(),
598        _ => return Err("reliakit-derive: expected a type name after the item keyword".into()),
599    };
600    idx += 1;
601
602    let has_generics =
603        matches!(tokens.get(idx), Some(TokenTree::Punct(punct)) if punct.as_char() == '<');
604
605    let body = if has_generics {
606        // A generic item is rejected by validation before its body is used, and
607        // `idx` here points at the `<` parameters rather than the body, so don't
608        // try to read it. The placeholder body is never inspected.
609        match kind {
610            Kind::Struct => RawBody::Struct(Shape::Unit),
611            Kind::Enum => RawBody::Enum(Vec::new()),
612            Kind::Union => RawBody::Union,
613        }
614    } else {
615        match kind {
616            // The union body is never read: validation rejects unions outright.
617            Kind::Union => RawBody::Union,
618            Kind::Struct => match tokens.get(idx) {
619                Some(TokenTree::Group(group)) => match group.delimiter() {
620                    Delimiter::Brace => RawBody::Struct(Shape::Named(named_fields(group.stream()))),
621                    Delimiter::Parenthesis => {
622                        RawBody::Struct(Shape::Tuple(count_fields(group.stream())))
623                    }
624                    _ => return Err("reliakit-derive: unexpected struct body".into()),
625                },
626                Some(TokenTree::Punct(punct)) if punct.as_char() == ';' => {
627                    RawBody::Struct(Shape::Unit)
628                }
629                _ => return Err("reliakit-derive: unexpected struct body".into()),
630            },
631            Kind::Enum => match tokens.get(idx) {
632                Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => {
633                    RawBody::Enum(raw_variants(group.stream()))
634                }
635                _ => return Err("reliakit-derive: expected a braced enum body".into()),
636            },
637        }
638    };
639
640    Ok(Raw {
641        name,
642        has_generics,
643        saw_repr,
644        body,
645    })
646}
647
648/// Encode statements for a struct body (one `encode` call per field, in order).
649fn struct_encode_statements(shape: &Shape) -> String {
650    let mut body = String::new();
651    match shape {
652        Shape::Named(fields) => {
653            for field in fields {
654                body.push_str(&format!(
655                    "::reliakit_codec::CanonicalEncode::encode(&self.{field}, __writer)?;",
656                ));
657            }
658        }
659        Shape::Tuple(count) => {
660            for index in 0..*count {
661                body.push_str(&format!(
662                    "::reliakit_codec::CanonicalEncode::encode(&self.{index}, __writer)?;",
663                ));
664            }
665        }
666        Shape::Unit => {}
667    }
668    body
669}
670
671/// The decode body for a struct (returns `Ok(Self { .. })`).
672fn struct_decode_value(shape: &Shape) -> String {
673    let construct = match shape {
674        Shape::Named(fields) => {
675            let mut inner = String::new();
676            for field in fields {
677                inner.push_str(&format!(
678                    "{field}: ::reliakit_codec::CanonicalDecode::decode(__reader)?,",
679                ));
680            }
681            format!("Self {{ {inner} }}")
682        }
683        Shape::Tuple(count) => {
684            let mut inner = String::new();
685            for _ in 0..*count {
686                inner.push_str("::reliakit_codec::CanonicalDecode::decode(__reader)?,");
687            }
688            format!("Self({inner})")
689        }
690        Shape::Unit => "Self".to_string(),
691    };
692    format!("::core::result::Result::Ok({construct})")
693}
694
695/// Encode statements for an enum body: `match self { .. }`, where each arm
696/// writes the variant's `u32` declaration-index tag, then its fields in order.
697fn enum_encode_statements(variants: &[Variant]) -> String {
698    let mut arms = String::new();
699    for (index, variant) in variants.iter().enumerate() {
700        let tag = index as u32;
701        let name = &variant.name;
702        let tag_encode =
703            format!("::reliakit_codec::CanonicalEncode::encode(&{tag}u32, __writer)?;");
704        match &variant.shape {
705            Shape::Unit => {
706                arms.push_str(&format!("Self::{name} => {{ {tag_encode} }},"));
707            }
708            Shape::Tuple(count) => {
709                let mut pattern = String::new();
710                let mut encodes = String::new();
711                for i in 0..*count {
712                    if i > 0 {
713                        pattern.push_str(", ");
714                    }
715                    pattern.push_str(&format!("__f{i}"));
716                    encodes.push_str(&format!(
717                        "::reliakit_codec::CanonicalEncode::encode(__f{i}, __writer)?;",
718                    ));
719                }
720                arms.push_str(&format!(
721                    "Self::{name}({pattern}) => {{ {tag_encode} {encodes} }},"
722                ));
723            }
724            Shape::Named(fields) => {
725                let mut pattern = String::new();
726                let mut encodes = String::new();
727                for (i, field) in fields.iter().enumerate() {
728                    if i > 0 {
729                        pattern.push_str(", ");
730                    }
731                    // Bind each named field to a positional local to avoid any
732                    // collision with `__writer`.
733                    pattern.push_str(&format!("{field}: __f{i}"));
734                    encodes.push_str(&format!(
735                        "::reliakit_codec::CanonicalEncode::encode(__f{i}, __writer)?;",
736                    ));
737                }
738                arms.push_str(&format!(
739                    "Self::{name} {{ {pattern} }} => {{ {tag_encode} {encodes} }},"
740                ));
741            }
742        }
743    }
744    format!("match self {{ {arms} }}")
745}
746
747/// The decode body for an enum: read the `u32` tag, then build the matching
748/// variant. An unknown tag is an `invalid_value` codec error.
749fn enum_decode_value(name: &str, variants: &[Variant]) -> String {
750    let mut arms = String::new();
751    for (index, variant) in variants.iter().enumerate() {
752        let tag = index as u32;
753        let vname = &variant.name;
754        let construct = match &variant.shape {
755            Shape::Unit => format!("Self::{vname}"),
756            Shape::Tuple(count) => {
757                let mut inner = String::new();
758                for _ in 0..*count {
759                    inner.push_str("::reliakit_codec::CanonicalDecode::decode(__reader)?,");
760                }
761                format!("Self::{vname}({inner})")
762            }
763            Shape::Named(fields) => {
764                let mut inner = String::new();
765                for field in fields {
766                    inner.push_str(&format!(
767                        "{field}: ::reliakit_codec::CanonicalDecode::decode(__reader)?,",
768                    ));
769                }
770                format!("Self::{vname} {{ {inner} }}")
771            }
772        };
773        arms.push_str(&format!("{tag}u32 => {construct},"));
774    }
775
776    let message = format!("reliakit-derive: unknown variant tag for {name}");
777    format!(
778        "let __tag: u32 = ::reliakit_codec::CanonicalDecode::decode(__reader)?;\n\
779         ::core::result::Result::Ok(match __tag {{\n\
780         {arms}\n\
781         _ => return ::core::result::Result::Err(\
782         ::reliakit_codec::CodecError::invalid_value({message:?})),\n\
783         }})"
784    )
785}
786
787/// Reads enum variants into [`RawVariant`]s without validating them.
788fn raw_variants(stream: TokenStream) -> Vec<RawVariant> {
789    let mut variants = Vec::new();
790    for segment in top_level_segments(stream) {
791        if segment.is_empty() {
792            // A trailing comma produces an empty final segment.
793            continue;
794        }
795
796        // The variant name is the first identifier in the segment (any leading
797        // outer attributes are non-ident tokens and are skipped).
798        let name_idx = match segment
799            .iter()
800            .position(|t| matches!(t, TokenTree::Ident(_)))
801        {
802            Some(i) => i,
803            None => {
804                variants.push(RawVariant {
805                    name: String::new(),
806                    shape: Err("reliakit-derive: expected an enum variant name".into()),
807                    has_discriminant: false,
808                });
809                continue;
810            }
811        };
812        let name = match &segment[name_idx] {
813            TokenTree::Ident(ident) => ident.to_string(),
814            _ => unreachable!("position matched an ident"),
815        };
816
817        let mut has_discriminant = false;
818        let shape = match segment.get(name_idx + 1) {
819            None => Ok(Shape::Unit),
820            Some(TokenTree::Group(group)) => match group.delimiter() {
821                Delimiter::Parenthesis => Ok(Shape::Tuple(count_fields(group.stream()))),
822                Delimiter::Brace => Ok(Shape::Named(named_fields(group.stream()))),
823                _ => Err(format!(
824                    "reliakit-derive: unsupported syntax in enum variant `{name}`"
825                )),
826            },
827            // An explicit discriminant: record it; validation rejects it. The
828            // placeholder shape is never used.
829            Some(TokenTree::Punct(punct)) if punct.as_char() == '=' => {
830                has_discriminant = true;
831                Ok(Shape::Unit)
832            }
833            Some(_) => Err(format!(
834                "reliakit-derive: unsupported syntax in enum variant `{name}`"
835            )),
836        };
837
838        variants.push(RawVariant {
839            name,
840            shape,
841            has_discriminant,
842        });
843    }
844    variants
845}
846
847/// Returns `true` if an outer-attribute body `[ ... ]` is a `repr` attribute.
848fn attr_is_repr(stream: TokenStream) -> bool {
849    matches!(stream.into_iter().next(), Some(TokenTree::Ident(ident)) if ident.to_string() == "repr")
850}
851
852/// Collects the names of named fields in declaration order.
853fn named_fields(stream: TokenStream) -> Vec<String> {
854    let mut fields = Vec::new();
855    for segment in top_level_segments(stream) {
856        // The field name is the first ident immediately followed by a `:`
857        // (the field/type separator, which is an `Alone`-spaced colon).
858        for window in segment.windows(2) {
859            if let (TokenTree::Ident(ident), TokenTree::Punct(punct)) = (&window[0], &window[1]) {
860                if punct.as_char() == ':' && punct.spacing() == Spacing::Alone {
861                    fields.push(ident.to_string());
862                    break;
863                }
864            }
865        }
866    }
867    fields
868}
869
870/// Counts the fields of a tuple body (non-empty top-level segments).
871fn count_fields(stream: TokenStream) -> usize {
872    top_level_segments(stream)
873        .into_iter()
874        .filter(|segment| !segment.is_empty())
875        .count()
876}
877
878/// Splits a token stream on top-level commas, dropping the commas.
879fn top_level_segments(stream: TokenStream) -> Vec<Vec<TokenTree>> {
880    let mut segments = Vec::new();
881    let mut current = Vec::new();
882    for token in stream {
883        match &token {
884            TokenTree::Punct(punct) if punct.as_char() == ',' => {
885                segments.push(core::mem::take(&mut current));
886            }
887            _ => current.push(token),
888        }
889    }
890    if !current.is_empty() {
891        segments.push(current);
892    }
893    segments
894}
895
896/// Builds a `compile_error!` invocation carrying `message`.
897fn compile_error(message: &str) -> TokenStream {
898    format!("::core::compile_error!({message:?});")
899        .parse()
900        .expect("compile_error message produced invalid tokens")
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    fn enum_raw(variants: Vec<RawVariant>, saw_repr: bool, has_generics: bool) -> Raw {
908        Raw {
909            name: "E".to_string(),
910            has_generics,
911            saw_repr,
912            body: RawBody::Enum(variants),
913        }
914    }
915
916    fn unit_variant(name: &str) -> RawVariant {
917        RawVariant {
918            name: name.to_string(),
919            shape: Ok(Shape::Unit),
920            has_discriminant: false,
921        }
922    }
923
924    // `Parsed` deliberately has no `Debug`, so these avoid `unwrap`/`unwrap_err`.
925    fn err_of(raw: Raw) -> String {
926        match validate(raw) {
927            Err(message) => message,
928            Ok(_) => panic!("expected validation to reject the item"),
929        }
930    }
931
932    fn ok_of(raw: Raw) -> Parsed {
933        match validate(raw) {
934            Ok(parsed) => parsed,
935            Err(message) => panic!("unexpected validation error: {message}"),
936        }
937    }
938
939    #[test]
940    fn rejects_union() {
941        let raw = Raw {
942            name: "U".to_string(),
943            has_generics: false,
944            saw_repr: false,
945            body: RawBody::Union,
946        };
947        assert!(err_of(raw).contains("does not support unions"));
948    }
949
950    #[test]
951    fn rejects_generic_struct() {
952        let raw = Raw {
953            name: "S".to_string(),
954            has_generics: true,
955            saw_repr: false,
956            body: RawBody::Struct(Shape::Unit),
957        };
958        assert!(err_of(raw).contains("does not support generic types yet"));
959    }
960
961    #[test]
962    fn rejects_generic_enum() {
963        let raw = enum_raw(vec![unit_variant("A")], false, true);
964        assert!(err_of(raw).contains("does not support generic types yet"));
965    }
966
967    #[test]
968    fn rejects_repr_enum() {
969        let raw = enum_raw(vec![unit_variant("A")], true, false);
970        assert!(err_of(raw).contains("does not support `#[repr(...)]` on enums"));
971    }
972
973    #[test]
974    fn rejects_explicit_discriminant() {
975        let raw = enum_raw(
976            vec![RawVariant {
977                name: "A".to_string(),
978                shape: Ok(Shape::Unit),
979                has_discriminant: true,
980            }],
981            false,
982            false,
983        );
984        let err = err_of(raw);
985        assert!(err.contains("does not support explicit enum discriminants"));
986        assert!(err.contains("`A = ...`"));
987    }
988
989    #[test]
990    fn rejects_empty_enum() {
991        let raw = enum_raw(vec![], false, false);
992        assert!(err_of(raw).contains("cannot derive for an empty enum"));
993    }
994
995    #[test]
996    fn rejects_unsupported_variant_syntax() {
997        let raw = enum_raw(
998            vec![RawVariant {
999                name: "A".to_string(),
1000                shape: Err("reliakit-derive: unsupported syntax in enum variant `A`".to_string()),
1001                has_discriminant: false,
1002            }],
1003            false,
1004            false,
1005        );
1006        assert!(err_of(raw).contains("unsupported syntax"));
1007    }
1008
1009    #[test]
1010    fn accepts_struct() {
1011        let raw = Raw {
1012            name: "S".to_string(),
1013            has_generics: false,
1014            saw_repr: false,
1015            body: RawBody::Struct(Shape::Named(vec!["x".to_string()])),
1016        };
1017        let parsed = ok_of(raw);
1018        assert_eq!(parsed.name, "S");
1019        assert!(matches!(parsed.body, Body::Struct(Shape::Named(_))));
1020    }
1021
1022    #[test]
1023    fn accepts_enum_preserving_variant_order() {
1024        let raw = enum_raw(
1025            vec![
1026                unit_variant("A"),
1027                RawVariant {
1028                    name: "B".to_string(),
1029                    shape: Ok(Shape::Tuple(1)),
1030                    has_discriminant: false,
1031                },
1032                RawVariant {
1033                    name: "C".to_string(),
1034                    shape: Ok(Shape::Named(vec!["id".to_string()])),
1035                    has_discriminant: false,
1036                },
1037            ],
1038            false,
1039            false,
1040        );
1041        match ok_of(raw).body {
1042            Body::Enum(variants) => {
1043                let names: Vec<&str> = variants.iter().map(|v| v.name.as_str()).collect();
1044                assert_eq!(names, ["A", "B", "C"]);
1045            }
1046            Body::Struct(_) => panic!("expected an enum body"),
1047        }
1048    }
1049
1050    #[test]
1051    fn csv_rejects_non_named_structs_and_enums() {
1052        assert!(
1053            csv_named_fields(&Body::Struct(Shape::Tuple(2)), "CsvEncode")
1054                .unwrap_err()
1055                .contains("requires a struct with named fields")
1056        );
1057        assert!(csv_named_fields(&Body::Struct(Shape::Unit), "CsvDecode")
1058            .unwrap_err()
1059            .contains("named fields"));
1060        let enum_body = Body::Enum(vec![Variant {
1061            name: "A".to_string(),
1062            shape: Shape::Unit,
1063        }]);
1064        assert!(csv_named_fields(&enum_body, "CsvEncode")
1065            .unwrap_err()
1066            .contains("does not support enums"));
1067    }
1068
1069    #[test]
1070    fn csv_named_struct_builds_methods() {
1071        let body = Body::Struct(Shape::Named(vec!["id".to_string(), "r#type".to_string()]));
1072        let fields = csv_named_fields(&body, "CsvEncode").expect("named struct accepted");
1073        let enc = csv_encode_methods(fields);
1074        assert!(enc.contains("__header.push(\"id\")"));
1075        // The `r#` prefix is dropped for the column name but kept for field access.
1076        assert!(enc.contains("__header.push(\"type\")"));
1077        assert!(enc.contains("encode_field(&self.id)"));
1078        assert!(enc.contains("encode_field(&self.r#type)"));
1079        let dec = csv_decode_method(fields);
1080        assert!(dec.contains("__fields.len() != 2"));
1081        assert!(dec.contains("__fields[0]"));
1082        assert!(dec.contains("at_field(1)"));
1083    }
1084}