1pub mod circuit;
39pub mod key_manager;
40pub mod keys;
41pub mod operations;
42pub mod optimizer;
43pub mod predicate;
44
45#[cfg(test)]
46mod filter_tests;
47
48pub use circuit::{
50 BinaryOperator, Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator,
51 EncryptedType, UnaryOperator,
52};
53pub use key_manager::{ClientId, KeyManager};
54pub use keys::{FheKeyPair, InMemoryKeyStorage, KeyStorage};
55pub use operations::{EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64};
56pub use optimizer::{CircuitOptimizer, DependencyGraph, NodeId, OptimizationStats};
57pub use predicate::{PredicateCompiler, compile_predicate};
58
59use crate::error::{AmateRSError, ErrorContext, Result};
60use crate::types::CipherBlob;
61use std::collections::HashMap;
62
63#[derive(Debug, Clone)]
68pub struct FheExecutor {
69 optimizer: CircuitOptimizer,
70 optimization_enabled: bool,
71}
72
73impl FheExecutor {
74 pub fn new() -> Self {
76 Self {
77 optimizer: CircuitOptimizer::new(),
78 optimization_enabled: true,
79 }
80 }
81
82 pub fn with_optimization(enable: bool) -> Self {
84 Self {
85 optimizer: if enable {
86 CircuitOptimizer::new()
87 } else {
88 CircuitOptimizer::disabled()
89 },
90 optimization_enabled: enable,
91 }
92 }
93
94 pub fn optimization_stats(&self) -> &OptimizationStats {
96 self.optimizer.stats()
97 }
98
99 pub fn dependency_graph(&self) -> &DependencyGraph {
101 self.optimizer.dependency_graph()
102 }
103
104 #[cfg(feature = "compute")]
115 pub fn execute(
116 &self,
117 circuit: &Circuit,
118 inputs: &HashMap<String, CipherBlob>,
119 ) -> Result<CipherBlob> {
120 circuit.validate()?;
122
123 for var_name in circuit.variable_types.keys() {
125 if !inputs.contains_key(var_name) {
126 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
127 "Missing input for variable: {}",
128 var_name
129 ))));
130 }
131 }
132
133 let optimized = if self.optimization_enabled {
135 let mut optimizer = self.optimizer.clone();
137 optimizer.optimize(circuit.clone())?
138 } else {
139 circuit.clone()
140 };
141
142 let result_value = self.execute_node(&optimized.root, inputs, &optimized.variable_types)?;
144
145 match result_value {
147 EncryptedValue::Bool(v) => v.to_cipher_blob(),
148 EncryptedValue::U8(v) => v.to_cipher_blob(),
149 EncryptedValue::U16(v) => v.to_cipher_blob(),
150 EncryptedValue::U32(v) => v.to_cipher_blob(),
151 EncryptedValue::U64(v) => v.to_cipher_blob(),
152 }
153 }
154
155 #[cfg(not(feature = "compute"))]
157 pub fn execute(
158 &self,
159 _circuit: &Circuit,
160 _inputs: &HashMap<String, CipherBlob>,
161 ) -> Result<CipherBlob> {
162 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
163 "FHE compute feature is not enabled".to_string(),
164 )))
165 }
166
167 #[cfg(feature = "compute")]
169 #[allow(clippy::only_used_in_recursion)]
170 fn execute_node(
171 &self,
172 node: &CircuitNode,
173 inputs: &HashMap<String, CipherBlob>,
174 variable_types: &HashMap<String, EncryptedType>,
175 ) -> Result<EncryptedValue> {
176 match node {
177 CircuitNode::Load(name) => {
178 let blob = inputs.get(name).ok_or_else(|| {
179 AmateRSError::FheComputation(ErrorContext::new(format!(
180 "Missing input: {}",
181 name
182 )))
183 })?;
184
185 let var_type = variable_types.get(name).ok_or_else(|| {
186 AmateRSError::FheComputation(ErrorContext::new(format!(
187 "Unknown variable type: {}",
188 name
189 )))
190 })?;
191
192 match var_type {
193 EncryptedType::Bool => {
194 Ok(EncryptedValue::Bool(EncryptedBool::from_cipher_blob(blob)?))
195 }
196 EncryptedType::U8 => {
197 Ok(EncryptedValue::U8(EncryptedU8::from_cipher_blob(blob)?))
198 }
199 EncryptedType::U16 => {
200 Ok(EncryptedValue::U16(EncryptedU16::from_cipher_blob(blob)?))
201 }
202 EncryptedType::U32 => {
203 Ok(EncryptedValue::U32(EncryptedU32::from_cipher_blob(blob)?))
204 }
205 EncryptedType::U64 => {
206 Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(blob)?))
207 }
208 }
209 }
210
211 CircuitNode::Constant(_value) => {
212 Err(AmateRSError::FheComputation(ErrorContext::new(
214 "Encrypted constants not yet supported".to_string(),
215 )))
216 }
217
218 CircuitNode::BinaryOp { op, left, right } => {
219 let left_val = self.execute_node(left, inputs, variable_types)?;
220 let right_val = self.execute_node(right, inputs, variable_types)?;
221
222 match (op, left_val, right_val) {
223 (BinaryOperator::And, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
225 Ok(EncryptedValue::Bool(l.and(&r)))
226 }
227 (BinaryOperator::Or, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
228 Ok(EncryptedValue::Bool(l.or(&r)))
229 }
230 (BinaryOperator::Xor, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
231 Ok(EncryptedValue::Bool(l.xor(&r)))
232 }
233
234 (BinaryOperator::Add, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
236 Ok(EncryptedValue::U8(l.add(&r)))
237 }
238 (BinaryOperator::Sub, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
239 Ok(EncryptedValue::U8(l.sub(&r)))
240 }
241 (BinaryOperator::Mul, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
242 Ok(EncryptedValue::U8(l.mul(&r)))
243 }
244
245 (BinaryOperator::Add, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
247 Ok(EncryptedValue::U16(l.add(&r)))
248 }
249 (BinaryOperator::Sub, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
250 Ok(EncryptedValue::U16(l.sub(&r)))
251 }
252 (BinaryOperator::Mul, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
253 Ok(EncryptedValue::U16(l.mul(&r)))
254 }
255
256 (BinaryOperator::Add, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
258 Ok(EncryptedValue::U32(l.add(&r)))
259 }
260 (BinaryOperator::Sub, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
261 Ok(EncryptedValue::U32(l.sub(&r)))
262 }
263 (BinaryOperator::Mul, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
264 Ok(EncryptedValue::U32(l.mul(&r)))
265 }
266
267 (BinaryOperator::Add, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
269 Ok(EncryptedValue::U64(l.add(&r)))
270 }
271 (BinaryOperator::Sub, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
272 Ok(EncryptedValue::U64(l.sub(&r)))
273 }
274 (BinaryOperator::Mul, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
275 Ok(EncryptedValue::U64(l.mul(&r)))
276 }
277
278 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
279 "Type mismatch in binary operation".to_string(),
280 ))),
281 }
282 }
283
284 CircuitNode::UnaryOp { op, operand } => {
285 let operand_val = self.execute_node(operand, inputs, variable_types)?;
286
287 match (op, operand_val) {
288 (UnaryOperator::Not, EncryptedValue::Bool(v)) => {
289 Ok(EncryptedValue::Bool(v.not()))
290 }
291
292 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
293 "Type mismatch in unary operation".to_string(),
294 ))),
295 }
296 }
297
298 CircuitNode::Compare { op, left, right } => {
299 let left_val = self.execute_node(left, inputs, variable_types)?;
300 let right_val = self.execute_node(right, inputs, variable_types)?;
301
302 match (left_val, right_val) {
303 (EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
304 let result = match op {
305 CompareOperator::Eq => l.eq(&r),
306 CompareOperator::Ne => l.ne(&r),
307 CompareOperator::Lt => l.lt(&r),
308 CompareOperator::Le => l.le(&r),
309 CompareOperator::Gt => l.gt(&r),
310 CompareOperator::Ge => l.ge(&r),
311 };
312 Ok(EncryptedValue::Bool(result))
313 }
314
315 (EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
316 let result = match op {
317 CompareOperator::Eq => l.eq(&r),
318 CompareOperator::Ne => l.ne(&r),
319 CompareOperator::Lt => l.lt(&r),
320 CompareOperator::Le => l.le(&r),
321 CompareOperator::Gt => l.gt(&r),
322 CompareOperator::Ge => l.ge(&r),
323 };
324 Ok(EncryptedValue::Bool(result))
325 }
326
327 (EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
328 let result = match op {
329 CompareOperator::Eq => l.eq(&r),
330 CompareOperator::Ne => l.ne(&r),
331 CompareOperator::Lt => l.lt(&r),
332 CompareOperator::Le => l.le(&r),
333 CompareOperator::Gt => l.gt(&r),
334 CompareOperator::Ge => l.ge(&r),
335 };
336 Ok(EncryptedValue::Bool(result))
337 }
338
339 (EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
340 let result = match op {
341 CompareOperator::Eq => l.eq(&r),
342 CompareOperator::Ne => l.ne(&r),
343 CompareOperator::Lt => l.lt(&r),
344 CompareOperator::Le => l.le(&r),
345 CompareOperator::Gt => l.gt(&r),
346 CompareOperator::Ge => l.ge(&r),
347 };
348 Ok(EncryptedValue::Bool(result))
349 }
350
351 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
352 "Type mismatch in comparison".to_string(),
353 ))),
354 }
355 }
356 }
357 }
358}
359
360impl Default for FheExecutor {
361 fn default() -> Self {
362 Self::new()
363 }
364}
365
366#[cfg(feature = "compute")]
368enum EncryptedValue {
369 Bool(EncryptedBool),
370 U8(EncryptedU8),
371 U16(EncryptedU16),
372 U32(EncryptedU32),
373 U64(EncryptedU64),
374}
375
376#[deprecated(since = "0.1.0", note = "Use CircuitNode instead")]
380#[derive(Debug, Clone)]
381pub enum Gate {
382 Add,
383 Mul,
384 Not,
385 Bootstrap,
386}
387
388#[cfg(all(test, feature = "compute"))]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_fhe_executor_basic() -> Result<()> {
394 let keypair = FheKeyPair::generate()?;
396 keypair.set_as_global_server_key();
397
398 let mut builder = CircuitBuilder::new();
400 builder
401 .declare_variable("a", EncryptedType::U8)
402 .declare_variable("b", EncryptedType::U8);
403
404 let a_node = builder.load("a");
405 let b_node = builder.load("b");
406 let sum_node = builder.add(a_node, b_node);
407
408 let circuit = builder.build(sum_node)?;
409
410 let a = EncryptedU8::encrypt(5, keypair.client_key());
412 let b = EncryptedU8::encrypt(3, keypair.client_key());
413
414 let mut inputs = HashMap::new();
415 inputs.insert("a".to_string(), a.to_cipher_blob()?);
416 inputs.insert("b".to_string(), b.to_cipher_blob()?);
417
418 let executor = FheExecutor::new();
420 let result_blob = executor.execute(&circuit, &inputs)?;
421
422 let result = EncryptedU8::from_cipher_blob(&result_blob)?;
424 assert_eq!(result.decrypt(keypair.client_key()), 8);
425
426 Ok(())
427 }
428
429 #[test]
430 fn test_fhe_executor_boolean() -> Result<()> {
431 let keypair = FheKeyPair::generate()?;
432 keypair.set_as_global_server_key();
433
434 let mut builder = CircuitBuilder::new();
435 builder
436 .declare_variable("x", EncryptedType::Bool)
437 .declare_variable("y", EncryptedType::Bool);
438
439 let x_node = builder.load("x");
440 let y_node = builder.load("y");
441 let and_node = builder.and(x_node, y_node);
442
443 let circuit = builder.build(and_node)?;
444
445 let x = EncryptedBool::encrypt(true, keypair.client_key());
446 let y = EncryptedBool::encrypt(false, keypair.client_key());
447
448 let mut inputs = HashMap::new();
449 inputs.insert("x".to_string(), x.to_cipher_blob()?);
450 inputs.insert("y".to_string(), y.to_cipher_blob()?);
451
452 let executor = FheExecutor::new();
453 let result_blob = executor.execute(&circuit, &inputs)?;
454
455 let result = EncryptedBool::from_cipher_blob(&result_blob)?;
456 assert!(!result.decrypt(keypair.client_key()));
457
458 Ok(())
459 }
460
461 #[test]
462 fn test_fhe_executor_comparison() -> Result<()> {
463 let keypair = FheKeyPair::generate()?;
464 keypair.set_as_global_server_key();
465
466 let mut builder = CircuitBuilder::new();
467 builder
468 .declare_variable("a", EncryptedType::U8)
469 .declare_variable("b", EncryptedType::U8);
470
471 let a_node = builder.load("a");
472 let b_node = builder.load("b");
473 let gt_node = builder.gt(a_node, b_node);
474
475 let circuit = builder.build(gt_node)?;
476
477 let a = EncryptedU8::encrypt(10, keypair.client_key());
478 let b = EncryptedU8::encrypt(5, keypair.client_key());
479
480 let mut inputs = HashMap::new();
481 inputs.insert("a".to_string(), a.to_cipher_blob()?);
482 inputs.insert("b".to_string(), b.to_cipher_blob()?);
483
484 let executor = FheExecutor::new();
485 let result_blob = executor.execute(&circuit, &inputs)?;
486
487 let result = EncryptedBool::from_cipher_blob(&result_blob)?;
488 assert!(result.decrypt(keypair.client_key()));
489
490 Ok(())
491 }
492
493 #[test]
494 fn test_missing_input_error() -> Result<()> {
495 let keypair = FheKeyPair::generate()?;
496 keypair.set_as_global_server_key();
497
498 let mut builder = CircuitBuilder::new();
499 builder.declare_variable("a", EncryptedType::U8);
500
501 let a_node = builder.load("a");
502 let circuit = builder.build(a_node)?;
503
504 let inputs = HashMap::new(); let executor = FheExecutor::new();
507 let result = executor.execute(&circuit, &inputs);
508
509 assert!(result.is_err());
510
511 Ok(())
512 }
513}