1use crate::error::{AmateRSError, ErrorContext, Result};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CircuitNode {
12 Load(String),
14
15 Constant(CircuitValue),
17
18 BinaryOp {
20 op: BinaryOperator,
21 left: Box<CircuitNode>,
22 right: Box<CircuitNode>,
23 },
24
25 UnaryOp {
27 op: UnaryOperator,
28 operand: Box<CircuitNode>,
29 },
30
31 Compare {
33 op: CompareOperator,
34 left: Box<CircuitNode>,
35 right: Box<CircuitNode>,
36 },
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum BinaryOperator {
42 Add,
43 Sub,
44 Mul,
45 And,
46 Or,
47 Xor,
48}
49
50impl BinaryOperator {
51 pub fn as_str(&self) -> &str {
53 match self {
54 BinaryOperator::Add => "+",
55 BinaryOperator::Sub => "-",
56 BinaryOperator::Mul => "*",
57 BinaryOperator::And => "AND",
58 BinaryOperator::Or => "OR",
59 BinaryOperator::Xor => "XOR",
60 }
61 }
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum UnaryOperator {
67 Not,
68 Neg,
69}
70
71impl UnaryOperator {
72 pub fn as_str(&self) -> &str {
74 match self {
75 UnaryOperator::Not => "NOT",
76 UnaryOperator::Neg => "-",
77 }
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum CompareOperator {
84 Eq,
85 Ne,
86 Lt,
87 Le,
88 Gt,
89 Ge,
90}
91
92impl CompareOperator {
93 pub fn as_str(&self) -> &str {
95 match self {
96 CompareOperator::Eq => "=",
97 CompareOperator::Ne => "!=",
98 CompareOperator::Lt => "<",
99 CompareOperator::Le => "<=",
100 CompareOperator::Gt => ">",
101 CompareOperator::Ge => ">=",
102 }
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum CircuitValue {
109 Bool(bool),
110 U8(u8),
111 U16(u16),
112 U32(u32),
113 U64(u64),
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
118pub enum EncryptedType {
119 Bool,
120 U8,
121 U16,
122 U32,
123 U64,
124}
125
126impl EncryptedType {
127 pub fn bit_width(&self) -> usize {
129 match self {
130 EncryptedType::Bool => 1,
131 EncryptedType::U8 => 8,
132 EncryptedType::U16 => 16,
133 EncryptedType::U32 => 32,
134 EncryptedType::U64 => 64,
135 }
136 }
137
138 pub fn is_numeric(&self) -> bool {
140 !matches!(self, EncryptedType::Bool)
141 }
142
143 pub fn is_boolean(&self) -> bool {
145 matches!(self, EncryptedType::Bool)
146 }
147}
148
149impl std::fmt::Display for EncryptedType {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 EncryptedType::Bool => write!(f, "bool"),
153 EncryptedType::U8 => write!(f, "u8"),
154 EncryptedType::U16 => write!(f, "u16"),
155 EncryptedType::U32 => write!(f, "u32"),
156 EncryptedType::U64 => write!(f, "u64"),
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct Circuit {
164 pub root: CircuitNode,
166
167 pub variable_types: HashMap<String, EncryptedType>,
169
170 pub result_type: EncryptedType,
172
173 pub depth: usize,
175
176 pub gate_count: usize,
178}
179
180impl Circuit {
181 pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
183 let result_type = Self::infer_type(&root, &variable_types)?;
184 let depth = Self::compute_depth(&root);
185 let gate_count = Self::count_gates(&root);
186
187 Ok(Self {
188 root,
189 variable_types,
190 result_type,
191 depth,
192 gate_count,
193 })
194 }
195
196 fn infer_type(
198 node: &CircuitNode,
199 variable_types: &HashMap<String, EncryptedType>,
200 ) -> Result<EncryptedType> {
201 match node {
202 CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
203 AmateRSError::FheComputation(ErrorContext::new(format!(
204 "Undefined variable: {}",
205 name
206 )))
207 }),
208
209 CircuitNode::Constant(value) => Ok(match value {
210 CircuitValue::Bool(_) => EncryptedType::Bool,
211 CircuitValue::U8(_) => EncryptedType::U8,
212 CircuitValue::U16(_) => EncryptedType::U16,
213 CircuitValue::U32(_) => EncryptedType::U32,
214 CircuitValue::U64(_) => EncryptedType::U64,
215 }),
216
217 CircuitNode::BinaryOp { op, left, right } => {
218 let left_type = Self::infer_type(left, variable_types)?;
219 let right_type = Self::infer_type(right, variable_types)?;
220
221 match op {
222 BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
223 if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
224 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
225 "Logical operation requires boolean operands, got {} and {}",
226 left_type, right_type
227 ))));
228 }
229 Ok(EncryptedType::Bool)
230 }
231
232 BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
233 if !left_type.is_numeric() || !right_type.is_numeric() {
234 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
235 "Arithmetic operation requires numeric operands, got {} and {}",
236 left_type, right_type
237 ))));
238 }
239
240 if left_type != right_type {
241 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
242 "Arithmetic operation requires matching types, got {} and {}",
243 left_type, right_type
244 ))));
245 }
246
247 Ok(left_type)
248 }
249 }
250 }
251
252 CircuitNode::UnaryOp { op, operand } => {
253 let operand_type = Self::infer_type(operand, variable_types)?;
254
255 match op {
256 UnaryOperator::Not => {
257 if operand_type != EncryptedType::Bool {
258 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
259 "NOT operation requires boolean operand, got {}",
260 operand_type
261 ))));
262 }
263 Ok(EncryptedType::Bool)
264 }
265
266 UnaryOperator::Neg => {
267 if !operand_type.is_numeric() {
268 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
269 "Negation operation requires numeric operand, got {}",
270 operand_type
271 ))));
272 }
273 Ok(operand_type)
274 }
275 }
276 }
277
278 CircuitNode::Compare { left, right, .. } => {
279 let left_type = Self::infer_type(left, variable_types)?;
280 let right_type = Self::infer_type(right, variable_types)?;
281
282 if left_type != right_type {
283 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
284 "Comparison requires matching types, got {} and {}",
285 left_type, right_type
286 ))));
287 }
288
289 Ok(EncryptedType::Bool)
290 }
291 }
292 }
293
294 fn compute_depth(node: &CircuitNode) -> usize {
296 match node {
297 CircuitNode::Load(_) | CircuitNode::Constant(_) => 1,
298
299 CircuitNode::BinaryOp { left, right, .. }
300 | CircuitNode::Compare { left, right, .. } => {
301 1 + Self::compute_depth(left).max(Self::compute_depth(right))
302 }
303
304 CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
305 }
306 }
307
308 fn count_gates(node: &CircuitNode) -> usize {
310 match node {
311 CircuitNode::Load(_) | CircuitNode::Constant(_) => 0,
312
313 CircuitNode::BinaryOp { left, right, .. }
314 | CircuitNode::Compare { left, right, .. } => {
315 1 + Self::count_gates(left) + Self::count_gates(right)
316 }
317
318 CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
319 }
320 }
321
322 pub fn validate(&self) -> Result<()> {
324 Self::validate_node(&self.root, &self.variable_types)?;
325 Ok(())
326 }
327
328 fn validate_node(
329 node: &CircuitNode,
330 variable_types: &HashMap<String, EncryptedType>,
331 ) -> Result<()> {
332 match node {
333 CircuitNode::Load(name) => {
334 if !variable_types.contains_key(name) {
335 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
336 "Undefined variable: {}",
337 name
338 ))));
339 }
340 Ok(())
341 }
342
343 CircuitNode::Constant(_) => Ok(()),
344
345 CircuitNode::BinaryOp { left, right, .. }
346 | CircuitNode::Compare { left, right, .. } => {
347 Self::validate_node(left, variable_types)?;
348 Self::validate_node(right, variable_types)?;
349 Ok(())
350 }
351
352 CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
353 }
354 }
355}
356
357#[derive(Default)]
359pub struct CircuitBuilder {
360 variable_types: HashMap<String, EncryptedType>,
361}
362
363impl CircuitBuilder {
364 pub fn new() -> Self {
365 Self::default()
366 }
367
368 pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
370 &self.variable_types
371 }
372
373 pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
375 self.variable_types.clone()
376 }
377
378 pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
380 self.variable_types.insert(name.into(), ty);
381 self
382 }
383
384 pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
386 Circuit::new(root, self.variable_types.clone())
387 }
388
389 pub fn load(&self, name: impl Into<String>) -> CircuitNode {
391 CircuitNode::Load(name.into())
392 }
393
394 pub fn constant(&self, value: CircuitValue) -> CircuitNode {
396 CircuitNode::Constant(value)
397 }
398
399 pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
401 CircuitNode::BinaryOp {
402 op: BinaryOperator::Add,
403 left: Box::new(left),
404 right: Box::new(right),
405 }
406 }
407
408 pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
410 CircuitNode::BinaryOp {
411 op: BinaryOperator::Sub,
412 left: Box::new(left),
413 right: Box::new(right),
414 }
415 }
416
417 pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
419 CircuitNode::BinaryOp {
420 op: BinaryOperator::Mul,
421 left: Box::new(left),
422 right: Box::new(right),
423 }
424 }
425
426 pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
428 CircuitNode::BinaryOp {
429 op: BinaryOperator::And,
430 left: Box::new(left),
431 right: Box::new(right),
432 }
433 }
434
435 pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
437 CircuitNode::BinaryOp {
438 op: BinaryOperator::Or,
439 left: Box::new(left),
440 right: Box::new(right),
441 }
442 }
443
444 pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
446 CircuitNode::BinaryOp {
447 op: BinaryOperator::Xor,
448 left: Box::new(left),
449 right: Box::new(right),
450 }
451 }
452
453 pub fn not(&self, operand: CircuitNode) -> CircuitNode {
455 CircuitNode::UnaryOp {
456 op: UnaryOperator::Not,
457 operand: Box::new(operand),
458 }
459 }
460
461 pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
463 CircuitNode::Compare {
464 op: CompareOperator::Eq,
465 left: Box::new(left),
466 right: Box::new(right),
467 }
468 }
469
470 pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
472 CircuitNode::Compare {
473 op: CompareOperator::Lt,
474 left: Box::new(left),
475 right: Box::new(right),
476 }
477 }
478
479 pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
481 CircuitNode::Compare {
482 op: CompareOperator::Gt,
483 left: Box::new(left),
484 right: Box::new(right),
485 }
486 }
487}
488
489#[derive(Debug, Clone, Default)]
494#[deprecated(
495 since = "0.1.0",
496 note = "Use CircuitOptimizer from optimizer module instead"
497)]
498pub struct CircuitOptimizer;
499
500#[allow(deprecated)]
501impl CircuitOptimizer {
502 pub fn new() -> Self {
503 Self
504 }
505
506 pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
511 let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
513 advanced_optimizer.optimize(circuit)
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_circuit_builder() -> Result<()> {
523 let mut builder = CircuitBuilder::new();
524 builder
525 .declare_variable("a", EncryptedType::U8)
526 .declare_variable("b", EncryptedType::U8);
527
528 let a = builder.load("a");
529 let b = builder.load("b");
530 let sum = builder.add(a, b);
531
532 let circuit = builder.build(sum)?;
533 assert_eq!(circuit.result_type, EncryptedType::U8);
534 assert_eq!(circuit.gate_count, 1);
535
536 Ok(())
537 }
538
539 #[test]
540 fn test_type_inference() -> Result<()> {
541 let mut builder = CircuitBuilder::new();
542 builder
543 .declare_variable("x", EncryptedType::Bool)
544 .declare_variable("y", EncryptedType::Bool);
545
546 let x = builder.load("x");
547 let y = builder.load("y");
548 let result = builder.and(x, y);
549
550 let circuit = builder.build(result)?;
551 assert_eq!(circuit.result_type, EncryptedType::Bool);
552
553 Ok(())
554 }
555
556 #[test]
557 fn test_type_mismatch_error() {
558 let mut builder = CircuitBuilder::new();
559 builder
560 .declare_variable("a", EncryptedType::U8)
561 .declare_variable("b", EncryptedType::Bool);
562
563 let a = builder.load("a");
564 let b = builder.load("b");
565 let invalid = builder.add(a, b);
566
567 let result = builder.build(invalid);
568 assert!(result.is_err());
569 }
570
571 #[test]
572 #[allow(deprecated)]
573 fn test_constant_folding() -> Result<()> {
574 let optimizer = CircuitOptimizer::new();
575 let builder = CircuitBuilder::new();
576
577 let a = builder.constant(CircuitValue::U8(5));
578 let b = builder.constant(CircuitValue::U8(3));
579 let sum = builder.add(a, b);
580
581 let circuit = Circuit::new(sum, HashMap::new())?;
582 let optimized = optimizer.optimize(circuit)?;
583
584 match optimized.root {
586 CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
587 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
588 "Constant folding failed".to_string(),
589 ))),
590 }
591 }
592
593 #[test]
594 fn test_circuit_depth() -> Result<()> {
595 let mut builder = CircuitBuilder::new();
596 builder
597 .declare_variable("a", EncryptedType::U8)
598 .declare_variable("b", EncryptedType::U8)
599 .declare_variable("c", EncryptedType::U8);
600
601 let a = builder.load("a");
602 let b = builder.load("b");
603 let c = builder.load("c");
604
605 let sum1 = builder.add(a, b);
607 let sum2 = builder.add(sum1, c);
608
609 let circuit = builder.build(sum2)?;
610 assert_eq!(circuit.depth, 3); Ok(())
613 }
614}