use crate::custom_ops::CustomOperationBody;
use crate::data_types::{array_type, scalar_size_in_bits, scalar_type, Type, BIT};
use crate::data_values::Value;
use crate::errors::Result;
use crate::graphs::{Context, Graph, Node, NodeAnnotation};
use crate::mpc::mpc_compiler::{check_private_tuple, KEY_LENGTH, PARTIES};
use crate::ops::utils::constant_scalar;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub(super) struct TruncateMPC {
pub scale: u128,
}
#[typetag::serde]
impl CustomOperationBody for TruncateMPC {
fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
if argument_types.len() == 1 {
if let Type::Array(_, st) | Type::Scalar(st) = argument_types[0].clone() {
if !st.is_signed() {
return Err(runtime_error!(
"Only signed types are supported by TruncateMPC"
));
}
let g = context.create_graph()?;
let input = g.input(argument_types[0].clone())?;
let o = if self.scale == 1 {
input
} else {
input.truncate(self.scale)?
};
o.set_as_output()?;
g.finalize()?;
return Ok(g);
} else {
panic!("Inconsistency with type checker");
}
}
if argument_types.len() != 2 {
return Err(runtime_error!(
"TruncateMPC should have either 1 or 2 inputs."
));
}
if let (Type::Tuple(v0), Type::Tuple(v1)) =
(argument_types[0].clone(), argument_types[1].clone())
{
check_private_tuple(v0)?;
check_private_tuple(v1)?;
} else {
return Err(runtime_error!(
"TruncateMPC should have a private tuple and a tuple of keys as input"
));
}
let t = argument_types[0].clone();
let input_t = if let Type::Tuple(t_vec) = t.clone() {
(*t_vec[0]).clone()
} else {
panic!("Shouldn't be here");
};
if !input_t.get_scalar_type().is_signed() {
return Err(runtime_error!(
"Only signed types are supported by TruncateMPC"
));
}
let g = context.create_graph()?;
let input_node = g.input(t)?;
let prf_type = argument_types[1].clone();
let prf_keys = g.input(prf_type)?;
if self.scale == 1 {
input_node.set_as_output()?;
g.finalize()?;
return Ok(g);
}
let prf_key_parties_12 = prf_keys.tuple_get(PARTIES as u64 - 1)?;
let random_node = g.prf(prf_key_parties_12, 0, input_t)?;
let mut result_shares = vec![];
let res0 = input_node.tuple_get(0)?.truncate(self.scale)?;
result_shares.push(res0);
let res1 = input_node
.tuple_get(1)?
.add(input_node.tuple_get(2)?)?
.truncate(self.scale)?
.subtract(random_node.clone())?;
let res1_sent = res1.nop()?;
res1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
result_shares.push(res1_sent);
result_shares.push(random_node);
g.create_tuple(result_shares)?.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!("TruncateMPC({})", self.scale)
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub(super) struct TruncateMPC2K {
pub k: u64,
}
#[typetag::serde]
impl CustomOperationBody for TruncateMPC2K {
fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
if argument_types.len() == 1 {
if let Type::Array(_, _) | Type::Scalar(_) = argument_types[0].clone() {
let g = context.create_graph()?;
let input = g.input(argument_types[0].clone())?;
let o = if self.k == 0 {
input
} else {
input.truncate(1 << self.k)?
};
o.set_as_output()?;
g.finalize()?;
return Ok(g);
} else {
panic!("Inconsistency with type checker");
}
}
if argument_types.len() != 3 {
return Err(runtime_error!("TruncateMPC2K should have 3 inputs."));
}
if let Type::Tuple(v0) = argument_types[0].clone() {
check_private_tuple(v0)?;
} else {
if !argument_types[0].is_array() && !argument_types[0].is_scalar() {
panic!("Inconsistency with type checker");
}
let g = context.create_graph()?;
let input = g.input(argument_types[0].clone())?;
let o = input.truncate(1 << self.k)?;
o.set_as_output()?;
g.finalize()?;
return Ok(g);
}
let key_type = array_type(vec![KEY_LENGTH], BIT);
if let Type::Tuple(v0) = argument_types[1].clone() {
check_private_tuple(v0.clone())?;
for t in v0 {
if *t != key_type {
return Err(runtime_error!("PRF key is of a wrong type"));
}
}
} else {
return Err(runtime_error!("PRF key is of a wrong type"));
}
if argument_types[2] != key_type {
return Err(runtime_error!("PRF key is of a wrong type"));
}
let t = argument_types[0].clone();
let input_t = if let Type::Tuple(t_vec) = t.clone() {
(*t_vec[0]).clone()
} else {
panic!("Shouldn't be here");
};
if !input_t.is_array() && !input_t.is_scalar() {
panic!("Inconsistency with type checker");
}
let g = context.create_graph()?;
let input_node = g.input(t)?;
let prf_mul_type = argument_types[1].clone();
let prf_mul_keys = g.input(prf_mul_type)?;
let prf_truncate_type = argument_types[2].clone();
let key_2 = g.input(prf_truncate_type)?;
if self.k == 0 {
input_node.set_as_output()?;
g.finalize()?;
return Ok(g);
}
let key_02 = prf_mul_keys.tuple_get(0)?;
let key_12 = prf_mul_keys.tuple_get(2)?;
let st = input_t.get_scalar_type();
let st_size = scalar_size_in_bits(st);
let x0 = {
let share = input_node.tuple_get(0)?;
if st.is_signed() {
let mod_fraction = constant_scalar(&g, 1u128 << (st_size - 2), st)?;
share.add(mod_fraction)?
} else {
share
}
};
let x1 = input_node.tuple_get(1)?;
let x2 = input_node.tuple_get(2)?;
let r = g.prf(key_2, 0, input_t.clone())?;
let unsigned_st = st.get_unsigned_counterpart();
let r_msb = {
let mask = constant_scalar(&g, 1u128 << (st_size - 1), unsigned_st)?.a2b()?;
let r_msb_scaled = r.a2b()?.multiply(mask)?.b2a(unsigned_st)?;
r_msb_scaled.truncate(1 << (st_size - 1))?.a2b()?.b2a(st)?
};
let r_truncated = {
let mask = constant_scalar(
&g,
(1u128 << (st_size - 1)) - (1u128 << self.k),
unsigned_st,
)?
.a2b()?;
r.a2b()?.multiply(mask)?.b2a(st)?.truncate(1 << self.k)?
};
let share_for_two = |val: Node| -> Result<(Node, Node)> {
let share0 = g.prf(key_02.clone(), 0, val.get_type()?)?;
let share1 = val.subtract(share0.clone())?;
let share1_sent = share1.nop()?;
share1_sent.add_annotation(NodeAnnotation::Send(2, 1))?;
Ok((share0, share1_sent))
};
let (r0, r1) = share_for_two(r)?;
let (r_msb0, r_msb1) = share_for_two(r_msb)?;
let (r_truncated0, r_truncated1) = share_for_two(r_truncated)?;
let y0 = g.prf(key_02, 0, input_t.clone())?;
let y2 = g.prf(key_12, 0, input_t)?;
let z0 = x0.add(x1)?;
let z1 = x2;
let c_share0 = z0.add(r0)?;
let c_share1 = z1.add(r1)?;
let c_share0_sent = c_share0.nop()?;
c_share0_sent.add_annotation(NodeAnnotation::Send(0, 1))?;
let c_share1_sent = c_share1.nop()?;
c_share1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
let c = c_share0_sent.add(c_share1_sent)?;
let c_truncated = c
.a2b()?
.b2a(unsigned_st)?
.truncate(1 << self.k)?
.a2b()?
.b2a(st)?;
let c_truncated_mod = {
let mask = g
.constant(
scalar_type(st),
Value::from_scalar((1u128 << (st_size - 1 - self.k)) - 1, st)?,
)?
.a2b()?;
c_truncated.a2b()?.multiply(mask)?.b2a(st)?
};
let c_msb = c
.a2b()?
.b2a(unsigned_st)?
.truncate(1 << (st_size - 1))?
.a2b()?
.b2a(st)?;
let two = constant_scalar(&g, 2, st)?;
let b0 = r_msb0
.subtract(r_msb0.multiply(c_msb.clone())?.multiply(two.clone())?)?
.add(c_msb.clone())?;
let b1 = r_msb1.subtract(r_msb1.multiply(c_msb)?.multiply(two)?)?;
let power2 = constant_scalar(&g, 1u128 << (st_size - 1 - self.k), st)?;
let y_prime0 = b0
.multiply(power2.clone())?
.subtract(r_truncated0)?
.add(c_truncated_mod)?;
let y_prime1 = b1.multiply(power2)?.subtract(r_truncated1)?;
let y_tilde0 = y_prime0.subtract(y0.clone())?;
let y_tilde0_sent = y_tilde0.nop()?;
y_tilde0_sent.add_annotation(NodeAnnotation::Send(0, 1))?;
let y_tilde1 = y_prime1.subtract(y2.clone())?;
let y_tilde1_sent = y_tilde1.nop()?;
y_tilde1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
let y1 = {
let sum01 = y_tilde0_sent.add(y_tilde1_sent)?;
if st.is_signed() {
let mod_fraction = constant_scalar(&g, 1u128 << (st_size - 2 - self.k), st)?;
sum01.subtract(mod_fraction)?
} else {
sum01
}
};
g.create_tuple(vec![y0, y1, y2])?.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!("TruncateMPC2K({})", self.k)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::{add_u128, subtract_vectors_u128};
use crate::data_types::{array_type, scalar_type, ScalarType, INT128, UINT128};
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::util::simple_context;
use crate::inline::inline_ops::{InlineConfig, InlineMode};
use crate::mpc::mpc_compiler::{prepare_for_mpc_evaluation, IOStatus, PARTIES};
fn prepare_context(
t: Type,
party_id: IOStatus,
output_parties: Vec<IOStatus>,
scale: u128,
inline_config: InlineConfig,
) -> Result<Context> {
let c = simple_context(|g| {
let i = g.input(t)?;
g.truncate(i, scale)
})?;
prepare_for_mpc_evaluation(c, vec![vec![party_id]], vec![output_parties], inline_config)
}
fn prepare_input(input: Vec<u128>, input_status: IOStatus, t: Type) -> Result<Vec<Value>> {
let mpc_input = match t {
Type::Scalar(st) => {
if input_status == IOStatus::Public || matches!(input_status, IOStatus::Party(_)) {
return Ok(vec![Value::from_scalar(input[0], st)?]);
}
let mut shares_vec = vec![];
shares_vec.push(Value::from_scalar(
subtract_vectors_u128(&input, &[3], st.get_modulus())?[0],
st,
)?);
for i in 1..PARTIES as u64 {
shares_vec.push(Value::from_scalar(i, st)?);
}
shares_vec
}
Type::Array(_, st) => {
if input_status == IOStatus::Public || matches!(input_status, IOStatus::Party(_)) {
return Ok(vec![Value::from_flattened_array(&input, st)?]);
}
let mut shares_vec = vec![];
let threes = vec![3; input.len()];
let first_share = subtract_vectors_u128(&input, &threes, st.get_modulus())?;
shares_vec.push(Value::from_flattened_array(&first_share, st)?);
for i in 1..PARTIES {
let share = vec![i; input.len()];
shares_vec.push(Value::from_flattened_array(&share, st)?);
}
shares_vec
}
_ => {
panic!("Shouldn't be here");
}
};
Ok(vec![Value::from_vector(mpc_input)])
}
fn compare_truncate_output(
output: &[u128],
expected: &[u128],
equal: bool,
st: ScalarType,
) -> Result<()> {
if st.is_signed() {
for (i, out_value) in output.iter().enumerate() {
let mut dif = (*out_value) as i64 - expected[i] as i64;
dif = dif.abs();
if equal && dif > 1 {
return Err(runtime_error!("Output is too far from expected"));
}
if !equal && dif <= 1 {
return Err(runtime_error!("Output is too close to expected"));
}
}
} else {
for (i, out_value) in output.iter().enumerate() {
let dif = (*out_value) - expected[i];
if equal && dif > 1 {
return Err(runtime_error!("Output is too far from expected"));
}
if !equal && dif <= 1 {
return Err(runtime_error!("Output is too close to expected"));
}
}
}
Ok(())
}
fn check_output(
mpc_graph: Graph,
inputs: Vec<Value>,
expected: Vec<u128>,
output_parties: Vec<IOStatus>,
t: Type,
) -> Result<()> {
let output = random_evaluate(mpc_graph.clone(), inputs)?;
let st = t.get_scalar_type();
if output_parties.is_empty() {
let out = output.access_vector(|v| {
let modulus = st.get_modulus();
let mut res = vec![0; expected.len()];
for val in v {
let arr = match t.clone() {
Type::Scalar(_) => {
vec![val.to_u128(st)?]
}
Type::Array(_, _) => val.to_flattened_array_u128(t.clone())?,
_ => {
panic!("Shouldn't be here");
}
};
for i in 0..expected.len() {
res[i] = add_u128(res[i], arr[i], modulus);
}
}
Ok(res)
})?;
compare_truncate_output(&out, &expected, true, st)?;
} else {
assert!(output.check_type(t.clone())?);
let out = match t.clone() {
Type::Scalar(_) => vec![output.to_u128(st)?],
Type::Array(_, _) => output.to_flattened_array_u128(t.clone())?,
_ => {
panic!("Shouldn't be here");
}
};
compare_truncate_output(&out, &expected, true, st)?;
}
Ok(())
}
fn truncate_helper(st: ScalarType, scale: u128) -> Result<()> {
let helper = |t: Type,
input: Vec<u128>,
input_status: IOStatus,
output_parties: Vec<IOStatus>,
inline_config: InlineConfig|
-> Result<()> {
let mpc_context = prepare_context(
t.clone(),
input_status.clone(),
output_parties.clone(),
scale,
inline_config,
)?;
let mpc_graph = mpc_context.get_main_graph()?;
let mpc_input = prepare_input(input.clone(), input_status.clone(), t.clone())?;
let expected = if t.get_scalar_type().is_signed() {
input
.iter()
.map(|x| {
let val = *x as i64;
let res = val / (scale as i64);
res as u128
})
.collect()
} else {
input
.iter()
.map(|x| {
let val = *x;
let res = val / (scale as u128);
res
})
.collect()
};
check_output(mpc_graph, mpc_input, expected, output_parties, t.clone())?;
Ok(())
};
let inline_config_simple = InlineConfig {
default_mode: InlineMode::Simple,
..Default::default()
};
let helper_runs = |inputs: Vec<u128>, t: Type| -> Result<()> {
helper(
t.clone(),
inputs.clone(),
IOStatus::Party(2),
vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
inline_config_simple.clone(),
)?;
helper(
t.clone(),
inputs.clone(),
IOStatus::Shared,
vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
inline_config_simple.clone(),
)?;
helper(
t.clone(),
inputs.clone(),
IOStatus::Party(2),
vec![IOStatus::Party(0)],
inline_config_simple.clone(),
)?;
helper(
t.clone(),
inputs.clone(),
IOStatus::Party(2),
vec![],
inline_config_simple.clone(),
)?;
helper(
t.clone(),
inputs.clone(),
IOStatus::Public,
vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
inline_config_simple.clone(),
)?;
helper(
t.clone(),
inputs.clone(),
IOStatus::Public,
vec![],
inline_config_simple.clone(),
)?;
Ok(())
};
let helper_malformed = |inputs: Vec<u128>, t: Type, runs: u64| -> Result<()> {
for _ in 0..runs {
helper_runs(inputs.clone(), t.clone())?;
}
Ok(())
};
helper_runs(vec![0], scalar_type(st))?;
helper_runs(vec![1000], scalar_type(st))?;
helper_runs(vec![0, 0], array_type(vec![2], st))?;
helper_runs(vec![2000, 255], array_type(vec![2], st))?;
if scale.is_power_of_two() && !st.is_signed() {
helper_runs(vec![(1u128 << 127) - 1], scalar_type(st))?;
}
if st.is_signed() {
helper_runs(vec![u128::MAX], scalar_type(st))?;
helper_runs(vec![u128::MAX - 999], scalar_type(st))?;
helper_runs(
vec![u128::MAX - 9, u128::MAX - 1023],
array_type(vec![2], st),
)?;
if scale.is_power_of_two() {
helper_runs(vec![-(1i128 << 126) as u128], scalar_type(st))?;
helper_runs(vec![(1u128 << 126) - 1], scalar_type(st))?;
}
}
if scale != 1 && !scale.is_power_of_two() {
assert!(helper_malformed(vec![i128::MAX as u128], scalar_type(st), 40).is_err());
assert!(helper_malformed(vec![i128::MIN as u128], scalar_type(st), 40).is_err());
assert!(helper_malformed(
vec![i128::MAX as u128, i128::MAX as u128 - 1],
array_type(vec![2], st),
40
)
.is_err());
assert!(helper_malformed(
vec![1u128 << 127, (1u128 << 127) + 1],
array_type(vec![2], st),
40
)
.is_err());
}
Ok(())
}
#[test]
fn test_truncate() -> Result<()> {
truncate_helper(UINT128, 1)?;
truncate_helper(UINT128, 1 << 3)?;
truncate_helper(UINT128, 1 << 7)?;
truncate_helper(UINT128, 1 << 29)?;
truncate_helper(UINT128, 1 << 31)?;
truncate_helper(INT128, 1)?;
truncate_helper(INT128, 15)?;
truncate_helper(INT128, 1 << 3)?;
truncate_helper(INT128, 1 << 7)?;
truncate_helper(INT128, 1 << 29)?;
truncate_helper(INT128, (1 << 29) - 1)?;
assert!(truncate_helper(UINT128, 15).is_err());
Ok(())
}
}