use crate::error::{AmateRSError, ErrorContext, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ConstantType {
Integer,
Boolean,
Float,
Bytes,
}
impl std::fmt::Display for ConstantType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstantType::Integer => write!(f, "integer"),
ConstantType::Boolean => write!(f, "boolean"),
ConstantType::Float => write!(f, "float"),
ConstantType::Bytes => write!(f, "bytes"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CircuitNode {
Load(String),
Constant(CircuitValue),
EncryptedConstant {
data: Vec<u8>,
original_type: ConstantType,
},
BinaryOp {
op: BinaryOperator,
left: Box<CircuitNode>,
right: Box<CircuitNode>,
},
UnaryOp {
op: UnaryOperator,
operand: Box<CircuitNode>,
},
Compare {
op: CompareOperator,
left: Box<CircuitNode>,
right: Box<CircuitNode>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinaryOperator {
Add,
Sub,
Mul,
And,
Or,
Xor,
}
impl BinaryOperator {
pub fn as_str(&self) -> &str {
match self {
BinaryOperator::Add => "+",
BinaryOperator::Sub => "-",
BinaryOperator::Mul => "*",
BinaryOperator::And => "AND",
BinaryOperator::Or => "OR",
BinaryOperator::Xor => "XOR",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UnaryOperator {
Not,
Neg,
}
impl UnaryOperator {
pub fn as_str(&self) -> &str {
match self {
UnaryOperator::Not => "NOT",
UnaryOperator::Neg => "-",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompareOperator {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
impl CompareOperator {
pub fn as_str(&self) -> &str {
match self {
CompareOperator::Eq => "=",
CompareOperator::Ne => "!=",
CompareOperator::Lt => "<",
CompareOperator::Le => "<=",
CompareOperator::Gt => ">",
CompareOperator::Ge => ">=",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CircuitValue {
Bool(bool),
U8(u8),
U16(u16),
U32(u32),
U64(u64),
}
impl std::fmt::Display for CircuitNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CircuitNode::Load(name) => write!(f, "Load({})", name),
CircuitNode::Constant(value) => match value {
CircuitValue::Bool(v) => write!(f, "Const({})", v),
CircuitValue::U8(v) => write!(f, "Const({}u8)", v),
CircuitValue::U16(v) => write!(f, "Const({}u16)", v),
CircuitValue::U32(v) => write!(f, "Const({}u32)", v),
CircuitValue::U64(v) => write!(f, "Const({}u64)", v),
},
CircuitNode::EncryptedConstant {
data,
original_type,
} => {
write!(f, "EncryptedConst({}, {} bytes)", original_type, data.len())
}
CircuitNode::BinaryOp { op, left, right } => {
write!(f, "({} {} {})", left, op.as_str(), right)
}
CircuitNode::UnaryOp { op, operand } => {
write!(f, "{}({})", op.as_str(), operand)
}
CircuitNode::Compare { op, left, right } => {
write!(f, "({} {} {})", left, op.as_str(), right)
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EncryptedType {
Bool,
U8,
U16,
U32,
U64,
}
impl EncryptedType {
pub fn bit_width(&self) -> usize {
match self {
EncryptedType::Bool => 1,
EncryptedType::U8 => 8,
EncryptedType::U16 => 16,
EncryptedType::U32 => 32,
EncryptedType::U64 => 64,
}
}
pub fn is_numeric(&self) -> bool {
!matches!(self, EncryptedType::Bool)
}
pub fn is_boolean(&self) -> bool {
matches!(self, EncryptedType::Bool)
}
}
impl std::fmt::Display for EncryptedType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EncryptedType::Bool => write!(f, "bool"),
EncryptedType::U8 => write!(f, "u8"),
EncryptedType::U16 => write!(f, "u16"),
EncryptedType::U32 => write!(f, "u32"),
EncryptedType::U64 => write!(f, "u64"),
}
}
}
#[derive(Debug, Clone)]
pub struct Circuit {
pub root: CircuitNode,
pub variable_types: HashMap<String, EncryptedType>,
pub result_type: EncryptedType,
pub depth: usize,
pub gate_count: usize,
}
impl Circuit {
pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
let result_type = Self::infer_type(&root, &variable_types)?;
let depth = Self::compute_depth(&root);
let gate_count = Self::count_gates(&root);
Ok(Self {
root,
variable_types,
result_type,
depth,
gate_count,
})
}
fn infer_type(
node: &CircuitNode,
variable_types: &HashMap<String, EncryptedType>,
) -> Result<EncryptedType> {
match node {
CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
AmateRSError::FheComputation(ErrorContext::new(format!(
"Undefined variable: {}",
name
)))
}),
CircuitNode::Constant(value) => Ok(match value {
CircuitValue::Bool(_) => EncryptedType::Bool,
CircuitValue::U8(_) => EncryptedType::U8,
CircuitValue::U16(_) => EncryptedType::U16,
CircuitValue::U32(_) => EncryptedType::U32,
CircuitValue::U64(_) => EncryptedType::U64,
}),
CircuitNode::EncryptedConstant { original_type, .. } => {
Ok(match original_type {
ConstantType::Boolean => EncryptedType::Bool,
ConstantType::Integer | ConstantType::Float | ConstantType::Bytes => {
EncryptedType::U64
}
})
}
CircuitNode::BinaryOp { op, left, right } => {
let left_type = Self::infer_type(left, variable_types)?;
let right_type = Self::infer_type(right, variable_types)?;
match op {
BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Logical operation requires boolean operands, got {} and {}",
left_type, right_type
))));
}
Ok(EncryptedType::Bool)
}
BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
if !left_type.is_numeric() || !right_type.is_numeric() {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Arithmetic operation requires numeric operands, got {} and {}",
left_type, right_type
))));
}
if left_type != right_type {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Arithmetic operation requires matching types, got {} and {}",
left_type, right_type
))));
}
Ok(left_type)
}
}
}
CircuitNode::UnaryOp { op, operand } => {
let operand_type = Self::infer_type(operand, variable_types)?;
match op {
UnaryOperator::Not => {
if operand_type != EncryptedType::Bool {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"NOT operation requires boolean operand, got {}",
operand_type
))));
}
Ok(EncryptedType::Bool)
}
UnaryOperator::Neg => {
if !operand_type.is_numeric() {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Negation operation requires numeric operand, got {}",
operand_type
))));
}
Ok(operand_type)
}
}
}
CircuitNode::Compare { left, right, .. } => {
let left_type = Self::infer_type(left, variable_types)?;
let right_type = Self::infer_type(right, variable_types)?;
if left_type != right_type {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Comparison requires matching types, got {} and {}",
left_type, right_type
))));
}
Ok(EncryptedType::Bool)
}
}
}
fn compute_depth(node: &CircuitNode) -> usize {
match node {
CircuitNode::Load(_)
| CircuitNode::Constant(_)
| CircuitNode::EncryptedConstant { .. } => 1,
CircuitNode::BinaryOp { left, right, .. }
| CircuitNode::Compare { left, right, .. } => {
1 + Self::compute_depth(left).max(Self::compute_depth(right))
}
CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
}
}
fn count_gates(node: &CircuitNode) -> usize {
match node {
CircuitNode::Load(_)
| CircuitNode::Constant(_)
| CircuitNode::EncryptedConstant { .. } => 0,
CircuitNode::BinaryOp { left, right, .. }
| CircuitNode::Compare { left, right, .. } => {
1 + Self::count_gates(left) + Self::count_gates(right)
}
CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
}
}
pub fn validate(&self) -> Result<()> {
Self::validate_node(&self.root, &self.variable_types)?;
Ok(())
}
fn validate_node(
node: &CircuitNode,
variable_types: &HashMap<String, EncryptedType>,
) -> Result<()> {
match node {
CircuitNode::Load(name) => {
if !variable_types.contains_key(name) {
return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Undefined variable: {}",
name
))));
}
Ok(())
}
CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => Ok(()),
CircuitNode::BinaryOp { left, right, .. }
| CircuitNode::Compare { left, right, .. } => {
Self::validate_node(left, variable_types)?;
Self::validate_node(right, variable_types)?;
Ok(())
}
CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
}
}
}
#[derive(Default)]
pub struct CircuitBuilder {
variable_types: HashMap<String, EncryptedType>,
}
impl CircuitBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
&self.variable_types
}
pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
self.variable_types.clone()
}
pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
self.variable_types.insert(name.into(), ty);
self
}
pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
Circuit::new(root, self.variable_types.clone())
}
pub fn load(&self, name: impl Into<String>) -> CircuitNode {
CircuitNode::Load(name.into())
}
pub fn constant(&self, value: CircuitValue) -> CircuitNode {
CircuitNode::Constant(value)
}
pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::BinaryOp {
op: BinaryOperator::Add,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::BinaryOp {
op: BinaryOperator::Sub,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::BinaryOp {
op: BinaryOperator::Mul,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::BinaryOp {
op: BinaryOperator::And,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::BinaryOp {
op: BinaryOperator::Or,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::BinaryOp {
op: BinaryOperator::Xor,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn not(&self, operand: CircuitNode) -> CircuitNode {
CircuitNode::UnaryOp {
op: UnaryOperator::Not,
operand: Box::new(operand),
}
}
pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::Compare {
op: CompareOperator::Eq,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::Compare {
op: CompareOperator::Lt,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
CircuitNode::Compare {
op: CompareOperator::Gt,
left: Box::new(left),
right: Box::new(right),
}
}
pub fn encrypted_constant(&self, data: Vec<u8>, original_type: ConstantType) -> CircuitNode {
CircuitNode::EncryptedConstant {
data,
original_type,
}
}
}
pub fn encrypt_constant(value: &CircuitValue, key: &[u8]) -> Result<Vec<u8>> {
if key.is_empty() {
return Err(AmateRSError::FheComputation(ErrorContext::new(
"Encryption key must not be empty".to_string(),
)));
}
let (type_tag, plaintext): (u8, Vec<u8>) = match value {
CircuitValue::Bool(v) => (0x00, vec![if *v { 1 } else { 0 }]),
CircuitValue::U8(v) => (0x01, v.to_le_bytes().to_vec()),
CircuitValue::U16(v) => (0x02, v.to_le_bytes().to_vec()),
CircuitValue::U32(v) => (0x03, v.to_le_bytes().to_vec()),
CircuitValue::U64(v) => (0x04, v.to_le_bytes().to_vec()),
};
let keystream = derive_keystream(key, plaintext.len());
let ciphertext: Vec<u8> = plaintext
.iter()
.zip(keystream.iter())
.map(|(p, k)| p ^ k)
.collect();
let mut output = Vec::with_capacity(1 + ciphertext.len());
output.push(type_tag);
output.extend_from_slice(&ciphertext);
Ok(output)
}
pub fn decrypt_constant(data: &[u8], key: &[u8]) -> Result<CircuitValue> {
if key.is_empty() {
return Err(AmateRSError::FheComputation(ErrorContext::new(
"Decryption key must not be empty".to_string(),
)));
}
if data.is_empty() {
return Err(AmateRSError::FheComputation(ErrorContext::new(
"Encrypted constant data is empty".to_string(),
)));
}
let type_tag = data[0];
let ciphertext = &data[1..];
let keystream = derive_keystream(key, ciphertext.len());
let plaintext: Vec<u8> = ciphertext
.iter()
.zip(keystream.iter())
.map(|(c, k)| c ^ k)
.collect();
match type_tag {
0x00 => {
if plaintext.is_empty() {
return Err(AmateRSError::FheComputation(ErrorContext::new(
"Encrypted boolean constant has no payload".to_string(),
)));
}
Ok(CircuitValue::Bool(plaintext[0] != 0))
}
0x01 => {
let arr: [u8; 1] = plaintext.as_slice().try_into().map_err(|_| {
AmateRSError::FheComputation(ErrorContext::new(
"Invalid encrypted u8 constant length".to_string(),
))
})?;
Ok(CircuitValue::U8(u8::from_le_bytes(arr)))
}
0x02 => {
let arr: [u8; 2] = plaintext.as_slice().try_into().map_err(|_| {
AmateRSError::FheComputation(ErrorContext::new(
"Invalid encrypted u16 constant length".to_string(),
))
})?;
Ok(CircuitValue::U16(u16::from_le_bytes(arr)))
}
0x03 => {
let arr: [u8; 4] = plaintext.as_slice().try_into().map_err(|_| {
AmateRSError::FheComputation(ErrorContext::new(
"Invalid encrypted u32 constant length".to_string(),
))
})?;
Ok(CircuitValue::U32(u32::from_le_bytes(arr)))
}
0x04 => {
let arr: [u8; 8] = plaintext.as_slice().try_into().map_err(|_| {
AmateRSError::FheComputation(ErrorContext::new(
"Invalid encrypted u64 constant length".to_string(),
))
})?;
Ok(CircuitValue::U64(u64::from_le_bytes(arr)))
}
_ => Err(AmateRSError::FheComputation(ErrorContext::new(format!(
"Unknown encrypted constant type tag: 0x{:02x}",
type_tag
)))),
}
}
pub fn encrypt_circuit_constants(node: &CircuitNode, key: &[u8]) -> Result<CircuitNode> {
match node {
CircuitNode::Load(name) => Ok(CircuitNode::Load(name.clone())),
CircuitNode::Constant(value) => {
let data = encrypt_constant(value, key)?;
let original_type = match value {
CircuitValue::Bool(_) => ConstantType::Boolean,
CircuitValue::U8(_)
| CircuitValue::U16(_)
| CircuitValue::U32(_)
| CircuitValue::U64(_) => ConstantType::Integer,
};
Ok(CircuitNode::EncryptedConstant {
data,
original_type,
})
}
CircuitNode::EncryptedConstant {
data,
original_type,
} => Ok(CircuitNode::EncryptedConstant {
data: data.clone(),
original_type: *original_type,
}),
CircuitNode::BinaryOp { op, left, right } => {
let left = encrypt_circuit_constants(left, key)?;
let right = encrypt_circuit_constants(right, key)?;
Ok(CircuitNode::BinaryOp {
op: *op,
left: Box::new(left),
right: Box::new(right),
})
}
CircuitNode::UnaryOp { op, operand } => {
let operand = encrypt_circuit_constants(operand, key)?;
Ok(CircuitNode::UnaryOp {
op: *op,
operand: Box::new(operand),
})
}
CircuitNode::Compare { op, left, right } => {
let left = encrypt_circuit_constants(left, key)?;
let right = encrypt_circuit_constants(right, key)?;
Ok(CircuitNode::Compare {
op: *op,
left: Box::new(left),
right: Box::new(right),
})
}
}
}
fn derive_keystream(key: &[u8], length: usize) -> Vec<u8> {
let mut keystream = Vec::with_capacity(length);
let mut block_index: u64 = 0;
while keystream.len() < length {
let mut hash: u64 = 0xcbf29ce484222325;
for &byte in key {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
for &byte in &block_index.to_le_bytes() {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
for &byte in &hash.to_le_bytes() {
if keystream.len() < length {
keystream.push(byte);
}
}
block_index += 1;
}
keystream
}
pub fn is_encrypted_constant(node: &CircuitNode) -> bool {
matches!(node, CircuitNode::EncryptedConstant { .. })
}
pub fn count_plaintext_constants(node: &CircuitNode) -> usize {
match node {
CircuitNode::Constant(_) => 1,
CircuitNode::EncryptedConstant { .. } | CircuitNode::Load(_) => 0,
CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
count_plaintext_constants(left) + count_plaintext_constants(right)
}
CircuitNode::UnaryOp { operand, .. } => count_plaintext_constants(operand),
}
}
pub fn count_encrypted_constants(node: &CircuitNode) -> usize {
match node {
CircuitNode::EncryptedConstant { .. } => 1,
CircuitNode::Constant(_) | CircuitNode::Load(_) => 0,
CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
count_encrypted_constants(left) + count_encrypted_constants(right)
}
CircuitNode::UnaryOp { operand, .. } => count_encrypted_constants(operand),
}
}
#[derive(Debug, Clone, Default)]
#[deprecated(
since = "0.1.0",
note = "Use CircuitOptimizer from optimizer module instead"
)]
pub struct CircuitOptimizer;
#[allow(deprecated)]
impl CircuitOptimizer {
pub fn new() -> Self {
Self
}
pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
advanced_optimizer.optimize(circuit)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_circuit_builder() -> Result<()> {
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let sum = builder.add(a, b);
let circuit = builder.build(sum)?;
assert_eq!(circuit.result_type, EncryptedType::U8);
assert_eq!(circuit.gate_count, 1);
Ok(())
}
#[test]
fn test_type_inference() -> Result<()> {
let mut builder = CircuitBuilder::new();
builder
.declare_variable("x", EncryptedType::Bool)
.declare_variable("y", EncryptedType::Bool);
let x = builder.load("x");
let y = builder.load("y");
let result = builder.and(x, y);
let circuit = builder.build(result)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
Ok(())
}
#[test]
fn test_type_mismatch_error() {
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::Bool);
let a = builder.load("a");
let b = builder.load("b");
let invalid = builder.add(a, b);
let result = builder.build(invalid);
assert!(result.is_err());
}
#[test]
#[allow(deprecated)]
fn test_constant_folding() -> Result<()> {
let optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U8(5));
let b = builder.constant(CircuitValue::U8(3));
let sum = builder.add(a, b);
let circuit = Circuit::new(sum, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
match optimized.root {
CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
_ => Err(AmateRSError::FheComputation(ErrorContext::new(
"Constant folding failed".to_string(),
))),
}
}
#[test]
fn test_circuit_depth() -> Result<()> {
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8)
.declare_variable("c", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let c = builder.load("c");
let sum1 = builder.add(a, b);
let sum2 = builder.add(sum1, c);
let circuit = builder.build(sum2)?;
assert_eq!(circuit.depth, 3);
Ok(())
}
#[test]
fn test_encrypted_constant_creation() {
let builder = CircuitBuilder::new();
let enc = builder.encrypted_constant(vec![0xAA, 0xBB], ConstantType::Integer);
match enc {
CircuitNode::EncryptedConstant {
data,
original_type,
} => {
assert_eq!(data, vec![0xAA, 0xBB]);
assert_eq!(original_type, ConstantType::Integer);
}
_ => panic!("Expected EncryptedConstant"),
}
}
#[test]
fn test_encrypt_constant_bool() -> Result<()> {
let key = b"test-encryption-key";
let value = CircuitValue::Bool(true);
let encrypted = encrypt_constant(&value, key)?;
assert!(!encrypted.is_empty());
assert_eq!(encrypted[0], 0x00);
let decrypted = decrypt_constant(&encrypted, key)?;
assert_eq!(decrypted, value);
Ok(())
}
#[test]
fn test_encrypt_constant_u8() -> Result<()> {
let key = b"test-key-u8";
let value = CircuitValue::U8(42);
let encrypted = encrypt_constant(&value, key)?;
assert_eq!(encrypted[0], 0x01);
let decrypted = decrypt_constant(&encrypted, key)?;
assert_eq!(decrypted, value);
Ok(())
}
#[test]
fn test_encrypt_constant_u16() -> Result<()> {
let key = b"test-key-u16";
let value = CircuitValue::U16(12345);
let encrypted = encrypt_constant(&value, key)?;
assert_eq!(encrypted[0], 0x02);
let decrypted = decrypt_constant(&encrypted, key)?;
assert_eq!(decrypted, value);
Ok(())
}
#[test]
fn test_encrypt_constant_u32() -> Result<()> {
let key = b"test-key-u32";
let value = CircuitValue::U32(1_000_000);
let encrypted = encrypt_constant(&value, key)?;
assert_eq!(encrypted[0], 0x03);
let decrypted = decrypt_constant(&encrypted, key)?;
assert_eq!(decrypted, value);
Ok(())
}
#[test]
fn test_encrypt_constant_u64() -> Result<()> {
let key = b"test-key-u64";
let value = CircuitValue::U64(0xDEAD_BEEF_CAFE_BABE);
let encrypted = encrypt_constant(&value, key)?;
assert_eq!(encrypted[0], 0x04);
let decrypted = decrypt_constant(&encrypted, key)?;
assert_eq!(decrypted, value);
Ok(())
}
#[test]
fn test_encrypt_constant_wrong_key_produces_wrong_value() -> Result<()> {
let key1 = b"correct-key";
let key2 = b"wrong-key!!";
let value = CircuitValue::U8(42);
let encrypted = encrypt_constant(&value, key1)?;
let decrypted = decrypt_constant(&encrypted, key2)?;
assert_ne!(decrypted, value);
Ok(())
}
#[test]
fn test_encrypt_constant_empty_key_error() {
let key: &[u8] = &[];
let value = CircuitValue::U8(1);
let result = encrypt_constant(&value, key);
assert!(result.is_err());
}
#[test]
fn test_decrypt_constant_empty_data_error() {
let key = b"some-key";
let result = decrypt_constant(&[], key);
assert!(result.is_err());
}
#[test]
fn test_encrypt_circuit_constants_transforms_all() -> Result<()> {
let builder = CircuitBuilder::new();
let key = b"circuit-encryption-key";
let a = builder.constant(CircuitValue::U8(5));
let b = builder.constant(CircuitValue::U8(3));
let sum = builder.add(a, b);
assert_eq!(count_plaintext_constants(&sum), 2);
assert_eq!(count_encrypted_constants(&sum), 0);
let encrypted = encrypt_circuit_constants(&sum, key)?;
assert_eq!(count_plaintext_constants(&encrypted), 0);
assert_eq!(count_encrypted_constants(&encrypted), 2);
match &encrypted {
CircuitNode::BinaryOp { op, left, right } => {
assert_eq!(*op, BinaryOperator::Add);
assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
}
_ => panic!("Expected BinaryOp after encryption"),
}
Ok(())
}
#[test]
fn test_encrypt_circuit_constants_preserves_loads() -> Result<()> {
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let key = b"key-for-loads-test";
let x = builder.load("x");
let c = builder.constant(CircuitValue::U8(10));
let sum = builder.add(x, c);
let encrypted = encrypt_circuit_constants(&sum, key)?;
match &encrypted {
CircuitNode::BinaryOp { left, right, .. } => {
assert!(matches!(**left, CircuitNode::Load(ref name) if name == "x"));
assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
}
_ => panic!("Expected BinaryOp"),
}
Ok(())
}
#[test]
fn test_encrypt_circuit_constants_already_encrypted_pass_through() -> Result<()> {
let builder = CircuitBuilder::new();
let key = b"key-pass-through";
let enc = builder.encrypted_constant(vec![0x01, 0x02, 0x03], ConstantType::Integer);
let original_data = vec![0x01, 0x02, 0x03];
let result = encrypt_circuit_constants(&enc, key)?;
match result {
CircuitNode::EncryptedConstant {
data,
original_type,
} => {
assert_eq!(data, original_data);
assert_eq!(original_type, ConstantType::Integer);
}
_ => panic!("Expected EncryptedConstant pass-through"),
}
Ok(())
}
#[test]
fn test_encrypted_constant_display() {
let node = CircuitNode::EncryptedConstant {
data: vec![0xAA, 0xBB, 0xCC],
original_type: ConstantType::Boolean,
};
let display = format!("{}", node);
assert!(display.contains("EncryptedConst"));
assert!(display.contains("boolean"));
assert!(display.contains("3 bytes"));
}
#[test]
fn test_circuit_node_display_variants() {
let load = CircuitNode::Load("x".to_string());
assert_eq!(format!("{}", load), "Load(x)");
let constant = CircuitNode::Constant(CircuitValue::U8(42));
assert_eq!(format!("{}", constant), "Const(42u8)");
let bool_const = CircuitNode::Constant(CircuitValue::Bool(true));
assert_eq!(format!("{}", bool_const), "Const(true)");
}
#[test]
fn test_constant_type_display() {
assert_eq!(format!("{}", ConstantType::Integer), "integer");
assert_eq!(format!("{}", ConstantType::Boolean), "boolean");
assert_eq!(format!("{}", ConstantType::Float), "float");
assert_eq!(format!("{}", ConstantType::Bytes), "bytes");
}
#[test]
fn test_constant_type_variants() {
let variants = [
ConstantType::Integer,
ConstantType::Boolean,
ConstantType::Float,
ConstantType::Bytes,
];
for (i, a) in variants.iter().enumerate() {
for (j, b) in variants.iter().enumerate() {
if i == j {
assert_eq!(a, b);
} else {
assert_ne!(a, b);
}
}
}
}
#[test]
fn test_constant_type_serialization_roundtrip() -> Result<()> {
let types = [
ConstantType::Integer,
ConstantType::Boolean,
ConstantType::Float,
ConstantType::Bytes,
];
for ct in &types {
let json = serde_json::to_string(ct).map_err(|e| {
AmateRSError::FheComputation(ErrorContext::new(format!(
"Serialization failed: {}",
e
)))
})?;
let deserialized: ConstantType = serde_json::from_str(&json).map_err(|e| {
AmateRSError::FheComputation(ErrorContext::new(format!(
"Deserialization failed: {}",
e
)))
})?;
assert_eq!(*ct, deserialized);
}
Ok(())
}
#[test]
fn test_is_encrypted_constant() {
let enc = CircuitNode::EncryptedConstant {
data: vec![1, 2, 3],
original_type: ConstantType::Integer,
};
assert!(is_encrypted_constant(&enc));
let plain = CircuitNode::Constant(CircuitValue::U8(5));
assert!(!is_encrypted_constant(&plain));
let load = CircuitNode::Load("x".to_string());
assert!(!is_encrypted_constant(&load));
}
#[test]
fn test_encrypted_constant_in_circuit_validation() -> Result<()> {
let enc = CircuitNode::EncryptedConstant {
data: vec![0x00, 0x01],
original_type: ConstantType::Boolean,
};
let circuit = Circuit::new(enc, HashMap::new())?;
circuit.validate()?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
Ok(())
}
#[test]
fn test_encrypted_constant_depth_and_gate_count() -> Result<()> {
let builder = CircuitBuilder::new();
let enc = builder.encrypted_constant(vec![0x01, 0x42], ConstantType::Integer);
let circuit = Circuit::new(enc, HashMap::new())?;
assert_eq!(circuit.depth, 1);
assert_eq!(circuit.gate_count, 0);
Ok(())
}
#[test]
fn test_mixed_plain_and_encrypted_constants() -> Result<()> {
let builder = CircuitBuilder::new();
let key = b"mixed-circuit-key";
let plain = builder.constant(CircuitValue::U8(10));
let encrypted_data = encrypt_constant(&CircuitValue::U8(20), key)?;
let enc = builder.encrypted_constant(encrypted_data, ConstantType::Integer);
let not_node = CircuitNode::UnaryOp {
op: UnaryOperator::Not,
operand: Box::new(CircuitNode::Constant(CircuitValue::Bool(true))),
};
assert_eq!(count_plaintext_constants(&plain), 1);
assert_eq!(count_encrypted_constants(&plain), 0);
assert_eq!(count_plaintext_constants(&enc), 0);
assert_eq!(count_encrypted_constants(&enc), 1);
assert_eq!(count_plaintext_constants(¬_node), 1);
assert_eq!(count_encrypted_constants(¬_node), 0);
Ok(())
}
#[test]
fn test_encrypt_constant_deterministic() -> Result<()> {
let key = b"deterministic-test-key";
let value = CircuitValue::U32(999);
let enc1 = encrypt_constant(&value, key)?;
let enc2 = encrypt_constant(&value, key)?;
assert_eq!(enc1, enc2);
Ok(())
}
#[test]
fn test_encrypt_constant_different_keys_differ() -> Result<()> {
let key1 = b"key-alpha";
let key2 = b"key-bravo";
let value = CircuitValue::U64(123456789);
let enc1 = encrypt_constant(&value, key1)?;
let enc2 = encrypt_constant(&value, key2)?;
assert_eq!(enc1[0], enc2[0]); assert_ne!(enc1[1..], enc2[1..]);
Ok(())
}
#[test]
fn test_encrypt_decrypt_roundtrip_all_types() -> Result<()> {
let key = b"roundtrip-all-types";
let values = vec![
CircuitValue::Bool(false),
CircuitValue::Bool(true),
CircuitValue::U8(0),
CircuitValue::U8(255),
CircuitValue::U16(0),
CircuitValue::U16(65535),
CircuitValue::U32(0),
CircuitValue::U32(u32::MAX),
CircuitValue::U64(0),
CircuitValue::U64(u64::MAX),
];
for value in &values {
let encrypted = encrypt_constant(value, key)?;
let decrypted = decrypt_constant(&encrypted, key)?;
assert_eq!(*value, decrypted, "Roundtrip failed for {:?}", value);
}
Ok(())
}
#[test]
fn test_encrypt_circuit_constants_nested() -> Result<()> {
let builder = CircuitBuilder::new();
let key = b"nested-circuit-key";
let t = builder.constant(CircuitValue::Bool(true));
let f = builder.constant(CircuitValue::Bool(false));
let and_node = builder.and(t, f);
let not_node = builder.not(and_node);
assert_eq!(count_plaintext_constants(¬_node), 2);
assert_eq!(count_encrypted_constants(¬_node), 0);
let encrypted = encrypt_circuit_constants(¬_node, key)?;
assert_eq!(count_plaintext_constants(&encrypted), 0);
assert_eq!(count_encrypted_constants(&encrypted), 2);
match &encrypted {
CircuitNode::UnaryOp { op, operand } => {
assert_eq!(*op, UnaryOperator::Not);
match operand.as_ref() {
CircuitNode::BinaryOp { op, left, right } => {
assert_eq!(*op, BinaryOperator::And);
assert!(is_encrypted_constant(left));
assert!(is_encrypted_constant(right));
}
_ => panic!("Expected BinaryOp inside UnaryOp"),
}
}
_ => panic!("Expected UnaryOp at root"),
}
Ok(())
}
}