1#![deny(warnings)]
2#![allow(unknown_lints)] #![allow(clippy::manual_is_multiple_of)] use proc_macro::TokenStream;
6use proc_macro2::{Span, TokenStream as TokenStream2};
7use quote::quote;
8use syn::{
9 parse::{self, Parse},
10 parse_macro_input, parse_quote,
11 punctuated::Punctuated,
12 spanned::Spanned,
13 FnArg, ItemFn, LitInt, LitStr, PatType, Path, ReturnType, Token, Type, Visibility,
14};
15
16#[proc_macro_attribute]
48pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
49 let f = parse_macro_input!(input as ItemFn);
50
51 #[cfg(not(feature = "u-boot"))]
52 let arguments_limit = 3;
53 #[cfg(feature = "u-boot")]
54 let arguments_limit = 2;
55
56 if f.sig.inputs.len() > arguments_limit {
58 return parse::Error::new(
59 f.sig.inputs.last().unwrap().span(),
60 "`#[entry]` function has too many arguments",
61 )
62 .to_compile_error()
63 .into();
64 }
65
66 #[cfg(not(feature = "u-boot"))]
67 for argument in f.sig.inputs.iter() {
68 if let Some(message) = check_argument_type(argument, "usize") {
69 return message;
70 };
71 }
72 #[cfg(feature = "u-boot")]
73 if let Some(message) = f
74 .sig
75 .inputs
76 .get(0)
77 .and_then(|argument| check_argument_type(argument, "c_int"))
78 {
79 return message;
80 }
81 #[cfg(feature = "u-boot")]
82 if let Some(message) = f
83 .sig
84 .inputs
85 .get(1)
86 .and_then(|argument| check_argument_type(argument, "*const *const c_char"))
87 {
88 return message;
89 }
90
91 let valid_signature = f.sig.constness.is_none()
93 && f.sig.asyncness.is_none()
94 && f.vis == Visibility::Inherited
95 && f.sig.abi.is_none()
96 && f.sig.generics.params.is_empty()
97 && f.sig.generics.where_clause.is_none()
98 && f.sig.variadic.is_none()
99 && match f.sig.output {
100 ReturnType::Default => false,
101 ReturnType::Type(_, ref ty) => matches!(**ty, Type::Never(_)),
102 };
103
104 if !valid_signature {
105 return parse::Error::new(
106 f.span(),
107 "`#[entry]` function must have signature `[unsafe] fn([arg0: usize, ...]) -> !`",
108 )
109 .to_compile_error()
110 .into();
111 }
112
113 if !args.is_empty() {
114 return parse::Error::new(Span::call_site(), "This attribute accepts no arguments")
115 .to_compile_error()
116 .into();
117 }
118
119 let attrs = f.attrs;
121 let unsafety = f.sig.unsafety;
122 let args = f.sig.inputs;
123 let stmts = f.block.stmts;
124
125 quote!(
126 #[allow(non_snake_case)]
127 #[export_name = "main"]
128 #(#attrs)*
129 pub #unsafety fn __risc_v_rt__main(#args) -> ! {
130 #(#stmts)*
131 }
132 )
133 .into()
134}
135
136fn strip_type_path(ty: &Type) -> Option<Type> {
137 match ty {
138 Type::Ptr(ty) => {
139 let mut ty = ty.clone();
140 ty.elem = Box::new(strip_type_path(&ty.elem)?);
141 Some(Type::Ptr(ty))
142 }
143 Type::Path(ty) => {
144 let mut ty = ty.clone();
145 let last_segment = ty.path.segments.last().unwrap().clone();
146 ty.path.segments = Punctuated::new();
147 ty.path.segments.push_value(last_segment);
148 Some(Type::Path(ty))
149 }
150 _ => None,
151 }
152}
153
154#[allow(unused)]
155fn is_correct_type(ty: &Type, name: &str) -> bool {
156 let correct: Type = syn::parse_str(name).unwrap();
157 if let Some(ty) = strip_type_path(ty) {
158 ty == correct
159 } else {
160 false
161 }
162}
163
164fn check_correct_type(argument: &PatType, ty: &str) -> Option<TokenStream> {
165 let inv_type_message = format!("argument type must be {ty}");
166
167 if !is_correct_type(&argument.ty, ty) {
168 let error = parse::Error::new(argument.ty.span(), inv_type_message);
169
170 Some(error.to_compile_error().into())
171 } else {
172 None
173 }
174}
175
176fn check_argument_type(argument: &FnArg, ty: &str) -> Option<TokenStream> {
177 let argument_error = parse::Error::new(argument.span(), "invalid argument");
178 let argument_error = argument_error.to_compile_error().into();
179
180 match argument {
181 FnArg::Typed(argument) => check_correct_type(argument, ty),
182 FnArg::Receiver(_) => Some(argument_error),
183 }
184}
185
186#[deprecated(note = "Use global_asm! to define the __pre_init function instead")]
218#[proc_macro_attribute]
219pub fn pre_init(args: TokenStream, input: TokenStream) -> TokenStream {
220 let f = parse_macro_input!(input as ItemFn);
221
222 let valid_signature = f.sig.constness.is_none()
224 && f.sig.asyncness.is_none()
225 && f.vis == Visibility::Inherited
226 && f.sig.unsafety.is_some()
227 && f.sig.abi.is_none()
228 && f.sig.inputs.is_empty()
229 && f.sig.generics.params.is_empty()
230 && f.sig.generics.where_clause.is_none()
231 && f.sig.variadic.is_none()
232 && match f.sig.output {
233 ReturnType::Default => true,
234 ReturnType::Type(_, ref ty) => match **ty {
235 Type::Tuple(ref tuple) => tuple.elems.is_empty(),
236 _ => false,
237 },
238 };
239
240 if !valid_signature {
241 return parse::Error::new(
242 f.span(),
243 "`#[pre_init]` function must have signature `unsafe fn()`",
244 )
245 .to_compile_error()
246 .into();
247 }
248
249 if !args.is_empty() {
250 return parse::Error::new(Span::call_site(), "This attribute accepts no arguments")
251 .to_compile_error()
252 .into();
253 }
254
255 let attrs = f.attrs;
257 let ident = f.sig.ident;
258 let block = f.block;
259
260 quote!(
261 #[export_name = "__pre_init"]
262 #(#attrs)*
263 pub unsafe fn #ident() #block
264 )
265 .into()
266}
267
268#[proc_macro_attribute]
292pub fn post_init(args: TokenStream, input: TokenStream) -> TokenStream {
293 let f = parse_macro_input!(input as ItemFn);
294
295 if f.sig.inputs.len() > 1 {
297 return parse::Error::new(
298 f.sig.inputs.last().unwrap().span(),
299 "`#[post_init]` function has too many arguments",
300 )
301 .to_compile_error()
302 .into();
303 }
304 for argument in f.sig.inputs.iter() {
305 if let Some(message) = check_argument_type(argument, "usize") {
306 return message;
307 };
308 }
309
310 let valid_signature = f.sig.constness.is_none()
312 && f.sig.asyncness.is_none()
313 && f.vis == Visibility::Inherited
314 && f.sig.abi.is_none()
315 && f.sig.generics.params.is_empty()
316 && f.sig.generics.where_clause.is_none()
317 && f.sig.variadic.is_none()
318 && match f.sig.output {
319 ReturnType::Default => true,
320 ReturnType::Type(_, ref ty) => match **ty {
321 Type::Tuple(ref tuple) => tuple.elems.is_empty(),
322 _ => false,
323 },
324 };
325
326 if !valid_signature {
327 return parse::Error::new(
328 f.span(),
329 "`#[post_init]` function must have signature `[unsafe] fn([usize])`",
330 )
331 .to_compile_error()
332 .into();
333 }
334
335 if !args.is_empty() {
336 return parse::Error::new(Span::call_site(), "This attribute accepts no arguments")
337 .to_compile_error()
338 .into();
339 }
340
341 let attrs = f.attrs;
343 let ident = f.sig.ident;
344 let args = f.sig.inputs;
345 let block = f.block;
346
347 quote!(
348 #[export_name = "__post_init"]
349 #(#attrs)*
350 unsafe fn #ident(#args) #block
351 )
352 .into()
353}
354
355struct AsmLoopArgs {
356 asm_template: String,
357 count_from: usize,
358 count_to: usize,
359}
360
361impl Parse for AsmLoopArgs {
362 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
363 let template: LitStr = input.parse().unwrap();
364 _ = input.parse::<Token![,]>().unwrap();
365 let count: LitInt = input.parse().unwrap();
366 if input.parse::<Token![,]>().is_ok() {
367 let count_to: LitInt = input.parse().unwrap();
368 Ok(Self {
369 asm_template: template.value(),
370 count_from: count.base10_parse().unwrap(),
371 count_to: count_to.base10_parse().unwrap(),
372 })
373 } else {
374 Ok(Self {
375 asm_template: template.value(),
376 count_from: 0,
377 count_to: count.base10_parse().unwrap(),
378 })
379 }
380 }
381}
382
383#[proc_macro]
406pub fn loop_asm(input: TokenStream) -> TokenStream {
407 let args = parse_macro_input!(input as AsmLoopArgs);
408
409 let tokens = (args.count_from..args.count_to)
410 .map(|i| {
411 let i = i.to_string();
412 let asm = args.asm_template.replace("{}", &i);
413 format!("core::arch::asm!(\"{asm}\");")
414 })
415 .collect::<Vec<String>>()
416 .join("\n");
417 tokens.parse().unwrap()
418}
419
420#[proc_macro]
443pub fn loop_global_asm(input: TokenStream) -> TokenStream {
444 let args = parse_macro_input!(input as AsmLoopArgs);
445
446 let instructions = (args.count_from..args.count_to)
447 .map(|i| {
448 let i = i.to_string();
449 args.asm_template.replace("{}", &i)
450 })
451 .collect::<Vec<String>>()
452 .join("\n");
453
454 let res = format!("core::arch::global_asm!(\n\"{instructions}\"\n);");
455 res.parse().unwrap()
456}
457
458#[derive(Clone, Copy, Debug)]
459enum RiscvArch {
460 Rv32I,
461 Rv32E,
462 Rv64I,
463 Rv64E,
464}
465
466impl Parse for RiscvArch {
467 fn parse(input: parse::ParseStream) -> syn::Result<Self> {
468 let ident: syn::Ident = input.parse()?;
469 match ident.to_string().as_str() {
470 "rv32i" => Ok(Self::Rv32I),
471 "rv32e" => Ok(Self::Rv32E),
472 "rv64i" => Ok(Self::Rv64I),
473 "rv64e" => Ok(Self::Rv64E),
474 _ => Err(syn::Error::new(ident.span(), "Invalid RISC-V architecture")),
475 }
476 }
477}
478
479impl RiscvArch {
480 fn try_from_env() -> Option<Self> {
481 let arch = std::env::var("RISCV_RT_BASE_ISA").ok()?;
482 match arch.as_str() {
483 "rv32i" => Some(Self::Rv32I),
484 "rv32e" => Some(Self::Rv32E),
485 "rv64i" => Some(Self::Rv64I),
486 "rv64e" => Some(Self::Rv64E),
487 _ => None,
488 }
489 }
490
491 const fn width(&self) -> usize {
492 match self {
493 Self::Rv32I | Self::Rv32E => 4,
494 Self::Rv64I | Self::Rv64E => 8,
495 }
496 }
497
498 const fn store(&self) -> &str {
499 match self {
500 Self::Rv32I | Self::Rv32E => "sw",
501 Self::Rv64I | Self::Rv64E => "sd",
502 }
503 }
504
505 const fn load(&self) -> &str {
506 match self {
507 Self::Rv32I | Self::Rv32E => "lw",
508 Self::Rv64I | Self::Rv64E => "ld",
509 }
510 }
511
512 fn trap_frame(&self) -> Vec<&str> {
513 match self {
514 Self::Rv32I | Self::Rv64I => vec![
515 "ra", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a0", "a1", "a2", "a3", "a4", "a5",
516 "a6", "a7",
517 ],
518 Self::Rv32E | Self::Rv64E => {
519 vec!["ra", "t0", "t1", "t2", "a0", "a1", "a2", "a3", "a4", "a5"]
520 }
521 }
522 }
523
524 const fn byte_alignment(&self) -> usize {
530 match self {
531 Self::Rv32E | Self::Rv64E => 4,
532 _ => 16,
533 }
534 }
535}
536
537fn store_trap<T: FnMut(&str) -> bool>(arch: RiscvArch, mut filter: T) -> String {
546 let width = arch.width();
547 let store = arch.store();
548 arch.trap_frame()
549 .iter()
550 .enumerate()
551 .filter(|(_, ®)| !reg.starts_with('_') && filter(reg))
552 .map(|(i, reg)| format!("{store} {reg}, {i}*{width}(sp)"))
553 .collect::<Vec<_>>()
554 .join("\n ")
555}
556
557fn load_trap(arch: RiscvArch) -> String {
560 let width = arch.width();
561 let load = arch.load();
562 arch.trap_frame()
563 .iter()
564 .enumerate()
565 .filter(|(_, ®)| !reg.starts_with('_'))
566 .map(|(i, reg)| format!("{load} {reg}, {i}*{width}(sp)"))
567 .collect::<Vec<_>>()
568 .join("\n ")
569}
570
571#[proc_macro]
573pub fn llvm_arch_patch(_input: TokenStream) -> TokenStream {
574 let q = if let Ok(arch) = std::env::var("RISCV_RT_LLVM_ARCH_PATCH") {
575 let patch = format!(".attribute arch,\"{arch}\"");
576 quote! { core::arch::global_asm!{#patch} }
577 } else {
578 quote!(compile_error!("RISCV_RT_LLVM_ARCH_PATCH is not set"))
579 };
580 q.into()
581}
582
583#[proc_macro]
589pub fn default_start_trap(_input: TokenStream) -> TokenStream {
590 let arch = RiscvArch::try_from_env().unwrap();
591
592 let width = arch.width();
593 let trap_size = arch.trap_frame().len();
594 let byte_alignment = arch.byte_alignment();
595 if (trap_size * width) % byte_alignment != 0 {
597 return parse::Error::new(Span::call_site(), "Trap frame size must be 16-byte aligned")
598 .to_compile_error()
599 .into();
600 }
601 let store = store_trap(arch, |_| true);
602 let load = load_trap(arch);
603
604 #[cfg(feature = "s-mode")]
605 let ret = "sret";
606 #[cfg(not(feature = "s-mode"))]
607 let ret = "mret";
608
609 let pre_default_start_trap = if cfg!(feature = "pre-default-start-trap") {
610 r#"
611 j _pre_default_start_trap
612.global _pre_default_start_trap_ret
613_pre_default_start_trap_ret:
614 "#
615 } else {
616 ""
617 };
618
619 format!(
620 r#"
621core::arch::global_asm!(
622".section .trap.start, \"ax\"
623.balign 4 /* Alignment required for xtvec */
624.global _default_start_trap
625_default_start_trap:
626 {pre_default_start_trap}
627 addi sp, sp, - {trap_size} * {width}
628 {store}
629 add a0, sp, zero
630 jal ra, _start_trap_rust
631 {load}
632 addi sp, sp, {trap_size} * {width}
633 {ret}
634");"#
635 )
636 .parse()
637 .unwrap()
638}
639
640#[cfg(feature = "v-trap")]
641#[proc_macro]
642pub fn vectored_interrupt_trap(_input: TokenStream) -> TokenStream {
647 let arch = RiscvArch::try_from_env().unwrap();
648 let width = arch.width();
649 let trap_size = arch.trap_frame().len();
650 let store_start = store_trap(arch, |reg| reg == "a0");
651 let store_continue = store_trap(arch, |reg| reg != "a0");
652 let load = load_trap(arch);
653
654 #[cfg(feature = "s-mode")]
655 let ret = "sret";
656 #[cfg(not(feature = "s-mode"))]
657 let ret = "mret";
658
659 let instructions = format!(
660 r#"
661core::arch::global_asm!(
662".section .trap.continue, \"ax\"
663
664.balign 4
665.global _start_DefaultHandler_trap
666_start_DefaultHandler_trap:
667 addi sp, sp, -{trap_size} * {width} // allocate space for trap frame
668 {store_start} // store trap partially (only register a0)
669 la a0, DefaultHandler // load interrupt handler address into a0
670
671.global _continue_interrupt_trap
672_continue_interrupt_trap:
673 {store_continue} // store trap partially (all registers except a0)
674 jalr ra, a0, 0 // jump to corresponding interrupt handler (address stored in a0)
675 {load} // restore trap frame
676 addi sp, sp, {trap_size} * {width} // deallocate space for trap frame
677 {ret} // return from interrupt
678");"#
679 );
680
681 instructions.parse().unwrap()
682}
683
684#[derive(Clone, Copy, Debug)]
685enum RiscvPacItem {
686 Exception,
687 ExternalInterrupt,
688 CoreInterrupt,
689}
690
691impl RiscvPacItem {
692 fn macro_id(&self) -> &str {
693 match self {
694 Self::Exception => "exception",
695 Self::ExternalInterrupt => "external_interrupt",
696 Self::CoreInterrupt => "core_interrupt",
697 }
698 }
699
700 fn valid_signature(&self) -> &str {
701 match self {
702 Self::Exception => "`[unsafe] [extern \"C\"] fn([&[mut] riscv_rt::TrapFrame]) [-> !]`",
703 _ => "`[unsafe] [extern \"C\"] fn() [-> !]`",
704 }
705 }
706
707 fn check_signature(&self, f: &ItemFn) -> bool {
708 let valid_args = match self {
709 Self::Exception => {
710 if f.sig.inputs.len() > 1 {
711 return false;
712 }
713 match f.sig.inputs.first() {
714 Some(FnArg::Typed(t)) => {
715 let first_param_type = *t.ty.clone();
716 let expected_types: Vec<Type> = vec![
717 parse_quote!(&riscv_rt::TrapFrame),
718 parse_quote!(&mut riscv_rt::TrapFrame),
719 ];
720 expected_types.contains(&first_param_type)
721 }
722 Some(_) => false,
723 None => true,
724 }
725 }
726 _ => f.sig.inputs.is_empty(),
727 };
728
729 valid_args
730 && f.sig.constness.is_none()
731 && f.sig.asyncness.is_none()
732 && f.vis == Visibility::Inherited
733 && match &f.sig.abi {
734 None => true,
735 Some(syn::Abi {
736 extern_token: _,
737 name: Some(name),
738 }) if name.value() == "C" => true,
739 _ => false,
740 }
741 && f.sig.generics.params.is_empty()
742 && f.sig.generics.where_clause.is_none()
743 && f.sig.variadic.is_none()
744 && match f.sig.output {
745 ReturnType::Default => true,
746 ReturnType::Type(_, ref ty) => matches!(**ty, Type::Never(_)),
747 }
748 }
749
750 fn impl_trait(&self) -> TokenStream2 {
751 match self {
752 Self::Exception => quote! { riscv_rt::ExceptionNumber },
753 Self::ExternalInterrupt => quote! { riscv_rt::ExternalInterruptNumber },
754 Self::CoreInterrupt => quote! { riscv_rt::CoreInterruptNumber },
755 }
756 }
757}
758
759#[proc_macro_attribute]
760pub fn exception(args: TokenStream, input: TokenStream) -> TokenStream {
775 trap(args, input, RiscvPacItem::Exception, None)
776}
777
778#[proc_macro_attribute]
779pub fn core_interrupt(args: TokenStream, input: TokenStream) -> TokenStream {
796 let arch = match () {
797 #[cfg(feature = "v-trap")]
798 () => RiscvArch::try_from_env(),
799 #[cfg(not(feature = "v-trap"))]
800 () => None,
801 };
802 trap(args, input, RiscvPacItem::CoreInterrupt, arch)
803}
804
805#[proc_macro_attribute]
806pub fn external_interrupt(args: TokenStream, input: TokenStream) -> TokenStream {
821 trap(args, input, RiscvPacItem::ExternalInterrupt, None)
822}
823
824fn trap(
825 args: TokenStream,
826 input: TokenStream,
827 pac_item: RiscvPacItem,
828 arch: Option<RiscvArch>,
829) -> TokenStream {
830 let f = parse_macro_input!(input as ItemFn);
831
832 if !pac_item.check_signature(&f) {
833 let msg = format!(
834 "`#[{}]` function must have signature {}",
835 pac_item.macro_id(),
836 pac_item.valid_signature()
837 );
838 return parse::Error::new(f.sig.span(), msg)
839 .to_compile_error()
840 .into();
841 }
842 if args.is_empty() {
843 let msg = format!(
844 "`#[{}]` attribute expects a path to a variant of an enum that implements the {} trait.",
845 pac_item.macro_id(),
846 pac_item.impl_trait()
847 );
848 return parse::Error::new(Span::call_site(), msg)
849 .to_compile_error()
850 .into();
851 }
852
853 let int_path = parse_macro_input!(args as Path);
854 let int_ident = &int_path.segments.last().unwrap().ident;
855 let export_name = format!("{int_ident:#}");
856
857 let start_trap = match arch {
858 #[cfg(feature = "v-trap")]
859 Some(arch) => {
860 let trap = start_interrupt_trap(int_ident, arch);
861 quote! {
862 #trap
863 }
864 }
865 _ => proc_macro2::TokenStream::new(),
866 };
867
868 let pac_trait = pac_item.impl_trait();
869
870 quote!(
871 const _: fn() = || {
873 fn assert_impl<T: #pac_trait>(_arg: T) {}
874 assert_impl(#int_path);
875 };
876
877 #start_trap
878
879 #[export_name = #export_name]
880 #f
881 )
882 .into()
883}
884
885#[cfg(feature = "v-trap")]
886fn start_interrupt_trap(ident: &syn::Ident, arch: RiscvArch) -> proc_macro2::TokenStream {
887 let interrupt = ident.to_string();
888 let width = arch.width();
889 let trap_size = arch.trap_frame().len();
890 let store = store_trap(arch, |r| r == "a0");
891
892 let instructions = format!(
893 r#"
894#[cfg(any(target_arch = "riscv32", target_arch = "riscv64"))]
895core::arch::global_asm!(
896 ".section .trap.start.{interrupt}, \"ax\"
897 .balign 4
898 .global _start_{interrupt}_trap
899 _start_{interrupt}_trap:
900 addi sp, sp, -{trap_size} * {width} // allocate space for trap frame
901 {store} // store trap partially (only register a0)
902 la a0, {interrupt} // load interrupt handler address into a0
903 j _continue_interrupt_trap // jump to common part of interrupt trap
904");"#
905 );
906
907 instructions.parse().unwrap()
908}