astarte_device_sdk_derive/
lib.rs1use std::{collections::HashMap, fmt::Debug};
22
23use proc_macro::TokenStream;
24
25use proc_macro2::Ident;
26use quote::{quote, quote_spanned};
27use syn::{
28    parse::{Parse, ParseStream},
29    parse_macro_input, parse_quote,
30    punctuated::Punctuated,
31    spanned::Spanned,
32    Attribute, Expr, GenericParam, Generics, MetaNameValue, Token,
33};
34
35use crate::{case::RenameRule, event::FromEventDerive};
36
37mod case;
38mod event;
39
40#[derive(Debug, Default)]
52struct ObjectAttributes {
53    rename_all: Option<RenameRule>,
55}
56
57impl ObjectAttributes {
58    fn merge(self, other: Self) -> Self {
60        let rename_all = other.rename_all.or(self.rename_all);
61
62        Self { rename_all }
63    }
64}
65
66impl Parse for ObjectAttributes {
67    fn parse(input: ParseStream) -> syn::Result<Self> {
68        let mut attrs = parse_name_value_attrs(input)?;
69
70        let rename_all = attrs
71            .remove("rename_all")
72            .map(|expr| {
73                parse_str_lit(&expr).and_then(|rename| {
74                    RenameRule::from_str(&rename)
75                        .map_err(|_| syn::Error::new(expr.span(), "invalid rename rule"))
76                })
77            })
78            .transpose()?;
79
80        if let Some((_, expr)) = attrs.iter().next() {
81            return Err(syn::Error::new(expr.span(), "unrecognized attribute"));
82        }
83
84        Ok(Self { rename_all })
85    }
86}
87
88fn parse_name_value_attrs(
92    input: &syn::parse::ParseBuffer<'_>,
93) -> Result<HashMap<String, Expr>, syn::Error> {
94    Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?
95        .into_iter()
96        .map(|v| {
97            v.path
98                .get_ident()
99                .ok_or_else(|| {
100                    syn::Error::new(v.span(), "expected an identifier like `rename_all`")
101                })
102                .map(|i| (i.to_string(), v.value))
103        })
104        .collect::<syn::Result<_>>()
105}
106
107fn parse_str_lit(expr: &Expr) -> syn::Result<String> {
109    match expr {
110        Expr::Lit(syn::ExprLit {
111            lit: syn::Lit::Str(lit),
112            ..
113        }) => Ok(lit.value()),
114        _ => Err(syn::Error::new(
115            expr.span(),
116            "expression must be a string literal",
117        )),
118    }
119}
120
121fn parse_bool_lit(expr: &Expr) -> syn::Result<bool> {
123    match expr {
124        Expr::Lit(syn::ExprLit {
125            lit: syn::Lit::Bool(lit),
126            ..
127        }) => Ok(lit.value()),
128        _ => Err(syn::Error::new(
129            expr.span(),
130            "expression must be a bool literal",
131        )),
132    }
133}
134
135struct ObjectDerive {
146    name: Ident,
147    attrs: ObjectAttributes,
148    fields: Vec<Ident>,
149    generics: Generics,
150}
151
152impl ObjectDerive {
153    fn quote(&self) -> proc_macro2::TokenStream {
154        let rename_rule = self.attrs.rename_all.unwrap_or_default();
155
156        let name = &self.name;
157        let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
158        let capacity = self.fields.len();
159        let fields = self.fields.iter().map(|i| {
160            let name = i.to_string();
161            let name = rename_rule.apply_to_field(&name);
162            quote_spanned! {i.span() =>
163                #[allow(unknown_lints)]
165                #[allow(clippy::unnecessary_fallible_conversions)]
166                let v: astarte_device_sdk::types::AstarteData = ::std::convert::TryInto::try_into(value.#i)?;
167                object.insert(#name.to_string(), v);
168            }
169        });
170
171        quote! {
172            impl #impl_generics ::std::convert::TryFrom<#name #ty_generics> for astarte_device_sdk::aggregate::AstarteObject #where_clause {
173                type Error = astarte_device_sdk::error::Error;
174
175                fn try_from(value: #name #ty_generics) -> ::std::result::Result<Self, Self::Error> {
176                    let mut object = Self::with_capacity(#capacity);
177                    #(#fields)*
178                    Ok(object)
179                }
180            }
181        }
182    }
183
184    pub fn add_trait_bound(mut generics: Generics) -> Generics {
185        for param in &mut generics.params {
186            if let GenericParam::Type(ref mut type_param) = *param {
187                type_param.bounds.push(parse_quote!(
188                    std::convert::TryInto<astarte_device_sdk::types::AstarteData, Error = astarte_device_sdk::error::Error>
189                ));
190            }
191        }
192        generics
193    }
194}
195
196impl Parse for ObjectDerive {
197    fn parse(input: ParseStream) -> syn::Result<Self> {
198        let ast = syn::DeriveInput::parse(input)?;
199
200        let attrs = ast
202            .attrs
203            .iter()
204            .filter_map(|a| parse_attribute_list::<ObjectAttributes>(a, "astarte_object"))
205            .collect::<Result<Vec<_>, _>>()?
206            .into_iter()
207            .reduce(|first, second| first.merge(second))
208            .unwrap_or_default();
209
210        let fields = parse_struct_fields(&ast)?;
211
212        let name = ast.ident;
213
214        let generics = Self::add_trait_bound(ast.generics);
215
216        Ok(Self {
217            name,
218            attrs,
219            fields,
220            generics,
221        })
222    }
223}
224
225fn parse_struct_fields(ast: &syn::DeriveInput) -> Result<Vec<Ident>, syn::Error> {
227    let syn::Data::Struct(ref st) = ast.data else {
228        return Err(syn::Error::new(ast.span(), "a named struct is required"));
229    };
230    let syn::Fields::Named(ref fields_named) = st.fields else {
231        return Err(syn::Error::new(ast.span(), "a nemed struct is required"));
232    };
233
234    let fields = fields_named
235        .named
236        .iter()
237        .map(|field| {
238            field
239                .ident
240                .clone()
241                .ok_or_else(|| syn::Error::new(field.span(), "field is not an ident"))
242        })
243        .collect::<Result<_, _>>()?;
244
245    Ok(fields)
246}
247
248pub(crate) fn parse_attribute_list<T>(attr: &Attribute, name: &str) -> Option<syn::Result<T>>
253where
254    T: Parse,
255{
256    let is_attr = attr
257        .path()
258        .get_ident()
259        .map(ToString::to_string)
260        .filter(|ident| ident == name)
261        .is_some();
262
263    if !is_attr {
264        return None;
265    }
266
267    match &attr.meta {
269        syn::Meta::Path(_) => None,
272        syn::Meta::NameValue(name) => Some(Err(syn::Error::new(
273            name.span(),
274            "cannot be used as a named value",
275        ))),
276        syn::Meta::List(list) => Some(syn::parse2::<T>(list.tokens.clone())),
277    }
278}
279
280#[proc_macro_derive(IntoAstarteObject, attributes(astarte_object))]
291pub fn astarte_aggregate_derive(input: TokenStream) -> TokenStream {
292    let aggregate = parse_macro_input!(input as ObjectDerive);
295
296    aggregate.quote().into()
298}
299
300#[proc_macro_derive(FromEvent, attributes(from_event, mapping))]
341pub fn from_event_derive(input: TokenStream) -> TokenStream {
342    let from_event = parse_macro_input!(input as FromEventDerive);
345
346    from_event.quote().into()
348}