Skip to main content

tightbeam_derive/
lib.rs

1//! Derive macro for TightBeam message types
2//!
3//! This crate provides the `#[derive(Beamable)]` macro that automatically
4//! implements the `Message` trait for structs.
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::parse::Parser;
9use syn::punctuated::Punctuated;
10use syn::{parse_macro_input, Attribute, DeriveInput, Meta, Token};
11
12fn has_flag(attrs: &[Attribute], name: &str) -> bool {
13	for attr in attrs {
14		if !attr.path().is_ident("beam") {
15			continue;
16		}
17		if let Meta::List(list) = &attr.meta {
18			// Allow mixing identifiers and name-value pairs in #[beam(...)]
19			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
20			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
21				for meta in metas {
22					if let Meta::Path(path) = meta {
23						if path.is_ident(name) {
24							return true;
25						}
26					}
27				}
28			}
29		}
30	}
31	false
32}
33
34fn get_version_value(attrs: &[Attribute]) -> Option<syn::Ident> {
35	for attr in attrs {
36		if !attr.path().is_ident("beam") {
37			continue;
38		}
39		if let Meta::List(list) = &attr.meta {
40			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
41			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
42				for meta in metas {
43					if let Meta::NameValue(nv) = meta {
44						if nv.path.is_ident("min_version") {
45							if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. }) = &nv.value {
46								return Some(syn::Ident::new(&lit_str.value(), lit_str.span()));
47							}
48						}
49					}
50				}
51			}
52		}
53	}
54	None
55}
56
57fn get_profile_value(attrs: &[Attribute]) -> Option<u8> {
58	for attr in attrs {
59		if !attr.path().is_ident("beam") {
60			continue;
61		}
62		if let Meta::List(list) = &attr.meta {
63			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
64			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
65				for meta in metas {
66					if let Meta::NameValue(nv) = meta {
67						if nv.path.is_ident("profile") {
68							if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) = &nv.value {
69								if let Ok(profile) = lit_int.base10_parse::<u8>() {
70									return Some(profile);
71								}
72							}
73						}
74					}
75				}
76			}
77		}
78	}
79	None
80}
81
82fn get_profile_type(attrs: &[Attribute]) -> Option<syn::Type> {
83	for attr in attrs {
84		if !attr.path().is_ident("beam") {
85			continue;
86		}
87		if let Meta::List(list) = &attr.meta {
88			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
89			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
90				for meta in metas {
91					if let Meta::List(profile_list) = meta {
92						if profile_list.path.is_ident("profile") {
93							// Parse the content inside profile(...) as a type
94							if let Ok(ty) = syn::parse2::<syn::Type>(profile_list.tokens.clone()) {
95								return Some(ty);
96							}
97						}
98					}
99				}
100			}
101		}
102	}
103	None
104}
105
106fn has_attr(attrs: &[Attribute], name: &str) -> bool {
107	attrs.iter().any(|attr| attr.path().is_ident(name))
108}
109
110fn get_error_message(attrs: &[Attribute]) -> Option<String> {
111	for attr in attrs {
112		if attr.path().is_ident("error") {
113			if let Meta::List(list) = &attr.meta {
114				if let Ok(lit_str) = syn::parse2::<syn::LitStr>(list.tokens.clone()) {
115					return Some(lit_str.value());
116				}
117			}
118		}
119	}
120	None
121}
122
123/// Derive macro for implementing `Message`
124///
125/// This macro can be applied to any struct that implements the necessary
126/// serialization traits (typically `der::Sequence`).
127#[proc_macro_derive(Beamable, attributes(beam))]
128pub fn derive_beamable(input: TokenStream) -> TokenStream {
129	let input = parse_macro_input!(input as DeriveInput);
130	let name = &input.ident;
131
132	let confidential = has_flag(&input.attrs, "confidential");
133	let nonrep = has_flag(&input.attrs, "nonrepudiable");
134	let compressed = has_flag(&input.attrs, "compressed");
135	let prioritized = has_flag(&input.attrs, "prioritized");
136	let message_integrity = has_flag(&input.attrs, "message_integrity");
137	let frame_integrity = has_flag(&input.attrs, "frame_integrity");
138	let min_version = get_version_value(&input.attrs);
139	let profile_value = get_profile_value(&input.attrs);
140	let profile_type = get_profile_type(&input.attrs);
141
142	// Validate that we don't have both numeric and type-based profiles
143	if profile_value.is_some() && profile_type.is_some() {
144		return syn::Error::new_spanned(
145			&input,
146			"Cannot specify both numeric profile (= N) and type-based profile (Type) simultaneously",
147		)
148		.to_compile_error()
149		.into();
150	}
151
152	// Profile-based security requirements
153	let (profile_confidential, profile_nonrep, profile_min_version) = match profile_value {
154		Some(1) => (true, true, Some(syn::Ident::new("V1", name.span()))), // FIPS
155		Some(2) => (true, true, Some(syn::Ident::new("V1", name.span()))), // Standard
156		Some(p) if p > 2 => (false, false, None),
157		_ => (false, false, None),
158	};
159
160	// Apply profile requirements (override individual flags)
161	let final_confidential = profile_confidential || confidential;
162	let final_nonrep = profile_nonrep || nonrep;
163	let final_min_version = profile_min_version.or(min_version);
164	let final_message_integrity = message_integrity;
165	let final_frame_integrity = frame_integrity;
166
167	let mut feature_checks = Vec::new();
168
169	if final_confidential && !cfg!(feature = "aead") {
170		feature_checks.push(quote! {
171			compile_error!(concat!(
172				"Message type `", stringify!(#name), "` is marked as confidential ",
173				"but the `aead` feature is not enabled. ",
174				"Enable the feature in Cargo.toml: features = [\"aead\"]"
175			));
176		});
177	}
178
179	if final_nonrep && !cfg!(feature = "signature") {
180		feature_checks.push(quote! {
181			compile_error!(concat!(
182				"Message type `", stringify!(#name), "` is marked as non-repudiable ",
183				"but the `signature` feature is not enabled. ",
184				"Enable the feature in Cargo.toml: features = [\"signature\"]"
185			));
186		});
187	}
188
189	if compressed && !cfg!(feature = "compress") {
190		feature_checks.push(quote! {
191			compile_error!(concat!(
192				"Message type `", stringify!(#name), "` is marked as compressed ",
193				"but the `compress` feature is not enabled. ",
194				"Enable the feature in Cargo.toml: features = [\"compress\"]"
195			));
196		});
197	}
198
199	if (final_message_integrity || final_frame_integrity) && !cfg!(feature = "digest") {
200		feature_checks.push(quote! {
201			compile_error!(concat!(
202				"Message type `", stringify!(#name), "` is marked as requiring message integrity ",
203				"but the `digest` feature is not enabled. ",
204				"Enable the feature in Cargo.toml: features = [\"digest\"]"
205			));
206		});
207	}
208
209	let min_version_value = if let Some(version) = final_min_version {
210		quote! { ::tightbeam::Version::#version }
211	} else {
212		quote! { ::tightbeam::Version::V0 }
213	};
214
215	let _has_profile = profile_type.is_some();
216	// `Message::Profile` is gated behind tightbeam's `crypto` feature, so the
217	// associated-type definition is emitted through `__tb_if_crypto!`, which is
218	// resolved in tightbeam's feature context rather than the consumer's.
219	let profile_type_impl = if let Some(profile_ty) = &profile_type {
220		quote! {
221			const HAS_PROFILE: bool = true;
222			::tightbeam::__tb_if_crypto! { type Profile = #profile_ty; }
223		}
224	} else {
225		// Always define HAS_PROFILE, even when false (needed for checker trait impls)
226		quote! {
227			const HAS_PROFILE: bool = false;
228			::tightbeam::__tb_if_crypto! { type Profile = ::tightbeam::crypto::profiles::TightbeamProfile; }
229		}
230	};
231
232	// Generate checker trait implementations for compile-time OID validation
233	// When HAS_PROFILE = true: generates impls ONLY for the matching OID type from the profile (compile-time enforcement)
234	// When HAS_PROFILE = false: generates generic impls for all OID types (no enforcement, allows any)
235	// All types using #[derive(Beamable)] get these impls - types not using derive must implement manually
236	let oid_validation_helpers = if let Some(profile_ty) = &profile_type {
237		// We know the profile type, so we can reference its associated types directly
238		// ONLY implement for the exact OID types from the profile - wrong OIDs will fail to compile
239		quote! {
240			::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_digest! {
241				impl ::tightbeam::builder::private::SealedDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
242				where
243					#name: ::tightbeam::Message,
244				{}
245
246				impl ::tightbeam::builder::CheckDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
247				where
248					#name: ::tightbeam::Message,
249				{
250					const RESULT: () = ();
251				}
252			} }
253
254			::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_aead! {
255				impl ::tightbeam::builder::private::SealedAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
256				where
257					#name: ::tightbeam::Message,
258				{}
259
260				impl ::tightbeam::builder::CheckAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
261				where
262					#name: ::tightbeam::Message,
263				{
264					const RESULT: () = ();
265				}
266			} }
267
268			::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_signature! {
269				impl ::tightbeam::builder::private::SealedSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
270				where
271					#name: ::tightbeam::Message,
272				{}
273
274				impl ::tightbeam::builder::CheckSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
275				where
276					#name: ::tightbeam::Message,
277				{
278					const RESULT: () = ();
279				}
280			} }
281		}
282	} else {
283		// When HAS_PROFILE = false, generate generic impls for all OID types (no enforcement)
284		// These allow FrameBuilder methods to work for types without profiles
285		quote! {
286			::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_digest! {
287				impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedDigestOid<D> for #name
288				where
289					#name: ::tightbeam::Message,
290				{}
291
292				impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckDigestOid<D> for #name
293				where
294					#name: ::tightbeam::Message,
295				{
296					const RESULT: () = ();
297				}
298			} }
299
300			::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_aead! {
301				impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedAeadOid<C> for #name
302				where
303					#name: ::tightbeam::Message,
304				{}
305
306				impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckAeadOid<C> for #name
307				where
308					#name: ::tightbeam::Message,
309				{
310					const RESULT: () = ();
311				}
312			} }
313
314			::tightbeam::__tb_if_builder! { ::tightbeam::__tb_if_signature! {
315				impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::private::SealedSignatureOid<S> for #name
316				where
317					#name: ::tightbeam::Message,
318				{}
319
320				impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::CheckSignatureOid<S> for #name
321				where
322					#name: ::tightbeam::Message,
323				{
324					const RESULT: () = ();
325				}
326			} }
327		}
328	};
329
330	let expanded = quote! {
331		const _: () = {
332			#(#feature_checks)*
333		};
334
335		impl ::tightbeam::Message for #name {
336			const MUST_BE_CONFIDENTIAL: bool = #final_confidential;
337			const MUST_BE_NON_REPUDIABLE: bool = #final_nonrep;
338			const MUST_HAVE_MESSAGE_INTEGRITY: bool = #final_message_integrity;
339			const MUST_HAVE_FRAME_INTEGRITY: bool = #final_frame_integrity;
340			const MUST_BE_COMPRESSED: bool = #compressed;
341			const MUST_BE_PRIORITIZED: bool = #prioritized;
342			const MIN_VERSION: ::tightbeam::Version = #min_version_value;
343			#profile_type_impl
344		}
345
346		#oid_validation_helpers
347	};
348
349	TokenStream::from(expanded)
350}
351
352/// Derive macro for implementing flag enum traits
353///
354/// This macro automatically adds the necessary attributes and trait
355/// implementations for flag enums used with the TightBeam flag system.
356#[proc_macro_derive(Flaggable)]
357pub fn derive_flaggable(input: TokenStream) -> TokenStream {
358	let input = parse_macro_input!(input as DeriveInput);
359	let name = &input.ident;
360	let name_str = name.to_string();
361
362	let expanded = quote! {
363		impl From<#name> for u8 {
364			fn from(val: #name) -> u8 {
365				val as u8
366			}
367		}
368
369		impl PartialEq<u8> for #name {
370			fn eq(&self, other: &u8) -> bool {
371				(*self as u8) == *other
372			}
373		}
374
375		impl #name {
376			pub const TYPE_NAME: &'static str = #name_str;
377		}
378	};
379
380	TokenStream::from(expanded)
381}
382
383/// Derive macro for implementing error traits with automatic Display and From
384/// implementations
385///
386/// This macro automatically implements `Display`, `Error`, and `From`
387/// conversions for error enums, similar to the `snafu` crate.
388///
389/// # Attributes
390///
391/// - `#[error("format string")]` - Specifies the display format for the variant
392/// - `#[from]` - Automatically implements `From` for the wrapped type
393#[proc_macro_derive(Errorizable, attributes(error, from))]
394pub fn derive_errorizable(input: TokenStream) -> TokenStream {
395	let input = parse_macro_input!(input as DeriveInput);
396	let name = &input.ident;
397
398	let data_enum = match &input.data {
399		syn::Data::Enum(data) => data,
400		_ => {
401			return syn::Error::new_spanned(&input, "Errorizable can only be derived for enums")
402				.to_compile_error()
403				.into();
404		}
405	};
406
407	let mut display_arms = Vec::new();
408	let mut from_impls = Vec::new();
409
410	for variant in &data_enum.variants {
411		let variant_name = &variant.ident;
412
413		// Get the error message from #[error("...")] attribute
414		let error_msg = get_error_message(&variant.attrs);
415		let has_from = has_attr(&variant.attrs, "from");
416
417		// Build the display match arm based on variant fields
418		match &variant.fields {
419			syn::Fields::Unnamed(fields) => {
420				let field_count = fields.unnamed.len();
421				let field_bindings: Vec<_> = (0..field_count)
422					.map(|i| syn::Ident::new(&format!("f{i}"), variant_name.span()))
423					.collect();
424
425				if let Some(msg) = error_msg {
426					// Check if format string contains field accessors like {expected} or {received}
427					if msg.contains("{expected") || msg.contains("{received") {
428						// Assume single field with .expected and .received properties
429						display_arms.push(quote! {
430							#name::#variant_name(ref f0) => {
431								write!(f, #msg, expected = f0.expected, received = f0.received)
432							}
433						});
434					} else {
435						display_arms.push(quote! {
436							#name::#variant_name(#(ref #field_bindings),*) => {
437								write!(f, #msg, #(#field_bindings),*)
438							}
439						});
440					}
441				} else {
442					display_arms.push(quote! {
443						#name::#variant_name(#(ref #field_bindings),*) => {
444							write!(f, "{}", stringify!(#variant_name))
445						}
446					});
447				}
448
449				// Generate From impl if #[from] is present and there's exactly one field
450				if has_from && field_count == 1 {
451					if let Some(field) = fields.unnamed.first() {
452						let field_type = &field.ty;
453						from_impls.push(quote! {
454							impl From<#field_type> for #name {
455								fn from(err: #field_type) -> Self {
456									#name::#variant_name(err)
457								}
458							}
459						});
460					}
461				}
462			}
463			syn::Fields::Named(fields) => {
464				let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
465
466				if let Some(msg) = error_msg {
467					display_arms.push(quote! {
468						#name::#variant_name { #(ref #field_names),* } => {
469							write!(f, #msg, #(#field_names = #field_names),*)
470						}
471					});
472				} else {
473					display_arms.push(quote! {
474						#name::#variant_name { .. } => {
475							write!(f, "{}", stringify!(#variant_name))
476						}
477					});
478				}
479			}
480			syn::Fields::Unit => {
481				if let Some(msg) = error_msg {
482					display_arms.push(quote! {
483						#name::#variant_name => write!(f, #msg)
484					});
485				} else {
486					display_arms.push(quote! {
487						#name::#variant_name => write!(f, "{}", stringify!(#variant_name))
488					});
489				}
490			}
491		}
492	}
493
494	let expanded = quote! {
495		impl core::fmt::Display for #name {
496			fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
497				match self {
498					#(#display_arms,)*
499				}
500			}
501		}
502
503		impl core::error::Error for #name {}
504
505		#(#from_impls)*
506	};
507
508	TokenStream::from(expanded)
509}