oscript_derive/
lib.rs

1#![feature(proc_macro_diagnostic)]
2extern crate proc_macro;
3use proc_macro::TokenStream;
4use quote::{format_ident, quote, ToTokens};
5use syn::{parse_macro_input, FnArg, ItemFn,Pat, PatType, Type};
6
7
8#[proc_macro_attribute]
9pub fn oscript_async_main(attr: TokenStream, item: TokenStream) -> TokenStream {
10    oscript_main_internal(attr, item, true)
11}
12
13
14#[proc_macro_attribute]
15pub fn oscript_main(attr: TokenStream, item: TokenStream) -> TokenStream {
16    oscript_main_internal(attr, item, false)
17}
18
19fn oscript_main_internal(attr: TokenStream, item: TokenStream, uses_async_macro: bool) -> TokenStream {
20    let mut use_tokio = uses_async_macro;
21
22    let parser = syn::meta::parser(|meta| {
23        if meta.path.is_ident("use_tokio") {
24            use_tokio = true;
25            Ok(())
26        } else {
27            Err(meta.error("unsupported attribute for [oscript..] on your main method.."))
28        }
29    });
30
31    parse_macro_input!(attr with parser);
32
33    let input = parse_macro_input!(item as ItemFn);
34
35    let is_marked_with_async = input.sig.asyncness.is_some();
36
37    if is_marked_with_async && !use_tokio {
38        panic!("Your main method is marked as async but you have not enabled tokio... Try using: #[oscript_async_main]");
39    } else if !is_marked_with_async && use_tokio {
40        proc_macro::Span::call_site()
41            .warning("#[oscript_async_main] used on non-async main will implicitly make it async! Consider marking it for clarity.")
42            .emit();
43    }
44
45    let fn_name = &input.sig.ident;
46    if fn_name != "main" {
47        panic!("Only the `main` function can be annotated with #[oscript_main]");
48    }
49    let fn_args = &input.sig.inputs;
50    
51   
52    // this just exists to make newlines work correctly for comments --> clap(help=..)
53    let doc_comments: Vec<String> = input
54        .attrs
55        .iter()
56        .flat_map(|attr| {
57            if attr.path().is_ident("doc") {
58                attr.to_token_stream()
59                    .to_string()
60                    .split("\n")
61                    .filter_map(|x|{
62                        x.split_once('=')
63                        .and_then(|(_, value)| value.split_once(']')
64                        .map(|(comment, _)| comment.trim().trim_matches('"').to_string().trim().to_string()))
65                    }).collect::<Vec<String>>()
66            } else {
67                vec![]
68            }
69        })
70        .collect();
71
72    let about_text = if !doc_comments.is_empty() {
73        let about = doc_comments.join("\n");
74        quote! { about=#about }
75    } else {
76        quote! { }
77    };
78
79    let struct_name = format_ident!("O{}Args", fn_name);
80
81    let mut struct_fields = Vec::new();
82    let mut fn_arg_conversions = Vec::new();
83
84    for arg in fn_args.iter() {
85        if let FnArg::Typed(PatType { pat, attrs, ty, .. }) = arg {
86            if let Pat::Ident(pat_ident) = &**pat {
87                let arg_name = &pat_ident.ident;
88               
89                let oargs = quote! {
90                    #[clap(long)] // gets overridden by any clap macro in attrs
91                    #(#attrs)*
92                };
93
94                match &**ty {
95                    // For references, use String as the type and handle the conversion
96                    Type::Reference(_) => {
97                        struct_fields.push(quote! {
98                            #oargs
99                            pub #arg_name: String
100                        });
101                        fn_arg_conversions.push(quote! {
102                            let #arg_name = args.#arg_name.as_str();
103                        });
104                    }
105
106                    // For all other types, use the type directly
107                    _ => {
108                        struct_fields.push(quote! {
109                            #oargs
110                            pub #arg_name: #ty
111                        });
112                        fn_arg_conversions.push(quote! {
113                            let #arg_name = args.#arg_name.clone();
114                        });
115                    }
116                }
117            }
118        }
119    }
120
121    
122
123    let clap_quote = quote! {
124        //use clap::*;
125        use clap::error::Error;
126        use clap::{Arg, ArgAction, ArgMatches, Args, Command, FromArgMatches, Parser};
127        #[derive(clap::Parser, Debug)]
128        pub struct #struct_name {
129            #(#struct_fields),*
130        }
131
132        #[derive(clap::Subcommand, Debug)]
133        pub enum Commands {
134            /// Generate shell completion scripts
135            GenerateCompletion {
136                #[clap(long)]
137                shell: String,
138            },
139            #[clap(#about_text)]
140            Run(#struct_name),
141        }
142
143        #[derive(clap::Parser, Debug)]
144        #[command(author, version)]
145        pub struct OscriptCli {
146            #[clap(subcommand)]
147            pub command: Commands,
148        }
149    };
150
151    let completion_logic = quote! {
152        use clap_complete::Shell;
153        let mut app = <OscriptCli as clap::CommandFactory>::command();
154        clap_complete::generate(shell, &mut app, env!("CARGO_PKG_NAME"), &mut std::io::stdout());
155    };
156
157    let fn_block = &input.block;
158    let visibility = &input.vis;
159
160    let async_attr_marker = if use_tokio {
161        quote! {
162            #[tokio::main]
163        }
164    } else {
165        quote! {}
166    };
167    let fn_marker = if use_tokio {
168        quote! {
169            async fn
170        }
171    } else {
172        quote! {
173            fn
174        }
175    };
176
177    let output = quote! {
178        #clap_quote
179        #async_attr_marker
180        #visibility #fn_marker #fn_name() {
181            let cli = OscriptCli::parse();
182            match cli.command {
183                Commands::GenerateCompletion { shell } => {
184                    let shell = shell.parse::<Shell>().expect("Invalid shell type");
185                    #completion_logic
186                    return;
187                }
188                Commands::Run(args) => {
189                    #(#fn_arg_conversions)*
190                    #fn_block
191                }
192            }
193        }
194    };
195
196    TokenStream::from(output)
197}