1use crate::error::{AmateRSError, ErrorContext, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum ConstantType {
13 Integer,
15 Boolean,
17 Float,
19 Bytes,
21}
22
23impl std::fmt::Display for ConstantType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 ConstantType::Integer => write!(f, "integer"),
27 ConstantType::Boolean => write!(f, "boolean"),
28 ConstantType::Float => write!(f, "float"),
29 ConstantType::Bytes => write!(f, "bytes"),
30 }
31 }
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum CircuitNode {
37 Load(String),
39
40 Constant(CircuitValue),
42
43 EncryptedConstant {
49 data: Vec<u8>,
51 original_type: ConstantType,
53 },
54
55 BinaryOp {
57 op: BinaryOperator,
58 left: Box<CircuitNode>,
59 right: Box<CircuitNode>,
60 },
61
62 UnaryOp {
64 op: UnaryOperator,
65 operand: Box<CircuitNode>,
66 },
67
68 Compare {
70 op: CompareOperator,
71 left: Box<CircuitNode>,
72 right: Box<CircuitNode>,
73 },
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum BinaryOperator {
79 Add,
80 Sub,
81 Mul,
82 And,
83 Or,
84 Xor,
85}
86
87impl BinaryOperator {
88 pub fn as_str(&self) -> &str {
90 match self {
91 BinaryOperator::Add => "+",
92 BinaryOperator::Sub => "-",
93 BinaryOperator::Mul => "*",
94 BinaryOperator::And => "AND",
95 BinaryOperator::Or => "OR",
96 BinaryOperator::Xor => "XOR",
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum UnaryOperator {
104 Not,
105 Neg,
106}
107
108impl UnaryOperator {
109 pub fn as_str(&self) -> &str {
111 match self {
112 UnaryOperator::Not => "NOT",
113 UnaryOperator::Neg => "-",
114 }
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum CompareOperator {
121 Eq,
122 Ne,
123 Lt,
124 Le,
125 Gt,
126 Ge,
127}
128
129impl CompareOperator {
130 pub fn as_str(&self) -> &str {
132 match self {
133 CompareOperator::Eq => "=",
134 CompareOperator::Ne => "!=",
135 CompareOperator::Lt => "<",
136 CompareOperator::Le => "<=",
137 CompareOperator::Gt => ">",
138 CompareOperator::Ge => ">=",
139 }
140 }
141}
142
143#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum CircuitValue {
146 Bool(bool),
147 U8(u8),
148 U16(u16),
149 U32(u32),
150 U64(u64),
151}
152
153impl std::fmt::Display for CircuitNode {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 match self {
156 CircuitNode::Load(name) => write!(f, "Load({})", name),
157 CircuitNode::Constant(value) => match value {
158 CircuitValue::Bool(v) => write!(f, "Const({})", v),
159 CircuitValue::U8(v) => write!(f, "Const({}u8)", v),
160 CircuitValue::U16(v) => write!(f, "Const({}u16)", v),
161 CircuitValue::U32(v) => write!(f, "Const({}u32)", v),
162 CircuitValue::U64(v) => write!(f, "Const({}u64)", v),
163 },
164 CircuitNode::EncryptedConstant {
165 data,
166 original_type,
167 } => {
168 write!(f, "EncryptedConst({}, {} bytes)", original_type, data.len())
169 }
170 CircuitNode::BinaryOp { op, left, right } => {
171 write!(f, "({} {} {})", left, op.as_str(), right)
172 }
173 CircuitNode::UnaryOp { op, operand } => {
174 write!(f, "{}({})", op.as_str(), operand)
175 }
176 CircuitNode::Compare { op, left, right } => {
177 write!(f, "({} {} {})", left, op.as_str(), right)
178 }
179 }
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
185pub enum EncryptedType {
186 Bool,
187 U8,
188 U16,
189 U32,
190 U64,
191}
192
193impl EncryptedType {
194 pub fn bit_width(&self) -> usize {
196 match self {
197 EncryptedType::Bool => 1,
198 EncryptedType::U8 => 8,
199 EncryptedType::U16 => 16,
200 EncryptedType::U32 => 32,
201 EncryptedType::U64 => 64,
202 }
203 }
204
205 pub fn is_numeric(&self) -> bool {
207 !matches!(self, EncryptedType::Bool)
208 }
209
210 pub fn is_boolean(&self) -> bool {
212 matches!(self, EncryptedType::Bool)
213 }
214}
215
216impl std::fmt::Display for EncryptedType {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 match self {
219 EncryptedType::Bool => write!(f, "bool"),
220 EncryptedType::U8 => write!(f, "u8"),
221 EncryptedType::U16 => write!(f, "u16"),
222 EncryptedType::U32 => write!(f, "u32"),
223 EncryptedType::U64 => write!(f, "u64"),
224 }
225 }
226}
227
228#[derive(Debug, Clone)]
230pub struct Circuit {
231 pub root: CircuitNode,
233
234 pub variable_types: HashMap<String, EncryptedType>,
236
237 pub result_type: EncryptedType,
239
240 pub depth: usize,
242
243 pub gate_count: usize,
245}
246
247impl Circuit {
248 pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
250 let result_type = Self::infer_type(&root, &variable_types)?;
251 let depth = Self::compute_depth(&root);
252 let gate_count = Self::count_gates(&root);
253
254 Ok(Self {
255 root,
256 variable_types,
257 result_type,
258 depth,
259 gate_count,
260 })
261 }
262
263 fn infer_type(
265 node: &CircuitNode,
266 variable_types: &HashMap<String, EncryptedType>,
267 ) -> Result<EncryptedType> {
268 match node {
269 CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
270 AmateRSError::FheComputation(ErrorContext::new(format!(
271 "Undefined variable: {}",
272 name
273 )))
274 }),
275
276 CircuitNode::Constant(value) => Ok(match value {
277 CircuitValue::Bool(_) => EncryptedType::Bool,
278 CircuitValue::U8(_) => EncryptedType::U8,
279 CircuitValue::U16(_) => EncryptedType::U16,
280 CircuitValue::U32(_) => EncryptedType::U32,
281 CircuitValue::U64(_) => EncryptedType::U64,
282 }),
283
284 CircuitNode::EncryptedConstant { original_type, .. } => {
285 Ok(match original_type {
286 ConstantType::Boolean => EncryptedType::Bool,
287 ConstantType::Integer | ConstantType::Float | ConstantType::Bytes => {
291 EncryptedType::U64
292 }
293 })
294 }
295
296 CircuitNode::BinaryOp { op, left, right } => {
297 let left_type = Self::infer_type(left, variable_types)?;
298 let right_type = Self::infer_type(right, variable_types)?;
299
300 match op {
301 BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
302 if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
303 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
304 "Logical operation requires boolean operands, got {} and {}",
305 left_type, right_type
306 ))));
307 }
308 Ok(EncryptedType::Bool)
309 }
310
311 BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
312 if !left_type.is_numeric() || !right_type.is_numeric() {
313 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
314 "Arithmetic operation requires numeric operands, got {} and {}",
315 left_type, right_type
316 ))));
317 }
318
319 if left_type != right_type {
320 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
321 "Arithmetic operation requires matching types, got {} and {}",
322 left_type, right_type
323 ))));
324 }
325
326 Ok(left_type)
327 }
328 }
329 }
330
331 CircuitNode::UnaryOp { op, operand } => {
332 let operand_type = Self::infer_type(operand, variable_types)?;
333
334 match op {
335 UnaryOperator::Not => {
336 if operand_type != EncryptedType::Bool {
337 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
338 "NOT operation requires boolean operand, got {}",
339 operand_type
340 ))));
341 }
342 Ok(EncryptedType::Bool)
343 }
344
345 UnaryOperator::Neg => {
346 if !operand_type.is_numeric() {
347 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
348 "Negation operation requires numeric operand, got {}",
349 operand_type
350 ))));
351 }
352 Ok(operand_type)
353 }
354 }
355 }
356
357 CircuitNode::Compare { left, right, .. } => {
358 let left_type = Self::infer_type(left, variable_types)?;
359 let right_type = Self::infer_type(right, variable_types)?;
360
361 if left_type != right_type {
362 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
363 "Comparison requires matching types, got {} and {}",
364 left_type, right_type
365 ))));
366 }
367
368 Ok(EncryptedType::Bool)
369 }
370 }
371 }
372
373 fn compute_depth(node: &CircuitNode) -> usize {
375 match node {
376 CircuitNode::Load(_)
377 | CircuitNode::Constant(_)
378 | CircuitNode::EncryptedConstant { .. } => 1,
379
380 CircuitNode::BinaryOp { left, right, .. }
381 | CircuitNode::Compare { left, right, .. } => {
382 1 + Self::compute_depth(left).max(Self::compute_depth(right))
383 }
384
385 CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
386 }
387 }
388
389 fn count_gates(node: &CircuitNode) -> usize {
391 match node {
392 CircuitNode::Load(_)
393 | CircuitNode::Constant(_)
394 | CircuitNode::EncryptedConstant { .. } => 0,
395
396 CircuitNode::BinaryOp { left, right, .. }
397 | CircuitNode::Compare { left, right, .. } => {
398 1 + Self::count_gates(left) + Self::count_gates(right)
399 }
400
401 CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
402 }
403 }
404
405 pub fn validate(&self) -> Result<()> {
407 Self::validate_node(&self.root, &self.variable_types)?;
408 Ok(())
409 }
410
411 fn validate_node(
412 node: &CircuitNode,
413 variable_types: &HashMap<String, EncryptedType>,
414 ) -> Result<()> {
415 match node {
416 CircuitNode::Load(name) => {
417 if !variable_types.contains_key(name) {
418 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
419 "Undefined variable: {}",
420 name
421 ))));
422 }
423 Ok(())
424 }
425
426 CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => Ok(()),
427
428 CircuitNode::BinaryOp { left, right, .. }
429 | CircuitNode::Compare { left, right, .. } => {
430 Self::validate_node(left, variable_types)?;
431 Self::validate_node(right, variable_types)?;
432 Ok(())
433 }
434
435 CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
436 }
437 }
438}
439
440#[derive(Default)]
442pub struct CircuitBuilder {
443 variable_types: HashMap<String, EncryptedType>,
444}
445
446impl CircuitBuilder {
447 pub fn new() -> Self {
448 Self::default()
449 }
450
451 pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
453 &self.variable_types
454 }
455
456 pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
458 self.variable_types.clone()
459 }
460
461 pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
463 self.variable_types.insert(name.into(), ty);
464 self
465 }
466
467 pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
469 Circuit::new(root, self.variable_types.clone())
470 }
471
472 pub fn load(&self, name: impl Into<String>) -> CircuitNode {
474 CircuitNode::Load(name.into())
475 }
476
477 pub fn constant(&self, value: CircuitValue) -> CircuitNode {
479 CircuitNode::Constant(value)
480 }
481
482 pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
484 CircuitNode::BinaryOp {
485 op: BinaryOperator::Add,
486 left: Box::new(left),
487 right: Box::new(right),
488 }
489 }
490
491 pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
493 CircuitNode::BinaryOp {
494 op: BinaryOperator::Sub,
495 left: Box::new(left),
496 right: Box::new(right),
497 }
498 }
499
500 pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
502 CircuitNode::BinaryOp {
503 op: BinaryOperator::Mul,
504 left: Box::new(left),
505 right: Box::new(right),
506 }
507 }
508
509 pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
511 CircuitNode::BinaryOp {
512 op: BinaryOperator::And,
513 left: Box::new(left),
514 right: Box::new(right),
515 }
516 }
517
518 pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
520 CircuitNode::BinaryOp {
521 op: BinaryOperator::Or,
522 left: Box::new(left),
523 right: Box::new(right),
524 }
525 }
526
527 pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
529 CircuitNode::BinaryOp {
530 op: BinaryOperator::Xor,
531 left: Box::new(left),
532 right: Box::new(right),
533 }
534 }
535
536 pub fn not(&self, operand: CircuitNode) -> CircuitNode {
538 CircuitNode::UnaryOp {
539 op: UnaryOperator::Not,
540 operand: Box::new(operand),
541 }
542 }
543
544 pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
546 CircuitNode::Compare {
547 op: CompareOperator::Eq,
548 left: Box::new(left),
549 right: Box::new(right),
550 }
551 }
552
553 pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
555 CircuitNode::Compare {
556 op: CompareOperator::Lt,
557 left: Box::new(left),
558 right: Box::new(right),
559 }
560 }
561
562 pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
564 CircuitNode::Compare {
565 op: CompareOperator::Gt,
566 left: Box::new(left),
567 right: Box::new(right),
568 }
569 }
570
571 pub fn encrypted_constant(&self, data: Vec<u8>, original_type: ConstantType) -> CircuitNode {
573 CircuitNode::EncryptedConstant {
574 data,
575 original_type,
576 }
577 }
578}
579
580pub fn encrypt_constant(value: &CircuitValue, key: &[u8]) -> Result<Vec<u8>> {
594 if key.is_empty() {
595 return Err(AmateRSError::FheComputation(ErrorContext::new(
596 "Encryption key must not be empty".to_string(),
597 )));
598 }
599
600 let (type_tag, plaintext): (u8, Vec<u8>) = match value {
602 CircuitValue::Bool(v) => (0x00, vec![if *v { 1 } else { 0 }]),
603 CircuitValue::U8(v) => (0x01, v.to_le_bytes().to_vec()),
604 CircuitValue::U16(v) => (0x02, v.to_le_bytes().to_vec()),
605 CircuitValue::U32(v) => (0x03, v.to_le_bytes().to_vec()),
606 CircuitValue::U64(v) => (0x04, v.to_le_bytes().to_vec()),
607 };
608
609 let keystream = derive_keystream(key, plaintext.len());
611
612 let ciphertext: Vec<u8> = plaintext
614 .iter()
615 .zip(keystream.iter())
616 .map(|(p, k)| p ^ k)
617 .collect();
618
619 let mut output = Vec::with_capacity(1 + ciphertext.len());
621 output.push(type_tag);
622 output.extend_from_slice(&ciphertext);
623
624 Ok(output)
625}
626
627pub fn decrypt_constant(data: &[u8], key: &[u8]) -> Result<CircuitValue> {
632 if key.is_empty() {
633 return Err(AmateRSError::FheComputation(ErrorContext::new(
634 "Decryption key must not be empty".to_string(),
635 )));
636 }
637 if data.is_empty() {
638 return Err(AmateRSError::FheComputation(ErrorContext::new(
639 "Encrypted constant data is empty".to_string(),
640 )));
641 }
642
643 let type_tag = data[0];
644 let ciphertext = &data[1..];
645
646 let keystream = derive_keystream(key, ciphertext.len());
647 let plaintext: Vec<u8> = ciphertext
648 .iter()
649 .zip(keystream.iter())
650 .map(|(c, k)| c ^ k)
651 .collect();
652
653 match type_tag {
654 0x00 => {
655 if plaintext.is_empty() {
656 return Err(AmateRSError::FheComputation(ErrorContext::new(
657 "Encrypted boolean constant has no payload".to_string(),
658 )));
659 }
660 Ok(CircuitValue::Bool(plaintext[0] != 0))
661 }
662 0x01 => {
663 let arr: [u8; 1] = plaintext.as_slice().try_into().map_err(|_| {
664 AmateRSError::FheComputation(ErrorContext::new(
665 "Invalid encrypted u8 constant length".to_string(),
666 ))
667 })?;
668 Ok(CircuitValue::U8(u8::from_le_bytes(arr)))
669 }
670 0x02 => {
671 let arr: [u8; 2] = plaintext.as_slice().try_into().map_err(|_| {
672 AmateRSError::FheComputation(ErrorContext::new(
673 "Invalid encrypted u16 constant length".to_string(),
674 ))
675 })?;
676 Ok(CircuitValue::U16(u16::from_le_bytes(arr)))
677 }
678 0x03 => {
679 let arr: [u8; 4] = plaintext.as_slice().try_into().map_err(|_| {
680 AmateRSError::FheComputation(ErrorContext::new(
681 "Invalid encrypted u32 constant length".to_string(),
682 ))
683 })?;
684 Ok(CircuitValue::U32(u32::from_le_bytes(arr)))
685 }
686 0x04 => {
687 let arr: [u8; 8] = plaintext.as_slice().try_into().map_err(|_| {
688 AmateRSError::FheComputation(ErrorContext::new(
689 "Invalid encrypted u64 constant length".to_string(),
690 ))
691 })?;
692 Ok(CircuitValue::U64(u64::from_le_bytes(arr)))
693 }
694 _ => Err(AmateRSError::FheComputation(ErrorContext::new(format!(
695 "Unknown encrypted constant type tag: 0x{:02x}",
696 type_tag
697 )))),
698 }
699}
700
701pub fn encrypt_circuit_constants(node: &CircuitNode, key: &[u8]) -> Result<CircuitNode> {
706 match node {
707 CircuitNode::Load(name) => Ok(CircuitNode::Load(name.clone())),
708
709 CircuitNode::Constant(value) => {
710 let data = encrypt_constant(value, key)?;
711 let original_type = match value {
712 CircuitValue::Bool(_) => ConstantType::Boolean,
713 CircuitValue::U8(_)
714 | CircuitValue::U16(_)
715 | CircuitValue::U32(_)
716 | CircuitValue::U64(_) => ConstantType::Integer,
717 };
718 Ok(CircuitNode::EncryptedConstant {
719 data,
720 original_type,
721 })
722 }
723
724 CircuitNode::EncryptedConstant {
726 data,
727 original_type,
728 } => Ok(CircuitNode::EncryptedConstant {
729 data: data.clone(),
730 original_type: *original_type,
731 }),
732
733 CircuitNode::BinaryOp { op, left, right } => {
734 let left = encrypt_circuit_constants(left, key)?;
735 let right = encrypt_circuit_constants(right, key)?;
736 Ok(CircuitNode::BinaryOp {
737 op: *op,
738 left: Box::new(left),
739 right: Box::new(right),
740 })
741 }
742
743 CircuitNode::UnaryOp { op, operand } => {
744 let operand = encrypt_circuit_constants(operand, key)?;
745 Ok(CircuitNode::UnaryOp {
746 op: *op,
747 operand: Box::new(operand),
748 })
749 }
750
751 CircuitNode::Compare { op, left, right } => {
752 let left = encrypt_circuit_constants(left, key)?;
753 let right = encrypt_circuit_constants(right, key)?;
754 Ok(CircuitNode::Compare {
755 op: *op,
756 left: Box::new(left),
757 right: Box::new(right),
758 })
759 }
760 }
761}
762
763fn derive_keystream(key: &[u8], length: usize) -> Vec<u8> {
770 let mut keystream = Vec::with_capacity(length);
771 let mut block_index: u64 = 0;
772
773 while keystream.len() < length {
774 let mut hash: u64 = 0xcbf29ce484222325;
776 for &byte in key {
777 hash ^= byte as u64;
778 hash = hash.wrapping_mul(0x100000001b3);
779 }
780 for &byte in &block_index.to_le_bytes() {
781 hash ^= byte as u64;
782 hash = hash.wrapping_mul(0x100000001b3);
783 }
784
785 for &byte in &hash.to_le_bytes() {
787 if keystream.len() < length {
788 keystream.push(byte);
789 }
790 }
791
792 block_index += 1;
793 }
794
795 keystream
796}
797
798pub fn is_encrypted_constant(node: &CircuitNode) -> bool {
800 matches!(node, CircuitNode::EncryptedConstant { .. })
801}
802
803pub fn count_plaintext_constants(node: &CircuitNode) -> usize {
805 match node {
806 CircuitNode::Constant(_) => 1,
807 CircuitNode::EncryptedConstant { .. } | CircuitNode::Load(_) => 0,
808 CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
809 count_plaintext_constants(left) + count_plaintext_constants(right)
810 }
811 CircuitNode::UnaryOp { operand, .. } => count_plaintext_constants(operand),
812 }
813}
814
815pub fn count_encrypted_constants(node: &CircuitNode) -> usize {
817 match node {
818 CircuitNode::EncryptedConstant { .. } => 1,
819 CircuitNode::Constant(_) | CircuitNode::Load(_) => 0,
820 CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
821 count_encrypted_constants(left) + count_encrypted_constants(right)
822 }
823 CircuitNode::UnaryOp { operand, .. } => count_encrypted_constants(operand),
824 }
825}
826
827#[derive(Debug, Clone, Default)]
832#[deprecated(
833 since = "0.1.0",
834 note = "Use CircuitOptimizer from optimizer module instead"
835)]
836pub struct CircuitOptimizer;
837
838#[allow(deprecated)]
839impl CircuitOptimizer {
840 pub fn new() -> Self {
841 Self
842 }
843
844 pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
849 let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
851 advanced_optimizer.optimize(circuit)
852 }
853}
854
855#[cfg(test)]
856mod tests {
857 use super::*;
858
859 #[test]
860 fn test_circuit_builder() -> Result<()> {
861 let mut builder = CircuitBuilder::new();
862 builder
863 .declare_variable("a", EncryptedType::U8)
864 .declare_variable("b", EncryptedType::U8);
865
866 let a = builder.load("a");
867 let b = builder.load("b");
868 let sum = builder.add(a, b);
869
870 let circuit = builder.build(sum)?;
871 assert_eq!(circuit.result_type, EncryptedType::U8);
872 assert_eq!(circuit.gate_count, 1);
873
874 Ok(())
875 }
876
877 #[test]
878 fn test_type_inference() -> Result<()> {
879 let mut builder = CircuitBuilder::new();
880 builder
881 .declare_variable("x", EncryptedType::Bool)
882 .declare_variable("y", EncryptedType::Bool);
883
884 let x = builder.load("x");
885 let y = builder.load("y");
886 let result = builder.and(x, y);
887
888 let circuit = builder.build(result)?;
889 assert_eq!(circuit.result_type, EncryptedType::Bool);
890
891 Ok(())
892 }
893
894 #[test]
895 fn test_type_mismatch_error() {
896 let mut builder = CircuitBuilder::new();
897 builder
898 .declare_variable("a", EncryptedType::U8)
899 .declare_variable("b", EncryptedType::Bool);
900
901 let a = builder.load("a");
902 let b = builder.load("b");
903 let invalid = builder.add(a, b);
904
905 let result = builder.build(invalid);
906 assert!(result.is_err());
907 }
908
909 #[test]
910 #[allow(deprecated)]
911 fn test_constant_folding() -> Result<()> {
912 let optimizer = CircuitOptimizer::new();
913 let builder = CircuitBuilder::new();
914
915 let a = builder.constant(CircuitValue::U8(5));
916 let b = builder.constant(CircuitValue::U8(3));
917 let sum = builder.add(a, b);
918
919 let circuit = Circuit::new(sum, HashMap::new())?;
920 let optimized = optimizer.optimize(circuit)?;
921
922 match optimized.root {
924 CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
925 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
926 "Constant folding failed".to_string(),
927 ))),
928 }
929 }
930
931 #[test]
932 fn test_circuit_depth() -> Result<()> {
933 let mut builder = CircuitBuilder::new();
934 builder
935 .declare_variable("a", EncryptedType::U8)
936 .declare_variable("b", EncryptedType::U8)
937 .declare_variable("c", EncryptedType::U8);
938
939 let a = builder.load("a");
940 let b = builder.load("b");
941 let c = builder.load("c");
942
943 let sum1 = builder.add(a, b);
945 let sum2 = builder.add(sum1, c);
946
947 let circuit = builder.build(sum2)?;
948 assert_eq!(circuit.depth, 3); Ok(())
951 }
952
953 #[test]
956 fn test_encrypted_constant_creation() {
957 let builder = CircuitBuilder::new();
958 let enc = builder.encrypted_constant(vec![0xAA, 0xBB], ConstantType::Integer);
959
960 match enc {
961 CircuitNode::EncryptedConstant {
962 data,
963 original_type,
964 } => {
965 assert_eq!(data, vec![0xAA, 0xBB]);
966 assert_eq!(original_type, ConstantType::Integer);
967 }
968 _ => panic!("Expected EncryptedConstant"),
969 }
970 }
971
972 #[test]
973 fn test_encrypt_constant_bool() -> Result<()> {
974 let key = b"test-encryption-key";
975 let value = CircuitValue::Bool(true);
976
977 let encrypted = encrypt_constant(&value, key)?;
978 assert!(!encrypted.is_empty());
979 assert_eq!(encrypted[0], 0x00);
981
982 let decrypted = decrypt_constant(&encrypted, key)?;
983 assert_eq!(decrypted, value);
984
985 Ok(())
986 }
987
988 #[test]
989 fn test_encrypt_constant_u8() -> Result<()> {
990 let key = b"test-key-u8";
991 let value = CircuitValue::U8(42);
992
993 let encrypted = encrypt_constant(&value, key)?;
994 assert_eq!(encrypted[0], 0x01); let decrypted = decrypt_constant(&encrypted, key)?;
997 assert_eq!(decrypted, value);
998
999 Ok(())
1000 }
1001
1002 #[test]
1003 fn test_encrypt_constant_u16() -> Result<()> {
1004 let key = b"test-key-u16";
1005 let value = CircuitValue::U16(12345);
1006
1007 let encrypted = encrypt_constant(&value, key)?;
1008 assert_eq!(encrypted[0], 0x02);
1009
1010 let decrypted = decrypt_constant(&encrypted, key)?;
1011 assert_eq!(decrypted, value);
1012
1013 Ok(())
1014 }
1015
1016 #[test]
1017 fn test_encrypt_constant_u32() -> Result<()> {
1018 let key = b"test-key-u32";
1019 let value = CircuitValue::U32(1_000_000);
1020
1021 let encrypted = encrypt_constant(&value, key)?;
1022 assert_eq!(encrypted[0], 0x03);
1023
1024 let decrypted = decrypt_constant(&encrypted, key)?;
1025 assert_eq!(decrypted, value);
1026
1027 Ok(())
1028 }
1029
1030 #[test]
1031 fn test_encrypt_constant_u64() -> Result<()> {
1032 let key = b"test-key-u64";
1033 let value = CircuitValue::U64(0xDEAD_BEEF_CAFE_BABE);
1034
1035 let encrypted = encrypt_constant(&value, key)?;
1036 assert_eq!(encrypted[0], 0x04);
1037
1038 let decrypted = decrypt_constant(&encrypted, key)?;
1039 assert_eq!(decrypted, value);
1040
1041 Ok(())
1042 }
1043
1044 #[test]
1045 fn test_encrypt_constant_wrong_key_produces_wrong_value() -> Result<()> {
1046 let key1 = b"correct-key";
1047 let key2 = b"wrong-key!!";
1048 let value = CircuitValue::U8(42);
1049
1050 let encrypted = encrypt_constant(&value, key1)?;
1051 let decrypted = decrypt_constant(&encrypted, key2)?;
1052 assert_ne!(decrypted, value);
1054
1055 Ok(())
1056 }
1057
1058 #[test]
1059 fn test_encrypt_constant_empty_key_error() {
1060 let key: &[u8] = &[];
1061 let value = CircuitValue::U8(1);
1062
1063 let result = encrypt_constant(&value, key);
1064 assert!(result.is_err());
1065 }
1066
1067 #[test]
1068 fn test_decrypt_constant_empty_data_error() {
1069 let key = b"some-key";
1070 let result = decrypt_constant(&[], key);
1071 assert!(result.is_err());
1072 }
1073
1074 #[test]
1075 fn test_encrypt_circuit_constants_transforms_all() -> Result<()> {
1076 let builder = CircuitBuilder::new();
1077 let key = b"circuit-encryption-key";
1078
1079 let a = builder.constant(CircuitValue::U8(5));
1081 let b = builder.constant(CircuitValue::U8(3));
1082 let sum = builder.add(a, b);
1083
1084 assert_eq!(count_plaintext_constants(&sum), 2);
1086 assert_eq!(count_encrypted_constants(&sum), 0);
1087
1088 let encrypted = encrypt_circuit_constants(&sum, key)?;
1090
1091 assert_eq!(count_plaintext_constants(&encrypted), 0);
1093 assert_eq!(count_encrypted_constants(&encrypted), 2);
1094
1095 match &encrypted {
1097 CircuitNode::BinaryOp { op, left, right } => {
1098 assert_eq!(*op, BinaryOperator::Add);
1099 assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
1100 assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1101 }
1102 _ => panic!("Expected BinaryOp after encryption"),
1103 }
1104
1105 Ok(())
1106 }
1107
1108 #[test]
1109 fn test_encrypt_circuit_constants_preserves_loads() -> Result<()> {
1110 let mut builder = CircuitBuilder::new();
1111 builder.declare_variable("x", EncryptedType::U8);
1112 let key = b"key-for-loads-test";
1113
1114 let x = builder.load("x");
1116 let c = builder.constant(CircuitValue::U8(10));
1117 let sum = builder.add(x, c);
1118
1119 let encrypted = encrypt_circuit_constants(&sum, key)?;
1120
1121 match &encrypted {
1123 CircuitNode::BinaryOp { left, right, .. } => {
1124 assert!(matches!(**left, CircuitNode::Load(ref name) if name == "x"));
1125 assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1126 }
1127 _ => panic!("Expected BinaryOp"),
1128 }
1129
1130 Ok(())
1131 }
1132
1133 #[test]
1134 fn test_encrypt_circuit_constants_already_encrypted_pass_through() -> Result<()> {
1135 let builder = CircuitBuilder::new();
1136 let key = b"key-pass-through";
1137
1138 let enc = builder.encrypted_constant(vec![0x01, 0x02, 0x03], ConstantType::Integer);
1140 let original_data = vec![0x01, 0x02, 0x03];
1141
1142 let result = encrypt_circuit_constants(&enc, key)?;
1143
1144 match result {
1146 CircuitNode::EncryptedConstant {
1147 data,
1148 original_type,
1149 } => {
1150 assert_eq!(data, original_data);
1151 assert_eq!(original_type, ConstantType::Integer);
1152 }
1153 _ => panic!("Expected EncryptedConstant pass-through"),
1154 }
1155
1156 Ok(())
1157 }
1158
1159 #[test]
1160 fn test_encrypted_constant_display() {
1161 let node = CircuitNode::EncryptedConstant {
1162 data: vec![0xAA, 0xBB, 0xCC],
1163 original_type: ConstantType::Boolean,
1164 };
1165 let display = format!("{}", node);
1166 assert!(display.contains("EncryptedConst"));
1167 assert!(display.contains("boolean"));
1168 assert!(display.contains("3 bytes"));
1169 }
1170
1171 #[test]
1172 fn test_circuit_node_display_variants() {
1173 let load = CircuitNode::Load("x".to_string());
1175 assert_eq!(format!("{}", load), "Load(x)");
1176
1177 let constant = CircuitNode::Constant(CircuitValue::U8(42));
1179 assert_eq!(format!("{}", constant), "Const(42u8)");
1180
1181 let bool_const = CircuitNode::Constant(CircuitValue::Bool(true));
1183 assert_eq!(format!("{}", bool_const), "Const(true)");
1184 }
1185
1186 #[test]
1187 fn test_constant_type_display() {
1188 assert_eq!(format!("{}", ConstantType::Integer), "integer");
1189 assert_eq!(format!("{}", ConstantType::Boolean), "boolean");
1190 assert_eq!(format!("{}", ConstantType::Float), "float");
1191 assert_eq!(format!("{}", ConstantType::Bytes), "bytes");
1192 }
1193
1194 #[test]
1195 fn test_constant_type_variants() {
1196 let variants = [
1198 ConstantType::Integer,
1199 ConstantType::Boolean,
1200 ConstantType::Float,
1201 ConstantType::Bytes,
1202 ];
1203 for (i, a) in variants.iter().enumerate() {
1204 for (j, b) in variants.iter().enumerate() {
1205 if i == j {
1206 assert_eq!(a, b);
1207 } else {
1208 assert_ne!(a, b);
1209 }
1210 }
1211 }
1212 }
1213
1214 #[test]
1215 fn test_constant_type_serialization_roundtrip() -> Result<()> {
1216 let types = [
1217 ConstantType::Integer,
1218 ConstantType::Boolean,
1219 ConstantType::Float,
1220 ConstantType::Bytes,
1221 ];
1222
1223 for ct in &types {
1224 let json = serde_json::to_string(ct).map_err(|e| {
1225 AmateRSError::FheComputation(ErrorContext::new(format!(
1226 "Serialization failed: {}",
1227 e
1228 )))
1229 })?;
1230 let deserialized: ConstantType = serde_json::from_str(&json).map_err(|e| {
1231 AmateRSError::FheComputation(ErrorContext::new(format!(
1232 "Deserialization failed: {}",
1233 e
1234 )))
1235 })?;
1236 assert_eq!(*ct, deserialized);
1237 }
1238
1239 Ok(())
1240 }
1241
1242 #[test]
1243 fn test_is_encrypted_constant() {
1244 let enc = CircuitNode::EncryptedConstant {
1245 data: vec![1, 2, 3],
1246 original_type: ConstantType::Integer,
1247 };
1248 assert!(is_encrypted_constant(&enc));
1249
1250 let plain = CircuitNode::Constant(CircuitValue::U8(5));
1251 assert!(!is_encrypted_constant(&plain));
1252
1253 let load = CircuitNode::Load("x".to_string());
1254 assert!(!is_encrypted_constant(&load));
1255 }
1256
1257 #[test]
1258 fn test_encrypted_constant_in_circuit_validation() -> Result<()> {
1259 let enc = CircuitNode::EncryptedConstant {
1261 data: vec![0x00, 0x01],
1262 original_type: ConstantType::Boolean,
1263 };
1264 let circuit = Circuit::new(enc, HashMap::new())?;
1265 circuit.validate()?;
1266 assert_eq!(circuit.result_type, EncryptedType::Bool);
1268 Ok(())
1269 }
1270
1271 #[test]
1272 fn test_encrypted_constant_depth_and_gate_count() -> Result<()> {
1273 let builder = CircuitBuilder::new();
1274
1275 let enc = builder.encrypted_constant(vec![0x01, 0x42], ConstantType::Integer);
1277 let circuit = Circuit::new(enc, HashMap::new())?;
1278 assert_eq!(circuit.depth, 1);
1279 assert_eq!(circuit.gate_count, 0);
1280
1281 Ok(())
1282 }
1283
1284 #[test]
1285 fn test_mixed_plain_and_encrypted_constants() -> Result<()> {
1286 let builder = CircuitBuilder::new();
1287 let key = b"mixed-circuit-key";
1288
1289 let plain = builder.constant(CircuitValue::U8(10));
1291 let encrypted_data = encrypt_constant(&CircuitValue::U8(20), key)?;
1292 let enc = builder.encrypted_constant(encrypted_data, ConstantType::Integer);
1293
1294 let not_node = CircuitNode::UnaryOp {
1297 op: UnaryOperator::Not,
1298 operand: Box::new(CircuitNode::Constant(CircuitValue::Bool(true))),
1299 };
1300
1301 assert_eq!(count_plaintext_constants(&plain), 1);
1305 assert_eq!(count_encrypted_constants(&plain), 0);
1306 assert_eq!(count_plaintext_constants(&enc), 0);
1307 assert_eq!(count_encrypted_constants(&enc), 1);
1308 assert_eq!(count_plaintext_constants(¬_node), 1);
1309 assert_eq!(count_encrypted_constants(¬_node), 0);
1310
1311 Ok(())
1312 }
1313
1314 #[test]
1315 fn test_encrypt_constant_deterministic() -> Result<()> {
1316 let key = b"deterministic-test-key";
1317 let value = CircuitValue::U32(999);
1318
1319 let enc1 = encrypt_constant(&value, key)?;
1320 let enc2 = encrypt_constant(&value, key)?;
1321 assert_eq!(enc1, enc2);
1323
1324 Ok(())
1325 }
1326
1327 #[test]
1328 fn test_encrypt_constant_different_keys_differ() -> Result<()> {
1329 let key1 = b"key-alpha";
1330 let key2 = b"key-bravo";
1331 let value = CircuitValue::U64(123456789);
1332
1333 let enc1 = encrypt_constant(&value, key1)?;
1334 let enc2 = encrypt_constant(&value, key2)?;
1335 assert_eq!(enc1[0], enc2[0]); assert_ne!(enc1[1..], enc2[1..]); Ok(())
1341 }
1342
1343 #[test]
1344 fn test_encrypt_decrypt_roundtrip_all_types() -> Result<()> {
1345 let key = b"roundtrip-all-types";
1346
1347 let values = vec![
1348 CircuitValue::Bool(false),
1349 CircuitValue::Bool(true),
1350 CircuitValue::U8(0),
1351 CircuitValue::U8(255),
1352 CircuitValue::U16(0),
1353 CircuitValue::U16(65535),
1354 CircuitValue::U32(0),
1355 CircuitValue::U32(u32::MAX),
1356 CircuitValue::U64(0),
1357 CircuitValue::U64(u64::MAX),
1358 ];
1359
1360 for value in &values {
1361 let encrypted = encrypt_constant(value, key)?;
1362 let decrypted = decrypt_constant(&encrypted, key)?;
1363 assert_eq!(*value, decrypted, "Roundtrip failed for {:?}", value);
1364 }
1365
1366 Ok(())
1367 }
1368
1369 #[test]
1370 fn test_encrypt_circuit_constants_nested() -> Result<()> {
1371 let builder = CircuitBuilder::new();
1372 let key = b"nested-circuit-key";
1373
1374 let t = builder.constant(CircuitValue::Bool(true));
1376 let f = builder.constant(CircuitValue::Bool(false));
1377 let and_node = builder.and(t, f);
1378 let not_node = builder.not(and_node);
1379
1380 assert_eq!(count_plaintext_constants(¬_node), 2);
1381 assert_eq!(count_encrypted_constants(¬_node), 0);
1382
1383 let encrypted = encrypt_circuit_constants(¬_node, key)?;
1384
1385 assert_eq!(count_plaintext_constants(&encrypted), 0);
1386 assert_eq!(count_encrypted_constants(&encrypted), 2);
1387
1388 match &encrypted {
1390 CircuitNode::UnaryOp { op, operand } => {
1391 assert_eq!(*op, UnaryOperator::Not);
1392 match operand.as_ref() {
1393 CircuitNode::BinaryOp { op, left, right } => {
1394 assert_eq!(*op, BinaryOperator::And);
1395 assert!(is_encrypted_constant(left));
1396 assert!(is_encrypted_constant(right));
1397 }
1398 _ => panic!("Expected BinaryOp inside UnaryOp"),
1399 }
1400 }
1401 _ => panic!("Expected UnaryOp at root"),
1402 }
1403
1404 Ok(())
1405 }
1406}