riscv_rt_macros/
lib.rs

1#![deny(warnings)]
2#![allow(unknown_lints)] // reason = "required for next line"
3#![allow(clippy::manual_is_multiple_of)] // reason = "requires MSRV bump"
4
5use proc_macro::TokenStream;
6use proc_macro2::{Span, TokenStream as TokenStream2};
7use quote::quote;
8use syn::{
9    parse::{self, Parse},
10    parse_macro_input, parse_quote,
11    punctuated::Punctuated,
12    spanned::Spanned,
13    FnArg, ItemFn, LitInt, LitStr, PatType, Path, ReturnType, Token, Type, Visibility,
14};
15
16/// Attribute to declare the entry point of the program
17///
18/// **IMPORTANT**: This attribute must appear exactly *once* in the dependency graph. Also, if you
19/// are using Rust 1.30 the attribute must be used on a reachable item (i.e. there must be no
20/// private modules between the item and the root of the crate); if the item is in the root of the
21/// crate you'll be fine. This reachability restriction doesn't apply to Rust 1.31 and newer releases.
22///
23/// The specified function will be called by the reset handler *after* RAM has been initialized.
24/// If present, the FPU will also be enabled before the function is called.
25///
26/// The type of the specified function must be `[unsafe] fn() -> !` (never ending function)
27///
28/// # Properties
29///
30/// The entry point will be called by the reset handler. The program can't reference to the entry
31/// point, much less invoke it.
32///
33/// # Examples
34///
35/// - Simple entry point
36///
37/// ``` no_run
38/// # #![no_main]
39/// # use riscv_rt_macros::entry;
40/// #[entry]
41/// fn main() -> ! {
42///     loop {
43///         /* .. */
44///     }
45/// }
46/// ```
47#[proc_macro_attribute]
48pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
49    let f = parse_macro_input!(input as ItemFn);
50
51    #[cfg(not(feature = "u-boot"))]
52    let arguments_limit = 3;
53    #[cfg(feature = "u-boot")]
54    let arguments_limit = 2;
55
56    // check the function arguments
57    if f.sig.inputs.len() > arguments_limit {
58        return parse::Error::new(
59            f.sig.inputs.last().unwrap().span(),
60            "`#[entry]` function has too many arguments",
61        )
62        .to_compile_error()
63        .into();
64    }
65
66    #[cfg(not(feature = "u-boot"))]
67    for argument in f.sig.inputs.iter() {
68        if let Some(message) = check_argument_type(argument, "usize") {
69            return message;
70        };
71    }
72    #[cfg(feature = "u-boot")]
73    if let Some(message) = f
74        .sig
75        .inputs
76        .get(0)
77        .and_then(|argument| check_argument_type(argument, "c_int"))
78    {
79        return message;
80    }
81    #[cfg(feature = "u-boot")]
82    if let Some(message) = f
83        .sig
84        .inputs
85        .get(1)
86        .and_then(|argument| check_argument_type(argument, "*const *const c_char"))
87    {
88        return message;
89    }
90
91    // check the function signature
92    let valid_signature = f.sig.constness.is_none()
93        && f.sig.asyncness.is_none()
94        && f.vis == Visibility::Inherited
95        && f.sig.abi.is_none()
96        && f.sig.generics.params.is_empty()
97        && f.sig.generics.where_clause.is_none()
98        && f.sig.variadic.is_none()
99        && match f.sig.output {
100            ReturnType::Default => false,
101            ReturnType::Type(_, ref ty) => matches!(**ty, Type::Never(_)),
102        };
103
104    if !valid_signature {
105        return parse::Error::new(
106            f.span(),
107            "`#[entry]` function must have signature `[unsafe] fn([arg0: usize, ...]) -> !`",
108        )
109        .to_compile_error()
110        .into();
111    }
112
113    if !args.is_empty() {
114        return parse::Error::new(Span::call_site(), "This attribute accepts no arguments")
115            .to_compile_error()
116            .into();
117    }
118
119    // XXX should we blacklist other attributes?
120    let attrs = f.attrs;
121    let unsafety = f.sig.unsafety;
122    let args = f.sig.inputs;
123    let stmts = f.block.stmts;
124
125    quote!(
126        #[allow(non_snake_case)]
127        #[export_name = "main"]
128        #(#attrs)*
129        pub #unsafety fn __risc_v_rt__main(#args) -> ! {
130            #(#stmts)*
131        }
132    )
133    .into()
134}
135
136fn strip_type_path(ty: &Type) -> Option<Type> {
137    match ty {
138        Type::Ptr(ty) => {
139            let mut ty = ty.clone();
140            ty.elem = Box::new(strip_type_path(&ty.elem)?);
141            Some(Type::Ptr(ty))
142        }
143        Type::Path(ty) => {
144            let mut ty = ty.clone();
145            let last_segment = ty.path.segments.last().unwrap().clone();
146            ty.path.segments = Punctuated::new();
147            ty.path.segments.push_value(last_segment);
148            Some(Type::Path(ty))
149        }
150        _ => None,
151    }
152}
153
154#[allow(unused)]
155fn is_correct_type(ty: &Type, name: &str) -> bool {
156    let correct: Type = syn::parse_str(name).unwrap();
157    if let Some(ty) = strip_type_path(ty) {
158        ty == correct
159    } else {
160        false
161    }
162}
163
164fn check_correct_type(argument: &PatType, ty: &str) -> Option<TokenStream> {
165    let inv_type_message = format!("argument type must be {ty}");
166
167    if !is_correct_type(&argument.ty, ty) {
168        let error = parse::Error::new(argument.ty.span(), inv_type_message);
169
170        Some(error.to_compile_error().into())
171    } else {
172        None
173    }
174}
175
176fn check_argument_type(argument: &FnArg, ty: &str) -> Option<TokenStream> {
177    let argument_error = parse::Error::new(argument.span(), "invalid argument");
178    let argument_error = argument_error.to_compile_error().into();
179
180    match argument {
181        FnArg::Typed(argument) => check_correct_type(argument, ty),
182        FnArg::Receiver(_) => Some(argument_error),
183    }
184}
185
186/// Attribute to mark which function will be called at the beginning of the reset handler.
187/// You must enable the `pre_init` feature in the `riscv-rt` crate to use this macro.
188///
189/// # IMPORTANT
190///
191/// This attribute is **deprecated**, as it is not safe to run Rust code **before** the static
192/// variables are initialized. The recommended way to run code before the static variables
193/// are initialized is to use the `global_asm!` macro to define the `__pre_init` function.
194///
195/// This attribute can appear at most *once* in the dependency graph. Also, if you
196/// are using Rust 1.30 the attribute must be used on a reachable item (i.e. there must be no
197/// private modules between the item and the root of the crate); if the item is in the root of the
198/// crate you'll be fine. This reachability restriction doesn't apply to Rust 1.31 and newer
199/// releases.
200///
201/// The function must have the signature of `unsafe fn()`.
202///
203/// The function passed will be called before static variables are initialized. Any access of static
204/// variables will result in undefined behavior.
205///
206/// # Examples
207///
208/// ```
209/// # use riscv_rt_macros::pre_init;
210/// #[pre_init]
211/// unsafe fn before_main() {
212///     // do something here
213/// }
214///
215/// # fn main() {}
216/// ```
217#[deprecated(note = "Use global_asm! to define the __pre_init function instead")]
218#[proc_macro_attribute]
219pub fn pre_init(args: TokenStream, input: TokenStream) -> TokenStream {
220    let f = parse_macro_input!(input as ItemFn);
221
222    // check the function signature
223    let valid_signature = f.sig.constness.is_none()
224        && f.sig.asyncness.is_none()
225        && f.vis == Visibility::Inherited
226        && f.sig.unsafety.is_some()
227        && f.sig.abi.is_none()
228        && f.sig.inputs.is_empty()
229        && f.sig.generics.params.is_empty()
230        && f.sig.generics.where_clause.is_none()
231        && f.sig.variadic.is_none()
232        && match f.sig.output {
233            ReturnType::Default => true,
234            ReturnType::Type(_, ref ty) => match **ty {
235                Type::Tuple(ref tuple) => tuple.elems.is_empty(),
236                _ => false,
237            },
238        };
239
240    if !valid_signature {
241        return parse::Error::new(
242            f.span(),
243            "`#[pre_init]` function must have signature `unsafe fn()`",
244        )
245        .to_compile_error()
246        .into();
247    }
248
249    if !args.is_empty() {
250        return parse::Error::new(Span::call_site(), "This attribute accepts no arguments")
251            .to_compile_error()
252            .into();
253    }
254
255    // XXX should we blacklist other attributes?
256    let attrs = f.attrs;
257    let ident = f.sig.ident;
258    let block = f.block;
259
260    quote!(
261        #[export_name = "__pre_init"]
262        #(#attrs)*
263        pub unsafe fn #ident() #block
264    )
265    .into()
266}
267
268/// Attribute to mark which function will be called before jumping to the entry point.
269/// You must enable the `post-init` feature in the `riscv-rt` crate to use this macro.
270///
271/// In contrast with `__pre_init`, this function is called after the static variables
272/// are initialized, so it is safe to access them. It is also safe to run Rust code.
273///
274/// The function must have the signature of `[unsafe] fn([usize])`, where the argument
275/// corresponds to the hart ID of the current hart. This is useful for multi-hart systems
276/// to perform hart-specific initialization.
277///
278/// # IMPORTANT
279///
280/// This attribute can appear at most *once* in the dependency graph.
281///
282/// # Examples
283///
284/// ```
285/// use riscv_rt_macros::post_init;
286/// #[post_init]
287/// unsafe fn before_main(hart_id: usize) {
288///     // do something here
289/// }
290/// ```
291#[proc_macro_attribute]
292pub fn post_init(args: TokenStream, input: TokenStream) -> TokenStream {
293    let f = parse_macro_input!(input as ItemFn);
294
295    // check the function arguments
296    if f.sig.inputs.len() > 1 {
297        return parse::Error::new(
298            f.sig.inputs.last().unwrap().span(),
299            "`#[post_init]` function has too many arguments",
300        )
301        .to_compile_error()
302        .into();
303    }
304    for argument in f.sig.inputs.iter() {
305        if let Some(message) = check_argument_type(argument, "usize") {
306            return message;
307        };
308    }
309
310    // check the function signature
311    let valid_signature = f.sig.constness.is_none()
312        && f.sig.asyncness.is_none()
313        && f.vis == Visibility::Inherited
314        && f.sig.abi.is_none()
315        && f.sig.generics.params.is_empty()
316        && f.sig.generics.where_clause.is_none()
317        && f.sig.variadic.is_none()
318        && match f.sig.output {
319            ReturnType::Default => true,
320            ReturnType::Type(_, ref ty) => match **ty {
321                Type::Tuple(ref tuple) => tuple.elems.is_empty(),
322                _ => false,
323            },
324        };
325
326    if !valid_signature {
327        return parse::Error::new(
328            f.span(),
329            "`#[post_init]` function must have signature `[unsafe] fn([usize])`",
330        )
331        .to_compile_error()
332        .into();
333    }
334
335    if !args.is_empty() {
336        return parse::Error::new(Span::call_site(), "This attribute accepts no arguments")
337            .to_compile_error()
338            .into();
339    }
340
341    // XXX should we blacklist other attributes?
342    let attrs = f.attrs;
343    let ident = f.sig.ident;
344    let args = f.sig.inputs;
345    let block = f.block;
346
347    quote!(
348        #[export_name = "__post_init"]
349        #(#attrs)*
350        unsafe fn #ident(#args) #block
351    )
352    .into()
353}
354
355struct AsmLoopArgs {
356    asm_template: String,
357    count_from: usize,
358    count_to: usize,
359}
360
361impl Parse for AsmLoopArgs {
362    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
363        let template: LitStr = input.parse().unwrap();
364        _ = input.parse::<Token![,]>().unwrap();
365        let count: LitInt = input.parse().unwrap();
366        if input.parse::<Token![,]>().is_ok() {
367            let count_to: LitInt = input.parse().unwrap();
368            Ok(Self {
369                asm_template: template.value(),
370                count_from: count.base10_parse().unwrap(),
371                count_to: count_to.base10_parse().unwrap(),
372            })
373        } else {
374            Ok(Self {
375                asm_template: template.value(),
376                count_from: 0,
377                count_to: count.base10_parse().unwrap(),
378            })
379        }
380    }
381}
382
383/// Loops an asm expression n times.
384///
385/// `loop_asm!` takes 2 or 3 arguments, the first is a string literal and the rest are a number literal
386/// See [the formatting syntax documentation in `std::fmt`](../std/fmt/index.html) for details.
387///
388/// Argument 1 is an assembly expression, all "{}" in this assembly expression will be replaced with the
389/// current loop index.
390///
391/// If 2 arguments are provided, the loop will start at 0 and end at the number provided in argument 2.
392///
393/// If 3 arguments are provided, the loop will start at the number provided in argument 2 and end at
394/// the number provided in argument 3.
395///
396/// # Examples
397///
398/// ```
399/// # use riscv_rt_macros::loop_asm;
400/// unsafe {
401///     loop_asm!("fmv.w.x f{}, x0", 32); // => core::arch::asm!("fmv.w.x f0, x0") ... core::arch::asm!("fmv.w.x f31, x0")
402///     loop_asm!("fmv.w.x f{}, x0", 1, 32); // => core::arch::asm!("fmv.w.x f1, x0") ... core::arch::asm!("fmv.w.x f31, x0")
403/// }
404/// ```
405#[proc_macro]
406pub fn loop_asm(input: TokenStream) -> TokenStream {
407    let args = parse_macro_input!(input as AsmLoopArgs);
408
409    let tokens = (args.count_from..args.count_to)
410        .map(|i| {
411            let i = i.to_string();
412            let asm = args.asm_template.replace("{}", &i);
413            format!("core::arch::asm!(\"{asm}\");")
414        })
415        .collect::<Vec<String>>()
416        .join("\n");
417    tokens.parse().unwrap()
418}
419
420/// Loops a global_asm expression n times.
421///
422/// `loop_global_asm!` takes 2 or 3 arguments, the first is a string literal and the rest are a number literal
423/// See [the formatting syntax documentation in `std::fmt`](../std/fmt/index.html) for details.
424///
425/// Argument 1 is an assembly expression, all "{}" in this assembly expression will be replaced with the
426/// current loop index.
427///
428/// If 2 arguments are provided, the loop will start at 0 and end at the number provided in argument 2.
429///
430/// If 3 arguments are provided, the loop will start at the number provided in argument 2 and end at
431/// the number provided in argument 3.
432///
433/// # Examples
434///
435/// ```
436/// # use riscv_rt_macros::loop_global_asm;
437/// unsafe {
438///     loop_global_asm!("fmv.w.x f{}, x0", 32); // => core::arch::global_asm!("fmv.w.x f0, x0") ... core::arch::global_asm!("fmv.w.x f31, x0")
439///     loop_global_asm!("fmv.w.x f{}, x0", 1, 32); // => core::arch::global_asm!("fmv.w.x f1, x0") ... core::arch::global_asm!("fmv.w.x f31, x0")
440/// }
441/// ```
442#[proc_macro]
443pub fn loop_global_asm(input: TokenStream) -> TokenStream {
444    let args = parse_macro_input!(input as AsmLoopArgs);
445
446    let instructions = (args.count_from..args.count_to)
447        .map(|i| {
448            let i = i.to_string();
449            args.asm_template.replace("{}", &i)
450        })
451        .collect::<Vec<String>>()
452        .join("\n");
453
454    let res = format!("core::arch::global_asm!(\n\"{instructions}\"\n);");
455    res.parse().unwrap()
456}
457
458#[derive(Clone, Copy, Debug)]
459enum RiscvArch {
460    Rv32I,
461    Rv32E,
462    Rv64I,
463    Rv64E,
464}
465
466impl Parse for RiscvArch {
467    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
468        let ident: syn::Ident = input.parse()?;
469        match ident.to_string().as_str() {
470            "rv32i" => Ok(Self::Rv32I),
471            "rv32e" => Ok(Self::Rv32E),
472            "rv64i" => Ok(Self::Rv64I),
473            "rv64e" => Ok(Self::Rv64E),
474            _ => Err(syn::Error::new(ident.span(), "Invalid RISC-V architecture")),
475        }
476    }
477}
478
479impl RiscvArch {
480    fn try_from_env() -> Option<Self> {
481        let arch = std::env::var("RISCV_RT_BASE_ISA").ok()?;
482        match arch.as_str() {
483            "rv32i" => Some(Self::Rv32I),
484            "rv32e" => Some(Self::Rv32E),
485            "rv64i" => Some(Self::Rv64I),
486            "rv64e" => Some(Self::Rv64E),
487            _ => None,
488        }
489    }
490
491    const fn width(&self) -> usize {
492        match self {
493            Self::Rv32I | Self::Rv32E => 4,
494            Self::Rv64I | Self::Rv64E => 8,
495        }
496    }
497
498    const fn store(&self) -> &str {
499        match self {
500            Self::Rv32I | Self::Rv32E => "sw",
501            Self::Rv64I | Self::Rv64E => "sd",
502        }
503    }
504
505    const fn load(&self) -> &str {
506        match self {
507            Self::Rv32I | Self::Rv32E => "lw",
508            Self::Rv64I | Self::Rv64E => "ld",
509        }
510    }
511
512    fn trap_frame(&self) -> Vec<&str> {
513        match self {
514            Self::Rv32I | Self::Rv64I => vec![
515                "ra", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a0", "a1", "a2", "a3", "a4", "a5",
516                "a6", "a7",
517            ],
518            Self::Rv32E | Self::Rv64E => {
519                vec!["ra", "t0", "t1", "t2", "a0", "a1", "a2", "a3", "a4", "a5"]
520            }
521        }
522    }
523
524    /// Standard RISC-V ABI requires the stack to be 16-byte aligned.
525    /// However, in LLVM, for RV32E and RV64E, the stack must be 4-byte aligned
526    /// to be compatible with the implementation of ilp32e in GCC
527    ///
528    /// Related: https://llvm.org/docs/RISCVUsage.html
529    const fn byte_alignment(&self) -> usize {
530        match self {
531            Self::Rv32E | Self::Rv64E => 4,
532            _ => 16,
533        }
534    }
535}
536
537/// Generate the assembly instructions to store the trap frame.
538///
539/// The `arch` parameter is used to determine the width of the registers.
540///
541/// The `filter` function is used to filter which registers to store.
542/// This is useful to optimize the binary size in vectored interrupt mode, which divides the trap
543/// frame storage in two parts: the first part saves space in the stack and stores only the `a0` register,
544/// while the second part stores the remaining registers.
545fn store_trap<T: FnMut(&str) -> bool>(arch: RiscvArch, mut filter: T) -> String {
546    let width = arch.width();
547    let store = arch.store();
548    arch.trap_frame()
549        .iter()
550        .enumerate()
551        .filter(|(_, &reg)| !reg.starts_with('_') && filter(reg))
552        .map(|(i, reg)| format!("{store} {reg}, {i}*{width}(sp)"))
553        .collect::<Vec<_>>()
554        .join("\n    ")
555}
556
557/// Generate the assembly instructions to load the trap frame.
558/// The `arch` parameter is used to determine the width of the registers.
559fn load_trap(arch: RiscvArch) -> String {
560    let width = arch.width();
561    let load = arch.load();
562    arch.trap_frame()
563        .iter()
564        .enumerate()
565        .filter(|(_, &reg)| !reg.starts_with('_'))
566        .map(|(i, reg)| format!("{load} {reg}, {i}*{width}(sp)"))
567        .collect::<Vec<_>>()
568        .join("\n    ")
569}
570
571/// Temporary patch macro to deal with LLVM bug
572#[proc_macro]
573pub fn llvm_arch_patch(_input: TokenStream) -> TokenStream {
574    let q = if let Ok(arch) = std::env::var("RISCV_RT_LLVM_ARCH_PATCH") {
575        let patch = format!(".attribute arch,\"{arch}\"");
576        quote! { core::arch::global_asm!{#patch} }
577    } else {
578        quote!(compile_error!("RISCV_RT_LLVM_ARCH_PATCH is not set"))
579    };
580    q.into()
581}
582
583/// Generates `_default_start_trap` function in assembly.
584/// If no `_start_trap` function is defined, the linker will use this function as the default.
585///
586/// This implementation stores all registers in the trap frame and calls `_start_trap_rust`.
587/// The trap frame is allocated on the stack and deallocated after the call.
588#[proc_macro]
589pub fn default_start_trap(_input: TokenStream) -> TokenStream {
590    let arch = RiscvArch::try_from_env().unwrap();
591
592    let width = arch.width();
593    let trap_size = arch.trap_frame().len();
594    let byte_alignment = arch.byte_alignment();
595    // ensure we do not break that sp is 16-byte aligned
596    if (trap_size * width) % byte_alignment != 0 {
597        return parse::Error::new(Span::call_site(), "Trap frame size must be 16-byte aligned")
598            .to_compile_error()
599            .into();
600    }
601    let store = store_trap(arch, |_| true);
602    let load = load_trap(arch);
603
604    #[cfg(feature = "s-mode")]
605    let ret = "sret";
606    #[cfg(not(feature = "s-mode"))]
607    let ret = "mret";
608
609    let pre_default_start_trap = if cfg!(feature = "pre-default-start-trap") {
610        r#"
611        j _pre_default_start_trap
612.global _pre_default_start_trap_ret
613_pre_default_start_trap_ret:
614        "#
615    } else {
616        ""
617    };
618
619    format!(
620        r#"
621core::arch::global_asm!(
622".section .trap.start, \"ax\"
623.balign 4 /* Alignment required for xtvec */
624.global _default_start_trap
625_default_start_trap:
626    {pre_default_start_trap}
627    addi sp, sp, - {trap_size} * {width}
628    {store}
629    add a0, sp, zero
630    jal ra, _start_trap_rust
631    {load}
632    addi sp, sp, {trap_size} * {width}
633    {ret}
634");"#
635    )
636    .parse()
637    .unwrap()
638}
639
640#[cfg(feature = "v-trap")]
641#[proc_macro]
642/// Generates global '_start_DefaultHandler_trap' and '_continue_interrupt_trap' functions in assembly.
643/// The '_start_DefaultHandler_trap' function stores the trap frame partially (only register a0) and
644/// jumps to the interrupt handler. The '_continue_interrupt_trap' function stores the trap frame
645/// partially (all registers except a0), jumps to the interrupt handler, and restores the trap frame.
646pub fn vectored_interrupt_trap(_input: TokenStream) -> TokenStream {
647    let arch = RiscvArch::try_from_env().unwrap();
648    let width = arch.width();
649    let trap_size = arch.trap_frame().len();
650    let store_start = store_trap(arch, |reg| reg == "a0");
651    let store_continue = store_trap(arch, |reg| reg != "a0");
652    let load = load_trap(arch);
653
654    #[cfg(feature = "s-mode")]
655    let ret = "sret";
656    #[cfg(not(feature = "s-mode"))]
657    let ret = "mret";
658
659    let instructions = format!(
660        r#"
661core::arch::global_asm!(
662".section .trap.continue, \"ax\"
663
664.balign 4
665.global _start_DefaultHandler_trap
666_start_DefaultHandler_trap:
667    addi sp, sp, -{trap_size} * {width} // allocate space for trap frame
668    {store_start}                       // store trap partially (only register a0)
669    la a0, DefaultHandler               // load interrupt handler address into a0
670
671.global _continue_interrupt_trap
672_continue_interrupt_trap:
673    {store_continue}                   // store trap partially (all registers except a0)
674    jalr ra, a0, 0                     // jump to corresponding interrupt handler (address stored in a0)
675    {load}                             // restore trap frame
676    addi sp, sp, {trap_size} * {width} // deallocate space for trap frame
677    {ret}                              // return from interrupt
678");"#
679    );
680
681    instructions.parse().unwrap()
682}
683
684#[derive(Clone, Copy, Debug)]
685enum RiscvPacItem {
686    Exception,
687    ExternalInterrupt,
688    CoreInterrupt,
689}
690
691impl RiscvPacItem {
692    fn macro_id(&self) -> &str {
693        match self {
694            Self::Exception => "exception",
695            Self::ExternalInterrupt => "external_interrupt",
696            Self::CoreInterrupt => "core_interrupt",
697        }
698    }
699
700    fn valid_signature(&self) -> &str {
701        match self {
702            Self::Exception => "`[unsafe] [extern \"C\"] fn([&[mut] riscv_rt::TrapFrame]) [-> !]`",
703            _ => "`[unsafe] [extern \"C\"] fn() [-> !]`",
704        }
705    }
706
707    fn check_signature(&self, f: &ItemFn) -> bool {
708        let valid_args = match self {
709            Self::Exception => {
710                if f.sig.inputs.len() > 1 {
711                    return false;
712                }
713                match f.sig.inputs.first() {
714                    Some(FnArg::Typed(t)) => {
715                        let first_param_type = *t.ty.clone();
716                        let expected_types: Vec<Type> = vec![
717                            parse_quote!(&riscv_rt::TrapFrame),
718                            parse_quote!(&mut riscv_rt::TrapFrame),
719                        ];
720                        expected_types.contains(&first_param_type)
721                    }
722                    Some(_) => false,
723                    None => true,
724                }
725            }
726            _ => f.sig.inputs.is_empty(),
727        };
728
729        valid_args
730            && f.sig.constness.is_none()
731            && f.sig.asyncness.is_none()
732            && f.vis == Visibility::Inherited
733            && match &f.sig.abi {
734                None => true,
735                Some(syn::Abi {
736                    extern_token: _,
737                    name: Some(name),
738                }) if name.value() == "C" => true,
739                _ => false,
740            }
741            && f.sig.generics.params.is_empty()
742            && f.sig.generics.where_clause.is_none()
743            && f.sig.variadic.is_none()
744            && match f.sig.output {
745                ReturnType::Default => true,
746                ReturnType::Type(_, ref ty) => matches!(**ty, Type::Never(_)),
747            }
748    }
749
750    fn impl_trait(&self) -> TokenStream2 {
751        match self {
752            Self::Exception => quote! { riscv_rt::ExceptionNumber },
753            Self::ExternalInterrupt => quote! { riscv_rt::ExternalInterruptNumber },
754            Self::CoreInterrupt => quote! { riscv_rt::CoreInterruptNumber },
755        }
756    }
757}
758
759#[proc_macro_attribute]
760/// Attribute to declare an exception handler.
761///
762/// The function must have the signature `[unsafe] [extern "C"] fn([&[mut] riscv_rt::TrapFrame]) [-> !]`.
763///
764/// The argument of the macro must be a path to a variant of an enum that implements the `riscv_rt::ExceptionNumber` trait.
765///
766/// # Example
767///
768/// ``` ignore,no_run
769/// #[riscv_rt::exception(riscv::interrupt::Exception::LoadMisaligned)]
770/// fn load_misaligned(trap_frame: &mut riscv_rt::TrapFrame) -> ! {
771///     loop{};
772/// }
773/// ```
774pub fn exception(args: TokenStream, input: TokenStream) -> TokenStream {
775    trap(args, input, RiscvPacItem::Exception, None)
776}
777
778#[proc_macro_attribute]
779/// Attribute to declare a core interrupt handler.
780///
781/// The function must have the signature `[unsafe] [extern "C"] fn() [-> !]`.
782///
783/// The argument of the macro must be a path to a variant of an enum that implements the `riscv_rt::CoreInterruptNumber` trait.
784///
785/// If the `v-trap` feature is enabled, this macro generates the corresponding interrupt trap handler in assembly.
786///
787/// # Example
788///
789/// ``` ignore,no_run
790/// #[riscv_rt::core_interrupt(riscv::interrupt::Interrupt::SupervisorSoft)]
791/// fn supervisor_soft() -> ! {
792///     loop{};
793/// }
794/// ```
795pub fn core_interrupt(args: TokenStream, input: TokenStream) -> TokenStream {
796    let arch = match () {
797        #[cfg(feature = "v-trap")]
798        () => RiscvArch::try_from_env(),
799        #[cfg(not(feature = "v-trap"))]
800        () => None,
801    };
802    trap(args, input, RiscvPacItem::CoreInterrupt, arch)
803}
804
805#[proc_macro_attribute]
806/// Attribute to declare an external interrupt handler.
807///
808/// The function must have the signature `[unsafe] [extern "C"] fn() [-> !]`.
809///
810/// The argument of the macro must be a path to a variant of an enum that implements the `riscv_rt::ExternalInterruptNumber` trait.
811///
812/// # Example
813///
814/// ``` ignore,no_run
815/// #[riscv_rt::external_interrupt(e310x::interrupt::Interrupt::GPIO0)]
816/// fn gpio0() -> ! {
817///     loop{};
818/// }
819/// ```
820pub fn external_interrupt(args: TokenStream, input: TokenStream) -> TokenStream {
821    trap(args, input, RiscvPacItem::ExternalInterrupt, None)
822}
823
824fn trap(
825    args: TokenStream,
826    input: TokenStream,
827    pac_item: RiscvPacItem,
828    arch: Option<RiscvArch>,
829) -> TokenStream {
830    let f = parse_macro_input!(input as ItemFn);
831
832    if !pac_item.check_signature(&f) {
833        let msg = format!(
834            "`#[{}]` function must have signature {}",
835            pac_item.macro_id(),
836            pac_item.valid_signature()
837        );
838        return parse::Error::new(f.sig.span(), msg)
839            .to_compile_error()
840            .into();
841    }
842    if args.is_empty() {
843        let msg = format!(
844            "`#[{}]` attribute expects a path to a variant of an enum that implements the {} trait.",
845            pac_item.macro_id(),
846            pac_item.impl_trait()
847        );
848        return parse::Error::new(Span::call_site(), msg)
849            .to_compile_error()
850            .into();
851    }
852
853    let int_path = parse_macro_input!(args as Path);
854    let int_ident = &int_path.segments.last().unwrap().ident;
855    let export_name = format!("{int_ident:#}");
856
857    let start_trap = match arch {
858        #[cfg(feature = "v-trap")]
859        Some(arch) => {
860            let trap = start_interrupt_trap(int_ident, arch);
861            quote! {
862                #trap
863            }
864        }
865        _ => proc_macro2::TokenStream::new(),
866    };
867
868    let pac_trait = pac_item.impl_trait();
869
870    quote!(
871        // Compile-time check to ensure the trap path implements the trap trait
872        const _: fn() = || {
873            fn assert_impl<T: #pac_trait>(_arg: T) {}
874            assert_impl(#int_path);
875        };
876
877        #start_trap
878
879        #[export_name = #export_name]
880        #f
881    )
882    .into()
883}
884
885#[cfg(feature = "v-trap")]
886fn start_interrupt_trap(ident: &syn::Ident, arch: RiscvArch) -> proc_macro2::TokenStream {
887    let interrupt = ident.to_string();
888    let width = arch.width();
889    let trap_size = arch.trap_frame().len();
890    let store = store_trap(arch, |r| r == "a0");
891
892    let instructions = format!(
893        r#"
894#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))]
895core::arch::global_asm!(
896    ".section .trap.start.{interrupt}, \"ax\"
897    .balign 4
898    .global _start_{interrupt}_trap
899    _start_{interrupt}_trap:
900        addi sp, sp, -{trap_size} * {width} // allocate space for trap frame
901        {store}                             // store trap partially (only register a0)
902        la a0, {interrupt}                  // load interrupt handler address into a0
903        j _continue_interrupt_trap          // jump to common part of interrupt trap
904");"#
905    );
906
907    instructions.parse().unwrap()
908}