1extern crate proc_macro;
7
8use proc_macro::TokenStream;
9use proc_macro_error2::{abort, proc_macro_error};
10use proc_macro2::TokenStream as TokenStream2;
11use quote::quote;
12use syn::{DeriveInput, Meta, parse_macro_input};
13
14fn get_bits(expr_call: &syn::ExprCall) -> syn::Expr {
15 if let syn::Expr::Path(ep) = &*expr_call.func {
16 if !ep.path.is_ident("Bits") {
17 abort!(
18 expr_call,
19 "Unexpected function name in coder: {}",
20 ep.path.get_ident().unwrap()
21 );
22 }
23 if expr_call.args.len() != 1 {
24 abort!(
25 expr_call,
26 "Unexpected number of arguments for Bits() in coder: {}",
27 expr_call.args.len()
28 );
29 }
30 return expr_call.args[0].clone();
31 }
32 abort!(expr_call, "Unexpected function call in coder");
33}
34
35fn parse_single_coder(input: &syn::Expr, extra_lit: Option<&syn::ExprLit>) -> TokenStream2 {
36 match &input {
37 syn::Expr::Lit(lit) => match extra_lit {
38 None => quote! {U32::Val(#lit)},
39 Some(elit) => quote! {U32::Val(#lit + #elit)},
40 },
41 syn::Expr::Call(expr_call) => {
42 let bits = get_bits(expr_call);
43 match extra_lit {
44 None => quote! {U32::Bits(#bits)},
45 Some(elit) => quote! {U32::BitsOffset{n: #bits, off: #elit}},
46 }
47 }
48 syn::Expr::Binary(syn::ExprBinary {
49 attrs: _,
50 left,
51 op: syn::BinOp::Add(_),
52 right,
53 }) => {
54 let (left, right) = if let syn::Expr::Lit(_) = **left {
55 (right, left)
56 } else {
57 (left, right)
58 };
59 match (&**left, &**right) {
60 (syn::Expr::Call(expr_call), syn::Expr::Lit(lit)) => {
61 let bits = get_bits(expr_call);
62 match extra_lit {
63 None => quote! {U32::BitsOffset{n: #bits, off: #lit}},
64 Some(elit) => quote! {U32::BitsOffset{n: #bits, off: #lit + #elit}},
65 }
66 }
67 _ => abort!(
68 input,
69 "Unexpected expression in coder, must be Bits(a) + b, Bits(a), or b"
70 ),
71 }
72 }
73 _ => abort!(
74 input,
75 "Unexpected expression in coder, must be Bits(a) + b, Bits(a), or b"
76 ),
77 }
78}
79
80fn parse_coder(input: &syn::Expr) -> TokenStream2 {
81 let parse_u2s = |expr_call: &syn::ExprCall, lit: Option<&syn::ExprLit>| {
82 if let syn::Expr::Path(ep) = &*expr_call.func {
83 if !ep.path.is_ident("u2S") {
84 let coder = parse_single_coder(input, None);
85 return quote! {U32Coder::Direct(#coder)};
86 }
87 if expr_call.args.len() != 4 {
88 abort!(
89 input,
90 "Unexpected number of arguments for U32() in coder: {}",
91 expr_call.args.len()
92 );
93 }
94 let args = vec![
95 parse_single_coder(&expr_call.args[0], lit),
96 parse_single_coder(&expr_call.args[1], lit),
97 parse_single_coder(&expr_call.args[2], lit),
98 parse_single_coder(&expr_call.args[3], lit),
99 ];
100 return quote! {U32Coder::Select(#(#args),*)};
101 }
102 abort!(input, "Unexpected function call in coder");
103 };
104
105 match &input {
106 syn::Expr::Call(expr_call) => parse_u2s(expr_call, None),
107 syn::Expr::Binary(syn::ExprBinary {
108 attrs: _,
109 left,
110 op: syn::BinOp::Add(_),
111 right,
112 }) => {
113 let (left, right) = if let syn::Expr::Lit(_) = **left {
114 (right, left)
115 } else {
116 (left, right)
117 };
118 match (&**left, &**right) {
119 (syn::Expr::Call(expr_call), syn::Expr::Lit(lit)) => {
120 parse_u2s(expr_call, Some(lit))
121 }
122 _ => abort!(
123 input,
124 "Unexpected expression in coder, must be (u2S|Bits)(a) + b, (u2S|Bits)(a), or b"
125 ),
126 }
127 }
128 _ => parse_single_coder(input, None),
129 }
130}
131
132fn parse_size_coder(mut input: syn::Expr) -> TokenStream2 {
133 match input {
134 syn::Expr::Call(syn::ExprCall {
135 ref func,
136 ref mut args,
137 ..
138 }) => {
139 if args.len() != 1 {
140 abort!(input, "Expected 1 argument in sized_coder inner call");
141 }
142
143 match &**func {
144 syn::Expr::Path(expr_path) if expr_path.path.is_ident("implicit") => {
145 let arg = args.first().unwrap().clone();
146 parse_coder(&arg)
147 }
148 syn::Expr::Path(expr_path) if expr_path.path.is_ident("explicit") => {
149 quote! { U32Coder::Direct(U32::Val(#args)) }
150 }
151 _ => abort!(
152 input,
153 "Unexpected expression in size_coder, must be 'implicit()' or 'explicit()'"
154 ),
155 }
156 }
157 _ => abort!(
158 input,
159 "Unexpected expression in size_coder, must be 'implicit()' or 'explicit()'"
160 ),
161 }
162}
163
164fn prettify_condition(cond: &syn::Expr) -> String {
165 (quote! {#cond})
166 .to_string()
167 .replace(" . ", ".")
168 .replace("! ", "!")
169 .replace(" :: ", "::")
170}
171
172#[derive(Debug)]
173struct Condition {
174 expr: Option<syn::Expr>,
175 has_all_default: bool,
176 pretty: String,
177}
178
179impl Condition {
180 fn get_expr(&self, all_default_field: &Option<syn::Ident>) -> Option<TokenStream2> {
181 if self.has_all_default {
182 let all_default = all_default_field.as_ref().unwrap();
183 match &self.expr {
184 Some(expr) => Some(quote! { !#all_default && (#expr) }),
185 None => Some(quote! { !#all_default }),
186 }
187 } else {
188 self.expr.as_ref().map(|expr| quote! {#expr})
189 }
190 }
191 fn get_pretty(&self, all_default_field: &Option<syn::Ident>) -> String {
192 if self.has_all_default {
193 let all_default = all_default_field.as_ref().unwrap();
194 let all_default = "!".to_owned() + "e! {#all_default}.to_string();
195 match &self.expr {
196 Some(_) => all_default + " && (" + &self.pretty + ")",
197 None => all_default,
198 }
199 } else {
200 self.pretty.clone()
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
206struct U32 {
207 coder: TokenStream2,
208}
209
210#[derive(Debug)]
211#[allow(clippy::large_enum_variant)]
212enum Coder {
213 WithoutConfig,
214 U32(U32),
215 Select(Condition, U32, U32),
216 Vector(U32, Box<Coder>),
217}
218
219impl Coder {
220 fn ty(&self) -> TokenStream2 {
221 match self {
222 Coder::WithoutConfig => quote! {()},
223 Coder::U32(..) => quote! {U32Coder},
224 Coder::Select(..) => quote! {U32Coder},
225 Coder::Vector(_, value_coder) => {
226 let value_coder_ty = value_coder.ty();
227 quote! {VectorCoder<#value_coder_ty>}
228 }
229 }
230 }
231
232 fn config(&self, all_default_field: &Option<syn::Ident>) -> TokenStream2 {
233 match self {
234 Coder::WithoutConfig => quote! { () },
235 Coder::U32(U32 { coder }) => quote! { #coder },
236 Coder::Select(condition, U32 { coder: coder_true }, U32 { coder: coder_false }) => {
237 let cnd = condition.get_expr(all_default_field).unwrap();
238 quote! {
239 if #cnd { #coder_true } else { #coder_false }
240 }
241 }
242 Coder::Vector(U32 { coder }, value_coder) => {
243 let value_coder = value_coder.config(all_default_field);
244 quote! {VectorCoder{size_coder: #coder, value_coder: #value_coder}}
245 }
246 }
247 }
248}
249
250#[derive(Debug)]
251enum FieldKind {
252 Unconditional(Coder),
253 Conditional(Condition, Coder),
254 Defaulted(Condition, Coder),
255}
256
257#[derive(Debug)]
258struct Field {
259 name: proc_macro2::Ident,
260 kind: FieldKind,
261 ty: syn::Type,
262 default: Option<TokenStream2>,
263 default_element: Option<TokenStream2>,
264 nonserialized_inits: Vec<TokenStream2>,
265}
266
267impl Field {
268 fn parse(f: &syn::Field, num: usize, all_default_field: &mut Option<syn::Ident>) -> Field {
269 let mut condition = None;
270 let mut default = None;
271 let mut coder = None;
272
273 let mut select_coder = None;
274 let mut coder_true = None;
275 let mut coder_false = None;
276
277 let mut is_all_default = false;
278
279 let mut size_coder = None;
280
281 let mut nonserialized = vec![];
282
283 let mut default_element = None;
284
285 for a in &f.attrs {
287 match a.path().get_ident().map(syn::Ident::to_string).as_deref() {
288 Some("coder") => {
289 if coder.is_some() {
290 abort!(f, "Repeated coder");
291 }
292 let coder_ast = a.parse_args::<syn::Expr>().unwrap();
293 coder = Some(Coder::U32(U32 {
294 coder: parse_coder(&coder_ast),
295 }));
296 }
297 Some("default") => {
298 if default.is_some() {
299 abort!(f, "Repeated default");
300 }
301 let default_expr = a.parse_args::<syn::Expr>().unwrap();
302 default = Some(quote! {#default_expr});
303 }
304 Some("default_element") => {
305 if default_element.is_some() {
306 abort!(f, "Repeated default_element")
307 }
308 let default_element_expr = a.parse_args::<syn::Expr>().unwrap();
309 default_element = Some(quote! { #default_element_expr })
310 }
311 Some("condition") => {
312 if condition.is_some() {
313 abort!(f, "Repeated condition");
314 }
315 let condition_ast = a.parse_args::<syn::Expr>().unwrap();
316 let pretty_cond = prettify_condition(&condition_ast);
317 condition = Some(Condition {
318 expr: Some(condition_ast),
319 has_all_default: all_default_field.is_some(),
320 pretty: pretty_cond,
321 });
322 }
323 Some("all_default") => {
324 if num != 0 {
325 abort!(f, "all_default is not the first field");
326 }
327 if default.is_some() {
328 abort!(f, "all_default has an implicit default");
329 }
330 is_all_default = true;
331 default = Some(quote! { true });
332 }
333 Some("select_coder") => {
334 if select_coder.is_some() {
335 abort!(f, "Repeated select_coder");
336 }
337 let condition_ast = a.parse_args::<syn::Expr>().unwrap();
338 let pretty_cond = prettify_condition(&condition_ast);
339 select_coder = Some(Condition {
340 expr: Some(condition_ast),
341 has_all_default: false,
342 pretty: pretty_cond,
343 });
344 }
345 Some("coder_false") => {
346 if coder_false.is_some() {
347 abort!(f, "Repeated coder_false");
348 }
349 let coder_ast = a.parse_args::<syn::Expr>().unwrap();
350 coder_false = Some(U32 {
351 coder: parse_coder(&coder_ast),
352 });
353 }
354 Some("coder_true") => {
355 if coder_true.is_some() {
356 abort!(f, "Repeated coder_true");
357 }
358 let coder_ast = a.parse_args::<syn::Expr>().unwrap();
359 coder_true = Some(U32 {
360 coder: parse_coder(&coder_ast),
361 });
362 }
363 Some("size_coder") => {
364 if size_coder.is_some() {
365 abort!(f, "Repeated size_coder");
366 }
367 let coder_ast = a.parse_args::<syn::Expr>().unwrap();
368 size_coder = Some(U32 {
369 coder: parse_size_coder(coder_ast),
370 });
371 }
372 Some("nonserialized") => {
373 let Meta::List(ns) = &a.meta else {
374 abort!(a, "Invalid attribute");
375 };
376 let stream = &ns.tokens;
377 nonserialized.push(quote! {#stream});
378 }
379 _ => {}
380 }
381 }
382
383 if default.is_some() && default_element.is_some() {
384 abort!(f, "default is incompatible with default_element");
385 }
386
387 if let Some(select_coder) = select_coder {
388 if coder_true.is_none() || coder_false.is_none() {
389 abort!(
390 f,
391 "Invalid field, select_coder is set but coder_true or coder_false are not"
392 )
393 }
394 if coder.is_some() {
395 abort!(f, "Invalid field, select_coder and coder are both present")
396 }
397 coder = Some(Coder::Select(
398 select_coder,
399 coder_true.unwrap(),
400 coder_false.unwrap(),
401 ))
402 }
403
404 let condition = if condition.is_some() || all_default_field.is_none() {
405 condition
406 } else {
407 Some(Condition {
408 expr: None,
409 has_all_default: true,
410 pretty: String::new(),
411 })
412 };
413
414 let mut coder = coder.unwrap_or_else(|| Coder::WithoutConfig);
416
417 if let Some(c) = size_coder {
418 if default.is_none() {
419 default = Some(quote! { Vec::new() });
420 }
421
422 coder = Coder::Vector(c, Box::new(coder))
423 }
424
425 let ident = f.ident.as_ref().unwrap();
426
427 let kind = match (condition, default.is_some()) {
428 (None, _) => FieldKind::Unconditional(coder),
429 (Some(cond), false) => FieldKind::Conditional(cond, coder),
430 (Some(cond), true) => FieldKind::Defaulted(cond, coder),
431 };
432 if is_all_default {
433 *all_default_field = Some(f.ident.as_ref().unwrap().clone());
434 }
435 Field {
436 name: ident.clone(),
437 kind,
438 ty: f.ty.clone(),
439 default,
440 default_element,
441 nonserialized_inits: nonserialized,
442 }
443 }
444
445 fn read_fun(&self, all_default_field: &Option<syn::Ident>) -> TokenStream2 {
447 let ident = &self.name;
448 let ty = &self.ty;
449 let nonserialized_inits = &self.nonserialized_inits;
450 match &self.kind {
451 FieldKind::Unconditional(coder) => {
452 let cfg_ty = coder.ty();
453 let cfg = coder.config(all_default_field);
454 let trc = quote! {
455 crate::util::tracing_wrappers::trace!("Setting {} to {:?}. total_bits_read: {}, peek: {:08b}", stringify!(#ident), #ident, br.total_bits_read(), br.peek(8));
456 };
457 quote! {
458 let #ident = {
459 let cfg = #cfg;
460 type NS = <#ty as UnconditionalCoder<#cfg_ty>>::Nonserialized;
461 let nonserialized = NS { #(#nonserialized_inits),* };
462 <#ty>::read_unconditional(&cfg, br, &nonserialized)?
463 };
464 #trc
465 }
466 }
467 FieldKind::Conditional(condition, coder) => {
468 let cfg_ty = coder.ty();
469 let cfg = coder.config(all_default_field);
470 let cnd = condition.get_expr(all_default_field).unwrap();
471 let pretty_cnd = condition.get_pretty(all_default_field);
472 let trc = quote! {
473 crate::util::tracing_wrappers::trace!("{} is {}, setting {} to {:?}. total_bits_read: {}, peek {:08b}", #pretty_cnd, #cnd, stringify!(#ident), #ident, br.total_bits_read(), br.peek(8));
474 };
475 quote! {
476 let #ident = {
477 let cond = #cnd;
478 let cfg = #cfg;
479 type NS = <#ty as ConditionalCoder<#cfg_ty>>::Nonserialized;
480 let nonserialized = NS { #(#nonserialized_inits),* };
481 <#ty>::read_conditional(&cfg, cond, br, &nonserialized)?
482 };
483 #trc
484 }
485 }
486 FieldKind::Defaulted(condition, coder) => {
487 let cfg_ty = coder.ty();
488 let cfg = coder.config(all_default_field);
489 let cnd = condition.get_expr(all_default_field).unwrap();
490 let pretty_cnd = condition.get_pretty(all_default_field);
491 let default = &self.default;
492 let trc = quote! {
493 crate::util::tracing_wrappers::trace!("{} is {}, setting {} to {:?}. total_bits_read: {}, peek {:08b}", #pretty_cnd, #cnd, stringify!(#ident), #ident, br.total_bits_read(), br.peek(8));
494 };
495
496 let (read_fn, default) = if let Some(def) = &self.default_element {
497 (quote! { read_defaulted_element }, Some(def))
498 } else {
499 (quote! { read_defaulted }, default.as_ref())
500 };
501
502 quote! {
503 let #ident = {
504 let cond = #cnd;
505 let cfg = #cfg;
506 type NS = <#ty as DefaultedCoder<#cfg_ty>>::Nonserialized;
507 let field_nonserialized = NS { #(#nonserialized_inits),* };
508 let default = #default;
509 <#ty>::#read_fn(&cfg, cond, default, br, &field_nonserialized)?
510 };
511 #trc
512 }
513 }
514 }
515 }
516
517 fn default_code(&self) -> TokenStream2 {
519 let ident = &self.name;
520 let ty = &self.ty;
521 let nonserialized_inits = &self.nonserialized_inits;
522 let default = &self.default;
523 match &self.kind {
524 FieldKind::Defaulted(_, coder) => {
525 let cfg_ty = coder.ty();
526 let default = &self.default;
527
528 quote! {
529 let #ident = {
530 type NS = <#ty as DefaultedCoder<#cfg_ty>>::Nonserialized;
531 let field_nonserialized = NS { #(#nonserialized_inits),* };
532 #default
533 };
534 }
535 }
536 _ => quote! { let #ident = #default; },
537 }
538 }
539}
540
541fn derive_struct(input: &DeriveInput) -> TokenStream2 {
542 let name = &input.ident;
543
544 let validate = input.attrs.iter().any(|a| a.path().is_ident("validate"));
545 let nonserialized: Vec<_> = input
546 .attrs
547 .iter()
548 .filter_map(|a| {
549 if a.path().is_ident("nonserialized") {
550 Some(a.parse_args::<syn::Expr>().unwrap())
551 } else {
552 None
553 }
554 })
555 .collect();
556 if nonserialized.len() > 1 {
557 abort!(input, "repeated nonserialized");
558 }
559 let nonserialized = if nonserialized.is_empty() {
560 quote! {Empty}
561 } else {
562 let v = &nonserialized[0];
563 quote! {#v}
564 };
565
566 let data = if let syn::Data::Struct(struct_data) = &input.data {
567 struct_data
568 } else {
569 abort!(input, "derive_struct didn't get a struct");
570 };
571
572 let fields = if let syn::Fields::Named(syn::FieldsNamed {
573 brace_token: _,
574 named,
575 }) = &data.fields
576 {
577 named
578 } else {
579 abort!(data.fields, "only named fields are supported (for now?)");
580 };
581
582 let mut all_default_field = None;
583
584 let fields: Vec<_> = fields
585 .iter()
586 .enumerate()
587 .map(|(n, f)| Field::parse(f, n, &mut all_default_field))
588 .collect();
589 let fields_read = fields.iter().map(|x| x.read_fun(&all_default_field));
590 let fields_names = fields.iter().map(|x| &x.name);
591
592 let impl_default = if fields.iter().all(|x| x.default.is_some()) {
593 let field_init = fields.iter().map(Field::default_code);
594 let struct_init = fields.iter().map(|f| {
595 let ident = &f.name;
596 quote! { #ident }
597 });
598 quote! {
599 impl #name {
600 pub fn default(nonserialized: &#nonserialized) -> #name {
601 #(#field_init)*
602 #name {
603 #(#struct_init),*
604 }
605 }
606 }
607
608 }
609 } else {
610 quote! {}
611 };
612
613 let impl_validate = if validate {
614 quote! { return_value.check(nonserialized)?; }
615 } else {
616 quote! {}
617 };
618
619 let align = match input.attrs.iter().any(|a| a.path().is_ident("aligned")) {
620 true => quote! { br.jump_to_byte_boundary()?; },
621 false => quote! {},
622 };
623
624 quote! {
625 #impl_default
626 impl crate::headers::encodings::UnconditionalCoder<()> for #name {
627 type Nonserialized = #nonserialized;
628 #[cold]
629 #[inline(never)]
630 fn read_unconditional(_: &(), br: &mut BitReader, nonserialized: &Self::Nonserialized) -> Result<#name, Error> {
631 use crate::headers::encodings::UnconditionalCoder;
632 use crate::headers::encodings::ConditionalCoder;
633 use crate::headers::encodings::DefaultedCoder;
634 use crate::headers::encodings::DefaultedElementCoder;
635 #align
636 #(#fields_read)*
637 let return_value = #name {
638 #(#fields_names),*
639 };
640 #impl_validate
641 Ok(return_value)
642 }
643 }
644 }
645}
646
647fn derive_enum(input: &DeriveInput) -> TokenStream2 {
648 let name = &input.ident;
649 quote! {
650 impl crate::headers::encodings::UnconditionalCoder<U32Coder> for #name {
651 type Nonserialized = Empty;
652 fn read_unconditional(config: &U32Coder, br: &mut BitReader, _: &Empty) -> Result<#name, Error> {
653 use num_traits::FromPrimitive;
654 let u = u32::read_unconditional(config, br, &Empty{})?;
655 if let Some(e) = #name::from_u32(u) {
656 Ok(e)
657 } else {
658 Err(Error::InvalidEnum(u, stringify!(#name).to_string()))
659 }
660 }
661 }
662 impl crate::headers::encodings::UnconditionalCoder<()> for #name {
663 type Nonserialized = Empty;
664 fn read_unconditional(config: &(), br: &mut BitReader, nonserialized: &Empty) -> Result<#name, Error> {
665 #name::read_unconditional(
666 &U32Coder::Select(
667 U32::Val(0), U32::Val(1),
668 U32::BitsOffset{n: 4, off: 2},
669 U32::BitsOffset{n: 6, off: 18}), br, nonserialized)
670 }
671 }
672 }
673}
674
675#[proc_macro_error]
676#[proc_macro_derive(
677 UnconditionalCoder,
678 attributes(
679 coder,
680 condition,
681 default,
682 default_element,
683 all_default,
684 select_coder,
685 coder_true,
686 coder_false,
687 validate,
688 size_coder,
689 nonserialized,
690 aligned,
691 )
692)]
693pub fn derive_jxl_headers(input: TokenStream) -> TokenStream {
694 let input = parse_macro_input!(input as DeriveInput);
695
696 match &input.data {
697 syn::Data::Struct(_) => derive_struct(&input).into(),
698 syn::Data::Enum(_) => derive_enum(&input).into(),
699 _ => abort!(input, "Only implemented for struct"),
700 }
701}
702
703#[proc_macro_attribute]
704pub fn noop(_attr: TokenStream, item: TokenStream) -> TokenStream {
705 item
706}
707
708#[cfg(feature = "test")]
709#[proc_macro]
710pub fn for_each_test_file(input: TokenStream) -> TokenStream {
711 use std::{fs, path::Path};
712 use syn::Ident;
713
714 let fn_name = parse_macro_input!(input as Ident);
715 let root_test_dir = Path::new(env!("CARGO_MANIFEST_DIR"))
716 .join("..")
717 .join("jxl")
718 .join("resources")
719 .join("test");
720 let conformance_test_dir = root_test_dir.join("conformance_test_images");
721
722 let mut tests = vec![];
723
724 for test_dir in [root_test_dir, conformance_test_dir] {
725 for entry in fs::read_dir(&test_dir).unwrap() {
726 let entry = entry.unwrap();
727 let path = entry.path();
728 if path.extension().is_some_and(|ext| ext == "jxl") {
729 let pathname = path.to_string_lossy();
730 let relative_path = path
731 .strip_prefix(&test_dir)
732 .unwrap()
733 .to_string_lossy()
734 .replace('/', "_slash_");
735 let test_name = format!(
736 "{}_{}",
737 fn_name,
738 relative_path.strip_suffix(".jxl").unwrap()
739 );
740 let test_name = Ident::new(&test_name, fn_name.span());
741 tests.push(quote! {
742 #[test]
743 fn #test_name() {
744 #fn_name(&Path::new(#pathname)).unwrap()
745 }
746 });
747 }
748 }
749 }
750
751 quote! {
752 #(#tests)*
753 }
754 .into()
755}