1use proc_macro::TokenStream;
2
3use proc_macro2::{Span, TokenStream as TokenStream2};
4use quote::{quote, quote_spanned};
5use syn::spanned::Spanned;
6use syn::{parse_macro_input, Data, DeriveInput, Expr, Ident, Index, Lit};
7use uuid::Uuid;
8
9#[proc_macro_derive(ParseBox, attributes(box_type))]
10pub fn derive_parse_box(input: TokenStream) -> TokenStream {
11 let input = parse_macro_input!(input as DeriveInput);
12 let ident = &input.ident;
13 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
14
15 if matches!(input.data, Data::Enum(_) | Data::Union(_)) {
16 return TokenStream::from(quote! {
18 std::compile_error!("this trait can only be derived for structs");
19 });
20 }
21 let box_type = extract_box_type(&input);
22 let read_fn = derive_read_fn(&input);
23
24 TokenStream::from(quote! {
25 #[automatically_derived]
26 impl #impl_generics mp4san::parse::ParseBox for #ident #ty_generics #where_clause {
27 fn box_type() -> mp4san::parse::BoxType {
28 #box_type
29 }
30
31 #read_fn
32 }
33 })
34}
35
36#[proc_macro_derive(ParsedBox, attributes(box_type))]
37pub fn derive_parsed_box(input: TokenStream) -> TokenStream {
38 let input = parse_macro_input!(input as DeriveInput);
39 let ident = &input.ident;
40 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
41
42 if matches!(input.data, Data::Enum(_) | Data::Union(_)) {
43 return TokenStream::from(quote! {
45 std::compile_error!("this trait can only be derived for structs");
46 });
47 }
48 let size = sum_box_size(&input);
49 let write_fn = derive_write_fn(&input);
50
51 TokenStream::from(quote! {
52 #[automatically_derived]
53 impl #impl_generics mp4san::parse::ParsedBox for #ident #ty_generics #where_clause {
54 fn encoded_len(&self) -> u64 {
55 #size
56 }
57
58 #write_fn
59 }
60 })
61}
62
63fn derive_write_fn(input: &DeriveInput) -> TokenStream2 {
64 let write_fields = match &input.data {
65 Data::Struct(struct_data) => {
66 let place_expr = struct_data.fields.iter().enumerate().map(|(index, field)| {
67 if let Some(ident) = &field.ident {
68 quote_spanned! { field.span() => self.#ident }
69 } else {
70 let tuple_index = Index::from(index);
71 quote_spanned! { field.span() => self.#tuple_index }
72 }
73 });
74 quote! { #( mp4san::parse::Mp4Value::put_buf(&#place_expr, &mut *out); )* }
75 }
76 _ => unreachable!(),
77 };
78 quote! {
79 fn put_buf(&self, out: &mut dyn bytes::BufMut) {
80 #write_fields
81 }
82 }
83}
84
85fn derive_read_fn(input: &DeriveInput) -> TokenStream2 {
86 let ident = &input.ident;
87 match &input.data {
88 Data::Struct(struct_data) => {
89 let mut field_ty = Vec::new();
90 let mut field_ident = Vec::new();
91 let mut bind_ident = Vec::new();
92 for (index, field) in struct_data.fields.iter().enumerate() {
93 field_ty.push(field.ty.clone());
94 if let Some(ident) = &field.ident {
95 field_ident.push(quote_spanned! { field.span() => #ident });
96 bind_ident.push(ident.clone());
97 } else {
98 let tuple_index = Index::from(index);
99 field_ident.push(quote_spanned! { field.span() => #tuple_index });
100 bind_ident.push(Ident::new(&format!("field_{index}"), Span::mixed_site()));
101 }
102 }
103 quote! {
104 fn parse(buf: &mut bytes::BytesMut) -> std::result::Result<Self, mp4san::Report<mp4san::parse::ParseError>> {
105 #(
106 let #bind_ident: #field_ty =
107 mp4san::parse::error::ParseResultExt::while_parsing_field(
108 mp4san::parse::Mp4Value::parse(&mut *buf),
109 #ident::box_type(),
110 stringify!(#field_ty),
111 )?;
112 )*
113 if !buf.is_empty() {
114 return
115 mp4san::parse::error::ParseResultExt::while_parsing_box(
116 mp4san::error::ResultExt::attach_printable(
117 Err(mp4san::parse::ParseError::InvalidInput.into()),
118 "extra unparsed data",
119 ),
120 #ident::box_type(),
121 );
122 }
123 std::result::Result::Ok(#ident { #( #field_ident: #bind_ident ),* })
124 }
125 }
126 }
127 _ => unreachable!(),
128 }
129}
130
131fn extract_box_type(input: &DeriveInput) -> TokenStream2 {
132 let mut iter = input.attrs.iter().filter(|attr| attr.path().is_ident("box_type"));
133 let Some(attr) = iter.next() else {
134 return quote! { std::compile_error!("missing `#[box_type]` attribute") };
138 };
139 if let Some(extra_attr) = iter.next() {
140 return quote_spanned! { extra_attr.span() =>
141 std::compile_error!("more than one `#[box_type]` attribute is not allowed")
142 };
143 }
144 let lit = match attr.meta.require_name_value().map(|name_value| &name_value.value).ok() {
145 Some(Expr::Lit(lit)) => &lit.lit,
146 _ => {
147 return quote_spanned! { attr.span() =>
148 std::compile_error!("`box_type` attribute must be of the form `#[box_type = ...]`")
149 }
150 }
151 };
152 match &lit {
153 Lit::Int(int_lit) => {
154 let int = match int_lit.base10_parse::<u128>() {
155 Ok(int) => int,
156 Err(error) => return error.into_compile_error(),
157 };
158 if let Ok(int) = u32::try_from(int) {
159 return quote! { mp4san::parse::BoxType::FourCC(mp4san::parse::FourCC { value: #int.to_be_bytes() }) };
160 } else {
161 return quote! { mp4san::parse::BoxType::Uuid(mp4san::parse::BoxUuid { value: #int.to_be_bytes() }) };
162 }
163 }
164 Lit::Str(string_lit) => {
165 let string = string_lit.value();
166 if let Ok(uuid) = Uuid::parse_str(&string) {
167 let int = uuid.as_u128();
168 return quote! { mp4san::parse::BoxType::Uuid(mp4san::parse::BoxUuid { value: #int.to_be_bytes() }) };
169 } else if string.len() == 4 {
170 return quote! {
171 let type_string = #string_lit;
172 let type_ = std::convert::TryInto::try_into(type_string.as_bytes()).unwrap();
173 mp4san::parse::BoxType::FourCC(mp4san::parse::FourCC { value: type_ })
174 };
175 }
176 }
177 Lit::ByteStr(bytes_lit) => {
178 let bytes = bytes_lit.value();
179 if bytes.len() == 4 {
180 return quote! {
181 mp4san::parse::BoxType::FourCC(mp4san::parse::FourCC { value: *#bytes_lit })
182 };
183 }
184 }
185 _ => {}
186 }
187 quote_spanned! { lit.span() => std::compile_error!(concat!(
188 r#"malformed `box_type` attribute input: try `"moov"`, `b"moov"`, or `0x6d6f6f76` for a"#,
189 r#" compact type, or `"a7b5465c-7eac-4caa-b744-bdc340127d37"` or"#,
190 r#" `0xa7b5465c_7eac_4caa_b744_bdc340127d37` for an extended type"#,
191 )) }
192}
193
194fn sum_box_size(derive_input: &DeriveInput) -> TokenStream2 {
195 let sum_expr = match &derive_input.data {
196 Data::Struct(struct_data) => {
197 let sum_expr = struct_data.fields.iter().enumerate().map(|(index, field)| {
198 if let Some(ident) = &field.ident {
199 quote_spanned! { field.span() => mp4san::parse::Mp4Value::encoded_len(&self.#ident) }
200 } else {
201 let tuple_index = Index::from(index);
202 quote_spanned! { field.span() => mp4san::parse::Mp4Value::encoded_len(&self.#tuple_index) }
203 }
204 });
205 quote! { #(+ #sum_expr)* }
206 }
207 _ => unreachable!(),
208 };
209 quote! {
210 0 #sum_expr
211 }
212}