arcium-macros 0.5.0

Helper macros for developing Solana programs that integrate with the Arcium network.
Documentation
//! Utility functions for reading circuit interfaces and generating token streams.
//!
//! ## Key Functions
//! - `comp_def_offset`: Generates deterministic u32 offsets from circuit names using SHA256 (first
//!   4 bytes)
//! - Interface readers: Parse `.idarc` files (standard circuits)
//! - Token generators: Convert circuit interface types to Rust token streams for code generation
//!
//! ## Build Directory Structure
//! - `build/{name}.arcis` - Compiled circuit bytecode
//! - `build/{name}.idarc` - Circuit interface definition (JSON)
//! - `build/{name}.weight` - Circuit weight/profile information (JSON)

use arcis_interface::{CircuitInterface, ScalarKind, Value};
use proc_macro2::TokenStream;
use quote::quote;
use sha2::{Digest, Sha256};
use std::fs;
use syn::{parse::Parse, punctuated::Punctuated, Meta, Token};

/// Arguments for the `#[arcium_callback]` attribute macro.
pub struct ArciumCallbackArgs {
    pub encrypted_ix: String,
    /// When `true` (default), validates that the callback output type matches the auto-generated
    /// type. Set to `false` to use custom deserialization logic.
    pub auto_serialize: bool,
}

impl Parse for ArciumCallbackArgs {
    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
        let mut encrypted_ix = None;
        let mut auto_serialize = None;

        let nested_meta_list = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;

        for nested_meta in nested_meta_list {
            if let Meta::NameValue(nv) = nested_meta {
                if nv.path.is_ident("encrypted_ix") {
                    if let syn::Expr::Lit(lit) = &nv.value {
                        if let syn::Lit::Str(s) = &lit.lit {
                            encrypted_ix = Some(s.value());
                        }
                    }
                } else if nv.path.is_ident("auto_serialize") {
                    if let syn::Expr::Lit(lit) = &nv.value {
                        if let syn::Lit::Bool(b) = &lit.lit {
                            auto_serialize = Some(b.value);
                        }
                    }
                }
            }
        }

        if let Some(c) = encrypted_ix {
            let args = ArciumCallbackArgs {
                encrypted_ix: c,
                auto_serialize: auto_serialize.unwrap_or(true), // Default to true
            };
            Ok(args)
        } else {
            panic!("Arcium callback derive requires a encrypted_ix = \"...\" parameter");
        }
    }
}

pub fn check_encrypted_ix_path(encrypted_ix_name: &str) {
    let encrypted_ix_file_path = format!("build/{}.arcis", &encrypted_ix_name);
    if fs::metadata(encrypted_ix_file_path.clone()).is_err() {
        panic!(
            "Confidential instruction was not found at path: {}",
            encrypted_ix_file_path,
        );
    }
}

pub fn read_conf_ix_interface(conf_ix_name: &str) -> CircuitInterface {
    let conf_ix_file_path = format!("build/{}.idarc", &conf_ix_name);
    let interface_json = fs::read_to_string(&conf_ix_file_path).unwrap_or_else(|_| {
        panic!(
            "Could not read confidential ix interface at path {}",
            conf_ix_file_path
        )
    });
    CircuitInterface::from_json(&interface_json).expect("Failed to parse interface from json")
}

pub fn read_compiled_conf_ix(conf_ix_name: &str) -> Vec<u8> {
    let conf_ix_file_path = format!("build/{}.arcis", &conf_ix_name);
    fs::read(&conf_ix_file_path).unwrap_or_else(|_| {
        panic!(
            "Could not read compiled confidential ix at path {}",
            conf_ix_file_path
        )
    })
}

/// Reads the pre-computed weight from the `.weight` file generated during circuit compilation.
///
/// The `.weight` file contains JSON with profile information including the computed weight.
/// This allows the macro to embed the weight at compile time without duplicating the weight
/// calculation formula.
pub fn read_circuit_weight(circuit_name: &str) -> u64 {
    let weight_path = format!("build/{}.weight", circuit_name);
    let content = fs::read_to_string(&weight_path).unwrap_or_else(|_| {
        panic!(
            "Could not read weight file at {}. Run 'arcium build' first.",
            weight_path
        )
    });

    let json: serde_json::Value =
        serde_json::from_str(&content).expect("Failed to parse .weight JSON");

    json["weight"]
        .as_u64()
        .expect("Missing or invalid 'weight' field in .weight file")
}

/// Generates a deterministic u32 offset from a circuit name using SHA256.
///
/// Takes the first 4 bytes of SHA256(circuit_name) as a little-endian u32. Collision risk is
/// negligible given the 2^32 space and typical circuit name diversity.
pub fn comp_def_offset(input: &str) -> u32 {
    let mut hasher = Sha256::new();
    hasher.update(input);
    let result = hasher.finalize();
    u32::from_le_bytes([result[0], result[1], result[2], result[3]])
}

