1use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse::Parser;
9use syn::punctuated::Punctuated;
10use syn::{parse_macro_input, Attribute, DeriveInput, Meta, Token};
11
12fn has_flag(attrs: &[Attribute], name: &str) -> bool {
13 for attr in attrs {
14 if !attr.path().is_ident("beam") {
15 continue;
16 }
17 if let Meta::List(list) = &attr.meta {
18 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
20 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
21 for meta in metas {
22 if let Meta::Path(path) = meta {
23 if path.is_ident(name) {
24 return true;
25 }
26 }
27 }
28 }
29 }
30 }
31 false
32}
33
34fn get_version_value(attrs: &[Attribute]) -> Option<syn::Ident> {
35 for attr in attrs {
36 if !attr.path().is_ident("beam") {
37 continue;
38 }
39 if let Meta::List(list) = &attr.meta {
40 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
41 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
42 for meta in metas {
43 if let Meta::NameValue(nv) = meta {
44 if nv.path.is_ident("min_version") {
45 if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. }) = &nv.value {
46 return Some(syn::Ident::new(&lit_str.value(), lit_str.span()));
47 }
48 }
49 }
50 }
51 }
52 }
53 }
54 None
55}
56
57fn get_profile_value(attrs: &[Attribute]) -> Option<u8> {
58 for attr in attrs {
59 if !attr.path().is_ident("beam") {
60 continue;
61 }
62 if let Meta::List(list) = &attr.meta {
63 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
64 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
65 for meta in metas {
66 if let Meta::NameValue(nv) = meta {
67 if nv.path.is_ident("profile") {
68 if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) = &nv.value {
69 if let Ok(profile) = lit_int.base10_parse::<u8>() {
70 return Some(profile);
71 }
72 }
73 }
74 }
75 }
76 }
77 }
78 }
79 None
80}
81
82fn get_profile_type(attrs: &[Attribute]) -> Option<syn::Type> {
83 for attr in attrs {
84 if !attr.path().is_ident("beam") {
85 continue;
86 }
87 if let Meta::List(list) = &attr.meta {
88 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
89 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
90 for meta in metas {
91 if let Meta::List(profile_list) = meta {
92 if profile_list.path.is_ident("profile") {
93 if let Ok(ty) = syn::parse2::<syn::Type>(profile_list.tokens.clone()) {
95 return Some(ty);
96 }
97 }
98 }
99 }
100 }
101 }
102 }
103 None
104}
105
106fn has_attr(attrs: &[Attribute], name: &str) -> bool {
107 attrs.iter().any(|attr| attr.path().is_ident(name))
108}
109
110fn get_error_message(attrs: &[Attribute]) -> Option<String> {
111 for attr in attrs {
112 if attr.path().is_ident("error") {
113 if let Meta::List(list) = &attr.meta {
114 if let Ok(lit_str) = syn::parse2::<syn::LitStr>(list.tokens.clone()) {
115 return Some(lit_str.value());
116 }
117 }
118 }
119 }
120 None
121}
122
123#[proc_macro_derive(Beamable, attributes(beam))]
128pub fn derive_beamable(input: TokenStream) -> TokenStream {
129 let input = parse_macro_input!(input as DeriveInput);
130 let name = &input.ident;
131
132 let confidential = has_flag(&input.attrs, "confidential");
133 let nonrep = has_flag(&input.attrs, "nonrepudiable");
134 let compressed = has_flag(&input.attrs, "compressed");
135 let prioritized = has_flag(&input.attrs, "prioritized");
136 let message_integrity = has_flag(&input.attrs, "message_integrity");
137 let frame_integrity = has_flag(&input.attrs, "frame_integrity");
138 let min_version = get_version_value(&input.attrs);
139 let profile_value = get_profile_value(&input.attrs);
140 let profile_type = get_profile_type(&input.attrs);
141
142 if profile_value.is_some() && profile_type.is_some() {
144 return syn::Error::new_spanned(
145 &input,
146 "Cannot specify both numeric profile (= N) and type-based profile (Type) simultaneously",
147 )
148 .to_compile_error()
149 .into();
150 }
151
152 let (profile_confidential, profile_nonrep, profile_min_version) = match profile_value {
154 Some(1) => (true, true, Some(syn::Ident::new("V1", name.span()))), Some(2) => (true, true, Some(syn::Ident::new("V1", name.span()))), Some(p) if p > 2 => (false, false, None),
157 _ => (false, false, None),
158 };
159
160 let final_confidential = profile_confidential || confidential;
162 let final_nonrep = profile_nonrep || nonrep;
163 let final_min_version = profile_min_version.or(min_version);
164 let final_message_integrity = message_integrity;
165 let final_frame_integrity = frame_integrity;
166
167 let mut feature_checks = Vec::new();
168
169 if final_confidential && !cfg!(feature = "aead") {
170 feature_checks.push(quote! {
171 compile_error!(concat!(
172 "Message type `", stringify!(#name), "` is marked as confidential ",
173 "but the `aead` feature is not enabled. ",
174 "Enable the feature in Cargo.toml: features = [\"aead\"]"
175 ));
176 });
177 }
178
179 if final_nonrep && !cfg!(feature = "signature") {
180 feature_checks.push(quote! {
181 compile_error!(concat!(
182 "Message type `", stringify!(#name), "` is marked as non-repudiable ",
183 "but the `signature` feature is not enabled. ",
184 "Enable the feature in Cargo.toml: features = [\"signature\"]"
185 ));
186 });
187 }
188
189 if compressed && !cfg!(feature = "compress") {
190 feature_checks.push(quote! {
191 compile_error!(concat!(
192 "Message type `", stringify!(#name), "` is marked as compressed ",
193 "but the `compress` feature is not enabled. ",
194 "Enable the feature in Cargo.toml: features = [\"compress\"]"
195 ));
196 });
197 }
198
199 if (final_message_integrity || final_frame_integrity) && !cfg!(feature = "digest") {
200 feature_checks.push(quote! {
201 compile_error!(concat!(
202 "Message type `", stringify!(#name), "` is marked as requiring message integrity ",
203 "but the `digest` feature is not enabled. ",
204 "Enable the feature in Cargo.toml: features = [\"digest\"]"
205 ));
206 });
207 }
208
209 let min_version_value = if let Some(version) = final_min_version {
210 quote! { ::tightbeam::Version::#version }
211 } else {
212 quote! { ::tightbeam::Version::V0 }
213 };
214
215 let _has_profile = profile_type.is_some();
216 let profile_type_impl = if let Some(profile_ty) = &profile_type {
220 quote! {
221 const HAS_PROFILE: bool = true;
222 ::tightbeam::__tb_if_crypto! { type Profile = #profile_ty; }
223 }
224 } else {
225 quote! {
227 const HAS_PROFILE: bool = false;
228 ::tightbeam::__tb_if_crypto! { type Profile = ::tightbeam::crypto::profiles::TightbeamProfile; }
229 }
230 };
231
232 let oid_validation_helpers = if let Some(profile_ty) = &profile_type {
237 quote! {
240 ::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_digest! {
241 impl ::tightbeam::builder::private::SealedDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
242 where
243 #name: ::tightbeam::Message,
244 {}
245
246 impl ::tightbeam::builder::CheckDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
247 where
248 #name: ::tightbeam::Message,
249 {
250 const RESULT: () = ();
251 }
252 } }
253
254 ::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_aead! {
255 impl ::tightbeam::builder::private::SealedAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
256 where
257 #name: ::tightbeam::Message,
258 {}
259
260 impl ::tightbeam::builder::CheckAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
261 where
262 #name: ::tightbeam::Message,
263 {
264 const RESULT: () = ();
265 }
266 } }
267
268 ::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_signature! {
269 impl ::tightbeam::builder::private::SealedSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
270 where
271 #name: ::tightbeam::Message,
272 {}
273
274 impl ::tightbeam::builder::CheckSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
275 where
276 #name: ::tightbeam::Message,
277 {
278 const RESULT: () = ();
279 }
280 } }
281 }
282 } else {
283 quote! {
286 ::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_digest! {
287 impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedDigestOid<D> for #name
288 where
289 #name: ::tightbeam::Message,
290 {}
291
292 impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckDigestOid<D> for #name
293 where
294 #name: ::tightbeam::Message,
295 {
296 const RESULT: () = ();
297 }
298 } }
299
300 ::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_aead! {
301 impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedAeadOid<C> for #name
302 where
303 #name: ::tightbeam::Message,
304 {}
305
306 impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckAeadOid<C> for #name
307 where
308 #name: ::tightbeam::Message,
309 {
310 const RESULT: () = ();
311 }
312 } }
313
314 ::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_signature! {
315 impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::private::SealedSignatureOid<S> for #name
316 where
317 #name: ::tightbeam::Message,
318 {}
319
320 impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::CheckSignatureOid<S> for #name
321 where
322 #name: ::tightbeam::Message,
323 {
324 const RESULT: () = ();
325 }
326 } }
327 }
328 };
329
330 let expanded = quote! {
331 const _: () = {
332 #(#feature_checks)*
333 };
334
335 impl ::tightbeam::Message for #name {
336 const MUST_BE_CONFIDENTIAL: bool = #final_confidential;
337 const MUST_BE_NON_REPUDIABLE: bool = #final_nonrep;
338 const MUST_HAVE_MESSAGE_INTEGRITY: bool = #final_message_integrity;
339 const MUST_HAVE_FRAME_INTEGRITY: bool = #final_frame_integrity;
340 const MUST_BE_COMPRESSED: bool = #compressed;
341 const MUST_BE_PRIORITIZED: bool = #prioritized;
342 const MIN_VERSION: ::tightbeam::Version = #min_version_value;
343 #profile_type_impl
344 }
345
346 #oid_validation_helpers
347 };
348
349 TokenStream::from(expanded)
350}
351
352#[proc_macro_derive(Flaggable)]
357pub fn derive_flaggable(input: TokenStream) -> TokenStream {
358 let input = parse_macro_input!(input as DeriveInput);
359 let name = &input.ident;
360 let name_str = name.to_string();
361
362 let expanded = quote! {
363 impl From<#name> for u8 {
364 fn from(val: #name) -> u8 {
365 val as u8
366 }
367 }
368
369 impl PartialEq<u8> for #name {
370 fn eq(&self, other: &u8) -> bool {
371 (*self as u8) == *other
372 }
373 }
374
375 impl #name {
376 pub const TYPE_NAME: &'static str = #name_str;
377 }
378 };
379
380 TokenStream::from(expanded)
381}
382
383#[proc_macro_derive(Errorizable, attributes(error, from))]
394pub fn derive_errorizable(input: TokenStream) -> TokenStream {
395 let input = parse_macro_input!(input as DeriveInput);
396 let name = &input.ident;
397
398 let data_enum = match &input.data {
399 syn::Data::Enum(data) => data,
400 _ => {
401 return syn::Error::new_spanned(&input, "Errorizable can only be derived for enums")
402 .to_compile_error()
403 .into();
404 }
405 };
406
407 let mut display_arms = Vec::new();
408 let mut from_impls = Vec::new();
409
410 for variant in &data_enum.variants {
411 let variant_name = &variant.ident;
412
413 let error_msg = get_error_message(&variant.attrs);
415 let has_from = has_attr(&variant.attrs, "from");
416
417 match &variant.fields {
419 syn::Fields::Unnamed(fields) => {
420 let field_count = fields.unnamed.len();
421 let field_bindings: Vec<_> = (0..field_count)
422 .map(|i| syn::Ident::new(&format!("f{i}"), variant_name.span()))
423 .collect();
424
425 if let Some(msg) = error_msg {
426 if msg.contains("{expected") || msg.contains("{received") {
428 display_arms.push(quote! {
430 #name::#variant_name(ref f0) => {
431 write!(f, #msg, expected = f0.expected, received = f0.received)
432 }
433 });
434 } else {
435 display_arms.push(quote! {
436 #name::#variant_name(#(ref #field_bindings),*) => {
437 write!(f, #msg, #(#field_bindings),*)
438 }
439 });
440 }
441 } else {
442 display_arms.push(quote! {
443 #name::#variant_name(#(ref #field_bindings),*) => {
444 write!(f, "{}", stringify!(#variant_name))
445 }
446 });
447 }
448
449 if has_from && field_count == 1 {
451 if let Some(field) = fields.unnamed.first() {
452 let field_type = &field.ty;
453 from_impls.push(quote! {
454 impl From<#field_type> for #name {
455 fn from(err: #field_type) -> Self {
456 #name::#variant_name(err)
457 }
458 }
459 });
460 }
461 }
462 }
463 syn::Fields::Named(fields) => {
464 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
465
466 if let Some(msg) = error_msg {
467 display_arms.push(quote! {
468 #name::#variant_name { #(ref #field_names),* } => {
469 write!(f, #msg, #(#field_names = #field_names),*)
470 }
471 });
472 } else {
473 display_arms.push(quote! {
474 #name::#variant_name { .. } => {
475 write!(f, "{}", stringify!(#variant_name))
476 }
477 });
478 }
479 }
480 syn::Fields::Unit => {
481 if let Some(msg) = error_msg {
482 display_arms.push(quote! {
483 #name::#variant_name => write!(f, #msg)
484 });
485 } else {
486 display_arms.push(quote! {
487 #name::#variant_name => write!(f, "{}", stringify!(#variant_name))
488 });
489 }
490 }
491 }
492 }
493
494 let expanded = quote! {
495 impl core::fmt::Display for #name {
496 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
497 match self {
498 #(#display_arms,)*
499 }
500 }
501 }
502
503 impl core::error::Error for #name {}
504
505 #(#from_impls)*
506 };
507
508 TokenStream::from(expanded)
509}