use crate::utils::read_conf_ix_interface;
use arcis_interface::{CircuitInterface, ScalarKind, Value};
use convert_case::{Case, Casing};
use std::collections::HashSet;
const NONCE_SIZE_BITS: usize = 128;
fn value_size_in_bytes(val: &Value) -> usize {
match val {
Value::Ciphertext { .. } | Value::MScalar { .. } | Value::MFloat { .. } | Value::MBool => {
32
}
Value::Scalar { size_in_bits, .. } => *size_in_bits / 8,
Value::Float { size_in_bits } => *size_in_bits / 8,
Value::Bool => 1,
Value::ArcisX25519Pubkey => 32,
Value::Point => 32,
Value::Array(inner) => {
if inner.is_empty() {
0
} else {
value_size_in_bytes(&inner[0]) * inner.len()
}
}
Value::Tuple(inner) => inner.iter().map(value_size_in_bytes).sum(),
Value::Struct(inner) => {
match EncryptionComponents::from_value(&Value::Struct(inner.clone())).get_type() {
Some((EncryptionType::Shared, len)) => 32 + 16 + len * 32,
Some((EncryptionType::Mxe, len)) => 16 + len * 32,
Some((EncryptionType::EncData, len)) => len * 32,
None => inner.iter().map(value_size_in_bytes).sum(),
}
}
}
}
#[derive(Debug)]
struct EncryptionComponents {
has_public_key: bool,
has_nonce: bool,
ciphertext_count: usize,
}
impl EncryptionComponents {
fn new() -> Self {
Self {
has_public_key: false,
has_nonce: false,
ciphertext_count: 0,
}
}
fn from_value(value: &Value) -> Self {
let mut components = Self::new();
components.extract_from(value);
components
}
fn get_type(&self) -> Option<(EncryptionType, usize)> {
match (
self.has_public_key,
self.has_nonce,
self.ciphertext_count > 0,
) {
(true, true, true) => Some((EncryptionType::Shared, self.ciphertext_count)),
(false, true, true) => Some((EncryptionType::Mxe, self.ciphertext_count)),
(false, false, true) => Some((EncryptionType::EncData, self.ciphertext_count)),
_ => None,
}
}
fn extract_from(&mut self, value: &Value) {
match value {
Value::ArcisX25519Pubkey => {
self.has_public_key = true;
}
Value::Scalar {
size_in_bits: NONCE_SIZE_BITS,
..
} => {
self.has_nonce = true;
}
Value::Ciphertext { .. } => {
self.ciphertext_count += 1;
}
Value::Array(values) if !values.is_empty() => {
for v in values {
self.extract_from(v);
}
}
Value::Struct(values) | Value::Tuple(values) => {
for v in values {
self.extract_from(v);
}
}
_ => {}
}
}
}
pub fn gen_callback_output_struct(conf_ix_name: &str) -> proc_macro2::TokenStream {
let iface: CircuitInterface = read_conf_ix_interface(conf_ix_name);
let struct_name = syn::Ident::new(
&format!("{}Output", iface.name.to_case(Case::Pascal)),
proc_macro2::Span::call_site(),
);
let custom_structs = gen_all_custom_structs(&iface);
let fields = iface.outputs.iter().enumerate().map(|(i, val)| {
let field_name = syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site());
let ty = value_to_type_for_output(val, &iface.name.to_case(Case::Pascal), i);
quote::quote! { pub #field_name: #ty }
});
let total_size: usize = iface.outputs.iter().map(value_size_in_bytes).sum();
let x = quote::quote! {
#(#custom_structs)*
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #struct_name {
#(#fields),*
}
impl #struct_name {
pub const SIZE: usize = #total_size;
}
impl ::arcium_anchor::HasSize for #struct_name {
const SIZE: usize = #total_size;
}
};
x
}
fn value_to_type(val: &Value) -> Vec<proc_macro2::TokenStream> {
match val {
Value::Ciphertext { .. } | Value::MScalar { .. } | Value::MFloat { .. } | Value::MBool => {
vec![quote::quote!([u8; 32])]
}
Value::Scalar { size_in_bits, kind } => match kind {
ScalarKind::Unsigned => match size_in_bits {
8 => vec![quote::quote!(u8)],
16 => vec![quote::quote!(u16)],
32 => vec![quote::quote!(u32)],
64 => vec![quote::quote!(u64)],
128 => vec![quote::quote!(u128)],
_ => panic!(
"Unsupported unsigned integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
size_in_bits
),
},
ScalarKind::Signed => match size_in_bits {
8 => vec![quote::quote!(i8)],
16 => vec![quote::quote!(i16)],
32 => vec![quote::quote!(i32)],
64 => vec![quote::quote!(i64)],
128 => vec![quote::quote!(i128)],
_ => panic!(
"Unsupported signed integer size: {} bits. Supported sizes are: 8, 16, 32, 64, 128",
size_in_bits
),
},
},
Value::Float { size_in_bits } => match size_in_bits {
32 => vec![quote::quote!(f32)],
64 => vec![quote::quote!(f64)],
_ => panic!(
"Unsupported float size: {} bits. Supported sizes are: 32 (f32), 64 (f64)",
size_in_bits
),
},
Value::Bool => vec![quote::quote!(bool)],
Value::ArcisX25519Pubkey => vec![quote::quote!([u8; 32])],
Value::Point => vec![quote::quote!([u8; 32])],
Value::Array(inner) => {
let len = inner.len();
if len == 0 {
return vec![];
}
let tys = value_to_type(&inner[0]);
if tys.is_empty() {
vec![]
} else {
vec![quote::quote!([#(tys[0]); #len])]
}
}
Value::Tuple(inner) => {
let tys = inner.iter().flat_map(value_to_type);
vec![quote::quote!((#(#tys),*))]
}
Value::Struct(inner) => {
match extract_and_get_encryption_type(&Value::Struct(inner.clone())) {
Some((EncryptionType::Shared, len)) => vec![quote::quote! {
SharedEncryptedStruct<#len>
}],
Some((EncryptionType::Mxe, len)) => vec![quote::quote! {
MXEEncryptedStruct<#len>
}],
Some((EncryptionType::EncData, len)) => vec![quote::quote! {
EncDataStruct<#len>
}],
None => {
inner.iter().flat_map(value_to_type).collect()
}
}
}
}
}
fn extract_and_get_encryption_type(value: &Value) -> Option<(EncryptionType, usize)> {
EncryptionComponents::from_value(value).get_type()
}
#[derive(Debug, PartialEq)]
enum EncryptionType {
Shared,
Mxe,
EncData,
}
pub fn gen_all_custom_structs(iface: &CircuitInterface) -> Vec<proc_macro2::TokenStream> {
let mut seen = HashSet::new();
let mut structs = Vec::new();
let base_prefix = iface.name.to_case(Case::Pascal);
for (i, val) in iface.outputs.iter().enumerate() {
let prefix = format!("{}OutputStruct{}", base_prefix, i);
collect_structs(val, &mut seen, &mut structs, &prefix);
}
structs
}
fn generate_nested_struct_name(prefix: &str, seen_count: usize) -> String {
format!("{}{}", prefix, seen_count)
}
fn collect_structs(
val: &Value,
seen: &mut HashSet<String>,
structs: &mut Vec<proc_macro2::TokenStream>,
prefix: &str,
) {
match val {
Value::Struct(inner) => {
if extract_and_get_encryption_type(&Value::Struct(inner.clone())).is_some() {
return;
}
let struct_name = if prefix.contains("OutputStruct") && !seen.contains(prefix) {
prefix.to_string()
} else {
generate_nested_struct_name(prefix, 0)
};
if seen.insert(struct_name.clone()) {
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
let field_types: Vec<_> = inner
.iter()
.enumerate()
.map(|(i, v)| value_to_type_for_structs_with_index(v, &struct_name, seen, i))
.collect();
let fields = field_types.iter().enumerate().map(|(i, ty)| {
let field_name =
syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site());
quote::quote! { pub #field_name: #ty }
});
let struct_size: usize = inner.iter().map(value_size_in_bytes).sum();
structs.push(quote::quote! {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ident {
#(#fields),*
}
impl #ident {
pub const SIZE: usize = #struct_size;
}
impl ::arcium_anchor::HasSize for #ident {
const SIZE: usize = #struct_size;
}
});
for (i, v) in inner.iter().enumerate() {
collect_structs(v, seen, structs, &format!("{}{}", struct_name, i));
}
}
}
Value::Array(inner) => {
if !inner.is_empty() {
if !matches!(&inner[0], Value::Tuple(_)) {
collect_structs(&inner[0], seen, structs, prefix);
}
}
}
Value::Tuple(inner) => {
let struct_name = prefix.to_string();
if seen.insert(struct_name.clone()) {
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
let struct_size: usize = inner.iter().map(value_size_in_bytes).sum();
if inner.is_empty() {
structs.push(quote::quote! {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ident {
}
impl #ident {
pub const SIZE: usize = #struct_size;
}
impl ::arcium_anchor::HasSize for #ident {
const SIZE: usize = #struct_size;
}
});
} else {
let field_types: Vec<_> = inner
.iter()
.enumerate()
.map(|(i, v)| {
value_to_type_for_structs_with_index(v, &struct_name, seen, i)
})
.collect();
let fields = field_types.iter().enumerate().map(|(i, ty)| {
let field_name = syn::Ident::new(
&format!("field_{}", i),
proc_macro2::Span::call_site(),
);
quote::quote! { pub #field_name: #ty }
});
structs.push(quote::quote! {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ident {
#(#fields),*
}
impl #ident {
pub const SIZE: usize = #struct_size;
}
impl ::arcium_anchor::HasSize for #ident {
const SIZE: usize = #struct_size;
}
});
for (field_idx, v) in inner.iter().enumerate() {
let nested_prefix = generate_nested_struct_name(&struct_name, field_idx);
collect_structs(v, seen, structs, &nested_prefix);
}
}
}
}
_ => {}
}
}
fn value_to_type_for_structs(
val: &Value,
prefix: &str,
seen: &mut HashSet<String>,
) -> proc_macro2::TokenStream {
value_to_type_for_structs_with_index(val, prefix, seen, 0)
}
fn value_to_type_for_structs_with_index(
val: &Value,
prefix: &str,
seen: &mut HashSet<String>,
index: usize,
) -> proc_macro2::TokenStream {
match val {
Value::Struct(inner) => {
match extract_and_get_encryption_type(&Value::Struct(inner.clone())) {
Some((EncryptionType::Shared, len)) => {
quote::quote! { SharedEncryptedStruct<#len> }
}
Some((EncryptionType::Mxe, len)) => quote::quote! { MXEEncryptedStruct<#len> },
Some((EncryptionType::EncData, len)) => quote::quote! { EncDataStruct<#len> },
None => {
let struct_name = format!("{}{}", prefix, index);
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
quote::quote! { #ident }
}
}
}
Value::Array(inner) => {
if inner.is_empty() {
quote::quote!([(); 0])
} else {
let ty = value_to_type_for_structs(&inner[0], prefix, seen);
let len = inner.len();
quote::quote!([#ty; #len])
}
}
Value::Tuple(_) => {
panic!("Tuple in circuit return type is not supported.");
}
_ => value_to_type(val)
.into_iter()
.next()
.unwrap_or_else(|| quote::quote!(())),
}
}
fn value_to_type_for_output(val: &Value, prefix: &str, i: usize) -> proc_macro2::TokenStream {
match val {
Value::Struct(inner) => {
match extract_and_get_encryption_type(&Value::Struct(inner.clone())) {
Some((EncryptionType::Shared, len)) => quote::quote! {
SharedEncryptedStruct<#len>
},
Some((EncryptionType::Mxe, len)) => quote::quote! {
MXEEncryptedStruct<#len>
},
Some((EncryptionType::EncData, len)) => quote::quote! {
EncDataStruct<#len>
},
None => {
let struct_name = format!("{}OutputStruct{}", prefix, i);
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
quote::quote! { #ident }
}
}
}
Value::Array(inner) => {
let len = inner.len();
if len == 0 {
quote::quote!([(); 0])
} else {
let ty = value_to_type_for_output(&inner[0], prefix, i);
quote::quote!([#ty; #len])
}
}
Value::Tuple(_inner) => {
let struct_name = format!("{}OutputStruct{}", prefix, i);
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
quote::quote! { #ident }
}
_ => value_to_type(val)
.into_iter()
.next()
.unwrap_or_else(|| quote::quote!(())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_semantic_encryption_detection() {
let shared_v = Value::Struct(vec![
Value::Struct(vec![
Value::ArcisX25519Pubkey,
Value::Scalar {
size_in_bits: 128,
kind: ScalarKind::Unsigned,
},
]),
Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
]),
]);
assert_eq!(
extract_and_get_encryption_type(&shared_v),
Some((EncryptionType::Shared, 1))
);
let mxe_v = Value::Struct(vec![
Value::Struct(vec![Value::Scalar {
size_in_bits: 128,
kind: ScalarKind::Unsigned,
}]),
Value::Struct(vec![
Value::Array(vec![
Value::Ciphertext { size_in_bits: 255 },
Value::Ciphertext { size_in_bits: 255 },
]),
Value::Array(vec![]),
]),
]);
assert_eq!(
extract_and_get_encryption_type(&mxe_v),
Some((EncryptionType::Mxe, 2))
);
let enc_data_v = Value::Struct(vec![Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
])]);
assert_eq!(
extract_and_get_encryption_type(&enc_data_v),
Some((EncryptionType::EncData, 1))
);
let normal_v = Value::Struct(vec![
Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned,
},
Value::Bool,
]);
assert_eq!(extract_and_get_encryption_type(&normal_v), None);
}
#[test]
fn test_gen_all_custom_structs() {
let iface = CircuitInterface {
name: "TestInterface".to_string(),
inputs: vec![],
outputs: vec![
Value::Struct(vec![
Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned,
},
Value::Bool,
]),
Value::Tuple(vec![
Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
},
Value::Float { size_in_bits: 32 },
]),
],
};
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 2); }
#[test]
fn test_empty_tuple_handling() {
let iface = CircuitInterface {
name: "ManticoreAuc".to_string(),
inputs: vec![],
outputs: vec![Value::Tuple(vec![])], };
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 1);
let struct_code = structs[0].to_string();
assert!(struct_code.contains("pub struct ManticoreAucOutputStruct0"));
assert!(struct_code.contains("# [derive (AnchorSerialize , AnchorDeserialize)]"));
assert!(!struct_code.contains("pub field_"));
let output_type = value_to_type_for_output(&Value::Tuple(vec![]), "ManticoreAuc", 0);
assert_eq!(output_type.to_string(), "ManticoreAucOutputStruct0");
}
#[test]
fn test_naming_consistency_for_nested_structs() {
let iface = CircuitInterface {
name: "InsertOrder".to_string(),
inputs: vec![],
outputs: vec![Value::Tuple(vec![
Value::Struct(vec![
Value::Struct(vec![Value::Scalar {
size_in_bits: 128,
kind: ScalarKind::Unsigned,
}]),
Value::Struct(vec![
Value::Array(vec![
Value::Ciphertext { size_in_bits: 255 },
Value::Ciphertext { size_in_bits: 255 },
]),
Value::Array(vec![]),
]),
]),
Value::Struct(vec![
Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
}, Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
}, Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
}, Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
}, Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
}, Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
}, ]),
])],
};
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 2);
let struct_strings: Vec<String> = structs.iter().map(|s| s.to_string()).collect();
assert!(struct_strings
.iter()
.any(|s| s.contains("pub struct InsertOrderOutputStruct0")));
assert!(struct_strings
.iter()
.any(|s| s.contains("pub struct InsertOrderOutputStruct01")));
let main_struct = struct_strings
.iter()
.find(|s| s.contains("pub struct InsertOrderOutputStruct0"))
.expect("Main output struct should exist");
assert!(main_struct.contains("InsertOrderOutputStruct01"));
}
#[test]
fn test_value_to_type_for_output() {
let val = Value::Struct(vec![
Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned,
},
Value::Bool,
]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(!ty.to_string().is_empty());
let val = Value::Tuple(vec![
Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
},
Value::Float { size_in_bits: 32 },
]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(!ty.to_string().is_empty());
let val = Value::Struct(vec![
Value::Struct(vec![
Value::ArcisX25519Pubkey,
Value::Scalar {
size_in_bits: 128,
kind: ScalarKind::Unsigned,
},
]),
Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
]),
]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(ty.to_string().contains("SharedEncryptedStruct"));
let val = Value::Struct(vec![Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
])]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(ty.to_string().contains("EncDataStruct"));
}
#[test]
fn test_generate_nested_struct_name() {
assert_eq!(
generate_nested_struct_name("ComplexExample", 0),
"ComplexExample0"
);
assert_eq!(
generate_nested_struct_name("ComplexExample", 5),
"ComplexExample5"
);
assert_eq!(
generate_nested_struct_name("InsertOrderOutputStruct0", 1),
"InsertOrderOutputStruct01"
);
assert_eq!(
generate_nested_struct_name("ComplexExampleOutputStruct0", 2),
"ComplexExampleOutputStruct02"
);
}
#[test]
fn test_struct_generation_with_nested_types() {
let iface = CircuitInterface {
name: "NestedTest".to_string(),
inputs: vec![],
outputs: vec![Value::Struct(vec![
Value::Struct(vec![
Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned,
},
Value::Bool,
]),
Value::Array(vec![
Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
},
Value::Float { size_in_bits: 32 },
]),
])],
};
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 2);
}
#[test]
fn test_value_size_in_bytes_primitives() {
assert_eq!(value_size_in_bytes(&Value::Bool), 1);
assert_eq!(
value_size_in_bytes(&Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned
}),
4
);
assert_eq!(
value_size_in_bytes(&Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned
}),
8
);
assert_eq!(value_size_in_bytes(&Value::Float { size_in_bits: 32 }), 4);
assert_eq!(
value_size_in_bytes(&Value::Ciphertext { size_in_bits: 255 }),
32
);
}
#[test]
fn test_value_size_in_bytes_composites() {
assert_eq!(
value_size_in_bytes(&Value::Tuple(vec![
Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned
},
Value::Bool,
Value::Float { size_in_bits: 32 },
])),
13
);
assert_eq!(
value_size_in_bytes(&Value::Array(vec![
Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned
};
5
])),
20
);
assert_eq!(
value_size_in_bytes(&Value::Struct(vec![
Value::Scalar {
size_in_bits: 32,
kind: ScalarKind::Unsigned
},
Value::Bool,
])),
5
);
}
#[test]
fn test_value_size_in_bytes_encryption_types() {
let shared_enc = Value::Struct(vec![
Value::Struct(vec![
Value::ArcisX25519Pubkey,
Value::Scalar {
size_in_bits: 128,
kind: ScalarKind::Unsigned,
},
]),
Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
]),
]);
assert_eq!(value_size_in_bytes(&shared_enc), 80);
let mxe_enc = Value::Struct(vec![
Value::Struct(vec![Value::Scalar {
size_in_bits: 128,
kind: ScalarKind::Unsigned,
}]),
Value::Struct(vec![
Value::Array(vec![
Value::Ciphertext { size_in_bits: 255 },
Value::Ciphertext { size_in_bits: 255 },
]),
Value::Array(vec![]),
]),
]);
assert_eq!(value_size_in_bytes(&mxe_enc), 80);
}
#[test]
fn test_generated_struct_has_correct_size_constant() {
let iface = CircuitInterface {
name: "SizeTest".to_string(),
inputs: vec![],
outputs: vec![Value::Tuple(vec![
Value::Scalar {
size_in_bits: 64,
kind: ScalarKind::Unsigned,
},
Value::Bool,
])],
};
let structs = gen_all_custom_structs(&iface);
let struct_code = structs[0].to_string();
assert!(struct_code.contains("const SIZE : usize = 9usize"));
}
}