use petgraph::stable_graph::NodeIndex;
use serde::{Deserialize, Serialize};
use sunscreen_backend::compile_inplace;
use sunscreen_compiler_common::{
CompilationResult, Context, EdgeInfo, NodeInfo, Operation as OperationTrait,
};
use sunscreen_fhe_program::{
FheProgram, Literal as FheProgramLiteral, Operation as FheProgramOperation, SchemeType,
};
use sunscreen_runtime::{InnerPlaintext, Params};
use std::cell::RefCell;
#[derive(Clone, Debug, Deserialize, Hash, Serialize, PartialEq, Eq)]
pub enum Literal {
U64(u64),
Plaintext(InnerPlaintext),
}
#[derive(Clone, Debug, Hash, Deserialize, Serialize, PartialEq, Eq)]
pub enum FheOperation {
InputCiphertext,
InputPlaintext,
Add,
AddPlaintext,
Sub,
SubPlaintext,
Negate,
Multiply,
MultiplyPlaintext,
Literal(Literal),
RotateLeft,
RotateRight,
SwapRows,
Output,
}
impl OperationTrait for FheOperation {
fn is_binary(&self) -> bool {
matches!(
self,
FheOperation::Add
| FheOperation::Multiply
| FheOperation::Sub
| FheOperation::RotateLeft
| FheOperation::RotateRight
| FheOperation::SubPlaintext
| FheOperation::AddPlaintext
| FheOperation::MultiplyPlaintext
)
}
fn is_commutative(&self) -> bool {
matches!(
self,
FheOperation::Add
| FheOperation::Multiply
| FheOperation::AddPlaintext
| FheOperation::MultiplyPlaintext
)
}
fn is_unary(&self) -> bool {
matches!(self, FheOperation::Negate | FheOperation::SwapRows)
}
fn is_unordered(&self) -> bool {
false
}
fn is_ordered(&self) -> bool {
false
}
}
pub type FheContext = Context<FheOperation, Params>;
pub type FheFrontendCompilation = CompilationResult<FheOperation>;
thread_local! {
pub static CURRENT_FHE_CTX: RefCell<Option<&'static mut FheContext>> = RefCell::new(None);
}
pub fn with_fhe_ctx<F, R>(f: F) -> R
where
F: FnOnce(&mut FheContext) -> R,
{
CURRENT_FHE_CTX.with(|ctx| {
let mut option = ctx.borrow_mut();
let ctx = option
.as_mut()
.expect("Called Ciphertext::new() outside of a context.");
f(ctx)
})
}
pub trait FheContextOps {
fn add_ciphertext_input(&mut self) -> NodeIndex;
fn add_plaintext_input(&mut self) -> NodeIndex;
fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex;
fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_negate(&mut self, x: NodeIndex) -> NodeIndex;
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_literal(&mut self, literal: Literal) -> NodeIndex;
fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex;
fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex;
fn add_output(&mut self, i: NodeIndex) -> NodeIndex;
}
impl FheContextOps for FheContext {
fn add_ciphertext_input(&mut self) -> NodeIndex {
self.add_node(FheOperation::InputCiphertext)
}
fn add_plaintext_input(&mut self) -> NodeIndex {
self.add_node(FheOperation::InputPlaintext)
}
fn add_plaintext_literal(&mut self, plaintext: InnerPlaintext) -> NodeIndex {
self.add_node(FheOperation::Literal(Literal::Plaintext(plaintext)))
}
fn add_subtraction(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::Sub, left, right)
}
fn add_subtraction_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::SubPlaintext, left, right)
}
fn add_negate(&mut self, x: NodeIndex) -> NodeIndex {
self.add_unary_operation(FheOperation::Negate, x)
}
fn add_addition(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::Add, left, right)
}
fn add_addition_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::AddPlaintext, left, right)
}
fn add_multiplication(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::Multiply, left, right)
}
fn add_multiplication_plaintext(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::MultiplyPlaintext, left, right)
}
fn add_literal(&mut self, literal: Literal) -> NodeIndex {
let existing_literal =
self.graph
.node_indices()
.find(|&i| match &self.graph[i].operation {
FheOperation::Literal(x) => *x == literal,
_ => false,
});
match existing_literal {
Some(x) => x,
None => self.add_node(FheOperation::Literal(literal)),
}
}
fn add_rotate_left(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::RotateLeft, left, right)
}
fn add_rotate_right(&mut self, left: NodeIndex, right: NodeIndex) -> NodeIndex {
self.add_binary_operation(FheOperation::RotateRight, left, right)
}
fn add_swap_rows(&mut self, x: NodeIndex) -> NodeIndex {
self.add_unary_operation(FheOperation::SwapRows, x)
}
fn add_output(&mut self, i: NodeIndex) -> NodeIndex {
self.add_unary_operation(FheOperation::Output, i)
}
}
pub trait FheCompile {
fn compile(&self) -> FheProgram;
}
impl FheCompile for FheFrontendCompilation {
fn compile(&self) -> FheProgram {
let mut fhe_program = FheProgram::new(SchemeType::Bfv);
let mapped_graph = self.0.map(
|id, n| match &n.operation {
FheOperation::Add => NodeInfo::new(FheProgramOperation::Add),
FheOperation::InputCiphertext => {
NodeInfo::new(FheProgramOperation::InputCiphertext(id.index()))
}
FheOperation::InputPlaintext => {
NodeInfo::new(FheProgramOperation::InputPlaintext(id.index()))
}
FheOperation::Literal(Literal::U64(x)) => {
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::U64(*x)))
}
FheOperation::Literal(Literal::Plaintext(x)) => {
NodeInfo::new(FheProgramOperation::Literal(FheProgramLiteral::Plaintext(
x.to_bytes().expect("Failed to serialize plaintext."),
)))
}
FheOperation::Sub => NodeInfo::new(FheProgramOperation::Sub),
FheOperation::SubPlaintext => NodeInfo::new(FheProgramOperation::SubPlaintext),
FheOperation::Negate => NodeInfo::new(FheProgramOperation::Negate),
FheOperation::Multiply => NodeInfo::new(FheProgramOperation::Multiply),
FheOperation::MultiplyPlaintext => {
NodeInfo::new(FheProgramOperation::MultiplyPlaintext)
}
FheOperation::Output => NodeInfo::new(FheProgramOperation::OutputCiphertext),
FheOperation::RotateLeft => NodeInfo::new(FheProgramOperation::ShiftLeft),
FheOperation::RotateRight => NodeInfo::new(FheProgramOperation::ShiftRight),
FheOperation::SwapRows => NodeInfo::new(FheProgramOperation::SwapRows),
FheOperation::AddPlaintext => NodeInfo::new(FheProgramOperation::AddPlaintext),
},
|_, e| match e {
EdgeInfo::Left => EdgeInfo::Left,
EdgeInfo::Right => EdgeInfo::Right,
EdgeInfo::Unary => EdgeInfo::Unary,
EdgeInfo::Unordered => unreachable!("FHE programs have no unordered edges."),
EdgeInfo::Ordered(_) => unreachable!("FHE programs have no ordered edges."),
},
);
fhe_program.graph = CompilationResult(mapped_graph);
compile_inplace(fhe_program)
}
}