ciphercore_base/mpc/mpc_truncate.rs
1use crate::custom_ops::CustomOperationBody;
2use crate::data_types::{array_type, scalar_size_in_bits, scalar_type, Type, BIT};
3use crate::data_values::Value;
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node, NodeAnnotation};
6use crate::mpc::mpc_compiler::{check_private_tuple, KEY_LENGTH, PARTIES};
7use crate::ops::utils::constant_scalar;
8
9use serde::{Deserialize, Serialize};
10
11/// Truncate MPC operation for public and private data.
12///
13/// In contrast to plaintext Truncate, this operation might introduce 2 types of errors:
14/// 1. 1 bit of additive error in LSB.
15/// This bit comes from the fact that truncating the addends of the sum a = b + c by d bits
16/// can remove a carry bit propagated to the (d+1)-th bit of the sum.
17/// E.g., truncating the addends of 2 = 1 + 1 by 2 results in 1/2 + 1/2 = 0 != 2/2.
18/// 2. Additive error in MSBs.
19/// Since addition is done modulo 2^m, every sum can be written as a = b + c +- k * 2^m with k in {0,1}.
20/// But the truncation result is b/scale + c/scale = (a + k * 2^m)/scale. If k = 1, the error is 2^m/scale.
21/// The probability of this error is
22/// * 1 - (a + 1) / 2^m for unsigned types,
23/// * (|a| - 1) / m, if a < 0 and (a + 1) / m, if a >= 0 for signed types.
24/// Therefore, this operation supports only signed types with a warning
25/// that it fails with probability < 2^(l-m) when |a| < 2^l.
26///
27#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
28pub(super) struct TruncateMPC {
29 pub scale: u128,
30}
31
32#[typetag::serde]
33impl CustomOperationBody for TruncateMPC {
34 fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
35 if argument_types.len() == 1 {
36 if let Type::Array(_, st) | Type::Scalar(st) = argument_types[0].clone() {
37 if !st.is_signed() {
38 return Err(runtime_error!(
39 "Only signed types are supported by TruncateMPC"
40 ));
41 }
42 let g = context.create_graph()?;
43 let input = g.input(argument_types[0].clone())?;
44 let o = if self.scale == 1 {
45 // Do nothing if scale is 1
46 input
47 } else {
48 input.truncate(self.scale)?
49 };
50 o.set_as_output()?;
51 g.finalize()?;
52 return Ok(g);
53 } else {
54 // Panics since:
55 // - the user has no direct access to this function.
56 // - the MPC compiler should pass correct arguments
57 // and this panic should never happen.
58 panic!("Inconsistency with type checker");
59 }
60 }
61 if argument_types.len() != 2 {
62 return Err(runtime_error!(
63 "TruncateMPC should have either 1 or 2 inputs."
64 ));
65 }
66
67 if let (Type::Tuple(v0), Type::Tuple(v1)) =
68 (argument_types[0].clone(), argument_types[1].clone())
69 {
70 check_private_tuple(v0)?;
71 check_private_tuple(v1)?;
72 } else {
73 return Err(runtime_error!(
74 "TruncateMPC should have a private tuple and a tuple of keys as input"
75 ));
76 }
77
78 let t = argument_types[0].clone();
79 let input_t = if let Type::Tuple(t_vec) = t.clone() {
80 (*t_vec[0]).clone()
81 } else {
82 panic!("Shouldn't be here");
83 };
84 if !input_t.get_scalar_type().is_signed() {
85 return Err(runtime_error!(
86 "Only signed types are supported by TruncateMPC"
87 ));
88 }
89
90 let g = context.create_graph()?;
91 let input_node = g.input(t)?;
92
93 let prf_type = argument_types[1].clone();
94 let prf_keys = g.input(prf_type)?;
95
96 // Do nothing if scale is 1.
97 if self.scale == 1 {
98 input_node.set_as_output()?;
99 g.finalize()?;
100 return Ok(g);
101 }
102
103 // Generate shares of a random value r = PRF_k(v) where k is known to parties 1 and 2 (it's the last key in the key triple).
104 let prf_key_parties_12 = prf_keys.tuple_get(PARTIES as u64 - 1)?;
105 let random_node = g.prf(prf_key_parties_12, 0, input_t)?;
106
107 let mut result_shares = vec![];
108 // 1st share of the result is the truncated 1st share of the input
109 let res0 = input_node.tuple_get(0)?.truncate(self.scale)?;
110 result_shares.push(res0);
111 // 2nd share of the results is the truncated sum of the 2nd and 3rd input shares minus r
112 let res1 = input_node
113 .tuple_get(1)?
114 .add(input_node.tuple_get(2)?)?
115 .truncate(self.scale)?
116 .subtract(random_node.clone())?;
117 let res1_sent = res1.nop()?;
118 // 2nd share should be sent to party 0
119 res1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
120 result_shares.push(res1_sent);
121 // 3rd share of the result is the random value r
122 result_shares.push(random_node);
123
124 g.create_tuple(result_shares)?.set_as_output()?;
125
126 g.finalize()?;
127 Ok(g)
128 }
129
130 fn get_name(&self) -> String {
131 format!("TruncateMPC({})", self.scale)
132 }
133}
134
135/// Truncate MPC operation for public and private data by a power of 2.
136///
137/// Signed input integers must be from the range [-modulus/4, modulus/4)
138/// and unsigned integers must be in the range [0, modulus/2) where modulus is the modulus of the input scalar type.
139///
140/// This algorithm returns floor(x/2^k) + w where w = 1 with probability (x mod 2^k)/2^k, otherwise w=0.
141/// So the result is biased to round(x/2^k).
142///
143/// The corresponding protocol is described [here](https://eprint.iacr.org/2019/131.pdf#page=10) and runs as follows.
144/// 0. The below protocol works correctly for integers in the range [0, modulus/2).
145/// For signed inputs, we add modulus/4 to input resulting in [0, modulus/2).
146/// For correctness, we should remove modulus/2^(k+2) after truncation since
147/// Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2).
148///
149/// Let x = (x0, x1, x2) is the 2-out-of-3 sharing of the (possibly, shifted) input.
150/// k_2 is a PRF key that is held only by party 2.
151/// k_02 is a PRF key that is held only by parties 0 and 2.
152/// k_12 is a PRF key that is held only by parties 1 and 2.
153/// The keys k_02 and k_12 are re-used multiplication keys.
154/// 1. Party 2 generates a random integer r of the input scalar type.
155/// 2. Party 2 extracts the MSB of r in the arithmetic form (r_msb).
156/// 3. Party 2 removes the MSB of r and truncates the result by k bits (r_truncated = sum_(i=k)^(s-2) r_i * 2^(i-k) where s is the bitsize of the input scalar type)
157/// 4. Party 2 creates 2-out-of-2 shares of r, r_msb and r_truncated.
158/// Such shares for a value val have the form (val0, val1) such that val = val0 + val1.
159/// The corresponding share val0 = PRF(k_02, iv_val0) of the aforementioned 3 values is generated by parties 0 and 2.
160/// The second share val1 = val - val0 is computed by party 2 and then it is sent to party 1.
161/// 5. Parties 0 and 2 compute y0 = PRF(key_02, iv_y0).
162/// Parties 1 and 2 compute y2 = PRF(key_12, iv_y2).
163/// The pair (y0, y2) is a 2-out-of-3 share of the output known to party 2.
164/// 6. Parties 0 and 1 create a 2-out-of-2 share of the input x.
165/// To obtain its share, party 0 sums its 2-out-of-3 shares to get z0 = x0 + x1.
166/// Party 1 takes z1 = x2.
167/// 7. Given r from party 2, parties 0 and 1 compute 2-out-of-2 shares of c = x + r via c0 = z0 + r0 and c1 = z1 + r1.
168/// 8. Parties 0 and 1 reveal c to each other and compute c_truncated_mod = (c/2^k) mod 2^(s-k-1).
169/// This is c truncated by k bits without its MSB.
170/// 9. Parties 0 and 1 compute the MSB of c via c/2^(s-1).
171/// 10. Parties 0 and 1 compute 2-out-of-2 shares of b = r_msb XOR c_msb using the following expressions:
172/// b0 = r_msb0 + c_msb - 2 * c_msb * r_msb0,
173/// b1 = r_msb1 - 2 * c_msb * r_msb1.
174/// Note that b0 + b1 = r_msb + c_msb - 2*c_msb*r_msb = r_msb XOR c_msb.
175/// All the above operations can be done locally as c_msb is known to parties 0 and 1.
176/// 11. Parties 0 and 1 compute 2-out-of-2 shares of y' = c_truncated_mod - r_truncated + b * 2^(st_size-1-k).
177/// This value is equal to the desired result floor(x/2^k) + w.
178/// 12. Party 0 masks y'0 with a random value y0 from party 2 as y_tilde0 = y'0 - y0 and sends it to party 1.
179/// 13. Party 1 masks y'1 with a random value y2 from party 2 as y_tilde1 = y'1 - y2 and sends it to party 0.
180/// 14. Parties 0 and 1 compute y1 = y_tilde0 + y_tilde1 = y' - y0 - y2.
181/// Together with y0 and y2 this value constitute the sharing of the truncation output.
182/// 14!. If input is signed, we should remove modulus/2^(k+2) after truncation since
183/// Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2) as in Step 0.
184/// 15. The protocol returns (y0, y1, y2).
185#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
186pub(super) struct TruncateMPC2K {
187 pub k: u64,
188}
189
190#[typetag::serde]
191impl CustomOperationBody for TruncateMPC2K {
192 fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
193 if argument_types.len() == 1 {
194 if let Type::Array(_, _) | Type::Scalar(_) = argument_types[0].clone() {
195 let g = context.create_graph()?;
196 let input = g.input(argument_types[0].clone())?;
197 let o = if self.k == 0 {
198 // Do nothing if scale is 1
199 input
200 } else {
201 input.truncate(1 << self.k)?
202 };
203 o.set_as_output()?;
204 g.finalize()?;
205 return Ok(g);
206 } else {
207 // Panics since:
208 // - the user has no direct access to this function.
209 // - the MPC compiler should pass correct arguments
210 // and this panic should never happen.
211 panic!("Inconsistency with type checker");
212 }
213 }
214 if argument_types.len() != 3 {
215 return Err(runtime_error!("TruncateMPC2K should have 3 inputs."));
216 }
217 if let Type::Tuple(v0) = argument_types[0].clone() {
218 check_private_tuple(v0)?;
219 } else {
220 if !argument_types[0].is_array() && !argument_types[0].is_scalar() {
221 // Panics since:
222 // - the user has no direct access to this function.
223 // - the MPC compiler should pass correct arguments
224 // and this panic should never happen.
225 panic!("Inconsistency with type checker");
226 }
227 let g = context.create_graph()?;
228 let input = g.input(argument_types[0].clone())?;
229 let o = input.truncate(1 << self.k)?;
230 o.set_as_output()?;
231 g.finalize()?;
232 return Ok(g);
233 }
234
235 // Check PRF keys
236 let key_type = array_type(vec![KEY_LENGTH], BIT);
237 if let Type::Tuple(v0) = argument_types[1].clone() {
238 check_private_tuple(v0.clone())?;
239 for t in v0 {
240 if *t != key_type {
241 return Err(runtime_error!("PRF key is of a wrong type"));
242 }
243 }
244 } else {
245 return Err(runtime_error!("PRF key is of a wrong type"));
246 }
247 if argument_types[2] != key_type {
248 return Err(runtime_error!("PRF key is of a wrong type"));
249 }
250
251 let t = argument_types[0].clone();
252 let input_t = if let Type::Tuple(t_vec) = t.clone() {
253 (*t_vec[0]).clone()
254 } else {
255 panic!("Shouldn't be here");
256 };
257 if !input_t.is_array() && !input_t.is_scalar() {
258 // Panics since:
259 // - the user has no direct access to this function.
260 // - the MPC compiler should pass correct arguments
261 // and this panic should never happen.
262 panic!("Inconsistency with type checker");
263 }
264
265 let g = context.create_graph()?;
266 let input_node = g.input(t)?;
267
268 // PRF keys
269 let prf_mul_type = argument_types[1].clone();
270 let prf_mul_keys = g.input(prf_mul_type)?;
271 let prf_truncate_type = argument_types[2].clone();
272 // PRF key k_2
273 let key_2 = g.input(prf_truncate_type)?;
274
275 if self.k == 0 {
276 input_node.set_as_output()?;
277 g.finalize()?;
278 return Ok(g);
279 }
280 // PRF key k_02, this is the last key in the multiplication PRF key triple
281 let key_02 = prf_mul_keys.tuple_get(0)?;
282 // PRF key k_12, this is the second key in the multiplication PRF key triple
283 let key_12 = prf_mul_keys.tuple_get(2)?;
284
285 let st = input_t.get_scalar_type();
286 let st_size = scalar_size_in_bits(st);
287
288 let x0 = {
289 let share = input_node.tuple_get(0)?;
290 // 0. The below protocol works correctly for integers in the range [0, modulus/2).
291 // For signed inputs, we add modulus/4 to input resulting in input + modulus/4 in [0, modulus/2)
292 // For correctness, we should remove modulus/2^(k+2) after truncation since
293 // Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2)
294 if st.is_signed() {
295 // modulus/4
296 let mod_fraction = constant_scalar(&g, 1u128 << (st_size - 2), st)?;
297 share.add(mod_fraction)?
298 } else {
299 share
300 }
301 };
302 let x1 = input_node.tuple_get(1)?;
303 let x2 = input_node.tuple_get(2)?;
304
305 // 1. Party 2 generates a random integer r of the input scalar type.
306 let r = g.prf(key_2, 0, input_t.clone())?;
307
308 let unsigned_st = st.get_unsigned_counterpart();
309 // 2. Party 2 extracts the MSB of r in the arithmetic form (r_msb).
310 let r_msb = {
311 // (0,0, ..., 1)
312 let mask = constant_scalar(&g, 1u128 << (st_size - 1), unsigned_st)?.a2b()?;
313 // (0,0, ..., r_(st_size-1)) -> r_(st_size-1)*2^(st_size-1) as unsigned integer
314 let r_msb_scaled = r.a2b()?.multiply(mask)?.b2a(unsigned_st)?;
315 // (r_(st_size-1), 0, ..., 0) -> r_(st_size-1) of st type
316 r_msb_scaled.truncate(1 << (st_size - 1))?.a2b()?.b2a(st)?
317 };
318
319 // 3. Party 2 removes the MSB of r and truncates the result by k bits (r_truncated = sum_(i=k)^(st_size-2) r_i * 2^(i-k))
320 let r_truncated = {
321 // (0, ..., 0, 1, ..., 1, 0, ..., 0) to extract r_k, r_(k+1), ..., r_(st_size-2)
322 let mask = constant_scalar(
323 &g,
324 (1u128 << (st_size - 1)) - (1u128 << self.k),
325 unsigned_st,
326 )?
327 .a2b()?;
328 // r_k + r_(k+1) * 2 + ... + r_(st_size-2) * 2^(st_size-2-k)
329 r.a2b()?.multiply(mask)?.b2a(st)?.truncate(1 << self.k)?
330 };
331
332 // 4. Party 2 creates 2-out-of-2 shares of r, r_msb and r_truncated.
333 // Such shares for a value val have the form (val0, val1) such that val = val0 + val1.
334 // The corresponding share val0 = PRF(k_02, iv) of the aforementioned 3 values is generated by parties 0 and 2.
335 // The second share val1 = val - val0 is computed by party 2 and then it is sent to party 1.
336 let share_for_two = |val: Node| -> Result<(Node, Node)> {
337 // first share val0 for party 0
338 let share0 = g.prf(key_02.clone(), 0, val.get_type()?)?;
339 // second share val1 for party 1
340 let share1 = val.subtract(share0.clone())?;
341 let share1_sent = share1.nop()?;
342 share1_sent.add_annotation(NodeAnnotation::Send(2, 1))?;
343 Ok((share0, share1_sent))
344 };
345 let (r0, r1) = share_for_two(r)?;
346 let (r_msb0, r_msb1) = share_for_two(r_msb)?;
347 let (r_truncated0, r_truncated1) = share_for_two(r_truncated)?;
348
349 // 5. Parties 0 and 2 compute y0 = PRF(key_02, iv).
350 // Parties 1 and 2 compute y2 = PRF(key_12, iv).
351 // The pair (y0, y2) is a 2-out-of-3 share of the output known to party 2.
352 let y0 = g.prf(key_02, 0, input_t.clone())?;
353 let y2 = g.prf(key_12, 0, input_t)?;
354
355 // 6. Party 0 and Party 1 create a 2-out-of-2 share of the input x.
356 // To obtain its share, party 0 sums its 2-out-of-3 shares to get z0 = x0 + x1. Party 1 takes z1 = x2.
357 let z0 = x0.add(x1)?;
358 let z1 = x2;
359
360 // 7. Given r from party 2, parties 0 and 1 compute 2-out-of-2 shares of c = x + r via c0 = z0 + r0 and c1 = z1 + r1.
361 let c_share0 = z0.add(r0)?;
362 let c_share1 = z1.add(r1)?;
363
364 // 8. Parties 0 and 1 reveal c to each other and compute c_truncated_mod = (c/2^k) mod 2^(st_size-k-1).
365 // This is c truncated by k bits without its MSB.
366 let c_share0_sent = c_share0.nop()?;
367 c_share0_sent.add_annotation(NodeAnnotation::Send(0, 1))?;
368 let c_share1_sent = c_share1.nop()?;
369 c_share1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
370 let c = c_share0_sent.add(c_share1_sent)?;
371 // Interpret c as unsigned integer and truncate
372 // (c / scale) mod 2^(st_size-1-k)
373 let c_truncated = c
374 .a2b()?
375 .b2a(unsigned_st)?
376 .truncate(1 << self.k)?
377 .a2b()?
378 .b2a(st)?;
379 let c_truncated_mod = {
380 // (1,1, ..., 1, 0, ..., 0) to perform mod 2^(st_size-1-k)
381 let mask = g
382 .constant(
383 scalar_type(st),
384 Value::from_scalar((1u128 << (st_size - 1 - self.k)) - 1, st)?,
385 )?
386 .a2b()?;
387 c_truncated.a2b()?.multiply(mask)?.b2a(st)?
388 };
389
390 // 9. Parties 0 and 1 compute the MSB of c via c/2^(st_size-1).
391 let c_msb = c
392 .a2b()?
393 .b2a(unsigned_st)?
394 .truncate(1 << (st_size - 1))?
395 .a2b()?
396 .b2a(st)?;
397
398 // 10. Parties 0 and 1 compute 2-out-of-2 shares of b = r_msb XOR c_msb using the following expressions:
399 // b0 = r_msb0 + c_msb - 2 * c_msb * r_msb0,
400 // b1 = r_msb1 - 2 * c_msb * r_msb1.
401 // Note that b0 + b1 = r_msb + c_msb - 2*c_msb*r_msb = r_msb XOR c_msb.
402 // All the above operations can be done locally as c_msb is known to parties 0 and 1.
403 let two = constant_scalar(&g, 2, st)?;
404 let b0 = r_msb0
405 .subtract(r_msb0.multiply(c_msb.clone())?.multiply(two.clone())?)?
406 .add(c_msb.clone())?;
407 let b1 = r_msb1.subtract(r_msb1.multiply(c_msb)?.multiply(two)?)?;
408
409 // 11. Parties 0 and 1 compute 2-out-of-2 shares of y' = c_truncated_mod - r_truncated + b * 2^(st_size-1-k).
410 // This value is equal to the desired result floor(x/2^k) + w.
411 // 2^(st_size-1-k)
412 let power2 = constant_scalar(&g, 1u128 << (st_size - 1 - self.k), st)?;
413 // y' = c_truncated_mod - r_truncated + b * 2^(st_size-1-k)
414 // This is 2-out-of-2 sharing of the result
415 let y_prime0 = b0
416 .multiply(power2.clone())?
417 .subtract(r_truncated0)?
418 .add(c_truncated_mod)?;
419 let y_prime1 = b1.multiply(power2)?.subtract(r_truncated1)?;
420
421 // 12. Party 0 masks y'0 with a random value y0 from party 2 as y_tilde0 = y'0 - y0 and sends it to party 1.
422 let y_tilde0 = y_prime0.subtract(y0.clone())?;
423 let y_tilde0_sent = y_tilde0.nop()?;
424 y_tilde0_sent.add_annotation(NodeAnnotation::Send(0, 1))?;
425 // 13. Party 1 masks y'1 with a random value y2 from party 2 as y_tilde1 = y'1 - y2 and sends it to party 0.
426 let y_tilde1 = y_prime1.subtract(y2.clone())?;
427 let y_tilde1_sent = y_tilde1.nop()?;
428 y_tilde1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
429
430 // 14. Parties 0 and 1 compute y1 = y_tilde0 + y_tilde1 = y' - y0 - y2.
431 // Together with y0 and y2 this value constitute the sharing of the truncation output.
432 let y1 = {
433 let sum01 = y_tilde0_sent.add(y_tilde1_sent)?;
434 if st.is_signed() {
435 // 14!. If input is signed, we should remove modulus/2^(k+2) after truncation since
436 // Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2)
437 let mod_fraction = constant_scalar(&g, 1u128 << (st_size - 2 - self.k), st)?;
438 sum01.subtract(mod_fraction)?
439 } else {
440 sum01
441 }
442 };
443
444 // 15. The protocol returns (y0, y1, y2).
445 g.create_tuple(vec![y0, y1, y2])?.set_as_output()?;
446
447 g.finalize()?;
448 Ok(g)
449 }
450
451 fn get_name(&self) -> String {
452 format!("TruncateMPC2K({})", self.k)
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::bytes::{add_u128, subtract_vectors_u128};
460 use crate::data_types::{array_type, scalar_type, ScalarType, INT128, UINT128};
461 use crate::data_values::Value;
462 use crate::evaluators::random_evaluate;
463 use crate::graphs::util::simple_context;
464 use crate::inline::inline_ops::{InlineConfig, InlineMode};
465 use crate::mpc::mpc_compiler::{prepare_for_mpc_evaluation, IOStatus, PARTIES};
466
467 fn prepare_context(
468 t: Type,
469 party_id: IOStatus,
470 output_parties: Vec<IOStatus>,
471 scale: u128,
472 inline_config: InlineConfig,
473 ) -> Result<Context> {
474 let c = simple_context(|g| {
475 let i = g.input(t)?;
476 g.truncate(i, scale)
477 })?;
478
479 prepare_for_mpc_evaluation(c, vec![vec![party_id]], vec![output_parties], inline_config)
480 }
481
482 fn prepare_input(input: Vec<u128>, input_status: IOStatus, t: Type) -> Result<Vec<Value>> {
483 let mpc_input = match t {
484 Type::Scalar(st) => {
485 if input_status == IOStatus::Public || matches!(input_status, IOStatus::Party(_)) {
486 return Ok(vec![Value::from_scalar(input[0], st)?]);
487 }
488
489 // shares of input = (input - 3, 1, 2)
490 let mut shares_vec = vec![];
491 shares_vec.push(Value::from_scalar(
492 subtract_vectors_u128(&input, &[3], st.get_modulus())?[0],
493 st,
494 )?);
495
496 for i in 1..PARTIES as u64 {
497 shares_vec.push(Value::from_scalar(i, st)?);
498 }
499 shares_vec
500 }
501 Type::Array(_, st) => {
502 if input_status == IOStatus::Public || matches!(input_status, IOStatus::Party(_)) {
503 return Ok(vec![Value::from_flattened_array(&input, st)?]);
504 }
505
506 // shares of input = (input - 3, 1, 2)
507 let mut shares_vec = vec![];
508 let threes = vec![3; input.len()];
509 let first_share = subtract_vectors_u128(&input, &threes, st.get_modulus())?;
510 shares_vec.push(Value::from_flattened_array(&first_share, st)?);
511
512 for i in 1..PARTIES {
513 let share = vec![i; input.len()];
514 shares_vec.push(Value::from_flattened_array(&share, st)?);
515 }
516 shares_vec
517 }
518 _ => {
519 panic!("Shouldn't be here");
520 }
521 };
522
523 Ok(vec![Value::from_vector(mpc_input)])
524 }
525
526 // output and expected are assumed to be small enough to be converted to i64 slices
527 fn compare_truncate_output(
528 output: &[u128],
529 expected: &[u128],
530 equal: bool,
531 st: ScalarType,
532 ) -> Result<()> {
533 if st.is_signed() {
534 for (i, out_value) in output.iter().enumerate() {
535 let mut dif = (*out_value) as i64 - expected[i] as i64;
536 dif = dif.abs();
537 if equal && dif > 1 {
538 return Err(runtime_error!("Output is too far from expected"));
539 }
540 if !equal && dif <= 1 {
541 return Err(runtime_error!("Output is too close to expected"));
542 }
543 }
544 } else {
545 for (i, out_value) in output.iter().enumerate() {
546 let dif = (*out_value) - expected[i];
547 if equal && dif > 1 {
548 return Err(runtime_error!("Output is too far from expected"));
549 }
550 if !equal && dif <= 1 {
551 return Err(runtime_error!("Output is too close to expected"));
552 }
553 }
554 }
555
556 Ok(())
557 }
558
559 fn check_output(
560 mpc_graph: Graph,
561 inputs: Vec<Value>,
562 expected: Vec<u128>,
563 output_parties: Vec<IOStatus>,
564 t: Type,
565 ) -> Result<()> {
566 let output = random_evaluate(mpc_graph.clone(), inputs)?;
567 let st = t.get_scalar_type();
568
569 if output_parties.is_empty() {
570 let out = output.access_vector(|v| {
571 let modulus = st.get_modulus();
572 let mut res = vec![0; expected.len()];
573 for val in v {
574 let arr = match t.clone() {
575 Type::Scalar(_) => {
576 vec![val.to_u128(st)?]
577 }
578 Type::Array(_, _) => val.to_flattened_array_u128(t.clone())?,
579 _ => {
580 panic!("Shouldn't be here");
581 }
582 };
583 for i in 0..expected.len() {
584 res[i] = add_u128(res[i], arr[i], modulus);
585 }
586 }
587 Ok(res)
588 })?;
589 compare_truncate_output(&out, &expected, true, st)?;
590 } else {
591 assert!(output.check_type(t.clone())?);
592 let out = match t.clone() {
593 Type::Scalar(_) => vec![output.to_u128(st)?],
594 Type::Array(_, _) => output.to_flattened_array_u128(t.clone())?,
595 _ => {
596 panic!("Shouldn't be here");
597 }
598 };
599 compare_truncate_output(&out, &expected, true, st)?;
600 }
601
602 Ok(())
603 }
604
605 fn truncate_helper(st: ScalarType, scale: u128) -> Result<()> {
606 let helper = |t: Type,
607 input: Vec<u128>,
608 input_status: IOStatus,
609 output_parties: Vec<IOStatus>,
610 inline_config: InlineConfig|
611 -> Result<()> {
612 let mpc_context = prepare_context(
613 t.clone(),
614 input_status.clone(),
615 output_parties.clone(),
616 scale,
617 inline_config,
618 )?;
619 let mpc_graph = mpc_context.get_main_graph()?;
620
621 let mpc_input = prepare_input(input.clone(), input_status.clone(), t.clone())?;
622
623 let expected = if t.get_scalar_type().is_signed() {
624 input
625 .iter()
626 .map(|x| {
627 let val = *x as i64;
628 let res = val / (scale as i64);
629 res as u128
630 })
631 .collect()
632 } else {
633 input
634 .iter()
635 .map(|x| {
636 let val = *x;
637 let res = val / (scale as u128);
638 res
639 })
640 .collect()
641 };
642 check_output(mpc_graph, mpc_input, expected, output_parties, t.clone())?;
643
644 Ok(())
645 };
646 let inline_config_simple = InlineConfig {
647 default_mode: InlineMode::Simple,
648 ..Default::default()
649 };
650 let helper_runs = |inputs: Vec<u128>, t: Type| -> Result<()> {
651 helper(
652 t.clone(),
653 inputs.clone(),
654 IOStatus::Party(2),
655 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
656 inline_config_simple.clone(),
657 )?;
658 helper(
659 t.clone(),
660 inputs.clone(),
661 IOStatus::Shared,
662 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
663 inline_config_simple.clone(),
664 )?;
665 helper(
666 t.clone(),
667 inputs.clone(),
668 IOStatus::Party(2),
669 vec![IOStatus::Party(0)],
670 inline_config_simple.clone(),
671 )?;
672 helper(
673 t.clone(),
674 inputs.clone(),
675 IOStatus::Party(2),
676 vec![],
677 inline_config_simple.clone(),
678 )?;
679 helper(
680 t.clone(),
681 inputs.clone(),
682 IOStatus::Public,
683 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
684 inline_config_simple.clone(),
685 )?;
686 helper(
687 t.clone(),
688 inputs.clone(),
689 IOStatus::Public,
690 vec![],
691 inline_config_simple.clone(),
692 )?;
693 Ok(())
694 };
695 // This test should fail with a probability depending on input and the number of runs
696 let helper_malformed = |inputs: Vec<u128>, t: Type, runs: u64| -> Result<()> {
697 for _ in 0..runs {
698 helper_runs(inputs.clone(), t.clone())?;
699 }
700 Ok(())
701 };
702
703 helper_runs(vec![0], scalar_type(st))?;
704 helper_runs(vec![1000], scalar_type(st))?;
705 helper_runs(vec![0, 0], array_type(vec![2], st))?;
706 helper_runs(vec![2000, 255], array_type(vec![2], st))?;
707
708 if scale.is_power_of_two() && !st.is_signed() {
709 // 2^127 - 1, this is a maximal UINT64 value that can be truncated without errors by TruncateMPC2K
710 helper_runs(vec![(1u128 << 127) - 1], scalar_type(st))?;
711 }
712
713 if st.is_signed() {
714 // -1
715 helper_runs(vec![u128::MAX], scalar_type(st))?;
716 // -1000
717 helper_runs(vec![u128::MAX - 999], scalar_type(st))?;
718 // [-10. -1024]
719 helper_runs(
720 vec![u128::MAX - 9, u128::MAX - 1023],
721 array_type(vec![2], st),
722 )?;
723 if scale.is_power_of_two() {
724 // - 2^126, this is a minimal INT128 value that can be truncated without errors by TruncateMPC2K
725 helper_runs(vec![-(1i128 << 126) as u128], scalar_type(st))?;
726 // 2^126-1, this is a maximal INT128 value that can be truncated without errors by TruncateMPC2K
727 helper_runs(vec![(1u128 << 126) - 1], scalar_type(st))?;
728 }
729 }
730
731 // Probabilistic tests of TruncateMPC for big values in absolute size
732 if scale != 1 && !scale.is_power_of_two() {
733 // 2^127 - 1, should fail with probability 1 - 2^(-40)
734 assert!(helper_malformed(vec![i128::MAX as u128], scalar_type(st), 40).is_err());
735 // -2^127, should fail with probability 1 - 2^(-40)
736 assert!(helper_malformed(vec![i128::MIN as u128], scalar_type(st), 40).is_err());
737 // [2^127 - 1, 2^127 - 2]
738 assert!(helper_malformed(
739 vec![i128::MAX as u128, i128::MAX as u128 - 1],
740 array_type(vec![2], st),
741 40
742 )
743 .is_err());
744 // [-2^127, -2^127 + 1]
745 assert!(helper_malformed(
746 vec![1u128 << 127, (1u128 << 127) + 1],
747 array_type(vec![2], st),
748 40
749 )
750 .is_err());
751 }
752 Ok(())
753 }
754
755 #[test]
756 fn test_truncate() -> Result<()> {
757 truncate_helper(UINT128, 1)?;
758 truncate_helper(UINT128, 1 << 3)?;
759 truncate_helper(UINT128, 1 << 7)?;
760 truncate_helper(UINT128, 1 << 29)?;
761 truncate_helper(UINT128, 1 << 31)?;
762
763 truncate_helper(INT128, 1)?;
764 truncate_helper(INT128, 15)?;
765 truncate_helper(INT128, 1 << 3)?;
766 truncate_helper(INT128, 1 << 7)?;
767 truncate_helper(INT128, 1 << 29)?;
768 truncate_helper(INT128, (1 << 29) - 1)?;
769
770 assert!(truncate_helper(UINT128, 15).is_err());
771 Ok(())
772 }
773}