mavspec_rust_gen 0.6.7

Rust code generation module for MAVSpec.
Documentation
use mavinspect::protocol::Dialect;
use quote::{format_ident, quote};

use crate::conventions::{
    dialect_enum_name, dialect_enum_specta_name, message_struct_name, messages_enum_entry_name,
    microservice_msg_enum_specta_name,
};
use crate::conventions::{microservice_doc_mention, microservice_msg_enum_name};
use crate::specs::dialects::dialect::DialectModuleSpec;
use crate::specs::Spec;
use crate::templates::helpers::{make_serde_derive_annotation, make_specta_derive_annotation};

pub fn dialect_module(spec: &DialectModuleSpec) -> syn::File {
    let leading_module_comment = match spec.msrv_name() {
        Some(msrv_name) => {
            format!(
                "# MAVLink microservice {} for dialect `{}`",
                microservice_doc_mention(msrv_name),
                spec.display_name()
            )
        }
        None => format!("# MAVLink dialect `{}`", spec.display_name()),
    };

    let mavspec_import = spec.params().mavspec_import();

    let dialect_imports = if spec.has_microservices() {
        quote! {
            pub mod messages;
            pub mod enums;
            pub mod microservices;
        }
    } else {
        quote! {
            pub mod messages;
            pub mod enums;
        }
    };

    let dialect_id_attr = match spec.dialect_id() {
        Some(id) => quote! { #[dialect(#id)] },
        None => quote! {},
    };

    let dialect_version_attr = match spec.version() {
        Some(version) => quote! { #[version(#version)] },
        None => quote! {},
    };

    let messages_enum_comment = match spec.msrv_name() {
        Some(msrv_name) => format!(
            " Enum containing all messages within {} microservice of `{}` dialect.",
            microservice_doc_mention(msrv_name),
            spec.display_name()
        ),
        None => format!(
            " Enum containing all messages within `{}` dialect.",
            spec.display_name()
        ),
    };

    let specta_name = match spec.msrv_name() {
        Some(msrv_name) => microservice_msg_enum_specta_name(msrv_name, spec.canonical_name()),
        None => dialect_enum_specta_name(spec.canonical_name()),
    };

    let derive_serde = make_serde_derive_annotation(spec.params().serde);
    let derive_specta =
        make_specta_derive_annotation(spec.params().specta, Some(specta_name.as_str()));

    let messages_variants = spec.messages().iter().map(|msg| {
        let comment = format!(" MAVLink message `{}`.", msg.name());
        let messages_enum_entry_name = format_ident!("{}", messages_enum_entry_name(msg.name()));
        let message_struct_name = format_ident!("{}", message_struct_name(msg.name()));

        quote! {
            #[doc = #comment]
            #messages_enum_entry_name(messages::#message_struct_name),
        }
    });

    let dialect_enum_ident = make_dialect_enum_ident(spec);

    let conversions = generate_conversions(spec);
    let metadata = generate_metadata(spec);
    let tests = generate_tests(spec);

    let content = quote! {
        #![doc = #leading_module_comment]

        #mavspec_import
        #dialect_imports

        #[doc = #messages_enum_comment]
        #[derive(mavspec::rust::derive::Dialect)]
        #dialect_id_attr
        #dialect_version_attr
        #[derive(core::clone::Clone, core::fmt::Debug, core::cmp::PartialEq)]
        #derive_specta
        #derive_serde
        #[allow(clippy::large_enum_variant)]
        pub enum #dialect_enum_ident {
            #(#messages_variants)*
        }

        #conversions

        #metadata

        #tests
    };

    syn::parse2(content).expect("Failed to parse dialect")
}

fn generate_tests(spec: &DialectModuleSpec) -> proc_macro2::TokenStream {
    let dialect_enum_ident = make_dialect_enum_ident(spec);

    if spec.params().generate_tests && !spec.messages().is_empty() {
        let ids = spec.messages().iter().map(|msg| {
            let id = msg.id();
            quote! {
                #id
            }
        });

        quote! {
            #[cfg(test)]
            mod tests {
                use mavspec::rust::spec::Dialect;

                use super::*;

                #[test]
                fn retrieve_message_info() {
                    for id in [
                        #(#ids,)*
                    ] {
                        let msg_info = #dialect_enum_ident::message_info(id);
                        assert!(msg_info.is_ok());
                        assert_eq!(msg_info.unwrap().id(), id);
                    }
                }
            }
        }
    } else {
        quote!()
    }
}

fn generate_conversions(spec: &DialectModuleSpec) -> proc_macro2::TokenStream {
    match spec.msrv() {
        Some(msrv) => {
            let msrv_enum_ident = format_ident!("{}", microservice_msg_enum_name(msrv.name()));

            let parent_dialect_enum_ident =
                format_ident!("{}", dialect_enum_name(msrv.parent().name()));
            let parent_dialect_path_ident =
                quote! { super::super::super::#parent_dialect_enum_ident };

            let from_msrv_to_dialect_arms = spec.messages().iter().map(|msg| {
                let msg_variant_ident = format_ident!("{}", messages_enum_entry_name(msg.name()));
                quote! {
                        #msrv_enum_ident::#msg_variant_ident(msg) => #parent_dialect_path_ident::#msg_variant_ident(msg.into()),
                    }
            });

            let from_dialect_to_msrv_arms = spec.messages().iter().map(|&msg| {
                let msg_variant_ident = format_ident!("{}", messages_enum_entry_name(msg.name()));
                let msg_convert = if is_same_message(msg.name(), spec.dialect(), msrv.parent()).is_some() {
                    quote! { msg }
                } else {
                    quote! { msg.try_into()? }
                };

                quote! {
                        #parent_dialect_path_ident::#msg_variant_ident(msg) => #msrv_enum_ident::#msg_variant_ident(#msg_convert),
                    }
            });

            quote! {
                impl core::convert::From<#msrv_enum_ident> for #parent_dialect_path_ident {
                    fn from(value: #msrv_enum_ident) -> Self {
                        #[allow(unreachable_patterns)]
                        match value {
                            #(#from_msrv_to_dialect_arms)*
                            _ => unreachable!(),
                        }
                    }
                }

                impl core::convert::TryFrom<#parent_dialect_path_ident> for #msrv_enum_ident {
                    type Error = mavspec::rust::spec::SpecError;

                    fn try_from(value: #parent_dialect_path_ident) -> Result<Self, Self::Error> {
                        use mavspec::rust::spec::MessageSpec;
                        Ok(match value {
                            #(#from_dialect_to_msrv_arms)*
                            msg => return Err(Self::Error::NotInDialect(msg.id())),
                        })
                    }
                }
            }
        }
        None => quote! {},
    }
}

