use crate::compute::{
Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
};
use crate::error::{AmateRSError, ErrorContext, Result};
use crate::types::{CipherBlob, ColumnRef, Predicate};
pub struct PredicateCompiler {
builder: CircuitBuilder,
}
impl PredicateCompiler {
pub fn new() -> Self {
Self {
builder: CircuitBuilder::new(),
}
}
pub fn compile(&mut self, predicate: &Predicate, value_type: EncryptedType) -> Result<Circuit> {
self.builder.declare_variable("value", value_type);
self.builder.declare_variable("rhs", value_type);
let root = self.compile_node(predicate)?;
self.builder.build(root)
}
fn compile_node(&self, predicate: &Predicate) -> Result<CircuitNode> {
match predicate {
Predicate::Eq(col, _value) => {
self.validate_column(col)?;
let value_node = self.builder.load("value");
let rhs_node = self.builder.load("rhs");
Ok(self.builder.eq(value_node, rhs_node))
}
Predicate::Gt(col, _value) => {
self.validate_column(col)?;
let value_node = self.builder.load("value");
let rhs_node = self.builder.load("rhs");
Ok(self.builder.gt(value_node, rhs_node))
}
Predicate::Lt(col, _value) => {
self.validate_column(col)?;
let value_node = self.builder.load("value");
let rhs_node = self.builder.load("rhs");
Ok(self.builder.lt(value_node, rhs_node))
}
Predicate::Gte(col, _value) => {
self.validate_column(col)?;
let value_node = self.builder.load("value");
let rhs_node = self.builder.load("rhs");
let lt_node = self.builder.lt(value_node, rhs_node);
Ok(self.builder.not(lt_node))
}
Predicate::Lte(col, _value) => {
self.validate_column(col)?;
let value_node = self.builder.load("value");
let rhs_node = self.builder.load("rhs");
let gt_node = self.builder.gt(value_node, rhs_node);
Ok(self.builder.not(gt_node))
}
Predicate::And(left, right) => {
let left_circuit = self.compile_node(left)?;
let right_circuit = self.compile_node(right)?;
Ok(self.builder.and(left_circuit, right_circuit))
}
Predicate::Or(left, right) => {
let left_circuit = self.compile_node(left)?;
let right_circuit = self.compile_node(right)?;
Ok(self.builder.or(left_circuit, right_circuit))
}
Predicate::Not(pred) => {
let pred_circuit = self.compile_node(pred)?;
Ok(self.builder.not(pred_circuit))
}
}
}
fn validate_column(&self, col: &ColumnRef) -> Result<()> {
let _ = col;
Ok(())
}
pub fn extract_rhs_value(predicate: &Predicate) -> Result<CipherBlob> {
match predicate {
Predicate::Eq(_, value)
| Predicate::Gt(_, value)
| Predicate::Lt(_, value)
| Predicate::Gte(_, value)
| Predicate::Lte(_, value) => Ok(value.clone()),
Predicate::And(left, _right) => {
Self::extract_rhs_value(left)
}
Predicate::Or(left, _right) => {
Self::extract_rhs_value(left)
}
Predicate::Not(pred) => {
Self::extract_rhs_value(pred)
}
}
}
pub fn extract_all_rhs_values(predicate: &Predicate) -> Vec<CipherBlob> {
match predicate {
Predicate::Eq(_, value)
| Predicate::Gt(_, value)
| Predicate::Lt(_, value)
| Predicate::Gte(_, value)
| Predicate::Lte(_, value) => vec![value.clone()],
Predicate::And(left, right) => {
let mut values = Self::extract_all_rhs_values(left);
values.extend(Self::extract_all_rhs_values(right));
values
}
Predicate::Or(left, right) => {
let mut values = Self::extract_all_rhs_values(left);
values.extend(Self::extract_all_rhs_values(right));
values
}
Predicate::Not(pred) => Self::extract_all_rhs_values(pred),
}
}
pub fn infer_value_type(_predicate: &Predicate) -> Option<EncryptedType> {
None
}
}
impl Default for PredicateCompiler {
fn default() -> Self {
Self::new()
}
}
pub fn compile_predicate(predicate: &Predicate, value_type: EncryptedType) -> Result<Circuit> {
let mut compiler = PredicateCompiler::new();
compiler.compile(predicate, value_type)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::col;
fn make_test_blob(value: u8) -> CipherBlob {
CipherBlob::new(vec![value])
}
#[test]
fn test_compiler_creation() {
let compiler = PredicateCompiler::new();
assert_eq!(compiler.builder.variable_types().len(), 0);
}
#[test]
fn test_compile_eq_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let predicate = Predicate::Eq(col("age"), make_test_blob(18));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert_eq!(circuit.variable_types.len(), 2);
assert!(circuit.variable_types.contains_key("value"));
assert!(circuit.variable_types.contains_key("rhs"));
Ok(())
}
#[test]
fn test_compile_gt_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let predicate = Predicate::Gt(col("age"), make_test_blob(18));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(circuit.gate_count > 0);
Ok(())
}
#[test]
fn test_compile_lt_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let predicate = Predicate::Lt(col("age"), make_test_blob(65));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
Ok(())
}
#[test]
fn test_compile_gte_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let predicate = Predicate::Gte(col("age"), make_test_blob(18));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
Ok(())
}
#[test]
fn test_compile_lte_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let predicate = Predicate::Lte(col("age"), make_test_blob(65));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
Ok(())
}
#[test]
fn test_compile_and_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let pred1 = Predicate::Gt(col("age"), make_test_blob(18));
let pred2 = Predicate::Lt(col("age"), make_test_blob(65));
let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(matches!(circuit.root, CircuitNode::BinaryOp { .. }));
assert!(circuit.gate_count >= 2);
Ok(())
}
#[test]
fn test_compile_or_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let pred1 = Predicate::Lt(col("age"), make_test_blob(18));
let pred2 = Predicate::Gt(col("age"), make_test_blob(65));
let predicate = Predicate::Or(Box::new(pred1), Box::new(pred2));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(matches!(circuit.root, CircuitNode::BinaryOp { .. }));
Ok(())
}
#[test]
fn test_compile_not_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let pred = Predicate::Eq(col("age"), make_test_blob(18));
let predicate = Predicate::Not(Box::new(pred));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(matches!(circuit.root, CircuitNode::UnaryOp { .. }));
Ok(())
}
#[test]
fn test_compile_complex_predicate() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let pred1 = Predicate::Gt(col("age"), make_test_blob(18));
let pred2 = Predicate::Lt(col("age"), make_test_blob(65));
let and_pred = Predicate::And(Box::new(pred1), Box::new(pred2));
let pred3 = Predicate::Eq(col("age"), make_test_blob(100));
let predicate = Predicate::Or(Box::new(and_pred), Box::new(pred3));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
assert!(circuit.gate_count >= 3);
assert!(circuit.depth >= 2);
Ok(())
}
#[test]
fn test_extract_rhs_value() -> Result<()> {
let blob = make_test_blob(42);
let predicate = Predicate::Gt(col("age"), blob.clone());
let extracted = PredicateCompiler::extract_rhs_value(&predicate)?;
assert_eq!(extracted, blob);
Ok(())
}
#[test]
fn test_extract_rhs_from_and() -> Result<()> {
let blob1 = make_test_blob(18);
let blob2 = make_test_blob(65);
let pred1 = Predicate::Gt(col("age"), blob1.clone());
let pred2 = Predicate::Lt(col("age"), blob2);
let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
let extracted = PredicateCompiler::extract_rhs_value(&predicate)?;
assert_eq!(extracted, blob1);
Ok(())
}
#[test]
fn test_extract_all_rhs_values() {
let blob1 = make_test_blob(18);
let blob2 = make_test_blob(65);
let pred1 = Predicate::Gt(col("age"), blob1.clone());
let pred2 = Predicate::Lt(col("age"), blob2.clone());
let predicate = Predicate::And(Box::new(pred1), Box::new(pred2));
let values = PredicateCompiler::extract_all_rhs_values(&predicate);
assert_eq!(values.len(), 2);
assert_eq!(values[0], blob1);
assert_eq!(values[1], blob2);
}
#[test]
fn test_compile_predicate_helper() -> Result<()> {
let predicate = Predicate::Eq(col("age"), make_test_blob(18));
let circuit = compile_predicate(&predicate, EncryptedType::U8)?;
assert_eq!(circuit.result_type, EncryptedType::Bool);
Ok(())
}
#[test]
fn test_circuit_validation() -> Result<()> {
let mut compiler = PredicateCompiler::new();
let predicate = Predicate::Gt(col("age"), make_test_blob(18));
let circuit = compiler.compile(&predicate, EncryptedType::U8)?;
circuit.validate()?;
Ok(())
}
}