astarte_device_sdk_derive/
lib.rs1use std::{collections::HashMap, fmt::Debug};
22
23use proc_macro::TokenStream;
24
25use proc_macro2::Ident;
26use quote::{quote, quote_spanned};
27use syn::{
28 parse::{Parse, ParseStream},
29 parse_macro_input, parse_quote,
30 punctuated::Punctuated,
31 spanned::Spanned,
32 Attribute, Expr, GenericParam, Generics, MetaNameValue, Token,
33};
34
35use crate::{case::RenameRule, event::FromEventDerive};
36
37mod case;
38mod event;
39
40#[derive(Debug, Default)]
52struct ObjectAttributes {
53 rename_all: Option<RenameRule>,
55}
56
57impl ObjectAttributes {
58 fn merge(self, other: Self) -> Self {
60 let rename_all = other.rename_all.or(self.rename_all);
61
62 Self { rename_all }
63 }
64}
65
66impl Parse for ObjectAttributes {
67 fn parse(input: ParseStream) -> syn::Result<Self> {
68 let mut attrs = parse_name_value_attrs(input)?;
69
70 let rename_all = attrs
71 .remove("rename_all")
72 .map(|expr| {
73 parse_str_lit(&expr).and_then(|rename| {
74 RenameRule::from_str(&rename)
75 .map_err(|_| syn::Error::new(expr.span(), "invalid rename rule"))
76 })
77 })
78 .transpose()?;
79
80 if let Some((_, expr)) = attrs.iter().next() {
81 return Err(syn::Error::new(expr.span(), "unrecognized attribute"));
82 }
83
84 Ok(Self { rename_all })
85 }
86}
87
88fn parse_name_value_attrs(
92 input: &syn::parse::ParseBuffer<'_>,
93) -> Result<HashMap<String, Expr>, syn::Error> {
94 Punctuated::<MetaNameValue, Token![,]>::parse_terminated(input)?
95 .into_iter()
96 .map(|v| {
97 v.path
98 .get_ident()
99 .ok_or_else(|| {
100 syn::Error::new(v.span(), "expected an identifier like `rename_all`")
101 })
102 .map(|i| (i.to_string(), v.value))
103 })
104 .collect::<syn::Result<_>>()
105}
106
107fn parse_str_lit(expr: &Expr) -> syn::Result<String> {
109 match expr {
110 Expr::Lit(syn::ExprLit {
111 lit: syn::Lit::Str(lit),
112 ..
113 }) => Ok(lit.value()),
114 _ => Err(syn::Error::new(
115 expr.span(),
116 "expression must be a string literal",
117 )),
118 }
119}
120
121fn parse_bool_lit(expr: &Expr) -> syn::Result<bool> {
123 match expr {
124 Expr::Lit(syn::ExprLit {
125 lit: syn::Lit::Bool(lit),
126 ..
127 }) => Ok(lit.value()),
128 _ => Err(syn::Error::new(
129 expr.span(),
130 "expression must be a bool literal",
131 )),
132 }
133}
134
135struct ObjectDerive {
146 name: Ident,
147 attrs: ObjectAttributes,
148 fields: Vec<Ident>,
149 generics: Generics,
150}
151
152impl ObjectDerive {
153 fn quote(&self) -> proc_macro2::TokenStream {
154 let rename_rule = self.attrs.rename_all.unwrap_or_default();
155
156 let name = &self.name;
157 let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
158 let capacity = self.fields.len();
159 let fields = self.fields.iter().map(|i| {
160 let name = i.to_string();
161 let name = rename_rule.apply_to_field(&name);
162 quote_spanned! {i.span() =>
163 #[allow(unknown_lints)]
165 #[allow(clippy::unnecessary_fallible_conversions)]
166 let v: astarte_device_sdk::types::AstarteData = ::std::convert::TryInto::try_into(value.#i)?;
167 object.insert(#name.to_string(), v);
168 }
169 });
170
171 quote! {
172 impl #impl_generics ::std::convert::TryFrom<#name #ty_generics> for astarte_device_sdk::aggregate::AstarteObject #where_clause {
173 type Error = astarte_device_sdk::error::Error;
174
175 fn try_from(value: #name #ty_generics) -> ::std::result::Result<Self, Self::Error> {
176 let mut object = Self::with_capacity(#capacity);
177 #(#fields)*
178 Ok(object)
179 }
180 }
181 }
182 }
183
184 pub fn add_trait_bound(mut generics: Generics) -> Generics {
185 for param in &mut generics.params {
186 if let GenericParam::Type(ref mut type_param) = *param {
187 type_param.bounds.push(parse_quote!(
188 std::convert::TryInto<astarte_device_sdk::types::AstarteData, Error = astarte_device_sdk::error::Error>
189 ));
190 }
191 }
192 generics
193 }
194}
195
196impl Parse for ObjectDerive {
197 fn parse(input: ParseStream) -> syn::Result<Self> {
198 let ast = syn::DeriveInput::parse(input)?;
199
200 let attrs = ast
202 .attrs
203 .iter()
204 .filter_map(|a| parse_attribute_list::<ObjectAttributes>(a, "astarte_object"))
205 .collect::<Result<Vec<_>, _>>()?
206 .into_iter()
207 .reduce(|first, second| first.merge(second))
208 .unwrap_or_default();
209
210 let fields = parse_struct_fields(&ast)?;
211
212 let name = ast.ident;
213
214 let generics = Self::add_trait_bound(ast.generics);
215
216 Ok(Self {
217 name,
218 attrs,
219 fields,
220 generics,
221 })
222 }
223}
224
225fn parse_struct_fields(ast: &syn::DeriveInput) -> Result<Vec<Ident>, syn::Error> {
227 let syn::Data::Struct(ref st) = ast.data else {
228 return Err(syn::Error::new(ast.span(), "a named struct is required"));
229 };
230 let syn::Fields::Named(ref fields_named) = st.fields else {
231 return Err(syn::Error::new(ast.span(), "a nemed struct is required"));
232 };
233
234 let fields = fields_named
235 .named
236 .iter()
237 .map(|field| {
238 field
239 .ident
240 .clone()
241 .ok_or_else(|| syn::Error::new(field.span(), "field is not an ident"))
242 })
243 .collect::<Result<_, _>>()?;
244
245 Ok(fields)
246}
247
248pub(crate) fn parse_attribute_list<T>(attr: &Attribute, name: &str) -> Option<syn::Result<T>>
253where
254 T: Parse,
255{
256 let is_attr = attr
257 .path()
258 .get_ident()
259 .map(ToString::to_string)
260 .filter(|ident| ident == name)
261 .is_some();
262
263 if !is_attr {
264 return None;
265 }
266
267 match &attr.meta {
269 syn::Meta::Path(_) => None,
272 syn::Meta::NameValue(name) => Some(Err(syn::Error::new(
273 name.span(),
274 "cannot be used as a named value",
275 ))),
276 syn::Meta::List(list) => Some(syn::parse2::<T>(list.tokens.clone())),
277 }
278}
279
280#[proc_macro_derive(IntoAstarteObject, attributes(astarte_object))]
291pub fn astarte_aggregate_derive(input: TokenStream) -> TokenStream {
292 let aggregate = parse_macro_input!(input as ObjectDerive);
295
296 aggregate.quote().into()
298}
299
300#[proc_macro_derive(FromEvent, attributes(from_event, mapping))]
341pub fn from_event_derive(input: TokenStream) -> TokenStream {
342 let from_event = parse_macro_input!(input as FromEventDerive);
345
346 from_event.quote().into()
348}