use crate::data_types::get_size_in_bits;
use crate::errors::Result;
use crate::graphs::{Graph, Node, NodeAnnotation, Operation};
use std::collections::HashSet;
use super::mpc_compiler::{get_zero_shares, recursively_sum_shares, PARTIES};
struct ResharingConfig {
pub nodes_to_reshare: HashSet<Node>,
unreshared_nodes: HashSet<Node>,
}
impl ResharingConfig {
fn new() -> Self {
ResharingConfig {
nodes_to_reshare: HashSet::new(),
unreshared_nodes: HashSet::new(),
}
}
fn local_operation_handler(&mut self, node: Node) -> Result<()> {
let dependencies = node.get_node_dependencies();
if node.get_operation().is_broadcasting_called() {
let mut unreshared_input_size = 0;
for dep in &dependencies {
if self.unreshared_nodes.contains(dep) {
unreshared_input_size += get_size_in_bits(dep.get_type()?)?;
}
}
let output_size = get_size_in_bits(node.get_type()?)?;
if output_size > unreshared_input_size {
self.ensure_dependencies_are_reshared(&node);
} else if unreshared_input_size > 0 {
self.unreshared_nodes.insert(node);
}
} else {
let mut is_one_input_unreshared = false;
for dep in &dependencies {
if self.unreshared_nodes.contains(dep) {
is_one_input_unreshared = true;
}
}
if is_one_input_unreshared {
self.unreshared_nodes.insert(node);
}
}
Ok(())
}
fn ensure_dependencies_are_reshared(&mut self, node: &Node) {
for dep_node in node.get_node_dependencies() {
if self.unreshared_nodes.contains(&dep_node) {
self.unreshared_nodes.remove(&dep_node);
self.nodes_to_reshare.insert(dep_node);
}
}
}
fn compute_graph_resharing(
&mut self,
graph: &Graph,
private_nodes: &HashSet<Node>,
) -> Result<()> {
self.nodes_to_reshare = HashSet::new();
self.unreshared_nodes = HashSet::new();
for node in graph.get_nodes() {
if !private_nodes.contains(&node) {
continue;
}
let op = node.get_operation();
if !op.is_mpc_compiled() {
return Err(runtime_error!("This operation shouldn't be MPC compiled"));
}
match op {
Operation::Input(_)
=> {
}
Operation::Add
| Operation::Subtract
| Operation::Sum(_)
| Operation::CumSum(_)
| Operation::Get(_)
| Operation::Stack(_)
| Operation::Concatenate(_)
| Operation::Reshape(_)
| Operation::PermuteAxes(_)
| Operation::Zip
| Operation::Repeat(_)
| Operation::TupleGet(_)
| Operation::CreateNamedTuple(_)
| Operation::NamedTupleGet(_)
| Operation::VectorToArray
| Operation::VectorGet
| Operation::CreateTuple
| Operation::ArrayToVector
| Operation::CreateVector(_) => {
self.local_operation_handler(node)?;
}
Operation::Multiply
| Operation::Dot
| Operation::Matmul
| Operation::Gemm(_, _) => {
let dependencies = node.get_node_dependencies();
let mut all_inputs_are_shared = true;
for dep_node in &dependencies {
if !private_nodes.contains(dep_node) {
all_inputs_are_shared = false;
}
}
if all_inputs_are_shared {
self.ensure_dependencies_are_reshared(&node);
self.unreshared_nodes.insert(node);
} else {
self.local_operation_handler(node)?;
}
}
Operation::Join(_, _)
| Operation::JoinWithColumnMasks(_, _)
| Operation::Truncate(_)
| Operation::A2B
| Operation::B2A(_)
| Operation::Sort(_)
| Operation::GetSlice(_)
=> {
self.ensure_dependencies_are_reshared(&node);
}
Operation::MixedMultiply | Operation::ApplyPermutation(_) => {
let dependencies = node.get_node_dependencies();
if private_nodes.contains(&dependencies[1]) {
self.ensure_dependencies_are_reshared(&node);
} else {
self.local_operation_handler(node)?;
}
},
_ => {
return Err(runtime_error!("Unrecognized operation {}", op));
}
}
}
let out_node = graph.get_output_node()?;
if self.unreshared_nodes.contains(&out_node) {
self.nodes_to_reshare.insert(out_node);
}
self.sanity_pass(graph, private_nodes)
}
fn sanity_pass(&mut self, graph: &Graph, shared_nodes: &HashSet<Node>) -> Result<()> {
for node in graph.get_nodes() {
match node.get_operation() {
Operation::Multiply
| Operation::Matmul
| Operation::Dot
| Operation::Gemm(_, _) => {
let dependencies = node.get_node_dependencies();
let mut all_inputs_are_shared = true;
for dep_node in &dependencies {
if !shared_nodes.contains(dep_node) {
all_inputs_are_shared = false;
}
}
if all_inputs_are_shared {
continue;
}
}
_ => {}
}
if self.nodes_to_reshare.contains(&node) {
let mut node_should_be_reshared = false;
for dep in node.get_node_dependencies() {
if self.unreshared_nodes.contains(&dep) {
node_should_be_reshared = true;
}
}
if !node_should_be_reshared {
self.nodes_to_reshare.remove(&node);
}
}
if self.unreshared_nodes.contains(&node) {
let mut node_is_unreshared = false;
for dep in node.get_node_dependencies() {
if self.unreshared_nodes.contains(&dep) {
node_is_unreshared = true;
}
}
if !node_is_unreshared {
self.unreshared_nodes.remove(&node);
}
}
}
Ok(())
}
}
pub(super) fn get_nodes_to_reshare(
graph: &Graph,
shared_nodes: &HashSet<Node>,
) -> Result<HashSet<Node>> {
let mut resharing_config = ResharingConfig::new();
resharing_config.compute_graph_resharing(graph, shared_nodes)?;
Ok(resharing_config.nodes_to_reshare)
}
pub(super) fn reshare(input_shares: &Node, prf_keys: &Node) -> Result<Node> {
let g = input_shares.get_graph();
let input_shares_vec: Vec<Node> = (0..PARTIES)
.map(|i| input_shares.tuple_get(i as u64).unwrap())
.collect();
let zero_shares =
get_zero_shares(g.clone(), prf_keys.clone(), input_shares_vec[0].get_type()?)?;
let mut output_shares_vec = vec![];
for i in 0..PARTIES {
let masked_share = recursively_sum_shares(
g.clone(),
vec![input_shares_vec[i].clone(), zero_shares[i].clone()],
)?;
let sent_share = g.nop(masked_share)?;
let im1 = ((i + PARTIES - 1) % PARTIES) as u64;
sent_share.add_annotation(NodeAnnotation::Send(i as u64, im1))?;
output_shares_vec.push(sent_share);
}
g.create_tuple(output_shares_vec)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data_types::{array_type, BIT, INT32};
use crate::graphs::create_context;
use crate::mpc::mpc_compiler::propagate_private_annotations;
#[test]
fn test_resharing() -> Result<()> {
{
let c = create_context()?;
let g = c.create_graph()?;
let i1 = g.input(array_type(vec![2, 10], BIT))?;
let i2 = g.input(array_type(vec![10, 3], BIT))?;
let prod = i1.matmul(i2)?;
let out = prod.sum(vec![0])?;
out.set_as_output()?;
g.finalize()?;
g.set_as_main()?;
c.finalize()?;
let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 1);
assert!(reshared_nodes.contains(&out));
let shared_nodes = propagate_private_annotations(g.clone(), vec![false, true])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 0);
let shared_nodes = propagate_private_annotations(g.clone(), vec![false, false])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 0);
}
{
let c = create_context()?;
let g = c.create_graph()?;
let i1 = g.input(array_type(vec![2, 10], BIT))?;
let i2 = g.input(array_type(vec![10, 3], BIT))?;
let prod = i1.matmul(i2)?;
let i3 = g.input(array_type(vec![3, 4], BIT))?;
let out = prod.matmul(i3)?;
out.set_as_output()?;
g.finalize()?;
g.set_as_main()?;
c.finalize()?;
let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true, true])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 2);
assert!(reshared_nodes.contains(&prod));
assert!(reshared_nodes.contains(&out));
let shared_nodes = propagate_private_annotations(g.clone(), vec![false, true, true])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 1);
assert!(reshared_nodes.contains(&out));
let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true, false])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 1);
assert!(reshared_nodes.contains(&prod));
}
{
let c = create_context()?;
let g = c.create_graph()?;
let i1 = g.input(array_type(vec![2, 10], BIT))?;
let i2 = g.input(array_type(vec![10, 3], BIT))?;
let prod12 = i1.matmul(i2)?;
let i3 = g.input(array_type(vec![2, 4], BIT))?;
let i4 = g.input(array_type(vec![4, 3], BIT))?;
let prod34 = i3.matmul(i4)?;
let out = prod12.add(prod34)?;
out.set_as_output()?;
g.finalize()?;
g.set_as_main()?;
c.finalize()?;
let shared_nodes = propagate_private_annotations(g.clone(), vec![true; 4])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 1);
assert!(reshared_nodes.contains(&out));
}
{
let c = create_context()?;
let g = c.create_graph()?;
let i1 = g.input(array_type(vec![2, 10], INT32))?;
let i2 = g.input(array_type(vec![10, 3], INT32))?;
let prod12 = i1.matmul(i2)?;
let i3 = g.input(array_type(vec![2, 3], INT32))?;
let s123 = prod12.add(i3)?;
let out = s123.a2b()?;
out.set_as_output()?;
g.finalize()?;
g.set_as_main()?;
c.finalize()?;
let shared_nodes = propagate_private_annotations(g.clone(), vec![true; 3])?.0;
let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
assert!(reshared_nodes.len() == 1);
assert!(reshared_nodes.contains(&s123));
}
Ok(())
}
}