use crossbeam::atomic::AtomicCell;
use sunscreen_compiler_common::GraphQuery;
use sunscreen_fhe_program::{FheProgram, Literal, Operation::*};
use sunscreen_runtime::traverse;
use std::collections::HashMap;
mod canonical_embedding_norm;
mod measured_model;
pub use canonical_embedding_norm::*;
pub use measured_model::*;
pub const NOISE_STD_DEV: f64 = 3.2f64;
pub const NOISE_NUM_STD_DEVIATIONS: f64 = 6f64;
pub const NOISE_MAX: f64 = NOISE_STD_DEV * NOISE_NUM_STD_DEVIATIONS;
pub fn predict_noise(model: &(dyn NoiseModel + Sync), fhe_program: &FheProgram) -> Vec<f64> {
let mut noise_levels: Vec<AtomicCell<f64>> = Vec::with_capacity(fhe_program.graph.node_count());
for _ in 0..fhe_program.graph.node_count() {
noise_levels.push(AtomicCell::new(0.));
}
let node_id_to_output_id = fhe_program
.graph
.node_indices()
.filter_map(|id| match fhe_program.graph[id].operation {
OutputCiphertext => Some(id.index()),
_ => None,
})
.enumerate()
.map(|(output_num, node_id)| (node_id, output_num))
.collect::<HashMap<usize, usize>>();
traverse(
fhe_program,
|node_id| {
let node = &fhe_program.graph[node_id];
let query = GraphQuery::new(&fhe_program.graph.0);
let noise = match &node.operation {
InputCiphertext(_) => model.encrypt(),
InputPlaintext(_) => 0.0,
Add => {
let (left, right) = query.get_binary_operands(node_id).unwrap();
model.add_ct_ct(
noise_levels[left.index()].load(),
noise_levels[right.index()].load(),
)
}
AddPlaintext => {
let (left, _) = query.get_binary_operands(node_id).unwrap();
model.add_ct_pt(noise_levels[left.index()].load())
}
Multiply => {
let (left, right) = query.get_binary_operands(node_id).unwrap();
model.mul_ct_ct(
noise_levels[left.index()].load(),
noise_levels[right.index()].load(),
)
}
MultiplyPlaintext => {
let (left, _) = query.get_binary_operands(node_id).unwrap();
model.mul_ct_pt(noise_levels[left.index()].load())
}
Relinearize => {
let x = query.get_unary_operand(node_id).unwrap();
model.relinearize(noise_levels[x.index()].load())
}
Negate => {
let x = query.get_unary_operand(node_id).unwrap();
model.neg(noise_levels[x.index()].load())
}
Sub => {
let (left, right) = query.get_binary_operands(node_id).unwrap();
model.sub_ct_ct(
noise_levels[left.index()].load(),
noise_levels[right.index()].load(),
)
}
SubPlaintext => {
let (left, _) = query.get_binary_operands(node_id).unwrap();
model.sub_ct_pt(noise_levels[left.index()].load())
}
OutputCiphertext => {
let x = query.get_unary_operand(node_id).unwrap();
let output_id = node_id_to_output_id[&node_id.index()];
model.output(output_id, noise_levels[x.index()].load())
}
Literal(_) => 0.0,
ShiftLeft => {
let (left, right) = query.get_binary_operands(node_id).unwrap();
let b = match fhe_program.graph[right].operation {
Literal(Literal::U64(v)) => v as i32,
_ => panic!(
"Illegal right operand for ShiftLeft: {:#?}",
fhe_program.graph[right].operation
),
};
model.shift_left(noise_levels[left.index()].load(), b)
}
ShiftRight => {
let (left, right) = query.get_binary_operands(node_id).unwrap();
let b = match fhe_program.graph[right].operation {
Literal(Literal::U64(v)) => v as i32,
_ => panic!(
"Illegal right operand for ShiftLeft: {:#?}",
fhe_program.graph[right].operation
),
};
model.shift_right(noise_levels[left.index()].load(), b)
}
SwapRows => {
let x = query.get_unary_operand(node_id).unwrap();
model.swap_rows(noise_levels[x.index()].load())
}
};
noise_levels[node_id.index()].store(noise);
Ok(())
},
None,
)
.unwrap();
noise_levels
.iter()
.zip(fhe_program.graph.node_indices())
.filter_map(|(x, node_id)| match fhe_program.graph[node_id].operation {
OutputCiphertext => Some(x.load()),
_ => None,
})
.collect()
}
pub fn noise_to_noise_budget(invariant_noise: f64) -> f64 {
-f64::log2(2. * invariant_noise)
}
pub fn noise_budget_to_noise(invariant_noise_budget: f64) -> f64 {
f64::powf(2., -invariant_noise_budget) / 2.
}
pub trait NoiseModel {
fn encrypt(&self) -> f64;
fn add_ct_ct(&self, a_invariant_noise: f64, b_invariant_noise: f64) -> f64;
fn add_ct_pt(&self, ct_invariant_noise: f64) -> f64;
fn mul_ct_ct(&self, a_invariant_noise: f64, b_invariant_noise: f64) -> f64;
fn mul_ct_pt(&self, a_invariant_noise: f64) -> f64;
fn relinearize(&self, a_invariant_noise: f64) -> f64;
fn output(&self, output_id: usize, invariant_noise: f64) -> f64;
fn neg(&self, invariant_noise: f64) -> f64;
fn sub_ct_ct(&self, a_invariant_noise: f64, b_invariant_noise: f64) -> f64;
fn sub_ct_pt(&self, a_invariant_noise: f64) -> f64;
fn swap_rows(&self, a_invariant_noise: f64) -> f64;
fn shift_left(&self, a_invariant_noise: f64, places: i32) -> f64;
fn shift_right(&self, a_invariant_noise: f64, places: i32) -> f64;
}
#[test]
fn can_roundtrip_noise_to_budget() {
let noise_budget = 42.;
let noise = noise_budget_to_noise(noise_budget);
let new_budget = noise_to_noise_budget(noise);
assert_eq!(new_budget, noise_budget);
}