use std::{
collections::{HashMap, VecDeque},
fmt::Write,
ops::{Deref, DerefMut},
sync::Arc,
};
use mux_circuits::{MuxCircuit, MuxEdgeInfo, MuxOp};
use parasol_concurrency::AtomicRefCell;
use petgraph::{Direction, prelude::StableGraph, stable_graph::NodeIndex, visit::EdgeRef};
use crate::crypto::{
Encryption, L0LweCiphertext, L1GgswCiphertext, L1GlevCiphertext, L1GlweCiphertext,
L1LweCiphertext, ciphertext::CiphertextType,
};
pub type SharedL0LweCiphertext = Arc<AtomicRefCell<L0LweCiphertext>>;
pub type SharedL1LweCiphertext = Arc<AtomicRefCell<L1LweCiphertext>>;
pub type SharedL1GlweCiphertext = Arc<AtomicRefCell<L1GlweCiphertext>>;
pub type SharedL1GgswCiphertext = Arc<AtomicRefCell<L1GgswCiphertext>>;
pub type SharedL1GlevCiphertext = Arc<AtomicRefCell<L1GlevCiphertext>>;
#[derive(Clone)]
pub enum FheOp {
InputLwe0(SharedL0LweCiphertext),
InputLwe1(SharedL1LweCiphertext),
InputGlwe1(SharedL1GlweCiphertext),
InputGgsw1(SharedL1GgswCiphertext),
InputGlev1(SharedL1GlevCiphertext),
OutputLwe0(SharedL0LweCiphertext),
OutputLwe1(SharedL1LweCiphertext),
OutputGlwe1(SharedL1GlweCiphertext),
OutputGgsw1(SharedL1GgswCiphertext),
OutputGlev1(SharedL1GlevCiphertext),
SampleExtract(usize),
KeyswitchL1toL0,
Not,
GlweAdd,
CMux,
GlevCMux,
MultiplyGgswGlwe,
CircuitBootstrap,
SchemeSwitch,
ZeroLwe0,
OneLwe0,
ZeroGlwe1,
OneGlwe1,
ZeroGgsw1,
OneGgsw1,
ZeroGlev1,
OneGlev1,
Retire,
Nop,
MulXN(usize),
}
impl std::fmt::Debug for FheOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut potential_string = String::new();
let op = match self {
Self::InputLwe0(_) => "InputLwe0",
Self::InputLwe1(_) => "InputLwe1",
Self::InputGlwe1(_) => "InputGlwe1",
Self::InputGgsw1(_) => "InputGgsw1",
Self::InputGlev1(_) => "InputGlev1",
Self::OutputLwe0(_) => "OutputLwe0",
Self::OutputLwe1(_) => "OutputLwe1",
Self::OutputGlwe1(_) => "OutputGlwe1",
Self::OutputGgsw1(_) => "OutputGgsw1",
Self::OutputGlev1(_) => "OutputGlev1",
Self::SampleExtract(_) => "SampleExtract",
Self::Not => "Not",
Self::GlweAdd => "GlweAdd",
Self::KeyswitchL1toL0 => "KeyswitchL1toL0",
Self::CMux => "CMux",
Self::GlevCMux => "GlevCMux",
Self::MultiplyGgswGlwe => "MultiplyGgswGlwe",
Self::CircuitBootstrap => "CircuitBootstrap",
Self::ZeroLwe0 => "ZeroLwe0",
Self::OneLwe0 => "ZeroLwe1",
Self::ZeroGlwe1 => "ZeroGlwe1",
Self::OneGlwe1 => "OneGlwe1",
Self::ZeroGgsw1 => "ZeroGgsw1",
Self::OneGgsw1 => "OneGgsw1",
Self::ZeroGlev1 => "ZeroGlev1",
Self::OneGlev1 => "OneGlev1",
Self::Retire => "Retire",
Self::Nop => "Nop",
Self::MulXN(amt) => {
write!(&mut potential_string, "Rotate({amt})")?;
&potential_string
}
Self::SchemeSwitch => "SchemeSwitch",
};
write!(f, "{op}")
}
}
#[derive(Copy, Clone, Debug)]
pub enum FheEdge {
Low,
High,
Sel,
Unary,
Glwe,
Ggsw,
Left,
Right,
}
#[derive(Debug)]
pub struct FheCircuit {
pub graph: StableGraph<FheOp, FheEdge>,
}
impl Deref for FheCircuit {
type Target = StableGraph<FheOp, FheEdge>;
fn deref(&self) -> &Self::Target {
&self.graph
}
}
impl DerefMut for FheCircuit {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.graph
}
}
impl Default for FheCircuit {
fn default() -> Self {
Self::new()
}
}
pub enum MuxMode {
Glwe,
Glev,
}
impl MuxMode {
pub fn mux(&self) -> FheOp {
match self {
Self::Glwe => FheOp::CMux,
Self::Glev => FheOp::GlevCMux,
}
}
pub fn zero(&self) -> FheOp {
match self {
Self::Glwe => FheOp::ZeroGlwe1,
Self::Glev => FheOp::ZeroGlev1,
}
}
pub fn one(&self) -> FheOp {
match self {
Self::Glwe => FheOp::OneGlwe1,
Self::Glev => FheOp::OneGlev1,
}
}
}
impl FheCircuit {
pub fn new() -> Self {
Self {
graph: StableGraph::new(),
}
}
pub fn insert_mux_circuit(
&mut self,
mux_circuit: &MuxCircuit,
nodes_to_inputs: &[NodeIndex],
mux_mode: MuxMode,
) -> Vec<NodeIndex> {
assert_eq!(mux_circuit.inputs.len(), nodes_to_inputs.len());
let mut node_renames = HashMap::new();
let mut outputs = vec![
NodeIndex::default();
mux_circuit
.graph
.node_weights()
.filter(|x| matches!(x, MuxOp::Output(_)))
.count()
];
for i in mux_circuit.graph.node_indices() {
let mux_op = mux_circuit.graph[i];
let fhe_equivalent_op = match mux_op {
MuxOp::Mux => Some(mux_mode.mux()),
MuxOp::One => Some(mux_mode.one()),
MuxOp::Zero => Some(mux_mode.zero()),
_ => None,
};
if let Some(n) = fhe_equivalent_op {
let new_idx = self.graph.add_node(n);
node_renames.insert(i, new_idx);
}
}
for (fhe_provided_op_index, mux_op_index) in
nodes_to_inputs.iter().zip(mux_circuit.inputs.iter())
{
if !matches!(mux_circuit.graph[*mux_op_index], MuxOp::Variable(_)) {
panic!("Mux trees can only be connected to Variable nodes.");
}
if !matches!(
self.graph[*fhe_provided_op_index],
FheOp::InputGgsw1(_)
| FheOp::CircuitBootstrap
| FheOp::ZeroGgsw1
| FheOp::OneGgsw1
| FheOp::SchemeSwitch
) {
panic!("Mux trees can only be connected to Ggsw, CBS, or Scheme switch nodes.");
}
for e in mux_circuit
.graph
.edges_directed(*mux_op_index, petgraph::Direction::Outgoing)
{
let target = node_renames.get(&e.target()).unwrap();
self.graph
.add_edge(*fhe_provided_op_index, *target, Self::map_edge(e.weight()));
}
}
for i in mux_circuit
.graph
.node_indices()
.filter(|n| matches!(mux_circuit.graph[*n], MuxOp::Output(_)))
{
let o = mux_circuit.graph[i];
match o {
MuxOp::Output(o) => {
let prev = mux_circuit
.graph
.edges_directed(i, petgraph::Direction::Incoming)
.nth(0)
.unwrap();
let idx = node_renames.get(&prev.source()).unwrap();
outputs[o as usize] = *idx;
}
_ => unreachable!(),
}
}
for i in mux_circuit.graph.node_indices() {
let node = mux_circuit.graph[i];
if matches!(node, MuxOp::Output(_)) || matches!(node, MuxOp::Variable(_)) {
continue;
}
for e in mux_circuit
.graph
.edges_directed(i, petgraph::Direction::Outgoing)
{
let src = node_renames.get(&e.source()).unwrap();
let dst = node_renames.get(&e.target());
if let Some(dst) = dst {
self.graph.add_edge(*src, *dst, Self::map_edge(e.weight()));
}
}
}
outputs
}
pub fn insert_mux_circuit_output_glwe1_outputs(
&mut self,
mux_circuit: &MuxCircuit,
nodes_to_inputs: &[NodeIndex],
enc: &Encryption,
) -> Vec<NodeIndex> {
let cmux_outputs = self.insert_mux_circuit(mux_circuit, nodes_to_inputs, MuxMode::Glwe);
let glwe_outputs = (0..cmux_outputs.len())
.map(|_| Arc::new(AtomicRefCell::new(enc.allocate_glwe_l1())))
.collect::<Vec<_>>();
cmux_outputs
.iter()
.zip(glwe_outputs.iter())
.map(|(cmux_out, glwe_out)| {
let o = self.graph.add_node(FheOp::OutputGlwe1(glwe_out.clone()));
self.graph.add_edge(*cmux_out, o, FheEdge::Unary);
o
})
.collect::<Vec<_>>()
}
pub fn insert_mux_circuit_l1glwe_outputs(
&mut self,
mux_circuit: &MuxCircuit,
nodes_to_inputs: &[NodeIndex],
enc: &Encryption,
) -> Vec<Arc<AtomicRefCell<L1GlweCiphertext>>> {
let glwe_outputs =
self.insert_mux_circuit_output_glwe1_outputs(mux_circuit, nodes_to_inputs, enc);
glwe_outputs
.iter()
.map(|x| {
let node = self.graph.node_weight(*x).unwrap();
match node {
FheOp::OutputGlwe1(x) => x.clone(),
_ => unreachable!(),
}
})
.collect::<Vec<_>>()
}
pub fn insert_mux_circuit_and_connect_inputs(
&mut self,
mux_circuit: &MuxCircuit,
inputs: &[Arc<AtomicRefCell<L1GlweCiphertext>>],
enc: &Encryption,
) -> Vec<Arc<AtomicRefCell<L1GlweCiphertext>>> {
let node_indices = inputs
.iter()
.map(|input| {
let i = self.add_node(FheOp::InputGlwe1(input.clone()));
let se = self.add_node(FheOp::SampleExtract(0));
self.add_edge(i, se, FheEdge::Unary);
let ks = self.add_node(FheOp::KeyswitchL1toL0);
self.add_edge(se, ks, FheEdge::Unary);
let cbs = self.add_node(FheOp::CircuitBootstrap);
self.add_edge(ks, cbs, FheEdge::Unary);
cbs
})
.collect::<Vec<_>>();
self.insert_mux_circuit_l1glwe_outputs(mux_circuit, &node_indices, enc)
}
fn map_edge(e: &MuxEdgeInfo) -> FheEdge {
match e {
MuxEdgeInfo::High => FheEdge::High,
MuxEdgeInfo::Low => FheEdge::Low,
MuxEdgeInfo::Select => FheEdge::Sel,
MuxEdgeInfo::Output => unreachable!(),
}
}
}
impl From<StableGraph<FheOp, FheEdge>> for FheCircuit {
fn from(value: StableGraph<FheOp, FheEdge>) -> Self {
Self { graph: value }
}
}
pub fn prune<N: Clone, E: Clone>(
graph: &StableGraph<N, E>,
nodes: &[NodeIndex],
) -> (StableGraph<N, E>, HashMap<NodeIndex, NodeIndex>) {
let mut out_graph = StableGraph::new();
let mut queue = VecDeque::new();
let mut rename = HashMap::new();
for i in nodes {
queue.push_back(*i);
}
while !queue.is_empty() {
let cur_id = queue.pop_front().unwrap();
rename
.entry(cur_id)
.or_insert_with(|| out_graph.add_node(graph.node_weight(cur_id).unwrap().to_owned()));
for next in graph.neighbors_directed(cur_id, Direction::Incoming) {
if let std::collections::hash_map::Entry::Vacant(e) = rename.entry(next) {
let new_id = out_graph.add_node(graph.node_weight(next).unwrap().to_owned());
e.insert(new_id);
queue.push_back(next);
}
}
}
for (old, _) in rename.iter() {
for e in graph.edges_directed(*old, Direction::Incoming) {
let source = *rename.get(&e.source()).unwrap();
let target = *rename.get(&e.target()).unwrap();
out_graph.add_edge(source, target, e.weight().to_owned());
}
}
(out_graph, rename)
}
pub fn insert_ciphertext_conversion(
graph: &mut FheCircuit,
cur_node: NodeIndex,
in_type: CiphertextType,
out_type: CiphertextType,
) -> NodeIndex {
if in_type == out_type {
return cur_node;
}
let (conv_idx, next_type) = match in_type {
CiphertextType::L0LweCiphertext => {
let idx = graph.add_node(FheOp::CircuitBootstrap);
graph.add_edge(cur_node, idx, FheEdge::Unary);
(idx, CiphertextType::L1GgswCiphertext)
}
CiphertextType::L1GgswCiphertext => {
if out_type == CiphertextType::L1GlevCiphertext {
let idx = graph.add_node(FheOp::GlevCMux);
let zero = graph.add_node(FheOp::ZeroGlev1);
let one = graph.add_node(FheOp::OneGlev1);
graph.add_edge(zero, idx, FheEdge::Low);
graph.add_edge(one, idx, FheEdge::High);
graph.add_edge(cur_node, idx, FheEdge::Sel);
(idx, out_type)
} else {
let idx = graph.add_node(FheOp::MultiplyGgswGlwe);
let one = graph.add_node(FheOp::OneGlwe1);
graph.add_edge(one, idx, FheEdge::Glwe);
graph.add_edge(cur_node, idx, FheEdge::Ggsw);
(idx, CiphertextType::L1GlweCiphertext)
}
}
CiphertextType::L1GlweCiphertext => {
let idx = graph.add_node(FheOp::SampleExtract(0));
graph.add_edge(cur_node, idx, FheEdge::Unary);
(idx, CiphertextType::L1LweCiphertext)
}
CiphertextType::L1LweCiphertext => {
let idx = graph.add_node(FheOp::KeyswitchL1toL0);
graph.add_edge(cur_node, idx, FheEdge::Unary);
(idx, CiphertextType::L0LweCiphertext)
}
CiphertextType::L1GlevCiphertext => {
let idx = graph.add_node(FheOp::SchemeSwitch);
graph.add_edge(cur_node, idx, FheEdge::Unary);
(idx, CiphertextType::L1GgswCiphertext)
}
};
insert_ciphertext_conversion(graph, conv_idx, next_type, out_type)
}