fn generate_metadata(spec: &DialectModuleSpec) -> proc_macro2::TokenStream {
    if !spec.params().metadata {
        return quote! {};
    }

    let dialect_enum_ident = make_dialect_enum_ident(spec);

    let mut message_ids = spec
        .messages()
        .iter()
        .map(|msg| msg.id())
        .collect::<Vec<_>>();
    message_ids.sort();

    let message_ids = message_ids.iter().map(|msg_id| {
        quote! {
            #msg_id
        }
    });

    quote! {
        impl #dialect_enum_ident {
            /// Iterator over all message IDs within this dialect.
            ///
            /// Requires `metadata` feature flag to be enabled.
            pub fn message_ids() -> impl Iterator<Item=mavspec::rust::spec::types::MessageId> {
                [#(#message_ids,)*].iter().copied()
            }
        }
    }
}

fn make_dialect_enum_ident(specs: &DialectModuleSpec) -> syn::Ident {
    let dialect_enum_ident = format_ident!("{}", specs.enum_name());
    let dialect_enum_ident = match specs.msrv_name() {
        None => dialect_enum_ident,
        Some(msrv_name) => format_ident!("{}", microservice_msg_enum_name(msrv_name)),
    };
    dialect_enum_ident
}

fn is_same_message(msg_name: &str, dialect: &Dialect, parent_dialect: &Dialect) -> Option<()> {
    let msg = dialect.get_message_by_name(msg_name)?;
    let parent_msg = parent_dialect.get_message_by_name(msg_name)?;

    if msg.fingerprint_strict(Some(&dialect))
        == parent_msg.fingerprint_strict(Some(&parent_dialect))
    {
        return Some(());
    }

    None
}