use std::convert::Infallible;
use sunscreen_compiler_common::{
forward_traverse_mut,
transforms::{GraphTransforms, Transform},
EdgeInfo, GraphQuery, NodeInfo,
};
use sunscreen_fhe_program::{
FheProgram,
Operation::{self, *},
};
use petgraph::{stable_graph::NodeIndex, visit::EdgeRef, Direction};
type FheGraphQuery<'a> = GraphQuery<'a, NodeInfo<Operation>, EdgeInfo>;
pub fn apply_insert_relinearizations(ir: &mut FheProgram) {
let insert_relin = |id: NodeIndex, query: FheGraphQuery| {
let mut transforms = GraphTransforms::new();
let relin_node = transforms.push(Transform::AddNode(NodeInfo {
operation: Operation::Relinearize,
}));
transforms.push(Transform::AddEdge(
id.into(),
relin_node.into(),
EdgeInfo::Unary,
));
for e in query.edges_directed(id, Direction::Outgoing) {
let operand_type = e.weight();
transforms.push(Transform::RemoveEdge(id.into(), e.target().into()));
transforms.push(Transform::AddEdge(
relin_node.into(),
e.target().into(),
*operand_type,
));
}
transforms
};
forward_traverse_mut(&mut ir.graph.0, |query, id| {
let transforms = match query.get_node(id).unwrap().operation {
Multiply => insert_relin(id, query),
_ => GraphTransforms::default(),
};
Ok::<_, Infallible>(transforms)
})
.unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
use petgraph::stable_graph::NodeIndex;
use sunscreen_compiler_common::GraphQuery;
use sunscreen_fhe_program::{
FheProgramTrait, Literal as FheProgramLiteral, Operation, SchemeType,
};
fn create_test_dag() -> FheProgram {
let mut ir = FheProgram::new(SchemeType::Bfv);
let ct = ir.add_input_ciphertext(0);
let l1 = ir.add_input_literal(FheProgramLiteral::from(7u64));
let add = ir.add_add(ct, l1);
let l2 = ir.add_input_literal(FheProgramLiteral::from(5u64));
let mul = ir.add_multiply(add, l2);
let add_2 = ir.add_add(mul, l2);
ir.add_multiply(add_2, ct);
ir
}
#[test]
fn inserts_relinearizations() {
let mut ir = create_test_dag();
assert_eq!(ir.graph.node_count(), 7);
apply_insert_relinearizations(&mut ir);
assert_eq!(ir.graph.node_count(), 9);
let query = GraphQuery::new(&ir.graph.0);
let relin_nodes = ir
.graph
.node_indices()
.filter(|i| {
matches!(
query.get_node(*i).unwrap().operation,
Operation::Relinearize
)
})
.collect::<Vec<NodeIndex>>();
assert_eq!(relin_nodes.len(), 2);
assert!(relin_nodes
.iter()
.all(|id| { query.neighbors_directed(*id, Direction::Incoming).count() == 1 }),);
assert!(relin_nodes.iter().all(|id| {
query
.neighbors_directed(*id, Direction::Incoming)
.map(|id| query.get_node(id).unwrap())
.all(|node| matches!(node.operation, Operation::Multiply))
}));
assert_eq!(
query
.neighbors_directed(relin_nodes[0], Direction::Outgoing)
.count(),
1
);
assert_eq!(
query
.neighbors_directed(relin_nodes[1], Direction::Outgoing)
.count(),
0
);
assert!(query
.neighbors_directed(relin_nodes[0], Direction::Outgoing)
.all(|i| { matches!(query.get_node(i).unwrap().operation, Operation::Add) }),);
}
}