1use crate::compute::{
7 Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
8};
9use crate::error::{AmateRSError, ErrorContext, Result};
10use crate::types::{CipherBlob, ColumnRef, Predicate};
11
12pub struct PredicateCompiler {
34 builder: CircuitBuilder,
35}
36
37impl PredicateCompiler {
38 pub fn new() -> Self {
40 Self {
41 builder: CircuitBuilder::new(),
42 }
43 }
44
45 pub fn compile(&mut self, predicate: &Predicate, value_type: EncryptedType) -> Result<Circuit> {
69 self.builder.declare_variable("value", value_type);
71 self.builder.declare_variable("rhs", value_type);
72
73 let root = self.compile_node(predicate)?;
75
76 self.builder.build(root)
78 }
79
80 fn compile_node(&self, predicate: &Predicate) -> Result<CircuitNode> {
82 match predicate {
83 Predicate::Eq(col, _value) => {
84 self.validate_column(col)?;
86 let value_node = self.builder.load("value");
87 let rhs_node = self.builder.load("rhs");
88 Ok(self.builder.eq(value_node, rhs_node))
89 }
90
91 Predicate::Gt(col, _value) => {
92 self.validate_column(col)?;
94 let value_node = self.builder.load("value");
95 let rhs_node = self.builder.load("rhs");
96 Ok(self.builder.gt(value_node, rhs_node))
97 }
98
99 Predicate::Lt(col, _value) => {
100 self.validate_column(col)?;
102 let value_node = self.builder.load("value");
103 let rhs_node = self.builder.load("rhs");
104 Ok(self.builder.lt(value_node, rhs_node))
105 }
106
107 Predicate::Gte(col, _value) => {
108 self.validate_column(col)?;
110 let value_node = self.builder.load("value");
111 let rhs_node = self.builder.load("rhs");
112 let lt_node = self.builder.lt(value_node, rhs_node);
114 Ok(self.builder.not(lt_node))
115 }
116
117 Predicate::Lte(col, _value) => {
118 self.validate_column(col)?;
120 let value_node = self.builder.load("value");
121 let rhs_node = self.builder.load("rhs");
122 let gt_node = self.builder.gt(value_node, rhs_node);
124 Ok(self.builder.not(gt_node))
125 }
126
127 Predicate::And(left, right) => {
128 let left_circuit = self.compile_node(left)?;
133 let right_circuit = self.compile_node(right)?;
134 Ok(self.builder.and(left_circuit, right_circuit))
135 }
136
137 Predicate::Or(left, right) => {
138 let left_circuit = self.compile_node(left)?;
140 let right_circuit = self.compile_node(right)?;
141 Ok(self.builder.or(left_circuit, right_circuit))
142 }
143
144 Predicate::Not(pred) => {
145 let pred_circuit = self.compile_node(pred)?;
147 Ok(self.builder.not(pred_circuit))
148 }
149 }
150 }
151
152 fn validate_column(&self, col: &ColumnRef) -> Result<()> {
156 let _ = col;
160 Ok(())
161 }
162
163 pub fn extract_rhs_value(predicate: &Predicate) -> Result<CipherBlob> {
181 match predicate {
182 Predicate::Eq(_, value)
183 | Predicate::Gt(_, value)
184 | Predicate::Lt(_, value)
185 | Predicate::Gte(_, value)
186 | Predicate::Lte(_, value) => Ok(value.clone()),
187
188 Predicate::And(left, _right) => {
189 Self::extract_rhs_value(left)
191 }
192
193 Predicate::Or(left, _right) => {
194 Self::extract_rhs_value(left)
196 }
197
198 Predicate::Not(pred) => {
199 Self::extract_rhs_value(pred)
201 }
202 }
203 }
204
205 pub fn extract_all_rhs_values(predicate: &Predicate) -> Vec<CipherBlob> {
219 match predicate {
220 Predicate::Eq(_, value)
221 | Predicate::Gt(_, value)
222 | Predicate::Lt(_, value)
223 | Predicate::Gte(_, value)
224 | Predicate::Lte(_, value) => vec![value.clone()],
225
226 Predicate::And(left, right) => {
227 let mut values = Self::extract_all_rhs_values(left);
228 values.extend(Self::extract_all_rhs_values(right));
229 values
230 }
231
232 Predicate::Or(left, right) => {
233 let mut values = Self::extract_all_rhs_values(left);
234 values.extend(Self::extract_all_rhs_values(right));
235 values
236 }
237
238 Predicate::Not(pred) => Self::extract_all_rhs_values(pred),
239 }
240 }
241
242 pub fn infer_value_type(_predicate: &Predicate) -> Option<EncryptedType> {
255 None
259 }
260}
261
262impl Default for PredicateCompiler {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268pub fn compile_predicate(predicate: &Predicate, value_type: EncryptedType) -> Result<Circuit> {
282 let mut compiler = PredicateCompiler::new();
283 compiler.compile(predicate, value_type)
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::types::col;
290
291 fn make_test_blob(value: u8) -> CipherBlob {
292 CipherBlob::new(vec![value])
293 }
294
295 #[test]
296 fn test_compiler_creation() {
297 let compiler = PredicateCompiler::new();
298 assert_eq!(compiler.builder.variable_types().len(), 0);
299 }
300
301 #[test]
302 fn test_compile_eq_predicate() -> Result<()> {
303 let mut compiler = PredicateCompiler::new();
304 let predicate = Predicate::Eq(col("age"), make_test_blob(18));
305
306 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
307
308 assert_eq!(circuit.result_type, EncryptedType::Bool);
309 assert_eq!(circuit.variable_types.len(), 2);
310 assert!(circuit.variable_types.contains_key("value"));
311 assert!(circuit.variable_types.contains_key("rhs"));
312
313 Ok(())
314 }
315
316 #[test]
317 fn test_compile_gt_predicate() -> Result<()> {
318 let mut compiler = PredicateCompiler::new();
319 let predicate = Predicate::Gt(col("age"), make_test_blob(18));
320
321 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
322
323 assert_eq!(circuit.result_type, EncryptedType::Bool);
324 assert!(circuit.gate_count > 0);
325
326 Ok(())
327 }
328
329 #[test]
330 fn test_compile_lt_predicate() -> Result<()> {
331 let mut compiler = PredicateCompiler::new();
332 let predicate = Predicate::Lt(col("age"), make_test_blob(65));
333
334 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
335
336 assert_eq!(circuit.result_type, EncryptedType::Bool);
337
338 Ok(())
339 }
340
341 #[test]
342 fn test_compile_gte_predicate() -> Result<()> {
343 let mut compiler = PredicateCompiler::new();
344 let predicate = Predicate::Gte(col("age"), make_test_blob(18));
345
346 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
347
348 assert_eq!(circuit.result_type, EncryptedType::Bool);
349 assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
351
352 Ok(())
353 }
354
355 #[test]
356 fn test_compile_lte_predicate() -> Result<()> {
357 let mut compiler = PredicateCompiler::new();
358 let predicate = Predicate::Lte(col("age"), make_test_blob(65));
359
360 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
361
362 assert_eq!(circuit.result_type, EncryptedType::Bool);
363 assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
365
366 Ok(())
367 }
368
369 #[test]
370 fn test_compile_and_predicate() -> Result<()> {
371 let mut compiler = PredicateCompiler::new();
372
373 let pred1 = Predicate::Gt(col("age"), make_test_blob(18));
375 let pred2 = Predicate::Lt(col("age"), make_test_blob(65));
376 let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
377
378 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
379
380 assert_eq!(circuit.result_type, EncryptedType::Bool);
381 assert!(matches!(circuit.root, CircuitNode::BinaryOp { .. }));
382
383 assert!(circuit.gate_count >= 2);
385
386 Ok(())
387 }
388
389 #[test]
390 fn test_compile_or_predicate() -> Result<()> {
391 let mut compiler = PredicateCompiler::new();
392
393 let pred1 = Predicate::Lt(col("age"), make_test_blob(18));
395 let pred2 = Predicate::Gt(col("age"), make_test_blob(65));
396 let predicate = Predicate::Or(Box::new(pred1), Box::new(pred2));
397
398 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
399
400 assert_eq!(circuit.result_type, EncryptedType::Bool);
401 assert!(matches!(circuit.root, CircuitNode::BinaryOp { .. }));
402
403 Ok(())
404 }
405
406 #[test]
407 fn test_compile_not_predicate() -> Result<()> {
408 let mut compiler = PredicateCompiler::new();
409
410 let pred = Predicate::Eq(col("age"), make_test_blob(18));
412 let predicate = Predicate::Not(Box::new(pred));
413
414 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
415
416 assert_eq!(circuit.result_type, EncryptedType::Bool);
417 assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
418
419 Ok(())
420 }
421
422 #[test]
423 fn test_compile_complex_predicate() -> Result<()> {
424 let mut compiler = PredicateCompiler::new();
425
426 let pred1 = Predicate::Gt(col("age"), make_test_blob(18));
428 let pred2 = Predicate::Lt(col("age"), make_test_blob(65));
429 let and_pred = Predicate::And(Box::new(pred1), Box::new(pred2));
430
431 let pred3 = Predicate::Eq(col("age"), make_test_blob(100));
432 let predicate = Predicate::Or(Box::new(and_pred), Box::new(pred3));
433
434 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
435
436 assert_eq!(circuit.result_type, EncryptedType::Bool);
437 assert!(circuit.gate_count >= 3);
439 assert!(circuit.depth >= 2);
440
441 Ok(())
442 }
443
444 #[test]
445 fn test_extract_rhs_value() -> Result<()> {
446 let blob = make_test_blob(42);
447 let predicate = Predicate::Gt(col("age"), blob.clone());
448
449 let extracted = PredicateCompiler::extract_rhs_value(&predicate)?;
450 assert_eq!(extracted, blob);
451
452 Ok(())
453 }
454
455 #[test]
456 fn test_extract_rhs_from_and() -> Result<()> {
457 let blob1 = make_test_blob(18);
458 let blob2 = make_test_blob(65);
459
460 let pred1 = Predicate::Gt(col("age"), blob1.clone());
461 let pred2 = Predicate::Lt(col("age"), blob2);
462 let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
463
464 let extracted = PredicateCompiler::extract_rhs_value(&predicate)?;
466 assert_eq!(extracted, blob1);
467
468 Ok(())
469 }
470
471 #[test]
472 fn test_extract_all_rhs_values() {
473 let blob1 = make_test_blob(18);
474 let blob2 = make_test_blob(65);
475
476 let pred1 = Predicate::Gt(col("age"), blob1.clone());
477 let pred2 = Predicate::Lt(col("age"), blob2.clone());
478 let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
479
480 let values = PredicateCompiler::extract_all_rhs_values(&predicate);
481 assert_eq!(values.len(), 2);
482 assert_eq!(values[0], blob1);
483 assert_eq!(values[1], blob2);
484 }
485
486 #[test]
487 fn test_compile_predicate_helper() -> Result<()> {
488 let predicate = Predicate::Eq(col("age"), make_test_blob(18));
489 let circuit = compile_predicate(&predicate, EncryptedType::U8)?;
490
491 assert_eq!(circuit.result_type, EncryptedType::Bool);
492
493 Ok(())
494 }
495
496 #[test]
497 fn test_circuit_validation() -> Result<()> {
498 let mut compiler = PredicateCompiler::new();
499 let predicate = Predicate::Gt(col("age"), make_test_blob(18));
500
501 let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
502
503 circuit.validate()?;
505
506 Ok(())
507 }
508}