use crate::utils::read_conf_ix_interface;
use arcis_interface::{CircuitInterface, Value};
use convert_case::{Case, Casing};
use std::collections::HashSet;
#[derive(Debug)]
struct EncryptionComponents {
has_public_key: bool,
has_nonce: bool,
ciphertext_count: usize,
}
impl EncryptionComponents {
fn new() -> Self {
Self {
has_public_key: false,
has_nonce: false,
ciphertext_count: 0,
}
}
fn from_value(value: &Value) -> Self {
let mut components = Self::new();
components.extract_from(value);
components
}
fn get_type(&self) -> Option<(EncryptionType, usize)> {
match (
self.has_public_key,
self.has_nonce,
self.ciphertext_count > 0,
) {
(true, true, true) => Some((EncryptionType::Shared, self.ciphertext_count)),
(false, true, true) => Some((EncryptionType::Mxe, self.ciphertext_count)),
(false, false, true) => Some((EncryptionType::EncData, self.ciphertext_count)),
_ => None,
}
}
fn extract_from(&mut self, value: &Value) {
match value {
Value::PublicKey { .. } => {
self.has_public_key = true;
}
Value::Scalar { size_in_bits: 128 } => {
self.has_nonce = true;
}
Value::Ciphertext { .. } => {
self.ciphertext_count += 1;
}
Value::Array(values) if !values.is_empty() => {
for v in values {
self.extract_from(v);
}
}
Value::Struct(values) => {
for v in values {
self.extract_from(v);
}
}
Value::Tuple(values) => {
for v in values {
self.extract_from(v);
}
}
_ => {}
}
}
}
pub fn gen_callback_output_struct(conf_ix_name: &str) -> proc_macro2::TokenStream {
let iface: CircuitInterface = read_conf_ix_interface(conf_ix_name);
let struct_name = syn::Ident::new(
&format!("{}Output", iface.name.to_case(Case::Pascal)),
proc_macro2::Span::call_site(),
);
let custom_structs = gen_all_custom_structs(&iface);
let fields = iface.outputs.iter().enumerate().map(|(i, val)| {
let field_name = syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site());
let ty = value_to_type_for_output(val, &iface.name.to_case(Case::Pascal), i);
quote::quote! { pub #field_name: #ty }
});
let x = quote::quote! {
#(#custom_structs)*
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #struct_name {
#(#fields),*
}
};
x
}
fn value_to_type(val: &Value) -> Vec<proc_macro2::TokenStream> {
match val {
Value::Ciphertext { .. } | Value::MScalar { .. } | Value::MFloat { .. } | Value::MBool => {
vec![quote::quote!([u8; 32])]
}
Value::Scalar { size_in_bits } => match size_in_bits {
8 => vec![quote::quote!(u8)],
16 => vec![quote::quote!(u16)],
32 => vec![quote::quote!(u32)],
64 => vec![quote::quote!(u64)],
128 => vec![quote::quote!(u128)],
_ => panic!("Unsupported scalar size: {}", size_in_bits),
},
Value::Float { size_in_bits } => match size_in_bits {
32 => vec![quote::quote!(f32)],
64 => vec![quote::quote!(f64)],
_ => panic!("Unsupported float size: {}", size_in_bits),
},
Value::Bool => vec![quote::quote!(bool)],
Value::PublicKey { .. } => vec![quote::quote!([u8; 32])],
Value::Point => vec![quote::quote!([u8; 32])],
Value::Array(inner) => {
let len = inner.len();
if len == 0 {
return vec![];
}
let tys = value_to_type(&inner[0]);
if tys.is_empty() {
vec![]
} else {
vec![quote::quote!([#(tys[0]); #len])]
}
}
Value::Tuple(inner) => {
let tys = inner.iter().flat_map(value_to_type);
vec![quote::quote!((#(#tys),*))]
}
Value::Struct(inner) => {
match extract_and_get_encryption_type(&Value::Struct(inner.clone())) {
Some((EncryptionType::Shared, len)) => vec![quote::quote! {
SharedEncryptedStruct<#len>
}],
Some((EncryptionType::Mxe, len)) => vec![quote::quote! {
MXEEncryptedStruct<#len>
}],
Some((EncryptionType::EncData, len)) => vec![quote::quote! {
EncDataStruct<#len>
}],
None => {
inner.iter().flat_map(value_to_type).collect()
}
}
}
}
}
fn extract_and_get_encryption_type(value: &Value) -> Option<(EncryptionType, usize)> {
EncryptionComponents::from_value(value).get_type()
}
#[derive(Debug, PartialEq)]
enum EncryptionType {
Shared,
Mxe,
EncData,
}
pub fn gen_all_custom_structs(iface: &CircuitInterface) -> Vec<proc_macro2::TokenStream> {
let mut seen = HashSet::new();
let mut structs = Vec::new();
let base_prefix = iface.name.to_case(Case::Pascal);
for (i, val) in iface.outputs.iter().enumerate() {
let prefix = format!("{}OutputStruct{}", base_prefix, i);
collect_structs(val, &mut seen, &mut structs, &prefix);
}
structs
}
fn generate_nested_struct_name(prefix: &str, seen_count: usize) -> String {
format!("{}{}", prefix, seen_count)
}
fn collect_structs(
val: &Value,
seen: &mut HashSet<String>,
structs: &mut Vec<proc_macro2::TokenStream>,
prefix: &str,
) {
match val {
Value::Struct(inner) => {
if extract_and_get_encryption_type(&Value::Struct(inner.clone())).is_some() {
return;
}
let struct_name = if prefix.contains("OutputStruct") && !seen.contains(prefix) {
prefix.to_string()
} else {
generate_nested_struct_name(prefix, 0)
};
if seen.insert(struct_name.clone()) {
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
let field_types: Vec<_> = inner
.iter()
.enumerate()
.map(|(i, v)| value_to_type_for_structs_with_index(v, &struct_name, seen, i))
.collect();
let fields = field_types.iter().enumerate().map(|(i, ty)| {
let field_name =
syn::Ident::new(&format!("field_{}", i), proc_macro2::Span::call_site());
quote::quote! { pub #field_name: #ty }
});
structs.push(quote::quote! {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ident {
#(#fields),*
}
});
for (i, v) in inner.iter().enumerate() {
collect_structs(v, seen, structs, &format!("{}{}", struct_name, i));
}
}
}
Value::Array(inner) => {
if !inner.is_empty() {
if !matches!(&inner[0], Value::Tuple(_)) {
collect_structs(&inner[0], seen, structs, prefix);
}
}
}
Value::Tuple(inner) => {
let struct_name = prefix.to_string();
if seen.insert(struct_name.clone()) {
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
if inner.is_empty() {
structs.push(quote::quote! {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ident {
}
});
} else {
let field_types: Vec<_> = inner
.iter()
.enumerate()
.map(|(i, v)| {
value_to_type_for_structs_with_index(v, &struct_name, seen, i)
})
.collect();
let fields = field_types.iter().enumerate().map(|(i, ty)| {
let field_name = syn::Ident::new(
&format!("field_{}", i),
proc_macro2::Span::call_site(),
);
quote::quote! { pub #field_name: #ty }
});
structs.push(quote::quote! {
#[derive(AnchorSerialize, AnchorDeserialize)]
pub struct #ident {
#(#fields),*
}
});
for (field_idx, v) in inner.iter().enumerate() {
let nested_prefix = generate_nested_struct_name(&struct_name, field_idx);
collect_structs(v, seen, structs, &nested_prefix);
}
}
}
}
_ => {}
}
}
fn value_to_type_for_structs(
val: &Value,
prefix: &str,
seen: &mut HashSet<String>,
) -> proc_macro2::TokenStream {
value_to_type_for_structs_with_index(val, prefix, seen, 0)
}
fn value_to_type_for_structs_with_index(
val: &Value,
prefix: &str,
seen: &mut HashSet<String>,
index: usize,
) -> proc_macro2::TokenStream {
match val {
Value::Struct(inner) => {
match extract_and_get_encryption_type(&Value::Struct(inner.clone())) {
Some((EncryptionType::Shared, len)) => {
quote::quote! { SharedEncryptedStruct<#len> }
}
Some((EncryptionType::Mxe, len)) => quote::quote! { MXEEncryptedStruct<#len> },
Some((EncryptionType::EncData, len)) => quote::quote! { EncDataStruct<#len> },
None => {
let struct_name = format!("{}{}", prefix, index);
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
quote::quote! { #ident }
}
}
}
Value::Array(inner) => {
if inner.is_empty() {
quote::quote!([(); 0])
} else {
let ty = value_to_type_for_structs(&inner[0], prefix, seen);
let len = inner.len();
quote::quote!([#ty; #len])
}
}
Value::Tuple(_) => {
panic!("Tuple in circuit return type is not supported.");
}
_ => value_to_type(val)
.into_iter()
.next()
.unwrap_or_else(|| quote::quote!(())),
}
}
fn value_to_type_for_output(val: &Value, prefix: &str, i: usize) -> proc_macro2::TokenStream {
match val {
Value::Struct(inner) => {
match extract_and_get_encryption_type(&Value::Struct(inner.clone())) {
Some((EncryptionType::Shared, len)) => quote::quote! {
SharedEncryptedStruct<#len>
},
Some((EncryptionType::Mxe, len)) => quote::quote! {
MXEEncryptedStruct<#len>
},
Some((EncryptionType::EncData, len)) => quote::quote! {
EncDataStruct<#len>
},
None => {
let struct_name = format!("{}OutputStruct{}", prefix, i);
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
quote::quote! { #ident }
}
}
}
Value::Array(inner) => {
let len = inner.len();
if len == 0 {
quote::quote!([(); 0])
} else {
let ty = value_to_type_for_output(&inner[0], prefix, i);
quote::quote!([#ty; #len])
}
}
Value::Tuple(_inner) => {
let struct_name = format!("{}OutputStruct{}", prefix, i);
let ident = syn::Ident::new(&struct_name, proc_macro2::Span::call_site());
quote::quote! { #ident }
}
_ => value_to_type(val)
.into_iter()
.next()
.unwrap_or_else(|| quote::quote!(())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_semantic_encryption_detection() {
let shared_v = Value::Struct(vec![
Value::Struct(vec![
Value::PublicKey { size_in_bits: 255 },
Value::Scalar { size_in_bits: 128 },
]),
Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
]),
]);
assert_eq!(
extract_and_get_encryption_type(&shared_v),
Some((EncryptionType::Shared, 1))
);
let mxe_v = Value::Struct(vec![
Value::Struct(vec![Value::Scalar { size_in_bits: 128 }]),
Value::Struct(vec![
Value::Array(vec![
Value::Ciphertext { size_in_bits: 255 },
Value::Ciphertext { size_in_bits: 255 },
]),
Value::Array(vec![]),
]),
]);
assert_eq!(
extract_and_get_encryption_type(&mxe_v),
Some((EncryptionType::Mxe, 2))
);
let enc_data_v = Value::Struct(vec![Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
])]);
assert_eq!(
extract_and_get_encryption_type(&enc_data_v),
Some((EncryptionType::EncData, 1))
);
let normal_v = Value::Struct(vec![Value::Scalar { size_in_bits: 32 }, Value::Bool]);
assert_eq!(extract_and_get_encryption_type(&normal_v), None);
}
#[test]
fn test_gen_all_custom_structs() {
let iface = CircuitInterface {
name: "TestInterface".to_string(),
inputs: vec![],
outputs: vec![
Value::Struct(vec![Value::Scalar { size_in_bits: 32 }, Value::Bool]),
Value::Tuple(vec![
Value::Scalar { size_in_bits: 64 },
Value::Float { size_in_bits: 32 },
]),
],
};
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 2); }
#[test]
fn test_empty_tuple_handling() {
let iface = CircuitInterface {
name: "ManticoreAuc".to_string(),
inputs: vec![],
outputs: vec![Value::Tuple(vec![])], };
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 1);
let struct_code = structs[0].to_string();
assert!(struct_code.contains("pub struct ManticoreAucOutputStruct0"));
assert!(struct_code.contains("# [derive (AnchorSerialize , AnchorDeserialize)]"));
assert!(!struct_code.contains("pub field_"));
let output_type = value_to_type_for_output(&Value::Tuple(vec![]), "ManticoreAuc", 0);
assert_eq!(output_type.to_string(), "ManticoreAucOutputStruct0");
}
#[test]
fn test_naming_consistency_for_nested_structs() {
let iface = CircuitInterface {
name: "InsertOrder".to_string(),
inputs: vec![],
outputs: vec![Value::Tuple(vec![
Value::Struct(vec![
Value::Struct(vec![Value::Scalar { size_in_bits: 128 }]),
Value::Struct(vec![
Value::Array(vec![
Value::Ciphertext { size_in_bits: 255 },
Value::Ciphertext { size_in_bits: 255 },
]),
Value::Array(vec![]),
]),
]),
Value::Struct(vec![
Value::Scalar { size_in_bits: 64 }, Value::Scalar { size_in_bits: 64 }, Value::Scalar { size_in_bits: 64 }, Value::Scalar { size_in_bits: 64 }, Value::Scalar { size_in_bits: 64 }, Value::Scalar { size_in_bits: 64 }, ]),
])],
};
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 2);
let struct_strings: Vec<String> = structs.iter().map(|s| s.to_string()).collect();
assert!(struct_strings
.iter()
.any(|s| s.contains("pub struct InsertOrderOutputStruct0")));
assert!(struct_strings
.iter()
.any(|s| s.contains("pub struct InsertOrderOutputStruct01")));
let main_struct = struct_strings
.iter()
.find(|s| s.contains("pub struct InsertOrderOutputStruct0"))
.expect("Main output struct should exist");
assert!(main_struct.contains("InsertOrderOutputStruct01"));
}
#[test]
fn test_value_to_type_for_output() {
let val = Value::Struct(vec![Value::Scalar { size_in_bits: 32 }, Value::Bool]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(!ty.to_string().is_empty());
let val = Value::Tuple(vec![
Value::Scalar { size_in_bits: 64 },
Value::Float { size_in_bits: 32 },
]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(!ty.to_string().is_empty());
let val = Value::Struct(vec![
Value::Struct(vec![
Value::PublicKey { size_in_bits: 255 },
Value::Scalar { size_in_bits: 128 },
]),
Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
]),
]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(ty.to_string().contains("SharedEncryptedStruct"));
let val = Value::Struct(vec![Value::Struct(vec![
Value::Array(vec![Value::Ciphertext { size_in_bits: 255 }]),
Value::Array(vec![]),
])]);
let ty = value_to_type_for_output(&val, "Test", 0);
assert!(ty.to_string().contains("EncDataStruct"));
}
#[test]
fn test_generate_nested_struct_name() {
assert_eq!(
generate_nested_struct_name("ComplexExample", 0),
"ComplexExample0"
);
assert_eq!(
generate_nested_struct_name("ComplexExample", 5),
"ComplexExample5"
);
assert_eq!(
generate_nested_struct_name("InsertOrderOutputStruct0", 1),
"InsertOrderOutputStruct01"
);
assert_eq!(
generate_nested_struct_name("ComplexExampleOutputStruct0", 2),
"ComplexExampleOutputStruct02"
);
}
#[test]
fn test_struct_generation_with_nested_types() {
let iface = CircuitInterface {
name: "NestedTest".to_string(),
inputs: vec![],
outputs: vec![Value::Struct(vec![
Value::Struct(vec![Value::Scalar { size_in_bits: 32 }, Value::Bool]),
Value::Array(vec![Value::Tuple(vec![
Value::Scalar { size_in_bits: 64 },
Value::Float { size_in_bits: 32 },
])]),
])],
};
let structs = gen_all_custom_structs(&iface);
assert_eq!(structs.len(), 2);
}
}