use crate::custom_ops::{
run_instantiation_pass, ContextMappings, CustomOperation, CustomOperationBody,
};
use crate::data_types::{array_type, scalar_type, tuple_type, ScalarType, Type, BIT};
use crate::errors::Result;
use crate::graphs::util::simple_context;
use crate::graphs::SliceElement::SubArray;
use crate::graphs::{Context, Graph, Node, NodeAnnotation};
use crate::inline::inline_ops::{
inline_operations, DepthOptimizationLevel, InlineConfig, InlineMode,
};
use crate::mpc::mpc_compiler::{check_private_tuple, compile_to_mpc_graph, PARTIES};
use crate::ops::adder::BinaryAddTransposed;
use crate::ops::utils::{pull_out_bits, pull_out_bits_for_type, put_in_bits};
use crate::type_inference::a2b_type_inference;
use serde::{Deserialize, Serialize};
use super::mpc_arithmetic::{add_mpc, multiply_mpc};
use super::resharing::reshare;
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub(super) struct A2BMPC {}
#[typetag::serde]
impl CustomOperationBody for A2BMPC {
fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
if argument_types.len() == 1 {
if let Type::Array(_, _) | Type::Scalar(_) = argument_types[0].clone() {
let g = context.create_graph()?;
let input = g.input(argument_types[0].clone())?;
g.a2b(input)?.set_as_output()?;
g.finalize()?;
return Ok(g);
} else {
panic!("Inconsistency with type checker");
}
}
if argument_types.len() != 2 {
return Err(runtime_error!("A2BMPC should have either 1 or 2 inputs."));
}
if let (Type::Tuple(v0), Type::Tuple(v1)) =
(argument_types[0].clone(), argument_types[1].clone())
{
check_private_tuple(v0)?;
check_private_tuple(v1)?;
} else {
return Err(runtime_error!(
"A2BMPC should have a private tuple and a tuple of keys as input"
));
}
let t = argument_types[0].clone();
let input_t = if let Type::Tuple(t_vec) = t.clone() {
(*t_vec[0]).clone()
} else {
return Err(runtime_error!("Shouldn't be here"));
};
let bits_t = pull_out_bits_for_type(a2b_type_inference(input_t)?)?;
let shift_mpc_g = get_left_shift_graph(context.clone(), bits_t.clone())?;
let adder_mpc_g = get_binary_adder_graph(context.clone(), bits_t)?;
let g = context.create_graph()?;
let input = g.input(t)?;
let prf_type = argument_types[1].clone();
let prf_keys = g.input(prf_type)?;
let mut bit_shares = vec![];
let mut input_bits = vec![];
for i in 0..PARTIES {
input_bits.push(pull_out_bits(input.tuple_get(i as u64)?.a2b()?)?);
}
let zero_bits = g.zeros(input_bits[0].get_type()?)?;
for share_id in 0..PARTIES {
let mut bit_share = vec![];
for (party_id, share) in input_bits.iter().enumerate().take(PARTIES) {
let bit_share_arith = if share_id == party_id {
share.clone()
} else {
zero_bits.clone()
};
bit_share.push(bit_share_arith);
}
let bit_share_tuple = g.create_tuple(bit_share)?;
bit_shares.push(bit_share_tuple);
}
let transposed_output = add_3_bitstrings(
g.clone(),
adder_mpc_g,
shift_mpc_g,
bit_shares[0].clone(),
bit_shares[1].clone(),
bit_shares[2].clone(),
prf_keys,
)?;
let mut output = vec![];
for i in 0..PARTIES {
output.push(put_in_bits(transposed_output.tuple_get(i as u64)?)?);
}
let o = g.create_tuple(output)?;
o.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
"A2BMPC".to_owned()
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub(super) struct B2AMPC {
pub st: ScalarType,
}
#[typetag::serde]
impl CustomOperationBody for B2AMPC {
fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
if argument_types.len() == 1 {
if let Type::Array(_, _) | Type::Scalar(_) = argument_types[0].clone() {
let g = context.create_graph()?;
let input = g.input(argument_types[0].clone())?;
g.b2a(input, self.st)?.set_as_output()?;
g.finalize()?;
return Ok(g);
} else {
panic!("Inconsistency with type checker");
}
}
if argument_types.len() != 3 {
return Err(runtime_error!("B2AMPC should have either 1 or 3 inputs."));
}
if let (Type::Tuple(v0), Type::Tuple(v1), Type::Tuple(v2)) = (
argument_types[0].clone(),
argument_types[1].clone(),
argument_types[2].clone(),
) {
check_private_tuple(v0)?;
check_private_tuple(v1)?;
if v2.len() != 2 {
return Err(runtime_error!(
"There should be {} PRF key triples, but {} provided",
2,
v2.len()
));
}
if *v2[0] != *v2[1] {
return Err(runtime_error!("PRF keys should be of the same type"));
}
if let Type::Tuple(sub_v) = (*v2[0]).clone() {
check_private_tuple(sub_v)?;
} else {
return Err(runtime_error!(
"Special PRF keys for B2A should be a tuple of tuples"
));
}
} else {
return Err(runtime_error!(
"B2AMPC should have a private tuple and a tuple of keys as input"
));
}
let t = argument_types[0].clone();
let input_t = if let Type::Tuple(t_vec) = t.clone() {
pull_out_bits_for_type((*t_vec[0]).clone())?
} else {
panic!("Shouldn't be here");
};
let shift_mpc_g = get_left_shift_graph(context.clone(), input_t.clone())?;
let adder_mpc_g = get_binary_adder_graph(context.clone(), input_t.clone())?;
let g = context.create_graph()?;
let input = g.input(t)?;
let mut transposed_input = vec![];
for i in 0..PARTIES {
transposed_input.push(pull_out_bits(input.tuple_get(i as u64)?)?);
}
let input = g.create_tuple(transposed_input)?;
let prf_for_mul_type = argument_types[1].clone();
let prf_for_mul_keys = g.input(prf_for_mul_type)?;
let prf_for_random_type = argument_types[2].clone();
let prf_for_random_keys = g.input(prf_for_random_type)?;
let mut bit_shares = vec![];
let mut random_shares = vec![];
for share_id in 0..(PARTIES - 1) as u64 {
let mut random_share = vec![];
let prf_key_triple = prf_for_random_keys.tuple_get(share_id)?;
for i in 0..PARTIES as u64 {
let prf_key = prf_key_triple.tuple_get(i)?;
let random_value = g.prf(prf_key, 0, input_t.clone())?;
random_share.push(random_value);
}
random_shares.push(g.create_tuple(random_share.clone())?);
let bit_share = random_share[0]
.add(random_share[1].clone())?
.add(random_share[2].clone())?;
bit_shares.push(bit_share);
}
let last_share_shared = add_3_bitstrings(
g.clone(),
adder_mpc_g,
shift_mpc_g,
input,
random_shares[0].clone(),
random_shares[1].clone(),
prf_for_mul_keys,
)?;
let x1_share0 = last_share_shared.tuple_get(0)?.nop()?;
x1_share0.add_annotation(NodeAnnotation::Send(0, 1))?;
let x1_share2 = last_share_shared.tuple_get(2)?.nop()?;
x1_share2.add_annotation(NodeAnnotation::Send(1, 0))?;
let mut x1_revealed = last_share_shared.tuple_get(1)?;
x1_revealed = x1_revealed.add(x1_share0)?.add(x1_share2)?;
bit_shares.insert(1, x1_revealed);
let zero = g.zeros(scalar_type(self.st))?;
let mut arith_shares = vec![];
for share in bit_shares.into_iter() {
arith_shares.push(put_in_bits(share)?.b2a(self.st)?);
}
arith_shares[0] = zero.subtract(arith_shares[0].clone())?;
arith_shares[2] = zero.subtract(arith_shares[2].clone())?;
let o = g.create_tuple(arith_shares)?;
o.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!("B2AMPC({})", self.st)
}
}
fn get_left_shift_graph(context: Context, bits_t: Type) -> Result<Graph> {
let shift_g = context.create_graph()?;
{
let tuple_bits_t = tuple_type(vec![bits_t.clone(); PARTIES]);
let input = shift_g.input(tuple_bits_t)?;
let shape = bits_t.get_shape();
let mut new_shape = shape;
new_shape[0] = 1;
let zero = shift_g.zeros(array_type(new_shape, BIT))?;
let mut result_shares = vec![];
for i in 0..PARTIES {
let share = input.tuple_get(i as u64)?;
let rows = shift_g.concatenate(
vec![
zero.clone(),
share.get_slice(vec![SubArray(None, Some(-1), None)])?,
],
0,
)?;
result_shares.push(rows);
}
let o = shift_g.create_tuple(result_shares)?;
o.set_as_output()?;
shift_g.finalize()?;
}
Ok(shift_g)
}
fn get_binary_adder_graph(context: Context, bits_t: Type) -> Result<Graph> {
let adder_context = simple_context(|g| {
let input1 = g.input(bits_t.clone())?;
let input2 = g.input(bits_t)?;
g.custom_op(
CustomOperation::new(BinaryAddTransposed {
overflow_bit: false,
}),
vec![input1, input2],
)
})?;
let instantiated_adder_context = run_instantiation_pass(adder_context)?.get_context();
let inlined_adder_context = inline_operations(
instantiated_adder_context,
InlineConfig {
default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
..Default::default()
},
)?;
let mut context_map = ContextMappings::default();
let adder_g_inlined = inlined_adder_context.get_main_graph()?;
let adder_mpc_g =
compile_to_mpc_graph(adder_g_inlined, vec![true, true], context, &mut context_map)?;
Ok(adder_mpc_g)
}
fn add_3_bitstrings(
g: Graph,
adder_g: Graph,
shift_g: Graph,
a: Node,
b: Node,
c: Node,
prf_for_mul_keys: Node,
) -> Result<Node> {
let xor_12 = add_mpc(a.clone(), b.clone())?;
let and_12 = multiply_mpc(a, b, prf_for_mul_keys.clone(), false)?;
let xor_12_and_3 = multiply_mpc(xor_12.clone(), c.clone(), prf_for_mul_keys.clone(), false)?;
let xor_123 = add_mpc(xor_12, c)?;
let carry = add_mpc(and_12, xor_12_and_3)?;
let shifted_carry = reshare(&g.call(shift_g, vec![carry])?, &prf_for_mul_keys)?;
g.call(adder_g, vec![prf_for_mul_keys, xor_123, shifted_carry])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::subtract_vectors_u128;
use crate::data_types::{array_type, ScalarType, INT128, UINT128};
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::Operation;
use crate::inline::inline_ops::{InlineConfig, InlineMode};
use crate::mpc::mpc_compiler::{prepare_for_mpc_evaluation, IOStatus};
use crate::type_inference::a2b_type_inference;
fn prepare_context(
op: Operation,
party_id: IOStatus,
output_parties: Vec<IOStatus>,
t: Type,
inline_config: InlineConfig,
) -> Result<Context> {
let input_t = if op == Operation::A2B {
t.clone()
} else {
a2b_type_inference(t.clone())?
};
let c = simple_context(|g| {
let i = g.input(input_t)?;
g.add_node(vec![i], vec![], op)
})?;
prepare_for_mpc_evaluation(c, vec![vec![party_id]], vec![output_parties], inline_config)
}
fn prepare_input(
op: Operation,
input: Vec<u128>,
input_status: IOStatus,
t: Type,
) -> Result<Vec<Value>> {
let mut res = vec![];
if input_status == IOStatus::Public {
match t {
Type::Scalar(st) => {
res.push(Value::from_scalar(input[0], st)?);
}
Type::Array(_, st) => {
res.push(Value::from_flattened_array(&input, st)?);
}
_ => {
panic!("Shouldn't be here");
}
}
return Ok(res);
}
let mut data_input = vec![];
match t {
Type::Array(_, st) => {
if matches!(input_status, IOStatus::Party(_)) {
res.push(Value::from_flattened_array(&input, st)?);
return Ok(res);
}
if let Operation::B2A(_) = op {
let first_share: Vec<u128> = input.iter().map(|x| (*x) ^ 3).collect();
data_input.push(Value::from_flattened_array(&first_share, st)?);
} else {
let threes = vec![3; input.len()];
let first_share = subtract_vectors_u128(&input, &threes, st.get_modulus())?;
data_input.push(Value::from_flattened_array(&first_share, st)?);
}
for i in 1..PARTIES {
let share = vec![i; input.len()];
data_input.push(Value::from_flattened_array(&share, st)?);
}
}
Type::Scalar(st) => {
if matches!(input_status, IOStatus::Party(_)) {
res.push(Value::from_scalar(input[0], st)?);
return Ok(res);
}
if let Operation::B2A(_) = op {
let first_share = input[0] ^ 3;
data_input.push(Value::from_scalar(first_share, st)?);
} else {
let first_share = subtract_vectors_u128(&input, &vec![3], st.get_modulus())?;
data_input.push(Value::from_scalar(first_share[0], st)?);
}
for i in 1..PARTIES {
data_input.push(Value::from_scalar(i, st)?);
}
}
_ => {
panic!("Shouldn't be here");
}
}
res.push(Value::from_vector(data_input));
Ok(res)
}
fn check_output(
op: Operation,
mpc_graph: Graph,
inputs: Vec<Value>,
expected: Vec<u128>,
output_parties: Vec<IOStatus>,
t: Type,
) -> Result<()> {
let output = random_evaluate(mpc_graph.clone(), inputs)?;
let st = t.get_scalar_type();
let out = if output_parties.is_empty() {
output.access_vector(|v| {
let mut res = vec![0; expected.len()];
for val in v {
let arr = match t.clone() {
Type::Scalar(_) => {
vec![val.to_u128(st)?]
}
Type::Array(_, _) => val.to_flattened_array_u128(t.clone())?,
_ => {
panic!("Shouldn't be here");
}
};
for i in 0..expected.len() {
if op == Operation::A2B {
res[i] ^= arr[i];
} else {
res[i] = res[i].wrapping_add(arr[i]);
}
}
}
Ok(res)
})?
} else {
assert!(output.check_type(t.clone())?);
match t.clone() {
Type::Scalar(_) => vec![output.to_u128(st)?],
Type::Array(_, _) => output.to_flattened_array_u128(t.clone())?,
_ => {
panic!("Shouldn't be here");
}
}
};
let (expected, out) = if let Some(m) = st.get_modulus() {
(
expected.iter().map(|x| (x % m)).collect(),
out.iter().map(|x| (x % m)).collect(),
)
} else {
(expected, out)
};
assert_eq!(out, expected);
Ok(())
}
fn conversion_test(op: Operation, st: ScalarType) -> Result<()> {
let helper = |input: Vec<u128>,
input_status: IOStatus,
output_parties: Vec<IOStatus>,
inline_config: InlineConfig,
t: Type|
-> Result<()> {
if let Operation::B2A(st_b2a) = op.clone() {
if st_b2a != st {
panic!("The scalar type of B2A should be equal to the input scalar type");
}
}
let mpc_context = prepare_context(
op.clone(),
input_status.clone(),
output_parties.clone(),
t.clone(),
inline_config,
)?;
let mpc_graph = mpc_context.get_main_graph()?;
let inputs = prepare_input(op.clone(), input.clone(), input_status.clone(), t.clone())?;
check_output(
op.clone(),
mpc_graph,
inputs,
input.clone(),
output_parties,
t.clone(),
)?;
Ok(())
};
let inline_config_simple = InlineConfig {
default_mode: InlineMode::Simple,
..Default::default()
};
let helper_runs = |inputs: Vec<u128>, t: Type| -> Result<()> {
helper(
inputs.clone(),
IOStatus::Party(2),
vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
inline_config_simple.clone(),
t.clone(),
)?;
helper(
inputs.clone(),
IOStatus::Party(2),
vec![IOStatus::Party(0), IOStatus::Party(1)],
inline_config_simple.clone(),
t.clone(),
)?;
helper(
inputs.clone(),
IOStatus::Party(2),
vec![IOStatus::Party(0)],
inline_config_simple.clone(),
t.clone(),
)?;
helper(
inputs.clone(),
IOStatus::Party(2),
vec![],
inline_config_simple.clone(),
t.clone(),
)?;
helper(
inputs.clone(),
IOStatus::Public,
vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
inline_config_simple.clone(),
t.clone(),
)?;
helper(
inputs.clone(),
IOStatus::Public,
vec![],
inline_config_simple.clone(),
t.clone(),
)?;
Ok(())
};
helper_runs(vec![85], scalar_type(st))?;
helper_runs(vec![(-1233425456713636117134i128) as u128], scalar_type(st))?;
helper_runs(vec![1234531312235111221677134 as u128], scalar_type(st))?;
helper_runs(vec![2, 85], array_type(vec![2], st))?;
helper_runs(vec![0, 255], array_type(vec![2], st))?;
helper_runs(
vec![1133353228592345678, (-12345677123142726513i128) as u128],
array_type(vec![2], st),
)?;
helper_runs(
vec![i128::MIN as u128, i128::MAX as u128, 0, u128::MAX],
array_type(vec![4], st),
)?;
Ok(())
}
#[test]
fn test_a2b_mpc() {
conversion_test(Operation::A2B, UINT128).unwrap();
conversion_test(Operation::A2B, INT128).unwrap();
}
#[test]
fn test_b2a_mpc() {
conversion_test(Operation::B2A(UINT128), UINT128).unwrap();
conversion_test(Operation::B2A(INT128), INT128).unwrap();
}
}