1pub mod circuit;
39pub mod gpu;
40pub mod key_manager;
41pub mod keys;
42pub mod operations;
43pub mod optimizer;
44pub mod plan_cache;
45pub mod planner;
46pub mod predicate;
47
48#[cfg(test)]
49mod filter_tests;
50
51pub use circuit::{
53 BinaryOperator, Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator,
54 ConstantType, EncryptedType, UnaryOperator, count_encrypted_constants,
55 count_plaintext_constants, decrypt_constant, encrypt_circuit_constants, encrypt_constant,
56 is_encrypted_constant,
57};
58pub use key_manager::{ClientId, KeyManager};
59pub use keys::{FheKeyPair, InMemoryKeyStorage, KeyStorage};
60pub use operations::{EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64};
61pub use optimizer::{CircuitOptimizer, DependencyGraph, NodeId, OptimizationStats};
62pub use planner::{LogicalPlan, PhysicalPlan, PlanCost, PlannerStats, QueryPlanner};
63pub use predicate::{PredicateCompiler, compile_predicate};
64
65use crate::error::{AmateRSError, ErrorContext, Result};
66use crate::types::CipherBlob;
67use std::collections::HashMap;
68
69#[derive(Debug, Clone)]
74pub struct FheExecutor {
75 optimizer: CircuitOptimizer,
76 optimization_enabled: bool,
77}
78
79impl FheExecutor {
80 pub fn new() -> Self {
82 Self {
83 optimizer: CircuitOptimizer::new(),
84 optimization_enabled: true,
85 }
86 }
87
88 pub fn with_optimization(enable: bool) -> Self {
90 Self {
91 optimizer: if enable {
92 CircuitOptimizer::new()
93 } else {
94 CircuitOptimizer::disabled()
95 },
96 optimization_enabled: enable,
97 }
98 }
99
100 pub fn optimization_stats(&self) -> &OptimizationStats {
102 self.optimizer.stats()
103 }
104
105 pub fn dependency_graph(&self) -> &DependencyGraph {
107 self.optimizer.dependency_graph()
108 }
109
110 #[cfg(feature = "compute")]
121 pub fn execute(
122 &self,
123 circuit: &Circuit,
124 inputs: &HashMap<String, CipherBlob>,
125 ) -> Result<CipherBlob> {
126 circuit.validate()?;
128
129 for var_name in circuit.variable_types.keys() {
131 if !inputs.contains_key(var_name) {
132 return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
133 "Missing input for variable: {}",
134 var_name
135 ))));
136 }
137 }
138
139 let optimized = if self.optimization_enabled {
141 let mut optimizer = self.optimizer.clone();
143 optimizer.optimize(circuit.clone())?
144 } else {
145 circuit.clone()
146 };
147
148 let result_value = self.execute_node(&optimized.root, inputs, &optimized.variable_types)?;
150
151 match result_value {
153 EncryptedValue::Bool(v) => v.to_cipher_blob(),
154 EncryptedValue::U8(v) => v.to_cipher_blob(),
155 EncryptedValue::U16(v) => v.to_cipher_blob(),
156 EncryptedValue::U32(v) => v.to_cipher_blob(),
157 EncryptedValue::U64(v) => v.to_cipher_blob(),
158 }
159 }
160
161 #[cfg(not(feature = "compute"))]
163 pub fn execute(
164 &self,
165 _circuit: &Circuit,
166 _inputs: &HashMap<String, CipherBlob>,
167 ) -> Result<CipherBlob> {
168 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
169 "FHE compute feature is not enabled".to_string(),
170 )))
171 }
172
173 #[cfg(feature = "compute")]
175 #[allow(clippy::only_used_in_recursion)]
176 fn execute_node(
177 &self,
178 node: &CircuitNode,
179 inputs: &HashMap<String, CipherBlob>,
180 variable_types: &HashMap<String, EncryptedType>,
181 ) -> Result<EncryptedValue> {
182 match node {
183 CircuitNode::Load(name) => {
184 let blob = inputs.get(name).ok_or_else(|| {
185 AmateRSError::FheComputation(ErrorContext::new(format!(
186 "Missing input: {}",
187 name
188 )))
189 })?;
190
191 let var_type = variable_types.get(name).ok_or_else(|| {
192 AmateRSError::FheComputation(ErrorContext::new(format!(
193 "Unknown variable type: {}",
194 name
195 )))
196 })?;
197
198 match var_type {
199 EncryptedType::Bool => {
200 Ok(EncryptedValue::Bool(EncryptedBool::from_cipher_blob(blob)?))
201 }
202 EncryptedType::U8 => {
203 Ok(EncryptedValue::U8(EncryptedU8::from_cipher_blob(blob)?))
204 }
205 EncryptedType::U16 => {
206 Ok(EncryptedValue::U16(EncryptedU16::from_cipher_blob(blob)?))
207 }
208 EncryptedType::U32 => {
209 Ok(EncryptedValue::U32(EncryptedU32::from_cipher_blob(blob)?))
210 }
211 EncryptedType::U64 => {
212 Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(blob)?))
213 }
214 }
215 }
216
217 CircuitNode::Constant(_value) => {
218 Err(AmateRSError::FheComputation(ErrorContext::new(
222 "Plaintext constants cannot be used in FHE execution. \
223 Use encrypt_circuit_constants() to encrypt constants before evaluation."
224 .to_string(),
225 )))
226 }
227
228 CircuitNode::EncryptedConstant {
229 data,
230 original_type,
231 } => {
232 let blob = CipherBlob::new(data.clone());
236 match original_type {
237 ConstantType::Boolean => Ok(EncryptedValue::Bool(
238 EncryptedBool::from_cipher_blob(&blob)?,
239 )),
240 ConstantType::Integer => {
241 Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(&blob)?))
245 }
246 ConstantType::Float | ConstantType::Bytes => {
247 Err(AmateRSError::FheComputation(ErrorContext::new(format!(
248 "EncryptedConstant of type {} is not directly evaluable in FHE circuits",
249 original_type
250 ))))
251 }
252 }
253 }
254
255 CircuitNode::BinaryOp { op, left, right } => {
256 let left_val = self.execute_node(left, inputs, variable_types)?;
257 let right_val = self.execute_node(right, inputs, variable_types)?;
258
259 match (op, left_val, right_val) {
260 (BinaryOperator::And, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
262 Ok(EncryptedValue::Bool(l.and(&r)))
263 }
264 (BinaryOperator::Or, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
265 Ok(EncryptedValue::Bool(l.or(&r)))
266 }
267 (BinaryOperator::Xor, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
268 Ok(EncryptedValue::Bool(l.xor(&r)))
269 }
270
271 (BinaryOperator::Add, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
273 Ok(EncryptedValue::U8(l.add(&r)))
274 }
275 (BinaryOperator::Sub, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
276 Ok(EncryptedValue::U8(l.sub(&r)))
277 }
278 (BinaryOperator::Mul, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
279 Ok(EncryptedValue::U8(l.mul(&r)))
280 }
281
282 (BinaryOperator::Add, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
284 Ok(EncryptedValue::U16(l.add(&r)))
285 }
286 (BinaryOperator::Sub, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
287 Ok(EncryptedValue::U16(l.sub(&r)))
288 }
289 (BinaryOperator::Mul, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
290 Ok(EncryptedValue::U16(l.mul(&r)))
291 }
292
293 (BinaryOperator::Add, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
295 Ok(EncryptedValue::U32(l.add(&r)))
296 }
297 (BinaryOperator::Sub, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
298 Ok(EncryptedValue::U32(l.sub(&r)))
299 }
300 (BinaryOperator::Mul, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
301 Ok(EncryptedValue::U32(l.mul(&r)))
302 }
303
304 (BinaryOperator::Add, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
306 Ok(EncryptedValue::U64(l.add(&r)))
307 }
308 (BinaryOperator::Sub, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
309 Ok(EncryptedValue::U64(l.sub(&r)))
310 }
311 (BinaryOperator::Mul, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
312 Ok(EncryptedValue::U64(l.mul(&r)))
313 }
314
315 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
316 "Type mismatch in binary operation".to_string(),
317 ))),
318 }
319 }
320
321 CircuitNode::UnaryOp { op, operand } => {
322 let operand_val = self.execute_node(operand, inputs, variable_types)?;
323
324 match (op, operand_val) {
325 (UnaryOperator::Not, EncryptedValue::Bool(v)) => {
326 Ok(EncryptedValue::Bool(v.not()))
327 }
328
329 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
330 "Type mismatch in unary operation".to_string(),
331 ))),
332 }
333 }
334
335 CircuitNode::Compare { op, left, right } => {
336 let left_val = self.execute_node(left, inputs, variable_types)?;
337 let right_val = self.execute_node(right, inputs, variable_types)?;
338
339 match (left_val, right_val) {
340 (EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
341 let result = match op {
342 CompareOperator::Eq => l.eq(&r),
343 CompareOperator::Ne => l.ne(&r),
344 CompareOperator::Lt => l.lt(&r),
345 CompareOperator::Le => l.le(&r),
346 CompareOperator::Gt => l.gt(&r),
347 CompareOperator::Ge => l.ge(&r),
348 };
349 Ok(EncryptedValue::Bool(result))
350 }
351
352 (EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
353 let result = match op {
354 CompareOperator::Eq => l.eq(&r),
355 CompareOperator::Ne => l.ne(&r),
356 CompareOperator::Lt => l.lt(&r),
357 CompareOperator::Le => l.le(&r),
358 CompareOperator::Gt => l.gt(&r),
359 CompareOperator::Ge => l.ge(&r),
360 };
361 Ok(EncryptedValue::Bool(result))
362 }
363
364 (EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
365 let result = match op {
366 CompareOperator::Eq => l.eq(&r),
367 CompareOperator::Ne => l.ne(&r),
368 CompareOperator::Lt => l.lt(&r),
369 CompareOperator::Le => l.le(&r),
370 CompareOperator::Gt => l.gt(&r),
371 CompareOperator::Ge => l.ge(&r),
372 };
373 Ok(EncryptedValue::Bool(result))
374 }
375
376 (EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
377 let result = match op {
378 CompareOperator::Eq => l.eq(&r),
379 CompareOperator::Ne => l.ne(&r),
380 CompareOperator::Lt => l.lt(&r),
381 CompareOperator::Le => l.le(&r),
382 CompareOperator::Gt => l.gt(&r),
383 CompareOperator::Ge => l.ge(&r),
384 };
385 Ok(EncryptedValue::Bool(result))
386 }
387
388 _ => Err(AmateRSError::FheComputation(ErrorContext::new(
389 "Type mismatch in comparison".to_string(),
390 ))),
391 }
392 }
393 }
394 }
395}
396
397impl Default for FheExecutor {
398 fn default() -> Self {
399 Self::new()
400 }
401}
402
403#[cfg(feature = "compute")]
405enum EncryptedValue {
406 Bool(EncryptedBool),
407 U8(EncryptedU8),
408 U16(EncryptedU16),
409 U32(EncryptedU32),
410 U64(EncryptedU64),
411}
412
413#[deprecated(since = "0.1.0", note = "Use CircuitNode instead")]
417#[derive(Debug, Clone)]
418pub enum Gate {
419 Add,
420 Mul,
421 Not,
422 Bootstrap,
423}
424
425#[cfg(all(test, feature = "compute"))]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_fhe_executor_basic() -> Result<()> {
431 let keypair = FheKeyPair::generate()?;
433 keypair.set_as_global_server_key();
434
435 let mut builder = CircuitBuilder::new();
437 builder
438 .declare_variable("a", EncryptedType::U8)
439 .declare_variable("b", EncryptedType::U8);
440
441 let a_node = builder.load("a");
442 let b_node = builder.load("b");
443 let sum_node = builder.add(a_node, b_node);
444
445 let circuit = builder.build(sum_node)?;
446
447 let a = EncryptedU8::encrypt(5, keypair.client_key());
449 let b = EncryptedU8::encrypt(3, keypair.client_key());
450
451 let mut inputs = HashMap::new();
452 inputs.insert("a".to_string(), a.to_cipher_blob()?);
453 inputs.insert("b".to_string(), b.to_cipher_blob()?);
454
455 let executor = FheExecutor::new();
457 let result_blob = executor.execute(&circuit, &inputs)?;
458
459 let result = EncryptedU8::from_cipher_blob(&result_blob)?;
461 assert_eq!(result.decrypt(keypair.client_key()), 8);
462
463 Ok(())
464 }
465
466 #[test]
467 fn test_fhe_executor_boolean() -> Result<()> {
468 let keypair = FheKeyPair::generate()?;
469 keypair.set_as_global_server_key();
470
471 let mut builder = CircuitBuilder::new();
472 builder
473 .declare_variable("x", EncryptedType::Bool)
474 .declare_variable("y", EncryptedType::Bool);
475
476 let x_node = builder.load("x");
477 let y_node = builder.load("y");
478 let and_node = builder.and(x_node, y_node);
479
480 let circuit = builder.build(and_node)?;
481
482 let x = EncryptedBool::encrypt(true, keypair.client_key());
483 let y = EncryptedBool::encrypt(false, keypair.client_key());
484
485 let mut inputs = HashMap::new();
486 inputs.insert("x".to_string(), x.to_cipher_blob()?);
487 inputs.insert("y".to_string(), y.to_cipher_blob()?);
488
489 let executor = FheExecutor::new();
490 let result_blob = executor.execute(&circuit, &inputs)?;
491
492 let result = EncryptedBool::from_cipher_blob(&result_blob)?;
493 assert!(!result.decrypt(keypair.client_key()));
494
495 Ok(())
496 }
497
498 #[test]
499 fn test_fhe_executor_comparison() -> Result<()> {
500 let keypair = FheKeyPair::generate()?;
501 keypair.set_as_global_server_key();
502
503 let mut builder = CircuitBuilder::new();
504 builder
505 .declare_variable("a", EncryptedType::U8)
506 .declare_variable("b", EncryptedType::U8);
507
508 let a_node = builder.load("a");
509 let b_node = builder.load("b");
510 let gt_node = builder.gt(a_node, b_node);
511
512 let circuit = builder.build(gt_node)?;
513
514 let a = EncryptedU8::encrypt(10, keypair.client_key());
515 let b = EncryptedU8::encrypt(5, keypair.client_key());
516
517 let mut inputs = HashMap::new();
518 inputs.insert("a".to_string(), a.to_cipher_blob()?);
519 inputs.insert("b".to_string(), b.to_cipher_blob()?);
520
521 let executor = FheExecutor::new();
522 let result_blob = executor.execute(&circuit, &inputs)?;
523
524 let result = EncryptedBool::from_cipher_blob(&result_blob)?;
525 assert!(result.decrypt(keypair.client_key()));
526
527 Ok(())
528 }
529
530 #[test]
531 fn test_missing_input_error() -> Result<()> {
532 let keypair = FheKeyPair::generate()?;
533 keypair.set_as_global_server_key();
534
535 let mut builder = CircuitBuilder::new();
536 builder.declare_variable("a", EncryptedType::U8);
537
538 let a_node = builder.load("a");
539 let circuit = builder.build(a_node)?;
540
541 let inputs = HashMap::new(); let executor = FheExecutor::new();
544 let result = executor.execute(&circuit, &inputs);
545
546 assert!(result.is_err());
547
548 Ok(())
549 }
550}