arcium_macros/lib.rs
1use arcium_program_macro::arcium_program_macro;
2use callback_macros::{callback_accs_derive, callback_ix_derive, CallbackAccArgs};
3use check_args_macro::check_args_fn;
4use init_comp_def_macros::{init_comp_def_derive, InitCompDefArgs};
5use proc_macro::TokenStream;
6use queue_comp_macros::{queue_comp_derive, QueueCompArgs};
7use quote::quote;
8use syn::{parse_macro_input, DeriveInput, Item, ItemFn, LitStr};
9use utils::{read_circuit_hash, ArciumCallbackArgs};
10
11mod arcium_program_macro;
12mod callback_macros;
13mod check_args_macro;
14mod gen_callback_types;
15mod init_comp_def_macros;
16mod queue_comp_macros;
17mod utils;
18mod validation;
19
20/// Automatically generates the callback instruction for a computation. The callback function must
21/// be named `<encrypted_ix>_callback` and take a single `SignedComputationOutputs<T>` argument in
22/// addition to the `Context` parameter.
23///
24/// The generic type parameter for `SignedComputationOutputs<T>` is automatically generated from
25/// your circuit's interface file (`build/<circuit_name>.idarc`). The generated type follows the
26/// naming convention `<CircuitName>Output` (e.g., circuit "my_circuit" generates
27/// `MyCircuitOutput`).
28///
29/// ```ignore
30/// #[callback_accounts("my_circuit")]
31/// #[derive(Accounts)]
32/// pub struct Callback<'info> {
33/// #[account(mut)]
34/// pub payer: Signer<'info>,
35/// pub arcium_program: Program<'info, Arcium>,
36/// #[account(
37/// address = derive_comp_def_pda!(COMP_DEF_OFFSET)
38/// )]
39/// pub comp_def_account: Account<'info, ComputationDefinitionAccount>,
40/// #[account(address = ::arcium_anchor::solana_instructions_sysvar::ID)]
41/// /// CHECK: instructions_sysvar, checked by the account constraint
42/// pub instructions_sysvar: UncheckedAccount<'info>,
43/// }
44///
45/// #[arcium_program]
46/// pub mod sample_program {
47/// // Will be called when the computation with circuit "my_circuit" resolves
48/// #[arcium_callback(encrypted_ix = "my_circuit")]
49/// pub fn my_circuit_callback(
50/// ctx: Context<Callback>,
51/// output: SignedComputationOutputs<MyCircuitOutput>,
52/// ) -> Result<()> {
53/// // Destructure and handle Success/Failure
54/// let result = match output {
55/// SignedComputationOutputs::Success(MyCircuitOutput { field_0 }) => field_0,
56/// SignedComputationOutputs::Failure => {
57/// return Err(ErrorCode::ComputationFailed.into());
58/// }
59/// };
60/// msg!("Computation succeeded with result: {:?}", result);
61/// Ok(())
62/// }
63/// }
64/// ```
65#[proc_macro_attribute]
66pub fn arcium_callback(args: TokenStream, item: TokenStream) -> TokenStream {
67 let args = parse_macro_input!(args as ArciumCallbackArgs);
68 let input_fn = parse_macro_input!(item as ItemFn);
69 callback_ix_derive(input_fn, args)
70}
71
72/// Validates the structure for queuing computations by checking the encrypted instruction exists,
73/// validating required account fields, and implementing the `QueueCompAccs` trait to make using it
74/// with queuing computations easy:
75///
76/// ```ignore
77/// #[queue_computation_accounts("add_together", payer)]
78/// #[derive(Accounts)]
79/// pub struct InitComputation<'info> {
80/// #[account(mut)]
81/// pub payer: Signer<'info>,
82/// #[account(
83/// address = derive_mxe_pda!()
84/// )]
85/// pub mxe_account: Account<'info, MXEAccount>,
86/// #[account(
87/// mut,
88/// address = derive_mempool_pda!()
89/// )]
90/// pub mempool_account: Account<'info, Mempool>,
91/// #[account(
92/// mut,
93/// address = derive_execpool_pda!()
94/// )]
95/// pub executing_pool: Account<'info, ExecutingPool>,
96/// #[account(
97/// address = derive_comp_def_pda!(COMP_DEF_OFFSET)
98/// )]
99/// pub comp_def_account: Account<'info, ComputationDefinitionAccount>,
100/// #[account(
101/// mut,
102/// address = derive_cluster_pda!(mxe_account)
103/// )]
104/// pub cluster_account: Account<'info, Cluster>,
105/// #[account(
106/// mut,
107/// address = ARCIUM_FEE_POOL_ACCOUNT_ADDRESS,
108/// )]
109/// pub pool_account: Account<'info, FeePool>,
110/// #[account(
111/// mut,
112/// address = ARCIUM_CLOCK_ACCOUNT_ADDRESS
113/// )]
114/// pub clock_account: Account<'info, ClockAccount>,
115/// pub system_program: Program<'info, System>,
116/// pub arcium_program: Program<'info, Arcium>,
117/// }
118///
119/// #[arcium_program]
120/// pub mod sample_program {
121/// pub fn submit_computation(
122/// ctx: Context<InitComputation>,
123/// x: [u8; 32],
124/// y: [u8; 32],
125/// computation_offset: u64,
126/// ) -> Result<()> {
127/// // This will queue a computation that will execute "add_together" circuit
128/// let args = ArgBuilder::new()
129/// .encrypted_u8(x)
130/// .encrypted_u8(y)
131/// .build();
132/// // Parameters: accs, computation_offset, args,callback_instructions, num_callback_txs, cu_price_micro
133/// queue_computation(
134/// &ctx.accounts,
135/// computation_offset,
136/// args,
137/// vec![AddTogetherCallback::callback_ix(&[])],
138/// 1,
139/// 0,
140/// )?;
141/// Ok(())
142/// }
143/// }
144/// ```
145#[proc_macro_attribute]
146pub fn queue_computation_accounts(
147 args: proc_macro::TokenStream,
148 item: proc_macro::TokenStream,
149) -> proc_macro::TokenStream {
150 // Parse the input tokens into a syntax tree
151 let mut input = parse_macro_input!(item as DeriveInput);
152 let args = parse_macro_input!(args as QueueCompArgs);
153 queue_comp_derive(&mut input, args)
154}
155
156/// Validates the structure for computation callbacks by checking the encrypted instruction exists,
157/// validating required account fields, and ensuring the structure has the correct fields for
158/// callbacks:
159///
160/// ```ignore
161/// #[callback_accounts("my_circuit")]
162/// #[derive(Accounts)]
163/// pub struct Callback<'info> {
164/// #[account(mut)]
165/// pub payer: Signer<'info>,
166/// pub arcium_program: Program<'info, Arcium>,
167/// #[account(
168/// address = derive_comp_def_pda!(COMP_DEF_OFFSET)
169/// )]
170/// pub comp_def_account: Account<'info, ComputationDefinitionAccount>,
171/// #[account(address = ::arcium_anchor::solana_instructions_sysvar::ID)]
172/// /// CHECK: instructions_sysvar, checked by the account constraint
173/// pub instructions_sysvar: UncheckedAccount<'info>,
174/// }
175///
176/// #[arcium_program]
177/// pub mod sample_program {
178/// // Will be called when the computation with circuit "my_circuit" resolves
179/// #[arcium_callback(encrypted_ix = "my_circuit")]
180/// pub fn my_circuit_callback(
181/// ctx: Context<Callback>,
182/// output: SignedComputationOutputs<MyCircuitOutput>,
183/// ) -> Result<()> {
184/// // Destructure and handle Success/Failure
185/// let result = match output {
186/// SignedComputationOutputs::Success(MyCircuitOutput { field_0 }) => field_0,
187/// SignedComputationOutputs::Failure => {
188/// return Err(ErrorCode::ComputationFailed.into());
189/// }
190/// };
191/// msg!("Computation succeeded with result: {:?}", result);
192/// Ok(())
193/// }
194/// }
195/// ```
196#[proc_macro_attribute]
197pub fn callback_accounts(
198 args: proc_macro::TokenStream,
199 item: proc_macro::TokenStream,
200) -> proc_macro::TokenStream {
201 // Parse the input tokens into a syntax tree
202 let input = parse_macro_input!(item as DeriveInput);
203 let args = parse_macro_input!(args as CallbackAccArgs);
204 callback_accs_derive(&input, args)
205}
206
207/// The #[arcium_program] attribute defines the module
208/// containing all instruction handlers defining all entries into a Solana program.
209/// Under the hood, it gets expanded to Anchor's `#[program]` and some additional definitions needed
210/// for Arcium.
211#[proc_macro_attribute]
212pub fn arcium_program(
213 _args: proc_macro::TokenStream,
214 item: proc_macro::TokenStream,
215) -> proc_macro::TokenStream {
216 // Parse the input tokens into a syntax tree
217 let mut input = parse_macro_input!(item as Item);
218 arcium_program_macro(&mut input)
219}
220
221/// Validates the structure for initializing computation definitions by checking the encrypted
222/// instruction exists, validating required account fields, and implementing the `InitCompDefAccs`
223/// trait to make using it with computation definitions easy:
224///
225/// ```ignore
226/// #[init_computation_definition_accounts("my_circuit", payer)]
227/// #[derive(Accounts)]
228/// pub struct InitMyCircuitCompDef<'info> {
229/// #[account(mut)]
230/// pub payer: Signer<'info>,
231/// #[account(
232/// mut,
233/// address = derive_mxe_pda!()
234/// )]
235/// pub mxe_account: Box<Account<'info, MXEAccount>>,
236/// #[account(mut)]
237/// /// CHECK: comp_def_account, checked by arcium program.
238/// /// Can't check it here as it's not initialized yet.
239/// pub comp_def_account: UncheckedAccount<'info>,
240/// pub arcium_program: Program<'info, Arcium>,
241/// pub system_program: Program<'info, System>,
242/// }
243/// ```
244#[proc_macro_attribute]
245pub fn init_computation_definition_accounts(args: TokenStream, item: TokenStream) -> TokenStream {
246 // Parse the macro arguments
247 let args = parse_macro_input!(args as InitCompDefArgs);
248
249 // Parse the input tokens into a syntax tree
250 let mut input = parse_macro_input!(item as DeriveInput);
251
252 // Call the inner function with the parsed arguments
253 init_comp_def_derive(&mut input, args)
254}
255
256/// Compile-time validation of computation arguments against an interface definition.
257///
258/// This macro provides **compile-time** verification that your computation arguments match the
259/// circuit interface, catching type mismatches and argument count errors before runtime.
260///
261/// ## Usage
262/// 1. Attach `#[check_args]` to your function
263/// 2. Attach `#[args("your_circuit_name")]` to your computation arguments (array literal or `vec!`
264/// macro)
265///
266/// ## Checks Performed
267/// - Correct number of arguments based on the interface file (`build/your_circuit_name.idarc`)
268/// - Correct argument types (e.g., `PlaintextU32`, `EncryptedU8`, `X25519Pubkey`)
269/// - Correct argument order matching the circuit interface
270///
271/// ## Compile-Time Errors
272/// If arguments don't match the interface, you'll see a compilation error like:
273/// ```text
274/// error: mismatched types: expected Parameter::PlaintextU64, found Parameter::PlaintextU32
275/// ```
276///
277/// ## Example
278/// ```ignore
279/// #[check_args]
280/// pub fn submit_computation(ctx: Context<MyAccounts>, computation_offset: u64) -> Result<()> {
281/// queue_computation(
282/// &ctx.accounts,
283/// computation_offset,
284/// #[args("add_together")] // Validates against build/add_together.idarc
285/// ArgBuilder::new()
286/// .plaintext_u64(10)
287/// .plaintext_u64(20)
288/// .build(),
289/// None,
290/// vec![],
291/// 1,
292/// 0,
293/// )?;
294/// Ok(())
295/// }
296/// ```
297#[proc_macro_attribute]
298pub fn check_args(_args: TokenStream, item: TokenStream) -> TokenStream {
299 // Parse as a Stmt
300 let input = parse_macro_input!(item as ItemFn);
301 check_args_fn(input).into()
302}
303
304/// Returns the SHA-256 hash of a compiled circuit as a `[u8; 32]` array at compile time.
305///
306/// The hash is computed over the serialized circuit bytecode and can be used for verifying
307/// off-chain circuit sources. This is the same hash that ARX nodes verify when fetching circuits
308/// from off-chain sources.
309///
310/// ## Example
311/// ```ignore
312/// use arcium_macros::circuit_hash;
313///
314/// let source = OffChainCircuitSource {
315/// source: "https://ipfs.io/ipfs/Qm...".into(),
316/// hash: circuit_hash!("my_circuit"),
317/// };
318/// ```
319#[proc_macro]
320pub fn circuit_hash(input: TokenStream) -> TokenStream {
321 let circuit_name = parse_macro_input!(input as LitStr);
322 let hash = read_circuit_hash(&circuit_name.value());
323 let hash_bytes = hash.iter();
324 quote! { [#(#hash_bytes),*] }.into()
325}