/// Transforms the circuit interface into a list of tokens that represent the parameters for the
/// circuit. You might be wondering why we do circuitinterface -> param_tokens, instead of
/// circuitinterface -> params -> param_tokens, as the latter would feel a bit cleaner. The reason
/// is that this would require us to import arcium_client as a dependency of arcium_macros,
/// which for some reason causes a billion errors in anchor programs that then want to use
/// arcium_macros (in spite of these programs having arcium_client as a dependency themselves!).
/// Therefore, this is the only way to do it.
pub fn get_param_tokens_from_interface(circuit: &CircuitInterface) -> Vec<TokenStream> {
    circuit
        .inputs
        .iter()
        .flat_map(raw_input_to_param_tokens)
        .collect()
}

pub fn get_output_tokens_from_interface(circuit: &CircuitInterface) -> Vec<TokenStream> {
    circuit
        .outputs
        .iter()
        .flat_map(raw_output_to_output_tokens)
        .collect()
}

fn raw_input_to_param_tokens(val: &Value) -> Vec<TokenStream> {
    match val {
        Value::Bool => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextBool}],
        Value::Scalar { size_in_bits, kind } => match kind {
            ScalarKind::Unsigned => match size_in_bits {
                8 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU8}],
                16 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU16}],
                32 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU32}],
                64 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU64}],
                128 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU128}],
                _ => panic!(
                    "Unsupported unsigned integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
                    size_in_bits
                ),
            },
            ScalarKind::Signed => match size_in_bits {
                8 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI8}],
                16 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI16}],
                32 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI32}],
                64 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI64}],
                128 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI128}],
                _ => panic!(
                    "Unsupported signed integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
                    size_in_bits
                ),
            },
        },
        Value::Ciphertext { size_in_bits: _ } => {
            vec![quote! {::arcium_client::idl::arcium::types::Parameter::Ciphertext}]
        }
        Value::ArcisX25519Pubkey => {
            vec![quote! {::arcium_client::idl::arcium::types::Parameter::ArcisX25519Pubkey}]
        }
        Value::Point => {
            vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextPoint}]
        }
        Value::Float { size_in_bits } => {
            if *size_in_bits != 64 {
                panic!(
                    "Unsupported float size: {} bits. Only 64-bit floats (f64) are supported",
                    size_in_bits
                );
            }
            vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextFloat}]
        }
        Value::Array(c) => c.iter().flat_map(raw_input_to_param_tokens).collect(),
        Value::Tuple(c) => c.iter().flat_map(raw_input_to_param_tokens).collect(),
        Value::Struct(c) => c.iter().flat_map(raw_input_to_param_tokens).collect(),
        Value::MBool => panic!("Unsupported shared bool"),
        Value::MScalar { size_in_bits: _ } => panic!("Unsupported shared scalar"),
        Value::MFloat { size_in_bits: _ } => panic!("Unsupported shared float"),
    }
}

fn raw_output_to_output_tokens(val: &Value) -> Vec<TokenStream> {
    match val {
        Value::Bool => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextBool}],
        Value::Scalar { size_in_bits, kind } => match kind {
            ScalarKind::Unsigned => match size_in_bits {
                8 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU8}],
                16 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU16}],
                32 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU32}],
                64 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU64}],
                128 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU128}],
                _ => panic!(
                    "Unsupported unsigned integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
                    size_in_bits
                ),
            },
            ScalarKind::Signed => match size_in_bits {
                8 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI8}],
                16 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI16}],
                32 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI32}],
                64 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI64}],
                128 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI128}],
                _ => panic!(
                    "Unsupported signed integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
                    size_in_bits
                ),
            },
        },
        Value::Ciphertext { size_in_bits: _ } => {
            vec![quote! {::arcium_client::idl::arcium::types::Output::Ciphertext}]
        }
        Value::ArcisX25519Pubkey => {
            vec![quote! {::arcium_client::idl::arcium::types::Output::ArcisX25519Pubkey}]
        }
        Value::Point => {
            vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextPoint}]
        }
        Value::Float { size_in_bits } => {
            if *size_in_bits != 64 {
                panic!(
                    "Unsupported float size: {} bits. Only 64-bit floats (f64) are supported",
                    size_in_bits
                );
            }
            vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextFloat}]
        }
        Value::Array(c) => c.iter().flat_map(raw_output_to_output_tokens).collect(),
        Value::Tuple(c) => c.iter().flat_map(raw_output_to_output_tokens).collect(),
        Value::Struct(c) => c.iter().flat_map(raw_output_to_output_tokens).collect(),
        Value::MBool => panic!("Raw encrypted outputs are not supported yet."),
        Value::MScalar { size_in_bits: _ } => {
            panic!("Raw encrypted outputs are not supported yet.")
        }
        Value::MFloat { size_in_bits: _ } => panic!("Raw encrypted outputs are not supported yet."),
    }
}

