use arcis_interface::{CircuitInterface, ScalarKind, Value};
use proc_macro2::TokenStream;
use quote::quote;
use sha2::{Digest, Sha256};
use std::fs;
use syn::{parse::Parse, punctuated::Punctuated, Meta, Token};
pub struct ArciumCallbackArgs {
pub encrypted_ix: String,
pub auto_serialize: bool,
}
impl Parse for ArciumCallbackArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut encrypted_ix = None;
let mut auto_serialize = None;
let nested_meta_list = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
for nested_meta in nested_meta_list {
if let Meta::NameValue(nv) = nested_meta {
if nv.path.is_ident("encrypted_ix") {
if let syn::Expr::Lit(lit) = &nv.value {
if let syn::Lit::Str(s) = &lit.lit {
encrypted_ix = Some(s.value());
}
}
} else if nv.path.is_ident("auto_serialize") {
if let syn::Expr::Lit(lit) = &nv.value {
if let syn::Lit::Bool(b) = &lit.lit {
auto_serialize = Some(b.value);
}
}
}
}
}
if let Some(c) = encrypted_ix {
let args = ArciumCallbackArgs {
encrypted_ix: c,
auto_serialize: auto_serialize.unwrap_or(true), };
Ok(args)
} else {
panic!("Arcium callback derive requires a encrypted_ix = \"...\" parameter");
}
}
}
pub fn check_encrypted_ix_path(encrypted_ix_name: &str) {
let encrypted_ix_file_path = format!("build/{}.arcis", &encrypted_ix_name);
if fs::metadata(encrypted_ix_file_path.clone()).is_err() {
panic!(
"Confidential instruction was not found at path: {}",
encrypted_ix_file_path,
);
}
}
pub fn read_conf_ix_interface(conf_ix_name: &str) -> CircuitInterface {
let conf_ix_file_path = format!("build/{}.idarc", &conf_ix_name);
let interface_json = fs::read_to_string(&conf_ix_file_path).unwrap_or_else(|_| {
panic!(
"Could not read confidential ix interface at path {}",
conf_ix_file_path
)
});
CircuitInterface::from_json(&interface_json).expect("Failed to parse interface from json")
}
pub fn read_compiled_conf_ix(conf_ix_name: &str) -> Vec<u8> {
let conf_ix_file_path = format!("build/{}.arcis", &conf_ix_name);
fs::read(&conf_ix_file_path).unwrap_or_else(|_| {
panic!(
"Could not read compiled confidential ix at path {}",
conf_ix_file_path
)
})
}
pub fn read_circuit_weight(circuit_name: &str) -> u64 {
let weight_path = format!("build/{}.weight", circuit_name);
let content = fs::read_to_string(&weight_path).unwrap_or_else(|_| {
panic!(
"Could not read weight file at {}. Run 'arcium build' first.",
weight_path
)
});
let json: serde_json::Value =
serde_json::from_str(&content).expect("Failed to parse .weight JSON");
json["weight"]
.as_u64()
.expect("Missing or invalid 'weight' field in .weight file")
}
pub fn read_circuit_hash(circuit_name: &str) -> [u8; 32] {
let hash_path = format!("build/{}.hash", circuit_name);
let content = fs::read_to_string(&hash_path).unwrap_or_else(|_| {
panic!(
"Could not read hash file at {}. Run 'arcium build' first.",
hash_path
)
});
serde_json::from_str(&content).unwrap_or_else(|e| {
panic!(
"Failed to parse .hash JSON at {}. Expected [u8; 32] array. Error: {}",
hash_path, e
)
})
}
pub fn comp_def_offset(input: &str) -> u32 {
let mut hasher = Sha256::new();
hasher.update(input);
let result = hasher.finalize();
u32::from_le_bytes([result[0], result[1], result[2], result[3]])
}
pub fn get_param_tokens_from_interface(circuit: &CircuitInterface) -> Vec<TokenStream> {
circuit
.inputs
.iter()
.flat_map(raw_input_to_param_tokens)
.collect()
}
pub fn get_output_tokens_from_interface(circuit: &CircuitInterface) -> Vec<TokenStream> {
circuit
.outputs
.iter()
.flat_map(raw_output_to_output_tokens)
.collect()
}
fn raw_input_to_param_tokens(val: &Value) -> Vec<TokenStream> {
match val {
Value::Bool => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextBool}],
Value::Scalar { size_in_bits, kind } => match kind {
ScalarKind::Unsigned => match size_in_bits {
8 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU8}],
16 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU16}],
32 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU32}],
64 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU64}],
128 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextU128}],
_ => panic!(
"Unsupported unsigned integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
size_in_bits
),
},
ScalarKind::Signed => match size_in_bits {
8 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI8}],
16 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI16}],
32 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI32}],
64 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI64}],
128 => vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextI128}],
_ => panic!(
"Unsupported signed integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
size_in_bits
),
},
},
Value::Ciphertext { size_in_bits: _ } => {
vec![quote! {::arcium_client::idl::arcium::types::Parameter::Ciphertext}]
}
Value::ArcisX25519Pubkey => {
vec![quote! {::arcium_client::idl::arcium::types::Parameter::ArcisX25519Pubkey}]
}
Value::Point => {
vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextPoint}]
}
Value::Float { size_in_bits } => {
if *size_in_bits != 64 {
panic!(
"Unsupported float size: {} bits. Only 64-bit floats (f64) are supported",
size_in_bits
);
}
vec![quote! {::arcium_client::idl::arcium::types::Parameter::PlaintextFloat}]
}
Value::Array(c) => c.iter().flat_map(raw_input_to_param_tokens).collect(),
Value::Tuple(c) => c.iter().flat_map(raw_input_to_param_tokens).collect(),
Value::Struct(c) => c.iter().flat_map(raw_input_to_param_tokens).collect(),
Value::MBool => panic!("Unsupported shared bool"),
Value::MScalar { size_in_bits: _ } => panic!("Unsupported shared scalar"),
Value::MFloat { size_in_bits: _ } => panic!("Unsupported shared float"),
}
}
fn raw_output_to_output_tokens(val: &Value) -> Vec<TokenStream> {
match val {
Value::Bool => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextBool}],
Value::Scalar { size_in_bits, kind } => match kind {
ScalarKind::Unsigned => match size_in_bits {
8 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU8}],
16 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU16}],
32 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU32}],
64 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU64}],
128 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextU128}],
_ => panic!(
"Unsupported unsigned integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
size_in_bits
),
},
ScalarKind::Signed => match size_in_bits {
8 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI8}],
16 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI16}],
32 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI32}],
64 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI64}],
128 => vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextI128}],
_ => panic!(
"Unsupported signed integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
size_in_bits
),
},
},
Value::Ciphertext { size_in_bits: _ } => {
vec![quote! {::arcium_client::idl::arcium::types::Output::Ciphertext}]
}
Value::ArcisX25519Pubkey => {
vec![quote! {::arcium_client::idl::arcium::types::Output::ArcisX25519Pubkey}]
}
Value::Point => {
vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextPoint}]
}
Value::Float { size_in_bits } => {
if *size_in_bits != 64 {
panic!(
"Unsupported float size: {} bits. Only 64-bit floats (f64) are supported",
size_in_bits
);
}
vec![quote! {::arcium_client::idl::arcium::types::Output::PlaintextFloat}]
}
Value::Array(c) => c.iter().flat_map(raw_output_to_output_tokens).collect(),
Value::Tuple(c) => c.iter().flat_map(raw_output_to_output_tokens).collect(),
Value::Struct(c) => c.iter().flat_map(raw_output_to_output_tokens).collect(),
Value::MBool => panic!("Raw encrypted outputs are not supported yet."),
Value::MScalar { size_in_bits: _ } => {
panic!("Raw encrypted outputs are not supported yet.")
}
Value::MFloat { size_in_bits: _ } => panic!("Raw encrypted outputs are not supported yet."),
}
}
#[allow(dead_code)]
pub fn circuit_callback_discriminator(circuit_name: &str) -> [u8; 8] {
let ix_name = format!("{}_callback", circuit_name);
calc_ix_discriminator(&ix_name)
}
#[allow(dead_code)]
fn calc_ix_discriminator(ix_ident: &str) -> [u8; 8] {
let preimage_str = format!("global:{}", ix_ident);
let preimage = preimage_str.as_bytes();
let mut hasher = Sha256::new();
hasher.update(preimage);
let hash = hasher.finalize();
let mut res = [0u8; 8];
res.copy_from_slice(&hash[..8]);
res
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use tempfile::TempDir;
static DIR_MUTEX: Mutex<()> = Mutex::new(());
#[test]
fn test_comp_def_offset() {
let conf_ix_name = "add_together";
let offset = comp_def_offset(conf_ix_name);
assert_eq!(offset, 4005749700);
}
#[test]
fn test_read_circuit_weight_valid() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let build_dir = temp_dir.path().join("build");
fs::create_dir_all(&build_dir).unwrap();
fs::write(
build_dir.join("test_circuit.weight"),
r#"{"weight": 12345678}"#,
)
.unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = read_circuit_weight("test_circuit");
std::env::set_current_dir(&original_dir).unwrap();
assert_eq!(result, 12345678);
}
#[test]
fn test_read_circuit_weight_missing_file() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = std::panic::catch_unwind(|| read_circuit_weight("nonexistent_circuit"));
std::env::set_current_dir(&original_dir).unwrap();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("Could not read weight file"),
"Expected panic message to contain 'Could not read weight file', got: {}",
msg
);
}
#[test]
fn test_read_circuit_weight_invalid_json() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let build_dir = temp_dir.path().join("build");
fs::create_dir_all(&build_dir).unwrap();
fs::write(build_dir.join("invalid.weight"), "not valid json").unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = std::panic::catch_unwind(|| read_circuit_weight("invalid"));
std::env::set_current_dir(&original_dir).unwrap();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("Failed to parse .weight JSON"),
"Expected panic message to contain 'Failed to parse .weight JSON', got: {}",
msg
);
}
#[test]
fn test_read_circuit_weight_missing_field() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let build_dir = temp_dir.path().join("build");
fs::create_dir_all(&build_dir).unwrap();
fs::write(build_dir.join("no_weight.weight"), r#"{"other": 123}"#).unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = std::panic::catch_unwind(|| read_circuit_weight("no_weight"));
std::env::set_current_dir(&original_dir).unwrap();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("Missing or invalid 'weight' field"),
"Expected panic message to contain 'Missing or invalid 'weight' field', got: {}",
msg
);
}
#[test]
fn test_read_circuit_hash_valid() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let build_dir = temp_dir.path().join("build");
fs::create_dir_all(&build_dir).unwrap();
let expected_hash: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
];
fs::write(
build_dir.join("test_circuit.hash"),
serde_json::to_string(&expected_hash).unwrap(),
)
.unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = read_circuit_hash("test_circuit");
std::env::set_current_dir(&original_dir).unwrap();
assert_eq!(result, expected_hash);
}
#[test]
fn test_read_circuit_hash_missing_file() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = std::panic::catch_unwind(|| read_circuit_hash("nonexistent_circuit"));
std::env::set_current_dir(&original_dir).unwrap();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("Could not read hash file"),
"Expected panic message to contain 'Could not read hash file', got: {}",
msg
);
}
#[test]
fn test_read_circuit_hash_invalid_json() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let build_dir = temp_dir.path().join("build");
fs::create_dir_all(&build_dir).unwrap();
fs::write(build_dir.join("invalid.hash"), "not valid json").unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = std::panic::catch_unwind(|| read_circuit_hash("invalid"));
std::env::set_current_dir(&original_dir).unwrap();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("Failed to parse .hash JSON"),
"Expected panic message to contain 'Failed to parse .hash JSON', got: {}",
msg
);
}
#[test]
fn test_read_circuit_hash_wrong_size() {
let _lock = DIR_MUTEX.lock().unwrap();
let temp_dir = TempDir::new().unwrap();
let build_dir = temp_dir.path().join("build");
fs::create_dir_all(&build_dir).unwrap();
let wrong_size: [u8; 16] = [1; 16];
fs::write(
build_dir.join("wrong_size.hash"),
serde_json::to_string(&wrong_size).unwrap(),
)
.unwrap();
let original_dir = std::env::current_dir().unwrap();
std::env::set_current_dir(temp_dir.path()).unwrap();
let result = std::panic::catch_unwind(|| read_circuit_hash("wrong_size"));
std::env::set_current_dir(&original_dir).unwrap();
assert!(result.is_err());
let err = result.unwrap_err();
let msg = err
.downcast_ref::<String>()
.map(|s| s.as_str())
.or_else(|| err.downcast_ref::<&str>().copied())
.unwrap_or("");
assert!(
msg.contains("Failed to parse .hash JSON"),
"Expected panic message to contain 'Failed to parse .hash JSON', got: {}",
msg
);
}
}