lox_derive/
lib.rs

1// SPDX-FileCopyrightText: 2024 Andrei Zisu <matzipan@gmail.com>
2// SPDX-FileCopyrightText: 2025 Helge Eichhorn <git@helgeeichhorn.de>
3//
4// SPDX-License-Identifier: MPL-2.0
5
6use std::iter::zip;
7
8use darling::{FromDeriveInput, FromMeta, util::Flag};
9use proc_macro_crate::{FoundCrate, crate_name};
10use proc_macro2::{Ident, Span};
11use quote::{ToTokens, format_ident, quote};
12use syn::{
13    Data, DeriveInput, Error, Field, Fields, GenericParam, Generics, Index, parse_macro_input,
14    parse_quote,
15};
16
17fn add_trait_bounds(mut generics: Generics) -> Generics {
18    for param in &mut generics.params {
19        if let GenericParam::Type(ref mut type_param) = *param {
20            // Add a trait bound to each type parameter
21            type_param.bounds.push(parse_quote!(::std::fmt::Debug));
22        }
23    }
24    generics
25}
26
27#[proc_macro_derive(ApproxEq)]
28pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    let DeriveInput {
30        ident,
31        data,
32        generics,
33        ..
34    } = parse_macro_input!(input);
35
36    let lox_test_utils = match crate_name("lox-test-utils") {
37        Ok(FoundCrate::Itself) => quote!(crate),
38        Ok(FoundCrate::Name(name)) => {
39            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
40            quote!(::#ident)
41        }
42        Err(_) => quote!(::lox_test_utils),
43    };
44
45    let fields: Vec<proc_macro2::TokenStream> = match data {
46        Data::Struct(data) => match data.fields {
47            Fields::Named(fields) => fields
48                .named
49                .into_iter()
50                .map(|f| {
51                    let f = f.ident.unwrap();
52                    quote! {#f}
53                })
54                .collect::<Vec<proc_macro2::TokenStream>>(),
55            Fields::Unnamed(fields) => fields
56                .unnamed
57                .into_iter()
58                .enumerate()
59                .map(|(idx, _)| {
60                    let idx = Index::from(idx);
61                    quote! {#idx}
62                })
63                .collect(),
64            Fields::Unit => {
65                return Error::new(ident.span(), "unit structs are not supported")
66                    .into_compile_error()
67                    .into();
68            }
69        },
70        _ => {
71            return Error::new(ident.span(), format!("{} is not a struct", ident))
72                .into_compile_error()
73                .into();
74        }
75    }
76    .iter()
77    .map(|f| {
78        quote! {
79            results.merge(stringify!(#f).to_string(), self.#f.approx_eq(&rhs.#f, atol, rtol));
80        }
81    })
82    .collect();
83
84    let generics = add_trait_bounds(generics);
85    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
86
87    let results = quote! {#lox_test_utils::approx_eq::results::ApproxEqResults};
88
89    let output = quote! {
90        impl #impl_generics #lox_test_utils::approx_eq::ApproxEq for #ident #ty_generics #where_clause {
91            fn approx_eq(&self, rhs: &Self, atol: f64, rtol: f64) -> #results {
92                let mut results = #results::new();
93                #(#fields)*
94                results
95            }
96        }
97    };
98    output.into()
99}
100
101#[derive(FromMeta)]
102struct Scales {
103    ut1: Flag,
104    tdb: Flag,
105    dynamic: Flag,
106    multi: Flag,
107}
108
109#[derive(FromDeriveInput, Default)]
110#[darling(default, attributes(lox_time))]
111struct Opts {
112    disable: Option<Scales>,
113    error: Option<Ident>,
114}
115
116const SCALES: [&str; 6] = ["Tai", "Tcb", "Tcg", "Tdb", "Tt", "Ut1"];
117
118#[proc_macro_derive(OffsetProvider, attributes(lox_time))]
119pub fn derive_offset_provider(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
120    let input = parse_macro_input!(input);
121    let opts = match Opts::from_derive_input(&input) {
122        Ok(opts) => opts,
123        Err(err) => return err.write_errors().into(),
124    };
125    let DeriveInput { ident, .. } = input;
126
127    let lox_time = match crate_name("lox-time") {
128        Ok(FoundCrate::Itself) => quote!(crate),
129        Ok(FoundCrate::Name(name)) => {
130            let ident = syn::Ident::new(&name, proc_macro2::Span::call_site());
131            quote!(::#ident)
132        }
133        Err(_) => quote!(::lox_time),
134    };
135
136    let mut output = quote! {
137        impl #lox_time::offsets::OffsetProvider for #ident {}
138    };
139
140    let eop_error = quote! {#lox_time::offsets::MissingEopProviderError};
141    let delta = quote! {#lox_time::deltas::TimeDelta};
142    let try_offset = quote! {#lox_time::offsets::TryOffset};
143    let scales: Vec<proc_macro2::TokenStream> = SCALES
144        .iter()
145        .map(|scale| {
146            let scale = format_ident!("{}", scale);
147            quote! {#lox_time::time_scales::#scale}
148        })
149        .collect();
150    let tai = scales[0].clone();
151    let tcb = scales[1].clone();
152    let tcg = scales[2].clone();
153    let tdb = scales[3].clone();
154    let tt = scales[4].clone();
155    let ut1 = scales[5].clone();
156
157    // No-ops
158
159    for scale in &scales {
160        output.extend(quote! {
161            impl #try_offset<#scale, #scale> for #ident {
162                type Error = ::core::convert::Infallible;
163
164                fn try_offset(&self, _origin: #scale, _target: #scale, _delta: #lox_time::deltas::TimeDelta) -> Result<#delta, Self::Error> {
165                    Ok(#delta::default())
166                }
167            }
168        })
169    }
170
171    // TAI <-> TT
172
173    let d_tai_tt = quote! {#lox_time::offsets::D_TAI_TT};
174
175    output.extend(quote! {
176        impl #try_offset<#tai, #tt> for #ident
177        {
178            type Error = ::core::convert::Infallible;
179
180            fn try_offset(
181                &self,
182                _origin: #tai,
183                _target: #tt,
184                _delta: #delta,
185            ) -> Result<#delta, Self::Error> {
186                Ok(#d_tai_tt)
187            }
188        }
189
190        impl #try_offset<#tt, #tai> for #ident
191        {
192            type Error = ::core::convert::Infallible;
193
194            fn try_offset(
195                &self,
196                _origin: #tt,
197                _target: #tai,
198                _delta: #delta,
199            ) -> Result<#delta, Self::Error> {
200                Ok(-#d_tai_tt)
201            }
202        }
203    });
204
205    // TT <-> TCG
206
207    let tt_to_tcg = quote! {#lox_time::offsets::tt_to_tcg};
208    let tcg_to_tt = quote! {#lox_time::offsets::tcg_to_tt};
209
210    output.extend(quote! {
211        impl #try_offset<#tt, #tcg> for #ident
212        {
213            type Error = ::core::convert::Infallible;
214
215            fn try_offset(
216                &self,
217                _origin: #tt,
218                _target: #tcg,
219                delta: #delta,
220            ) -> Result<#delta, Self::Error> {
221                Ok(#tt_to_tcg(delta))
222            }
223        }
224
225        impl #try_offset<#tcg, #tt> for #ident
226        {
227            type Error = ::core::convert::Infallible;
228
229            fn try_offset(
230                &self,
231                _origin: #tcg,
232                _target: #tt,
233                delta: #delta,
234            ) -> Result<#delta, Self::Error> {
235                Ok(#tcg_to_tt(delta))
236            }
237        }
238    });
239
240    // TCB <-> TDB
241
242    let tdb_to_tcb = quote! {#lox_time::offsets::tdb_to_tcb};
243    let tcb_to_tdb = quote! {#lox_time::offsets::tcb_to_tdb};
244
245    output.extend(quote! {
246        impl #try_offset<#tdb, #tcb> for #ident
247        {
248            type Error = ::core::convert::Infallible;
249
250            fn try_offset(
251                &self,
252                _origin: #tdb,
253                _target: #tcb,
254                delta: #delta,
255            ) -> Result<#delta, Self::Error> {
256                Ok(#tdb_to_tcb(delta))
257            }
258        }
259
260        impl #try_offset<#tcb, #tdb> for #ident
261        {
262            type Error = ::core::convert::Infallible;
263
264            fn try_offset(
265                &self,
266                _origin: #tcb,
267                _target: #tdb,
268                delta: #delta,
269            ) -> Result<#delta, Self::Error> {
270                Ok(#tcb_to_tdb(delta))
271            }
272        }
273    });
274
275    // TT <-> TDB
276
277    let disable_tdb = opts
278        .disable
279        .as_ref()
280        .is_some_and(|disable| disable.tdb.is_present());
281
282    if !disable_tdb {
283        let tdb_to_tt = quote! {#lox_time::offsets::tdb_to_tt};
284        let tt_to_tdb = quote! {#lox_time::offsets::tt_to_tdb};
285
286        output.extend(quote! {
287            impl #try_offset<#tdb, #tt> for #ident
288            {
289                type Error = ::core::convert::Infallible;
290
291                fn try_offset(
292                    &self,
293                    _origin: #tdb,
294                    _target: #tt,
295                    delta: #delta,
296                ) -> Result<#delta, Self::Error> {
297                    Ok(#tdb_to_tt(delta))
298                }
299            }
300
301            impl #try_offset<#tt, #tdb> for #ident
302            {
303                type Error = ::core::convert::Infallible;
304
305                fn try_offset(
306                    &self,
307                    _origin: #tt,
308                    _target: #tdb,
309                    delta: #delta,
310                ) -> Result<#delta, Self::Error> {
311                    Ok(#tt_to_tdb(delta))
312                }
313            }
314        });
315    }
316
317    // UT1
318
319    let disable_ut1 = opts
320        .disable
321        .as_ref()
322        .is_some_and(|disable| disable.ut1.is_present());
323
324    if !disable_ut1 {
325        for scale in scales.split_last().unwrap().1 {
326            output.extend(quote! {
327                impl #try_offset<#ut1, #scale> for #ident
328                {
329                    type Error = #eop_error;
330
331                    fn try_offset(
332                        &self,
333                        _origin: #ut1,
334                        _target: #scale,
335                        delta: #delta,
336                    ) -> Result<#delta, Self::Error> {
337                        Err(#eop_error)
338                    }
339                }
340
341                impl #try_offset<#scale, #ut1> for #ident
342                {
343                    type Error = #eop_error;
344
345                    fn try_offset(
346                        &self,
347                        _origin: #scale,
348                        _target: #ut1,
349                        delta: #delta,
350                    ) -> Result<#delta, Self::Error> {
351                        Err(#eop_error)
352                    }
353                }
354            });
355        }
356    }
357
358    // DynTimeScale
359
360    let disable_dyn = opts
361        .disable
362        .as_ref()
363        .is_some_and(|disable| disable.dynamic.is_present());
364
365    if !disable_dyn {
366        let dyn_scale = quote! {#lox_time::time_scales::DynTimeScale};
367        let dyn_scales: Vec<proc_macro2::TokenStream> = SCALES
368            .iter()
369            .map(|scale| {
370                let scale = format_ident!("{}", scale);
371                quote! {#dyn_scale::#scale}
372            })
373            .collect();
374        let error = opts
375            .error
376            .map(|err| {
377                let err = quote! {#err};
378                // FIXME: Remove once `!` lands on stable.
379                output.extend(quote! {
380                    impl From<::core::convert::Infallible> for #err {
381                        fn from(_: ::core::convert::Infallible) -> Self {
382                            #err::default()
383                        }
384                    }
385                });
386                err
387            })
388            .unwrap_or(eop_error.clone());
389
390        let mut arms = quote! {};
391
392        for (dyn_scale1, scale1) in zip(&dyn_scales, &scales) {
393            for (dyn_scale2, scale2) in zip(&dyn_scales, &scales) {
394                if scale1.to_string() == scale2.to_string() {
395                    continue;
396                }
397                arms.extend(quote! {
398                    (#dyn_scale1, #dyn_scale2) => {
399                        Ok(self.try_offset(#scale1, #scale2, delta)?)
400                    }
401                })
402            }
403        }
404
405        output.extend(quote! {
406            impl #try_offset<#dyn_scale, #dyn_scale> for #ident {
407                type Error = #error;
408
409                fn try_offset(&self, origin: #dyn_scale, target: #dyn_scale, delta: #delta) -> Result<#delta, Self::Error> {
410                    match (origin, target) {
411                        #arms
412                        (_, _) => Ok(#delta::default()),
413                    }
414                }
415            }
416        });
417
418        for scale in scales.split_last().unwrap().1 {
419            let mut arms1 = quote! {};
420            let mut arms2 = quote! {};
421
422            for (dyn_scale, scale) in zip(&dyn_scales, &scales) {
423                arms1.extend(quote! {
424                    #dyn_scale => Ok(self.try_offset(#scale, target, delta)?),
425                });
426                arms2.extend(quote! {
427                    #dyn_scale => Ok(self.try_offset(origin, #scale, delta)?),
428                });
429            }
430
431            output.extend(quote! {
432                impl #try_offset<#dyn_scale, #scale> for #ident {
433                    type Error = #error;
434
435                    fn try_offset(&self, origin: #dyn_scale, target: #scale, delta: #delta) -> Result<#delta, Self::Error> {
436                        match origin {
437                            #arms1
438                        }
439                    }
440                }
441
442                impl #try_offset<#scale, #dyn_scale> for #ident {
443                    type Error = #error;
444
445                    fn try_offset(&self, origin: #scale, target: #dyn_scale, delta: #delta) -> Result<#delta, Self::Error> {
446                        match target {
447                            #arms2
448                        }
449                    }
450                }
451            });
452        }
453    }
454
455    // Two-step transformations
456
457    let disable_multi = opts
458        .disable
459        .as_ref()
460        .is_some_and(|disable| disable.multi.is_present());
461
462    if !disable_multi && !disable_tdb {
463        let multis = [
464            (&tai, &tt, &tdb),
465            (&tdb, &tt, &tcg),
466            (&tai, &tt, &tcg),
467            (&tai, &tdb, &tcb),
468            (&tt, &tdb, &tcb),
469            (&tcb, &tdb, &tcg),
470        ];
471
472        let two_step = quote! {#lox_time::offsets::two_step_offset};
473
474        for (origin, via, target) in multis {
475            output.extend(quote!{
476                impl #try_offset<#origin, #target> for #ident
477                {
478                    type Error = ::core::convert::Infallible;
479
480                    fn try_offset(&self, origin: #origin, target: #target, delta: #delta) -> Result<#delta, Self::Error> {
481                        Ok(#two_step(self, origin, #via, target, delta))
482                    }
483                }
484
485                impl #try_offset<#target, #origin> for #ident
486                {
487                    type Error = ::core::convert::Infallible;
488
489                    fn try_offset(&self, origin: #target, target: #origin, delta: #delta) -> Result<#delta, Self::Error> {
490                        Ok(#two_step(self, origin, #via, target, delta))
491                    }
492                }
493            });
494        }
495    }
496
497    output.into()
498}
499
500fn generate_call_to_deserializer_for_covariance_matrix_kvn_type(
501    expected_kvn_name: &str,
502) -> proc_macro2::TokenStream {
503    quote! {
504        match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
505            None => Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
506                keyword: #expected_kvn_name.to_string()
507            }),
508            Some(next_line) => {
509                let result = crate::ndm::kvn::parser::parse_kvn_covariance_matrix(
510                    lines,
511                ).map_err(|x| match crate::ndm::kvn::KvnDeserializerErr::from(x) {
512                    crate::ndm::kvn::KvnDeserializerErr::InvalidCovarianceMatrixFormat { .. } => crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedKeyword {
513                        // This is empty because we just want to tell the
514                        // vector iterator to stop the iteration.
515                        found: "".to_string(),
516                        expected: "".to_string(),
517                    },
518                    crate::ndm::kvn::KvnDeserializerErr::UnexpectedEndOfInput { keyword } => crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
519                        keyword
520                    },
521                    e => e,
522                })?;
523
524                Ok(result)
525            }
526        }
527    }
528}
529
530fn generate_call_to_deserializer_for_kvn_type(
531    type_name: &str,
532    type_name_new: &syn::Path,
533    expected_kvn_name: &str,
534    with_keyword_matching: bool,
535    unpack_value: bool,
536) -> Result<proc_macro2::TokenStream, proc_macro::TokenStream> {
537    let unpack_insert = if unpack_value {
538        quote! { .value }
539    } else {
540        quote! {}
541    };
542
543    match type_name {
544        "String" | "f64" | "i32" | "u64" | "NonNegativeDouble" | "NegativeDouble"
545        | "PositiveDouble" => {
546            let parser = match type_name {
547                "String" => quote! {
548                    crate::ndm::kvn::parser::parse_kvn_string_line(
549                        next_line
550                    ).map_err(|x| crate::ndm::kvn::KvnDeserializerErr::from(x))?
551                },
552                "f64" | "NonNegativeDouble" | "NegativeDouble" | "PositiveDouble" => quote! {
553                    crate::ndm::kvn::parser::parse_kvn_numeric_line(
554                        next_line,
555                        true, //@TODO
556                    ).map_err(|x| crate::ndm::kvn::KvnDeserializerErr::from(x))?
557                },
558                "i32" | "u64" => quote! {
559                    crate::ndm::kvn::parser::parse_kvn_integer_line(
560                        next_line,
561                        true, //@TODO
562                    ).map_err(|x| crate::ndm::kvn::KvnDeserializerErr::from(x))?
563                },
564                // Assumes the match list here exhaustively matches the one from above
565                _ => unreachable!(),
566            };
567
568            if with_keyword_matching {
569                Ok(quote! {
570                    match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
571                        None => Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
572                            keyword: #expected_kvn_name.to_string()
573                        }),
574                        Some(next_line) => {
575                            let line_matches = crate::ndm::kvn::parser::kvn_line_matches_key(
576                                #expected_kvn_name,
577                                next_line,
578                            )?;
579
580                            let result = if line_matches {
581                                let next_line = lines.next().unwrap();
582
583                                Ok(#parser #unpack_insert)
584                            } else {
585                                Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedKeyword {
586                                    found: next_line.to_string(),
587                                    expected: #expected_kvn_name.to_string(),
588                                })
589                            };
590
591                            result
592                        }
593                    }
594                })
595            } else {
596                Ok(quote! {
597                   {
598                      let next_line = lines.next().unwrap();
599                      #parser #unpack_insert
600                   }
601                })
602            }
603        }
604        "common::StateVectorAccType" => Ok(quote! {
605            match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
606                None => Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
607                    keyword: #expected_kvn_name.to_string()
608                }),
609                Some(next_line) => {
610                    let result = crate::ndm::kvn::parser::parse_kvn_state_vector(
611                        next_line,
612                    ).map_err(|x| match crate::ndm::kvn::KvnDeserializerErr::from(x) {
613                        crate::ndm::kvn::KvnDeserializerErr::InvalidStateVectorFormat { .. } => crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedKeyword {
614                            // This is empty because we just want to tell the
615                            // vector iterator to stop the iteration.
616                            found: "".to_string(),
617                            expected: "".to_string(),
618                        },
619                        e => e,
620                    }).map(|x| x.into());
621
622                    if result.is_ok() {
623                        let _ = lines.next().unwrap();
624                    }
625
626                    result
627                }
628            }
629        }),
630        _ => Ok(quote! {
631           {
632                let has_next_line = crate::ndm::kvn::parser::get_next_nonempty_line(lines).is_some();
633
634                let result = if has_next_line {
635                    #type_name_new::deserialize(lines)
636                } else {
637                    Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
638                          keyword: #expected_kvn_name.to_string()
639                    })
640                };
641
642                result
643            }
644        }),
645    }
646}
647
648fn get_generic_type_argument(field: &Field) -> Option<(String, &syn::Path)> {
649    if let syn::Type::Path(type_path) = &field.ty {
650        let path_part = type_path.path.segments.first();
651        if let Some(path_part) = path_part
652            && let syn::PathArguments::AngleBracketed(type_argument) = &path_part.arguments
653            && let Some(syn::GenericArgument::Type(syn::Type::Path(r#type))) =
654                &type_argument.args.first()
655        {
656            return Some((
657                r#type
658                    .path
659                    .segments
660                    .clone()
661                    .into_iter()
662                    .map(|ident| ident.to_token_stream().to_string())
663                    .reduce(|a, b| a + "::" + &b)
664                    .unwrap(),
665                &r#type.path,
666            ));
667        }
668    }
669
670    None
671}
672
673fn generate_call_to_deserializer_for_option_type(
674    expected_kvn_name: &str,
675    field: &Field,
676) -> Result<proc_macro2::TokenStream, proc_macro::TokenStream> {
677    let (type_name, type_ident) = get_generic_type_argument(field).ok_or(
678        syn::Error::new_spanned(field, "Malformed type for `#[derive(KvnDeserialize)]`")
679            .into_compile_error(),
680    )?;
681
682    let deserializer_for_kvn_type = generate_call_to_deserializer_for_kvn_type(
683        type_name.as_ref(),
684        type_ident,
685        expected_kvn_name,
686        true,
687        true,
688    )?;
689
690    let condition_shortcut = match type_name.as_str() {
691        "String" | "f64" | "i32" | "u64" => quote! {},
692        _ => quote! { ! #type_ident::should_check_key_match() || },
693    };
694
695    let value = match type_name.as_ref() {
696        "NonNegativeDouble" | "NegativeDouble" | "PositiveDouble" => {
697            quote! { #type_ident (item) }
698        }
699        _ => quote! { item },
700    };
701
702    Ok(quote! {
703        match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
704            None => None,
705            Some(next_line) => {
706                let line_matches = crate::ndm::kvn::parser::kvn_line_matches_key(
707                    #expected_kvn_name,
708                    next_line,
709                )?;
710
711                if #condition_shortcut line_matches {
712                    let result = #deserializer_for_kvn_type;
713
714                    match result {
715                        Ok(item) => Some(#value),
716                        Err(crate::ndm::kvn::KvnDeserializerErr::UnexpectedKeyword { .. }) |
717                        Err(crate::ndm::kvn::KvnDeserializerErr::UnexpectedEndOfInput { .. }) => None,
718                        Err(e) => Err(e)?,
719                    }
720                } else {
721                    None
722                }
723            }
724        }
725    })
726}
727
728fn generate_call_to_deserializer_for_vec_type(
729    expected_kvn_name: &str,
730    field: &Field,
731) -> Result<proc_macro2::TokenStream, proc_macro::TokenStream> {
732    let (type_name, type_ident) = get_generic_type_argument(field).ok_or(
733        syn::Error::new_spanned(field, "Malformed type for `#[derive(KvnDeserialize)]`")
734            .into_compile_error(),
735    )?;
736
737    let expected_kvn_name = expected_kvn_name.trim_end_matches("_LIST");
738
739    let deserializer_for_kvn_type = generate_call_to_deserializer_for_kvn_type(
740        type_name.as_ref(),
741        type_ident,
742        expected_kvn_name,
743        true,
744        true,
745    )?;
746
747    Ok(quote! {
748        {
749            let mut items: Vec<#type_ident> = Vec::new();
750
751            let mut is_retry = false;
752
753            loop {
754                let result = #deserializer_for_kvn_type;
755
756                match result {
757                    Ok(item) => {
758                        is_retry = false;
759                        items.push(item)
760                    },
761                    Err(crate::ndm::kvn::KvnDeserializerErr::UnexpectedKeyword { .. }) |
762                    Err(crate::ndm::kvn::KvnDeserializerErr::UnexpectedEndOfInput { .. }) => if is_retry {
763                        break;
764                    } else {
765                        is_retry = true;
766                        continue;
767                    },
768                    Err(e) => Err(e)?,
769                }
770            }
771
772            items
773        }
774    })
775}
776
777fn get_prefix_and_postfix_keyword(attrs: &[syn::Attribute]) -> Option<(String, String)> {
778    let mut keyword: Option<syn::LitStr> = None;
779
780    for attr in attrs.iter() {
781        if !attr.path().is_ident("kvn") {
782            continue;
783        }
784
785        let _ = attr.parse_nested_meta(|meta| {
786            if meta.path.is_ident("prefix_and_postfix_keyword") {
787                let value = meta.value()?;
788                keyword = value.parse()?;
789
790                Ok(())
791            } else {
792                Err(meta.error("unsupported attribute"))
793            }
794        });
795    }
796
797    keyword.map(|keyword| {
798        let keyword = keyword.value().to_uppercase();
799
800        (format!("{keyword}_START"), format!("{keyword}_STOP"))
801    })
802}
803
804fn is_value_unit_struct(item: &DeriveInput) -> bool {
805    item.attrs.iter().any(|attr| {
806        attr.path().is_ident("kvn")
807            && attr
808                .parse_nested_meta(|meta| {
809                    if meta.path.is_ident("value_unit_struct") {
810                        Ok(())
811                    } else {
812                        Err(meta.error("unsupported attribute"))
813                    }
814                })
815                .is_ok()
816    })
817}
818
819fn extract_type_path(ty: &syn::Type) -> Option<&syn::Path> {
820    match *ty {
821        syn::Type::Path(ref typepath) if typepath.qself.is_none() => Some(&typepath.path),
822        _ => None,
823    }
824}
825
826fn handle_version_field(
827    type_name: &proc_macro2::Ident,
828    field: &syn::Field,
829) -> Result<(proc_macro2::TokenStream, proc_macro2::TokenStream), proc_macro2::TokenStream> {
830    let string_type_name = type_name.to_string();
831
832    if !string_type_name.ends_with("Type") {
833        return Err(syn::Error::new_spanned(
834            type_name,
835            "Types with \"version\" field should be of the form SomethingType",
836        )
837        .into_compile_error());
838    }
839
840    let message_type_name = string_type_name
841        .trim_end_matches("Type")
842        .to_string()
843        .to_uppercase();
844
845    let field_name = field.ident.as_ref().unwrap();
846
847    // 7.4.4 Keywords must be uppercase and must not contain blanks
848    let expected_kvn_name = format!("CCSDS_{message_type_name}_VERS");
849
850    // Unwrap is okay because we expect this to be a well defined type path,
851    // because this is not a general-purpose proc macro, but one that we
852    // control the input to ourselves.
853    let field_type = extract_type_path(&field.ty)
854        .unwrap()
855        .to_token_stream()
856        .to_string();
857    let field_type_new = extract_type_path(&field.ty).unwrap();
858
859    let parser = generate_call_to_deserializer_for_kvn_type(
860        &field_type,
861        field_type_new,
862        &expected_kvn_name,
863        true,
864        true,
865    )?;
866
867    Ok((
868        quote! { let #field_name = #parser?; },
869        quote! { #field_name, },
870    ))
871}
872
873fn deserializer_for_struct_with_named_fields(
874    type_name: &proc_macro2::Ident,
875    fields: &syn::punctuated::Punctuated<syn::Field, syn::token::Comma>,
876    is_value_unit_struct: bool,
877    struct_level_prefix_and_postfix_keyword: Option<(String, String)>,
878) -> proc_macro2::TokenStream {
879    if &type_name.to_string() == "UserDefinedType" {
880        //@TODO
881        return quote! {
882            Ok(Default::default())
883        };
884    }
885
886    if is_value_unit_struct {
887        let mut deserializer = None;
888        let mut unit_type: Option<String> = None;
889        let mut unit_field_name_ident: Option<&proc_macro2::Ident> = None;
890        let mut field_type: Option<String> = None;
891        let mut field_type_new: Option<&syn::Path> = None;
892
893        for (index, field) in fields.iter().enumerate() {
894            // Unwrap is okay because we only support named structs
895            let field_name_ident = field.ident.as_ref().unwrap();
896
897            let field_name = field_name_ident.to_token_stream().to_string();
898
899            match index {
900                0 => {
901                    if field_name.as_str() != "base" {
902                        return syn::Error::new_spanned(
903                            field,
904                            "The first field in a value unit struct should be called \"base\"",
905                        )
906                        .into_compile_error();
907                    }
908
909                    // Unwrap is okay because we expect this to be a well defined type path,
910                    // because this is not a general-purpose proc macro, but one that we
911                    // control the input to ourselves.
912                    let local_field_type = extract_type_path(&field.ty)
913                        .unwrap()
914                        .to_token_stream()
915                        .to_string();
916                    let local_field_type_new = extract_type_path(&field.ty).unwrap();
917
918                    match local_field_type.as_str() {
919                        "KvnDateTimeValue" | "String" | "f64" | "i32" | "NonNegativeDouble"
920                        | "NegativeDouble" | "PositiveDouble" => {
921                            match generate_call_to_deserializer_for_kvn_type(
922                                &local_field_type,
923                                local_field_type_new,
924                                "undefined",
925                                false,
926                                false,
927                            ) {
928                                Ok(deserializer_for_kvn_type) => {
929                                    deserializer = Some(deserializer_for_kvn_type)
930                                }
931                                Err(e) => return e.into(),
932                            }
933                        }
934
935                        _ => {
936                            return syn::Error::new_spanned(
937                                field,
938                                "Unsupported field type for deserializer",
939                            )
940                            .into_compile_error();
941                        }
942                    };
943
944                    field_type = Some(local_field_type);
945                    field_type_new = Some(local_field_type_new);
946                }
947                1 => {
948                    if field_name.as_str() != "units" && field_name.as_str() != "parameter" {
949                        return syn::Error::new_spanned(
950                             field,
951                             "The second field in a value unit struct should be called \"units\" or \"parameter\"",
952                         )
953                         .into_compile_error();
954                    }
955
956                    unit_type = get_generic_type_argument(field).map(|x| x.0);
957                    unit_field_name_ident = Some(field_name_ident);
958                }
959                _ => {
960                    return syn::Error::new_spanned(
961                        field,
962                        "Only two fields are allowed: \"base\" and (\"units\" or \"parameters\"",
963                    )
964                    .into_compile_error();
965                }
966            }
967        }
968
969        // This unwrap is okay because we know the field exists. If it didn't exist we would've thrown an error.
970        let unit_type = unit_type.unwrap();
971        let unit_field_name_ident = unit_field_name_ident.unwrap();
972        let field_type = field_type.unwrap();
973        let field_type_new = field_type_new.unwrap();
974
975        let unit_type_ident = syn::Ident::new(&unit_type, Span::call_site());
976
977        let base = match field_type.as_str() {
978            "NonNegativeDouble" | "NegativeDouble" | "PositiveDouble" => {
979                quote! { #field_type_new (kvn_value.value) }
980            }
981            _ => quote! { kvn_value.value },
982        };
983
984        match deserializer {
985            None => syn::Error::new_spanned(fields, "Unable to create deserializer for struct")
986                .into_compile_error(),
987            Some(deserializer) => quote! {
988                let kvn_value = #deserializer;
989                Ok(#type_name {
990                    base: #base,
991                    #unit_field_name_ident: kvn_value.unit.map(|unit| #unit_type_ident (unit)),
992                })
993            },
994        }
995    } else {
996        let field_deserializers: Result<Vec<_>, _> = fields.iter().filter(|field| {
997            // For OemCovarianceMatrixType we filter the types which start with
998            // cx, cy and cz because we populate those differently
999            if type_name != "OemCovarianceMatrixType" {
1000                return true
1001            }
1002
1003            // Unwrap is okay because we only support named structs
1004            let field_name = field.ident.as_ref().unwrap().to_token_stream().to_string();
1005
1006            !field_name.starts_with("cx")
1007                && !field_name.starts_with("cy")
1008                && !field_name.starts_with("cz")
1009        }).map(|field| {
1010                let field_name = field.ident.as_ref().unwrap();
1011
1012                // Unwrap is okay because we only support named structs
1013                // 7.4.4 Keywords must be uppercase and must not contain blanks
1014                let expected_kvn_name = field_name.to_token_stream().to_string().to_uppercase();
1015
1016                // Unwrap is okay because we expect this to be a well defined type path,
1017                // because this is not a general-purpose proc macro, but one that we
1018                // control the input to ourselves.
1019                let field_type_new = extract_type_path(&field.ty).unwrap();
1020
1021                // Unwrap is okay becuase we always expect at least one type
1022                let field_main_type = field_type_new.segments.iter()
1023                    .next_back()
1024                    .unwrap()
1025                    .ident
1026                    .to_string();
1027
1028                if field_name == "version" {
1029                    return handle_version_field(type_name, field);
1030                }
1031
1032                let parser = match field_main_type.as_str() {
1033                    "String" | "f64" | "i32" => {
1034                        let deserializer_for_kvn_type = generate_call_to_deserializer_for_kvn_type(
1035                            &field_main_type,
1036                            field_type_new,
1037                            &expected_kvn_name,
1038                            true,
1039                            true,
1040                        )?;
1041
1042                        quote! {
1043                            #deserializer_for_kvn_type?
1044                        }
1045                    }
1046                    "Option" => {
1047                        generate_call_to_deserializer_for_option_type(
1048                            &expected_kvn_name,
1049                            field
1050                        )?
1051                    }
1052                    "Vec" => generate_call_to_deserializer_for_vec_type(&expected_kvn_name, field)?,
1053                    _ => {
1054
1055                        let condition_shortcut = match field_main_type.as_str() {
1056                            "String" => quote! {},
1057                            _ => quote! { ! #field_type_new::should_check_key_match() || },
1058                        };
1059
1060                        quote! {
1061                            match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
1062                                None => Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
1063                                    keyword: #expected_kvn_name.to_string()
1064                                })?,
1065                                Some(next_line) => {
1066                                    let line_matches = crate::ndm::kvn::parser::kvn_line_matches_key(
1067                                        #expected_kvn_name,
1068                                        next_line,
1069                                    )?;
1070
1071                                    if #condition_shortcut line_matches {
1072                                        #field_type_new::deserialize(lines)?
1073                                    } else {
1074                                        Err(crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedKeyword {
1075                                            found: next_line.to_string(),
1076                                            expected: #expected_kvn_name.to_string(),
1077                                        })?
1078                                    }
1079                                }
1080                            }
1081                        }
1082                    }
1083                };
1084
1085                let field_prefix_and_postfix_keyword_checks = get_prefix_and_postfix_keyword(&field.attrs);
1086
1087                let wrapped_parser = add_prefix_and_postfix_keyword_checks(field_prefix_and_postfix_keyword_checks, parser, true);
1088
1089                Ok((
1090                    quote! { let #field_name = #wrapped_parser; },
1091                    quote! { #field_name, }
1092                ))
1093             })
1094             .collect();
1095
1096        if let Err(e) = field_deserializers {
1097            return e;
1098        }
1099
1100        let mut field_deserializers = field_deserializers.unwrap();
1101
1102        if type_name == "OemCovarianceMatrixType" {
1103            let covariance_matrix_parser =
1104                generate_call_to_deserializer_for_covariance_matrix_kvn_type("COVARIANCE_MATRIX");
1105
1106            field_deserializers.push((
1107                quote! { let covariance_matrix = #covariance_matrix_parser?; },
1108                quote! {},
1109            ));
1110
1111            for (field, field_type) in [
1112                ("cx_x", "PositionCovarianceType"),
1113                ("cy_x", "PositionCovarianceType"),
1114                ("cy_y", "PositionCovarianceType"),
1115                ("cz_x", "PositionCovarianceType"),
1116                ("cz_y", "PositionCovarianceType"),
1117                ("cz_z", "PositionCovarianceType"),
1118                ("cx_dot_x", "PositionVelocityCovarianceType"),
1119                ("cx_dot_y", "PositionVelocityCovarianceType"),
1120                ("cx_dot_z", "PositionVelocityCovarianceType"),
1121                ("cx_dot_x_dot", "VelocityCovarianceType"),
1122                ("cy_dot_x", "PositionVelocityCovarianceType"),
1123                ("cy_dot_y", "PositionVelocityCovarianceType"),
1124                ("cy_dot_z", "PositionVelocityCovarianceType"),
1125                ("cy_dot_x_dot", "VelocityCovarianceType"),
1126                ("cy_dot_y_dot", "VelocityCovarianceType"),
1127                ("cz_dot_x", "PositionVelocityCovarianceType"),
1128                ("cz_dot_y", "PositionVelocityCovarianceType"),
1129                ("cz_dot_z", "PositionVelocityCovarianceType"),
1130                ("cz_dot_x_dot", "VelocityCovarianceType"),
1131                ("cz_dot_y_dot", "VelocityCovarianceType"),
1132                ("cz_dot_z_dot", "VelocityCovarianceType"),
1133            ] {
1134                let field_ident = syn::Ident::new(field, Span::call_site());
1135                let type_ident = syn::Ident::new(field_type, Span::call_site());
1136
1137                field_deserializers.push((
1138                    quote! {},
1139                    quote! {
1140                        #field_ident: #type_ident {
1141                            base: covariance_matrix.#field_ident,
1142                            units: None,
1143                        },
1144                    },
1145                ));
1146            }
1147        }
1148
1149        let (field_deserializers, fields): (Vec<_>, Vec<_>) =
1150            field_deserializers.into_iter().unzip();
1151
1152        let parser_to_wrap = quote! {
1153            #(#field_deserializers)*
1154
1155            Ok(#type_name {
1156                #(#fields)*
1157            })
1158        };
1159
1160        let wrapped_parser = add_prefix_and_postfix_keyword_checks(
1161            struct_level_prefix_and_postfix_keyword,
1162            parser_to_wrap,
1163            false,
1164        );
1165
1166        quote! {
1167            #wrapped_parser
1168        }
1169    }
1170}
1171
1172fn add_prefix_and_postfix_keyword_checks(
1173    prefix_and_postfix_keyword: Option<(String, String)>,
1174    parser_to_wrap: proc_macro2::TokenStream,
1175    is_field: bool,
1176) -> proc_macro2::TokenStream {
1177    match prefix_and_postfix_keyword {
1178        None => parser_to_wrap,
1179        Some((prefix_keyword, postfix_keyword)) => {
1180            let mismatch_handler = if is_field {
1181                quote! { Default::default() }
1182            } else {
1183                quote! {
1184                    Err(
1185                        crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedKeyword {
1186                            found: next_line.to_string(),
1187                            expected: #prefix_keyword.to_string(),
1188                        },
1189                    )?
1190                }
1191            };
1192
1193            quote! {
1194
1195                match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
1196                    None => Err(
1197                        crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
1198                            keyword: #prefix_keyword.to_string(),
1199                        },
1200                    )?,
1201
1202                    Some(next_line) => {
1203                        let line_matches = crate::ndm::kvn::parser::kvn_line_matches_key(
1204                            #prefix_keyword,
1205                            next_line,
1206                        )?;
1207
1208                        if line_matches {
1209                            lines.next().unwrap();
1210
1211                            let result = { #parser_to_wrap };
1212
1213                            match crate::ndm::kvn::parser::get_next_nonempty_line(lines) {
1214                                None =>  Err(
1215                                    crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedEndOfInput {
1216                                        keyword: #postfix_keyword.to_string(),
1217                                    },
1218                                )?,
1219
1220                                Some(next_line) => {
1221                                    let line_matches = crate::ndm::kvn::parser::kvn_line_matches_key(
1222                                        #postfix_keyword,
1223                                        next_line,
1224                                    )?;
1225
1226                                    if line_matches {
1227                                        lines.next().unwrap();
1228                                    } else {
1229                                        Err(
1230                                            crate::ndm::kvn::KvnDeserializerErr::<String>::UnexpectedKeyword {
1231                                                found: next_line.to_string(),
1232                                                expected: #postfix_keyword.to_string(),
1233                                            },
1234                                        )?
1235                                    }
1236                                }
1237                            };
1238
1239                            result
1240                        } else {
1241                            #mismatch_handler
1242                        }
1243                    }
1244                }
1245            }
1246        }
1247    }
1248}
1249
1250fn deserializers_for_struct_with_unnamed_fields(
1251    type_name: &proc_macro2::Ident,
1252    fields: &syn::punctuated::Punctuated<Field, syn::token::Comma>,
1253) -> proc_macro2::TokenStream {
1254    let field = fields
1255        .first()
1256        .expect("We expect exactly one item in structs with unnamed fields");
1257
1258    if &type_name.to_string() == "EpochType" {
1259        return quote! {
1260            Ok(#type_name (
1261                crate::ndm::kvn::parser::parse_kvn_datetime_line(
1262                    lines.next().unwrap()
1263                ).map_err(|x| crate::ndm::kvn::KvnDeserializerErr::from(x))
1264                .map(|x| x)?.full_value
1265            ))
1266        };
1267    }
1268
1269    // Unwrap is okay because we expect this to be a well defined type path,
1270    // because this is not a general-purpose proc macro, but one that we
1271    // control the input to ourselves.
1272    let field_type_new = extract_type_path(&field.ty).unwrap();
1273    let field_type = field_type_new.to_token_stream().to_string();
1274
1275    let deserializer_for_kvn_type = generate_call_to_deserializer_for_kvn_type(
1276        &field_type,
1277        field_type_new,
1278        "unnamed field",
1279        false,
1280        true,
1281    );
1282
1283    let deserializer_for_kvn_type = match deserializer_for_kvn_type {
1284        Ok(deserializer_for_kvn_type) => deserializer_for_kvn_type,
1285        Err(e) => return e.into(),
1286    };
1287
1288    quote! {
1289        Ok(#type_name (
1290            #deserializer_for_kvn_type
1291        ))
1292    }
1293}
1294
1295#[proc_macro_derive(KvnDeserialize, attributes(kvn))]
1296pub fn derive_kvn_deserialize(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
1297    let item = syn::parse_macro_input!(item as syn::DeriveInput);
1298    let type_name = &item.ident;
1299    let is_value_unit_struct = is_value_unit_struct(&item);
1300    let prefix_and_postfix_keyword = get_prefix_and_postfix_keyword(&item.attrs);
1301
1302    let Data::Struct(strukt) = item.data else {
1303        return syn::Error::new_spanned(
1304            &item,
1305            "only structs are supported for `#[derive(KvnDeserialize)]`",
1306        )
1307        .into_compile_error()
1308        .into();
1309    };
1310
1311    let (struct_deserializer, should_check_key_match) = match strukt.fields {
1312        Fields::Named(syn::FieldsNamed { named, .. }) => (
1313            deserializer_for_struct_with_named_fields(
1314                type_name,
1315                &named,
1316                is_value_unit_struct,
1317                prefix_and_postfix_keyword,
1318            ),
1319            is_value_unit_struct,
1320        ),
1321        Fields::Unnamed(syn::FieldsUnnamed { unnamed, .. }) => (
1322            deserializers_for_struct_with_unnamed_fields(type_name, &unnamed),
1323            true,
1324        ),
1325        _ => {
1326            return syn::Error::new_spanned(
1327                &strukt.fields,
1328                "only named fields are supported for `#[derive(KvnDeserialize)]`",
1329            )
1330            .into_compile_error()
1331            .into();
1332        }
1333    };
1334
1335    let (impl_generics, type_generics, where_clause) = item.generics.split_for_impl();
1336
1337    let struct_deserializer = quote! {
1338        impl #impl_generics crate::ndm::kvn::KvnDeserializer for #type_name #type_generics
1339        #where_clause
1340        {
1341            fn deserialize<'a>(lines: &mut ::std::iter::Peekable<impl Iterator<Item = &'a str>>)
1342            -> Result<#type_name, crate::ndm::kvn::KvnDeserializerErr<String>> {
1343                #struct_deserializer
1344            }
1345
1346            fn should_check_key_match () -> bool {
1347                #should_check_key_match
1348            }
1349        }
1350    };
1351
1352    struct_deserializer.into()
1353}