switchboard_solana_macros/
lib.rs

1extern crate proc_macro;
2
3mod params;
4mod utils;
5
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{FnArg, ItemFn, Result as SynResult, ReturnType, Type};
9
10#[proc_macro_attribute]
11pub fn switchboard_function(attr: TokenStream, item: TokenStream) -> TokenStream {
12    // Parse the macro parameters to set a timeout
13    let macro_params = match syn::parse::<params::SwitchboardSolanaFunctionArgs>(attr.clone()) {
14        Ok(args) => args,
15        Err(err) => {
16            let e = syn::Error::new_spanned(
17                err.to_compile_error(),
18                format!("Failed to parse macro parameters: {:?}", err),
19            );
20
21            return e.to_compile_error().into();
22        }
23    };
24
25    // Try to build the token stream, return errors if failed
26    match build_token_stream(macro_params, item) {
27        Ok(token_stream) => token_stream,
28        Err(err) => err.to_compile_error().into(),
29    }
30}
31
32/// Validates whether the first param is &mut FunctionRunner
33fn validate_function_runner_param_mut_ref(input: &ItemFn) -> SynResult<()> {
34    // Extract the first parameter of the function.
35    let first_param_type = input.sig.inputs.iter().next().ok_or_else(|| {
36        syn::Error::new_spanned(
37            &input.sig,
38            "The switchboard_function must take at least one parameter",
39        )
40    })?;
41
42    // Ensure the first parameter is a typed argument.
43    let typed_arg = match first_param_type {
44        FnArg::Typed(typed) => typed,
45        _ => {
46            return Err(syn::Error::new_spanned(
47                first_param_type,
48                "Expected a typed parameter",
49            ));
50        }
51    };
52
53    // Check if the first parameter is a mutable reference to FunctionRunner.
54    let is_function_runner_param = if let Type::Reference(type_reference) = &*typed_arg.ty {
55        if let Type::Path(type_path) = &*type_reference.elem {
56            type_reference.mutability.is_some() && // Check for mutability
57                type_path.path.is_ident("FunctionRunner")
58        } else {
59            false
60        }
61    } else {
62        false
63    };
64
65    if !is_function_runner_param {
66        return Err(syn::Error::new_spanned(
67            &typed_arg.ty,
68            "First parameter must be of type `&mut FunctionRunner`",
69        ));
70    }
71
72    Ok(())
73}
74
75/// Validates whether the first param is Arc<FunctionRunner>
76fn validate_function_runner_param_arc(input: &ItemFn) -> SynResult<()> {
77    // Extract the first parameter of the function.
78    let first_param_type = input.sig.inputs.iter().next().ok_or_else(|| {
79        syn::Error::new_spanned(
80            &input.sig,
81            "The switchboard_function must take at least one parameter",
82        )
83    })?;
84
85    // Ensure the first parameter is a typed argument.
86    let typed_arg = match first_param_type {
87        FnArg::Typed(typed) => typed,
88        _ => {
89            return Err(syn::Error::new_spanned(
90                first_param_type,
91                "Expected a typed parameter",
92            ));
93        }
94    };
95
96    // Extract the inner type from the expected `Arc`.
97    let inner_type = utils::extract_inner_type_from_arc(&typed_arg.ty).ok_or_else(|| {
98        syn::Error::new_spanned(
99            &typed_arg.ty,
100            "Parameter must be of type Arc<FunctionRunner>",
101        )
102    })?;
103
104    // Check that the inner type of `Arc` is `FunctionRunner`.
105    let is_function_runner = if let Type::Path(type_path) = inner_type {
106        &type_path.path.segments.last().unwrap().ident == "FunctionRunner"
107    } else {
108        false
109    };
110
111    if !is_function_runner {
112        return Err(syn::Error::new_spanned(
113            &typed_arg.ty,
114            "Parameter inside Arc must be of type FunctionRunner",
115        ));
116    }
117
118    Ok(())
119}
120
121/// Validates whether the first param is FunctionRunner
122fn validate_function_runner_param(input: &ItemFn) -> SynResult<()> {
123    // Extract the first parameter of the function.
124    let first_param_type = input.sig.inputs.iter().next().ok_or_else(|| {
125        syn::Error::new_spanned(
126            &input.sig,
127            "The switchboard_function must take at least one parameter",
128        )
129    })?;
130
131    let typed_arg = match first_param_type {
132        FnArg::Typed(typed) => typed,
133        _ => {
134            return Err(syn::Error::new_spanned(
135                first_param_type,
136                "Expected a typed parameter",
137            ));
138        }
139    };
140
141    let inner_type = utils::extract_inner_type_from_arc(&typed_arg.ty).ok_or_else(|| {
142        syn::Error::new_spanned(
143            &typed_arg.ty,
144            "Parameter must be of type Arc<FunctionRunner>",
145        )
146    })?;
147
148    let is_function_runner = if let Type::Path(type_path) = inner_type {
149        &type_path.path.segments.last().unwrap().ident == "FunctionRunner"
150    } else {
151        false
152    };
153
154    if !is_function_runner {
155        return Err(syn::Error::new_spanned(
156            &typed_arg.ty,
157            "Parameter must be an Arc<FunctionRunner>",
158        ));
159    }
160
161    Ok(())
162}
163
164/// Helper function to validate the return type is a Result with the correct arguments.
165fn validate_function_return_type(input: &ItemFn) -> SynResult<()> {
166    let ty = match &input.sig.output {
167        ReturnType::Type(_, ty) => ty,
168        ReturnType::Default => {
169            return Err(syn::Error::new_spanned(
170                &input.sig.output,
171                "Function does not have a return type",
172            ));
173        }
174    };
175
176    let (ok_type, err_type) = utils::extract_result_args(ty).ok_or_else(|| {
177        syn::Error::new_spanned(&input.sig.output, "Return type must be a Result")
178    })?;
179
180    // Validate the inner Vec type
181    let inner_vec_type = utils::extract_inner_type_from_vec(ok_type).ok_or_else(|| {
182        syn::Error::new_spanned(
183            &input.sig.output,
184            "Ok variant of Result must be a Vec<Instruction>",
185        )
186    })?;
187
188    if !matches!(inner_vec_type, Type::Path(t) if t.path.is_ident("Instruction")) {
189        return Err(syn::Error::new_spanned(
190            &input.sig.output,
191            "Ok variant of Result must be a Vec<Instruction>",
192        ));
193    }
194
195    // Validate the error type
196    let error_type_path_segments = match err_type {
197        Type::Path(type_path) => &type_path.path.segments,
198        _ => {
199            return Err(syn::Error::new_spanned(
200                err_type,
201                "Error type must be a path type",
202            ));
203        }
204    };
205
206    // Check if the error type is SbFunctionError or switchboard_common::SbFunctionError
207    let is_sb_function_error = match error_type_path_segments.last() {
208        Some(last_segment) if last_segment.ident == "SbFunctionError" => true,
209        Some(last_segment) if last_segment.ident == "Error" => {
210            // If the last segment is "Error", check the preceding segment for "switchboard_common"
211            error_type_path_segments.len() > 1
212                && error_type_path_segments[error_type_path_segments.len() - 2].ident
213                    == "switchboard_common"
214        }
215        _ => false,
216    };
217
218    if !is_sb_function_error {
219        return Err(syn::Error::new_spanned(
220            &input.sig.output,
221            "The error variant in the Result return type should be SbFunctionError",
222        ));
223    }
224
225    Ok(())
226}
227
228fn validate_second_parameter(input: &ItemFn) -> SynResult<()> {
229    let second_param = input.sig.inputs.iter().nth(1).ok_or_else(|| {
230        syn::Error::new_spanned(
231            &input.sig,
232            "The switchboard_function must take two parameters",
233        )
234    })?;
235
236    let typed_arg = match second_param {
237        FnArg::Typed(typed) => typed,
238        _ => {
239            return Err(syn::Error::new_spanned(
240                second_param,
241                "Expected a typed second parameter",
242            ));
243        }
244    };
245
246    // Use the utility function to extract the inner type from a Vec
247    let inner_type = utils::extract_inner_type_from_vec(&typed_arg.ty).ok_or_else(|| {
248        syn::Error::new_spanned(
249            &typed_arg.ty,
250            "The second parameter must be of type Vec<u8>",
251        )
252    })?;
253
254    // Ensure the inner type of the Vec is u8
255    if let Type::Path(type_path) = inner_type {
256        if !type_path.path.is_ident("u8") {
257            return Err(syn::Error::new_spanned(
258                &typed_arg.ty,
259                "The second parameter must be of type Vec<u8>",
260            ));
261        }
262    } else {
263        return Err(syn::Error::new_spanned(
264            &typed_arg.ty,
265            "The second parameter must be of type Vec<u8>",
266        ));
267    }
268
269    Ok(())
270}
271
272fn build_token_stream(
273    _params: params::SwitchboardSolanaFunctionArgs,
274    item: TokenStream,
275) -> SynResult<TokenStream> {
276    let input: ItemFn = syn::parse(item.clone())?;
277    let function_name = &input.sig.ident;
278
279    // Validate that there's exactly one input of the correct type
280    if input.sig.inputs.len() != 2 {
281        return Err(
282            syn::Error::new_spanned(
283                &input.sig,
284                "The switchboard_function must take exactly one parameter of type 'Arc<FunctionRunner>' and 'Vec<u8>'"
285            )
286        );
287    }
288
289    validate_function_return_type(&input)?;
290
291    // Validate input parameters
292    // validate_function_runner_param_arc(&input)?;
293    validate_function_runner_param(&input)?;
294    validate_second_parameter(&input)?;
295
296    let expanded = quote! {
297            use switchboard_solana::prelude::*;
298
299            // Include the original function definition
300            #input
301
302            pub type SwitchboardFunctionResult<T> = std::result::Result<T, SbFunctionError>;
303
304            /// Run an async function and catch any panics
305            pub async fn run_switchboard_function<F, T>(
306                logic: F,
307            ) -> SwitchboardFunctionResult<()>
308            where
309                F: Fn(Arc<FunctionRunner>, Vec<u8>) -> T + Send + 'static,
310                T: futures::Future<Output = SwitchboardFunctionResult<Vec<Instruction>>>
311                    + Send,
312            {
313                // Initialize the runner
314                let mut runner = FunctionRunner::from_env(None).unwrap();
315
316                // Lets pre-load all of the accounts we'll need to yield our container parameters
317                runner.load_accounts().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
318
319                // Parse the container parameters based on our loaded accounts
320                let params = runner.load_params().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
321                let runner = Arc::new(runner);
322                /// TODO:
323                let commitment = None;
324                match logic(runner.clone(), params).await {
325                    Ok(ixs) => {
326                        runner
327                            .emit(ixs, Some(commitment.unwrap_or(solana_sdk::commitment_config::CommitmentConfig::confirmed())))
328                            .await
329                            .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
330
331                        Ok(())
332                    }
333                    Err(e) => {
334                        println!("Error: Switchboard function failed with error code: {:?}", e);
335                        let mut err_code = 199;
336                        if let SbFunctionError::FunctionError(code) = e {
337                            err_code = code;
338                        }
339                        runner
340                            .emit_error(err_code, None)
341                            .await
342                            .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
343
344                        Ok(())
345                    }
346                }
347            }
348
349            /// Run an async function and catch any panics
350            // Maintain separate implementations for better peace of mind.
351            pub async fn run_switchboard_function_simulation<F, T>(
352                logic: F,
353            ) -> SwitchboardFunctionResult<()>
354            where
355                F: Fn(Arc<FunctionRunner>, Vec<u8>) -> T + Send + 'static,
356                T: futures::Future<Output = SwitchboardFunctionResult<Vec<Instruction>>>
357                    + Send,
358            {
359                // Initialize the runner
360                let mut runner = FunctionRunner::from_env(None).unwrap();
361
362                // Lets pre-load all of the accounts we'll need to yield our container parameters
363                runner.load_accounts().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
364
365                // Parse the container parameters based on our loaded accounts
366                let params = runner.load_params().await.map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
367
368                let runner = Arc::new(runner);
369                match logic(runner.clone(), params).await {
370                    Ok(ixs) => {
371                        match runner.get_function_result(ixs.clone(), 0, None).await {
372                            Ok(function_result) => {
373                                let serialized_output = format!(
374                                    "{}{}",
375                                    FUNCTION_RESULT_PREFIX,
376                                    function_result.hex_encode()
377                                );
378
379                                println!("\n## Output\n{}", serialized_output);
380                                println!("\n## Instructions\n{:#?}", ixs.clone());
381                            }
382                            Err(e) => {
383                                panic!("Failed to get FunctionResult from ixs: {:?}", e);
384                            }
385                        }
386
387                        Ok(())
388                    }
389                    Err(e) => {
390                        println!("Error: Switchboard function failed with error code: {:?}", e);
391                        let mut err_code = 199;
392                        if let SbFunctionError::FunctionError(code) = e {
393                            err_code = code;
394                        }
395                        runner
396                            .emit_error(err_code, None)
397                            .await
398                            .map_err(|_e| SbFunctionError::FunctionResultEmitError)?;
399
400                        Ok(())
401                    }
402                }
403            }
404
405            #[tokio::main(worker_threads = 12)]
406            async fn main() -> SwitchboardFunctionResult<()> {
407                let is_simulation = match std::env::var("SWITCHBOARD_FUNCTION_SIMULATION") {
408                    Ok(value) => {
409                        let value = value.to_lowercase().trim().to_string();
410                        value == "1" || value == "true"
411                    }
412                    Err(_) => false,
413                };
414
415                if is_simulation {
416                    println!("[Debug] Simulation mode detected");
417                    #[cfg(feature = "dotenv")]
418                    dotenvy::dotenv().ok();
419
420
421
422                    run_switchboard_function_simulation(#function_name).await?;
423                } else {
424                    run_switchboard_function(#function_name).await?;
425                }
426
427
428                Ok(())
429            }
430    };
431
432    Ok(TokenStream::from(expanded))
433}
434
435#[proc_macro_attribute]
436pub fn sb_error(_attr: TokenStream, item: TokenStream) -> TokenStream {
437    let input = syn::parse_macro_input!(item as syn::DeriveInput);
438
439    let name = &input.ident;
440    let expanded = quote! {
441        #[derive(Clone, Copy, Debug, PartialEq)]
442        #[repr(u8)]
443        #input
444
445        impl From<#name> for SbFunctionError {
446            fn from(item: #name) -> Self {
447                SbFunctionError::FunctionError(item as u8 + 1)
448            }
449        }
450
451        impl From<#name> for u8 {
452            fn from(item: #name) -> Self {
453                item as u8 + 1
454            }
455        }
456
457        impl std::fmt::Display for #name {
458            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
459                write!(f, "{:?}", self)
460            }
461        }
462
463        impl std::error::Error for #name {}
464    };
465
466    TokenStream::from(expanded)
467}