Skip to main content

affn_derive/
lib.rs

1//! # affn-derive
2//!
3//! Derive macros for the `affn` crate, providing `#[derive(ReferenceFrame)]`
4//! and `#[derive(ReferenceCenter)]` for convenient frame and center definitions.
5//!
6//! ## Usage
7//!
8//! These derives are re-exported from `affn`, so you typically use them as:
9//!
10//! ```rust,ignore
11//! use affn::{ReferenceFrame, ReferenceCenter};
12//!
13//! #[derive(Debug, Copy, Clone, ReferenceFrame)]
14//! struct MyFrame;
15//!
16//! #[derive(Debug, Copy, Clone, ReferenceCenter)]
17//! struct MyCenter;
18//! ```
19//!
20//! ## Attributes
21//!
22//! ### `#[derive(ReferenceFrame)]`
23//!
24//! - `#[frame(name = "CustomName")]` - Override the frame name (defaults to struct name)
25//! - `#[frame(polar = "dec", azimuth = "ra")]` - Also implement `SphericalNaming` with custom names
26//! - `#[frame(distance = "altitude")]` - Override distance name (defaults to "distance")
27//! - `#[frame(inherent)]` - Generate inherent methods on `Direction<F>` and `Position<C,F,U>`.
28//!   Only valid when the frame is defined in the same crate as `Direction`/`Position`.
29//!
30//! ### `#[derive(ReferenceCenter)]`
31//!
32//! - `#[center(name = "CustomName")]` - Override the center name (defaults to struct name)
33//! - `#[center(params = MyParamsType)]` - Specify the `Params` associated type (defaults to `()`)
34//! - `#[center(affine = false)]` - Skip implementing `AffineCenter` marker trait
35
36use proc_macro::TokenStream;
37use proc_macro2::TokenStream as TokenStream2;
38use quote::quote;
39use syn::{parse_macro_input, DeriveInput, Expr, Lit, Meta, Type};
40
41// =============================================================================
42// ReferenceFrame derive
43// =============================================================================
44
45/// Derive macro for implementing [`ReferenceFrame`](affn::frames::ReferenceFrame).
46///
47/// # Example
48///
49/// ```rust,ignore
50/// use affn::ReferenceFrame;
51///
52/// #[derive(Debug, Copy, Clone, ReferenceFrame)]
53/// struct ICRS;
54///
55/// assert_eq!(ICRS::frame_name(), "ICRS");
56/// ```
57///
58/// ## Custom Name
59///
60/// ```rust,ignore
61/// #[derive(Debug, Copy, Clone, ReferenceFrame)]
62/// #[frame(name = "International Celestial Reference System")]
63/// struct ICRS;
64///
65/// assert_eq!(ICRS::frame_name(), "International Celestial Reference System");
66/// ```
67///
68/// ## SphericalNaming
69///
70/// When `polar` and `azimuth` attributes are provided, the macro also implements
71/// [`SphericalNaming`](affn::frames::SphericalNaming):
72///
73/// ```rust,ignore
74/// #[derive(Debug, Copy, Clone, ReferenceFrame)]
75/// #[frame(polar = "dec", azimuth = "ra")]
76/// struct ICRS;
77///
78/// assert_eq!(ICRS::polar_name(), "dec");
79/// assert_eq!(ICRS::azimuth_name(), "ra");
80/// assert_eq!(ICRS::distance_name(), "distance"); // default
81/// ```
82///
83/// With custom distance name:
84///
85/// ```rust,ignore
86/// #[derive(Debug, Copy, Clone, ReferenceFrame)]
87/// #[frame(polar = "lat", azimuth = "lon", distance = "altitude")]
88/// struct ITRF;
89/// ```
90#[proc_macro_derive(ReferenceFrame, attributes(frame))]
91pub fn derive_reference_frame(input: TokenStream) -> TokenStream {
92    let input = parse_macro_input!(input as DeriveInput);
93    match derive_reference_frame_impl(input) {
94        Ok(tokens) => tokens.into(),
95        Err(err) => err.to_compile_error().into(),
96    }
97}
98
99/// Attributes parsed from `#[frame(...)]`.
100#[derive(Default)]
101struct FrameAttributes {
102    /// Custom frame name (defaults to struct name).
103    name: Option<String>,
104    /// Polar angle name for SphericalNaming (e.g., "dec", "lat", "alt").
105    polar: Option<String>,
106    /// Azimuthal angle name for SphericalNaming (e.g., "ra", "lon", "az").
107    azimuth: Option<String>,
108    /// Distance name for SphericalNaming (defaults to "distance").
109    distance: Option<String>,
110    /// Whether to generate inherent impls on Direction<F> and Position<C,F,U>.
111    /// Only valid when the frame is defined in the same crate as Direction/Position.
112    inherent: bool,
113}
114
115fn derive_reference_frame_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
116    let name = &input.ident;
117
118    // Parse #[frame(...)] attributes
119    let attrs = parse_frame_attributes(&input)?;
120
121    let name_expr = match &attrs.name {
122        Some(custom_name) => quote! { #custom_name },
123        None => {
124            let name_str = name.to_string();
125            quote! { #name_str }
126        }
127    };
128
129    // Generate SphericalNaming impl (always) + inherent methods (when `inherent` flag set)
130    let spherical_impl = match (&attrs.polar, &attrs.azimuth) {
131        (Some(polar), Some(azimuth)) => {
132            let distance = attrs.distance.as_deref().unwrap_or("distance");
133
134            let polar_ident = syn::Ident::new(polar, proc_macro2::Span::call_site());
135            let azimuth_ident = syn::Ident::new(azimuth, proc_macro2::Span::call_site());
136
137            // SphericalNaming is always generated (trait impl, not inherent)
138            let naming_impl = quote! {
139                impl ::affn::frames::SphericalNaming for #name {
140                    fn polar_name() -> &'static str {
141                        #polar
142                    }
143                    fn azimuth_name() -> &'static str {
144                        #azimuth
145                    }
146                    fn distance_name() -> &'static str {
147                        #distance
148                    }
149                }
150            };
151
152            // Inherent impls: only generated when `inherent` flag is set.
153            // These require Direction/Position to be in the same crate as the frame.
154            let inherent_impl = if attrs.inherent {
155                // Determine constructor parameter order:
156                // IAU convention: polar first for alt/az, azimuth first for everything else
157                let polar_first = polar == "alt";
158
159                let (first_param, second_param) = if polar_first {
160                    (&polar_ident, &azimuth_ident)
161                } else {
162                    (&azimuth_ident, &polar_ident)
163                };
164
165                // new_raw always takes (polar, azimuth)
166                let (polar_arg, azimuth_arg) = (&polar_ident, &azimuth_ident);
167
168                let polar_doc = format!("Returns the {} angle in degrees.", polar);
169                let azimuth_doc = format!("Returns the {} angle in degrees.", azimuth);
170                let dir_new_doc = format!(
171                    "Creates a new direction from {} and {} (canonicalized).",
172                    first_param, second_param
173                );
174                let pos_new_doc = format!(
175                    "Creates a new position from {}, {}, and distance (canonicalized).",
176                    first_param, second_param
177                );
178
179                quote! {
180                    // ── Direction<F>: inherent named constructor + getters ──
181
182                    impl ::affn::spherical::Direction<#name> {
183                        #[doc = #dir_new_doc]
184                        #[inline]
185                        pub fn new(
186                            #first_param: ::qtty::Degrees,
187                            #second_param: ::qtty::Degrees,
188                        ) -> Self {
189                            Self::new_raw(
190                                #polar_arg .wrap_quarter_fold(),
191                                #azimuth_arg .normalize(),
192                            )
193                        }
194
195                        #[doc = #polar_doc]
196                        #[inline]
197                        pub fn #polar_ident(&self) -> ::qtty::Degrees {
198                            self.polar
199                        }
200
201                        #[doc = #azimuth_doc]
202                        #[inline]
203                        pub fn #azimuth_ident(&self) -> ::qtty::Degrees {
204                            self.azimuth
205                        }
206                    }
207
208                    // ── Position<C, F, U>: inherent named getters (any center) ──
209
210                    impl<C, U> ::affn::spherical::Position<C, #name, U>
211                    where
212                        C: ::affn::centers::ReferenceCenter,
213                        U: ::qtty::LengthUnit,
214                    {
215                        #[doc = #polar_doc]
216                        #[inline]
217                        pub fn #polar_ident(&self) -> ::qtty::Degrees {
218                            self.polar
219                        }
220
221                        #[doc = #azimuth_doc]
222                        #[inline]
223                        pub fn #azimuth_ident(&self) -> ::qtty::Degrees {
224                            self.azimuth
225                        }
226                    }
227
228                    // ── Position<C, F, U>: named constructor (only Params = ()) ──
229
230                    impl<C, U> ::affn::spherical::Position<C, #name, U>
231                    where
232                        C: ::affn::centers::ReferenceCenter<Params = ()>,
233                        U: ::qtty::LengthUnit,
234                    {
235                        #[doc = #pos_new_doc]
236                        #[inline]
237                        pub fn new<T: Into<::qtty::Quantity<U>>>(
238                            #first_param: ::qtty::Degrees,
239                            #second_param: ::qtty::Degrees,
240                            distance: T,
241                        ) -> Self {
242                            Self::new_raw(
243                                #polar_arg .wrap_quarter_fold(),
244                                #azimuth_arg .normalize(),
245                                distance.into(),
246                            )
247                        }
248                    }
249                }
250            } else {
251                quote! {}
252            };
253
254            quote! {
255                #naming_impl
256                #inherent_impl
257            }
258        }
259        (Some(_), None) => {
260            return Err(syn::Error::new_spanned(
261                &input.ident,
262                "`polar` attribute requires `azimuth` to also be specified",
263            ));
264        }
265        (None, Some(_)) => {
266            return Err(syn::Error::new_spanned(
267                &input.ident,
268                "`azimuth` attribute requires `polar` to also be specified",
269            ));
270        }
271        (None, None) => quote! {},
272    };
273
274    let expanded = quote! {
275        impl ::affn::frames::ReferenceFrame for #name {
276            fn frame_name() -> &'static str {
277                #name_expr
278            }
279        }
280
281        #spherical_impl
282    };
283
284    Ok(expanded)
285}
286
287fn parse_frame_attributes(input: &DeriveInput) -> syn::Result<FrameAttributes> {
288    let mut attrs = FrameAttributes::default();
289
290    for attr in &input.attrs {
291        if attr.path().is_ident("frame") {
292            let nested = attr.parse_args_with(
293                syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
294            )?;
295
296            for meta in nested {
297                match &meta {
298                    Meta::Path(path) if path.is_ident("inherent") => {
299                        attrs.inherent = true;
300                    }
301                    Meta::NameValue(nv) => {
302                        let value_str = extract_string_literal(&nv.value)?;
303
304                        if nv.path.is_ident("name") {
305                            attrs.name = Some(value_str);
306                        } else if nv.path.is_ident("polar") {
307                            attrs.polar = Some(value_str);
308                        } else if nv.path.is_ident("azimuth") {
309                            attrs.azimuth = Some(value_str);
310                        } else if nv.path.is_ident("distance") {
311                            attrs.distance = Some(value_str);
312                        }
313                    }
314                    _ => {}
315                }
316            }
317        }
318    }
319
320    Ok(attrs)
321}
322
323/// Extract a string literal from an expression, or return an error.
324fn extract_string_literal(expr: &Expr) -> syn::Result<String> {
325    if let Expr::Lit(expr_lit) = expr {
326        if let Lit::Str(lit_str) = &expr_lit.lit {
327            return Ok(lit_str.value());
328        }
329    }
330    Err(syn::Error::new_spanned(expr, "expected string literal"))
331}
332
333// =============================================================================
334// ReferenceCenter derive
335// =============================================================================
336
337/// Derive macro for implementing [`ReferenceCenter`](affn::centers::ReferenceCenter).
338///
339/// By default, this also implements [`AffineCenter`](affn::centers::AffineCenter).
340///
341/// # Example
342///
343/// ```rust,ignore
344/// use affn::ReferenceCenter;
345///
346/// #[derive(Debug, Copy, Clone, ReferenceCenter)]
347/// struct Heliocentric;
348///
349/// assert_eq!(Heliocentric::center_name(), "Heliocentric");
350/// ```
351///
352/// ## Custom Parameters
353///
354/// ```rust,ignore
355/// use affn::ReferenceCenter;
356///
357/// #[derive(Clone, Debug, Default, PartialEq)]
358/// struct ObserverLocation {
359///     lat: f64,
360///     lon: f64,
361/// }
362///
363/// #[derive(Debug, Copy, Clone, ReferenceCenter)]
364/// #[center(params = ObserverLocation)]
365/// struct Topocentric;
366/// ```
367///
368/// ## Skip AffineCenter
369///
370/// ```rust,ignore
371/// #[derive(Debug, Copy, Clone, ReferenceCenter)]
372/// #[center(affine = false)]
373/// struct NonAffineCenter;
374/// ```
375#[proc_macro_derive(ReferenceCenter, attributes(center))]
376pub fn derive_reference_center(input: TokenStream) -> TokenStream {
377    let input = parse_macro_input!(input as DeriveInput);
378    match derive_reference_center_impl(input) {
379        Ok(tokens) => tokens.into(),
380        Err(err) => err.to_compile_error().into(),
381    }
382}
383
384struct CenterAttributes {
385    name: Option<String>,
386    params: Option<Type>,
387    affine: bool,
388}
389
390impl Default for CenterAttributes {
391    fn default() -> Self {
392        Self {
393            name: None,
394            params: None,
395            affine: true,
396        }
397    }
398}
399
400fn derive_reference_center_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
401    let name = &input.ident;
402
403    // Parse #[center(...)] attributes
404    let attrs = parse_center_attributes(&input)?;
405
406    let name_expr = match attrs.name {
407        Some(custom_name) => quote! { #custom_name },
408        None => {
409            let name_str = name.to_string();
410            quote! { #name_str }
411        }
412    };
413
414    let params_type = match attrs.params {
415        Some(ty) => quote! { #ty },
416        None => quote! { () },
417    };
418
419    let affine_impl = if attrs.affine {
420        quote! {
421            impl ::affn::centers::AffineCenter for #name {}
422        }
423    } else {
424        quote! {}
425    };
426
427    let expanded = quote! {
428        impl ::affn::centers::ReferenceCenter for #name {
429            type Params = #params_type;
430
431            fn center_name() -> &'static str {
432                #name_expr
433            }
434        }
435
436        #affine_impl
437    };
438
439    Ok(expanded)
440}
441
442fn parse_center_attributes(input: &DeriveInput) -> syn::Result<CenterAttributes> {
443    let mut attrs = CenterAttributes::default();
444
445    for attr in &input.attrs {
446        if attr.path().is_ident("center") {
447            let nested = attr.parse_args_with(
448                syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
449            )?;
450
451            for meta in nested {
452                match meta {
453                    Meta::NameValue(nv) => {
454                        if nv.path.is_ident("name") {
455                            if let Expr::Lit(expr_lit) = &nv.value {
456                                if let Lit::Str(lit_str) = &expr_lit.lit {
457                                    attrs.name = Some(lit_str.value());
458                                    continue;
459                                }
460                            }
461                            return Err(syn::Error::new_spanned(
462                                &nv.value,
463                                "expected string literal for `name`",
464                            ));
465                        } else if nv.path.is_ident("params") {
466                            // Parse as a type path
467                            if let Expr::Path(expr_path) = &nv.value {
468                                attrs.params = Some(Type::Path(syn::TypePath {
469                                    qself: None,
470                                    path: expr_path.path.clone(),
471                                }));
472                                continue;
473                            }
474                            return Err(syn::Error::new_spanned(
475                                &nv.value,
476                                "expected type for `params`",
477                            ));
478                        } else if nv.path.is_ident("affine") {
479                            if let Expr::Lit(expr_lit) = &nv.value {
480                                if let Lit::Bool(lit_bool) = &expr_lit.lit {
481                                    attrs.affine = lit_bool.value();
482                                    continue;
483                                }
484                            }
485                            return Err(syn::Error::new_spanned(
486                                &nv.value,
487                                "expected boolean for `affine`",
488                            ));
489                        }
490                    }
491                    _ => {}
492                }
493            }
494        }
495    }
496
497    Ok(attrs)
498}