amaters_core/compute/
predicate.rs

1//! Predicate-to-FHE-Circuit Compiler
2//!
3//! This module provides compilation of AmateRS query predicates into FHE circuits
4//! that can be executed on encrypted data without revealing plaintext values.
5
6use crate::compute::{
7    Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
8};
9use crate::error::{AmateRSError, ErrorContext, Result};
10use crate::types::{CipherBlob, ColumnRef, Predicate};
11
12/// Compiles query predicates into executable FHE circuits
13///
14/// The PredicateCompiler transforms high-level query predicates (like `age > 18`)
15/// into FHE circuits that can evaluate these conditions on encrypted data.
16/// The result is always an encrypted boolean indicating whether the predicate
17/// matches or not.
18///
19/// # Example
20///
21/// ```rust,ignore
22/// use amaters_core::compute::{PredicateCompiler, EncryptedType};
23/// use amaters_core::types::{Predicate, col, CipherBlob};
24///
25/// let mut compiler = PredicateCompiler::new();
26///
27/// // Compile: age > 18
28/// let predicate = Predicate::Gt(col("age"), encrypted_18);
29/// let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
30///
31/// // The circuit can now be executed on encrypted age values
32/// ```
33pub struct PredicateCompiler {
34    builder: CircuitBuilder,
35}
36
37impl PredicateCompiler {
38    /// Create a new predicate compiler
39    pub fn new() -> Self {
40        Self {
41            builder: CircuitBuilder::new(),
42        }
43    }
44
45    /// Compile a predicate into an FHE circuit
46    ///
47    /// The resulting circuit will have inputs for:
48    /// - `value`: The encrypted column value to test
49    /// - `rhs`: The encrypted comparison value (right-hand side)
50    ///
51    /// The circuit output is an encrypted boolean indicating the predicate result.
52    ///
53    /// # Arguments
54    ///
55    /// * `predicate` - The predicate to compile
56    /// * `value_type` - The encrypted type of the values being compared
57    ///
58    /// # Returns
59    ///
60    /// A `Circuit` that evaluates the predicate on encrypted data
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if:
65    /// - The predicate references undefined columns
66    /// - Type inference fails
67    /// - The circuit construction is invalid
68    pub fn compile(&mut self, predicate: &Predicate, value_type: EncryptedType) -> Result<Circuit> {
69        // Declare variables for the circuit
70        self.builder.declare_variable("value", value_type);
71        self.builder.declare_variable("rhs", value_type);
72
73        // Compile the predicate into a circuit node
74        let root = self.compile_node(predicate)?;
75
76        // Build and return the circuit
77        self.builder.build(root)
78    }
79
80    /// Recursively compile a predicate node into a circuit node
81    fn compile_node(&self, predicate: &Predicate) -> Result<CircuitNode> {
82        match predicate {
83            Predicate::Eq(col, _value) => {
84                // Equality: value == rhs
85                self.validate_column(col)?;
86                let value_node = self.builder.load("value");
87                let rhs_node = self.builder.load("rhs");
88                Ok(self.builder.eq(value_node, rhs_node))
89            }
90
91            Predicate::Gt(col, _value) => {
92                // Greater than: value > rhs
93                self.validate_column(col)?;
94                let value_node = self.builder.load("value");
95                let rhs_node = self.builder.load("rhs");
96                Ok(self.builder.gt(value_node, rhs_node))
97            }
98
99            Predicate::Lt(col, _value) => {
100                // Less than: value < rhs
101                self.validate_column(col)?;
102                let value_node = self.builder.load("value");
103                let rhs_node = self.builder.load("rhs");
104                Ok(self.builder.lt(value_node, rhs_node))
105            }
106
107            Predicate::Gte(col, _value) => {
108                // Greater than or equal: value >= rhs
109                self.validate_column(col)?;
110                let value_node = self.builder.load("value");
111                let rhs_node = self.builder.load("rhs");
112                // Implement as NOT (value < rhs)
113                let lt_node = self.builder.lt(value_node, rhs_node);
114                Ok(self.builder.not(lt_node))
115            }
116
117            Predicate::Lte(col, _value) => {
118                // Less than or equal: value <= rhs
119                self.validate_column(col)?;
120                let value_node = self.builder.load("value");
121                let rhs_node = self.builder.load("rhs");
122                // Implement as NOT (value > rhs)
123                let gt_node = self.builder.gt(value_node, rhs_node);
124                Ok(self.builder.not(gt_node))
125            }
126
127            Predicate::And(left, right) => {
128                // Logical AND: left AND right
129                // Note: This requires both predicates to reference the same value
130                // For now, we'll compile recursively but this may need refinement
131                // for multi-column predicates
132                let left_circuit = self.compile_node(left)?;
133                let right_circuit = self.compile_node(right)?;
134                Ok(self.builder.and(left_circuit, right_circuit))
135            }
136
137            Predicate::Or(left, right) => {
138                // Logical OR: left OR right
139                let left_circuit = self.compile_node(left)?;
140                let right_circuit = self.compile_node(right)?;
141                Ok(self.builder.or(left_circuit, right_circuit))
142            }
143
144            Predicate::Not(pred) => {
145                // Logical NOT: NOT pred
146                let pred_circuit = self.compile_node(pred)?;
147                Ok(self.builder.not(pred_circuit))
148            }
149        }
150    }
151
152    /// Validate that a column reference is supported
153    ///
154    /// For now, we only support single-column predicates with the column named "value"
155    fn validate_column(&self, col: &ColumnRef) -> Result<()> {
156        // In the current design, we're evaluating predicates on individual values
157        // The column reference should match what we're testing
158        // For now, we accept any column name since we're binding it to "value"
159        let _ = col;
160        Ok(())
161    }
162
163    /// Extract the RHS (right-hand side) value from a predicate
164    ///
165    /// This walks the predicate tree to find comparison values.
166    /// For composite predicates (And/Or/Not), it extracts from the first
167    /// comparison it encounters.
168    ///
169    /// # Arguments
170    ///
171    /// * `predicate` - The predicate to extract from
172    ///
173    /// # Returns
174    ///
175    /// The encrypted value used in the predicate comparison
176    ///
177    /// # Errors
178    ///
179    /// Returns an error if the predicate contains no comparison operations
180    pub fn extract_rhs_value(predicate: &Predicate) -> Result<CipherBlob> {
181        match predicate {
182            Predicate::Eq(_, value)
183            | Predicate::Gt(_, value)
184            | Predicate::Lt(_, value)
185            | Predicate::Gte(_, value)
186            | Predicate::Lte(_, value) => Ok(value.clone()),
187
188            Predicate::And(left, _right) => {
189                // For AND, extract from left (could also merge both)
190                Self::extract_rhs_value(left)
191            }
192
193            Predicate::Or(left, _right) => {
194                // For OR, extract from left
195                Self::extract_rhs_value(left)
196            }
197
198            Predicate::Not(pred) => {
199                // For NOT, extract from inner predicate
200                Self::extract_rhs_value(pred)
201            }
202        }
203    }
204
205    /// Extract all RHS values from a predicate
206    ///
207    /// For composite predicates, this returns all comparison values.
208    /// This is useful for complex predicates like `age > 18 AND age < 65`
209    /// which have multiple RHS values.
210    ///
211    /// # Arguments
212    ///
213    /// * `predicate` - The predicate to extract from
214    ///
215    /// # Returns
216    ///
217    /// A vector of all encrypted values used in comparisons
218    pub fn extract_all_rhs_values(predicate: &Predicate) -> Vec<CipherBlob> {
219        match predicate {
220            Predicate::Eq(_, value)
221            | Predicate::Gt(_, value)
222            | Predicate::Lt(_, value)
223            | Predicate::Gte(_, value)
224            | Predicate::Lte(_, value) => vec![value.clone()],
225
226            Predicate::And(left, right) => {
227                let mut values = Self::extract_all_rhs_values(left);
228                values.extend(Self::extract_all_rhs_values(right));
229                values
230            }
231
232            Predicate::Or(left, right) => {
233                let mut values = Self::extract_all_rhs_values(left);
234                values.extend(Self::extract_all_rhs_values(right));
235                values
236            }
237
238            Predicate::Not(pred) => Self::extract_all_rhs_values(pred),
239        }
240    }
241
242    /// Get the required encrypted type for a predicate's values
243    ///
244    /// This analyzes the predicate to determine what type of encrypted values
245    /// it operates on. This is useful for automatic type inference.
246    ///
247    /// # Arguments
248    ///
249    /// * `predicate` - The predicate to analyze
250    ///
251    /// # Returns
252    ///
253    /// The encrypted type hint, or None if it cannot be determined
254    pub fn infer_value_type(_predicate: &Predicate) -> Option<EncryptedType> {
255        // For now, we don't have type information in the predicate itself
256        // This would require extending the Predicate enum with type metadata
257        // or analyzing the CipherBlob metadata
258        None
259    }
260}
261
262impl Default for PredicateCompiler {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268/// Helper function to compile a simple predicate
269///
270/// This is a convenience wrapper around PredicateCompiler for single predicates.
271///
272/// # Example
273///
274/// ```rust,ignore
275/// use amaters_core::compute::{compile_predicate, EncryptedType};
276/// use amaters_core::types::{Predicate, col};
277///
278/// let predicate = Predicate::Gt(col("age"), encrypted_18);
279/// let circuit = compile_predicate(&predicate, EncryptedType::U8)?;
280/// ```
281pub fn compile_predicate(predicate: &Predicate, value_type: EncryptedType) -> Result<Circuit> {
282    let mut compiler = PredicateCompiler::new();
283    compiler.compile(predicate, value_type)
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::types::col;
290
291    fn make_test_blob(value: u8) -> CipherBlob {
292        CipherBlob::new(vec![value])
293    }
294
295    #[test]
296    fn test_compiler_creation() {
297        let compiler = PredicateCompiler::new();
298        assert_eq!(compiler.builder.variable_types().len(), 0);
299    }
300
301    #[test]
302    fn test_compile_eq_predicate() -> Result<()> {
303        let mut compiler = PredicateCompiler::new();
304        let predicate = Predicate::Eq(col("age"), make_test_blob(18));
305
306        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
307
308        assert_eq!(circuit.result_type, EncryptedType::Bool);
309        assert_eq!(circuit.variable_types.len(), 2);
310        assert!(circuit.variable_types.contains_key("value"));
311        assert!(circuit.variable_types.contains_key("rhs"));
312
313        Ok(())
314    }
315
316    #[test]
317    fn test_compile_gt_predicate() -> Result<()> {
318        let mut compiler = PredicateCompiler::new();
319        let predicate = Predicate::Gt(col("age"), make_test_blob(18));
320
321        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
322
323        assert_eq!(circuit.result_type, EncryptedType::Bool);
324        assert!(circuit.gate_count > 0);
325
326        Ok(())
327    }
328
329    #[test]
330    fn test_compile_lt_predicate() -> Result<()> {
331        let mut compiler = PredicateCompiler::new();
332        let predicate = Predicate::Lt(col("age"), make_test_blob(65));
333
334        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
335
336        assert_eq!(circuit.result_type, EncryptedType::Bool);
337
338        Ok(())
339    }
340
341    #[test]
342    fn test_compile_gte_predicate() -> Result<()> {
343        let mut compiler = PredicateCompiler::new();
344        let predicate = Predicate::Gte(col("age"), make_test_blob(18));
345
346        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
347
348        assert_eq!(circuit.result_type, EncryptedType::Bool);
349        // Gte is implemented as NOT (value < rhs), so should have a NOT gate
350        assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
351
352        Ok(())
353    }
354
355    #[test]
356    fn test_compile_lte_predicate() -> Result<()> {
357        let mut compiler = PredicateCompiler::new();
358        let predicate = Predicate::Lte(col("age"), make_test_blob(65));
359
360        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
361
362        assert_eq!(circuit.result_type, EncryptedType::Bool);
363        // Lte is implemented as NOT (value > rhs), so should have a NOT gate
364        assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
365
366        Ok(())
367    }
368
369    #[test]
370    fn test_compile_and_predicate() -> Result<()> {
371        let mut compiler = PredicateCompiler::new();
372
373        // age > 18 AND age < 65
374        let pred1 = Predicate::Gt(col("age"), make_test_blob(18));
375        let pred2 = Predicate::Lt(col("age"), make_test_blob(65));
376        let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
377
378        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
379
380        assert_eq!(circuit.result_type, EncryptedType::Bool);
381        assert!(matches!(circuit.root, CircuitNode::BinaryOp { .. }));
382
383        // Should have more gates due to AND
384        assert!(circuit.gate_count >= 2);
385
386        Ok(())
387    }
388
389    #[test]
390    fn test_compile_or_predicate() -> Result<()> {
391        let mut compiler = PredicateCompiler::new();
392
393        // age < 18 OR age > 65
394        let pred1 = Predicate::Lt(col("age"), make_test_blob(18));
395        let pred2 = Predicate::Gt(col("age"), make_test_blob(65));
396        let predicate = Predicate::Or(Box::new(pred1), Box::new(pred2));
397
398        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
399
400        assert_eq!(circuit.result_type, EncryptedType::Bool);
401        assert!(matches!(circuit.root, CircuitNode::BinaryOp { .. }));
402
403        Ok(())
404    }
405
406    #[test]
407    fn test_compile_not_predicate() -> Result<()> {
408        let mut compiler = PredicateCompiler::new();
409
410        // NOT (age == 18)
411        let pred = Predicate::Eq(col("age"), make_test_blob(18));
412        let predicate = Predicate::Not(Box::new(pred));
413
414        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
415
416        assert_eq!(circuit.result_type, EncryptedType::Bool);
417        assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
418
419        Ok(())
420    }
421
422    #[test]
423    fn test_compile_complex_predicate() -> Result<()> {
424        let mut compiler = PredicateCompiler::new();
425
426        // (age > 18 AND age < 65) OR age == 100
427        let pred1 = Predicate::Gt(col("age"), make_test_blob(18));
428        let pred2 = Predicate::Lt(col("age"), make_test_blob(65));
429        let and_pred = Predicate::And(Box::new(pred1), Box::new(pred2));
430
431        let pred3 = Predicate::Eq(col("age"), make_test_blob(100));
432        let predicate = Predicate::Or(Box::new(and_pred), Box::new(pred3));
433
434        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
435
436        assert_eq!(circuit.result_type, EncryptedType::Bool);
437        // Complex predicate should have multiple gates
438        assert!(circuit.gate_count >= 3);
439        assert!(circuit.depth >= 2);
440
441        Ok(())
442    }
443
444    #[test]
445    fn test_extract_rhs_value() -> Result<()> {
446        let blob = make_test_blob(42);
447        let predicate = Predicate::Gt(col("age"), blob.clone());
448
449        let extracted = PredicateCompiler::extract_rhs_value(&predicate)?;
450        assert_eq!(extracted, blob);
451
452        Ok(())
453    }
454
455    #[test]
456    fn test_extract_rhs_from_and() -> Result<()> {
457        let blob1 = make_test_blob(18);
458        let blob2 = make_test_blob(65);
459
460        let pred1 = Predicate::Gt(col("age"), blob1.clone());
461        let pred2 = Predicate::Lt(col("age"), blob2);
462        let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
463
464        // Should extract from left predicate
465        let extracted = PredicateCompiler::extract_rhs_value(&predicate)?;
466        assert_eq!(extracted, blob1);
467
468        Ok(())
469    }
470
471    #[test]
472    fn test_extract_all_rhs_values() {
473        let blob1 = make_test_blob(18);
474        let blob2 = make_test_blob(65);
475
476        let pred1 = Predicate::Gt(col("age"), blob1.clone());
477        let pred2 = Predicate::Lt(col("age"), blob2.clone());
478        let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
479
480        let values = PredicateCompiler::extract_all_rhs_values(&predicate);
481        assert_eq!(values.len(), 2);
482        assert_eq!(values[0], blob1);
483        assert_eq!(values[1], blob2);
484    }
485
486    #[test]
487    fn test_compile_predicate_helper() -> Result<()> {
488        let predicate = Predicate::Eq(col("age"), make_test_blob(18));
489        let circuit = compile_predicate(&predicate, EncryptedType::U8)?;
490
491        assert_eq!(circuit.result_type, EncryptedType::Bool);
492
493        Ok(())
494    }
495
496    #[test]
497    fn test_circuit_validation() -> Result<()> {
498        let mut compiler = PredicateCompiler::new();
499        let predicate = Predicate::Gt(col("age"), make_test_blob(18));
500
501        let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
502
503        // Circuit should be valid
504        circuit.validate()?;
505
506        Ok(())
507    }
508}