ciphercore_base/mpc/
mpc_truncate.rs

1use crate::custom_ops::CustomOperationBody;
2use crate::data_types::{array_type, scalar_size_in_bits, scalar_type, Type, BIT};
3use crate::data_values::Value;
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node, NodeAnnotation};
6use crate::mpc::mpc_compiler::{check_private_tuple, KEY_LENGTH, PARTIES};
7use crate::ops::utils::constant_scalar;
8
9use serde::{Deserialize, Serialize};
10
11/// Truncate MPC operation for public and private data.
12///
13/// In contrast to plaintext Truncate, this operation might introduce 2 types of errors:
14/// 1. 1 bit of additive error in LSB.
15///    This bit comes from the fact that truncating the addends of the sum a = b + c by d bits
16///    can remove a carry bit propagated to the (d+1)-th bit of the sum.
17///    E.g., truncating the addends of 2 = 1 + 1 by 2 results in 1/2 + 1/2 = 0 != 2/2.
18/// 2. Additive error in MSBs.
19///    Since addition is done modulo 2^m, every sum can be written as a = b + c +- k * 2^m with k in {0,1}.
20///    But the truncation result is b/scale + c/scale = (a + k * 2^m)/scale. If k = 1, the error is 2^m/scale.
21///    The probability of this error is
22///    * 1 - (a + 1) / 2^m for unsigned types,
23///    * (|a| - 1) / m, if a < 0 and (a + 1) / m, if a >= 0 for signed types.  
24///    Therefore, this operation supports only signed types with a warning
25///    that it fails with probability < 2^(l-m) when |a| < 2^l.
26///
27#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
28pub(super) struct TruncateMPC {
29    pub scale: u128,
30}
31
32#[typetag::serde]
33impl CustomOperationBody for TruncateMPC {
34    fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
35        if argument_types.len() == 1 {
36            if let Type::Array(_, st) | Type::Scalar(st) = argument_types[0].clone() {
37                if !st.is_signed() {
38                    return Err(runtime_error!(
39                        "Only signed types are supported by TruncateMPC"
40                    ));
41                }
42                let g = context.create_graph()?;
43                let input = g.input(argument_types[0].clone())?;
44                let o = if self.scale == 1 {
45                    // Do nothing if scale is 1
46                    input
47                } else {
48                    input.truncate(self.scale)?
49                };
50                o.set_as_output()?;
51                g.finalize()?;
52                return Ok(g);
53            } else {
54                // Panics since:
55                // - the user has no direct access to this function.
56                // - the MPC compiler should pass correct arguments
57                // and this panic should never happen.
58                panic!("Inconsistency with type checker");
59            }
60        }
61        if argument_types.len() != 2 {
62            return Err(runtime_error!(
63                "TruncateMPC should have either 1 or 2 inputs."
64            ));
65        }
66
67        if let (Type::Tuple(v0), Type::Tuple(v1)) =
68            (argument_types[0].clone(), argument_types[1].clone())
69        {
70            check_private_tuple(v0)?;
71            check_private_tuple(v1)?;
72        } else {
73            return Err(runtime_error!(
74                "TruncateMPC should have a private tuple and a tuple of keys as input"
75            ));
76        }
77
78        let t = argument_types[0].clone();
79        let input_t = if let Type::Tuple(t_vec) = t.clone() {
80            (*t_vec[0]).clone()
81        } else {
82            panic!("Shouldn't be here");
83        };
84        if !input_t.get_scalar_type().is_signed() {
85            return Err(runtime_error!(
86                "Only signed types are supported by TruncateMPC"
87            ));
88        }
89
90        let g = context.create_graph()?;
91        let input_node = g.input(t)?;
92
93        let prf_type = argument_types[1].clone();
94        let prf_keys = g.input(prf_type)?;
95
96        // Do nothing if scale is 1.
97        if self.scale == 1 {
98            input_node.set_as_output()?;
99            g.finalize()?;
100            return Ok(g);
101        }
102
103        // Generate shares of a random value r = PRF_k(v) where k is known to parties 1 and 2 (it's the last key in the key triple).
104        let prf_key_parties_12 = prf_keys.tuple_get(PARTIES as u64 - 1)?;
105        let random_node = g.prf(prf_key_parties_12, 0, input_t)?;
106
107        let mut result_shares = vec![];
108        // 1st share of the result is the truncated 1st share of the input
109        let res0 = input_node.tuple_get(0)?.truncate(self.scale)?;
110        result_shares.push(res0);
111        // 2nd share of the results is the truncated sum of the 2nd and 3rd input shares minus r
112        let res1 = input_node
113            .tuple_get(1)?
114            .add(input_node.tuple_get(2)?)?
115            .truncate(self.scale)?
116            .subtract(random_node.clone())?;
117        let res1_sent = res1.nop()?;
118        // 2nd share should be sent to party 0
119        res1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
120        result_shares.push(res1_sent);
121        // 3rd share of the result is the random value r
122        result_shares.push(random_node);
123
124        g.create_tuple(result_shares)?.set_as_output()?;
125
126        g.finalize()?;
127        Ok(g)
128    }
129
130    fn get_name(&self) -> String {
131        format!("TruncateMPC({})", self.scale)
132    }
133}
134
135/// Truncate MPC operation for public and private data by a power of 2.
136///
137/// Signed input integers must be from the range [-modulus/4, modulus/4)
138/// and unsigned integers must be in the range [0, modulus/2) where modulus is the modulus of the input scalar type.
139///  
140/// This algorithm returns floor(x/2^k) + w where w = 1 with probability (x mod 2^k)/2^k, otherwise w=0.
141/// So the result is biased to round(x/2^k).
142///
143/// The corresponding protocol is described [here](https://eprint.iacr.org/2019/131.pdf#page=10) and runs as follows.
144///     0. The below protocol works correctly for integers in the range [0, modulus/2).
145///        For signed inputs, we add modulus/4 to input resulting in [0, modulus/2).
146///        For correctness, we should remove modulus/2^(k+2) after truncation since
147///        Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2).
148///
149/// Let x = (x0, x1, x2) is the 2-out-of-3 sharing of the (possibly, shifted) input.
150/// k_2 is a PRF key that is held only by party 2.
151/// k_02 is a PRF key that is held only by parties 0 and 2.
152/// k_12 is a PRF key that is held only by parties 1 and 2.
153/// The keys k_02 and k_12 are re-used multiplication keys.
154///     1. Party 2 generates a random integer r of the input scalar type.
155///     2. Party 2 extracts the MSB of r in the arithmetic form (r_msb).
156///     3. Party 2 removes the MSB of r and truncates the result by k bits (r_truncated = sum_(i=k)^(s-2) r_i * 2^(i-k) where s is the bitsize of the input scalar type)
157///     4. Party 2 creates 2-out-of-2 shares of r, r_msb and r_truncated.
158///        Such shares for a value val have the form (val0, val1) such that val = val0 + val1.
159///        The corresponding share val0 = PRF(k_02, iv_val0) of the aforementioned 3 values is generated by parties 0 and 2.
160///        The second share val1 = val - val0 is computed by party 2 and then it is sent to party 1.
161///     5. Parties 0 and 2 compute y0 = PRF(key_02, iv_y0).
162///        Parties 1 and 2 compute y2 = PRF(key_12, iv_y2).
163///        The pair (y0, y2) is a 2-out-of-3 share of the output known to party 2.
164///     6. Parties 0 and 1 create a 2-out-of-2 share of the input x.
165///        To obtain its share, party 0 sums its 2-out-of-3 shares to get z0 = x0 + x1.
166///        Party 1 takes z1 = x2.
167///     7. Given r from party 2, parties 0 and 1 compute 2-out-of-2 shares of c = x + r via c0 = z0 + r0 and c1 = z1 + r1.
168///     8. Parties 0 and 1 reveal c to each other and compute c_truncated_mod = (c/2^k) mod 2^(s-k-1).
169///        This is c truncated by k bits without its MSB.
170///     9. Parties 0 and 1 compute the MSB of c via c/2^(s-1).
171///     10. Parties 0 and 1 compute 2-out-of-2 shares of b = r_msb XOR c_msb using the following expressions:
172///             b0 = r_msb0 + c_msb - 2 * c_msb * r_msb0,
173///             b1 = r_msb1 - 2 * c_msb * r_msb1.
174///         Note that b0 + b1 = r_msb + c_msb - 2*c_msb*r_msb = r_msb XOR c_msb.
175///         All the above operations can be done locally as c_msb is known to parties 0 and 1.
176///     11. Parties 0 and 1 compute 2-out-of-2 shares of y' = c_truncated_mod - r_truncated + b * 2^(st_size-1-k).
177///         This value is equal to the desired result floor(x/2^k) + w.
178///     12. Party 0 masks y'0 with a random value y0 from party 2 as y_tilde0 = y'0 - y0 and sends it to party 1.
179///     13. Party 1 masks y'1 with a random value y2 from party 2 as y_tilde1 = y'1 - y2 and sends it to party 0.
180///     14. Parties 0 and 1 compute y1 = y_tilde0 + y_tilde1 = y' - y0 - y2.
181///         Together with y0 and y2 this value constitute the sharing of the truncation output.
182///     14!. If input is signed, we should remove modulus/2^(k+2) after truncation since
183///          Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2) as in Step 0.
184///     15. The protocol returns (y0, y1, y2).
185#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
186pub(super) struct TruncateMPC2K {
187    pub k: u64,
188}
189
190#[typetag::serde]
191impl CustomOperationBody for TruncateMPC2K {
192    fn instantiate(&self, context: Context, argument_types: Vec<Type>) -> Result<Graph> {
193        if argument_types.len() == 1 {
194            if let Type::Array(_, _) | Type::Scalar(_) = argument_types[0].clone() {
195                let g = context.create_graph()?;
196                let input = g.input(argument_types[0].clone())?;
197                let o = if self.k == 0 {
198                    // Do nothing if scale is 1
199                    input
200                } else {
201                    input.truncate(1 << self.k)?
202                };
203                o.set_as_output()?;
204                g.finalize()?;
205                return Ok(g);
206            } else {
207                // Panics since:
208                // - the user has no direct access to this function.
209                // - the MPC compiler should pass correct arguments
210                // and this panic should never happen.
211                panic!("Inconsistency with type checker");
212            }
213        }
214        if argument_types.len() != 3 {
215            return Err(runtime_error!("TruncateMPC2K should have 3 inputs."));
216        }
217        if let Type::Tuple(v0) = argument_types[0].clone() {
218            check_private_tuple(v0)?;
219        } else {
220            if !argument_types[0].is_array() && !argument_types[0].is_scalar() {
221                // Panics since:
222                // - the user has no direct access to this function.
223                // - the MPC compiler should pass correct arguments
224                // and this panic should never happen.
225                panic!("Inconsistency with type checker");
226            }
227            let g = context.create_graph()?;
228            let input = g.input(argument_types[0].clone())?;
229            let o = input.truncate(1 << self.k)?;
230            o.set_as_output()?;
231            g.finalize()?;
232            return Ok(g);
233        }
234
235        // Check PRF keys
236        let key_type = array_type(vec![KEY_LENGTH], BIT);
237        if let Type::Tuple(v0) = argument_types[1].clone() {
238            check_private_tuple(v0.clone())?;
239            for t in v0 {
240                if *t != key_type {
241                    return Err(runtime_error!("PRF key is of a wrong type"));
242                }
243            }
244        } else {
245            return Err(runtime_error!("PRF key is of a wrong type"));
246        }
247        if argument_types[2] != key_type {
248            return Err(runtime_error!("PRF key is of a wrong type"));
249        }
250
251        let t = argument_types[0].clone();
252        let input_t = if let Type::Tuple(t_vec) = t.clone() {
253            (*t_vec[0]).clone()
254        } else {
255            panic!("Shouldn't be here");
256        };
257        if !input_t.is_array() && !input_t.is_scalar() {
258            // Panics since:
259            // - the user has no direct access to this function.
260            // - the MPC compiler should pass correct arguments
261            // and this panic should never happen.
262            panic!("Inconsistency with type checker");
263        }
264
265        let g = context.create_graph()?;
266        let input_node = g.input(t)?;
267
268        // PRF keys
269        let prf_mul_type = argument_types[1].clone();
270        let prf_mul_keys = g.input(prf_mul_type)?;
271        let prf_truncate_type = argument_types[2].clone();
272        // PRF key k_2
273        let key_2 = g.input(prf_truncate_type)?;
274
275        if self.k == 0 {
276            input_node.set_as_output()?;
277            g.finalize()?;
278            return Ok(g);
279        }
280        // PRF key k_02, this is the last key in the multiplication PRF key triple
281        let key_02 = prf_mul_keys.tuple_get(0)?;
282        // PRF key k_12, this is the second key in the multiplication PRF key triple
283        let key_12 = prf_mul_keys.tuple_get(2)?;
284
285        let st = input_t.get_scalar_type();
286        let st_size = scalar_size_in_bits(st);
287
288        let x0 = {
289            let share = input_node.tuple_get(0)?;
290            // 0. The below protocol works correctly for integers in the range [0, modulus/2).
291            //    For signed inputs, we add modulus/4 to input resulting in input + modulus/4 in [0, modulus/2)
292            //    For correctness, we should remove modulus/2^(k+2) after truncation since
293            //    Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2)
294            if st.is_signed() {
295                // modulus/4
296                let mod_fraction = constant_scalar(&g, 1u128 << (st_size - 2), st)?;
297                share.add(mod_fraction)?
298            } else {
299                share
300            }
301        };
302        let x1 = input_node.tuple_get(1)?;
303        let x2 = input_node.tuple_get(2)?;
304
305        // 1. Party 2 generates a random integer r of the input scalar type.
306        let r = g.prf(key_2, 0, input_t.clone())?;
307
308        let unsigned_st = st.get_unsigned_counterpart();
309        // 2. Party 2 extracts the MSB of r in the arithmetic form (r_msb).
310        let r_msb = {
311            // (0,0, ..., 1)
312            let mask = constant_scalar(&g, 1u128 << (st_size - 1), unsigned_st)?.a2b()?;
313            // (0,0, ..., r_(st_size-1)) -> r_(st_size-1)*2^(st_size-1) as unsigned integer
314            let r_msb_scaled = r.a2b()?.multiply(mask)?.b2a(unsigned_st)?;
315            // (r_(st_size-1), 0, ..., 0) -> r_(st_size-1) of st type
316            r_msb_scaled.truncate(1 << (st_size - 1))?.a2b()?.b2a(st)?
317        };
318
319        // 3. Party 2 removes the MSB of r and truncates the result by k bits (r_truncated = sum_(i=k)^(st_size-2) r_i * 2^(i-k))
320        let r_truncated = {
321            // (0, ..., 0, 1, ..., 1, 0, ..., 0) to extract r_k, r_(k+1), ..., r_(st_size-2)
322            let mask = constant_scalar(
323                &g,
324                (1u128 << (st_size - 1)) - (1u128 << self.k),
325                unsigned_st,
326            )?
327            .a2b()?;
328            // r_k + r_(k+1) * 2 + ... + r_(st_size-2) * 2^(st_size-2-k)
329            r.a2b()?.multiply(mask)?.b2a(st)?.truncate(1 << self.k)?
330        };
331
332        // 4. Party 2 creates 2-out-of-2 shares of r, r_msb and r_truncated.
333        //    Such shares for a value val have the form (val0, val1) such that val = val0 + val1.
334        //    The corresponding share val0 = PRF(k_02, iv) of the aforementioned 3 values is generated by parties 0 and 2.
335        //    The second share val1 = val - val0 is computed by party 2 and then it is sent to party 1.
336        let share_for_two = |val: Node| -> Result<(Node, Node)> {
337            // first share val0 for party 0
338            let share0 = g.prf(key_02.clone(), 0, val.get_type()?)?;
339            // second share val1 for party 1
340            let share1 = val.subtract(share0.clone())?;
341            let share1_sent = share1.nop()?;
342            share1_sent.add_annotation(NodeAnnotation::Send(2, 1))?;
343            Ok((share0, share1_sent))
344        };
345        let (r0, r1) = share_for_two(r)?;
346        let (r_msb0, r_msb1) = share_for_two(r_msb)?;
347        let (r_truncated0, r_truncated1) = share_for_two(r_truncated)?;
348
349        // 5. Parties 0 and 2 compute y0 = PRF(key_02, iv).
350        //    Parties 1 and 2 compute y2 = PRF(key_12, iv).
351        //    The pair (y0, y2) is a 2-out-of-3 share of the output known to party 2.
352        let y0 = g.prf(key_02, 0, input_t.clone())?;
353        let y2 = g.prf(key_12, 0, input_t)?;
354
355        // 6. Party 0 and Party 1 create a 2-out-of-2 share of the input x.
356        //    To obtain its share, party 0 sums its 2-out-of-3 shares to get z0 = x0 + x1. Party 1 takes z1 = x2.
357        let z0 = x0.add(x1)?;
358        let z1 = x2;
359
360        // 7. Given r from party 2, parties 0 and 1 compute 2-out-of-2 shares of c = x + r via c0 = z0 + r0 and c1 = z1 + r1.
361        let c_share0 = z0.add(r0)?;
362        let c_share1 = z1.add(r1)?;
363
364        // 8. Parties 0 and 1 reveal c to each other and compute c_truncated_mod = (c/2^k) mod 2^(st_size-k-1).
365        //    This is c truncated by k bits without its MSB.
366        let c_share0_sent = c_share0.nop()?;
367        c_share0_sent.add_annotation(NodeAnnotation::Send(0, 1))?;
368        let c_share1_sent = c_share1.nop()?;
369        c_share1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
370        let c = c_share0_sent.add(c_share1_sent)?;
371        // Interpret c as unsigned integer and truncate
372        // (c / scale) mod 2^(st_size-1-k)
373        let c_truncated = c
374            .a2b()?
375            .b2a(unsigned_st)?
376            .truncate(1 << self.k)?
377            .a2b()?
378            .b2a(st)?;
379        let c_truncated_mod = {
380            // (1,1, ..., 1, 0, ..., 0) to perform mod 2^(st_size-1-k)
381            let mask = g
382                .constant(
383                    scalar_type(st),
384                    Value::from_scalar((1u128 << (st_size - 1 - self.k)) - 1, st)?,
385                )?
386                .a2b()?;
387            c_truncated.a2b()?.multiply(mask)?.b2a(st)?
388        };
389
390        // 9. Parties 0 and 1 compute the MSB of c via c/2^(st_size-1).
391        let c_msb = c
392            .a2b()?
393            .b2a(unsigned_st)?
394            .truncate(1 << (st_size - 1))?
395            .a2b()?
396            .b2a(st)?;
397
398        // 10. Parties 0 and 1 compute 2-out-of-2 shares of b = r_msb XOR c_msb using the following expressions:
399        //             b0 = r_msb0 + c_msb - 2 * c_msb * r_msb0,
400        //             b1 = r_msb1 - 2 * c_msb * r_msb1.
401        //     Note that b0 + b1 = r_msb + c_msb - 2*c_msb*r_msb = r_msb XOR c_msb.
402        //     All the above operations can be done locally as c_msb is known to parties 0 and 1.
403        let two = constant_scalar(&g, 2, st)?;
404        let b0 = r_msb0
405            .subtract(r_msb0.multiply(c_msb.clone())?.multiply(two.clone())?)?
406            .add(c_msb.clone())?;
407        let b1 = r_msb1.subtract(r_msb1.multiply(c_msb)?.multiply(two)?)?;
408
409        // 11. Parties 0 and 1 compute 2-out-of-2 shares of y' = c_truncated_mod - r_truncated + b * 2^(st_size-1-k).
410        //     This value is equal to the desired result floor(x/2^k) + w.
411        // 2^(st_size-1-k)
412        let power2 = constant_scalar(&g, 1u128 << (st_size - 1 - self.k), st)?;
413        // y' = c_truncated_mod - r_truncated + b * 2^(st_size-1-k)
414        // This is 2-out-of-2 sharing of the result
415        let y_prime0 = b0
416            .multiply(power2.clone())?
417            .subtract(r_truncated0)?
418            .add(c_truncated_mod)?;
419        let y_prime1 = b1.multiply(power2)?.subtract(r_truncated1)?;
420
421        // 12. Party 0 masks y'0 with a random value y0 from party 2 as y_tilde0 = y'0 - y0 and sends it to party 1.
422        let y_tilde0 = y_prime0.subtract(y0.clone())?;
423        let y_tilde0_sent = y_tilde0.nop()?;
424        y_tilde0_sent.add_annotation(NodeAnnotation::Send(0, 1))?;
425        // 13. Party 1 masks y'1 with a random value y2 from party 2 as y_tilde1 = y'1 - y2 and sends it to party 0.
426        let y_tilde1 = y_prime1.subtract(y2.clone())?;
427        let y_tilde1_sent = y_tilde1.nop()?;
428        y_tilde1_sent.add_annotation(NodeAnnotation::Send(1, 0))?;
429
430        // 14. Parties 0 and 1 compute y1 = y_tilde0 + y_tilde1 = y' - y0 - y2.
431        //     Together with y0 and y2 this value constitute the sharing of the truncation output.
432        let y1 = {
433            let sum01 = y_tilde0_sent.add(y_tilde1_sent)?;
434            if st.is_signed() {
435                // 14!. If input is signed, we should remove modulus/2^(k+2) after truncation since
436                //      Truncate(input + modulus/4, 2^k) = Truncate(input, 2^k) + modulus/2^(k+2)
437                let mod_fraction = constant_scalar(&g, 1u128 << (st_size - 2 - self.k), st)?;
438                sum01.subtract(mod_fraction)?
439            } else {
440                sum01
441            }
442        };
443
444        // 15. The protocol returns (y0, y1, y2).
445        g.create_tuple(vec![y0, y1, y2])?.set_as_output()?;
446
447        g.finalize()?;
448        Ok(g)
449    }
450
451    fn get_name(&self) -> String {
452        format!("TruncateMPC2K({})", self.k)
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::bytes::{add_u128, subtract_vectors_u128};
460    use crate::data_types::{array_type, scalar_type, ScalarType, INT128, UINT128};
461    use crate::data_values::Value;
462    use crate::evaluators::random_evaluate;
463    use crate::graphs::util::simple_context;
464    use crate::inline::inline_ops::{InlineConfig, InlineMode};
465    use crate::mpc::mpc_compiler::{prepare_for_mpc_evaluation, IOStatus, PARTIES};
466
467    fn prepare_context(
468        t: Type,
469        party_id: IOStatus,
470        output_parties: Vec<IOStatus>,
471        scale: u128,
472        inline_config: InlineConfig,
473    ) -> Result<Context> {
474        let c = simple_context(|g| {
475            let i = g.input(t)?;
476            g.truncate(i, scale)
477        })?;
478
479        prepare_for_mpc_evaluation(c, vec![vec![party_id]], vec![output_parties], inline_config)
480    }
481
482    fn prepare_input(input: Vec<u128>, input_status: IOStatus, t: Type) -> Result<Vec<Value>> {
483        let mpc_input = match t {
484            Type::Scalar(st) => {
485                if input_status == IOStatus::Public || matches!(input_status, IOStatus::Party(_)) {
486                    return Ok(vec![Value::from_scalar(input[0], st)?]);
487                }
488
489                // shares of input = (input - 3, 1, 2)
490                let mut shares_vec = vec![];
491                shares_vec.push(Value::from_scalar(
492                    subtract_vectors_u128(&input, &[3], st.get_modulus())?[0],
493                    st,
494                )?);
495
496                for i in 1..PARTIES as u64 {
497                    shares_vec.push(Value::from_scalar(i, st)?);
498                }
499                shares_vec
500            }
501            Type::Array(_, st) => {
502                if input_status == IOStatus::Public || matches!(input_status, IOStatus::Party(_)) {
503                    return Ok(vec![Value::from_flattened_array(&input, st)?]);
504                }
505
506                // shares of input = (input - 3, 1, 2)
507                let mut shares_vec = vec![];
508                let threes = vec![3; input.len()];
509                let first_share = subtract_vectors_u128(&input, &threes, st.get_modulus())?;
510                shares_vec.push(Value::from_flattened_array(&first_share, st)?);
511
512                for i in 1..PARTIES {
513                    let share = vec![i; input.len()];
514                    shares_vec.push(Value::from_flattened_array(&share, st)?);
515                }
516                shares_vec
517            }
518            _ => {
519                panic!("Shouldn't be here");
520            }
521        };
522
523        Ok(vec![Value::from_vector(mpc_input)])
524    }
525
526    // output and expected are assumed to be small enough to be converted to i64 slices
527    fn compare_truncate_output(
528        output: &[u128],
529        expected: &[u128],
530        equal: bool,
531        st: ScalarType,
532    ) -> Result<()> {
533        if st.is_signed() {
534            for (i, out_value) in output.iter().enumerate() {
535                let mut dif = (*out_value) as i64 - expected[i] as i64;
536                dif = dif.abs();
537                if equal && dif > 1 {
538                    return Err(runtime_error!("Output is too far from expected"));
539                }
540                if !equal && dif <= 1 {
541                    return Err(runtime_error!("Output is too close to expected"));
542                }
543            }
544        } else {
545            for (i, out_value) in output.iter().enumerate() {
546                let dif = (*out_value) - expected[i];
547                if equal && dif > 1 {
548                    return Err(runtime_error!("Output is too far from expected"));
549                }
550                if !equal && dif <= 1 {
551                    return Err(runtime_error!("Output is too close to expected"));
552                }
553            }
554        }
555
556        Ok(())
557    }
558
559    fn check_output(
560        mpc_graph: Graph,
561        inputs: Vec<Value>,
562        expected: Vec<u128>,
563        output_parties: Vec<IOStatus>,
564        t: Type,
565    ) -> Result<()> {
566        let output = random_evaluate(mpc_graph.clone(), inputs)?;
567        let st = t.get_scalar_type();
568
569        if output_parties.is_empty() {
570            let out = output.access_vector(|v| {
571                let modulus = st.get_modulus();
572                let mut res = vec![0; expected.len()];
573                for val in v {
574                    let arr = match t.clone() {
575                        Type::Scalar(_) => {
576                            vec![val.to_u128(st)?]
577                        }
578                        Type::Array(_, _) => val.to_flattened_array_u128(t.clone())?,
579                        _ => {
580                            panic!("Shouldn't be here");
581                        }
582                    };
583                    for i in 0..expected.len() {
584                        res[i] = add_u128(res[i], arr[i], modulus);
585                    }
586                }
587                Ok(res)
588            })?;
589            compare_truncate_output(&out, &expected, true, st)?;
590        } else {
591            assert!(output.check_type(t.clone())?);
592            let out = match t.clone() {
593                Type::Scalar(_) => vec![output.to_u128(st)?],
594                Type::Array(_, _) => output.to_flattened_array_u128(t.clone())?,
595                _ => {
596                    panic!("Shouldn't be here");
597                }
598            };
599            compare_truncate_output(&out, &expected, true, st)?;
600        }
601
602        Ok(())
603    }
604
605    fn truncate_helper(st: ScalarType, scale: u128) -> Result<()> {
606        let helper = |t: Type,
607                      input: Vec<u128>,
608                      input_status: IOStatus,
609                      output_parties: Vec<IOStatus>,
610                      inline_config: InlineConfig|
611         -> Result<()> {
612            let mpc_context = prepare_context(
613                t.clone(),
614                input_status.clone(),
615                output_parties.clone(),
616                scale,
617                inline_config,
618            )?;
619            let mpc_graph = mpc_context.get_main_graph()?;
620
621            let mpc_input = prepare_input(input.clone(), input_status.clone(), t.clone())?;
622
623            let expected = if t.get_scalar_type().is_signed() {
624                input
625                    .iter()
626                    .map(|x| {
627                        let val = *x as i64;
628                        let res = val / (scale as i64);
629                        res as u128
630                    })
631                    .collect()
632            } else {
633                input
634                    .iter()
635                    .map(|x| {
636                        let val = *x;
637                        let res = val / (scale as u128);
638                        res
639                    })
640                    .collect()
641            };
642            check_output(mpc_graph, mpc_input, expected, output_parties, t.clone())?;
643
644            Ok(())
645        };
646        let inline_config_simple = InlineConfig {
647            default_mode: InlineMode::Simple,
648            ..Default::default()
649        };
650        let helper_runs = |inputs: Vec<u128>, t: Type| -> Result<()> {
651            helper(
652                t.clone(),
653                inputs.clone(),
654                IOStatus::Party(2),
655                vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
656                inline_config_simple.clone(),
657            )?;
658            helper(
659                t.clone(),
660                inputs.clone(),
661                IOStatus::Shared,
662                vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
663                inline_config_simple.clone(),
664            )?;
665            helper(
666                t.clone(),
667                inputs.clone(),
668                IOStatus::Party(2),
669                vec![IOStatus::Party(0)],
670                inline_config_simple.clone(),
671            )?;
672            helper(
673                t.clone(),
674                inputs.clone(),
675                IOStatus::Party(2),
676                vec![],
677                inline_config_simple.clone(),
678            )?;
679            helper(
680                t.clone(),
681                inputs.clone(),
682                IOStatus::Public,
683                vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
684                inline_config_simple.clone(),
685            )?;
686            helper(
687                t.clone(),
688                inputs.clone(),
689                IOStatus::Public,
690                vec![],
691                inline_config_simple.clone(),
692            )?;
693            Ok(())
694        };
695        // This test should fail with a probability depending on input and the number of runs
696        let helper_malformed = |inputs: Vec<u128>, t: Type, runs: u64| -> Result<()> {
697            for _ in 0..runs {
698                helper_runs(inputs.clone(), t.clone())?;
699            }
700            Ok(())
701        };
702
703        helper_runs(vec![0], scalar_type(st))?;
704        helper_runs(vec![1000], scalar_type(st))?;
705        helper_runs(vec![0, 0], array_type(vec![2], st))?;
706        helper_runs(vec![2000, 255], array_type(vec![2], st))?;
707
708        if scale.is_power_of_two() && !st.is_signed() {
709            // 2^127 - 1, this is a maximal UINT64 value that can be truncated without errors by TruncateMPC2K
710            helper_runs(vec![(1u128 << 127) - 1], scalar_type(st))?;
711        }
712
713        if st.is_signed() {
714            // -1
715            helper_runs(vec![u128::MAX], scalar_type(st))?;
716            // -1000
717            helper_runs(vec![u128::MAX - 999], scalar_type(st))?;
718            // [-10. -1024]
719            helper_runs(
720                vec![u128::MAX - 9, u128::MAX - 1023],
721                array_type(vec![2], st),
722            )?;
723            if scale.is_power_of_two() {
724                // - 2^126, this is a minimal INT128 value that can be truncated without errors by TruncateMPC2K
725                helper_runs(vec![-(1i128 << 126) as u128], scalar_type(st))?;
726                // 2^126-1, this is a maximal INT128 value that can be truncated without errors by TruncateMPC2K
727                helper_runs(vec![(1u128 << 126) - 1], scalar_type(st))?;
728            }
729        }
730
731        // Probabilistic tests of TruncateMPC for big values in absolute size
732        if scale != 1 && !scale.is_power_of_two() {
733            // 2^127 - 1, should fail with probability 1 - 2^(-40)
734            assert!(helper_malformed(vec![i128::MAX as u128], scalar_type(st), 40).is_err());
735            // -2^127, should fail with probability 1 - 2^(-40)
736            assert!(helper_malformed(vec![i128::MIN as u128], scalar_type(st), 40).is_err());
737            // [2^127 - 1, 2^127 - 2]
738            assert!(helper_malformed(
739                vec![i128::MAX as u128, i128::MAX as u128 - 1],
740                array_type(vec![2], st),
741                40
742            )
743            .is_err());
744            // [-2^127, -2^127 + 1]
745            assert!(helper_malformed(
746                vec![1u128 << 127, (1u128 << 127) + 1],
747                array_type(vec![2], st),
748                40
749            )
750            .is_err());
751        }
752        Ok(())
753    }
754
755    #[test]
756    fn test_truncate() -> Result<()> {
757        truncate_helper(UINT128, 1)?;
758        truncate_helper(UINT128, 1 << 3)?;
759        truncate_helper(UINT128, 1 << 7)?;
760        truncate_helper(UINT128, 1 << 29)?;
761        truncate_helper(UINT128, 1 << 31)?;
762
763        truncate_helper(INT128, 1)?;
764        truncate_helper(INT128, 15)?;
765        truncate_helper(INT128, 1 << 3)?;
766        truncate_helper(INT128, 1 << 7)?;
767        truncate_helper(INT128, 1 << 29)?;
768        truncate_helper(INT128, (1 << 29) - 1)?;
769
770        assert!(truncate_helper(UINT128, 15).is_err());
771        Ok(())
772    }
773}