1use proc_macro2::Span;
2use shuttle_common::models::infra::InfraRequest;
3use syn::{
4 meta::{parser, ParseNestedMeta},
5 parse::Parser,
6 parse_file, parse_quote,
7 spanned::Spanned,
8 Attribute, Item, ItemFn, LitStr, Meta, MetaList, Path,
9};
10
11pub fn parse_infra_from_code(rust_source_code: &str) -> Result<Option<InfraRequest>, syn::Error> {
14 let Some((_main_fn, main_attr)) = find_runtime_main_fn(rust_source_code)? else {
15 return Err(syn::Error::new(
16 Span::call_site(),
17 "No function using #[shuttle_runtime::main] found",
18 ));
19 };
20
21 parse_infra_from_meta(&main_attr.meta)
23}
24
25pub fn find_runtime_main_fn(
27 rust_source_code: &str,
28) -> Result<Option<(ItemFn, Attribute)>, syn::Error> {
29 let ast = parse_file(rust_source_code)?;
30
31 let main_fn_and_attr = ast.items.into_iter().find_map(|item| match item {
32 Item::Fn(item_fn) => main_fn_and_attr(item_fn),
33 _ => None,
34 });
35
36 Ok(main_fn_and_attr)
37}
38
39pub fn main_fn_and_attr(item_fn: ItemFn) -> Option<(ItemFn, Attribute)> {
41 let runtime_main_path: Path = parse_quote! { shuttle_runtime::main };
42 let codegen_main_path: Path = parse_quote! { shuttle_codegen::main };
43 item_fn
44 .attrs
45 .clone()
46 .into_iter()
47 .find(|attr| attr.path() == &runtime_main_path || attr.path() == &codegen_main_path)
48 .map(|attr| (item_fn, attr))
49}
50
51fn parse_infra_from_meta(meta: &Meta) -> Result<Option<InfraRequest>, syn::Error> {
52 match meta {
53 Meta::Path(_) => Ok(None),
55 Meta::List(ref meta_list) => parse_infra_from_meta_list(meta_list).map(Some),
57 Meta::NameValue(_) => Err(syn::Error::new(
59 meta.span(),
60 "Expected plain attribute or list",
61 )),
62 }
63}
64
65fn parse_infra_from_meta_list(meta_list: &MetaList) -> Result<InfraRequest, syn::Error> {
66 let mut infra_parser = InfraAttrParser::default();
67 let meta_parser = parser(|meta| infra_parser.parse_nested_meta(meta));
68 meta_parser.parse2(meta_list.tokens.clone())?;
69 Ok(infra_parser.into_infra())
70}
71
72#[derive(Default)]
73pub struct InfraAttrParser(InfraRequest);
74impl InfraAttrParser {
75 pub fn parse_nested_meta(&mut self, meta: ParseNestedMeta) -> Result<(), syn::Error> {
79 let key = meta.path.require_ident()?.to_string();
80 let value = meta.value()?;
81 match key.as_str() {
82 "instance_size" => {
83 self.0.instance_size =
84 Some(value.parse::<LitStr>()?.value().parse().map_err(|e| {
85 syn::Error::new(value.span(), format!("Invalid value: {e}"))
86 })?);
87 }
88 unknown_key => {
89 return Err(syn::Error::new(
90 key.span(),
91 format!("Invalid macro attribute key: '{unknown_key}'"),
92 ))
93 }
94 }
95 Ok(())
96 }
97 pub fn into_infra(self) -> InfraRequest {
98 self.0
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use shuttle_common::models::project::ComputeTier;
105
106 use super::*;
107
108 #[test]
109 fn infra_meta() {
110 let attr: Attribute = parse_quote! { #[shuttle_runtime::main(instance_size = "m")] };
111 assert_eq!(
112 parse_infra_from_meta(&attr.meta).unwrap().unwrap(),
113 InfraRequest {
114 instance_size: Some(ComputeTier::M),
115 ..Default::default()
116 }
117 );
118
119 let attr: Attribute = parse_quote! { #[shuttle_runtime::main(instance_size = "xyz",)] };
120 assert_eq!(
121 parse_infra_from_meta(&attr.meta).unwrap_err().to_string(),
122 "Invalid value: Matching variant not found"
123 );
124
125 let attr: Attribute = parse_quote! { #[shuttle_runtime::main()] };
126 assert_eq!(
127 parse_infra_from_meta(&attr.meta).unwrap().unwrap(),
128 InfraRequest::default()
129 );
130
131 let attr: Attribute = parse_quote! { #[shuttle_runtime::main] };
132 assert_eq!(parse_infra_from_meta(&attr.meta).unwrap(), None);
133
134 let attr: Attribute = parse_quote! { #[shuttle_runtime = "132"] };
135 assert_eq!(
136 parse_infra_from_meta(&attr.meta).unwrap_err().to_string(),
137 "Expected plain attribute or list"
138 );
139
140 let attr: Attribute = parse_quote! { #[shuttle_runtime::main(,)] };
141 assert_eq!(
142 parse_infra_from_meta(&attr.meta).unwrap_err().to_string(),
143 "unexpected token in nested attribute, expected ident"
144 );
145 }
146
147 #[test]
148 fn find_main_fn() {
149 let rust = r#"
150 use abc::def;
151
152 fn blob() -> u8 {}
153
154 #[shuttle_runtime::main]
155 async fn main() -> ShuttleAxum {}
156 "#;
157 assert!(find_runtime_main_fn(rust).unwrap().is_some());
158
159 let rust = r#"
160 #[shuttle_codegen::main]
161 async fn main() -> ShuttleAxum {}
162 "#;
163 assert!(find_runtime_main_fn(rust).unwrap().is_some());
164
165 let rust = r#"
167 use shuttle_runtime::main;
168 #[main]
169 async fn main() -> ShuttleAxum {}
170 "#;
171 assert!(find_runtime_main_fn(rust).unwrap().is_none());
172
173 let rust = r#"
175 mod not_root {
176 #[shuttle_runtime::main]
177 async fn main() -> ShuttleAxum {}
178 }
179 "#;
180 assert!(find_runtime_main_fn(rust).unwrap().is_none());
181 }
182
183 #[test]
184 fn parse() {
185 let rust = r#"
186 #[shuttle_runtime::main(
187 instance_size = "m",
188 )]
189 async fn main() -> ShuttleAxum {}
190 "#;
191 assert_eq!(
192 parse_infra_from_code(rust).unwrap().unwrap(),
193 InfraRequest {
194 instance_size: Some(ComputeTier::M),
195 ..Default::default()
196 }
197 );
198
199 let rust = r#"
200 #[shuttle_runtime::main { instance_size = "xxl" } ]
201 async fn main() -> ShuttleAxum {}
202 "#;
203 assert_eq!(
204 parse_infra_from_code(rust).unwrap().unwrap(),
205 InfraRequest {
206 instance_size: Some(ComputeTier::XXL),
207 ..Default::default()
208 }
209 );
210
211 let rust = r#"
212 #[shuttle_runtime::main[instance_size = "xs"]]
213 async fn main() -> ShuttleAxum {}
214 "#;
215 assert_eq!(
216 parse_infra_from_code(rust).unwrap().unwrap(),
217 InfraRequest {
218 instance_size: Some(ComputeTier::XS),
219 ..Default::default()
220 }
221 );
222
223 let rust = r#"
224 #[shuttle_runtime::main(instance_size = 500000)]
225 async fn main() -> ShuttleAxum {}
226 "#;
227 assert_eq!(
228 parse_infra_from_code(rust).unwrap_err().to_string(),
229 "expected string literal"
230 );
231
232 let rust = r#"
233 #[shuttle_runtime::main(leet = 1337)]
234 async fn main() -> ShuttleAxum {}
235 "#;
236 assert_eq!(
237 parse_infra_from_code(rust).unwrap_err().to_string(),
238 "Invalid macro attribute key: 'leet'"
239 );
240 }
241}