switchboard_solana_macros/
lib.rs1extern crate proc_macro;
2
3mod params;
4mod utils;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{FnArg, ItemFn, Result as SynResult, ReturnType, Type};
9
10#[proc_macro_attribute]
11pub fn switchboard_function(attr: TokenStream, item: TokenStream) -> TokenStream {
12 let macro_params = match syn::parse::<params::SwitchboardSolanaFunctionArgs>(attr.clone()) {
14 Ok(args) => args,
15 Err(err) => {
16 let e = syn::Error::new_spanned(
17 err.to_compile_error(),
18 format!("Failed to parse macro parameters: {:?}", err),
19 );
20
21 return e.to_compile_error().into();
22 }
23 };
24
25 match build_token_stream(macro_params, item) {
27 Ok(token_stream) => token_stream,
28 Err(err) => err.to_compile_error().into(),
29 }
30}
31
32fn validate_function_runner_param_mut_ref(input: &ItemFn) -> SynResult<()> {
34 let first_param_type = input.sig.inputs.iter().next().ok_or_else(|| {
36 syn::Error::new_spanned(
37 &input.sig,
38 "The switchboard_function must take at least one parameter",
39 )
40 })?;
41
42 let typed_arg = match first_param_type {
44 FnArg::Typed(typed) => typed,
45 _ => {
46 return Err(syn::Error::new_spanned(
47 first_param_type,
48 "Expected a typed parameter",
49 ));
50 }
51 };
52
53 let is_function_runner_param = if let Type::Reference(type_reference) = &*typed_arg.ty {
55 if let Type::Path(type_path) = &*type_reference.elem {
56 type_reference.mutability.is_some() && type_path.path.is_ident("FunctionRunner")
58 } else {
59 false
60 }
61 } else {
62 false
63 };
64
65 if !is_function_runner_param {
66 return Err(syn::Error::new_spanned(
67 &typed_arg.ty,
68 "First parameter must be of type `&mut FunctionRunner`",
69 ));
70 }
71
72 Ok(())
73}
74
75fn validate_function_runner_param_arc(input: &ItemFn) -> SynResult<()> {
77 let first_param_type = input.sig.inputs.iter().next().ok_or_else(|| {
79 syn::Error::new_spanned(
80 &input.sig,
81 "The switchboard_function must take at least one parameter",
82 )
83 })?;
84
85 let typed_arg = match first_param_type {
87 FnArg::Typed(typed) => typed,
88 _ => {
89 return Err(syn::Error::new_spanned(
90 first_param_type,
91 "Expected a typed parameter",
92 ));
93 }
94 };
95
96 let inner_type = utils::extract_inner_type_from_arc(&typed_arg.ty).ok_or_else(|| {
98 syn::Error::new_spanned(
99 &typed_arg.ty,
100 "Parameter must be of type Arc<FunctionRunner>",
101 )
102 })?;
103
104 let is_function_runner = if let Type::Path(type_path) = inner_type {
106 &type_path.path.segments.last().unwrap().ident == "FunctionRunner"
107 } else {
108 false
109 };
110
111 if !is_function_runner {
112 return Err(syn::Error::new_spanned(
113 &typed_arg.ty,
114 "Parameter inside Arc must be of type FunctionRunner",
115 ));
116 }
117
118 Ok(())
119}
120
121fn validate_function_runner_param(input: &ItemFn) -> SynResult<()> {
123 let first_param_type = input.sig.inputs.iter().next().ok_or_else(|| {
125 syn::Error::new_spanned(
126 &input.sig,
127 "The switchboard_function must take at least one parameter",
128 )
129 })?;
130
131 let typed_arg = match first_param_type {
132 FnArg::Typed(typed) => typed,
133 _ => {
134 return Err(syn::Error::new_spanned(
135 first_param_type,
136 "Expected a typed parameter",
137 ));
138 }
139 };
140
141 let inner_type = utils::extract_inner_type_from_arc(&typed_arg.ty).ok_or_else(|| {
142 syn::Error::new_spanned(
143 &typed_arg.ty,
144 "Parameter must be of type Arc<FunctionRunner>",
145 )
146 })?;
147
148 let is_function_runner = if let Type::Path(type_path) = inner_type {
149 &type_path.path.segments.last().unwrap().ident == "FunctionRunner"
150 } else {
151 false
152 };
153
154 if !is_function_runner {
155 return Err(syn::Error::new_spanned(
156 &typed_arg.ty,
157 "Parameter must be an Arc<FunctionRunner>",
158 ));
159 }
160
161 Ok(())
162}
163
164fn validate_function_return_type(input: &ItemFn) -> SynResult<()> {
166 let ty = match &input.sig.output {
167 ReturnType::Type(_, ty) => ty,
168 ReturnType::Default => {
169 return Err(syn::Error::new_spanned(
170 &input.sig.output,
171 "Function does not have a return type",
172 ));
173 }
174 };
175
176 let (ok_type, err_type) = utils::extract_result_args(ty).ok_or_else(|| {
177 syn::Error::new_spanned(&input.sig.output, "Return type must be a Result")
178 })?;
179
180 let inner_vec_type = utils::extract_inner_type_from_vec(ok_type).ok_or_else(|| {
182 syn::Error::new_spanned(
183 &input.sig.output,
184 "Ok variant of Result must be a Vec<Instruction>",
185 )
186 })?;
187
188 if !matches!(inner_vec_type, Type::Path(t) if t.path.is_ident("Instruction")) {
189 return Err(syn::Error::new_spanned(
190 &input.sig.output,
191 "Ok variant of Result must be a Vec<Instruction>",
192 ));
193 }
194
195 let error_type_path_segments = match err_type {
197 Type::Path(type_path) => &type_path.path.segments,
198 _ => {
199 return Err(syn::Error::new_spanned(
200 err_type,
201 "Error type must be a path type",
202 ));
203 }
204 };
205
206 let is_sb_function_error = match error_type_path_segments.last() {
208 Some(last_segment) if last_segment.ident == "SbFunctionError" => true,
209 Some(last_segment) if last_segment.ident == "Error" => {
210 error_type_path_segments.len() > 1
212 && error_type_path_segments[error_type_path_segments.len() - 2].ident
213 == "switchboard_common"
214 }
215 _ => false,
216 };
217
218 if !is_sb_function_error {
219 return Err(syn::Error::new_spanned(
220 &input.sig.output,
221 "The error variant in the Result return type should be SbFunctionError",
222 ));
223 }
224
225 Ok(())
226}
227
228fn validate_second_parameter(input: &ItemFn) -> SynResult<()> {
229 let second_param = input.sig.inputs.iter().nth(1).ok_or_else(|| {
230 syn::Error::new_spanned(
231 &input.sig,
232 "The switchboard_function must take two parameters",
233 )
234 })?;
235
236 let typed_arg = match second_param {
237 FnArg::Typed(typed) => typed,
238 _ => {
239 return Err(syn::Error::new_spanned(
240 second_param,
241 "Expected a typed second parameter",
242 ));
243 }
244 };
245
246 let inner_type = utils::extract_inner_type_from_vec(&typed_arg.ty).ok_or_else(|| {
248 syn::Error::new_spanned(
249 &typed_arg.ty,
250 "The second parameter must be of type Vec<u8>",
251 )
252 })?;
253
254 if let Type::Path(type_path) = inner_type {
256 if !type_path.path.is_ident("u8") {
257 return Err(syn::Error::new_spanned(
258 &typed_arg.ty,
259 "The second parameter must be of type Vec<u8>",
260 ));
261 }
262 } else {
263 return Err(syn::Error::new_spanned(
264 &typed_arg.ty,
265 "The second parameter must be of type Vec<u8>",
266 ));
267 }
268
269 Ok(())
270}
271
272fn build_token_stream(
273 _params: params::SwitchboardSolanaFunctionArgs,
274 item: TokenStream,
275) -> SynResult<TokenStream> {
276 let input: ItemFn = syn::parse(item.clone())?;
277 let function_name = &input.sig.ident;
278
279 if input.sig.inputs.len() != 2 {
281 return Err(
282 syn::Error::new_spanned(
283 &input.sig,
284 "The switchboard_function must take exactly one parameter of type 'Arc<FunctionRunner>' and 'Vec<u8>'"
285 )
286 );
287 }
288
289 validate_function_return_type(&input)?;
290
291 validate_function_runner_param(&input)?;
294 validate_second_parameter(&input)?;
295
296 let expanded = quote! {
297 use switchboard_solana::prelude::*;
298
299 #input
301
302 pub type SwitchboardFunctionResult<T> = std::result::Result<T, SbFunctionError>;
303
304 pub async fn run_switchboard_function<F, T>(
306 logic: F,
307 ) -> SwitchboardFunctionResult<()>
308 where
309 F: Fn(Arc<FunctionRunner>, Vec<u8>) -> T + Send + 'static,
310 T: futures::Future<Output = SwitchboardFunctionResult<Vec<Instruction>>>
311 + Send,
312 {
313 let mut runner = FunctionRunner::from_env(None).unwrap();
315
316 runner.load_accounts().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
318
319 let params = runner.load_params().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
321 let runner = Arc::new(runner);
322 let commitment = None;
324 match logic(runner.clone(), params).await {
325 Ok(ixs) => {
326 runner
327 .emit(ixs, Some(commitment.unwrap_or(solana_sdk::commitment_config::CommitmentConfig::confirmed())))
328 .await
329 .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
330
331 Ok(())
332 }
333 Err(e) => {
334 println!("Error: Switchboard function failed with error code: {:?}", e);
335 let mut err_code = 199;
336 if let SbFunctionError::FunctionError(code) = e {
337 err_code = code;
338 }
339 runner
340 .emit_error(err_code, None)
341 .await
342 .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
343
344 Ok(())
345 }
346 }
347 }
348
349 pub async fn run_switchboard_function_simulation<F, T>(
352 logic: F,
353 ) -> SwitchboardFunctionResult<()>
354 where
355 F: Fn(Arc<FunctionRunner>, Vec<u8>) -> T + Send + 'static,
356 T: futures::Future<Output = SwitchboardFunctionResult<Vec<Instruction>>>
357 + Send,
358 {
359 let mut runner = FunctionRunner::from_env(None).unwrap();
361
362 runner.load_accounts().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
364
365 let params = runner.load_params().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
367
368 let runner = Arc::new(runner);
369 match logic(runner.clone(), params).await {
370 Ok(ixs) => {
371 match runner.get_function_result(ixs.clone(), 0, None).await {
372 Ok(function_result) => {
373 let serialized_output = format!(
374 "{}{}",
375 FUNCTION_RESULT_PREFIX,
376 function_result.hex_encode()
377 );
378
379 println!("\n## Output\n{}", serialized_output);
380 println!("\n## Instructions\n{:#?}", ixs.clone());
381 }
382 Err(e) => {
383 panic!("Failed to get FunctionResult from ixs: {:?}", e);
384 }
385 }
386
387 Ok(())
388 }
389 Err(e) => {
390 println!("Error: Switchboard function failed with error code: {:?}", e);
391 let mut err_code = 199;
392 if let SbFunctionError::FunctionError(code) = e {
393 err_code = code;
394 }
395 runner
396 .emit_error(err_code, None)
397 .await
398 .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
399
400 Ok(())
401 }
402 }
403 }
404
405 #[tokio::main(worker_threads = 12)]
406 async fn main() -> SwitchboardFunctionResult<()> {
407 let is_simulation = match std::env::var("SWITCHBOARD_FUNCTION_SIMULATION") {
408 Ok(value) => {
409 let value = value.to_lowercase().trim().to_string();
410 value == "1" || value == "true"
411 }
412 Err(_) => false,
413 };
414
415 if is_simulation {
416 println!("[Debug] Simulation mode detected");
417 #[cfg(feature = "dotenv")]
418 dotenvy::dotenv().ok();
419
420
421
422 run_switchboard_function_simulation(#function_name).await?;
423 } else {
424 run_switchboard_function(#function_name).await?;
425 }
426
427
428 Ok(())
429 }
430 };
431
432 Ok(TokenStream::from(expanded))
433}
434
435#[proc_macro_attribute]
436pub fn sb_error(_attr: TokenStream, item: TokenStream) -> TokenStream {
437 let input = syn::parse_macro_input!(item as syn::DeriveInput);
438
439 let name = &input.ident;
440 let expanded = quote! {
441 #[derive(Clone, Copy, Debug, PartialEq)]
442 #[repr(u8)]
443 #input
444
445 impl From<#name> for SbFunctionError {
446 fn from(item: #name) -> Self {
447 SbFunctionError::FunctionError(item as u8 + 1)
448 }
449 }
450
451 impl From<#name> for u8 {
452 fn from(item: #name) -> Self {
453 item as u8 + 1
454 }
455 }
456
457 impl std::fmt::Display for #name {
458 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
459 write!(f, "{:?}", self)
460 }
461 }
462
463 impl std::error::Error for #name {}
464 };
465
466 TokenStream::from(expanded)
467}