#[allow(dead_code)]
pub fn circuit_callback_discriminator(circuit_name: &str) -> [u8; 8] {
    let ix_name = format!("{}_callback", circuit_name);
    calc_ix_discriminator(&ix_name)
}

#[allow(dead_code)]
fn calc_ix_discriminator(ix_ident: &str) -> [u8; 8] {
    let preimage_str = format!("global:{}", ix_ident);
    let preimage = preimage_str.as_bytes();
    let mut hasher = Sha256::new();
    hasher.update(preimage);
    let hash = hasher.finalize();
    let mut res = [0u8; 8];
    res.copy_from_slice(&hash[..8]);
    res
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::sync::Mutex;
    use tempfile::TempDir;

    // Mutex to ensure tests that change current directory don't run in parallel
    static DIR_MUTEX: Mutex<()> = Mutex::new(());

    #[test]
    fn test_comp_def_offset() {
        let conf_ix_name = "add_together";
        let offset = comp_def_offset(conf_ix_name);
        assert_eq!(offset, 4005749700);
    }

    #[test]
    fn test_read_circuit_weight_valid() {
        let _lock = DIR_MUTEX.lock().unwrap();
        let temp_dir = TempDir::new().unwrap();
        let build_dir = temp_dir.path().join("build");
        fs::create_dir_all(&build_dir).unwrap();
        fs::write(
            build_dir.join("test_circuit.weight"),
            r#"{"weight": 12345678}"#,
        )
        .unwrap();

        let original_dir = std::env::current_dir().unwrap();
        std::env::set_current_dir(temp_dir.path()).unwrap();
        let result = read_circuit_weight("test_circuit");
        std::env::set_current_dir(&original_dir).unwrap();
        assert_eq!(result, 12345678);
    }

    #[test]
    fn test_read_circuit_weight_missing_file() {
        let _lock = DIR_MUTEX.lock().unwrap();
        let temp_dir = TempDir::new().unwrap();
        let original_dir = std::env::current_dir().unwrap();
        std::env::set_current_dir(temp_dir.path()).unwrap();

        let result = std::panic::catch_unwind(|| read_circuit_weight("nonexistent_circuit"));
        std::env::set_current_dir(&original_dir).unwrap();

        assert!(result.is_err());
        let err = result.unwrap_err();
        let msg = err
            .downcast_ref::<String>()
            .map(|s| s.as_str())
            .or_else(|| err.downcast_ref::<&str>().copied())
            .unwrap_or("");
        assert!(
            msg.contains("Could not read weight file"),
            "Expected panic message to contain 'Could not read weight file', got: {}",
            msg
        );
    }

    #[test]
    fn test_read_circuit_weight_invalid_json() {
        let _lock = DIR_MUTEX.lock().unwrap();
        let temp_dir = TempDir::new().unwrap();
        let build_dir = temp_dir.path().join("build");
        fs::create_dir_all(&build_dir).unwrap();
        fs::write(build_dir.join("invalid.weight"), "not valid json").unwrap();

        let original_dir = std::env::current_dir().unwrap();
        std::env::set_current_dir(temp_dir.path()).unwrap();

        let result = std::panic::catch_unwind(|| read_circuit_weight("invalid"));
        std::env::set_current_dir(&original_dir).unwrap();

        assert!(result.is_err());
        let err = result.unwrap_err();
        let msg = err
            .downcast_ref::<String>()
            .map(|s| s.as_str())
            .or_else(|| err.downcast_ref::<&str>().copied())
            .unwrap_or("");
        assert!(
            msg.contains("Failed to parse .weight JSON"),
            "Expected panic message to contain 'Failed to parse .weight JSON', got: {}",
            msg
        );
    }

    #[test]
    fn test_read_circuit_weight_missing_field() {
        let _lock = DIR_MUTEX.lock().unwrap();
        let temp_dir = TempDir::new().unwrap();
        let build_dir = temp_dir.path().join("build");
        fs::create_dir_all(&build_dir).unwrap();
        fs::write(build_dir.join("no_weight.weight"), r#"{"other": 123}"#).unwrap();

        let original_dir = std::env::current_dir().unwrap();
        std::env::set_current_dir(temp_dir.path()).unwrap();

        let result = std::panic::catch_unwind(|| read_circuit_weight("no_weight"));
        std::env::set_current_dir(&original_dir).unwrap();

        assert!(result.is_err());
        let err = result.unwrap_err();
        let msg = err
            .downcast_ref::<String>()
            .map(|s| s.as_str())
            .or_else(|| err.downcast_ref::<&str>().copied())
            .unwrap_or("");
        assert!(
            msg.contains("Missing or invalid 'weight' field"),
            "Expected panic message to contain 'Missing or invalid 'weight' field', got: {}",
            msg
        );
    }
}