Skip to main content

arcium_macros/
lib.rs

1use arcium_program_macro::arcium_program_macro;
2use callback_macros::{callback_accs_derive, callback_ix_derive, CallbackAccArgs};
3use check_args_macro::check_args_fn;
4use init_comp_def_macros::{init_comp_def_derive, InitCompDefArgs};
5use proc_macro::TokenStream;
6use queue_comp_macros::{queue_comp_derive, QueueCompArgs};
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput, Item, ItemFn, LitStr};
9use utils::{read_circuit_hash, ArciumCallbackArgs};
10
11mod arcium_program_macro;
12mod callback_macros;
13mod check_args_macro;
14mod gen_callback_types;
15mod init_comp_def_macros;
16mod queue_comp_macros;
17mod utils;
18mod validation;
19
20/// Automatically generates the callback instruction for a computation. The callback function must
21/// be named `<encrypted_ix>_callback` and take a single `SignedComputationOutputs<T>` argument in
22/// addition to the `Context` parameter.
23///
24/// The generic type parameter for `SignedComputationOutputs<T>` is automatically generated from
25/// your circuit's interface file (`build/<circuit_name>.idarc`). The generated type follows the
26/// naming convention `<CircuitName>Output` (e.g., circuit "my_circuit" generates
27/// `MyCircuitOutput`).
28///
29/// ```ignore
30/// #[callback_accounts("my_circuit")]
31/// #[derive(Accounts)]
32/// pub struct Callback<'info> {
33///     #[account(mut)]
34///     pub payer: Signer<'info>,
35///     pub arcium_program: Program<'info, Arcium>,
36///     #[account(
37///         address = derive_comp_def_pda!(COMP_DEF_OFFSET)
38///     )]
39///     pub comp_def_account: Account<'info, ComputationDefinitionAccount>,
40///     #[account(address = ::arcium_anchor::solana_instructions_sysvar::ID)]
41///     /// CHECK: instructions_sysvar, checked by the account constraint
42///     pub instructions_sysvar: UncheckedAccount<'info>,
43/// }
44///
45/// #[arcium_program]
46/// pub mod sample_program {
47///     // Will be called when the computation with circuit "my_circuit" resolves
48///     #[arcium_callback(encrypted_ix = "my_circuit")]
49///     pub fn my_circuit_callback(
50///         ctx: Context<Callback>,
51///         output: SignedComputationOutputs<MyCircuitOutput>,
52///     ) -> Result<()> {
53///         // Destructure and handle Success/Failure
54///         let result = match output {
55///             SignedComputationOutputs::Success(MyCircuitOutput { field_0 }) => field_0,
56///             SignedComputationOutputs::Failure => {
57///                 return Err(ErrorCode::ComputationFailed.into());
58///             }
59///         };
60///         msg!("Computation succeeded with result: {:?}", result);
61///         Ok(())
62///     }
63/// }
64/// ```
65#[proc_macro_attribute]
66pub fn arcium_callback(args: TokenStream, item: TokenStream) -> TokenStream {
67    let args = parse_macro_input!(args as ArciumCallbackArgs);
68    let input_fn = parse_macro_input!(item as ItemFn);
69    callback_ix_derive(input_fn, args)
70}
71
72/// Validates the structure for queuing computations by checking the encrypted instruction exists,
73/// validating required account fields, and implementing the `QueueCompAccs` trait to make using it
74/// with queuing computations easy:
75///
76/// ```ignore
77/// #[queue_computation_accounts("add_together", payer)]
78/// #[derive(Accounts)]
79/// pub struct InitComputation<'info> {
80///     #[account(mut)]
81///     pub payer: Signer<'info>,
82///     #[account(
83///         address = derive_mxe_pda!()
84///     )]
85///     pub mxe_account: Account<'info, MXEAccount>,
86///     #[account(
87///         mut,
88///         address = derive_mempool_pda!()
89///     )]
90///     pub mempool_account: Account<'info, Mempool>,
91///     #[account(
92///         mut,
93///         address = derive_execpool_pda!()
94///     )]
95///     pub executing_pool: Account<'info, ExecutingPool>,
96///     #[account(
97///         address = derive_comp_def_pda!(COMP_DEF_OFFSET)
98///     )]
99///     pub comp_def_account: Account<'info, ComputationDefinitionAccount>,
100///     #[account(
101///         mut,
102///         address = derive_cluster_pda!(mxe_account)
103///     )]
104///     pub cluster_account: Account<'info, Cluster>,
105///     #[account(
106///         mut,
107///         address = ARCIUM_FEE_POOL_ACCOUNT_ADDRESS,
108///     )]
109///     pub pool_account: Account<'info, FeePool>,
110///     #[account(
111///         mut,
112///         address = ARCIUM_CLOCK_ACCOUNT_ADDRESS
113///     )]
114///     pub clock_account: Account<'info, ClockAccount>,
115///     pub system_program: Program<'info, System>,
116///     pub arcium_program: Program<'info, Arcium>,
117/// }
118///
119/// #[arcium_program]
120/// pub mod sample_program {
121///     pub fn submit_computation(
122///         ctx: Context<InitComputation>,
123///         x: [u8; 32],
124///         y: [u8; 32],
125///         computation_offset: u64,
126///     ) -> Result<()> {
127///         // This will queue a computation that will execute "add_together" circuit
128///         let args = ArgBuilder::new()
129///             .encrypted_u8(x)
130///             .encrypted_u8(y)
131///             .build();
132///         // Parameters: accs, computation_offset, args,callback_instructions, num_callback_txs, cu_price_micro
133///         queue_computation(
134///             &ctx.accounts,
135///             computation_offset,
136///             args,
137///             vec![AddTogetherCallback::callback_ix(&[])],
138///             1,
139///             0,
140///         )?;
141///         Ok(())
142///     }
143/// }
144/// ```
145#[proc_macro_attribute]
146pub fn queue_computation_accounts(
147    args: proc_macro::TokenStream,
148    item: proc_macro::TokenStream,
149) -> proc_macro::TokenStream {
150    // Parse the input tokens into a syntax tree
151    let mut input = parse_macro_input!(item as DeriveInput);
152    let args = parse_macro_input!(args as QueueCompArgs);
153    queue_comp_derive(&mut input, args)
154}
155
156/// Validates the structure for computation callbacks by checking the encrypted instruction exists,
157/// validating required account fields, and ensuring the structure has the correct fields for
158/// callbacks:
159///
160/// ```ignore
161/// #[callback_accounts("my_circuit")]
162/// #[derive(Accounts)]
163/// pub struct Callback<'info> {
164///     #[account(mut)]
165///     pub payer: Signer<'info>,
166///     pub arcium_program: Program<'info, Arcium>,
167///     #[account(
168///         address = derive_comp_def_pda!(COMP_DEF_OFFSET)
169///     )]
170///     pub comp_def_account: Account<'info, ComputationDefinitionAccount>,
171///     #[account(address = ::arcium_anchor::solana_instructions_sysvar::ID)]
172///     /// CHECK: instructions_sysvar, checked by the account constraint
173///     pub instructions_sysvar: UncheckedAccount<'info>,
174/// }
175///
176/// #[arcium_program]
177/// pub mod sample_program {
178///     // Will be called when the computation with circuit "my_circuit" resolves
179///     #[arcium_callback(encrypted_ix = "my_circuit")]
180///     pub fn my_circuit_callback(
181///         ctx: Context<Callback>,
182///         output: SignedComputationOutputs<MyCircuitOutput>,
183///     ) -> Result<()> {
184///         // Destructure and handle Success/Failure
185///         let result = match output {
186///             SignedComputationOutputs::Success(MyCircuitOutput { field_0 }) => field_0,
187///             SignedComputationOutputs::Failure => {
188///                 return Err(ErrorCode::ComputationFailed.into());
189///             }
190///         };
191///         msg!("Computation succeeded with result: {:?}", result);
192///         Ok(())
193///     }
194/// }
195/// ```
196#[proc_macro_attribute]
197pub fn callback_accounts(
198    args: proc_macro::TokenStream,
199    item: proc_macro::TokenStream,
200) -> proc_macro::TokenStream {
201    // Parse the input tokens into a syntax tree
202    let input = parse_macro_input!(item as DeriveInput);
203    let args = parse_macro_input!(args as CallbackAccArgs);
204    callback_accs_derive(&input, args)
205}
206
207/// The #[arcium_program] attribute defines the module
208/// containing all instruction handlers defining all entries into a Solana program.
209/// Under the hood, it gets expanded to Anchor's `#[program]` and some additional definitions needed
210/// for Arcium.
211#[proc_macro_attribute]
212pub fn arcium_program(
213    _args: proc_macro::TokenStream,
214    item: proc_macro::TokenStream,
215) -> proc_macro::TokenStream {
216    // Parse the input tokens into a syntax tree
217    let mut input = parse_macro_input!(item as Item);
218    arcium_program_macro(&mut input)
219}
220
221/// Validates the structure for initializing computation definitions by checking the encrypted
222/// instruction exists, validating required account fields, and implementing the `InitCompDefAccs`
223/// trait to make using it with computation definitions easy:
224///
225/// ```ignore
226/// #[init_computation_definition_accounts("my_circuit", payer)]
227/// #[derive(Accounts)]
228/// pub struct InitMyCircuitCompDef<'info> {
229///     #[account(mut)]
230///     pub payer: Signer<'info>,
231///     #[account(
232///         mut,
233///         address = derive_mxe_pda!()
234///     )]
235///     pub mxe_account: Box<Account<'info, MXEAccount>>,
236///     #[account(mut)]
237///     /// CHECK: comp_def_account, checked by arcium program.
238///     /// Can't check it here as it's not initialized yet.
239///     pub comp_def_account: UncheckedAccount<'info>,
240///     pub arcium_program: Program<'info, Arcium>,
241///     pub system_program: Program<'info, System>,
242/// }
243/// ```
244#[proc_macro_attribute]
245pub fn init_computation_definition_accounts(args: TokenStream, item: TokenStream) -> TokenStream {
246    // Parse the macro arguments
247    let args = parse_macro_input!(args as InitCompDefArgs);
248
249    // Parse the input tokens into a syntax tree
250    let mut input = parse_macro_input!(item as DeriveInput);
251
252    // Call the inner function with the parsed arguments
253    init_comp_def_derive(&mut input, args)
254}
255
256/// Compile-time validation of computation arguments against an interface definition.
257///
258/// This macro provides **compile-time** verification that your computation arguments match the
259/// circuit interface, catching type mismatches and argument count errors before runtime.
260///
261/// ## Usage
262/// 1. Attach `#[check_args]` to your function
263/// 2. Attach `#[args("your_circuit_name")]` to your computation arguments (array literal or `vec!`
264///    macro)
265///
266/// ## Checks Performed
267/// - Correct number of arguments based on the interface file (`build/your_circuit_name.idarc`)
268/// - Correct argument types (e.g., `PlaintextU32`, `EncryptedU8`, `X25519Pubkey`)
269/// - Correct argument order matching the circuit interface
270///
271/// ## Compile-Time Errors
272/// If arguments don't match the interface, you'll see a compilation error like:
273/// ```text
274/// error: mismatched types: expected Parameter::PlaintextU64, found Parameter::PlaintextU32
275/// ```
276///
277/// ## Example
278/// ```ignore
279/// #[check_args]
280/// pub fn submit_computation(ctx: Context<MyAccounts>, computation_offset: u64) -> Result<()> {
281///     queue_computation(
282///         &ctx.accounts,
283///         computation_offset,
284///         #[args("add_together")]  // Validates against build/add_together.idarc
285///         ArgBuilder::new()
286///             .plaintext_u64(10)
287///             .plaintext_u64(20)
288///             .build(),
289///         None,
290///         vec![],
291///         1,
292///         0,
293///     )?;
294///     Ok(())
295/// }
296/// ```
297#[proc_macro_attribute]
298pub fn check_args(_args: TokenStream, item: TokenStream) -> TokenStream {
299    // Parse as a Stmt
300    let input = parse_macro_input!(item as ItemFn);
301    check_args_fn(input).into()
302}
303
304/// Returns the SHA-256 hash of a compiled circuit as a `[u8; 32]` array at compile time.
305///
306/// The hash is computed over the serialized circuit bytecode and can be used for verifying
307/// off-chain circuit sources. This is the same hash that ARX nodes verify when fetching circuits
308/// from off-chain sources.
309///
310/// ## Example
311/// ```ignore
312/// use arcium_macros::circuit_hash;
313///
314/// let source = OffChainCircuitSource {
315///     source: "https://ipfs.io/ipfs/Qm...".into(),
316///     hash: circuit_hash!("my_circuit"),
317/// };
318/// ```
319#[proc_macro]
320pub fn circuit_hash(input: TokenStream) -> TokenStream {
321    let circuit_name = parse_macro_input!(input as LitStr);
322    let hash = read_circuit_hash(&circuit_name.value());
323    let hash_bytes = hash.iter();
324    quote! { [#(#hash_bytes),*] }.into()
325}