mp4san_derive/
lib.rs

1use proc_macro::TokenStream;
2
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::{quote, quote_spanned};
5use syn::spanned::Spanned;
6use syn::{parse_macro_input, Data, DeriveInput, Expr, Ident, Index, Lit};
7use uuid::Uuid;
8
9#[proc_macro_derive(ParseBox, attributes(box_type))]
10pub fn derive_parse_box(input: TokenStream) -> TokenStream {
11    let input = parse_macro_input!(input as DeriveInput);
12    let ident = &input.ident;
13    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
14
15    if matches!(input.data, Data::Enum(_) | Data::Union(_)) {
16        // This one _does_ need a semicolon though.
17        return TokenStream::from(quote! {
18            std::compile_error!("this trait can only be derived for structs");
19        });
20    }
21    let box_type = extract_box_type(&input);
22    let read_fn = derive_read_fn(&input);
23
24    TokenStream::from(quote! {
25        #[automatically_derived]
26        impl #impl_generics mp4san::parse::ParseBox for #ident #ty_generics #where_clause {
27            fn box_type() -> mp4san::parse::BoxType {
28                #box_type
29            }
30
31            #read_fn
32        }
33    })
34}
35
36#[proc_macro_derive(ParsedBox, attributes(box_type))]
37pub fn derive_parsed_box(input: TokenStream) -> TokenStream {
38    let input = parse_macro_input!(input as DeriveInput);
39    let ident = &input.ident;
40    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
41
42    if matches!(input.data, Data::Enum(_) | Data::Union(_)) {
43        // This one _does_ need a semicolon though.
44        return TokenStream::from(quote! {
45            std::compile_error!("this trait can only be derived for structs");
46        });
47    }
48    let size = sum_box_size(&input);
49    let write_fn = derive_write_fn(&input);
50
51    TokenStream::from(quote! {
52        #[automatically_derived]
53        impl #impl_generics mp4san::parse::ParsedBox for #ident #ty_generics #where_clause {
54            fn encoded_len(&self) -> u64 {
55                #size
56            }
57
58            #write_fn
59        }
60    })
61}
62
63fn derive_write_fn(input: &DeriveInput) -> TokenStream2 {
64    let write_fields = match &input.data {
65        Data::Struct(struct_data) => {
66            let place_expr = struct_data.fields.iter().enumerate().map(|(index, field)| {
67                if let Some(ident) = &field.ident {
68                    quote_spanned! { field.span() => self.#ident }
69                } else {
70                    let tuple_index = Index::from(index);
71                    quote_spanned! { field.span() => self.#tuple_index }
72                }
73            });
74            quote! { #( mp4san::parse::Mp4Value::put_buf(&#place_expr, &mut *out); )* }
75        }
76        _ => unreachable!(),
77    };
78    quote! {
79        fn put_buf(&self, out: &mut dyn bytes::BufMut) {
80            #write_fields
81        }
82    }
83}
84
85fn derive_read_fn(input: &DeriveInput) -> TokenStream2 {
86    let ident = &input.ident;
87    match &input.data {
88        Data::Struct(struct_data) => {
89            let mut field_ty = Vec::new();
90            let mut field_ident = Vec::new();
91            let mut bind_ident = Vec::new();
92            for (index, field) in struct_data.fields.iter().enumerate() {
93                field_ty.push(field.ty.clone());
94                if let Some(ident) = &field.ident {
95                    field_ident.push(quote_spanned! { field.span() => #ident });
96                    bind_ident.push(ident.clone());
97                } else {
98                    let tuple_index = Index::from(index);
99                    field_ident.push(quote_spanned! { field.span() => #tuple_index });
100                    bind_ident.push(Ident::new(&format!("field_{index}"), Span::mixed_site()));
101                }
102            }
103            quote! {
104                fn parse(buf: &mut bytes::BytesMut) -> std::result::Result<Self, mp4san::Report<mp4san::parse::ParseError>> {
105                    #(
106                        let #bind_ident: #field_ty =
107                            mp4san::parse::error::ParseResultExt::while_parsing_field(
108                                mp4san::parse::Mp4Value::parse(&mut *buf),
109                                #ident::box_type(),
110                                stringify!(#field_ty),
111                            )?;
112                    )*
113                    if !buf.is_empty() {
114                        return
115                            mp4san::parse::error::ParseResultExt::while_parsing_box(
116                                mp4san::error::ResultExt::attach_printable(
117                                    Err(mp4san::parse::ParseError::InvalidInput.into()),
118                                    "extra unparsed data",
119                                ),
120                                #ident::box_type(),
121                            );
122                    }
123                    std::result::Result::Ok(#ident { #( #field_ident: #bind_ident ),* })
124                }
125            }
126        }
127        _ => unreachable!(),
128    }
129}
130
131fn extract_box_type(input: &DeriveInput) -> TokenStream2 {
132    let mut iter = input.attrs.iter().filter(|attr| attr.path().is_ident("box_type"));
133    let Some(attr) = iter.next() else {
134        // When emitting compiler errors, no semicolon should be placed after `compile_error!()`:
135        // doing so will generate extraneous errors (type mismatch errors, Rust parse errors, or the
136        // like) in addition to the error we intend to emit.
137        return quote! { std::compile_error!("missing `#[box_type]` attribute") };
138    };
139    if let Some(extra_attr) = iter.next() {
140        return quote_spanned! { extra_attr.span() =>
141            std::compile_error!("more than one `#[box_type]` attribute is not allowed")
142        };
143    }
144    let lit = match attr.meta.require_name_value().map(|name_value| &name_value.value).ok() {
145        Some(Expr::Lit(lit)) => &lit.lit,
146        _ => {
147            return quote_spanned! { attr.span() =>
148                std::compile_error!("`box_type` attribute must be of the form `#[box_type = ...]`")
149            }
150        }
151    };
152    match &lit {
153        Lit::Int(int_lit) => {
154            let int = match int_lit.base10_parse::<u128>() {
155                Ok(int) => int,
156                Err(error) => return error.into_compile_error(),
157            };
158            if let Ok(int) = u32::try_from(int) {
159                return quote! { mp4san::parse::BoxType::FourCC(mp4san::parse::FourCC { value: #int.to_be_bytes() }) };
160            } else {
161                return quote! { mp4san::parse::BoxType::Uuid(mp4san::parse::BoxUuid { value: #int.to_be_bytes() }) };
162            }
163        }
164        Lit::Str(string_lit) => {
165            let string = string_lit.value();
166            if let Ok(uuid) = Uuid::parse_str(&string) {
167                let int = uuid.as_u128();
168                return quote! { mp4san::parse::BoxType::Uuid(mp4san::parse::BoxUuid { value: #int.to_be_bytes() }) };
169            } else if string.len() == 4 {
170                return quote! {
171                    let type_string = #string_lit;
172                    let type_ = std::convert::TryInto::try_into(type_string.as_bytes()).unwrap();
173                    mp4san::parse::BoxType::FourCC(mp4san::parse::FourCC { value: type_ })
174                };
175            }
176        }
177        Lit::ByteStr(bytes_lit) => {
178            let bytes = bytes_lit.value();
179            if bytes.len() == 4 {
180                return quote! {
181                    mp4san::parse::BoxType::FourCC(mp4san::parse::FourCC { value: *#bytes_lit })
182                };
183            }
184        }
185        _ => {}
186    }
187    quote_spanned! { lit.span() => std::compile_error!(concat!(
188        r#"malformed `box_type` attribute input: try `"moov"`, `b"moov"`, or `0x6d6f6f76` for a"#,
189        r#" compact type, or `"a7b5465c-7eac-4caa-b744-bdc340127d37"` or"#,
190        r#" `0xa7b5465c_7eac_4caa_b744_bdc340127d37` for an extended type"#,
191    )) }
192}
193
194fn sum_box_size(derive_input: &DeriveInput) -> TokenStream2 {
195    let sum_expr = match &derive_input.data {
196        Data::Struct(struct_data) => {
197            let sum_expr = struct_data.fields.iter().enumerate().map(|(index, field)| {
198                if let Some(ident) = &field.ident {
199                    quote_spanned! { field.span() => mp4san::parse::Mp4Value::encoded_len(&self.#ident) }
200                } else {
201                    let tuple_index = Index::from(index);
202                    quote_spanned! { field.span() => mp4san::parse::Mp4Value::encoded_len(&self.#tuple_index) }
203                }
204            });
205            quote! { #(+ #sum_expr)* }
206        }
207        _ => unreachable!(),
208    };
209    quote! {
210        0 #sum_expr
211    }
212}