1#![feature(proc_macro_diagnostic)]
2extern crate proc_macro;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::{parse_macro_input, FnArg, ItemFn,Pat, PatType, Type};
6
7
8#[proc_macro_attribute]
9pub fn oscript_async_main(attr: TokenStream, item: TokenStream) -> TokenStream {
10 oscript_main_internal(attr, item, true)
11}
12
13
14#[proc_macro_attribute]
15pub fn oscript_main(attr: TokenStream, item: TokenStream) -> TokenStream {
16 oscript_main_internal(attr, item, false)
17}
18
19fn oscript_main_internal(attr: TokenStream, item: TokenStream, uses_async_macro: bool) -> TokenStream {
20 let mut use_tokio = uses_async_macro;
21
22 let parser = syn::meta::parser(|meta| {
23 if meta.path.is_ident("use_tokio") {
24 use_tokio = true;
25 Ok(())
26 } else {
27 Err(meta.error("unsupported attribute for [oscript..] on your main method.."))
28 }
29 });
30
31 parse_macro_input!(attr with parser);
32
33 let input = parse_macro_input!(item as ItemFn);
34
35 let is_marked_with_async = input.sig.asyncness.is_some();
36
37 if is_marked_with_async && !use_tokio {
38 panic!("Your main method is marked as async but you have not enabled tokio... Try using: #[oscript_async_main]");
39 } else if !is_marked_with_async && use_tokio {
40 proc_macro::Span::call_site()
41 .warning("#[oscript_async_main] used on non-async main will implicitly make it async! Consider marking it for clarity.")
42 .emit();
43 }
44
45 let fn_name = &input.sig.ident;
46 if fn_name != "main" {
47 panic!("Only the `main` function can be annotated with #[oscript_main]");
48 }
49 let fn_args = &input.sig.inputs;
50
51
52 let doc_comments: Vec<String> = input
54 .attrs
55 .iter()
56 .flat_map(|attr| {
57 if attr.path().is_ident("doc") {
58 attr.to_token_stream()
59 .to_string()
60 .split("\n")
61 .filter_map(|x|{
62 x.split_once('=')
63 .and_then(|(_, value)| value.split_once(']')
64 .map(|(comment, _)| comment.trim().trim_matches('"').to_string().trim().to_string()))
65 }).collect::<Vec<String>>()
66 } else {
67 vec![]
68 }
69 })
70 .collect();
71
72 let about_text = if !doc_comments.is_empty() {
73 let about = doc_comments.join("\n");
74 quote! { about=#about }
75 } else {
76 quote! { }
77 };
78
79 let struct_name = format_ident!("O{}Args", fn_name);
80
81 let mut struct_fields = Vec::new();
82 let mut fn_arg_conversions = Vec::new();
83
84 for arg in fn_args.iter() {
85 if let FnArg::Typed(PatType { pat, attrs, ty, .. }) = arg {
86 if let Pat::Ident(pat_ident) = &**pat {
87 let arg_name = &pat_ident.ident;
88
89 let oargs = quote! {
90 #[clap(long)] #(#attrs)*
92 };
93
94 match &**ty {
95 Type::Reference(_) => {
97 struct_fields.push(quote! {
98 #oargs
99 pub #arg_name: String
100 });
101 fn_arg_conversions.push(quote! {
102 let #arg_name = args.#arg_name.as_str();
103 });
104 }
105
106 _ => {
108 struct_fields.push(quote! {
109 #oargs
110 pub #arg_name: #ty
111 });
112 fn_arg_conversions.push(quote! {
113 let #arg_name = args.#arg_name.clone();
114 });
115 }
116 }
117 }
118 }
119 }
120
121
122
123 let clap_quote = quote! {
124 use clap::error::Error;
126 use clap::{Arg, ArgAction, ArgMatches, Args, Command, FromArgMatches, Parser};
127 #[derive(clap::Parser, Debug)]
128 pub struct #struct_name {
129 #(#struct_fields),*
130 }
131
132 #[derive(clap::Subcommand, Debug)]
133 pub enum Commands {
134 GenerateCompletion {
136 #[clap(long)]
137 shell: String,
138 },
139 #[clap(#about_text)]
140 Run(#struct_name),
141 }
142
143 #[derive(clap::Parser, Debug)]
144 #[command(author, version)]
145 pub struct OscriptCli {
146 #[clap(subcommand)]
147 pub command: Commands,
148 }
149 };
150
151 let completion_logic = quote! {
152 use clap_complete::Shell;
153 let mut app = <OscriptCli as clap::CommandFactory>::command();
154 clap_complete::generate(shell, &mut app, env!("CARGO_PKG_NAME"), &mut std::io::stdout());
155 };
156
157 let fn_block = &input.block;
158 let visibility = &input.vis;
159
160 let async_attr_marker = if use_tokio {
161 quote! {
162 #[tokio::main]
163 }
164 } else {
165 quote! {}
166 };
167 let fn_marker = if use_tokio {
168 quote! {
169 async fn
170 }
171 } else {
172 quote! {
173 fn
174 }
175 };
176
177 let output = quote! {
178 #clap_quote
179 #async_attr_marker
180 #visibility #fn_marker #fn_name() {
181 let cli = OscriptCli::parse();
182 match cli.command {
183 Commands::GenerateCompletion { shell } => {
184 let shell = shell.parse::<Shell>().expect("Invalid shell type");
185 #completion_logic
186 return;
187 }
188 Commands::Run(args) => {
189 #(#fn_arg_conversions)*
190 #fn_block
191 }
192 }
193 }
194 };
195
196 TokenStream::from(output)
197}