fixedlength_format_parser/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use proc_macro2::{Ident, Span};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{parse_macro_input, Data, DeriveInput, Expr, Lit};
7
8#[derive(Debug)]
9struct RecordVariant {
10    kind: String,
11    enum_name: Ident,
12    variant_name: Ident,
13    fields: Vec<RecordField>,
14}
15
16impl ToTokens for RecordVariant {
17    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
18        let kind = &self.kind;
19        let enum_name = &self.enum_name;
20        let variant_name = &self.variant_name;
21        let fields = &self.fields;
22        tokens.append_all(quote! {
23            #kind => {
24                Ok(#enum_name::#variant_name {
25                    #(#fields),*
26                })
27            }
28        });
29    }
30}
31
32#[derive(Debug)]
33struct RecordField {
34    /// The name of the field in the enum variant.
35    name: Ident,
36    /// The point in the line at which this record begins.
37    from: usize,
38    /// The point in the line at which this record end (exclusive).
39    to: usize,
40    /// The kind of this record.
41    record_kind: String,
42    /// The ident of the error enum.
43    error_ident: Ident,
44}
45
46impl ToTokens for RecordField {
47    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
48        let name = &self.name;
49        let from = self.from;
50        let to = self.to;
51        let record_kind = &self.record_kind;
52        let error_ident = &self.error_ident;
53        let name_str = name.to_string();
54        tokens.append_all(quote! {
55            #name: s[#from..#to].parse().map_err(|_| #error_ident::FailedToParse {
56                record_type: #record_kind.to_string(),
57                field: #name_str.to_string(),
58            })?
59        });
60    }
61}
62
63#[proc_macro_derive(
64    FixedLengthFormatParser,
65    attributes(record_type, field_starts, field_ends, field_length)
66)]
67pub fn fixed_length_format_parser(item: TokenStream) -> TokenStream {
68    let input = parse_macro_input!(item as DeriveInput);
69    let target_ident = input.ident;
70    let error_ident = Ident::new(&format!("{}ParseError", target_ident), Span::call_site());
71    let visibility = input.vis;
72
73    let mut record_type_len = 0;
74    let mut known_variants = vec![];
75
76    // Validate that all record types are specified
77    // Validate that all record types are the same length
78
79    if let Data::Enum(enum_data) = input.data {
80        for variant in enum_data.variants {
81            assert!(variant.discriminant.is_none(), "Enum variants must not have a discriminant set to be built into a FixedLengthFormatParser.");
82
83            let mut current_cursor = 0;
84
85            for attr in variant.attrs {
86                let attr = attr.meta.require_name_value().unwrap();
87                if *attr.path.get_ident().unwrap() != "record_type" {
88                    panic!("Only the `record_type` attribute is expected on an enum variant.");
89                }
90                match &attr.value {
91                    Expr::Lit(literal) => {
92                        match &literal.lit {
93                            Lit::Str(st) => {
94                                let record_type = st.value();
95                                if record_type_len == 0 {
96                                    record_type_len = record_type.len();
97                                } else if record_type_len != record_type.len() {
98                                    panic!("All `record_type`s must be the same length.");
99                                }
100
101                                known_variants.push(RecordVariant {
102                                    kind: record_type.clone(),
103                                    enum_name: target_ident.clone(),
104                                    variant_name: variant.ident.clone(),
105                                    fields: variant.fields.iter().map(|f| {
106                                        let mut from = current_cursor;
107                                        let mut length = 0;
108                                        let mut to = current_cursor;
109
110                                        for attr in &f.attrs {
111                                            let attr = attr.meta.require_name_value().unwrap();
112                                            match attr.path.get_ident().unwrap().to_string().as_str() {
113                                                "field_starts" => {
114                                                    from = get_number(&attr.value);
115                                                    to = from + length;
116                                                },
117                                                "field_ends" => {
118                                                    to = get_number(&attr.value);
119                                                    length = to - from;
120                                                    current_cursor = to;
121                                                },
122                                                "field_length" => {
123                                                    length = get_number(&attr.value);
124                                                    to = from + length;
125                                                    current_cursor = to;
126                                                },
127                                                _ => {/* some other ident we don't care about */},
128                                            }
129                                        }
130
131                                        assert_ne!(length, 0, "`{}` field length is zero!", f.ident.as_ref().unwrap());
132
133                                        RecordField {
134                                            name: f.ident.clone().expect("the enum variants must be full structs, not tuples."),
135                                            from,
136                                            to,
137                                            record_kind: record_type.clone(),
138                                            error_ident: error_ident.clone(),
139                                        }
140                                    }).collect(),
141                                });
142                            },
143                            _ => panic!("`record_type` must specify a string literal, e.g.: #[record_type = \"HD\"]"),
144                        }
145                    },
146                    _ => panic!("`record_type` must specify a string literal, e.g.: #[record_type = \"HD\"]"),
147                }
148            }
149        }
150    } else {
151        panic!("FixedLengthFormatParser can only derive from enums.");
152    }
153
154    if record_type_len == 0 {
155        panic!("No `record_type`s have been specified, so the parser cannot be built.");
156    }
157
158    let expanded = quote! {
159        #[derive(Debug)]
160        #visibility enum #error_ident {
161            InvalidRecordType,
162            FailedToParse {
163                record_type: String,
164                field: String,
165            },
166        }
167        impl ::std::error::Error for #error_ident {}
168        impl ::std::fmt::Display for #error_ident {
169            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
170                match self {
171                    Self::InvalidRecordType => write!(f, "invalid record type"),
172                    Self::FailedToParse { record_type, field } => write!(f, "failed to parse field `{field}` in {record_type} record."),
173                }
174            }
175        }
176
177        impl ::std::str::FromStr for #target_ident {
178            type Err = #error_ident;
179
180            fn from_str(s: &str) -> Result<Self, Self::Err> {
181                let record_type = &s[0..#record_type_len];
182
183                match record_type {
184                    #(#known_variants),*
185                    _ => Err(#error_ident::InvalidRecordType),
186                }
187            }
188        }
189    };
190
191    TokenStream::from(expanded)
192}
193
194fn get_number(expr: &Expr) -> usize {
195    match &expr {
196        Expr::Lit(literal) => match &literal.lit {
197            Lit::Int(i) => i
198                .base10_parse()
199                .expect("expected number for field attribute"),
200            _ => panic!("expected number for field attribute"),
201        },
202        _ => panic!("expected number for field attribute"),
203    }
204}