1use crate::error::{AmateRSError, ErrorContext, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum ConstantType {
13 Integer,
15 Boolean,
17 Float,
19 Bytes,
21}
22
23impl std::fmt::Display for ConstantType {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 ConstantType::Integer => write!(f, "integer"),
27 ConstantType::Boolean => write!(f, "boolean"),
28 ConstantType::Float => write!(f, "float"),
29 ConstantType::Bytes => write!(f, "bytes"),
30 }
31 }
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum CircuitNode {
37 Load(String),
39
40 Constant(CircuitValue),
42
43 EncryptedConstant {
49 data: Vec<u8>,
51 original_type: ConstantType,
53 },
54
55 BinaryOp {
57 op: BinaryOperator,
58 left: Box<CircuitNode>,
59 right: Box<CircuitNode>,
60 },
61
62 UnaryOp {
64 op: UnaryOperator,
65 operand: Box<CircuitNode>,
66 },
67
68 Compare {
70 op: CompareOperator,
71 left: Box<CircuitNode>,
72 right: Box<CircuitNode>,
73 },
74
75 NaryOp {
81 op: BinaryOperator,
82 operands: Vec<CircuitNode>,
83 },
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum BinaryOperator {
89 Add,
90 Sub,
91 Mul,
92 And,
93 Or,
94 Xor,
95}
96
97impl BinaryOperator {
98 pub fn as_str(&self) -> &str {
100 match self {
101 BinaryOperator::Add => "+",
102 BinaryOperator::Sub => "-",
103 BinaryOperator::Mul => "*",
104 BinaryOperator::And => "AND",
105 BinaryOperator::Or => "OR",
106 BinaryOperator::Xor => "XOR",
107 }
108 }
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum UnaryOperator {
114 Not,
115 Neg,
116}
117
118impl UnaryOperator {
119 pub fn as_str(&self) -> &str {
121 match self {
122 UnaryOperator::Not => "NOT",
123 UnaryOperator::Neg => "-",
124 }
125 }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum CompareOperator {
131 Eq,
132 Ne,
133 Lt,
134 Le,
135 Gt,
136 Ge,
137}
138
139impl CompareOperator {
140 pub fn as_str(&self) -> &str {
142 match self {
143 CompareOperator::Eq => "=",
144 CompareOperator::Ne => "!=",
145 CompareOperator::Lt => "<",
146 CompareOperator::Le => "<=",
147 CompareOperator::Gt => ">",
148 CompareOperator::Ge => ">=",
149 }
150 }
151}
152
153#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum CircuitValue {
156 Bool(bool),
157 U8(u8),
158 U16(u16),
159 U32(u32),
160 U64(u64),
161}
162
163impl std::fmt::Display for CircuitNode {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 CircuitNode::Load(name) => write!(f, "Load({})", name),
167 CircuitNode::Constant(value) => match value {
168 CircuitValue::Bool(v) => write!(f, "Const({})", v),
169 CircuitValue::U8(v) => write!(f, "Const({}u8)", v),
170 CircuitValue::U16(v) => write!(f, "Const({}u16)", v),
171 CircuitValue::U32(v) => write!(f, "Const({}u32)", v),
172 CircuitValue::U64(v) => write!(f, "Const({}u64)", v),
173 },
174 CircuitNode::EncryptedConstant {
175 data,
176 original_type,
177 } => {
178 write!(f, "EncryptedConst({}, {} bytes)", original_type, data.len())
179 }
180 CircuitNode::BinaryOp { op, left, right } => {
181 write!(f, "({} {} {})", left, op.as_str(), right)
182 }
183 CircuitNode::UnaryOp { op, operand } => {
184 write!(f, "{}({})", op.as_str(), operand)
185 }
186 CircuitNode::Compare { op, left, right } => {
187 write!(f, "({} {} {})", left, op.as_str(), right)
188 }
189 CircuitNode::NaryOp { op, operands } => {
190 write!(f, "{}(", op.as_str())?;
191 for (i, operand) in operands.iter().enumerate() {
192 if i > 0 {
193 write!(f, ", ")?;
194 }
195 write!(f, "{}", operand)?;
196 }
197 write!(f, ")")
198 }
199 }
200 }
201}
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
205pub enum EncryptedType {
206 Bool,
207 U8,
208 U16,
209 U32,
210 U64,
211}
212
213impl EncryptedType {
214 pub fn bit_width(&self) -> usize {
216 match self {
217 EncryptedType::Bool => 1,
218 EncryptedType::U8 => 8,
219 EncryptedType::U16 => 16,
220 EncryptedType::U32 => 32,
221 EncryptedType::U64 => 64,
222 }
223 }
224
225 pub fn is_numeric(&self) -> bool {
227 !matches!(self, EncryptedType::Bool)
228 }
229
230 pub fn is_boolean(&self) -> bool {
232 matches!(self, EncryptedType::Bool)
233 }
234}
235
236impl std::fmt::Display for EncryptedType {
237 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238 match self {
239 EncryptedType::Bool => write!(f, "bool"),
240 EncryptedType::U8 => write!(f, "u8"),
241 EncryptedType::U16 => write!(f, "u16"),
242 EncryptedType::U32 => write!(f, "u32"),
243 EncryptedType::U64 => write!(f, "u64"),
244 }
245 }
246}
247
248#[derive(Debug, Clone)]
250pub struct Circuit {
251 pub root: CircuitNode,
253
254 pub variable_types: HashMap<String, EncryptedType>,
256
257 pub result_type: EncryptedType,
259
260 pub depth: usize,
262
263 pub gate_count: usize,
265}
266
267impl Circuit {
268 pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
270 let result_type = Self::infer_type(&root, &variable_types)?;
271 let depth = Self::compute_depth(&root);
272 let gate_count = Self::count_gates(&root);
273
274 Ok(Self {
275 root,
276 variable_types,
277 result_type,
278 depth,
279 gate_count,
280 })
281 }
282
283 fn infer_type(
285 node: &CircuitNode,
286 variable_types: &HashMap<String, EncryptedType>,
287 ) -> Result<EncryptedType> {
288 match node {
289 CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
290 AmateRSError::FheComputation(ErrorContext::new(format!(
291 "Undefined variable: {}",
292 name
293 )))
294 }),
295
296 CircuitNode::Constant(value) => Ok(match value {
297 CircuitValue::Bool(_) => EncryptedType::Bool,
298 CircuitValue::U8(_) => EncryptedType::U8,
299 CircuitValue::U16(_) => EncryptedType::U16,
300 CircuitValue::U32(_) => EncryptedType::U32,
301 CircuitValue::U64(_) => EncryptedType::U64,
302 }),
303
304 CircuitNode::EncryptedConstant { original_type, .. } => {
305 Ok(match original_type {
306 ConstantType::Boolean => EncryptedType::Bool,
307 ConstantType::Integer | ConstantType::Float | ConstantType::Bytes => {
311 EncryptedType::U64
312 }
313 })
314 }
315
316 CircuitNode::BinaryOp { op, left, right } => {
317 let left_type = Self::infer_type(left, variable_types)?;
318 let right_type = Self::infer_type(right, variable_types)?;
319
320 match op {
321 BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
322 if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
323 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
324 "Logical operation requires boolean operands, got {} and {}",
325 left_type, right_type
326 ))));
327 }
328 Ok(EncryptedType::Bool)
329 }
330
331 BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
332 if !left_type.is_numeric() || !right_type.is_numeric() {
333 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
334 "Arithmetic operation requires numeric operands, got {} and {}",
335 left_type, right_type
336 ))));
337 }
338
339 if left_type != right_type {
340 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
341 "Arithmetic operation requires matching types, got {} and {}",
342 left_type, right_type
343 ))));
344 }
345
346 Ok(left_type)
347 }
348 }
349 }
350
351 CircuitNode::UnaryOp { op, operand } => {
352 let operand_type = Self::infer_type(operand, variable_types)?;
353
354 match op {
355 UnaryOperator::Not => {
356 if operand_type != EncryptedType::Bool {
357 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
358 "NOT operation requires boolean operand, got {}",
359 operand_type
360 ))));
361 }
362 Ok(EncryptedType::Bool)
363 }
364
365 UnaryOperator::Neg => {
366 if !operand_type.is_numeric() {
367 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
368 "Negation operation requires numeric operand, got {}",
369 operand_type
370 ))));
371 }
372 Ok(operand_type)
373 }
374 }
375 }
376
377 CircuitNode::Compare { left, right, .. } => {
378 let left_type = Self::infer_type(left, variable_types)?;
379 let right_type = Self::infer_type(right, variable_types)?;
380
381 if left_type != right_type {
382 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
383 "Comparison requires matching types, got {} and {}",
384 left_type, right_type
385 ))));
386 }
387
388 Ok(EncryptedType::Bool)
389 }
390 CircuitNode::NaryOp { op, operands } => {
391 if operands.len() < 2 {
392 return Err(AmateRSError::FheComputation(ErrorContext::new(
393 "NaryOp requires at least 2 operands".to_string(),
394 )));
395 }
396 let first_type = Self::infer_type(&operands[0], variable_types)?;
397 for operand in &operands[1..] {
398 let t = Self::infer_type(operand, variable_types)?;
399 if t != first_type {
400 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
401 "NaryOp operands have mismatched types: {} and {}",
402 first_type, t
403 ))));
404 }
405 }
406 match op {
407 BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
408 if first_type != EncryptedType::Bool {
409 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
410 "Logical NaryOp requires boolean operands, got {}",
411 first_type
412 ))));
413 }
414 Ok(EncryptedType::Bool)
415 }
416 BinaryOperator::Add | BinaryOperator::Mul => {
417 if !first_type.is_numeric() {
418 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
419 "Arithmetic NaryOp requires numeric operands, got {}",
420 first_type
421 ))));
422 }
423 Ok(first_type)
424 }
425 BinaryOperator::Sub => Err(AmateRSError::FheComputation(ErrorContext::new(
426 "NaryOp is not valid for Sub (non-associative)".to_string(),
427 ))),
428 }
429 }
430 }
431 }
432
433 fn compute_depth(node: &CircuitNode) -> usize {
435 match node {
436 CircuitNode::Load(_)
437 | CircuitNode::Constant(_)
438 | CircuitNode::EncryptedConstant { .. } => 1,
439
440 CircuitNode::BinaryOp { left, right, .. }
441 | CircuitNode::Compare { left, right, .. } => {
442 1 + Self::compute_depth(left).max(Self::compute_depth(right))
443 }
444
445 CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
446 CircuitNode::NaryOp { operands, .. } => {
447 let max_operand_depth = operands.iter().map(Self::compute_depth).max().unwrap_or(1);
448 let n = operands.len();
449 let log2_n = (n as f64).log2().ceil() as usize;
450 log2_n.max(1) + max_operand_depth
451 }
452 }
453 }
454
455 fn count_gates(node: &CircuitNode) -> usize {
457 match node {
458 CircuitNode::Load(_)
459 | CircuitNode::Constant(_)
460 | CircuitNode::EncryptedConstant { .. } => 0,
461
462 CircuitNode::BinaryOp { left, right, .. }
463 | CircuitNode::Compare { left, right, .. } => {
464 1 + Self::count_gates(left) + Self::count_gates(right)
465 }
466
467 CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
468 CircuitNode::NaryOp { operands, .. } => {
469 let inner_gates: usize = operands.iter().map(Self::count_gates).sum();
470 inner_gates + operands.len().saturating_sub(1)
471 }
472 }
473 }
474
475 pub fn validate(&self) -> Result<()> {
477 Self::validate_node(&self.root, &self.variable_types)?;
478 Ok(())
479 }
480
481 fn validate_node(
482 node: &CircuitNode,
483 variable_types: &HashMap<String, EncryptedType>,
484 ) -> Result<()> {
485 match node {
486 CircuitNode::Load(name) => {
487 if !variable_types.contains_key(name) {
488 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
489 "Undefined variable: {}",
490 name
491 ))));
492 }
493 Ok(())
494 }
495
496 CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => Ok(()),
497
498 CircuitNode::BinaryOp { left, right, .. }
499 | CircuitNode::Compare { left, right, .. } => {
500 Self::validate_node(left, variable_types)?;
501 Self::validate_node(right, variable_types)?;
502 Ok(())
503 }
504
505 CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
506 CircuitNode::NaryOp { op, operands } => {
507 if operands.len() < 2 {
508 return Err(AmateRSError::FheComputation(ErrorContext::new(
509 "NaryOp requires at least 2 operands".to_string(),
510 )));
511 }
512 if op == &BinaryOperator::Sub {
513 return Err(AmateRSError::FheComputation(ErrorContext::new(
514 "NaryOp is not valid for Sub (non-associative)".to_string(),
515 )));
516 }
517 for operand in operands {
518 Self::validate_node(operand, variable_types)?;
519 }
520 Ok(())
521 }
522 }
523 }
524}
525
526#[derive(Default)]
528pub struct CircuitBuilder {
529 variable_types: HashMap<String, EncryptedType>,
530}
531
532impl CircuitBuilder {
533 pub fn new() -> Self {
534 Self::default()
535 }
536
537 pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
539 &self.variable_types
540 }
541
542 pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
544 self.variable_types.clone()
545 }
546
547 pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
549 self.variable_types.insert(name.into(), ty);
550 self
551 }
552
553 pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
555 Circuit::new(root, self.variable_types.clone())
556 }
557
558 pub fn load(&self, name: impl Into<String>) -> CircuitNode {
560 CircuitNode::Load(name.into())
561 }
562
563 pub fn constant(&self, value: CircuitValue) -> CircuitNode {
565 CircuitNode::Constant(value)
566 }
567
568 pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
570 CircuitNode::BinaryOp {
571 op: BinaryOperator::Add,
572 left: Box::new(left),
573 right: Box::new(right),
574 }
575 }
576
577 pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
579 CircuitNode::BinaryOp {
580 op: BinaryOperator::Sub,
581 left: Box::new(left),
582 right: Box::new(right),
583 }
584 }
585
586 pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
588 CircuitNode::BinaryOp {
589 op: BinaryOperator::Mul,
590 left: Box::new(left),
591 right: Box::new(right),
592 }
593 }
594
595 pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
597 CircuitNode::BinaryOp {
598 op: BinaryOperator::And,
599 left: Box::new(left),
600 right: Box::new(right),
601 }
602 }
603
604 pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
606 CircuitNode::BinaryOp {
607 op: BinaryOperator::Or,
608 left: Box::new(left),
609 right: Box::new(right),
610 }
611 }
612
613 pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
615 CircuitNode::BinaryOp {
616 op: BinaryOperator::Xor,
617 left: Box::new(left),
618 right: Box::new(right),
619 }
620 }
621
622 pub fn not(&self, operand: CircuitNode) -> CircuitNode {
624 CircuitNode::UnaryOp {
625 op: UnaryOperator::Not,
626 operand: Box::new(operand),
627 }
628 }
629
630 pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
632 CircuitNode::Compare {
633 op: CompareOperator::Eq,
634 left: Box::new(left),
635 right: Box::new(right),
636 }
637 }
638
639 pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
641 CircuitNode::Compare {
642 op: CompareOperator::Lt,
643 left: Box::new(left),
644 right: Box::new(right),
645 }
646 }
647
648 pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
650 CircuitNode::Compare {
651 op: CompareOperator::Gt,
652 left: Box::new(left),
653 right: Box::new(right),
654 }
655 }
656
657 pub fn encrypted_constant(&self, data: Vec<u8>, original_type: ConstantType) -> CircuitNode {
659 CircuitNode::EncryptedConstant {
660 data,
661 original_type,
662 }
663 }
664}
665
666pub fn encrypt_constant(value: &CircuitValue, key: &[u8]) -> Result<Vec<u8>> {
680 if key.is_empty() {
681 return Err(AmateRSError::FheComputation(ErrorContext::new(
682 "Encryption key must not be empty".to_string(),
683 )));
684 }
685
686 let (type_tag, plaintext): (u8, Vec<u8>) = match value {
688 CircuitValue::Bool(v) => (0x00, vec![if *v { 1 } else { 0 }]),
689 CircuitValue::U8(v) => (0x01, v.to_le_bytes().to_vec()),
690 CircuitValue::U16(v) => (0x02, v.to_le_bytes().to_vec()),
691 CircuitValue::U32(v) => (0x03, v.to_le_bytes().to_vec()),
692 CircuitValue::U64(v) => (0x04, v.to_le_bytes().to_vec()),
693 };
694
695 let keystream = derive_keystream(key, plaintext.len());
697
698 let ciphertext: Vec<u8> = plaintext
700 .iter()
701 .zip(keystream.iter())
702 .map(|(p, k)| p ^ k)
703 .collect();
704
705 let mut output = Vec::with_capacity(1 + ciphertext.len());
707 output.push(type_tag);
708 output.extend_from_slice(&ciphertext);
709
710 Ok(output)
711}
712
713pub fn decrypt_constant(data: &[u8], key: &[u8]) -> Result<CircuitValue> {
718 if key.is_empty() {
719 return Err(AmateRSError::FheComputation(ErrorContext::new(
720 "Decryption key must not be empty".to_string(),
721 )));
722 }
723 if data.is_empty() {
724 return Err(AmateRSError::FheComputation(ErrorContext::new(
725 "Encrypted constant data is empty".to_string(),
726 )));
727 }
728
729 let type_tag = data[0];
730 let ciphertext = &data[1..];
731
732 let keystream = derive_keystream(key, ciphertext.len());
733 let plaintext: Vec<u8> = ciphertext
734 .iter()
735 .zip(keystream.iter())
736 .map(|(c, k)| c ^ k)
737 .collect();
738
739 match type_tag {
740 0x00 => {
741 if plaintext.is_empty() {
742 return Err(AmateRSError::FheComputation(ErrorContext::new(
743 "Encrypted boolean constant has no payload".to_string(),
744 )));
745 }
746 Ok(CircuitValue::Bool(plaintext[0] != 0))
747 }
748 0x01 => {
749 let arr: [u8; 1] = plaintext.as_slice().try_into().map_err(|_| {
750 AmateRSError::FheComputation(ErrorContext::new(
751 "Invalid encrypted u8 constant length".to_string(),
752 ))
753 })?;
754 Ok(CircuitValue::U8(u8::from_le_bytes(arr)))
755 }
756 0x02 => {
757 let arr: [u8; 2] = plaintext.as_slice().try_into().map_err(|_| {
758 AmateRSError::FheComputation(ErrorContext::new(
759 "Invalid encrypted u16 constant length".to_string(),
760 ))
761 })?;
762 Ok(CircuitValue::U16(u16::from_le_bytes(arr)))
763 }
764 0x03 => {
765 let arr: [u8; 4] = plaintext.as_slice().try_into().map_err(|_| {
766 AmateRSError::FheComputation(ErrorContext::new(
767 "Invalid encrypted u32 constant length".to_string(),
768 ))
769 })?;
770 Ok(CircuitValue::U32(u32::from_le_bytes(arr)))
771 }
772 0x04 => {
773 let arr: [u8; 8] = plaintext.as_slice().try_into().map_err(|_| {
774 AmateRSError::FheComputation(ErrorContext::new(
775 "Invalid encrypted u64 constant length".to_string(),
776 ))
777 })?;
778 Ok(CircuitValue::U64(u64::from_le_bytes(arr)))
779 }
780 _ => Err(AmateRSError::FheComputation(ErrorContext::new(format!(
781 "Unknown encrypted constant type tag: 0x{:02x}",
782 type_tag
783 )))),
784 }
785}
786
787pub fn encrypt_circuit_constants(node: &CircuitNode, key: &[u8]) -> Result<CircuitNode> {
792 match node {
793 CircuitNode::Load(name) => Ok(CircuitNode::Load(name.clone())),
794
795 CircuitNode::Constant(value) => {
796 let data = encrypt_constant(value, key)?;
797 let original_type = match value {
798 CircuitValue::Bool(_) => ConstantType::Boolean,
799 CircuitValue::U8(_)
800 | CircuitValue::U16(_)
801 | CircuitValue::U32(_)
802 | CircuitValue::U64(_) => ConstantType::Integer,
803 };
804 Ok(CircuitNode::EncryptedConstant {
805 data,
806 original_type,
807 })
808 }
809
810 CircuitNode::EncryptedConstant {
812 data,
813 original_type,
814 } => Ok(CircuitNode::EncryptedConstant {
815 data: data.clone(),
816 original_type: *original_type,
817 }),
818
819 CircuitNode::BinaryOp { op, left, right } => {
820 let left = encrypt_circuit_constants(left, key)?;
821 let right = encrypt_circuit_constants(right, key)?;
822 Ok(CircuitNode::BinaryOp {
823 op: *op,
824 left: Box::new(left),
825 right: Box::new(right),
826 })
827 }
828
829 CircuitNode::UnaryOp { op, operand } => {
830 let operand = encrypt_circuit_constants(operand, key)?;
831 Ok(CircuitNode::UnaryOp {
832 op: *op,
833 operand: Box::new(operand),
834 })
835 }
836
837 CircuitNode::Compare { op, left, right } => {
838 let left = encrypt_circuit_constants(left, key)?;
839 let right = encrypt_circuit_constants(right, key)?;
840 Ok(CircuitNode::Compare {
841 op: *op,
842 left: Box::new(left),
843 right: Box::new(right),
844 })
845 }
846 CircuitNode::NaryOp { op, operands } => {
847 let new_operands: Result<Vec<CircuitNode>> = operands
848 .iter()
849 .map(|o| encrypt_circuit_constants(o, key))
850 .collect();
851 Ok(CircuitNode::NaryOp {
852 op: *op,
853 operands: new_operands?,
854 })
855 }
856 }
857}
858
859fn derive_keystream(key: &[u8], length: usize) -> Vec<u8> {
866 let mut keystream = Vec::with_capacity(length);
867 let mut block_index: u64 = 0;
868
869 while keystream.len() < length {
870 let mut hash: u64 = 0xcbf29ce484222325;
872 for &byte in key {
873 hash ^= byte as u64;
874 hash = hash.wrapping_mul(0x100000001b3);
875 }
876 for &byte in &block_index.to_le_bytes() {
877 hash ^= byte as u64;
878 hash = hash.wrapping_mul(0x100000001b3);
879 }
880
881 for &byte in &hash.to_le_bytes() {
883 if keystream.len() < length {
884 keystream.push(byte);
885 }
886 }
887
888 block_index += 1;
889 }
890
891 keystream
892}
893
894pub fn is_encrypted_constant(node: &CircuitNode) -> bool {
896 matches!(node, CircuitNode::EncryptedConstant { .. })
897}
898
899pub fn count_plaintext_constants(node: &CircuitNode) -> usize {
901 match node {
902 CircuitNode::Constant(_) => 1,
903 CircuitNode::EncryptedConstant { .. } | CircuitNode::Load(_) => 0,
904 CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
905 count_plaintext_constants(left) + count_plaintext_constants(right)
906 }
907 CircuitNode::UnaryOp { operand, .. } => count_plaintext_constants(operand),
908 CircuitNode::NaryOp { operands, .. } => {
909 operands.iter().map(count_plaintext_constants).sum()
910 }
911 }
912}
913
914pub fn count_encrypted_constants(node: &CircuitNode) -> usize {
916 match node {
917 CircuitNode::EncryptedConstant { .. } => 1,
918 CircuitNode::Constant(_) | CircuitNode::Load(_) => 0,
919 CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
920 count_encrypted_constants(left) + count_encrypted_constants(right)
921 }
922 CircuitNode::UnaryOp { operand, .. } => count_encrypted_constants(operand),
923 CircuitNode::NaryOp { operands, .. } => {
924 operands.iter().map(count_encrypted_constants).sum()
925 }
926 }
927}
928
929#[derive(Debug, Clone, Default)]
934#[deprecated(
935 since = "0.1.0",
936 note = "Use CircuitOptimizer from optimizer module instead"
937)]
938pub struct CircuitOptimizer;
939
940#[allow(deprecated)]
941impl CircuitOptimizer {
942 pub fn new() -> Self {
943 Self
944 }
945
946 pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
951 let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
953 advanced_optimizer.optimize(circuit)
954 }
955}
956
957#[cfg(test)]
958mod tests {
959 use super::*;
960
961 #[test]
962 fn test_circuit_builder() -> Result<()> {
963 let mut builder = CircuitBuilder::new();
964 builder
965 .declare_variable("a", EncryptedType::U8)
966 .declare_variable("b", EncryptedType::U8);
967
968 let a = builder.load("a");
969 let b = builder.load("b");
970 let sum = builder.add(a, b);
971
972 let circuit = builder.build(sum)?;
973 assert_eq!(circuit.result_type, EncryptedType::U8);
974 assert_eq!(circuit.gate_count, 1);
975
976 Ok(())
977 }
978
979 #[test]
980 fn test_type_inference() -> Result<()> {
981 let mut builder = CircuitBuilder::new();
982 builder
983 .declare_variable("x", EncryptedType::Bool)
984 .declare_variable("y", EncryptedType::Bool);
985
986 let x = builder.load("x");
987 let y = builder.load("y");
988 let result = builder.and(x, y);
989
990 let circuit = builder.build(result)?;
991 assert_eq!(circuit.result_type, EncryptedType::Bool);
992
993 Ok(())
994 }
995
996 #[test]
997 fn test_type_mismatch_error() {
998 let mut builder = CircuitBuilder::new();
999 builder
1000 .declare_variable("a", EncryptedType::U8)
1001 .declare_variable("b", EncryptedType::Bool);
1002
1003 let a = builder.load("a");
1004 let b = builder.load("b");
1005 let invalid = builder.add(a, b);
1006
1007 let result = builder.build(invalid);
1008 assert!(result.is_err());
1009 }
1010
1011 #[test]
1012 #[allow(deprecated)]
1013 fn test_constant_folding() -> Result<()> {
1014 let optimizer = CircuitOptimizer::new();
1015 let builder = CircuitBuilder::new();
1016
1017 let a = builder.constant(CircuitValue::U8(5));
1018 let b = builder.constant(CircuitValue::U8(3));
1019 let sum = builder.add(a, b);
1020
1021 let circuit = Circuit::new(sum, HashMap::new())?;
1022 let optimized = optimizer.optimize(circuit)?;
1023
1024 match optimized.root {
1026 CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
1027 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
1028 "Constant folding failed".to_string(),
1029 ))),
1030 }
1031 }
1032
1033 #[test]
1034 fn test_circuit_depth() -> Result<()> {
1035 let mut builder = CircuitBuilder::new();
1036 builder
1037 .declare_variable("a", EncryptedType::U8)
1038 .declare_variable("b", EncryptedType::U8)
1039 .declare_variable("c", EncryptedType::U8);
1040
1041 let a = builder.load("a");
1042 let b = builder.load("b");
1043 let c = builder.load("c");
1044
1045 let sum1 = builder.add(a, b);
1047 let sum2 = builder.add(sum1, c);
1048
1049 let circuit = builder.build(sum2)?;
1050 assert_eq!(circuit.depth, 3); Ok(())
1053 }
1054
1055 #[test]
1058 fn test_encrypted_constant_creation() {
1059 let builder = CircuitBuilder::new();
1060 let enc = builder.encrypted_constant(vec![0xAA, 0xBB], ConstantType::Integer);
1061
1062 match enc {
1063 CircuitNode::EncryptedConstant {
1064 data,
1065 original_type,
1066 } => {
1067 assert_eq!(data, vec![0xAA, 0xBB]);
1068 assert_eq!(original_type, ConstantType::Integer);
1069 }
1070 _ => panic!("Expected EncryptedConstant"),
1071 }
1072 }
1073
1074 #[test]
1075 fn test_encrypt_constant_bool() -> Result<()> {
1076 let key = b"test-encryption-key";
1077 let value = CircuitValue::Bool(true);
1078
1079 let encrypted = encrypt_constant(&value, key)?;
1080 assert!(!encrypted.is_empty());
1081 assert_eq!(encrypted[0], 0x00);
1083
1084 let decrypted = decrypt_constant(&encrypted, key)?;
1085 assert_eq!(decrypted, value);
1086
1087 Ok(())
1088 }
1089
1090 #[test]
1091 fn test_encrypt_constant_u8() -> Result<()> {
1092 let key = b"test-key-u8";
1093 let value = CircuitValue::U8(42);
1094
1095 let encrypted = encrypt_constant(&value, key)?;
1096 assert_eq!(encrypted[0], 0x01); let decrypted = decrypt_constant(&encrypted, key)?;
1099 assert_eq!(decrypted, value);
1100
1101 Ok(())
1102 }
1103
1104 #[test]
1105 fn test_encrypt_constant_u16() -> Result<()> {
1106 let key = b"test-key-u16";
1107 let value = CircuitValue::U16(12345);
1108
1109 let encrypted = encrypt_constant(&value, key)?;
1110 assert_eq!(encrypted[0], 0x02);
1111
1112 let decrypted = decrypt_constant(&encrypted, key)?;
1113 assert_eq!(decrypted, value);
1114
1115 Ok(())
1116 }
1117
1118 #[test]
1119 fn test_encrypt_constant_u32() -> Result<()> {
1120 let key = b"test-key-u32";
1121 let value = CircuitValue::U32(1_000_000);
1122
1123 let encrypted = encrypt_constant(&value, key)?;
1124 assert_eq!(encrypted[0], 0x03);
1125
1126 let decrypted = decrypt_constant(&encrypted, key)?;
1127 assert_eq!(decrypted, value);
1128
1129 Ok(())
1130 }
1131
1132 #[test]
1133 fn test_encrypt_constant_u64() -> Result<()> {
1134 let key = b"test-key-u64";
1135 let value = CircuitValue::U64(0xDEAD_BEEF_CAFE_BABE);
1136
1137 let encrypted = encrypt_constant(&value, key)?;
1138 assert_eq!(encrypted[0], 0x04);
1139
1140 let decrypted = decrypt_constant(&encrypted, key)?;
1141 assert_eq!(decrypted, value);
1142
1143 Ok(())
1144 }
1145
1146 #[test]
1147 fn test_encrypt_constant_wrong_key_produces_wrong_value() -> Result<()> {
1148 let key1 = b"correct-key";
1149 let key2 = b"wrong-key!!";
1150 let value = CircuitValue::U8(42);
1151
1152 let encrypted = encrypt_constant(&value, key1)?;
1153 let decrypted = decrypt_constant(&encrypted, key2)?;
1154 assert_ne!(decrypted, value);
1156
1157 Ok(())
1158 }
1159
1160 #[test]
1161 fn test_encrypt_constant_empty_key_error() {
1162 let key: &[u8] = &[];
1163 let value = CircuitValue::U8(1);
1164
1165 let result = encrypt_constant(&value, key);
1166 assert!(result.is_err());
1167 }
1168
1169 #[test]
1170 fn test_decrypt_constant_empty_data_error() {
1171 let key = b"some-key";
1172 let result = decrypt_constant(&[], key);
1173 assert!(result.is_err());
1174 }
1175
1176 #[test]
1177 fn test_encrypt_circuit_constants_transforms_all() -> Result<()> {
1178 let builder = CircuitBuilder::new();
1179 let key = b"circuit-encryption-key";
1180
1181 let a = builder.constant(CircuitValue::U8(5));
1183 let b = builder.constant(CircuitValue::U8(3));
1184 let sum = builder.add(a, b);
1185
1186 assert_eq!(count_plaintext_constants(&sum), 2);
1188 assert_eq!(count_encrypted_constants(&sum), 0);
1189
1190 let encrypted = encrypt_circuit_constants(&sum, key)?;
1192
1193 assert_eq!(count_plaintext_constants(&encrypted), 0);
1195 assert_eq!(count_encrypted_constants(&encrypted), 2);
1196
1197 match &encrypted {
1199 CircuitNode::BinaryOp { op, left, right } => {
1200 assert_eq!(*op, BinaryOperator::Add);
1201 assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
1202 assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1203 }
1204 _ => panic!("Expected BinaryOp after encryption"),
1205 }
1206
1207 Ok(())
1208 }
1209
1210 #[test]
1211 fn test_encrypt_circuit_constants_preserves_loads() -> Result<()> {
1212 let mut builder = CircuitBuilder::new();
1213 builder.declare_variable("x", EncryptedType::U8);
1214 let key = b"key-for-loads-test";
1215
1216 let x = builder.load("x");
1218 let c = builder.constant(CircuitValue::U8(10));
1219 let sum = builder.add(x, c);
1220
1221 let encrypted = encrypt_circuit_constants(&sum, key)?;
1222
1223 match &encrypted {
1225 CircuitNode::BinaryOp { left, right, .. } => {
1226 assert!(matches!(**left, CircuitNode::Load(ref name) if name == "x"));
1227 assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1228 }
1229 _ => panic!("Expected BinaryOp"),
1230 }
1231
1232 Ok(())
1233 }
1234
1235 #[test]
1236 fn test_encrypt_circuit_constants_already_encrypted_pass_through() -> Result<()> {
1237 let builder = CircuitBuilder::new();
1238 let key = b"key-pass-through";
1239
1240 let enc = builder.encrypted_constant(vec![0x01, 0x02, 0x03], ConstantType::Integer);
1242 let original_data = vec![0x01, 0x02, 0x03];
1243
1244 let result = encrypt_circuit_constants(&enc, key)?;
1245
1246 match result {
1248 CircuitNode::EncryptedConstant {
1249 data,
1250 original_type,
1251 } => {
1252 assert_eq!(data, original_data);
1253 assert_eq!(original_type, ConstantType::Integer);
1254 }
1255 _ => panic!("Expected EncryptedConstant pass-through"),
1256 }
1257
1258 Ok(())
1259 }
1260
1261 #[test]
1262 fn test_encrypted_constant_display() {
1263 let node = CircuitNode::EncryptedConstant {
1264 data: vec![0xAA, 0xBB, 0xCC],
1265 original_type: ConstantType::Boolean,
1266 };
1267 let display = format!("{}", node);
1268 assert!(display.contains("EncryptedConst"));
1269 assert!(display.contains("boolean"));
1270 assert!(display.contains("3 bytes"));
1271 }
1272
1273 #[test]
1274 fn test_circuit_node_display_variants() {
1275 let load = CircuitNode::Load("x".to_string());
1277 assert_eq!(format!("{}", load), "Load(x)");
1278
1279 let constant = CircuitNode::Constant(CircuitValue::U8(42));
1281 assert_eq!(format!("{}", constant), "Const(42u8)");
1282
1283 let bool_const = CircuitNode::Constant(CircuitValue::Bool(true));
1285 assert_eq!(format!("{}", bool_const), "Const(true)");
1286 }
1287
1288 #[test]
1289 fn test_constant_type_display() {
1290 assert_eq!(format!("{}", ConstantType::Integer), "integer");
1291 assert_eq!(format!("{}", ConstantType::Boolean), "boolean");
1292 assert_eq!(format!("{}", ConstantType::Float), "float");
1293 assert_eq!(format!("{}", ConstantType::Bytes), "bytes");
1294 }
1295
1296 #[test]
1297 fn test_constant_type_variants() {
1298 let variants = [
1300 ConstantType::Integer,
1301 ConstantType::Boolean,
1302 ConstantType::Float,
1303 ConstantType::Bytes,
1304 ];
1305 for (i, a) in variants.iter().enumerate() {
1306 for (j, b) in variants.iter().enumerate() {
1307 if i == j {
1308 assert_eq!(a, b);
1309 } else {
1310 assert_ne!(a, b);
1311 }
1312 }
1313 }
1314 }
1315
1316 #[test]
1317 fn test_constant_type_serialization_roundtrip() -> Result<()> {
1318 let types = [
1319 ConstantType::Integer,
1320 ConstantType::Boolean,
1321 ConstantType::Float,
1322 ConstantType::Bytes,
1323 ];
1324
1325 for ct in &types {
1326 let json = serde_json::to_string(ct).map_err(|e| {
1327 AmateRSError::FheComputation(ErrorContext::new(format!(
1328 "Serialization failed: {}",
1329 e
1330 )))
1331 })?;
1332 let deserialized: ConstantType = serde_json::from_str(&json).map_err(|e| {
1333 AmateRSError::FheComputation(ErrorContext::new(format!(
1334 "Deserialization failed: {}",
1335 e
1336 )))
1337 })?;
1338 assert_eq!(*ct, deserialized);
1339 }
1340
1341 Ok(())
1342 }
1343
1344 #[test]
1345 fn test_is_encrypted_constant() {
1346 let enc = CircuitNode::EncryptedConstant {
1347 data: vec![1, 2, 3],
1348 original_type: ConstantType::Integer,
1349 };
1350 assert!(is_encrypted_constant(&enc));
1351
1352 let plain = CircuitNode::Constant(CircuitValue::U8(5));
1353 assert!(!is_encrypted_constant(&plain));
1354
1355 let load = CircuitNode::Load("x".to_string());
1356 assert!(!is_encrypted_constant(&load));
1357 }
1358
1359 #[test]
1360 fn test_encrypted_constant_in_circuit_validation() -> Result<()> {
1361 let enc = CircuitNode::EncryptedConstant {
1363 data: vec![0x00, 0x01],
1364 original_type: ConstantType::Boolean,
1365 };
1366 let circuit = Circuit::new(enc, HashMap::new())?;
1367 circuit.validate()?;
1368 assert_eq!(circuit.result_type, EncryptedType::Bool);
1370 Ok(())
1371 }
1372
1373 #[test]
1374 fn test_encrypted_constant_depth_and_gate_count() -> Result<()> {
1375 let builder = CircuitBuilder::new();
1376
1377 let enc = builder.encrypted_constant(vec![0x01, 0x42], ConstantType::Integer);
1379 let circuit = Circuit::new(enc, HashMap::new())?;
1380 assert_eq!(circuit.depth, 1);
1381 assert_eq!(circuit.gate_count, 0);
1382
1383 Ok(())
1384 }
1385
1386 #[test]
1387 fn test_mixed_plain_and_encrypted_constants() -> Result<()> {
1388 let builder = CircuitBuilder::new();
1389 let key = b"mixed-circuit-key";
1390
1391 let plain = builder.constant(CircuitValue::U8(10));
1393 let encrypted_data = encrypt_constant(&CircuitValue::U8(20), key)?;
1394 let enc = builder.encrypted_constant(encrypted_data, ConstantType::Integer);
1395
1396 let not_node = CircuitNode::UnaryOp {
1399 op: UnaryOperator::Not,
1400 operand: Box::new(CircuitNode::Constant(CircuitValue::Bool(true))),
1401 };
1402
1403 assert_eq!(count_plaintext_constants(&plain), 1);
1407 assert_eq!(count_encrypted_constants(&plain), 0);
1408 assert_eq!(count_plaintext_constants(&enc), 0);
1409 assert_eq!(count_encrypted_constants(&enc), 1);
1410 assert_eq!(count_plaintext_constants(¬_node), 1);
1411 assert_eq!(count_encrypted_constants(¬_node), 0);
1412
1413 Ok(())
1414 }
1415
1416 #[test]
1417 fn test_encrypt_constant_deterministic() -> Result<()> {
1418 let key = b"deterministic-test-key";
1419 let value = CircuitValue::U32(999);
1420
1421 let enc1 = encrypt_constant(&value, key)?;
1422 let enc2 = encrypt_constant(&value, key)?;
1423 assert_eq!(enc1, enc2);
1425
1426 Ok(())
1427 }
1428
1429 #[test]
1430 fn test_encrypt_constant_different_keys_differ() -> Result<()> {
1431 let key1 = b"key-alpha";
1432 let key2 = b"key-bravo";
1433 let value = CircuitValue::U64(123456789);
1434
1435 let enc1 = encrypt_constant(&value, key1)?;
1436 let enc2 = encrypt_constant(&value, key2)?;
1437 assert_eq!(enc1[0], enc2[0]); assert_ne!(enc1[1..], enc2[1..]); Ok(())
1443 }
1444
1445 #[test]
1446 fn test_encrypt_decrypt_roundtrip_all_types() -> Result<()> {
1447 let key = b"roundtrip-all-types";
1448
1449 let values = vec![
1450 CircuitValue::Bool(false),
1451 CircuitValue::Bool(true),
1452 CircuitValue::U8(0),
1453 CircuitValue::U8(255),
1454 CircuitValue::U16(0),
1455 CircuitValue::U16(65535),
1456 CircuitValue::U32(0),
1457 CircuitValue::U32(u32::MAX),
1458 CircuitValue::U64(0),
1459 CircuitValue::U64(u64::MAX),
1460 ];
1461
1462 for value in &values {
1463 let encrypted = encrypt_constant(value, key)?;
1464 let decrypted = decrypt_constant(&encrypted, key)?;
1465 assert_eq!(*value, decrypted, "Roundtrip failed for {:?}", value);
1466 }
1467
1468 Ok(())
1469 }
1470
1471 #[test]
1472 fn test_encrypt_circuit_constants_nested() -> Result<()> {
1473 let builder = CircuitBuilder::new();
1474 let key = b"nested-circuit-key";
1475
1476 let t = builder.constant(CircuitValue::Bool(true));
1478 let f = builder.constant(CircuitValue::Bool(false));
1479 let and_node = builder.and(t, f);
1480 let not_node = builder.not(and_node);
1481
1482 assert_eq!(count_plaintext_constants(¬_node), 2);
1483 assert_eq!(count_encrypted_constants(¬_node), 0);
1484
1485 let encrypted = encrypt_circuit_constants(¬_node, key)?;
1486
1487 assert_eq!(count_plaintext_constants(&encrypted), 0);
1488 assert_eq!(count_encrypted_constants(&encrypted), 2);
1489
1490 match &encrypted {
1492 CircuitNode::UnaryOp { op, operand } => {
1493 assert_eq!(*op, UnaryOperator::Not);
1494 match operand.as_ref() {
1495 CircuitNode::BinaryOp { op, left, right } => {
1496 assert_eq!(*op, BinaryOperator::And);
1497 assert!(is_encrypted_constant(left));
1498 assert!(is_encrypted_constant(right));
1499 }
1500 _ => panic!("Expected BinaryOp inside UnaryOp"),
1501 }
1502 }
1503 _ => panic!("Expected UnaryOp at root"),
1504 }
1505
1506 Ok(())
1507 }
1508}