ciphercore_base/mpc/
mpc_compiler.rs

1use crate::custom_ops::{run_instantiation_pass, ContextMappings, CustomOperation, MappedContext};
2use crate::data_types::{array_type, scalar_type, tuple_type, Type, TypePointer, BIT, UINT64};
3use crate::data_values::Value;
4use crate::errors::Result;
5use crate::evaluators::Evaluator;
6use crate::graphs::{
7    copy_node_name, create_context, Context, Graph, Node, NodeAnnotation, Operation,
8};
9use crate::inline::inline_ops::{inline_operations, InlineConfig};
10use crate::mpc::mpc_apply_permutation::ApplyPermutationMPC;
11use crate::mpc::mpc_conversion::{A2BMPC, B2AMPC};
12use crate::mpc::mpc_truncate::{TruncateMPC, TruncateMPC2K};
13use crate::optimizer::optimize::optimize_context;
14
15use std::collections::HashMap;
16use std::collections::HashSet;
17
18use super::mpc_arithmetic::{add_mpc, general_multiply_mpc, subtract_mpc};
19use super::mpc_psi::JoinMPC;
20use super::mpc_radix_sort::RadixSortMPC;
21use super::resharing::{get_nodes_to_reshare, reshare};
22
23// We implement the ABY3 protocol, which has 3 parties involved
24pub const PARTIES: usize = 3;
25
26// Ownership status of input/output nodes
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum IOStatus {
29    Public,     // input/output is public / known to all the parties
30    Party(u64), // input is privately owned by party i; output is revealed to party i
31    Shared,     // input/output is shared / unknown to all the parties
32}
33
34// Bitsize of PRF keys
35pub const KEY_LENGTH: u64 = 128;
36
37/// Checks whether a private tuple value has the correct number of shares
38pub(super) fn check_private_tuple(v: Vec<TypePointer>) -> Result<()> {
39    if v.len() != PARTIES {
40        return Err(runtime_error!(
41            "Private tuple should have {} values, but {} provided",
42            PARTIES,
43            v.len()
44        ));
45    }
46    let t = (*v[0]).clone();
47    for coef in v.iter().skip(1) {
48        if t != **coef {
49            return Err(runtime_error!(
50                "Private tuple should have value of the same type"
51            ));
52        }
53    }
54    Ok(())
55}
56
57fn is_one_node_private(nodes: &[Node], private_nodes: &HashSet<Node>) -> bool {
58    for node in nodes {
59        if private_nodes.contains(node) {
60            return true;
61        }
62    }
63    false
64}
65
66fn are_all_nodes_private(nodes: &[Node], private_nodes: &HashSet<Node>) -> bool {
67    for node in nodes {
68        if !private_nodes.contains(node) {
69            return false;
70        }
71    }
72    true
73}
74
75/// Generate random share of a given node: (node + alpha_0, alpha_1, alpha_2),
76/// where alpha_i = PRF(k_i, 0) - PRF(k_(i+1 % 3), 0).
77/// If no node is given, generate random shares of zero.
78/// The input node is given as a pair of the node and the index of the party that wants to share this node.
79/// The iv of any PRF call is set to zero here, but it will be changed to a unique
80/// number by `uniquify_prf_id` when preparing an MPC graph for evaluation.
81fn recursively_generate_node_shares(
82    g: Graph,
83    prf_keys: Vec<Node>,
84    t: Type,
85    node_to_share: Option<(Node, IOStatus)>,
86) -> Result<Vec<Node>> {
87    match t {
88        Type::Scalar(_) | Type::Array(_, _) => {
89            let mut random_shares = vec![];
90            for key in prf_keys {
91                let prf_i = g.prf(key, 0, t.clone())?;
92                random_shares.push(prf_i);
93            }
94            let mut node_shares = vec![];
95            for i in 0..PARTIES {
96                let ip1 = (i + 1) % PARTIES;
97                let alpha = g.subtract(random_shares[i].clone(), random_shares[ip1].clone())?;
98                node_shares.push(alpha);
99            }
100
101            match node_to_share {
102                Some((node, IOStatus::Party(id))) => {
103                    node_shares[id as usize] = node_shares[id as usize].add(node)?;
104                }
105                Some((node, IOStatus::Public)) => {
106                    node_shares[0] = node_shares[0].add(node)?;
107                }
108                Some((_, IOStatus::Shared)) => {
109                    return Err(runtime_error!(
110                        "Given node must belong to a party or be public"
111                    ));
112                }
113                None => (),
114            }
115            Ok(node_shares)
116        }
117        Type::Tuple(types) => {
118            let mut unpacked_node_shares = vec![vec![]; PARTIES];
119            for (i, sub_t) in types.iter().enumerate() {
120                let sub_node_to_share = match node_to_share.clone() {
121                    Some((node, party_id)) => Some((node.tuple_get(i as u64)?, party_id)),
122                    None => None,
123                };
124                let sub_node_shares = recursively_generate_node_shares(
125                    g.clone(),
126                    prf_keys.clone(),
127                    (**sub_t).clone(),
128                    sub_node_to_share,
129                )?;
130                for party_id in 0..PARTIES {
131                    unpacked_node_shares[party_id].push(sub_node_shares[party_id].clone());
132                }
133            }
134            let mut node_shares = vec![];
135            for unpacked_share in unpacked_node_shares {
136                node_shares.push(g.create_tuple(unpacked_share)?);
137            }
138            Ok(node_shares)
139        }
140        Type::Vector(length, element_type) => {
141            let mut unpacked_node_shares = vec![vec![]; PARTIES];
142            for i in 0..length {
143                let sub_node_to_share = match node_to_share.clone() {
144                    Some((node, party_id)) => {
145                        let i_node =
146                            g.constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?;
147                        Some((node.vector_get(i_node)?, party_id))
148                    }
149                    None => None,
150                };
151                let sub_node_shares = recursively_generate_node_shares(
152                    g.clone(),
153                    prf_keys.clone(),
154                    (*element_type).clone(),
155                    sub_node_to_share,
156                )?;
157                for party_id in 0..PARTIES {
158                    unpacked_node_shares[party_id].push(sub_node_shares[party_id].clone());
159                }
160            }
161            let mut node_shares = vec![];
162            for unpacked_share in unpacked_node_shares {
163                node_shares.push(g.create_vector((*element_type).clone(), unpacked_share)?);
164            }
165            Ok(node_shares)
166        }
167        Type::NamedTuple(names_types) => {
168            let mut unpacked_node_shares = vec![vec![]; PARTIES];
169            for (name, sub_t) in &names_types {
170                let sub_node_to_share = match node_to_share.clone() {
171                    Some((node, party_id)) => {
172                        Some((node.named_tuple_get((*name).clone())?, party_id))
173                    }
174                    None => None,
175                };
176                let sub_node_shares = recursively_generate_node_shares(
177                    g.clone(),
178                    prf_keys.clone(),
179                    (**sub_t).clone(),
180                    sub_node_to_share,
181                )?;
182                for party_id in 0..PARTIES {
183                    unpacked_node_shares[party_id]
184                        .push(((*name).clone(), sub_node_shares[party_id].clone()));
185                }
186            }
187            let mut node_shares = vec![];
188            for unpacked_share in unpacked_node_shares {
189                node_shares.push(g.create_named_tuple(unpacked_share)?);
190            }
191            Ok(node_shares)
192        }
193    }
194}
195
196pub(crate) fn get_node_shares(
197    g: Graph,
198    prf_keys: Node,
199    t: Type,
200    node_to_share: Option<(Node, IOStatus)>,
201) -> Result<Vec<Node>> {
202    let mut prf_keys_vec = vec![];
203    for i in 0..PARTIES {
204        let key = g.tuple_get(prf_keys.clone(), i as u64)?;
205        prf_keys_vec.push(key);
206    }
207    recursively_generate_node_shares(g, prf_keys_vec, t, node_to_share)
208}
209
210pub(super) fn get_zero_shares(g: Graph, prf_keys: Node, t: Type) -> Result<Vec<Node>> {
211    get_node_shares(g, prf_keys, t, None)
212}
213
214/// Returns the hash set of the private nodes of the given graph,
215/// a Boolean value indicating whether PRF keys should be used for multiplication,
216/// a Boolean value indicating whether PRF keys should be used for B2A.
217/// a Boolean value indicating whether PRF keys should be used for Truncate2k.
218pub(super) fn propagate_private_annotations(
219    graph: Graph,
220    is_input_private: Vec<bool>,
221) -> Result<(HashSet<Node>, bool, bool, bool)> {
222    let mut private_nodes: HashSet<Node> = HashSet::new();
223    let mut use_prf_for_mul = false;
224    let mut use_prf_for_b2a = false;
225    let mut use_prf_for_truncate2k = false;
226    let mut input_id = 0usize;
227    for node in graph.get_nodes() {
228        let op = node.get_operation();
229        match op {
230            Operation::Input(_) => {
231                if is_input_private[input_id] {
232                    private_nodes.insert(node);
233                }
234                input_id += 1;
235            }
236            Operation::Add
237            | Operation::Subtract
238            | Operation::Multiply
239            | Operation::MixedMultiply
240            | Operation::Dot
241            | Operation::Matmul
242            | Operation::Gemm(_, _)
243            | Operation::Join(_, _)
244            | Operation::JoinWithColumnMasks(_, _)
245            | Operation::A2B
246            | Operation::B2A(_)
247            | Operation::PermuteAxes(_)
248            | Operation::ArrayToVector
249            | Operation::TupleGet(_)
250            | Operation::NamedTupleGet(_)
251            | Operation::VectorToArray
252            | Operation::GetSlice(_)
253            | Operation::Reshape(_)
254            | Operation::Sum(_)
255            | Operation::CumSum(_)
256            | Operation::Get(_)
257            | Operation::CreateTuple
258            | Operation::CreateNamedTuple(_)
259            | Operation::CreateVector(_)
260            | Operation::Stack(_)
261            | Operation::ApplyPermutation(_)
262            | Operation::Sort(_)
263            | Operation::Concatenate(_)
264            | Operation::Zip
265            | Operation::Repeat(_) => {
266                let dependencies = node.get_node_dependencies();
267                if is_one_node_private(&dependencies, &private_nodes) {
268                    private_nodes.insert(node.clone());
269                    if matches!(
270                        op,
271                        Operation::Join(_, _) | Operation::JoinWithColumnMasks(_, _)
272                    ) {
273                        use_prf_for_mul = true;
274                    }
275                }
276                if ([
277                    Operation::Multiply,
278                    Operation::Dot,
279                    Operation::Matmul,
280                    Operation::A2B,
281                ]
282                .contains(&op)
283                    || matches!(op, Operation::Gemm(_, _)))
284                    && are_all_nodes_private(&dependencies, &private_nodes)
285                {
286                    use_prf_for_mul = true;
287                }
288                if matches!(op, Operation::B2A(_))
289                    && are_all_nodes_private(&dependencies, &private_nodes)
290                {
291                    use_prf_for_mul = true;
292                    use_prf_for_b2a = true;
293                }
294                if matches!(op, Operation::Sort(_)) && private_nodes.contains(&dependencies[0]) {
295                    use_prf_for_mul = true;
296                }
297                if matches!(op, Operation::ApplyPermutation(_))
298                    && private_nodes.contains(&dependencies[1])
299                {
300                    use_prf_for_mul = true;
301                }
302                if matches!(op, Operation::MixedMultiply)
303                    && private_nodes.contains(&dependencies[1])
304                {
305                    use_prf_for_mul = true;
306                }
307            }
308            Operation::Truncate(scale) => {
309                let dependencies = node.get_node_dependencies();
310                if is_one_node_private(&dependencies, &private_nodes) {
311                    private_nodes.insert(node.clone());
312                }
313
314                if are_all_nodes_private(&dependencies, &private_nodes) {
315                    use_prf_for_mul = true;
316                    if scale.is_power_of_two() {
317                        use_prf_for_truncate2k = true;
318                    }
319                }
320            }
321            Operation::Constant(_, _) | Operation::Zeros(_) | Operation::Ones(_) => {
322                // Constants are always public
323            }
324            Operation::VectorGet => {
325                let dependencies = node.get_node_dependencies();
326                if private_nodes.contains(&dependencies[1]) {
327                    return Err(runtime_error!("VectorGet can't have a private index"));
328                }
329                if private_nodes.contains(&dependencies[0]) {
330                    private_nodes.insert(node.clone());
331                }
332            }
333            _ => {
334                return Err(runtime_error!(
335                    "MPC compiler can't preprocess inputs of {}",
336                    op
337                ));
338            }
339        }
340    }
341    Ok((
342        private_nodes,
343        use_prf_for_mul,
344        use_prf_for_b2a,
345        use_prf_for_truncate2k,
346    ))
347}
348
349pub(super) fn is_array_shared(array: &Node) -> Result<bool> {
350    Ok(array.get_type()?.is_tuple())
351}
352
353pub(super) fn compile_to_mpc_graph(
354    in_graph: Graph,
355    is_input_private: Vec<bool>,
356    out_context: Context,
357    out_mapping: &mut ContextMappings,
358) -> Result<Graph> {
359    let out_graph = out_context.create_graph()?;
360
361    let (private_nodes, use_prf_for_mul, use_prf_for_b2a, use_prf_for_truncate2k) =
362        propagate_private_annotations(in_graph.clone(), is_input_private)?;
363    // Input tuple of PRF keys for multiplication if needed
364    // If created, these are the first input node of a graph
365    let prf_keys_mul = if use_prf_for_mul {
366        // PRF key type
367        let key_t = array_type(vec![KEY_LENGTH], BIT);
368        let key_inputs = vec![key_t; PARTIES];
369        let keys_type = tuple_type(key_inputs);
370        let node = out_graph.input(keys_type)?;
371        node.add_annotation(NodeAnnotation::PRFMultiplication)?;
372        Some(node)
373    } else {
374        None
375    };
376    // Input tuple of PRF keys for B2A if needed
377    // If created, these are the second input node of a graph
378    let prf_keys_b2a = if use_prf_for_b2a {
379        // PRF key type
380        let key_t = array_type(vec![KEY_LENGTH], BIT);
381        let key_inputs = vec![key_t; PARTIES];
382        let key_triple_type = tuple_type(key_inputs);
383        let keys_type = tuple_type(vec![key_triple_type; 2]);
384        let node = out_graph.input(keys_type)?;
385        node.add_annotation(NodeAnnotation::PRFB2A)?;
386        Some(node)
387    } else {
388        None
389    };
390    // Input tuple of PRF keys for Truncate if needed
391    // If created, these are the second input node of a graph
392    let prf_keys_truncate2k = if use_prf_for_truncate2k {
393        // PRF key type
394        let key_t = array_type(vec![KEY_LENGTH], BIT);
395        let node = out_graph.input(key_t)?;
396        node.add_annotation(NodeAnnotation::PRFTruncate)?;
397        Some(node)
398    } else {
399        None
400    };
401
402    // This helper closure applies the operation (op) of a given node (node_to_be_private)
403    // to given input nodes (node_dependencies).
404    // If the node is public, the operation is applied as in the plaintext evaluator.
405    // If the node is private, the helper extracts secret shares from input nodes
406    // or promote public nodes to private and apply the operation on each share.
407    // The following assumptions should be satisfied by the helper inputs:
408    // - every input of a public node is a public node;
409    // - a constant node is public;
410    // - every private node has at least one private input node;
411    // all the public inputs are promoted to private (except for VectorGet);
412    // - a private VectorGet node has a private input vector node and a public index node.
413    // TODO: extract it to the method.
414    let apply_op = |node_to_be_private: Node,
415                    op: Operation,
416                    node_dependencies: Vec<Node>,
417                    old_dependencies: Vec<Node>|
418     -> Result<Node> {
419        if !private_nodes.contains(&node_to_be_private) {
420            return out_graph.add_node(node_dependencies, vec![], op);
421        }
422        if let Operation::Input(t) = op.clone() {
423            let tuple_t = tuple_type(vec![t; PARTIES]);
424            return out_graph.input(tuple_t);
425        }
426        let mut result_shares = vec![];
427        for i in 0..PARTIES {
428            let share = match op.clone() {
429                Operation::VectorGet => vec![
430                    out_graph.tuple_get(node_dependencies[0].clone(), i as u64)?,
431                    node_dependencies[1].clone(),
432                ],
433                _ => {
434                    let mut share_vec = vec![];
435                    for (j, old_node) in old_dependencies.iter().enumerate() {
436                        if private_nodes.contains(old_node) {
437                            // if node is private, take the corresponding share
438                            let new_node =
439                                out_graph.tuple_get(node_dependencies[j].clone(), i as u64)?;
440                            share_vec.push(new_node);
441                        } else {
442                            // if node is public, we promote it to private by splitting into (node, 0, 0)
443                            // where 0 is a node containing zeros
444                            if i == 0 {
445                                share_vec.push(node_dependencies[j].clone())
446                            } else {
447                                share_vec.push(out_graph.zeros(node_dependencies[j].get_type()?)?);
448                            }
449                        }
450                    }
451                    share_vec
452                }
453            };
454            let result_share = out_graph.add_node(share, vec![], op.clone())?;
455            result_shares.push(result_share);
456        }
457        out_graph.create_tuple(result_shares)
458    };
459
460    let nodes_to_reshare = get_nodes_to_reshare(&in_graph, &private_nodes)?;
461
462    for node in in_graph.get_nodes() {
463        let op = node.get_operation();
464        let mut new_node = match op.clone() {
465            Operation::Input(_) => apply_op(node.clone(), op, vec![], vec![])?,
466            Operation::Add | Operation::Subtract => {
467                let dependencies = node.get_node_dependencies();
468                let input0 = dependencies[0].clone();
469                let input1 = dependencies[1].clone();
470                let new_input0 = out_mapping.get_node(input0);
471                let new_input1 = out_mapping.get_node(input1);
472                match op.clone() {
473                    Operation::Add => add_mpc(new_input0, new_input1)?,
474                    Operation::Subtract => subtract_mpc(new_input0, new_input1)?,
475                    _ => panic!("Should not be here"),
476                }
477            }
478            Operation::Multiply
479            | Operation::MixedMultiply
480            | Operation::Dot
481            | Operation::Matmul
482            | Operation::Gemm(_, _) => {
483                let dependencies = node.get_node_dependencies();
484                let input0 = dependencies[0].clone();
485                let input1 = dependencies[1].clone();
486                let new_input0 = out_mapping.get_node(input0.clone());
487                let new_input1 = out_mapping.get_node(input1.clone());
488                // Don't reshare the product node.
489                // Let compiler to check the following nodes and decide.
490                general_multiply_mpc(new_input0, new_input1, op, prf_keys_mul.clone(), false)?
491            }
492            Operation::Join(join_t, headers) | Operation::JoinWithColumnMasks(join_t, headers) => {
493                let dependencies = node.get_node_dependencies();
494                let input0 = dependencies[0].clone();
495                let input1 = dependencies[1].clone();
496                let new_input0 = out_mapping.get_node(input0.clone());
497                let new_input1 = out_mapping.get_node(input1.clone());
498                let mut headers_vec = vec![];
499                for headers_pair in headers {
500                    headers_vec.push(headers_pair);
501                }
502                let custom_op = match op {
503                    Operation::Join(_, _) => CustomOperation::new(JoinMPC {
504                        join_t,
505                        headers: headers_vec,
506                        has_column_masks: false,
507                    }),
508                    Operation::JoinWithColumnMasks(_, _) => CustomOperation::new(JoinMPC {
509                        join_t,
510                        headers: headers_vec,
511                        has_column_masks: true,
512                    }),
513                    _ => {
514                        return Err(runtime_error!("Shouldn't be here"));
515                    }
516                };
517
518                if private_nodes.contains(&node) {
519                    // If one input set is private, MPC protocols requires invoking PRFs.
520                    // Thus, PRF keys must be provided.
521                    let keys = match prf_keys_mul {
522                        Some(ref k) => k.clone(),
523                        None => {
524                            return Err(runtime_error!("Propagation of annotations failed"));
525                        }
526                    };
527                    out_graph.custom_op(
528                        custom_op,
529                        vec![new_input0.clone(), new_input1.clone(), keys],
530                    )?
531                } else {
532                    out_graph.custom_op(custom_op, vec![new_input0.clone(), new_input1.clone()])?
533                }
534            }
535            Operation::ApplyPermutation(inverse_permutation) => {
536                let dependencies = node.get_node_dependencies();
537                let input = dependencies[0].clone();
538                let permutation = dependencies[1].clone();
539                let new_input = out_mapping.get_node(input.clone());
540                let new_permutation = out_mapping.get_node(permutation.clone());
541                let custom_op = CustomOperation::new(ApplyPermutationMPC {
542                    inverse_permutation,
543                    reveal_output: false,
544                });
545                if private_nodes.contains(&permutation) {
546                    // If the permutation is private, MPC protocols requires invoking PRFs.
547                    // Thus, PRF keys must be provided.
548                    let keys = match prf_keys_mul {
549                        Some(ref k) => k.clone(),
550                        None => {
551                            return Err(runtime_error!("Propagation of annotations failed"));
552                        }
553                    };
554                    out_graph.custom_op(custom_op, vec![new_input, new_permutation, keys])?
555                } else {
556                    out_graph.custom_op(custom_op, vec![new_input, new_permutation])?
557                }
558            }
559            Operation::Sort(key) => {
560                let dependencies = node.get_node_dependencies();
561                let mut mapped_dependencies = dependencies
562                    .into_iter()
563                    .map(|d| out_mapping.get_node(d))
564                    .collect::<Vec<Node>>();
565                let custom_op = CustomOperation::new(RadixSortMPC::new(key));
566                if private_nodes.contains(&node) {
567                    // If one input set is private, MPC protocols requires invoking PRFs.
568                    // Thus, PRF keys must be provided.
569                    let keys = match prf_keys_mul {
570                        Some(ref k) => k.clone(),
571                        None => {
572                            return Err(runtime_error!("Propagation of annotations failed"));
573                        }
574                    };
575                    mapped_dependencies.push(keys);
576                }
577                out_graph.custom_op(custom_op, mapped_dependencies)?
578            }
579            Operation::Truncate(scale) => {
580                let dependencies = node.get_node_dependencies();
581                let input = dependencies[0].clone();
582                let new_input = out_mapping.get_node(input.clone());
583                let custom_op = if scale.is_power_of_two() {
584                    let k = scale.trailing_zeros() as u64;
585                    CustomOperation::new(TruncateMPC2K { k })
586                } else {
587                    CustomOperation::new(TruncateMPC { scale })
588                };
589                if private_nodes.contains(&input) {
590                    // If input is private, the MPC protocol requires invoking PRFs.
591                    // Thus, PRF keys must be provided.
592                    let prf_mul_keys = match prf_keys_mul {
593                        Some(ref k) => k.clone(),
594                        None => {
595                            return Err(runtime_error!("Propagation of annotations failed"));
596                        }
597                    };
598
599                    if scale.is_power_of_two() {
600                        let prf_truncate_keys = match prf_keys_truncate2k {
601                            Some(ref k) => k.clone(),
602                            None => {
603                                return Err(runtime_error!("Propagation of annotations failed"));
604                            }
605                        };
606
607                        out_graph.custom_op(
608                            custom_op,
609                            vec![new_input.clone(), prf_mul_keys, prf_truncate_keys],
610                        )?
611                    } else {
612                        out_graph.custom_op(custom_op, vec![new_input.clone(), prf_mul_keys])?
613                    }
614                } else {
615                    out_graph.custom_op(custom_op, vec![new_input.clone()])?
616                }
617            }
618            Operation::A2B => {
619                let dependencies = node.get_node_dependencies();
620                let input = dependencies[0].clone();
621                let new_input = out_mapping.get_node(input.clone());
622                let custom_op = CustomOperation::new(A2BMPC {});
623                if private_nodes.contains(&input) {
624                    // If input is private, the MPC protocol requires invoking PRFs.
625                    // Thus, PRF keys must be provided.
626                    let keys = match prf_keys_mul {
627                        Some(ref k) => k.clone(),
628                        None => {
629                            return Err(runtime_error!("Propagation of annotations failed"));
630                        }
631                    };
632                    out_graph.custom_op(custom_op, vec![new_input.clone(), keys])?
633                } else {
634                    out_graph.custom_op(custom_op, vec![new_input.clone()])?
635                }
636            }
637            Operation::B2A(st) => {
638                let dependencies = node.get_node_dependencies();
639                let input = dependencies[0].clone();
640                let new_input = out_mapping.get_node(input.clone());
641                let custom_op = CustomOperation::new(B2AMPC { st });
642                if private_nodes.contains(&input) {
643                    // If input is private, the MPC protocol requires invoking PRFs.
644                    // Thus, PRF keys must be provided.
645                    let keys_mul = match prf_keys_mul {
646                        Some(ref k) => k.clone(),
647                        None => {
648                            return Err(runtime_error!("Propagation of annotations failed"));
649                        }
650                    };
651                    let keys_b2a = match prf_keys_b2a {
652                        Some(ref k) => k.clone(),
653                        None => {
654                            return Err(runtime_error!("Propagation of annotations failed"));
655                        }
656                    };
657                    out_graph.custom_op(custom_op, vec![new_input.clone(), keys_mul, keys_b2a])?
658                } else {
659                    out_graph.custom_op(custom_op, vec![new_input.clone()])?
660                }
661            }
662            Operation::Constant(t, v) => out_graph.constant(t, v)?,
663            Operation::Zeros(t) => out_graph.zeros(t)?,
664            Operation::Ones(t) => out_graph.ones(t)?,
665            Operation::PermuteAxes(_)
666            | Operation::ArrayToVector
667            | Operation::VectorToArray
668            | Operation::TupleGet(_)
669            | Operation::NamedTupleGet(_)
670            | Operation::GetSlice(_)
671            | Operation::Reshape(_)
672            | Operation::Sum(_)
673            | Operation::CumSum(_)
674            | Operation::Get(_)
675            | Operation::Repeat(_) => {
676                let dependencies = node.get_node_dependencies();
677                let input = dependencies[0].clone();
678                let new_input = out_mapping.get_node(input.clone());
679                apply_op(input, op, vec![new_input], dependencies)?
680            }
681            Operation::VectorGet => {
682                let dependencies = node.get_node_dependencies();
683                let vector = dependencies[0].clone();
684                let index = dependencies[1].clone();
685                let new_vector = out_mapping.get_node(vector.clone());
686                let new_index = out_mapping.get_node(index.clone());
687
688                apply_op(vector, op, vec![new_vector, new_index], vec![])?
689            }
690            Operation::CreateTuple
691            | Operation::CreateNamedTuple(_)
692            | Operation::CreateVector(_)
693            | Operation::Stack(_)
694            | Operation::Concatenate(_)
695            | Operation::Zip => {
696                let dependencies = node.get_node_dependencies();
697                let new_dependencies: Vec<Node> = dependencies
698                    .iter()
699                    .map(|x| out_mapping.get_node((*x).clone()))
700                    .collect();
701                apply_op(node.clone(), op, new_dependencies, dependencies)?
702            }
703            _ => {
704                return Err(runtime_error!(
705                    "MPC compilation for {} not yet implemented",
706                    op
707                ));
708            }
709        };
710        if private_nodes.contains(&node) {
711            new_node = if nodes_to_reshare.contains(&node) {
712                // Node is 3-out-of-3 and needs resharing to 2-out-of-3 sharing
713                let keys_mul = match prf_keys_mul {
714                    Some(ref k) => k.clone(),
715                    None => {
716                        return Err(runtime_error!("Propagation of annotations failed"));
717                    }
718                };
719                reshare(&new_node, &keys_mul)?
720            } else {
721                new_node
722            };
723            new_node.add_annotation(NodeAnnotation::Private)?;
724        }
725        out_mapping.insert_node(node, new_node);
726    }
727    let old_output_node = in_graph.get_output_node()?;
728    let output_node = out_mapping.get_node(old_output_node);
729    out_graph.set_output_node(output_node)?;
730    out_graph.finalize()?;
731    Ok(out_graph)
732}
733
734fn contains_node_annotation(g: Graph, annotation: NodeAnnotation) -> Result<bool> {
735    let nodes = g.get_nodes();
736    for node in nodes {
737        let annotations = node.get_annotations()?;
738        if annotations.contains(&annotation) {
739            return Ok(true);
740        }
741    }
742    Ok(false)
743}
744
745fn share_node(g: Graph, node: Node, prf_keys: Node, status: IOStatus) -> Result<Node> {
746    let mut outputs = vec![];
747    let t = node.get_type()?;
748    let node_shares = get_node_shares(g.clone(), prf_keys, t, Some((node, status)))?;
749    // networking
750    for (i, node_share) in node_shares.iter().enumerate().take(PARTIES) {
751        let network_node = g.nop((*node_share).clone())?;
752        let im1 = ((i + PARTIES - 1) % PARTIES) as u64;
753        network_node.add_annotation(NodeAnnotation::Send(i as u64, im1))?;
754        outputs.push(network_node);
755    }
756    g.create_tuple(outputs)
757}
758
759fn share_input(g: Graph, node: Node, t: Type, prf_keys: Node, status: IOStatus) -> Result<Node> {
760    let plain_input = g.input(t)?;
761    copy_node_name(node, plain_input.clone())?;
762    share_node(g, plain_input, prf_keys, status)
763}
764
765/// Generates a triple of random PRF keys (k_0, k_1, k_2) such that k_i is generated by party i.
766/// The keys are then distributed such that
767/// the ith party has k_i and k_{i+1} (the index is taken modulo 3).
768pub(super) fn generate_prf_key_triple(g: Graph) -> Result<Vec<Node>> {
769    let key_t = array_type(vec![KEY_LENGTH], BIT);
770    let mut triple = vec![];
771    for party_id in 0..PARTIES {
772        let key = g.random(key_t.clone())?;
773        let key_sent = g.nop(key)?;
774        let prev_party_id = (party_id + PARTIES - 1) % PARTIES;
775        key_sent.add_annotation(NodeAnnotation::Send(party_id as u64, prev_party_id as u64))?;
776        triple.push(key_sent);
777    }
778    Ok(triple)
779}
780
781fn share_all_inputs(
782    in_graph: Graph,
783    out_graph: Graph,
784    input_party_map: Vec<IOStatus>,
785    prf_keys: Node,
786    is_prf_mul_key_needed: bool,
787    is_prf_b2a_key_needed: bool,
788    is_prf_truncate_key_needed: bool,
789) -> Result<Vec<Node>> {
790    let mut shared_inputs = if is_prf_mul_key_needed {
791        vec![prf_keys.clone()]
792    } else {
793        vec![]
794    };
795    if is_prf_b2a_key_needed {
796        // Create PRF keys for B2A.
797        // These are 2 tuples ((k_00, k_01, k_02), (k_10, k_11, k_12)) where
798        // k_00, k_01, k_02, k_10, k_11 should be known to party 0,
799        // k_01, k_02, k_10, k_11, k_12 should be known to party 1,
800        // all these keys should be known to party 2.
801        // Party i generate keys k_0i and k_1i and send it to other parties.
802        let prf_b2a_key = {
803            let mut keys = vec![];
804            for _ in 0..2 {
805                let key_triple = generate_prf_key_triple(out_graph.clone())?;
806                keys.push(key_triple);
807            }
808            // Stopping here will result in the following access pattern
809            // k_00, k_01, k_10, k_11 should be known to party 0,
810            // k_01, k_02, k_11, k_12 should be known to party 1,
811            // k_00, k_02, k_10, k_12 should be known to party 2.
812            // This means that party 0 should send k_10 to party 1,
813            keys[1][0] = keys[1][0].nop()?;
814            keys[1][0].add_annotation(NodeAnnotation::Send(0, 1))?;
815            // party 1 should send k_01, k_11 to party 2
816            keys[0][1] = keys[0][1].nop()?;
817            keys[0][1].add_annotation(NodeAnnotation::Send(1, 2))?;
818            keys[1][1] = keys[1][1].nop()?;
819            keys[1][1].add_annotation(NodeAnnotation::Send(1, 2))?;
820            // and party 2 should send k_02 to party 0.
821            keys[0][2] = keys[0][2].nop()?;
822            keys[0][2].add_annotation(NodeAnnotation::Send(2, 0))?;
823
824            let key_triple0 = out_graph.create_tuple(keys[0].clone())?;
825            let key_triple1 = out_graph.create_tuple(keys[1].clone())?;
826            out_graph.create_tuple(vec![key_triple0, key_triple1])?
827        };
828        shared_inputs.push(prf_b2a_key);
829    }
830    if is_prf_truncate_key_needed {
831        let key_t = array_type(vec![KEY_LENGTH], BIT);
832        let prf_truncate_key = out_graph.random(key_t)?;
833        shared_inputs.push(prf_truncate_key);
834    }
835
836    let mut input_id = 0usize;
837    for node in in_graph.get_nodes() {
838        if let Operation::Input(t) = node.get_operation() {
839            let shared_input = match input_party_map[input_id] {
840                IOStatus::Party(_) => share_input(
841                    out_graph.clone(),
842                    node.clone(),
843                    t,
844                    prf_keys.clone(),
845                    input_party_map[input_id].clone(),
846                )?,
847                IOStatus::Shared => {
848                    let new_node = out_graph.input(tuple_type(vec![t.clone(); PARTIES]))?;
849                    copy_node_name(node.clone(), new_node.clone())?;
850                    new_node
851                }
852                IOStatus::Public => {
853                    let new_node = out_graph.input(t)?;
854                    copy_node_name(node.clone(), new_node.clone())?;
855                    new_node
856                }
857            };
858            input_id += 1;
859            shared_inputs.push(shared_input);
860        }
861    }
862    Ok(shared_inputs)
863}
864
865pub(super) fn recursively_sum_shares(g: Graph, shares: Vec<Node>) -> Result<Node> {
866    let t = shares[0].get_type()?;
867    match t {
868        Type::Scalar(_) | Type::Array(_, _) => {
869            let mut res = shares[0].clone();
870            for share in shares.iter().skip(1) {
871                res = res.add(share.clone())?;
872            }
873            Ok(res)
874        }
875        Type::Tuple(types) => {
876            let mut revealed_sub_nodes = vec![];
877            for i in 0..types.len() as u64 {
878                let mut sub_shares = vec![];
879                for share in &shares {
880                    let sub_share = share.tuple_get(i)?;
881                    sub_shares.push(sub_share);
882                }
883                let revealed_sub_node = recursively_sum_shares(g.clone(), sub_shares)?;
884                revealed_sub_nodes.push(revealed_sub_node);
885            }
886            g.create_tuple(revealed_sub_nodes)
887        }
888        Type::Vector(length, element_type) => {
889            let mut revealed_sub_nodes = vec![];
890            for i in 0..length {
891                let i_node = g.constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?;
892                let mut sub_shares = vec![];
893                for share in &shares {
894                    let sub_share = share.vector_get(i_node.clone())?;
895                    sub_shares.push(sub_share);
896                }
897                let revealed_sub_node = recursively_sum_shares(g.clone(), sub_shares)?;
898                revealed_sub_nodes.push(revealed_sub_node);
899            }
900            g.create_vector((*element_type).clone(), revealed_sub_nodes)
901        }
902        Type::NamedTuple(names_types) => {
903            let mut revealed_sub_nodes = vec![];
904            for (name, _) in names_types {
905                let mut sub_shares = vec![];
906                for share in &shares {
907                    let sub_share = share.named_tuple_get(name.clone())?;
908                    sub_shares.push(sub_share);
909                }
910                let revealed_sub_node = recursively_sum_shares(g.clone(), sub_shares)?;
911                revealed_sub_nodes.push((name, revealed_sub_node));
912            }
913            g.create_named_tuple(revealed_sub_nodes)
914        }
915    }
916}
917
918/// Output parties ids must be in the range 0..PARTIES.
919fn reveal_output(g: Graph, out_node: Node, output_parties: Vec<IOStatus>) -> Result<Node> {
920    // If there are no parties obtaining revealed output, return output in the shared form
921    if output_parties.is_empty() {
922        return Ok(out_node);
923    }
924    // Extract output shares
925    let mut shares = vec![];
926    for i in 0..PARTIES as u64 {
927        let share = out_node.tuple_get(i)?;
928        shares.push(share);
929    }
930    if let IOStatus::Party(id) = output_parties[0] {
931        let party_id = id as usize;
932        let mut shares_to_reveal = shares.clone();
933        // Networking to obtain missing shares
934        let prev_party_id = (party_id + PARTIES - 1) % PARTIES;
935        let missing_share = shares_to_reveal[prev_party_id].nop()?;
936        shares_to_reveal[prev_party_id] = missing_share
937            .add_annotation(NodeAnnotation::Send(prev_party_id as u64, party_id as u64))?;
938        // Sum shares
939        let revealed_node = recursively_sum_shares(g, shares_to_reveal)?;
940        // If there are other parties waiting for a revealed value, send it to them
941        let result_node = if output_parties.len() > 1 {
942            let mut send_node = revealed_node;
943            for i in 1..PARTIES {
944                let party_to_send_id = (party_id + i) % PARTIES;
945                if output_parties.contains(&IOStatus::Party(party_to_send_id as u64)) {
946                    send_node = send_node.nop()?;
947                    send_node.add_annotation(NodeAnnotation::Send(
948                        party_id as u64,
949                        party_to_send_id as u64,
950                    ))?;
951                }
952            }
953            // Output node can't have Send annotation
954            send_node.nop()?
955        } else {
956            revealed_node
957        };
958        return Ok(result_node);
959    }
960    panic!("Shouldn't be here");
961}
962
963/// Compiles all the graphs of an already inlined context into graphs for secure computation and add it to another context.
964/// Namely, every plaintext operation is replaced by a related MPC protocol from the ABY3 framework.
965/// The given input-party map describes assigns every input to one of the following statuses:
966/// - public,
967/// - already shared,
968/// - should be shared by certain party.
969/// The `output_parties` argument contains ids of the parties (from 0..PARTIES) that obtain the revealed result of MPC computation.
970fn compile_to_mpc_context(
971    in_context: Context,
972    input_party_map: Vec<Vec<IOStatus>>,
973    output_parties: Vec<Vec<IOStatus>>,
974    out_context: Context,
975    out_mapping: &mut ContextMappings,
976) -> Result<()> {
977    in_context.check_finalized()?;
978
979    for (i, graph) in in_context.get_graphs().iter().enumerate() {
980        // compile the current graph to MPC
981        let is_input_private: Vec<bool> = input_party_map[i]
982            .iter()
983            .map(|x| *x != IOStatus::Public)
984            .collect();
985        let computation_graph = compile_to_mpc_graph(
986            graph.clone(),
987            is_input_private.clone(),
988            out_context.clone(),
989            out_mapping,
990        )?;
991
992        let new_graph = out_context.create_graph()?;
993        // Input tuple of PRF keys for zero sharing.
994        let prf_keys = {
995            let keys_vec = generate_prf_key_triple(new_graph.clone())?;
996            new_graph.create_tuple(keys_vec)?
997        };
998
999        // input nodes that are followed by secret sharing
1000        let is_prf_mul_key_needed =
1001            contains_node_annotation(computation_graph.clone(), NodeAnnotation::PRFMultiplication)?;
1002        let is_prf_b2a_key_needed =
1003            contains_node_annotation(computation_graph.clone(), NodeAnnotation::PRFB2A)?;
1004        let is_prf_truncate_key_needed =
1005            contains_node_annotation(computation_graph.clone(), NodeAnnotation::PRFTruncate)?;
1006        let shared_input = share_all_inputs(
1007            graph.clone(),
1008            new_graph.clone(),
1009            input_party_map[i].clone(),
1010            prf_keys.clone(),
1011            is_prf_mul_key_needed,
1012            is_prf_b2a_key_needed,
1013            is_prf_truncate_key_needed,
1014        )?;
1015
1016        // compute the MPC graph on the shared input
1017        let shared_result = new_graph.call(computation_graph.clone(), shared_input)?;
1018        // reveal the output to the given parties
1019        let is_output_private = {
1020            let out_node = computation_graph.get_output_node()?;
1021            let out_anno = out_node.get_annotations()?;
1022            out_anno.contains(&NodeAnnotation::Private)
1023        };
1024        let result = if is_output_private {
1025            reveal_output(new_graph.clone(), shared_result, output_parties[i].clone())?
1026        } else if output_parties[i].is_empty() {
1027            // if output is public and it should be secretly shared (no output parties), party 0 creates its secret sharing
1028            let node = share_node(
1029                new_graph.clone(),
1030                shared_result.clone(),
1031                prf_keys,
1032                IOStatus::Party(0),
1033            )?;
1034            node.add_annotation(NodeAnnotation::Private)?
1035        } else {
1036            shared_result
1037        };
1038        result.set_as_output()?;
1039        new_graph.finalize()?;
1040        out_mapping.insert_graph(graph.clone(), new_graph);
1041    }
1042    Ok(())
1043}
1044
1045/// Compiles all the graphs of an already inlined context into graphs for secure computation.
1046/// Namely, every plaintext operation is replaced by a related MPC protocol from the ABY3 framework.
1047/// The given input-party map describes what statuses of inputs (public, shared or owned by a party).
1048/// The `output_parties` argument contains ids of the parties (from 0..PARTIES) that obtain the revealed result of MPC computation.
1049/// Input of PRF nodes is always zero. Thus, the resulting context is insecure to evaluate!
1050/// To guarantee security, unique PRF inputs are assigned later.
1051/// If private, the output of the main graph is always a tuple of 3 elements where the first element is known to the first party,
1052/// the second to the second one etc. Thus, the first tuple element can be either a share or a revealed value known to the first party.
1053fn compile_to_mpc(
1054    context: Context,
1055    input_party_map: Vec<Vec<IOStatus>>,
1056    output_parties: Vec<Vec<IOStatus>>,
1057) -> Result<MappedContext> {
1058    for sub_map in &input_party_map {
1059        for status in sub_map {
1060            if let IOStatus::Party(id) = *status {
1061                if id >= PARTIES as u64 {
1062                    return Err(runtime_error!("Input party should have a valid party ID"));
1063                }
1064            }
1065        }
1066    }
1067    for sub_parties in &output_parties {
1068        for status in sub_parties {
1069            if let IOStatus::Party(id) = *status {
1070                if id >= PARTIES as u64 {
1071                    return Err(runtime_error!("Output party should have a valid party ID"));
1072                }
1073            } else {
1074                return Err(runtime_error!(
1075                    "Output status should be a party id or shared"
1076                ));
1077            }
1078        }
1079    }
1080    let new_context = create_context()?;
1081    let mut context_map = ContextMappings::default();
1082    compile_to_mpc_context(
1083        context.clone(),
1084        input_party_map,
1085        output_parties,
1086        new_context.clone(),
1087        &mut context_map,
1088    )?;
1089    let old_main_graph = context.get_main_graph()?;
1090    let main_graph = context_map.get_graph(old_main_graph);
1091    new_context.set_main_graph(main_graph)?;
1092    new_context.finalize()?;
1093    let mut mapped_context = MappedContext::new(new_context);
1094    mapped_context.mappings = context_map;
1095    Ok(mapped_context)
1096}
1097
1098/// Creates a new copy of an input context with PRF nodes containing globally unique inputs (iv's).
1099/// These global inputs are taken from the set {1,2,...,n} where n is the total number of PRF nodes.
1100pub fn uniquify_prf_id(context: Context) -> Result<Context> {
1101    let new_context = create_context()?;
1102    let mut context_map = ContextMappings::default();
1103    let graphs = context.get_graphs();
1104    let mut prf_id = 0;
1105    for graph in graphs {
1106        let out_graph = new_context.create_graph()?;
1107        let nodes = graph.get_nodes();
1108        for node in nodes {
1109            let op = node.get_operation();
1110            let op = if op.is_prf_operation() {
1111                prf_id += 1;
1112                op.update_prf_id(prf_id)?
1113            } else {
1114                op
1115            };
1116            let node_dependencies = node.get_node_dependencies();
1117            let new_node_dependencies: Vec<Node> = node_dependencies
1118                .iter()
1119                .map(|x| context_map.get_node((*x).clone()))
1120                .collect();
1121            let graph_dependencies = node.get_graph_dependencies();
1122            let new_graph_dependencies: Vec<Graph> = graph_dependencies
1123                .iter()
1124                .map(|x| context_map.get_graph((*x).clone()))
1125                .collect();
1126            let new_node = out_graph.add_node(new_node_dependencies, new_graph_dependencies, op)?;
1127            let annotations = node.get_annotations()?;
1128            for anno in annotations {
1129                new_node.add_annotation(anno)?;
1130            }
1131            copy_node_name(node.clone(), new_node.clone())?;
1132            context_map.insert_node(node, new_node);
1133        }
1134        let output_node = graph.get_output_node()?;
1135        let new_output_node = context_map.get_node(output_node);
1136        out_graph.set_output_node(new_output_node)?;
1137        out_graph.finalize()?;
1138        context_map.insert_graph(graph, out_graph.clone());
1139    }
1140    let old_main_graph = context.get_main_graph()?;
1141    let main_graph = context_map.get_graph(old_main_graph);
1142    new_context.set_main_graph(main_graph)?;
1143    new_context.finalize()?;
1144    Ok(new_context)
1145}
1146
1147/// Converts a given inlined context to its counterpart that operates on MPC shares and is ready for evaluation.
1148/// It includes a call to the MPC compiler, the custom operation instantiation and inlining with a given configuration.
1149/// After inlining this function provides a unique input to every PRF node.
1150/// The resulting context preserves only the names of input nodes.
1151pub fn prepare_for_mpc_evaluation(
1152    context: Context,
1153    input_party_map: Vec<Vec<IOStatus>>,
1154    output_parties: Vec<Vec<IOStatus>>,
1155    inline_config: InlineConfig,
1156) -> Result<Context> {
1157    let mpc_context = compile_to_mpc(context, input_party_map, output_parties)?.get_context();
1158    let instantiated_context = run_instantiation_pass(mpc_context)?.get_context();
1159    let inlined_context = inline_operations(instantiated_context, inline_config)?;
1160    uniquify_prf_id(inlined_context)
1161}
1162
1163fn print_stats(graph: Graph) -> Result<()> {
1164    let mut cnt = HashMap::<String, u64>::new();
1165    for node in graph.get_nodes() {
1166        let op_name = format!("{}", node.get_operation());
1167        *cnt.entry(op_name).or_insert(0) += 1;
1168    }
1169    let mut entries: Vec<(String, u64)> = cnt.iter().map(|e| (e.0.clone(), *e.1)).collect();
1170    entries.sort_by_key(|e| -(e.1 as i64));
1171    eprintln!("-------Stats--------");
1172    eprintln!("Total ops: {}", graph.get_nodes().len());
1173    for e in entries {
1174        eprintln!("{}\t{}", e.0, e.1);
1175    }
1176    Ok(())
1177}
1178
1179pub fn prepare_context<E>(
1180    context: Context,
1181    inline_config: InlineConfig,
1182    evaluator: E,
1183    print_unoptimized_stats: bool,
1184) -> Result<Context>
1185where
1186    E: Evaluator + Sized,
1187{
1188    eprintln!("Instantiating...");
1189    let context2 = run_instantiation_pass(context)?.get_context();
1190    eprintln!("Inlining...");
1191    let context3 = inline_operations(context2, inline_config)?;
1192    if print_unoptimized_stats {
1193        print_stats(context3.get_main_graph()?)?;
1194    }
1195    eprintln!("Optimizing...");
1196    optimize_context(context3, evaluator)
1197}
1198
1199/// Takes raw context (no inlining, etc.), and runs the whole pipeline (instantiation+inlining+MPC) on it,
1200/// to prepare to be used in runtime.
1201pub fn compile_context<T, E>(
1202    context: Context,
1203    input_parties: Vec<IOStatus>,
1204    output_parties: Vec<IOStatus>,
1205    inline_config: InlineConfig,
1206    get_evaluator: T,
1207) -> Result<Context>
1208where
1209    T: Fn() -> Result<E>,
1210    E: Evaluator + Sized,
1211{
1212    let evaluator0 = get_evaluator()?;
1213    let context4 = prepare_context(context, inline_config.clone(), evaluator0, true)?;
1214    print_stats(context4.get_main_graph()?)?;
1215    let mut number_of_inputs = 0;
1216    for node in context4.get_main_graph()?.get_nodes() {
1217        if node.get_operation().is_input() {
1218            number_of_inputs += 1;
1219        }
1220    }
1221    if input_parties.len() != number_of_inputs {
1222        return Err(runtime_error!(
1223            "Invalid number of input parties: {} expected, but {} found",
1224            number_of_inputs,
1225            input_parties.len()
1226        ));
1227    }
1228    eprintln!("input_parties = {input_parties:?}");
1229    eprintln!("output_parties = {output_parties:?}");
1230    let compiled_context0 = prepare_for_mpc_evaluation(
1231        context4,
1232        vec![input_parties],
1233        vec![output_parties],
1234        inline_config,
1235    )?;
1236    print_stats(compiled_context0.get_main_graph()?)?;
1237
1238    let evaluator1 = get_evaluator()?;
1239    let compiled_context = optimize_context(compiled_context0, evaluator1)?;
1240    print_stats(compiled_context.get_main_graph()?)?;
1241    Ok(compiled_context)
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246    use super::*;
1247    use crate::custom_ops::run_instantiation_pass;
1248    use crate::data_types::{
1249        array_type, get_types_vector, named_tuple_type, scalar_type, tuple_type, vector_type, BIT,
1250        INT32, INT64, UINT64, UINT8, VOID_TYPE,
1251    };
1252    use crate::data_values::Value;
1253    use crate::evaluators::random_evaluate;
1254    use crate::evaluators::simple_evaluator::evaluate_add_subtract_multiply;
1255    use crate::graphs::util::simple_context;
1256    use crate::graphs::SliceElement::{Ellipsis, SubArray};
1257    use crate::inline::inline_ops::{inline_operations, InlineConfig, InlineMode};
1258    use crate::random::PRNG;
1259
1260    use std::collections::HashMap;
1261
1262    #[test]
1263    fn test_malformed() {
1264        || -> Result<()> {
1265            let c = create_context()?;
1266            assert!(compile_to_mpc(
1267                c.clone(),
1268                vec![vec![IOStatus::Public]],
1269                vec![vec![IOStatus::Party(0)]]
1270            )
1271            .is_err());
1272            let g = c.create_graph()?;
1273            g.input(scalar_type(BIT))?.set_as_output()?;
1274            g.finalize()?;
1275            c.set_main_graph(g)?;
1276            c.finalize()?;
1277            assert!(compile_to_mpc(
1278                c.clone(),
1279                vec![vec![IOStatus::Party(3)]],
1280                vec![vec![IOStatus::Party(0)]]
1281            )
1282            .is_err());
1283            assert!(compile_to_mpc(
1284                c.clone(),
1285                vec![vec![IOStatus::Public]],
1286                vec![vec![IOStatus::Party(5)]]
1287            )
1288            .is_err());
1289            assert!(compile_to_mpc(
1290                c.clone(),
1291                vec![vec![IOStatus::Public]],
1292                vec![vec![IOStatus::Shared]]
1293            )
1294            .is_err());
1295            Ok(())
1296        }()
1297        .unwrap();
1298    }
1299
1300    fn reveal_private_value(value: Value, t: Type) -> Result<Value> {
1301        let shares = value.to_vector()?;
1302        if matches!(t.clone(), Type::Array(_, _) | Type::Scalar(_)) {
1303            let mut res = Value::zero_of_type(t.clone());
1304            for share in shares {
1305                res = evaluate_add_subtract_multiply(
1306                    t.clone(),
1307                    res.clone(),
1308                    t.clone(),
1309                    share,
1310                    Operation::Add,
1311                    t.clone(),
1312                )?;
1313            }
1314            return Ok(res);
1315        }
1316
1317        let vector_types = get_types_vector(t.clone())?;
1318        let mut shares_vec = vec![];
1319        for i in 0..PARTIES {
1320            shares_vec.push(shares[i].to_vector()?);
1321        }
1322        let mut res_vec = vec![];
1323        for i in 0..vector_types.len() {
1324            let mut tuple_vec = vec![];
1325            for j in 0..PARTIES {
1326                tuple_vec.push(shares_vec[j][i].clone());
1327            }
1328            let tuple = Value::from_vector(tuple_vec);
1329            res_vec.push(reveal_private_value(tuple, (*vector_types[i]).clone())?);
1330        }
1331        Ok(Value::from_vector(res_vec))
1332    }
1333
1334    #[test]
1335    fn test_input() {
1336        let seed = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
1337        let mut prng = PRNG::new(Some(*seed)).unwrap();
1338        let mut helper =
1339            |t: Type, input_status: IOStatus, output_parties: Vec<IOStatus>| -> Result<()> {
1340                let c = simple_context(|g| g.input(t.clone()))?;
1341                let mpc_mapped_context = compile_to_mpc(
1342                    c,
1343                    vec![vec![input_status.clone()]],
1344                    vec![output_parties.clone()],
1345                )?;
1346                let mpc_context = mpc_mapped_context.get_context();
1347                let mpc_graph = mpc_context.get_main_graph()?;
1348                let mut inputs = vec![];
1349                if input_status == IOStatus::Shared {
1350                    let tuple_t = tuple_type(vec![t.clone(); PARTIES]);
1351                    inputs.push(prng.get_random_value(tuple_t.clone())?);
1352                } else {
1353                    inputs.push(prng.get_random_value(t.clone())?);
1354                }
1355                let output = random_evaluate(mpc_graph.clone(), inputs.clone())?;
1356
1357                let mpc_computation_graph = mpc_context.get_graphs()[0].clone();
1358
1359                let computation_output_node = mpc_computation_graph.get_output_node()?;
1360                let computation_output_annotations = computation_output_node.get_annotations()?;
1361                if input_status != IOStatus::Public {
1362                    let expected = if input_status == IOStatus::Shared {
1363                        reveal_private_value(inputs[0].clone(), t.clone())?
1364                    } else {
1365                        inputs[0].clone()
1366                    };
1367                    // check that output is a sharing of expected
1368                    if output_parties.is_empty() {
1369                        let revealed_output = reveal_private_value(output.clone(), t.clone())?;
1370                        assert!(output.check_type(tuple_type(vec![t.clone(); PARTIES]))?);
1371                        assert_eq!(revealed_output, expected);
1372                    } else {
1373                        // check that output is of the right type
1374                        assert!(output.check_type(t.clone())?);
1375                        assert_eq!(output, expected.clone());
1376                    }
1377                    assert!(computation_output_annotations.contains(&NodeAnnotation::Private));
1378                } else {
1379                    // public input must be shared
1380                    if output_parties.is_empty() {
1381                        let revealed_output = reveal_private_value(output.clone(), t.clone())?;
1382                        assert!(output.check_type(tuple_type(vec![t.clone(); PARTIES]))?);
1383                        assert_eq!(revealed_output, inputs[0]);
1384                        // check that the final output is private (since it's shared)
1385                        let output_annotations = mpc_graph.get_output_node()?.get_annotations()?;
1386                        assert!(output_annotations.contains(&NodeAnnotation::Private));
1387                    } else {
1388                        assert_eq!(output, inputs[0]);
1389                    }
1390                    // computation output should be public on public inputs
1391                    assert!(!computation_output_annotations.contains(&NodeAnnotation::Private));
1392                }
1393                Ok(())
1394            };
1395
1396        helper(
1397            array_type(vec![2, 2], INT64),
1398            IOStatus::Party(0),
1399            vec![IOStatus::Party(1)],
1400        )
1401        .unwrap();
1402        helper(
1403            array_type(vec![2, 2], INT64),
1404            IOStatus::Public,
1405            vec![IOStatus::Party(0)],
1406        )
1407        .unwrap();
1408        helper(
1409            scalar_type(UINT64),
1410            IOStatus::Party(1),
1411            vec![IOStatus::Party(1), IOStatus::Party(2)],
1412        )
1413        .unwrap();
1414        helper(
1415            scalar_type(UINT64),
1416            IOStatus::Public,
1417            vec![IOStatus::Party(0)],
1418        )
1419        .unwrap();
1420        helper(
1421            scalar_type(UINT64),
1422            IOStatus::Shared,
1423            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1424        )
1425        .unwrap();
1426        helper(scalar_type(UINT64), IOStatus::Shared, vec![]).unwrap();
1427        helper(scalar_type(UINT64), IOStatus::Public, vec![]).unwrap();
1428    }
1429
1430    fn prepare_private_value(value: Value, t: Type) -> Result<Vec<Value>> {
1431        // private shares of value are generated as
1432        // value = (value + 2, -1, -1)
1433        if let Type::Scalar(st) | Type::Array(_, st) = t.clone() {
1434            let mut res = vec![];
1435            let zero = Value::zero_of_type(t.clone());
1436            let one = Value::from_scalar(1, st)?;
1437            let two = Value::from_scalar(2, st)?;
1438            for i in 0..PARTIES {
1439                let (add_sub, l_value, r_value) = match i {
1440                    0 => (Operation::Add, value.clone(), two.clone()),
1441                    1 => (Operation::Subtract, zero.clone(), one.clone()),
1442                    2 => (Operation::Subtract, zero.clone(), one.clone()),
1443                    _ => panic!("More than 3 parties are not supported"),
1444                };
1445                let share = evaluate_add_subtract_multiply(
1446                    t.clone(),
1447                    l_value,
1448                    scalar_type(st),
1449                    r_value,
1450                    add_sub,
1451                    t.clone(),
1452                )?;
1453                res.push(share);
1454            }
1455            return Ok(res);
1456        }
1457
1458        let vector_types = get_types_vector(t.clone())?;
1459        let mut shares = vec![vec![]; PARTIES];
1460        value.access_vector(|vector_values| {
1461            for i in 0..vector_values.len() {
1462                let tuple_i =
1463                    prepare_private_value(vector_values[i].clone(), (*vector_types[i]).clone())?;
1464                for j in 0..PARTIES {
1465                    shares[j].push(tuple_i[j].clone())
1466                }
1467            }
1468            Ok(())
1469        })?;
1470        let mut res = vec![];
1471        for share in shares {
1472            res.push(Value::from_vector(share));
1473        }
1474        Ok(res)
1475    }
1476
1477    fn prepare_value(value: Value, t: Type, is_input_private: bool) -> Result<Value> {
1478        if is_input_private {
1479            let tuple = prepare_private_value(value, t)?;
1480            return Ok(Value::from_vector(tuple));
1481        }
1482        Ok(value)
1483    }
1484
1485    fn prepare_input(
1486        input_types: Vec<Type>,
1487        is_input_shared: Vec<bool>,
1488    ) -> Result<(Vec<Value>, Vec<Value>)> {
1489        let seed: [u8; 16] = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1490        let mut prng = PRNG::new(Some(seed))?;
1491        let mut plain_inputs = vec![];
1492        let mut mpc_inputs = vec![];
1493        for i in 0..input_types.len() {
1494            let random_value = prng.get_random_value(input_types[i].clone())?;
1495            plain_inputs.push(random_value.clone());
1496            mpc_inputs.push(prepare_value(
1497                random_value,
1498                input_types[i].clone(),
1499                is_input_shared[i].clone(),
1500            )?);
1501        }
1502        Ok((plain_inputs, mpc_inputs))
1503    }
1504
1505    fn check_output(
1506        plain_graph: Graph,
1507        mpc_graph: Graph,
1508        plain_inputs: Vec<Value>,
1509        mpc_inputs: Vec<Value>,
1510        output_parties: Vec<IOStatus>,
1511        t: Type,
1512    ) -> Result<()> {
1513        let plain_output = random_evaluate(plain_graph.clone(), plain_inputs)?;
1514        let mpc_output = random_evaluate(mpc_graph.clone(), mpc_inputs)?;
1515
1516        if output_parties.is_empty() {
1517            // check that mpc_output is a sharing of plain_output
1518            assert!(mpc_output.check_type(tuple_type(vec![t.clone(); PARTIES]))?);
1519            let value_revealed = reveal_private_value(mpc_output.clone(), t.clone())?;
1520            assert_eq!(value_revealed, plain_output);
1521        } else {
1522            assert!(mpc_output.check_type(t.clone())?);
1523            assert_eq!(mpc_output, plain_output);
1524        }
1525
1526        Ok(())
1527    }
1528
1529    fn helper_one_input(
1530        input_types: Vec<Type>,
1531        op: Operation,
1532        input_party_map: Vec<IOStatus>,
1533        output_parties: Vec<IOStatus>,
1534    ) -> Result<()> {
1535        let c = simple_context(|g| {
1536            let mut input_nodes = vec![];
1537            for i in 0..input_types.len() {
1538                let input_node = g.input(input_types[i].clone())?;
1539                input_node.set_name(&format!("Input {}", i))?;
1540                input_nodes.push(input_node);
1541            }
1542            let o = if op != Operation::VectorGet {
1543                g.add_node(input_nodes, vec![], op)?
1544            } else {
1545                input_nodes[0].vector_get(g.zeros(scalar_type(UINT64))?)?
1546            };
1547            o.set_name("Plaintext operation")?;
1548            Ok(o)
1549        })?;
1550        let g = c.get_main_graph()?;
1551        let output_type = g.get_output_node()?.get_type()?;
1552
1553        let inline_config = InlineConfig {
1554            default_mode: InlineMode::Simple,
1555            ..Default::default()
1556        };
1557        let mpc_c = prepare_for_mpc_evaluation(
1558            c.clone(),
1559            vec![input_party_map.clone()],
1560            vec![output_parties.clone()],
1561            inline_config,
1562        )?;
1563        let mpc_graph = mpc_c.get_main_graph()?;
1564        // Check names
1565        let mpc_node_result = mpc_c.retrieve_node(mpc_graph.clone(), "Plaintext operation");
1566        assert!(mpc_node_result.is_err());
1567        for i in 0..input_types.len() {
1568            let node_name = format!("Input {}", i);
1569            let new_input_node = mpc_c.retrieve_node(mpc_graph.clone(), &node_name);
1570            assert!(new_input_node.is_ok());
1571        }
1572
1573        let is_input_shared = input_party_map
1574            .iter()
1575            .map(|x| *x == IOStatus::Shared)
1576            .collect();
1577        let (plain_inputs, mpc_inputs) = prepare_input(input_types.clone(), is_input_shared)?;
1578
1579        check_output(
1580            g,
1581            mpc_graph,
1582            plain_inputs,
1583            mpc_inputs,
1584            output_parties,
1585            output_type,
1586        )?;
1587        Ok(())
1588    }
1589
1590    fn test_helper_one_input(input_type: Type, op: Operation) -> Result<()> {
1591        helper_one_input(
1592            vec![input_type.clone()],
1593            op.clone(),
1594            vec![IOStatus::Party(0)],
1595            vec![IOStatus::Party(0)],
1596        )?;
1597        helper_one_input(
1598            vec![input_type.clone()],
1599            op.clone(),
1600            vec![IOStatus::Party(0)],
1601            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1602        )?;
1603        helper_one_input(
1604            vec![input_type.clone()],
1605            op.clone(),
1606            vec![IOStatus::Public],
1607            vec![IOStatus::Party(1)],
1608        )?;
1609        helper_one_input(
1610            vec![input_type.clone()],
1611            op.clone(),
1612            vec![IOStatus::Shared],
1613            vec![IOStatus::Party(0), IOStatus::Party(1)],
1614        )?;
1615        helper_one_input(
1616            vec![input_type.clone()],
1617            op.clone(),
1618            vec![IOStatus::Party(0)],
1619            vec![],
1620        )?;
1621        helper_one_input(vec![input_type], op, vec![IOStatus::Public], vec![])?;
1622        Ok(())
1623    }
1624
1625    #[test]
1626    fn test_permute_axes() {
1627        test_helper_one_input(
1628            array_type(vec![4, 2, 3], INT32),
1629            Operation::PermuteAxes(vec![2, 0, 1]),
1630        )
1631        .unwrap();
1632    }
1633
1634    #[test]
1635    fn test_array_to_vector() {
1636        test_helper_one_input(array_type(vec![3, 1], UINT8), Operation::ArrayToVector).unwrap();
1637    }
1638
1639    #[test]
1640    fn test_vector_to_array() {
1641        test_helper_one_input(
1642            vector_type(10, array_type(vec![4, 3], INT32)),
1643            Operation::VectorToArray,
1644        )
1645        .unwrap();
1646    }
1647
1648    #[test]
1649    fn test_vector_get() {
1650        test_helper_one_input(
1651            vector_type(10, array_type(vec![4, 3], INT32)),
1652            Operation::VectorGet,
1653        )
1654        .unwrap();
1655    }
1656
1657    #[test]
1658    fn test_get_slice() {
1659        test_helper_one_input(
1660            array_type(vec![10, 128], INT32),
1661            Operation::GetSlice(vec![Ellipsis, SubArray(None, Some(-1), None)]),
1662        )
1663        .unwrap();
1664    }
1665
1666    #[test]
1667    fn test_reshape() {
1668        test_helper_one_input(
1669            array_type(vec![10, 128], INT32),
1670            Operation::Reshape(array_type(vec![20, 64], INT32)),
1671        )
1672        .unwrap();
1673    }
1674
1675    #[test]
1676    fn test_tuple_get() {
1677        let t = array_type(vec![10, 128], INT32);
1678        test_helper_one_input(
1679            tuple_type(vec![t.clone(), scalar_type(UINT64), t]),
1680            Operation::TupleGet(1),
1681        )
1682        .unwrap();
1683    }
1684
1685    #[test]
1686    fn test_named_tuple_get() {
1687        let t = array_type(vec![10, 128], INT32);
1688        test_helper_one_input(
1689            named_tuple_type(vec![
1690                ("a".to_owned(), t.clone()),
1691                ("b".to_owned(), scalar_type(UINT64)),
1692                ("c".to_owned(), t),
1693            ]),
1694            Operation::NamedTupleGet("b".to_string()),
1695        )
1696        .unwrap();
1697    }
1698
1699    #[test]
1700    fn test_sum() {
1701        test_helper_one_input(
1702            array_type(vec![10, 5, 12], INT32),
1703            Operation::Sum(vec![0, 1]),
1704        )
1705        .unwrap();
1706    }
1707
1708    #[test]
1709    fn test_cum_sum() -> Result<()> {
1710        test_helper_one_input(array_type(vec![10, 5, 12], INT32), Operation::CumSum(0))?;
1711        test_helper_one_input(array_type(vec![10, 5, 12], INT32), Operation::CumSum(1))?;
1712        test_helper_one_input(array_type(vec![10, 5, 12], INT32), Operation::CumSum(2))?;
1713        Ok(())
1714    }
1715
1716    #[test]
1717    fn test_get() {
1718        test_helper_one_input(array_type(vec![10, 128], INT32), Operation::Get(vec![5, 4]))
1719            .unwrap();
1720    }
1721
1722    #[test]
1723    fn test_repeat() {
1724        test_helper_one_input(array_type(vec![10, 128], INT32), Operation::Repeat(10)).unwrap();
1725    }
1726
1727    fn helper_create_ops(
1728        input_types: Vec<Type>,
1729        op: Operation,
1730        input_party_map: Vec<IOStatus>,
1731        output_parties: Vec<IOStatus>,
1732        include_constant: bool,
1733    ) -> Result<()> {
1734        let c = create_context()?;
1735        let g = c.create_graph()?;
1736        let mut input_nodes = vec![];
1737        for i in 0..input_types.len() {
1738            let input_node = g.input(input_types[i].clone())?;
1739            input_node.set_name(&format!("Input {}", i))?;
1740            input_nodes.push(input_node);
1741        }
1742        let resolved_op = if include_constant {
1743            input_nodes.push(g.constant(
1744                input_types[0].clone(),
1745                Value::zero_of_type(input_types[0].clone()),
1746            )?);
1747            match op {
1748                Operation::CreateNamedTuple(mut names) => {
1749                    names.push("const".to_owned());
1750                    Operation::CreateNamedTuple(names)
1751                }
1752                Operation::Stack(outer_shape) => {
1753                    let mut pr = 1;
1754                    for x in &outer_shape {
1755                        pr *= *x;
1756                    }
1757                    Operation::Stack(vec![pr + 1])
1758                }
1759                _ => op,
1760            }
1761        } else {
1762            op
1763        };
1764        let o = g.add_node(input_nodes, vec![], resolved_op)?;
1765        o.set_name("Plaintext operation")?;
1766        let output_type = o.get_type()?;
1767        g.set_output_node(o.clone())?;
1768        g.finalize()?;
1769        c.set_main_graph(g.clone())?;
1770        c.finalize()?;
1771
1772        let inline_config = InlineConfig {
1773            default_mode: InlineMode::Simple,
1774            ..Default::default()
1775        };
1776        let mpc_c = prepare_for_mpc_evaluation(
1777            c.clone(),
1778            vec![input_party_map.clone()],
1779            vec![output_parties.clone()],
1780            inline_config,
1781        )?;
1782        let mpc_graph = mpc_c.get_main_graph()?;
1783        // Check names
1784        let mpc_node_result = mpc_c.retrieve_node(mpc_graph.clone(), "Plaintext operation");
1785        assert!(mpc_node_result.is_err());
1786        for i in 0..input_types.len() {
1787            let node_name = format!("Input {}", i);
1788            let new_input_node = mpc_c.retrieve_node(mpc_graph.clone(), &node_name);
1789            assert!(new_input_node.is_ok());
1790        }
1791
1792        let is_input_shared = input_party_map
1793            .iter()
1794            .map(|x| *x == IOStatus::Shared)
1795            .collect();
1796        let (plain_inputs, mpc_inputs) = prepare_input(input_types.clone(), is_input_shared)?;
1797
1798        check_output(
1799            g,
1800            mpc_graph,
1801            plain_inputs,
1802            mpc_inputs,
1803            output_parties,
1804            output_type,
1805        )?;
1806        Ok(())
1807    }
1808
1809    fn test_helper_create_ops(input_types: Vec<Type>, op: Operation) -> Result<()> {
1810        helper_create_ops(
1811            input_types.clone(),
1812            op.clone(),
1813            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1814            vec![IOStatus::Party(0)],
1815            true,
1816        )?;
1817        helper_create_ops(
1818            input_types.clone(),
1819            op.clone(),
1820            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1821            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1822            true,
1823        )?;
1824        helper_create_ops(
1825            input_types.clone(),
1826            op.clone(),
1827            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1828            vec![IOStatus::Party(0)],
1829            true,
1830        )?;
1831        helper_create_ops(
1832            input_types.clone(),
1833            op.clone(),
1834            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1835            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1836            true,
1837        )?;
1838
1839        helper_create_ops(
1840            input_types.clone(),
1841            op.clone(),
1842            vec![IOStatus::Party(0), IOStatus::Public, IOStatus::Party(1)],
1843            vec![IOStatus::Party(0)],
1844            true,
1845        )?;
1846        helper_create_ops(
1847            input_types.clone(),
1848            op.clone(),
1849            vec![IOStatus::Party(0), IOStatus::Public, IOStatus::Party(1)],
1850            vec![IOStatus::Party(0)],
1851            true,
1852        )?;
1853
1854        helper_create_ops(
1855            input_types.clone(),
1856            op.clone(),
1857            vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1858            vec![IOStatus::Party(0)],
1859            false,
1860        )?;
1861        helper_create_ops(
1862            input_types.clone(),
1863            op.clone(),
1864            vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1865            vec![IOStatus::Party(0)],
1866            false,
1867        )?;
1868        helper_create_ops(
1869            input_types.clone(),
1870            op.clone(),
1871            vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1872            vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1873            false,
1874        )?;
1875        helper_create_ops(
1876            input_types.clone(),
1877            op.clone(),
1878            vec![IOStatus::Shared, IOStatus::Shared, IOStatus::Party(0)],
1879            vec![IOStatus::Party(0), IOStatus::Party(1)],
1880            true,
1881        )?;
1882        helper_create_ops(
1883            input_types.clone(),
1884            op.clone(),
1885            vec![IOStatus::Shared, IOStatus::Shared, IOStatus::Party(0)],
1886            vec![],
1887            true,
1888        )?;
1889        helper_create_ops(
1890            input_types.clone(),
1891            op.clone(),
1892            vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1893            vec![],
1894            true,
1895        )?;
1896        Ok(())
1897    }
1898
1899    #[test]
1900    fn test_create_tuple() {
1901        let t = array_type(vec![10, 128], INT32);
1902        test_helper_create_ops(
1903            vec![t.clone(), scalar_type(UINT64), t.clone()],
1904            Operation::CreateTuple,
1905        )
1906        .unwrap();
1907        test_helper_create_ops(vec![t, VOID_TYPE], Operation::CreateTuple).unwrap();
1908    }
1909
1910    #[test]
1911    fn test_create_named_tuple() {
1912        let t = array_type(vec![10, 128], INT32);
1913        test_helper_create_ops(
1914            vec![t.clone(), scalar_type(UINT64), t.clone()],
1915            Operation::CreateNamedTuple(vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]),
1916        )
1917        .unwrap();
1918        test_helper_create_ops(
1919            vec![t, VOID_TYPE],
1920            Operation::CreateNamedTuple(vec!["a".to_owned(), "b".to_owned()]),
1921        )
1922        .unwrap();
1923    }
1924
1925    #[test]
1926    fn test_create_vector() {
1927        let t = array_type(vec![10, 128], INT32);
1928        test_helper_create_ops(vec![t.clone(); 3], Operation::CreateVector(t)).unwrap();
1929    }
1930
1931    #[test]
1932    fn test_zip() {
1933        let t = vector_type(10, array_type(vec![4, 3], INT32));
1934        test_helper_create_ops(vec![t.clone(); 3], Operation::Zip).unwrap();
1935    }
1936    #[test]
1937    fn test_stack() {
1938        let t = array_type(vec![10, 128], INT32);
1939        test_helper_create_ops(vec![t.clone(); 3], Operation::Stack(vec![3])).unwrap();
1940    }
1941    #[test]
1942    fn test_concatenate() {
1943        let t1 = array_type(vec![10, 1, 10], INT32);
1944        let t2 = array_type(vec![10, 2, 10], INT32);
1945        let t3 = array_type(vec![10, 3, 10], INT32);
1946        test_helper_create_ops(vec![t1, t2, t3], Operation::Concatenate(1)).unwrap();
1947    }
1948
1949    // Checks that every PRF node of a context has a unique input.
1950    fn check_prf_id(context: Context) -> Result<()> {
1951        let mut iv_node_map: HashMap<u64, Node> = HashMap::new();
1952        let graphs = context.get_graphs();
1953        for graph in graphs {
1954            let nodes = graph.get_nodes();
1955            for node in nodes {
1956                let iv = match node.get_operation() {
1957                    Operation::PRF(iv, _) => iv,
1958                    Operation::PermutationFromPRF(iv, _) => iv,
1959                    _ => continue,
1960                };
1961                if let Some(other_node) = iv_node_map.get(&iv) {
1962                    if *other_node != node {
1963                        return Err(runtime_error!("PRF node with non-unique iv"));
1964                    }
1965                } else {
1966                    iv_node_map.insert(iv, node);
1967                }
1968            }
1969        }
1970        Ok(())
1971    }
1972
1973    #[test]
1974    fn test_prf_id() {
1975        || -> Result<()> {
1976            let c = create_context()?;
1977            let g1 = c.create_graph()?;
1978            {
1979                let i = g1.input(scalar_type(UINT64))?;
1980                let o = i.a2b()?;
1981                g1.set_output_node(o)?;
1982                g1.finalize()?;
1983            }
1984            let g2 = c.create_graph()?;
1985            {
1986                let i = g2.input(scalar_type(UINT64))?;
1987                let o = i.a2b()?;
1988                g2.set_output_node(o)?;
1989                g2.finalize()?;
1990            }
1991
1992            c.set_main_graph(g2)?;
1993            c.finalize()?;
1994
1995            let mpc_c = compile_to_mpc(
1996                c,
1997                vec![vec![IOStatus::Party(0)], vec![IOStatus::Party(1)]],
1998                vec![vec![IOStatus::Party(1)], vec![IOStatus::Party(2)]],
1999            )?
2000            .get_context();
2001            let instantiated_context = run_instantiation_pass(mpc_c)?.get_context();
2002            assert!(check_prf_id(instantiated_context.clone()).is_err());
2003            let inlined_context = inline_operations(
2004                instantiated_context.clone(),
2005                InlineConfig {
2006                    default_mode: InlineMode::Simple,
2007                    ..Default::default()
2008                },
2009            )?;
2010            assert!(check_prf_id(inlined_context.clone()).is_err());
2011
2012            let validated_instantiated_context = uniquify_prf_id(instantiated_context)?;
2013            assert!(check_prf_id(validated_instantiated_context).is_ok());
2014            let validated_inlined_context = uniquify_prf_id(inlined_context)?;
2015            assert!(check_prf_id(validated_inlined_context).is_ok());
2016            Ok(())
2017        }()
2018        .unwrap()
2019    }
2020
2021    #[test]
2022    fn test_prf_ids_for_permutation_from_prf() -> Result<()> {
2023        let c = create_context()?;
2024        let g1 = c.create_graph()?;
2025        {
2026            let k = g1.random(array_type(vec![128], BIT))?;
2027            g1.permutation_from_prf(k, 0, 10)?.set_as_output()?;
2028            g1.finalize()?;
2029        }
2030        let g2 = c.create_graph()?;
2031        {
2032            let k = g2.random(array_type(vec![128], BIT))?;
2033            g2.prf(k, 0, scalar_type(UINT64))?.set_as_output()?;
2034            g2.finalize()?;
2035        }
2036        let g3 = c.create_graph()?;
2037        {
2038            let k = g3.random(array_type(vec![128], BIT))?;
2039            g3.permutation_from_prf(k, 0, 11)?.set_as_output()?;
2040            g3.finalize()?;
2041        }
2042        let g4 = c.create_graph()?;
2043        {
2044            let k = g4.random(array_type(vec![128], BIT))?;
2045            g4.prf(k, 0, scalar_type(INT32))?.set_as_output()?;
2046            g4.finalize()?;
2047        }
2048
2049        c.set_main_graph(g2)?;
2050        c.finalize()?;
2051        assert!(check_prf_id(c.clone()).is_err());
2052
2053        let c = uniquify_prf_id(c)?;
2054        assert!(check_prf_id(c).is_ok());
2055        Ok(())
2056    }
2057
2058    #[test]
2059    fn test_resharing() -> Result<()> {
2060        {
2061            let c = create_context()?;
2062            let g = c.create_graph()?;
2063            let i1 = g.input(array_type(vec![2, 10], BIT))?;
2064            let i2 = g.input(array_type(vec![10, 3], BIT))?;
2065            let prod = i1.matmul(i2)?;
2066            let out = prod.sum(vec![0])?;
2067            out.set_as_output()?;
2068            g.finalize()?;
2069            g.set_as_main()?;
2070            c.finalize()?;
2071
2072            let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true])?.0;
2073            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2074
2075            assert!(reshared_nodes.len() == 1);
2076            assert!(reshared_nodes.contains(&out));
2077
2078            let shared_nodes = propagate_private_annotations(g.clone(), vec![false, true])?.0;
2079            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2080
2081            assert!(reshared_nodes.len() == 0);
2082
2083            let shared_nodes = propagate_private_annotations(g.clone(), vec![false, false])?.0;
2084            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2085
2086            assert!(reshared_nodes.len() == 0);
2087        }
2088
2089        {
2090            let c = create_context()?;
2091            let g = c.create_graph()?;
2092            let i1 = g.input(array_type(vec![2, 10], BIT))?;
2093            let i2 = g.input(array_type(vec![10, 3], BIT))?;
2094            let prod = i1.matmul(i2)?;
2095            let i3 = g.input(array_type(vec![3, 4], BIT))?;
2096            let out = prod.matmul(i3)?;
2097            out.set_as_output()?;
2098            g.finalize()?;
2099            g.set_as_main()?;
2100            c.finalize()?;
2101
2102            let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true, true])?.0;
2103            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2104
2105            assert!(reshared_nodes.len() == 2);
2106            assert!(reshared_nodes.contains(&prod));
2107            assert!(reshared_nodes.contains(&out));
2108
2109            let shared_nodes = propagate_private_annotations(g.clone(), vec![false, true, true])?.0;
2110            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2111
2112            assert!(reshared_nodes.len() == 1);
2113            assert!(reshared_nodes.contains(&out));
2114
2115            let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true, false])?.0;
2116            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2117
2118            assert!(reshared_nodes.len() == 1);
2119            assert!(reshared_nodes.contains(&prod));
2120        }
2121        {
2122            let c = create_context()?;
2123            let g = c.create_graph()?;
2124            let i1 = g.input(array_type(vec![2, 10], BIT))?;
2125            let i2 = g.input(array_type(vec![10, 3], BIT))?;
2126            let prod12 = i1.matmul(i2)?;
2127            let i3 = g.input(array_type(vec![2, 4], BIT))?;
2128            let i4 = g.input(array_type(vec![4, 3], BIT))?;
2129            let prod34 = i3.matmul(i4)?;
2130            let out = prod12.add(prod34)?;
2131            out.set_as_output()?;
2132            g.finalize()?;
2133            g.set_as_main()?;
2134            c.finalize()?;
2135
2136            let shared_nodes = propagate_private_annotations(g.clone(), vec![true; 4])?.0;
2137            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2138
2139            assert!(reshared_nodes.len() == 1);
2140            assert!(reshared_nodes.contains(&out));
2141        }
2142        {
2143            let c = create_context()?;
2144            let g = c.create_graph()?;
2145            let i1 = g.input(array_type(vec![2, 10], INT32))?;
2146            let i2 = g.input(array_type(vec![10, 3], INT32))?;
2147            let prod12 = i1.matmul(i2)?;
2148            let i3 = g.input(array_type(vec![2, 3], INT32))?;
2149            let s123 = prod12.add(i3)?;
2150            let out = s123.a2b()?;
2151            out.set_as_output()?;
2152            g.finalize()?;
2153            g.set_as_main()?;
2154            c.finalize()?;
2155
2156            let shared_nodes = propagate_private_annotations(g.clone(), vec![true; 3])?.0;
2157            let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2158
2159            assert!(reshared_nodes.len() == 1);
2160            assert!(reshared_nodes.contains(&s123));
2161        }
2162
2163        Ok(())
2164    }
2165}