rs_matter_macros/
lib.rs

1/*
2 *
3 *    Copyright (c) 2020-2022 Project CHIP Authors
4 *
5 *    Licensed under the Apache License, Version 2.0 (the "License");
6 *    you may not use this file except in compliance with the License.
7 *    You may obtain a copy of the License at
8 *
9 *        http://www.apache.org/licenses/LICENSE-2.0
10 *
11 *    Unless required by applicable law or agreed to in writing, software
12 *    distributed under the License is distributed on an "AS IS" BASIS,
13 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 *    See the License for the specific language governing permissions and
15 *    limitations under the License.
16 */
17
18use proc_macro::TokenStream;
19use proc_macro2::{Ident, Span};
20use quote::{format_ident, quote};
21use syn::Lit::{Int, Str};
22use syn::NestedMeta::{Lit, Meta};
23use syn::{parse_macro_input, DeriveInput, Lifetime};
24use syn::{
25    Meta::{List, NameValue},
26    MetaList, MetaNameValue, Type,
27};
28
29struct TlvArgs {
30    start: u8,
31    datatype: String,
32    unordered: bool,
33    lifetime: syn::Lifetime,
34}
35
36impl Default for TlvArgs {
37    fn default() -> Self {
38        Self {
39            start: 0,
40            datatype: "struct".to_string(),
41            unordered: false,
42            lifetime: Lifetime::new("'_", Span::call_site()),
43        }
44    }
45}
46
47fn parse_tlvargs(ast: &DeriveInput) -> TlvArgs {
48    let mut tlvargs: TlvArgs = Default::default();
49
50    if !ast.attrs.is_empty() {
51        if let List(MetaList {
52            path,
53            paren_token: _,
54            nested,
55        }) = ast.attrs[0].parse_meta().unwrap()
56        {
57            if path.is_ident("tlvargs") {
58                for a in nested {
59                    if let Meta(NameValue(MetaNameValue {
60                        path: key_path,
61                        eq_token: _,
62                        lit: key_val,
63                    })) = a
64                    {
65                        if key_path.is_ident("start") {
66                            if let Int(litint) = key_val {
67                                tlvargs.start = litint.base10_parse::<u8>().unwrap();
68                            }
69                        } else if key_path.is_ident("lifetime") {
70                            if let Str(litstr) = key_val {
71                                tlvargs.lifetime =
72                                    Lifetime::new(&litstr.value(), Span::call_site());
73                            }
74                        } else if key_path.is_ident("datatype") {
75                            if let Str(litstr) = key_val {
76                                tlvargs.datatype = litstr.value();
77                            }
78                        } else if key_path.is_ident("unordered") {
79                            tlvargs.unordered = true;
80                        }
81                    }
82                }
83            }
84        }
85    }
86    tlvargs
87}
88
89fn parse_tag_val(field: &syn::Field) -> Option<u8> {
90    if !field.attrs.is_empty() {
91        if let List(MetaList {
92            path,
93            paren_token: _,
94            nested,
95        }) = field.attrs[0].parse_meta().unwrap()
96        {
97            if path.is_ident("tagval") {
98                for a in nested {
99                    if let Lit(Int(litint)) = a {
100                        return Some(litint.base10_parse::<u8>().unwrap());
101                    }
102                }
103            }
104        }
105    }
106    None
107}
108
109fn get_crate_name() -> String {
110    let found_crate = proc_macro_crate::crate_name("rs-matter").unwrap_or_else(|err| {
111        eprintln!("Warning: defaulting to `crate` {err}");
112        proc_macro_crate::FoundCrate::Itself
113    });
114
115    match found_crate {
116        proc_macro_crate::FoundCrate::Itself => String::from("crate"),
117        proc_macro_crate::FoundCrate::Name(name) => name,
118    }
119}
120
121/// Generate a ToTlv implementation for a structure
122fn gen_totlv_for_struct(
123    fields: &syn::FieldsNamed,
124    struct_name: &proc_macro2::Ident,
125    tlvargs: &TlvArgs,
126    generics: &syn::Generics,
127) -> TokenStream {
128    let mut tag_start = tlvargs.start;
129    let datatype = format_ident!("start_{}", tlvargs.datatype);
130
131    let mut idents = Vec::new();
132    let mut tags = Vec::new();
133
134    for field in fields.named.iter() {
135        //        let field_name: &syn::Ident = field.ident.as_ref().unwrap();
136        //        let name: String = field_name.to_string();
137        //        let literal_key_str = syn::LitStr::new(&name, field.span());
138        //        let type_name = &field.ty;
139        //        keys.push(quote! { #literal_key_str });
140        idents.push(&field.ident);
141        //        types.push(type_name.to_token_stream());
142        if let Some(a) = parse_tag_val(field) {
143            tags.push(a);
144        } else {
145            tags.push(tag_start);
146            tag_start += 1;
147        }
148    }
149
150    let expanded = quote! {
151        impl #generics ToTLV for #struct_name #generics {
152            fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
153                let anchor = tw.get_tail();
154
155                if let Err(err) = (|| {
156                    tw. #datatype (tag_type)?;
157                    #(
158                        self.#idents.to_tlv(tw, TagType::Context(#tags))?;
159                    )*
160                    tw.end_container()
161                })() {
162                    tw.rewind_to(anchor);
163                    Err(err)
164                } else {
165                    Ok(())
166                }
167            }
168        }
169    };
170    //    panic!("The generated code is {}", expanded);
171    expanded.into()
172}
173
174/// Generate a ToTlv implementation for an enum
175fn gen_totlv_for_enum(
176    data_enum: &syn::DataEnum,
177    enum_name: &proc_macro2::Ident,
178    tlvargs: &TlvArgs,
179    generics: &syn::Generics,
180) -> TokenStream {
181    let mut tag_start = tlvargs.start;
182
183    let mut variant_names = Vec::new();
184    let mut types = Vec::new();
185    let mut tags = Vec::new();
186
187    for v in data_enum.variants.iter() {
188        variant_names.push(&v.ident);
189        if let syn::Fields::Unnamed(fields) = &v.fields {
190            if let Type::Path(path) = &fields.unnamed[0].ty {
191                types.push(&path.path.segments[0].ident);
192            } else {
193                panic!("Path not found {:?}", v.fields);
194            }
195        } else {
196            panic!("Unnamed field not found {:?}", v.fields);
197        }
198        tags.push(tag_start);
199        tag_start += 1;
200    }
201
202    let krate = Ident::new(&get_crate_name(), Span::call_site());
203
204    let expanded = quote! {
205        impl #generics #krate::tlv::ToTLV for #enum_name #generics {
206            fn to_tlv(&self, tw: &mut #krate::tlv::TLVWriter, tag_type: #krate::tlv::TagType) -> Result<(), #krate::error::Error> {
207                let anchor = tw.get_tail();
208
209                if let Err(err) = (|| {
210                    tw.start_struct(tag_type)?;
211                    match self {
212                        #(
213                            Self::#variant_names(c) => { c.to_tlv(tw, #krate::tlv::TagType::Context(#tags))?; },
214                        )*
215                    }
216                    tw.end_container()
217                })() {
218                    tw.rewind_to(anchor);
219                    Err(err)
220                } else {
221                    Ok(())
222                }
223            }
224        }
225    };
226
227    //    panic!("Expanded to {}", expanded);
228    expanded.into()
229}
230
231/// Derive ToTLV Macro
232///
233/// This macro works for structures. It will create an implementation
234/// of the ToTLV trait for that structure.  All the members of the
235/// structure, sequentially, will get Context tags starting from 0
236/// Some configurations are possible through the 'tlvargs' attributes.
237/// For example:
238///  #[tlvargs(start = 1, datatype = "list")]
239///
240/// start: This can be used to override the default tag from which the
241///        encoding starts (Default: 0)
242/// datatype: This can be used to define whether this data structure is
243///        to be encoded as a structure or list. Possible values: list
244///        (Default: struct)
245///
246/// Additionally, structure members can use the tagval attribute to
247/// define a specific tag to be used
248/// For example:
249///  #[argval(22)]
250///  name: u8,
251/// In the above case, the 'name' attribute will be encoded/decoded with
252/// the tag 22
253
254#[proc_macro_derive(ToTLV, attributes(tlvargs, tagval))]
255pub fn derive_totlv(item: TokenStream) -> TokenStream {
256    let ast = parse_macro_input!(item as DeriveInput);
257    let name = &ast.ident;
258
259    let tlvargs = parse_tlvargs(&ast);
260    let generics = ast.generics;
261
262    if let syn::Data::Struct(syn::DataStruct {
263        fields: syn::Fields::Named(ref fields),
264        ..
265    }) = ast.data
266    {
267        gen_totlv_for_struct(fields, name, &tlvargs, &generics)
268    } else if let syn::Data::Enum(data_enum) = ast.data {
269        gen_totlv_for_enum(&data_enum, name, &tlvargs, &generics)
270    } else {
271        panic!(
272            "Derive ToTLV - Only supported Struct for now {:?}",
273            ast.data
274        );
275    }
276}
277
278/// Generate a FromTlv implementation for a structure
279fn gen_fromtlv_for_struct(
280    fields: &syn::FieldsNamed,
281    struct_name: &proc_macro2::Ident,
282    tlvargs: TlvArgs,
283    generics: &syn::Generics,
284) -> TokenStream {
285    let mut tag_start = tlvargs.start;
286    let lifetime = tlvargs.lifetime;
287    let datatype = format_ident!("confirm_{}", tlvargs.datatype);
288
289    let mut idents = Vec::new();
290    let mut types = Vec::new();
291    let mut tags = Vec::new();
292
293    for field in fields.named.iter() {
294        let type_name = &field.ty;
295        if let Some(a) = parse_tag_val(field) {
296            // TODO: The current limitation with this is that a hard-coded integer
297            // value has to be mentioned in the tagval attribute. This is because
298            // our tags vector is for integers, and pushing an 'identifier' on it
299            // wouldn't work.
300            tags.push(a);
301        } else {
302            tags.push(tag_start);
303            tag_start += 1;
304        }
305        idents.push(&field.ident);
306
307        if let Type::Path(path) = type_name {
308            types.push(&path.path.segments[0].ident);
309        } else {
310            panic!("Don't know what to do {:?}", type_name);
311        }
312    }
313
314    let krate = Ident::new(&get_crate_name(), Span::call_site());
315
316    // Currently we don't use find_tag() because the tags come in sequential
317    // order. If ever the tags start coming out of order, we can use find_tag()
318    // instead
319    let expanded = if !tlvargs.unordered {
320        quote! {
321           impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics {
322               fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result<Self, #krate::error::Error> {
323                   let mut t_iter = t.#datatype ()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?;
324                   let mut item = t_iter.next();
325                   #(
326                       let #idents = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(#tags)) {
327                           let backup = item;
328                           item = t_iter.next();
329                           #types::from_tlv(&backup.unwrap())
330                       } else {
331                           #types::tlv_not_found()
332                       }?;
333                   )*
334                   Ok(Self {
335                       #(#idents,
336                       )*
337                   })
338               }
339           }
340        }
341    } else {
342        quote! {
343           impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics {
344               fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result<Self, #krate::error::Error> {
345                   #(
346                       let #idents = if let Ok(s) = t.find_tag(#tags as u32) {
347                           #types::from_tlv(&s)
348                       } else {
349                           #types::tlv_not_found()
350                       }?;
351                   )*
352
353                   Ok(Self {
354                       #(#idents,
355                       )*
356                   })
357               }
358           }
359        }
360    };
361    //        panic!("The generated code is {}", expanded);
362    expanded.into()
363}
364
365/// Generate a FromTlv implementation for an enum
366fn gen_fromtlv_for_enum(
367    data_enum: &syn::DataEnum,
368    enum_name: &proc_macro2::Ident,
369    tlvargs: TlvArgs,
370    generics: &syn::Generics,
371) -> TokenStream {
372    let mut tag_start = tlvargs.start;
373    let lifetime = tlvargs.lifetime;
374
375    let mut variant_names = Vec::new();
376    let mut types = Vec::new();
377    let mut tags = Vec::new();
378
379    for v in data_enum.variants.iter() {
380        variant_names.push(&v.ident);
381        if let syn::Fields::Unnamed(fields) = &v.fields {
382            if let Type::Path(path) = &fields.unnamed[0].ty {
383                types.push(&path.path.segments[0].ident);
384            } else {
385                panic!("Path not found {:?}", v.fields);
386            }
387        } else {
388            panic!("Unnamed field not found {:?}", v.fields);
389        }
390        tags.push(tag_start);
391        tag_start += 1;
392    }
393
394    let krate = Ident::new(&get_crate_name(), Span::call_site());
395
396    let expanded = quote! {
397           impl #generics #krate::tlv::FromTLV <#lifetime> for #enum_name #generics {
398               fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result<Self, #krate::error::Error> {
399                   let mut t_iter = t.confirm_struct()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?;
400                   let mut item = t_iter.next().ok_or_else(|| Error::new(#krate::error::ErrorCode::Invalid))?;
401                   if let TagType::Context(tag) = item.get_tag() {
402                       match tag {
403                           #(
404                               #tags => Ok(Self::#variant_names(#types::from_tlv(&item)?)),
405                           )*
406                           _ => Err(#krate::error::Error::new(#krate::error::ErrorCode::Invalid)),
407                       }
408                   } else {
409                       Err(#krate::error::Error::new(#krate::error::ErrorCode::TLVTypeMismatch))
410                   }
411               }
412           }
413    };
414
415    //        panic!("Expanded to {}", expanded);
416    expanded.into()
417}
418
419/// Derive FromTLV Macro
420///
421/// This macro works for structures. It will create an implementation
422/// of the FromTLV trait for that structure.  All the members of the
423/// structure, sequentially, will get Context tags starting from 0
424/// Some configurations are possible through the 'tlvargs' attributes.
425/// For example:
426///  #[tlvargs(lifetime = "'a", start = 1, datatype = "list", unordered)]
427///
428/// start: This can be used to override the default tag from which the
429///        decoding starts (Default: 0)
430/// datatype: This can be used to define whether this data structure is
431///        to be decoded as a structure or list. Possible values: list
432///        (Default: struct)
433/// lifetime: If the structure has a lifetime annotation, use this variable
434///        to indicate that. The 'impl' will then use that lifetime
435///        indicator.
436/// unordered: By default, the decoder expects that the tags are in
437///        sequentially increasing order. Set this if that is not the case.
438///
439/// Additionally, structure members can use the tagval attribute to
440/// define a specific tag to be used
441/// For example:
442///  #[argval(22)]
443///  name: u8,
444/// In the above case, the 'name' attribute will be encoded/decoded with
445/// the tag 22
446
447#[proc_macro_derive(FromTLV, attributes(tlvargs, tagval))]
448pub fn derive_fromtlv(item: TokenStream) -> TokenStream {
449    let ast = parse_macro_input!(item as DeriveInput);
450    let name = &ast.ident;
451
452    let tlvargs = parse_tlvargs(&ast);
453
454    let generics = ast.generics;
455
456    if let syn::Data::Struct(syn::DataStruct {
457        fields: syn::Fields::Named(ref fields),
458        ..
459    }) = ast.data
460    {
461        gen_fromtlv_for_struct(fields, name, tlvargs, &generics)
462    } else if let syn::Data::Enum(data_enum) = ast.data {
463        gen_fromtlv_for_enum(&data_enum, name, tlvargs, &generics)
464    } else {
465        panic!(
466            "Derive FromTLV - Only supported Struct for now {:?}",
467            ast.data
468        )
469    }
470}