#![recursion_limit = "128"]
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{TokenStreamExt, quote, quote_spanned};
use syn::spanned::Spanned;
use syn::{
Error, Expr, ExprLit, ExprPath, ItemFn, ItemStruct, Lit, Visibility, parse_macro_input,
parse_quote, parse_quote_spanned,
};
macro_rules! err {
($span:expr, $message:expr $(,)?) => {
Error::new($span.span(), $message).to_compile_error()
};
($span:expr, $message:expr, $($args:expr),*) => {
Error::new($span.span(), format!($message, $($args),*)).to_compile_error()
};
}
#[proc_macro_attribute]
pub fn unsafe_protocol(args: TokenStream, input: TokenStream) -> TokenStream {
let expr = parse_macro_input!(args as Expr);
let guid_val = match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) => {
quote!(::uefi::guid!(#lit))
}
Expr::Path(ExprPath { path, .. }) => quote!(#path),
_ => err!(
expr,
"macro input must be either a string literal or path to a constant"
),
};
let item_struct = parse_macro_input!(input as ItemStruct);
let ident = &item_struct.ident;
let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl();
quote! {
#item_struct
unsafe impl #impl_generics ::uefi::Identify for #ident #ty_generics #where_clause {
const GUID: ::uefi::Guid = #guid_val;
}
impl #impl_generics ::uefi::proto::Protocol for #ident #ty_generics #where_clause {}
}
.into()
}
#[proc_macro_attribute]
pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
let mut errors = TokenStream2::new();
if !args.is_empty() {
errors.append_all(err!(
TokenStream2::from(args),
"Entry attribute accepts no arguments"
));
}
let mut f = parse_macro_input!(input as ItemFn);
if let Some(ref abi) = f.sig.abi {
errors.append_all(err!(abi, "Entry function must have no ABI modifier"));
}
if let Some(asyncness) = f.sig.asyncness {
errors.append_all(err!(asyncness, "Entry function should not be async"));
}
if let Some(constness) = f.sig.constness {
errors.append_all(err!(constness, "Entry function should not be const"));
}
if !f.sig.generics.params.is_empty() {
errors.append_all(err!(
f.sig.generics.params,
"Entry function should not be generic"
));
}
if !f.sig.inputs.is_empty() {
errors.append_all(err!(f.sig.inputs, "Entry function must have no arguments"));
}
if !errors.is_empty() {
return errors.into();
}
let signature_span = f.sig.span();
let image_handle_ident = quote!(internal_image_handle);
let system_table_ident = quote!(internal_system_table);
f.sig.inputs = parse_quote_spanned!(
signature_span=>
#image_handle_ident: ::uefi::Handle,
#system_table_ident: *const ::core::ffi::c_void,
);
f.block.stmts.insert(
0,
parse_quote! {
unsafe {
::uefi::boot::set_image_handle(#image_handle_ident);
::uefi::table::set_system_table(#system_table_ident.cast());
}
},
);
f.sig.abi = Some(parse_quote_spanned!(signature_span=> extern "efiapi"));
f.vis = Visibility::Inherited;
let unsafety = &f.sig.unsafety;
let fn_ident = &f.sig.ident;
let fn_output = &f.sig.output;
let expected_args = quote!(::uefi::Handle, *const core::ffi::c_void);
let fn_type_check = quote_spanned! {signature_span=>
const _:
#unsafety extern "efiapi" fn(#expected_args) -> ::uefi::Status =
#fn_ident as #unsafety extern "efiapi" fn(#expected_args) #fn_output;
};
let result = quote! {
#fn_type_check
#[unsafe(export_name = "efi_main")]
#f
};
result.into()
}