1use convert_case::{Case, Casing};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote, TokenStreamExt};
4use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, BareFnArg, ReturnType, Type};
5
6const X64_ARG_REGS: [&str; 6] = ["rdi", "rsi", "rdx", "r10", "r8", "r9"];
7
8#[proc_macro_attribute]
9pub fn syscall(attr: TokenStream, item: TokenStream) -> TokenStream {
10 let item_type = parse_macro_input!(item as syn::ItemType);
11
12 let bare_fn = match *item_type.ty.clone() {
13 syn::Type::BareFn(bare_fn) => bare_fn,
14 _ => {
15 panic!("Must be a fn type eg. fn(input: usize) -> usize");
16 }
17 };
18
19 let inputs = bare_fn.inputs;
20 if inputs.len() > 6 {
21 panic!("A syscall has a maximum of six arguments")
22 }
23
24 let vis = item_type.vis;
25 let fn_name = format_ident!("{}", item_type.ident.to_string().to_case(Case::Snake));
26 let return_type = bare_fn.output;
27
28 let never_return = never_returns(&return_type);
29
30 let x86_64_asm_tokens = x86_64_asm_tokens(&inputs, never_return);
31 let sys_num = proc_macro2::TokenStream::from(attr);
32
33 let tokens = if never_return {
34 quote! {
35 #[inline(always)]
36 #[cfg(target_arch = "x86_64")]
37 #vis unsafe fn #fn_name(#inputs) #return_type {
38 let rax = #sys_num;
39 #x86_64_asm_tokens
40 }
41 }
42 } else {
43 quote! {
44 #[inline(always)]
45 #[cfg(target_arch = "x86_64")]
46 #vis unsafe fn #fn_name(#inputs) #return_type {
47 let mut rax = #sys_num as _;
48 #x86_64_asm_tokens
49 rax
50 }
51 }
52 };
53
54 TokenStream::from(tokens)
55}
56
57fn never_returns(return_type: &syn::ReturnType) -> bool {
58 match &return_type {
59 ReturnType::Default => false,
60 ReturnType::Type(_, ty) => matches!(ty.as_ref(), Type::Never(_)),
61 }
62}
63
64fn x86_64_asm_tokens(
65 inputs: &Punctuated<BareFnArg, Comma>,
66 never_return: bool,
67) -> proc_macro2::TokenStream {
68 let map_fnargs_to_reg_tokens = inputs.iter().enumerate().map(|(i, e)| {
69 let register_ident = X64_ARG_REGS[i];
70 if let Some(variable_str) = &e.name {
71 let variable_ident = variable_str.0.clone();
72 quote!(in(#register_ident) #variable_ident)
73 } else {
74 panic!("BareFnArg must have a name")
75 }
76 });
77
78 let input = if never_return {
79 quote! {
80 in("rax") rax,
81 }
82 } else {
83 quote! {
84 inout("rax") rax,
85 }
86 };
87
88 let mut options = quote! { options(nostack), };
89 if never_return {
90 options.append_all(quote! { options(noreturn), });
91 }
92
93 quote! {
94 core::arch::asm!(
95 "syscall",
96 #input
97 #(#map_fnargs_to_reg_tokens),*,
98 clobber_abi("system"),
99 #options
100 );
101 }
102}