1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
3use quote::quote as q;
4
5use turbolift_internals::extract_function;
6
7#[cfg(feature = "distributed")]
8#[proc_macro_attribute]
9#[tracing::instrument]
10pub fn on(distribution_platform_: TokenStream, function_: TokenStream) -> TokenStream {
11 use quote::{format_ident, ToTokens};
12 use std::fs;
13 use std::path::PathBuf;
14 use std::str::FromStr;
15
16 use turbolift_internals::{build_project, CACHE_PATH};
17
18 const RUN_ID_NAME: &str = "_turbolift_run_id";
19
20 let distribution_platform = TokenStream2::from(distribution_platform_);
22 let function = TokenStream2::from(function_);
23
24 let original_target_function = extract_function::get_fn_item(function.clone());
26 let original_target_function_ident = original_target_function.sig.ident.clone();
27 let original_target_function_name = original_target_function_ident.to_string();
28 let mut target_function = original_target_function.clone();
29 target_function.sig.ident = format_ident!("{}_raw", target_function.sig.ident);
30 let signature = target_function.sig.clone();
31 let function_name = signature.ident;
32 let function_name_string = function_name.to_string();
33 let typed_params = signature.inputs;
34 let untyped_params = extract_function::to_untyped_params(typed_params.clone());
35 let mut untyped_params_with_run_id = untyped_params.clone();
36 untyped_params_with_run_id.insert(
38 0,
39 Box::new(syn::Pat::Ident(syn::PatIdent {
40 attrs: Vec::new(),
41 by_ref: None,
42 mutability: None,
43 ident: Ident::new(RUN_ID_NAME, Span::call_site()),
44 subpat: None,
45 })),
46 );
47 let untyped_params_tokens_with_run_id = untyped_params_with_run_id.to_token_stream();
48 let untyped_params_tokens = untyped_params.to_token_stream();
49 let params_as_path = extract_function::to_path_params(untyped_params.clone());
50 let wrapper_route = format!(
51 "{}/{{{}}}/{}",
52 original_target_function_name, RUN_ID_NAME, ¶ms_as_path
53 );
54 let mut param_types = extract_function::to_param_types(typed_params.clone());
55 param_types.insert(
57 0,
58 Box::new(syn::Type::Verbatim(
59 str::parse::<TokenStream2>("String")
60 .expect("could not parse \"String\" as a tokenstream"),
61 )),
62 );
63 let params_vec = extract_function::params_json_vec(untyped_params.clone());
64 let result_type = extract_function::get_result_type(&signature.output);
65 let dummy_function = extract_function::make_dummy_function(
66 original_target_function,
67 &function_name_string,
68 untyped_params,
69 );
70
71 let sanitized_file = extract_function::get_sanitized_file(&function);
75 let main_file = q! {
77 #sanitized_file
78 use turbolift::tokio_compat_02::FutureExt;
79
80 #dummy_function
81 #target_function
82
83 async fn health_probe(_req: turbolift::actix_web::HttpRequest) -> impl turbolift::actix_web::Responder {
84 turbolift::actix_web::HttpResponse::Ok()
85 }
86
87 #[turbolift::tracing::instrument]
88 async fn turbolift_wrapper(turbolift::actix_web::web::Path((#untyped_params_tokens_with_run_id)): turbolift::actix_web::web::Path<(#param_types)>) -> impl turbolift::actix_web::Responder {
89 turbolift::actix_web::HttpResponse::Ok()
90 .json(#function_name(#untyped_params_tokens))
91 }
92
93 #[turbolift::tracing::instrument]
94 fn main() {
95 turbolift::actix_web::rt::System::new("main".to_string())
96 .block_on(async move {
97 let args: Vec<String> = std::env::args().collect();
98 let ip_and_port = &args[1];
99 turbolift::tracing::info!("service main() started. ip_and_port parsed.");
100 turbolift::actix_web::HttpServer::new(
101 ||
102 turbolift::actix_web::App::new()
103 .route(
104 #wrapper_route, turbolift::actix_web::web::get().to(turbolift_wrapper)
105 )
106 .route(
107 "/health-probe", turbolift::actix_web::web::get().to(health_probe)
108 )
109 .default_service(
110 turbolift::actix_web::web::get()
111 .to(
112 |req: turbolift::actix_web::HttpRequest|
113 turbolift::actix_web::HttpResponse::NotFound().body(
114 format!("endpoint not found: {}", req.uri())
115 )
116 )
117 )
118 )
119 .bind(ip_and_port)?
120 .run()
121 .compat()
122 .await
123 }).unwrap();
124 }
125 };
126
127 let function_cache_proj_path = CACHE_PATH.join(original_target_function_name.clone());
129 fs::create_dir_all(function_cache_proj_path.clone()).unwrap();
130 let files_to_copy: Vec<PathBuf> = fs::read_dir(".")
131 .expect("could not read dir")
132 .map(|res| res.expect("could not read entry").path())
133 .filter(|path| path.file_name() != CACHE_PATH.file_name())
134 .filter(
135 |path| path.to_str() != Some("./target"),
136 )
140 .collect();
141 fs_extra::copy_items(
142 &files_to_copy,
143 function_cache_proj_path.clone(),
144 &fs_extra::dir::CopyOptions {
145 overwrite: true,
146 ..Default::default()
147 },
148 )
149 .expect("error copying items to build cache");
150
151 let target_main_file = function_cache_proj_path.join("src").join("main.rs");
153 fs::write(target_main_file, main_file.to_string()).expect("error editing project main.rs");
154
155 build_project::edit_cargo_file(
157 PathBuf::from_str(".")
158 .expect("could not find project dir")
159 .canonicalize()
160 .expect("could not canonicalize path to project dir")
161 .as_path(),
162 &function_cache_proj_path.join("Cargo.toml"),
163 &original_target_function_name,
164 )
165 .expect("error editing cargo file");
166
167 if let Err(e) = build_project::lint(&function_cache_proj_path) {
169 tracing::error!(
170 error = e.as_ref() as &(dyn std::error::Error + 'static),
171 "ignoring linting error"
172 );
173 }
174
175 let project_source_binary = {
187 let tar = extract_function::make_compressed_proj_src(&function_cache_proj_path);
188 let tar_file = CACHE_PATH.join(original_target_function_name.clone() + "_source.tar");
189 fs::write(&tar_file, tar).expect("failure writing bin");
190 TokenStream2::from_str(&format!(
191 "std::include_bytes!(\"{}\")",
192 tar_file
193 .canonicalize()
194 .expect("error canonicalizing tar file location")
195 .to_str()
196 .expect("failure converting file path to str")
197 ))
198 .expect("syntax error while embedding project tar.")
199 };
200
201 let declare_and_dispatch = q! {
203 extern crate turbolift;
204
205 #[turbolift::tracing::instrument]
207 async fn #original_target_function_ident(#typed_params) ->
208 turbolift::DistributionResult<#result_type> {
209 use std::time::Duration;
210 use turbolift::distributed_platform::DistributionPlatform;
211 use turbolift::DistributionResult;
212 use turbolift::tokio_compat_02::FutureExt;
213 use turbolift::uuid::Uuid;
214
215 let mut platform = #distribution_platform.lock().await;
216
217 if !platform.has_declared(#original_target_function_name) {
218 platform
219 .declare(#original_target_function_name, #project_source_binary)
220 .compat()
221 .await?;
222 }
223
224 let params = #params_vec.join("/");
225 let resp_string = platform
226 .dispatch(
227 #original_target_function_name,
228 params.to_string()
229 )
230 .compat()
231 .await?;
232 Ok(turbolift::serde_json::from_str(&resp_string)?)
233 }
234 };
235 declare_and_dispatch.into()
236}
237
238#[cfg(not(feature = "distributed"))]
239#[proc_macro_attribute]
240pub fn on(_distribution_platform: TokenStream, function_: TokenStream) -> TokenStream {
241 let function = TokenStream2::from(function_);
243 let mut wrapped_original_function = extract_function::get_fn_item(function);
244 let original_target_function_ident = wrapped_original_function.sig.ident.clone();
245 let signature = wrapped_original_function.sig.clone();
246 let typed_params = signature.inputs;
247 let untyped_params = extract_function::to_untyped_params(typed_params.clone());
248 let output_type = extract_function::get_result_type(&signature.output);
249 wrapped_original_function.sig.ident = Ident::new("wrapped_function", Span::call_site());
250
251 let async_function = q! {
252 extern crate turbolift;
253
254 #[turbolift::tracing::instrument]
255 async fn #original_target_function_ident(#typed_params) -> turbolift::DistributionResult<#output_type> {
256 #wrapped_original_function
257 Ok(wrapped_function(#untyped_params))
258 }
259 };
260 async_function.into()
261}
262
263#[proc_macro_attribute]
264pub fn with(_attr: TokenStream, _item: TokenStream) -> TokenStream {
265 unimplemented!()
266}