line_cutter_macros/
lib.rs

1extern crate proc_macro;
2
3use quote::quote;
4use syn::Type;
5
6/// Representation of a field in a positionally encoded document
7struct PositionalField {
8    ident: syn::Ident,
9    start: usize, // Note: start is not validated (min should be 1)
10    size: usize,  // Note: size is not validated (min should be 1)
11    rust_type: syn::Type,
12    custom_decoder: Option<syn::Path>, // Optional custom decoder function
13    custom_encoder: Option<syn::Path>, // Optional custom encoder function
14}
15
16impl PositionalField {
17    /// returns the start index of the field within the encoded String (0-indexed)
18    fn get_start_index(&self) -> usize {
19        self.start - 1
20    }
21
22    /// returns the end index of the field within the encoded String (0-indexed)
23    fn get_end_index(&self) -> usize {
24        self.get_start_index() + self.size
25    }
26}
27
28/// Supported Rust types
29enum SupportedRustTypes {
30    Boolean,
31    Int64,
32    Int32,
33    UInt64,
34    UInt32,
35    UInt16,
36    UInt8,
37    String,
38    NaiveDate,
39    NaiveDateTime,
40    NaiveTime,
41    ChronoTimeDelta,
42}
43
44impl TryFrom<&str> for SupportedRustTypes {
45    type Error = ();
46
47    fn try_from(value: &str) -> Result<Self, ()> {
48        match value {
49            "bool" => Ok(SupportedRustTypes::Boolean),
50            "i32" => Ok(SupportedRustTypes::Int32),
51            "i64" => Ok(SupportedRustTypes::Int64),
52            "u64" => Ok(SupportedRustTypes::UInt64),
53            "u32" => Ok(SupportedRustTypes::UInt32),
54            "u16" => Ok(SupportedRustTypes::UInt16),
55            "u8" => Ok(SupportedRustTypes::UInt8),
56            "String" => Ok(SupportedRustTypes::String),
57            "NaiveDate" => Ok(SupportedRustTypes::NaiveDate),
58            "chrono::NaiveDate" => Ok(SupportedRustTypes::NaiveDate),
59            "NaiveDateTime" => Ok(SupportedRustTypes::NaiveDateTime),
60            "chrono::NaiveDateTime" => Ok(SupportedRustTypes::NaiveDateTime),
61            "NaiveTime" => Ok(SupportedRustTypes::NaiveTime),
62            "chrono::NaiveTime" => Ok(SupportedRustTypes::NaiveTime),
63            "TimeDelta" => Ok(SupportedRustTypes::ChronoTimeDelta),
64            "chrono::TimeDelta" => Ok(SupportedRustTypes::ChronoTimeDelta),
65            _ => Err(()),
66        }
67    }
68}
69
70struct ParsedRustType {
71    optional: bool,
72    rust_type: SupportedRustTypes,
73}
74
75#[proc_macro_derive(PositionalText, attributes(positional_field))]
76pub fn positional_text(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
77    let input = syn::parse_macro_input!(input as syn::DeriveInput);
78    let struct_ident = &input.ident;
79    let struct_name: String = struct_ident.to_string();
80
81    // fail if data type is enum or union (we only support strcut)
82    let struct_data: &syn::DataStruct = match &input.data {
83        syn::Data::Struct(data_struct) => data_struct,
84        syn::Data::Enum(data_enum) => {
85            return syn::Error::new_spanned(
86                data_enum.enum_token,
87                "Enum is not supported. Only `struct`s are supoorted.",
88            )
89            .to_compile_error()
90            .into();
91        }
92        syn::Data::Union(data_union) => {
93            return syn::Error::new_spanned(
94                data_union.union_token,
95                "Union is not supported. Only `struct`s are supoorted.",
96            )
97            .to_compile_error()
98            .into();
99        }
100    };
101
102    // field_decoders are the code that converts input String into struct fields
103    // Will contain token streams for each field's decoding logic
104    let mut field_decoders = Vec::new();
105    // Will contain token streams for each field's encoding logic
106    let mut field_encoders: Vec<proc_macro2::TokenStream> = Vec::new();
107    // Will contain token streams for each field's validation logic
108    let mut field_validators: Vec<proc_macro2::TokenStream> = Vec::new();
109
110    // fail if the struct contains a non-named fields
111    let named_fields = match &struct_data.fields {
112        syn::Fields::Named(fields) => fields,
113        _ => panic!("Only named fields are supported"),
114    };
115    // computed length of the record
116    let mut record_length: usize = 0;
117    // Process each field of the struct
118    for field in &named_fields.named {
119        // parse the struct's field
120        let positional_field = match parse_field(field) {
121            Ok(positional_field) => positional_field,
122            Err(err_token_stream) => return err_token_stream.into(),
123        };
124        if positional_field.get_end_index() > record_length {
125            record_length = positional_field.get_end_index();
126        }
127
128        let field_decoder_res = gen_field_decoder(&positional_field, &struct_name);
129        match field_decoder_res {
130            Err(err) => return err.into(),
131            Ok(field_decoder) => field_decoders.push(field_decoder),
132        }
133
134        let field_encoder_res = gen_field_encoder(&positional_field);
135        match field_encoder_res {
136            Err(err) => return err.into(),
137            Ok(field_encoder) => field_encoders.push(field_encoder),
138        }
139
140        let field_validator_res = gen_field_validator(&positional_field);
141        match field_validator_res {
142            Err(err) => return err.into(),
143            Ok(field_validator) => field_validators.push(field_validator),
144        }
145    }
146
147    let expanded = quote! {
148
149        fn hhmmss_to_timedelta(hhmmss: &str) -> Result<chrono::TimeDelta, Box<dyn std::error::Error>> {
150            if hhmmss.len() != 6 {
151                return Err("Invalid HHMMSS format".into());
152            }
153
154            let hours: i64 = hhmmss[0..2].parse()?;
155            let minutes: i64 = hhmmss[2..4].parse()?;
156            if minutes > 59 {
157                return Err("Minutes must be <= 59".into());
158            }
159            let seconds: i64 = hhmmss[4..6].parse()?;
160            if seconds > 59 {
161                return Err("Seconds must be <= 59".into());
162            }
163
164            Ok(chrono::TimeDelta::hours(hours) + chrono::TimeDelta::minutes(minutes) + chrono::TimeDelta::seconds(seconds))
165        }
166
167        fn timedelta_to_hhmmss(duration: Option<chrono::TimeDelta>) -> String {
168            match duration {
169                Some(d) => {
170                    let total_seconds = d.num_seconds();
171                    let hours: i64 = total_seconds / 3600;
172                    let minutes: i64 = (total_seconds % 3600) / 60;
173                    let seconds: i64 = total_seconds % 60;
174                    format!("{:02}{:02}{:02}", hours, minutes, seconds)
175                },
176                None => "      ".to_string()
177            }
178        }
179
180        #[automatically_derived]
181        impl line_cutter::PositionalEncoded for #struct_ident {
182            fn decode(s: &str) -> Result<Self, line_cutter::DecodeError> {
183                // First, validate and pad the input string to the correct length
184                let s_str: String = if s.chars().count() > #record_length {
185                    Err(
186                        line_cutter::DecodeError {
187                            start: 0,
188                            end: s.chars().count(),
189                            input_value: s.to_string(),
190                            field_name: None,
191                            record_name: Some(stringify!(#struct_ident).to_string()),
192                            error: format!(
193                                "Invalid record length. Expected {} but found {}.",
194                                #record_length,
195                                s.chars().count(),
196                            ),
197                        }
198                    )?
199                } else {
200                    // Ensure it's the correct length by padding with spaces if needed
201                    // WARN: This could cause the decoding to fail later if added spaces are invalid.
202                    //       This is technically creating data. The records should be the correct length...
203                    let char_count = s.chars().count();
204                    if char_count < #record_length {
205                        format!("{}{}", s, " ".repeat(#record_length - char_count))
206                    } else {
207                        s.to_string()
208                    }
209                };
210
211                // Convert to Vec<char> for proper Unicode-aware character indexing
212                let chars: Vec<char> = s_str.chars().collect();
213
214                Ok(
215                    #struct_ident {
216                        #(#field_decoders)*
217                    }
218                )
219            }
220
221            fn encode(&self) -> String {
222                let mut result = String::new();
223                #(#field_encoders)*
224                result
225            }
226
227            fn validate(&self) -> Result<(), Vec<line_cutter::ValidationError>> {
228                let mut errors = Vec::new();
229                #(#field_validators)*
230                if errors.is_empty() {
231                    Ok(())
232                } else {
233                    Err(errors)
234                }
235            }
236        }
237    };
238    proc_macro::TokenStream::from(expanded)
239}
240
241/// Parses a field and it's attributes, returning a PositionalField or Error tokens to return.
242fn parse_field(input_field: &syn::Field) -> Result<PositionalField, proc_macro2::TokenStream> {
243    let field_ident = input_field.ident.as_ref().unwrap();
244    let rust_type = input_field.ty.clone();
245    let mut start: Option<usize> = None;
246    let mut size: Option<usize> = None;
247    let mut custom_decoder: Option<syn::Path> = None;
248    let mut custom_encoder: Option<syn::Path> = None;
249
250    let pos_field_attr = input_field
251        .attrs
252        .iter()
253        .find(|attr| attr.path().is_ident("positional_field"));
254
255    let pos_field_attr = pos_field_attr.ok_or(
256        syn::Error::new_spanned(
257            field_ident,
258            "Missing positional_field attribute, e.g. `#[positional_field(start = 1, size = 3)]`",
259        )
260        .to_compile_error(),
261    )?;
262
263    // extract the metadata in the parameter attributes (start, size, decoder, encoder)
264    pos_field_attr
265        .parse_nested_meta(|meta| {
266            if meta.path.is_ident("start") {
267                let value = meta.value()?; // this parses the `=`
268                let attr_value: syn::LitInt = value.parse()?; // this parses the value after =
269                start = Some(attr_value.base10_parse()?);
270                Ok(())
271            } else if meta.path.is_ident("size") {
272                let value = meta.value()?; // this parses the `=`
273                let attr_value: syn::LitInt = value.parse()?; // this parses the value after =
274                size = Some(attr_value.base10_parse()?);
275                Ok(())
276            } else if meta.path.is_ident("decoder") {
277                let value = meta.value()?; // this parses the `=`
278                let attr_value: syn::LitStr = value.parse()?; // this parses the string literal
279                custom_decoder = Some(attr_value.parse()?);
280                Ok(())
281            } else if meta.path.is_ident("encoder") {
282                let value = meta.value()?; // this parses the `=`
283                let attr_value: syn::LitStr = value.parse()?; // this parses the string literal
284                custom_encoder = Some(attr_value.parse()?);
285                Ok(())
286            } else {
287                Err(meta.error(format!(
288                    "unrecognized positional_field attribute {}",
289                    meta.path.get_ident().unwrap()
290                )))
291            }
292        })
293        .map_err(|err| err.into_compile_error())?;
294
295    // check that all the required fields are present
296    let start = start.ok_or(
297        syn::Error::new_spanned(
298            field_ident,
299            "`start` argument must be specified, e.g. `#[positional_field(start = 1, ...)]`",
300        )
301        .to_compile_error(),
302    )?;
303
304    let size = size.ok_or(
305        syn::Error::new_spanned(
306            field_ident,
307            "`size` argument must be specified, e.g. `#[positional_field(size = 3, ...)]`",
308        )
309        .to_compile_error(),
310    )?;
311
312    Ok(PositionalField {
313        ident: field_ident.clone(),
314        start,
315        size,
316        rust_type,
317        custom_decoder,
318        custom_encoder,
319    })
320}
321
322/// Parses a Rust type from a syn::Type into a ParsedRustType struct.
323/// Option<T> is supported where T is any of the above types.
324fn parse_rust_type(raw_type: &syn::Type) -> Result<ParsedRustType, proc_macro2::TokenStream> {
325    if let Type::Path(type_path) = raw_type {
326        if type_path.qself.is_some() {
327            return Err(syn::Error::new_spanned(
328                raw_type,
329                "Unsupported field type: qualified self types are not supported",
330            )
331            .to_compile_error());
332        }
333
334        if let Some(segment) = type_path.path.segments.last() {
335            let type_name = segment.ident.to_string();
336            if type_name == "Option" {
337                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
338                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
339                        // Recurse for the inner type, but we know it's optional.
340                        let inner_parsed = parse_rust_type(inner_ty)?;
341                        if inner_parsed.optional {
342                            return Err(syn::Error::new_spanned(
343                                raw_type,
344                                "Nested Options are not supported",
345                            )
346                            .to_compile_error());
347                        }
348                        return Ok(ParsedRustType {
349                            optional: true,
350                            rust_type: inner_parsed.rust_type,
351                        });
352                    }
353                }
354                return Err(syn::Error::new_spanned(
355                    raw_type,
356                    "Option without a type parameter is not supported",
357                )
358                .to_compile_error());
359            } else {
360                // Not an option, so it's a regular type.
361                // To support `chrono::NaiveDate`, we can convert the whole type path to a string.
362                let type_as_string = quote!(#raw_type).to_string().replace(' ', "");
363                let rust_type = SupportedRustTypes::try_from(type_as_string.as_str())
364                    .or_else(|_| SupportedRustTypes::try_from(type_name.as_str()))
365                    .map_err(|_| {
366                        syn::Error::new_spanned(
367                            raw_type,
368                            format!("Unsupported field type: {}", type_as_string),
369                        )
370                        .to_compile_error()
371                    })?;
372
373                return Ok(ParsedRustType {
374                    optional: false,
375                    rust_type,
376                });
377            }
378        }
379    }
380    Err(syn::Error::new_spanned(raw_type, "Unsupported field type.").to_compile_error())
381}
382
383/// Generates decoders based on the rust field type.
384/// Creates logic to decode a field from a fixed-width string to a Rust type.
385/// Uses the field's start and size attributes to extract the correct substring from the input string.
386fn gen_field_decoder(
387    positional_field: &PositionalField,
388    struct_name: &str,
389) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
390    let field_ident = &positional_field.ident;
391    let field_name = field_ident.to_string();
392    let start_index = positional_field.get_start_index();
393    let end_index = positional_field.get_end_index();
394
395    // If custom decoder is provided, use it instead of standard decoding
396    if let Some(custom_decoder) = &positional_field.custom_decoder {
397        return Ok(quote! {
398            #field_ident: {
399                let val: String = chars[#start_index..#end_index].iter().collect();
400                #custom_decoder(&val)
401                    .map_err(|e| line_cutter::DecodeError {
402                        start: #start_index,
403                        end: #end_index,
404                        input_value: val.to_string(),
405                        field_name: Some(#field_name.to_string()),
406                        record_name: Some(#struct_name.to_string()),
407                        error: e.to_string(),
408                    })?
409            },
410        });
411    }
412
413    let parsed_type = parse_rust_type(&positional_field.rust_type)?;
414
415    let inner_parser = match parsed_type.rust_type {
416        SupportedRustTypes::Boolean => quote! {
417            match val.trim() {
418                "Y" => Ok(true),
419                "N" => Ok(false),
420                _ => Err(format!("Invalid boolean value: '{}'", val)),
421            }
422        },
423        SupportedRustTypes::Int32 => {
424            quote! { val.trim().parse::<i32>().map_err(|e| format!("Failed to parse i32: {:?}", e)) }
425        }
426        SupportedRustTypes::Int64 => {
427            quote! { val.trim().parse::<i64>().map_err(|e| format!("Failed to parse i64: {:?}", e)) }
428        }
429        SupportedRustTypes::UInt32 => {
430            quote! { val.trim().parse::<u32>().map_err(|e| format!("Failed to parse u32: {:?}", e)) }
431        }
432        SupportedRustTypes::UInt64 => {
433            quote! { val.trim().parse::<u64>().map_err(|e| format!("Failed to parse u64: {:?}", e)) }
434        }
435        SupportedRustTypes::UInt16 => {
436            quote! { val.trim().parse::<u16>().map_err(|e| format!("Failed to parse u16: {:?}", e)) }
437        }
438        SupportedRustTypes::UInt8 => {
439            quote! { val.trim().parse::<u8>().map_err(|e| format!("Failed to parse u8: {:?}", e)) }
440        }
441        SupportedRustTypes::String => quote! {
442            val.trim().parse::<String>()
443        },
444        SupportedRustTypes::NaiveDate => quote! {
445            chrono::NaiveDate::parse_from_str(val.trim(), "%Y%m%d").map_err(|e| format!("Invalid date format: {:?}", e))
446        },
447        SupportedRustTypes::NaiveDateTime => quote! {
448            chrono::NaiveDateTime::parse_from_str(val.trim(), "%Y%m%d%H%M%S").map_err(|e| format!("Invalid datetime format: {:?}", e))
449        },
450        SupportedRustTypes::NaiveTime => quote! {
451            chrono::NaiveTime::parse_from_str(val.trim(), "%H%M%S").map_err(|e| format!("Invalid time format: {:?}", e))
452        },
453        SupportedRustTypes::ChronoTimeDelta => quote! {
454            hhmmss_to_timedelta(val.trim()).map_err(|e| format!("Invalid duration format: {:?}", e))
455        },
456    };
457
458    let final_decoder = if parsed_type.optional {
459        quote! {
460            {
461                let val: String = chars[#start_index..#end_index].iter().collect();
462                if val.trim().is_empty() {
463                    None
464                } else {
465                    let parsed = #inner_parser.map_err(|e| line_cutter::DecodeError {
466                        start: #start_index,
467                        end: #end_index,
468                        input_value: val.to_string(),
469                        field_name: Some(#field_name.to_string()),
470                        record_name: Some(#struct_name.to_string()),
471                        error: e.to_string(),
472                    })?;
473                    Some(parsed)
474                }
475            }
476        }
477    } else {
478        quote! {
479            {
480                let val: String = chars[#start_index..#end_index].iter().collect();
481                #inner_parser.map_err(|e| line_cutter::DecodeError {
482                    start: #start_index,
483                    end: #end_index,
484                    input_value: val.to_string(),
485                    field_name: Some(#field_name.to_string()),
486                    record_name: Some(#struct_name.to_string()),
487                    error: e.to_string(),
488                })?
489            }
490        }
491    };
492
493    Ok(quote! {
494        #field_ident: #final_decoder,
495    })
496}
497
498/// Generates encoders based on the Rust field type.
499/// Creates logic to encode a field from a Rust type to a fixed-width string.
500fn gen_field_encoder(
501    positional_field: &PositionalField,
502) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
503    let field_ident = &positional_field.ident;
504    let field_size = positional_field.size;
505
506    // If custom encoder is provided, use it instead of standard encoding
507    if let Some(custom_encoder) = &positional_field.custom_encoder {
508        return Ok(quote! {
509            result.push_str(&#custom_encoder(&self.#field_ident));
510        });
511    }
512
513    let parsed_type = parse_rust_type(&positional_field.rust_type)?;
514
515    let encoder_logic = match (parsed_type.optional, &parsed_type.rust_type) {
516        // Required fields
517        (false, SupportedRustTypes::String) => quote! {
518            result.push_str(&format!("{:<width$}", self.#field_ident, width = #field_size));
519        },
520        (false, SupportedRustTypes::Boolean) => quote! {
521            result.push_str(if self.#field_ident { "Y" } else { "N" });
522        },
523        (false, SupportedRustTypes::Int32 | SupportedRustTypes::Int64) => quote! {
524            if self.#field_ident >= 0 {
525                result.push_str(&format!("{:0>width$}", self.#field_ident, width = #field_size));
526            } else {
527                result.push_str(&format!("-{:0>width$}", self.#field_ident.abs(), width = #field_size - 1));
528            }
529        },
530        (
531            false,
532            SupportedRustTypes::UInt32
533            | SupportedRustTypes::UInt64
534            | SupportedRustTypes::UInt16
535            | SupportedRustTypes::UInt8,
536        ) => quote! {
537            result.push_str(&format!("{:0>width$}", self.#field_ident, width = #field_size));
538        },
539        (false, SupportedRustTypes::NaiveDate) => quote! {
540            result.push_str(&self.#field_ident.format("%Y%m%d").to_string());
541        },
542        (false, SupportedRustTypes::NaiveDateTime) => quote! {
543            result.push_str(&self.#field_ident.format("%Y%m%d%H%M%S").to_string());
544        },
545        (false, SupportedRustTypes::NaiveTime) => quote! {
546            result.push_str(&self.#field_ident.format("%H%M%S").to_string());
547        },
548        (false, SupportedRustTypes::ChronoTimeDelta) => quote! {
549            result.push_str(&timedelta_to_hhmmss(Some(self.#field_ident)));
550        },
551
552        // Optional fields
553        (true, SupportedRustTypes::String) => quote! {
554            result.push_str(&format!("{:<width$}", self.#field_ident.as_deref().unwrap_or(""), width = #field_size));
555        },
556        (true, SupportedRustTypes::Boolean) => quote! {
557            result.push_str(match self.#field_ident { Some(true) => "Y", Some(false) => "N", None => " " });
558        },
559        (true, SupportedRustTypes::Int32 | SupportedRustTypes::Int64) => quote! {
560            result.push_str(&match self.#field_ident {
561                Some(num) => {
562                    if num >= 0 {
563                        format!("{:0>width$}", num, width = #field_size)
564                    } else {
565                        format!("-{:0>width$}", num.abs(), width = #field_size - 1)
566                    }
567                },
568                None => " ".repeat(#field_size),
569            });
570        },
571        (
572            true,
573            SupportedRustTypes::UInt32
574            | SupportedRustTypes::UInt64
575            | SupportedRustTypes::UInt16
576            | SupportedRustTypes::UInt8,
577        ) => quote! {
578            result.push_str(&match self.#field_ident {
579                Some(num) => format!("{:0>width$}", num, width = #field_size),
580                None => " ".repeat(#field_size),
581            });
582        },
583        (true, SupportedRustTypes::NaiveDate) => quote! {
584            result.push_str(&match self.#field_ident {
585                Some(date) => date.format("%Y%m%d").to_string(),
586                None => " ".repeat(8),
587            });
588        },
589        (true, SupportedRustTypes::NaiveDateTime) => quote! {
590            result.push_str(&match self.#field_ident {
591                Some(dt) => dt.format("%Y%m%d%H%M%S").to_string(),
592                None => " ".repeat(14),
593            });
594        },
595        (true, SupportedRustTypes::NaiveTime) => quote! {
596            result.push_str(&match self.#field_ident {
597                Some(time) => time.format("%H%M%S").to_string(),
598                None => " ".repeat(6),
599            });
600        },
601        (true, SupportedRustTypes::ChronoTimeDelta) => quote! {
602            result.push_str(&timedelta_to_hhmmss(self.#field_ident));
603        },
604    };
605
606    Ok(encoder_logic)
607}
608
609/// Generates validators for each field
610fn gen_field_validator(
611    positional_field: &PositionalField,
612) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
613    let field_ident = &positional_field.ident;
614    let field_name = field_ident.to_string();
615    let field_size = positional_field.size;
616    let parsed_type = parse_rust_type(&positional_field.rust_type)?;
617
618    let validator_logic = match (parsed_type.optional, &parsed_type.rust_type) {
619        (false, SupportedRustTypes::String) => quote! {
620            if self.#field_ident.len() > #field_size {
621                errors.push(line_cutter::ValidationError {
622                    field_name: #field_name.to_string(),
623                    message: format!("max length is {} but found {}", #field_size, self.#field_ident.len()),
624                });
625            }
626        },
627        (true, SupportedRustTypes::String) => quote! {
628            if let Some(val) = &self.#field_ident {
629                if val.len() > #field_size {
630                    errors.push(line_cutter::ValidationError {
631                        field_name: #field_name.to_string(),
632                        message: format!("max length is {} but found {}", #field_size, val.len()),
633                    });
634                }
635            }
636        },
637        _ => quote! {},
638    };
639
640    Ok(validator_logic)
641}