1use crate::custom_ops::{run_instantiation_pass, ContextMappings, CustomOperation, MappedContext};
2use crate::data_types::{array_type, scalar_type, tuple_type, Type, TypePointer, BIT, UINT64};
3use crate::data_values::Value;
4use crate::errors::Result;
5use crate::evaluators::Evaluator;
6use crate::graphs::{
7 copy_node_name, create_context, Context, Graph, Node, NodeAnnotation, Operation,
8};
9use crate::inline::inline_ops::{inline_operations, InlineConfig};
10use crate::mpc::mpc_apply_permutation::ApplyPermutationMPC;
11use crate::mpc::mpc_conversion::{A2BMPC, B2AMPC};
12use crate::mpc::mpc_truncate::{TruncateMPC, TruncateMPC2K};
13use crate::optimizer::optimize::optimize_context;
14
15use std::collections::HashMap;
16use std::collections::HashSet;
17
18use super::mpc_arithmetic::{add_mpc, general_multiply_mpc, subtract_mpc};
19use super::mpc_psi::JoinMPC;
20use super::mpc_radix_sort::RadixSortMPC;
21use super::resharing::{get_nodes_to_reshare, reshare};
22
23pub const PARTIES: usize = 3;
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum IOStatus {
29 Public, Party(u64), Shared, }
33
34pub const KEY_LENGTH: u64 = 128;
36
37pub(super) fn check_private_tuple(v: Vec<TypePointer>) -> Result<()> {
39 if v.len() != PARTIES {
40 return Err(runtime_error!(
41 "Private tuple should have {} values, but {} provided",
42 PARTIES,
43 v.len()
44 ));
45 }
46 let t = (*v[0]).clone();
47 for coef in v.iter().skip(1) {
48 if t != **coef {
49 return Err(runtime_error!(
50 "Private tuple should have value of the same type"
51 ));
52 }
53 }
54 Ok(())
55}
56
57fn is_one_node_private(nodes: &[Node], private_nodes: &HashSet<Node>) -> bool {
58 for node in nodes {
59 if private_nodes.contains(node) {
60 return true;
61 }
62 }
63 false
64}
65
66fn are_all_nodes_private(nodes: &[Node], private_nodes: &HashSet<Node>) -> bool {
67 for node in nodes {
68 if !private_nodes.contains(node) {
69 return false;
70 }
71 }
72 true
73}
74
75fn recursively_generate_node_shares(
82 g: Graph,
83 prf_keys: Vec<Node>,
84 t: Type,
85 node_to_share: Option<(Node, IOStatus)>,
86) -> Result<Vec<Node>> {
87 match t {
88 Type::Scalar(_) | Type::Array(_, _) => {
89 let mut random_shares = vec![];
90 for key in prf_keys {
91 let prf_i = g.prf(key, 0, t.clone())?;
92 random_shares.push(prf_i);
93 }
94 let mut node_shares = vec![];
95 for i in 0..PARTIES {
96 let ip1 = (i + 1) % PARTIES;
97 let alpha = g.subtract(random_shares[i].clone(), random_shares[ip1].clone())?;
98 node_shares.push(alpha);
99 }
100
101 match node_to_share {
102 Some((node, IOStatus::Party(id))) => {
103 node_shares[id as usize] = node_shares[id as usize].add(node)?;
104 }
105 Some((node, IOStatus::Public)) => {
106 node_shares[0] = node_shares[0].add(node)?;
107 }
108 Some((_, IOStatus::Shared)) => {
109 return Err(runtime_error!(
110 "Given node must belong to a party or be public"
111 ));
112 }
113 None => (),
114 }
115 Ok(node_shares)
116 }
117 Type::Tuple(types) => {
118 let mut unpacked_node_shares = vec![vec![]; PARTIES];
119 for (i, sub_t) in types.iter().enumerate() {
120 let sub_node_to_share = match node_to_share.clone() {
121 Some((node, party_id)) => Some((node.tuple_get(i as u64)?, party_id)),
122 None => None,
123 };
124 let sub_node_shares = recursively_generate_node_shares(
125 g.clone(),
126 prf_keys.clone(),
127 (**sub_t).clone(),
128 sub_node_to_share,
129 )?;
130 for party_id in 0..PARTIES {
131 unpacked_node_shares[party_id].push(sub_node_shares[party_id].clone());
132 }
133 }
134 let mut node_shares = vec![];
135 for unpacked_share in unpacked_node_shares {
136 node_shares.push(g.create_tuple(unpacked_share)?);
137 }
138 Ok(node_shares)
139 }
140 Type::Vector(length, element_type) => {
141 let mut unpacked_node_shares = vec![vec![]; PARTIES];
142 for i in 0..length {
143 let sub_node_to_share = match node_to_share.clone() {
144 Some((node, party_id)) => {
145 let i_node =
146 g.constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?;
147 Some((node.vector_get(i_node)?, party_id))
148 }
149 None => None,
150 };
151 let sub_node_shares = recursively_generate_node_shares(
152 g.clone(),
153 prf_keys.clone(),
154 (*element_type).clone(),
155 sub_node_to_share,
156 )?;
157 for party_id in 0..PARTIES {
158 unpacked_node_shares[party_id].push(sub_node_shares[party_id].clone());
159 }
160 }
161 let mut node_shares = vec![];
162 for unpacked_share in unpacked_node_shares {
163 node_shares.push(g.create_vector((*element_type).clone(), unpacked_share)?);
164 }
165 Ok(node_shares)
166 }
167 Type::NamedTuple(names_types) => {
168 let mut unpacked_node_shares = vec![vec![]; PARTIES];
169 for (name, sub_t) in &names_types {
170 let sub_node_to_share = match node_to_share.clone() {
171 Some((node, party_id)) => {
172 Some((node.named_tuple_get((*name).clone())?, party_id))
173 }
174 None => None,
175 };
176 let sub_node_shares = recursively_generate_node_shares(
177 g.clone(),
178 prf_keys.clone(),
179 (**sub_t).clone(),
180 sub_node_to_share,
181 )?;
182 for party_id in 0..PARTIES {
183 unpacked_node_shares[party_id]
184 .push(((*name).clone(), sub_node_shares[party_id].clone()));
185 }
186 }
187 let mut node_shares = vec![];
188 for unpacked_share in unpacked_node_shares {
189 node_shares.push(g.create_named_tuple(unpacked_share)?);
190 }
191 Ok(node_shares)
192 }
193 }
194}
195
196pub(crate) fn get_node_shares(
197 g: Graph,
198 prf_keys: Node,
199 t: Type,
200 node_to_share: Option<(Node, IOStatus)>,
201) -> Result<Vec<Node>> {
202 let mut prf_keys_vec = vec![];
203 for i in 0..PARTIES {
204 let key = g.tuple_get(prf_keys.clone(), i as u64)?;
205 prf_keys_vec.push(key);
206 }
207 recursively_generate_node_shares(g, prf_keys_vec, t, node_to_share)
208}
209
210pub(super) fn get_zero_shares(g: Graph, prf_keys: Node, t: Type) -> Result<Vec<Node>> {
211 get_node_shares(g, prf_keys, t, None)
212}
213
214pub(super) fn propagate_private_annotations(
219 graph: Graph,
220 is_input_private: Vec<bool>,
221) -> Result<(HashSet<Node>, bool, bool, bool)> {
222 let mut private_nodes: HashSet<Node> = HashSet::new();
223 let mut use_prf_for_mul = false;
224 let mut use_prf_for_b2a = false;
225 let mut use_prf_for_truncate2k = false;
226 let mut input_id = 0usize;
227 for node in graph.get_nodes() {
228 let op = node.get_operation();
229 match op {
230 Operation::Input(_) => {
231 if is_input_private[input_id] {
232 private_nodes.insert(node);
233 }
234 input_id += 1;
235 }
236 Operation::Add
237 | Operation::Subtract
238 | Operation::Multiply
239 | Operation::MixedMultiply
240 | Operation::Dot
241 | Operation::Matmul
242 | Operation::Gemm(_, _)
243 | Operation::Join(_, _)
244 | Operation::JoinWithColumnMasks(_, _)
245 | Operation::A2B
246 | Operation::B2A(_)
247 | Operation::PermuteAxes(_)
248 | Operation::ArrayToVector
249 | Operation::TupleGet(_)
250 | Operation::NamedTupleGet(_)
251 | Operation::VectorToArray
252 | Operation::GetSlice(_)
253 | Operation::Reshape(_)
254 | Operation::Sum(_)
255 | Operation::CumSum(_)
256 | Operation::Get(_)
257 | Operation::CreateTuple
258 | Operation::CreateNamedTuple(_)
259 | Operation::CreateVector(_)
260 | Operation::Stack(_)
261 | Operation::ApplyPermutation(_)
262 | Operation::Sort(_)
263 | Operation::Concatenate(_)
264 | Operation::Zip
265 | Operation::Repeat(_) => {
266 let dependencies = node.get_node_dependencies();
267 if is_one_node_private(&dependencies, &private_nodes) {
268 private_nodes.insert(node.clone());
269 if matches!(
270 op,
271 Operation::Join(_, _) | Operation::JoinWithColumnMasks(_, _)
272 ) {
273 use_prf_for_mul = true;
274 }
275 }
276 if ([
277 Operation::Multiply,
278 Operation::Dot,
279 Operation::Matmul,
280 Operation::A2B,
281 ]
282 .contains(&op)
283 || matches!(op, Operation::Gemm(_, _)))
284 && are_all_nodes_private(&dependencies, &private_nodes)
285 {
286 use_prf_for_mul = true;
287 }
288 if matches!(op, Operation::B2A(_))
289 && are_all_nodes_private(&dependencies, &private_nodes)
290 {
291 use_prf_for_mul = true;
292 use_prf_for_b2a = true;
293 }
294 if matches!(op, Operation::Sort(_)) && private_nodes.contains(&dependencies[0]) {
295 use_prf_for_mul = true;
296 }
297 if matches!(op, Operation::ApplyPermutation(_))
298 && private_nodes.contains(&dependencies[1])
299 {
300 use_prf_for_mul = true;
301 }
302 if matches!(op, Operation::MixedMultiply)
303 && private_nodes.contains(&dependencies[1])
304 {
305 use_prf_for_mul = true;
306 }
307 }
308 Operation::Truncate(scale) => {
309 let dependencies = node.get_node_dependencies();
310 if is_one_node_private(&dependencies, &private_nodes) {
311 private_nodes.insert(node.clone());
312 }
313
314 if are_all_nodes_private(&dependencies, &private_nodes) {
315 use_prf_for_mul = true;
316 if scale.is_power_of_two() {
317 use_prf_for_truncate2k = true;
318 }
319 }
320 }
321 Operation::Constant(_, _) | Operation::Zeros(_) | Operation::Ones(_) => {
322 }
324 Operation::VectorGet => {
325 let dependencies = node.get_node_dependencies();
326 if private_nodes.contains(&dependencies[1]) {
327 return Err(runtime_error!("VectorGet can't have a private index"));
328 }
329 if private_nodes.contains(&dependencies[0]) {
330 private_nodes.insert(node.clone());
331 }
332 }
333 _ => {
334 return Err(runtime_error!(
335 "MPC compiler can't preprocess inputs of {}",
336 op
337 ));
338 }
339 }
340 }
341 Ok((
342 private_nodes,
343 use_prf_for_mul,
344 use_prf_for_b2a,
345 use_prf_for_truncate2k,
346 ))
347}
348
349pub(super) fn is_array_shared(array: &Node) -> Result<bool> {
350 Ok(array.get_type()?.is_tuple())
351}
352
353pub(super) fn compile_to_mpc_graph(
354 in_graph: Graph,
355 is_input_private: Vec<bool>,
356 out_context: Context,
357 out_mapping: &mut ContextMappings,
358) -> Result<Graph> {
359 let out_graph = out_context.create_graph()?;
360
361 let (private_nodes, use_prf_for_mul, use_prf_for_b2a, use_prf_for_truncate2k) =
362 propagate_private_annotations(in_graph.clone(), is_input_private)?;
363 let prf_keys_mul = if use_prf_for_mul {
366 let key_t = array_type(vec![KEY_LENGTH], BIT);
368 let key_inputs = vec![key_t; PARTIES];
369 let keys_type = tuple_type(key_inputs);
370 let node = out_graph.input(keys_type)?;
371 node.add_annotation(NodeAnnotation::PRFMultiplication)?;
372 Some(node)
373 } else {
374 None
375 };
376 let prf_keys_b2a = if use_prf_for_b2a {
379 let key_t = array_type(vec![KEY_LENGTH], BIT);
381 let key_inputs = vec![key_t; PARTIES];
382 let key_triple_type = tuple_type(key_inputs);
383 let keys_type = tuple_type(vec![key_triple_type; 2]);
384 let node = out_graph.input(keys_type)?;
385 node.add_annotation(NodeAnnotation::PRFB2A)?;
386 Some(node)
387 } else {
388 None
389 };
390 let prf_keys_truncate2k = if use_prf_for_truncate2k {
393 let key_t = array_type(vec![KEY_LENGTH], BIT);
395 let node = out_graph.input(key_t)?;
396 node.add_annotation(NodeAnnotation::PRFTruncate)?;
397 Some(node)
398 } else {
399 None
400 };
401
402 let apply_op = |node_to_be_private: Node,
415 op: Operation,
416 node_dependencies: Vec<Node>,
417 old_dependencies: Vec<Node>|
418 -> Result<Node> {
419 if !private_nodes.contains(&node_to_be_private) {
420 return out_graph.add_node(node_dependencies, vec![], op);
421 }
422 if let Operation::Input(t) = op.clone() {
423 let tuple_t = tuple_type(vec![t; PARTIES]);
424 return out_graph.input(tuple_t);
425 }
426 let mut result_shares = vec![];
427 for i in 0..PARTIES {
428 let share = match op.clone() {
429 Operation::VectorGet => vec![
430 out_graph.tuple_get(node_dependencies[0].clone(), i as u64)?,
431 node_dependencies[1].clone(),
432 ],
433 _ => {
434 let mut share_vec = vec![];
435 for (j, old_node) in old_dependencies.iter().enumerate() {
436 if private_nodes.contains(old_node) {
437 let new_node =
439 out_graph.tuple_get(node_dependencies[j].clone(), i as u64)?;
440 share_vec.push(new_node);
441 } else {
442 if i == 0 {
445 share_vec.push(node_dependencies[j].clone())
446 } else {
447 share_vec.push(out_graph.zeros(node_dependencies[j].get_type()?)?);
448 }
449 }
450 }
451 share_vec
452 }
453 };
454 let result_share = out_graph.add_node(share, vec![], op.clone())?;
455 result_shares.push(result_share);
456 }
457 out_graph.create_tuple(result_shares)
458 };
459
460 let nodes_to_reshare = get_nodes_to_reshare(&in_graph, &private_nodes)?;
461
462 for node in in_graph.get_nodes() {
463 let op = node.get_operation();
464 let mut new_node = match op.clone() {
465 Operation::Input(_) => apply_op(node.clone(), op, vec![], vec![])?,
466 Operation::Add | Operation::Subtract => {
467 let dependencies = node.get_node_dependencies();
468 let input0 = dependencies[0].clone();
469 let input1 = dependencies[1].clone();
470 let new_input0 = out_mapping.get_node(input0);
471 let new_input1 = out_mapping.get_node(input1);
472 match op.clone() {
473 Operation::Add => add_mpc(new_input0, new_input1)?,
474 Operation::Subtract => subtract_mpc(new_input0, new_input1)?,
475 _ => panic!("Should not be here"),
476 }
477 }
478 Operation::Multiply
479 | Operation::MixedMultiply
480 | Operation::Dot
481 | Operation::Matmul
482 | Operation::Gemm(_, _) => {
483 let dependencies = node.get_node_dependencies();
484 let input0 = dependencies[0].clone();
485 let input1 = dependencies[1].clone();
486 let new_input0 = out_mapping.get_node(input0.clone());
487 let new_input1 = out_mapping.get_node(input1.clone());
488 general_multiply_mpc(new_input0, new_input1, op, prf_keys_mul.clone(), false)?
491 }
492 Operation::Join(join_t, headers) | Operation::JoinWithColumnMasks(join_t, headers) => {
493 let dependencies = node.get_node_dependencies();
494 let input0 = dependencies[0].clone();
495 let input1 = dependencies[1].clone();
496 let new_input0 = out_mapping.get_node(input0.clone());
497 let new_input1 = out_mapping.get_node(input1.clone());
498 let mut headers_vec = vec![];
499 for headers_pair in headers {
500 headers_vec.push(headers_pair);
501 }
502 let custom_op = match op {
503 Operation::Join(_, _) => CustomOperation::new(JoinMPC {
504 join_t,
505 headers: headers_vec,
506 has_column_masks: false,
507 }),
508 Operation::JoinWithColumnMasks(_, _) => CustomOperation::new(JoinMPC {
509 join_t,
510 headers: headers_vec,
511 has_column_masks: true,
512 }),
513 _ => {
514 return Err(runtime_error!("Shouldn't be here"));
515 }
516 };
517
518 if private_nodes.contains(&node) {
519 let keys = match prf_keys_mul {
522 Some(ref k) => k.clone(),
523 None => {
524 return Err(runtime_error!("Propagation of annotations failed"));
525 }
526 };
527 out_graph.custom_op(
528 custom_op,
529 vec![new_input0.clone(), new_input1.clone(), keys],
530 )?
531 } else {
532 out_graph.custom_op(custom_op, vec![new_input0.clone(), new_input1.clone()])?
533 }
534 }
535 Operation::ApplyPermutation(inverse_permutation) => {
536 let dependencies = node.get_node_dependencies();
537 let input = dependencies[0].clone();
538 let permutation = dependencies[1].clone();
539 let new_input = out_mapping.get_node(input.clone());
540 let new_permutation = out_mapping.get_node(permutation.clone());
541 let custom_op = CustomOperation::new(ApplyPermutationMPC {
542 inverse_permutation,
543 reveal_output: false,
544 });
545 if private_nodes.contains(&permutation) {
546 let keys = match prf_keys_mul {
549 Some(ref k) => k.clone(),
550 None => {
551 return Err(runtime_error!("Propagation of annotations failed"));
552 }
553 };
554 out_graph.custom_op(custom_op, vec![new_input, new_permutation, keys])?
555 } else {
556 out_graph.custom_op(custom_op, vec![new_input, new_permutation])?
557 }
558 }
559 Operation::Sort(key) => {
560 let dependencies = node.get_node_dependencies();
561 let mut mapped_dependencies = dependencies
562 .into_iter()
563 .map(|d| out_mapping.get_node(d))
564 .collect::<Vec<Node>>();
565 let custom_op = CustomOperation::new(RadixSortMPC::new(key));
566 if private_nodes.contains(&node) {
567 let keys = match prf_keys_mul {
570 Some(ref k) => k.clone(),
571 None => {
572 return Err(runtime_error!("Propagation of annotations failed"));
573 }
574 };
575 mapped_dependencies.push(keys);
576 }
577 out_graph.custom_op(custom_op, mapped_dependencies)?
578 }
579 Operation::Truncate(scale) => {
580 let dependencies = node.get_node_dependencies();
581 let input = dependencies[0].clone();
582 let new_input = out_mapping.get_node(input.clone());
583 let custom_op = if scale.is_power_of_two() {
584 let k = scale.trailing_zeros() as u64;
585 CustomOperation::new(TruncateMPC2K { k })
586 } else {
587 CustomOperation::new(TruncateMPC { scale })
588 };
589 if private_nodes.contains(&input) {
590 let prf_mul_keys = match prf_keys_mul {
593 Some(ref k) => k.clone(),
594 None => {
595 return Err(runtime_error!("Propagation of annotations failed"));
596 }
597 };
598
599 if scale.is_power_of_two() {
600 let prf_truncate_keys = match prf_keys_truncate2k {
601 Some(ref k) => k.clone(),
602 None => {
603 return Err(runtime_error!("Propagation of annotations failed"));
604 }
605 };
606
607 out_graph.custom_op(
608 custom_op,
609 vec![new_input.clone(), prf_mul_keys, prf_truncate_keys],
610 )?
611 } else {
612 out_graph.custom_op(custom_op, vec![new_input.clone(), prf_mul_keys])?
613 }
614 } else {
615 out_graph.custom_op(custom_op, vec![new_input.clone()])?
616 }
617 }
618 Operation::A2B => {
619 let dependencies = node.get_node_dependencies();
620 let input = dependencies[0].clone();
621 let new_input = out_mapping.get_node(input.clone());
622 let custom_op = CustomOperation::new(A2BMPC {});
623 if private_nodes.contains(&input) {
624 let keys = match prf_keys_mul {
627 Some(ref k) => k.clone(),
628 None => {
629 return Err(runtime_error!("Propagation of annotations failed"));
630 }
631 };
632 out_graph.custom_op(custom_op, vec![new_input.clone(), keys])?
633 } else {
634 out_graph.custom_op(custom_op, vec![new_input.clone()])?
635 }
636 }
637 Operation::B2A(st) => {
638 let dependencies = node.get_node_dependencies();
639 let input = dependencies[0].clone();
640 let new_input = out_mapping.get_node(input.clone());
641 let custom_op = CustomOperation::new(B2AMPC { st });
642 if private_nodes.contains(&input) {
643 let keys_mul = match prf_keys_mul {
646 Some(ref k) => k.clone(),
647 None => {
648 return Err(runtime_error!("Propagation of annotations failed"));
649 }
650 };
651 let keys_b2a = match prf_keys_b2a {
652 Some(ref k) => k.clone(),
653 None => {
654 return Err(runtime_error!("Propagation of annotations failed"));
655 }
656 };
657 out_graph.custom_op(custom_op, vec![new_input.clone(), keys_mul, keys_b2a])?
658 } else {
659 out_graph.custom_op(custom_op, vec![new_input.clone()])?
660 }
661 }
662 Operation::Constant(t, v) => out_graph.constant(t, v)?,
663 Operation::Zeros(t) => out_graph.zeros(t)?,
664 Operation::Ones(t) => out_graph.ones(t)?,
665 Operation::PermuteAxes(_)
666 | Operation::ArrayToVector
667 | Operation::VectorToArray
668 | Operation::TupleGet(_)
669 | Operation::NamedTupleGet(_)
670 | Operation::GetSlice(_)
671 | Operation::Reshape(_)
672 | Operation::Sum(_)
673 | Operation::CumSum(_)
674 | Operation::Get(_)
675 | Operation::Repeat(_) => {
676 let dependencies = node.get_node_dependencies();
677 let input = dependencies[0].clone();
678 let new_input = out_mapping.get_node(input.clone());
679 apply_op(input, op, vec![new_input], dependencies)?
680 }
681 Operation::VectorGet => {
682 let dependencies = node.get_node_dependencies();
683 let vector = dependencies[0].clone();
684 let index = dependencies[1].clone();
685 let new_vector = out_mapping.get_node(vector.clone());
686 let new_index = out_mapping.get_node(index.clone());
687
688 apply_op(vector, op, vec![new_vector, new_index], vec![])?
689 }
690 Operation::CreateTuple
691 | Operation::CreateNamedTuple(_)
692 | Operation::CreateVector(_)
693 | Operation::Stack(_)
694 | Operation::Concatenate(_)
695 | Operation::Zip => {
696 let dependencies = node.get_node_dependencies();
697 let new_dependencies: Vec<Node> = dependencies
698 .iter()
699 .map(|x| out_mapping.get_node((*x).clone()))
700 .collect();
701 apply_op(node.clone(), op, new_dependencies, dependencies)?
702 }
703 _ => {
704 return Err(runtime_error!(
705 "MPC compilation for {} not yet implemented",
706 op
707 ));
708 }
709 };
710 if private_nodes.contains(&node) {
711 new_node = if nodes_to_reshare.contains(&node) {
712 let keys_mul = match prf_keys_mul {
714 Some(ref k) => k.clone(),
715 None => {
716 return Err(runtime_error!("Propagation of annotations failed"));
717 }
718 };
719 reshare(&new_node, &keys_mul)?
720 } else {
721 new_node
722 };
723 new_node.add_annotation(NodeAnnotation::Private)?;
724 }
725 out_mapping.insert_node(node, new_node);
726 }
727 let old_output_node = in_graph.get_output_node()?;
728 let output_node = out_mapping.get_node(old_output_node);
729 out_graph.set_output_node(output_node)?;
730 out_graph.finalize()?;
731 Ok(out_graph)
732}
733
734fn contains_node_annotation(g: Graph, annotation: NodeAnnotation) -> Result<bool> {
735 let nodes = g.get_nodes();
736 for node in nodes {
737 let annotations = node.get_annotations()?;
738 if annotations.contains(&annotation) {
739 return Ok(true);
740 }
741 }
742 Ok(false)
743}
744
745fn share_node(g: Graph, node: Node, prf_keys: Node, status: IOStatus) -> Result<Node> {
746 let mut outputs = vec![];
747 let t = node.get_type()?;
748 let node_shares = get_node_shares(g.clone(), prf_keys, t, Some((node, status)))?;
749 for (i, node_share) in node_shares.iter().enumerate().take(PARTIES) {
751 let network_node = g.nop((*node_share).clone())?;
752 let im1 = ((i + PARTIES - 1) % PARTIES) as u64;
753 network_node.add_annotation(NodeAnnotation::Send(i as u64, im1))?;
754 outputs.push(network_node);
755 }
756 g.create_tuple(outputs)
757}
758
759fn share_input(g: Graph, node: Node, t: Type, prf_keys: Node, status: IOStatus) -> Result<Node> {
760 let plain_input = g.input(t)?;
761 copy_node_name(node, plain_input.clone())?;
762 share_node(g, plain_input, prf_keys, status)
763}
764
765pub(super) fn generate_prf_key_triple(g: Graph) -> Result<Vec<Node>> {
769 let key_t = array_type(vec![KEY_LENGTH], BIT);
770 let mut triple = vec![];
771 for party_id in 0..PARTIES {
772 let key = g.random(key_t.clone())?;
773 let key_sent = g.nop(key)?;
774 let prev_party_id = (party_id + PARTIES - 1) % PARTIES;
775 key_sent.add_annotation(NodeAnnotation::Send(party_id as u64, prev_party_id as u64))?;
776 triple.push(key_sent);
777 }
778 Ok(triple)
779}
780
781fn share_all_inputs(
782 in_graph: Graph,
783 out_graph: Graph,
784 input_party_map: Vec<IOStatus>,
785 prf_keys: Node,
786 is_prf_mul_key_needed: bool,
787 is_prf_b2a_key_needed: bool,
788 is_prf_truncate_key_needed: bool,
789) -> Result<Vec<Node>> {
790 let mut shared_inputs = if is_prf_mul_key_needed {
791 vec![prf_keys.clone()]
792 } else {
793 vec![]
794 };
795 if is_prf_b2a_key_needed {
796 let prf_b2a_key = {
803 let mut keys = vec![];
804 for _ in 0..2 {
805 let key_triple = generate_prf_key_triple(out_graph.clone())?;
806 keys.push(key_triple);
807 }
808 keys[1][0] = keys[1][0].nop()?;
814 keys[1][0].add_annotation(NodeAnnotation::Send(0, 1))?;
815 keys[0][1] = keys[0][1].nop()?;
817 keys[0][1].add_annotation(NodeAnnotation::Send(1, 2))?;
818 keys[1][1] = keys[1][1].nop()?;
819 keys[1][1].add_annotation(NodeAnnotation::Send(1, 2))?;
820 keys[0][2] = keys[0][2].nop()?;
822 keys[0][2].add_annotation(NodeAnnotation::Send(2, 0))?;
823
824 let key_triple0 = out_graph.create_tuple(keys[0].clone())?;
825 let key_triple1 = out_graph.create_tuple(keys[1].clone())?;
826 out_graph.create_tuple(vec![key_triple0, key_triple1])?
827 };
828 shared_inputs.push(prf_b2a_key);
829 }
830 if is_prf_truncate_key_needed {
831 let key_t = array_type(vec![KEY_LENGTH], BIT);
832 let prf_truncate_key = out_graph.random(key_t)?;
833 shared_inputs.push(prf_truncate_key);
834 }
835
836 let mut input_id = 0usize;
837 for node in in_graph.get_nodes() {
838 if let Operation::Input(t) = node.get_operation() {
839 let shared_input = match input_party_map[input_id] {
840 IOStatus::Party(_) => share_input(
841 out_graph.clone(),
842 node.clone(),
843 t,
844 prf_keys.clone(),
845 input_party_map[input_id].clone(),
846 )?,
847 IOStatus::Shared => {
848 let new_node = out_graph.input(tuple_type(vec![t.clone(); PARTIES]))?;
849 copy_node_name(node.clone(), new_node.clone())?;
850 new_node
851 }
852 IOStatus::Public => {
853 let new_node = out_graph.input(t)?;
854 copy_node_name(node.clone(), new_node.clone())?;
855 new_node
856 }
857 };
858 input_id += 1;
859 shared_inputs.push(shared_input);
860 }
861 }
862 Ok(shared_inputs)
863}
864
865pub(super) fn recursively_sum_shares(g: Graph, shares: Vec<Node>) -> Result<Node> {
866 let t = shares[0].get_type()?;
867 match t {
868 Type::Scalar(_) | Type::Array(_, _) => {
869 let mut res = shares[0].clone();
870 for share in shares.iter().skip(1) {
871 res = res.add(share.clone())?;
872 }
873 Ok(res)
874 }
875 Type::Tuple(types) => {
876 let mut revealed_sub_nodes = vec![];
877 for i in 0..types.len() as u64 {
878 let mut sub_shares = vec![];
879 for share in &shares {
880 let sub_share = share.tuple_get(i)?;
881 sub_shares.push(sub_share);
882 }
883 let revealed_sub_node = recursively_sum_shares(g.clone(), sub_shares)?;
884 revealed_sub_nodes.push(revealed_sub_node);
885 }
886 g.create_tuple(revealed_sub_nodes)
887 }
888 Type::Vector(length, element_type) => {
889 let mut revealed_sub_nodes = vec![];
890 for i in 0..length {
891 let i_node = g.constant(scalar_type(UINT64), Value::from_scalar(i, UINT64)?)?;
892 let mut sub_shares = vec![];
893 for share in &shares {
894 let sub_share = share.vector_get(i_node.clone())?;
895 sub_shares.push(sub_share);
896 }
897 let revealed_sub_node = recursively_sum_shares(g.clone(), sub_shares)?;
898 revealed_sub_nodes.push(revealed_sub_node);
899 }
900 g.create_vector((*element_type).clone(), revealed_sub_nodes)
901 }
902 Type::NamedTuple(names_types) => {
903 let mut revealed_sub_nodes = vec![];
904 for (name, _) in names_types {
905 let mut sub_shares = vec![];
906 for share in &shares {
907 let sub_share = share.named_tuple_get(name.clone())?;
908 sub_shares.push(sub_share);
909 }
910 let revealed_sub_node = recursively_sum_shares(g.clone(), sub_shares)?;
911 revealed_sub_nodes.push((name, revealed_sub_node));
912 }
913 g.create_named_tuple(revealed_sub_nodes)
914 }
915 }
916}
917
918fn reveal_output(g: Graph, out_node: Node, output_parties: Vec<IOStatus>) -> Result<Node> {
920 if output_parties.is_empty() {
922 return Ok(out_node);
923 }
924 let mut shares = vec![];
926 for i in 0..PARTIES as u64 {
927 let share = out_node.tuple_get(i)?;
928 shares.push(share);
929 }
930 if let IOStatus::Party(id) = output_parties[0] {
931 let party_id = id as usize;
932 let mut shares_to_reveal = shares.clone();
933 let prev_party_id = (party_id + PARTIES - 1) % PARTIES;
935 let missing_share = shares_to_reveal[prev_party_id].nop()?;
936 shares_to_reveal[prev_party_id] = missing_share
937 .add_annotation(NodeAnnotation::Send(prev_party_id as u64, party_id as u64))?;
938 let revealed_node = recursively_sum_shares(g, shares_to_reveal)?;
940 let result_node = if output_parties.len() > 1 {
942 let mut send_node = revealed_node;
943 for i in 1..PARTIES {
944 let party_to_send_id = (party_id + i) % PARTIES;
945 if output_parties.contains(&IOStatus::Party(party_to_send_id as u64)) {
946 send_node = send_node.nop()?;
947 send_node.add_annotation(NodeAnnotation::Send(
948 party_id as u64,
949 party_to_send_id as u64,
950 ))?;
951 }
952 }
953 send_node.nop()?
955 } else {
956 revealed_node
957 };
958 return Ok(result_node);
959 }
960 panic!("Shouldn't be here");
961}
962
963fn compile_to_mpc_context(
971 in_context: Context,
972 input_party_map: Vec<Vec<IOStatus>>,
973 output_parties: Vec<Vec<IOStatus>>,
974 out_context: Context,
975 out_mapping: &mut ContextMappings,
976) -> Result<()> {
977 in_context.check_finalized()?;
978
979 for (i, graph) in in_context.get_graphs().iter().enumerate() {
980 let is_input_private: Vec<bool> = input_party_map[i]
982 .iter()
983 .map(|x| *x != IOStatus::Public)
984 .collect();
985 let computation_graph = compile_to_mpc_graph(
986 graph.clone(),
987 is_input_private.clone(),
988 out_context.clone(),
989 out_mapping,
990 )?;
991
992 let new_graph = out_context.create_graph()?;
993 let prf_keys = {
995 let keys_vec = generate_prf_key_triple(new_graph.clone())?;
996 new_graph.create_tuple(keys_vec)?
997 };
998
999 let is_prf_mul_key_needed =
1001 contains_node_annotation(computation_graph.clone(), NodeAnnotation::PRFMultiplication)?;
1002 let is_prf_b2a_key_needed =
1003 contains_node_annotation(computation_graph.clone(), NodeAnnotation::PRFB2A)?;
1004 let is_prf_truncate_key_needed =
1005 contains_node_annotation(computation_graph.clone(), NodeAnnotation::PRFTruncate)?;
1006 let shared_input = share_all_inputs(
1007 graph.clone(),
1008 new_graph.clone(),
1009 input_party_map[i].clone(),
1010 prf_keys.clone(),
1011 is_prf_mul_key_needed,
1012 is_prf_b2a_key_needed,
1013 is_prf_truncate_key_needed,
1014 )?;
1015
1016 let shared_result = new_graph.call(computation_graph.clone(), shared_input)?;
1018 let is_output_private = {
1020 let out_node = computation_graph.get_output_node()?;
1021 let out_anno = out_node.get_annotations()?;
1022 out_anno.contains(&NodeAnnotation::Private)
1023 };
1024 let result = if is_output_private {
1025 reveal_output(new_graph.clone(), shared_result, output_parties[i].clone())?
1026 } else if output_parties[i].is_empty() {
1027 let node = share_node(
1029 new_graph.clone(),
1030 shared_result.clone(),
1031 prf_keys,
1032 IOStatus::Party(0),
1033 )?;
1034 node.add_annotation(NodeAnnotation::Private)?
1035 } else {
1036 shared_result
1037 };
1038 result.set_as_output()?;
1039 new_graph.finalize()?;
1040 out_mapping.insert_graph(graph.clone(), new_graph);
1041 }
1042 Ok(())
1043}
1044
1045fn compile_to_mpc(
1054 context: Context,
1055 input_party_map: Vec<Vec<IOStatus>>,
1056 output_parties: Vec<Vec<IOStatus>>,
1057) -> Result<MappedContext> {
1058 for sub_map in &input_party_map {
1059 for status in sub_map {
1060 if let IOStatus::Party(id) = *status {
1061 if id >= PARTIES as u64 {
1062 return Err(runtime_error!("Input party should have a valid party ID"));
1063 }
1064 }
1065 }
1066 }
1067 for sub_parties in &output_parties {
1068 for status in sub_parties {
1069 if let IOStatus::Party(id) = *status {
1070 if id >= PARTIES as u64 {
1071 return Err(runtime_error!("Output party should have a valid party ID"));
1072 }
1073 } else {
1074 return Err(runtime_error!(
1075 "Output status should be a party id or shared"
1076 ));
1077 }
1078 }
1079 }
1080 let new_context = create_context()?;
1081 let mut context_map = ContextMappings::default();
1082 compile_to_mpc_context(
1083 context.clone(),
1084 input_party_map,
1085 output_parties,
1086 new_context.clone(),
1087 &mut context_map,
1088 )?;
1089 let old_main_graph = context.get_main_graph()?;
1090 let main_graph = context_map.get_graph(old_main_graph);
1091 new_context.set_main_graph(main_graph)?;
1092 new_context.finalize()?;
1093 let mut mapped_context = MappedContext::new(new_context);
1094 mapped_context.mappings = context_map;
1095 Ok(mapped_context)
1096}
1097
1098pub fn uniquify_prf_id(context: Context) -> Result<Context> {
1101 let new_context = create_context()?;
1102 let mut context_map = ContextMappings::default();
1103 let graphs = context.get_graphs();
1104 let mut prf_id = 0;
1105 for graph in graphs {
1106 let out_graph = new_context.create_graph()?;
1107 let nodes = graph.get_nodes();
1108 for node in nodes {
1109 let op = node.get_operation();
1110 let op = if op.is_prf_operation() {
1111 prf_id += 1;
1112 op.update_prf_id(prf_id)?
1113 } else {
1114 op
1115 };
1116 let node_dependencies = node.get_node_dependencies();
1117 let new_node_dependencies: Vec<Node> = node_dependencies
1118 .iter()
1119 .map(|x| context_map.get_node((*x).clone()))
1120 .collect();
1121 let graph_dependencies = node.get_graph_dependencies();
1122 let new_graph_dependencies: Vec<Graph> = graph_dependencies
1123 .iter()
1124 .map(|x| context_map.get_graph((*x).clone()))
1125 .collect();
1126 let new_node = out_graph.add_node(new_node_dependencies, new_graph_dependencies, op)?;
1127 let annotations = node.get_annotations()?;
1128 for anno in annotations {
1129 new_node.add_annotation(anno)?;
1130 }
1131 copy_node_name(node.clone(), new_node.clone())?;
1132 context_map.insert_node(node, new_node);
1133 }
1134 let output_node = graph.get_output_node()?;
1135 let new_output_node = context_map.get_node(output_node);
1136 out_graph.set_output_node(new_output_node)?;
1137 out_graph.finalize()?;
1138 context_map.insert_graph(graph, out_graph.clone());
1139 }
1140 let old_main_graph = context.get_main_graph()?;
1141 let main_graph = context_map.get_graph(old_main_graph);
1142 new_context.set_main_graph(main_graph)?;
1143 new_context.finalize()?;
1144 Ok(new_context)
1145}
1146
1147pub fn prepare_for_mpc_evaluation(
1152 context: Context,
1153 input_party_map: Vec<Vec<IOStatus>>,
1154 output_parties: Vec<Vec<IOStatus>>,
1155 inline_config: InlineConfig,
1156) -> Result<Context> {
1157 let mpc_context = compile_to_mpc(context, input_party_map, output_parties)?.get_context();
1158 let instantiated_context = run_instantiation_pass(mpc_context)?.get_context();
1159 let inlined_context = inline_operations(instantiated_context, inline_config)?;
1160 uniquify_prf_id(inlined_context)
1161}
1162
1163fn print_stats(graph: Graph) -> Result<()> {
1164 let mut cnt = HashMap::<String, u64>::new();
1165 for node in graph.get_nodes() {
1166 let op_name = format!("{}", node.get_operation());
1167 *cnt.entry(op_name).or_insert(0) += 1;
1168 }
1169 let mut entries: Vec<(String, u64)> = cnt.iter().map(|e| (e.0.clone(), *e.1)).collect();
1170 entries.sort_by_key(|e| -(e.1 as i64));
1171 eprintln!("-------Stats--------");
1172 eprintln!("Total ops: {}", graph.get_nodes().len());
1173 for e in entries {
1174 eprintln!("{}\t{}", e.0, e.1);
1175 }
1176 Ok(())
1177}
1178
1179pub fn prepare_context<E>(
1180 context: Context,
1181 inline_config: InlineConfig,
1182 evaluator: E,
1183 print_unoptimized_stats: bool,
1184) -> Result<Context>
1185where
1186 E: Evaluator + Sized,
1187{
1188 eprintln!("Instantiating...");
1189 let context2 = run_instantiation_pass(context)?.get_context();
1190 eprintln!("Inlining...");
1191 let context3 = inline_operations(context2, inline_config)?;
1192 if print_unoptimized_stats {
1193 print_stats(context3.get_main_graph()?)?;
1194 }
1195 eprintln!("Optimizing...");
1196 optimize_context(context3, evaluator)
1197}
1198
1199pub fn compile_context<T, E>(
1202 context: Context,
1203 input_parties: Vec<IOStatus>,
1204 output_parties: Vec<IOStatus>,
1205 inline_config: InlineConfig,
1206 get_evaluator: T,
1207) -> Result<Context>
1208where
1209 T: Fn() -> Result<E>,
1210 E: Evaluator + Sized,
1211{
1212 let evaluator0 = get_evaluator()?;
1213 let context4 = prepare_context(context, inline_config.clone(), evaluator0, true)?;
1214 print_stats(context4.get_main_graph()?)?;
1215 let mut number_of_inputs = 0;
1216 for node in context4.get_main_graph()?.get_nodes() {
1217 if node.get_operation().is_input() {
1218 number_of_inputs += 1;
1219 }
1220 }
1221 if input_parties.len() != number_of_inputs {
1222 return Err(runtime_error!(
1223 "Invalid number of input parties: {} expected, but {} found",
1224 number_of_inputs,
1225 input_parties.len()
1226 ));
1227 }
1228 eprintln!("input_parties = {input_parties:?}");
1229 eprintln!("output_parties = {output_parties:?}");
1230 let compiled_context0 = prepare_for_mpc_evaluation(
1231 context4,
1232 vec![input_parties],
1233 vec![output_parties],
1234 inline_config,
1235 )?;
1236 print_stats(compiled_context0.get_main_graph()?)?;
1237
1238 let evaluator1 = get_evaluator()?;
1239 let compiled_context = optimize_context(compiled_context0, evaluator1)?;
1240 print_stats(compiled_context.get_main_graph()?)?;
1241 Ok(compiled_context)
1242}
1243
1244#[cfg(test)]
1245mod tests {
1246 use super::*;
1247 use crate::custom_ops::run_instantiation_pass;
1248 use crate::data_types::{
1249 array_type, get_types_vector, named_tuple_type, scalar_type, tuple_type, vector_type, BIT,
1250 INT32, INT64, UINT64, UINT8, VOID_TYPE,
1251 };
1252 use crate::data_values::Value;
1253 use crate::evaluators::random_evaluate;
1254 use crate::evaluators::simple_evaluator::evaluate_add_subtract_multiply;
1255 use crate::graphs::util::simple_context;
1256 use crate::graphs::SliceElement::{Ellipsis, SubArray};
1257 use crate::inline::inline_ops::{inline_operations, InlineConfig, InlineMode};
1258 use crate::random::PRNG;
1259
1260 use std::collections::HashMap;
1261
1262 #[test]
1263 fn test_malformed() {
1264 || -> Result<()> {
1265 let c = create_context()?;
1266 assert!(compile_to_mpc(
1267 c.clone(),
1268 vec![vec![IOStatus::Public]],
1269 vec![vec![IOStatus::Party(0)]]
1270 )
1271 .is_err());
1272 let g = c.create_graph()?;
1273 g.input(scalar_type(BIT))?.set_as_output()?;
1274 g.finalize()?;
1275 c.set_main_graph(g)?;
1276 c.finalize()?;
1277 assert!(compile_to_mpc(
1278 c.clone(),
1279 vec![vec![IOStatus::Party(3)]],
1280 vec![vec![IOStatus::Party(0)]]
1281 )
1282 .is_err());
1283 assert!(compile_to_mpc(
1284 c.clone(),
1285 vec![vec![IOStatus::Public]],
1286 vec![vec![IOStatus::Party(5)]]
1287 )
1288 .is_err());
1289 assert!(compile_to_mpc(
1290 c.clone(),
1291 vec![vec![IOStatus::Public]],
1292 vec![vec![IOStatus::Shared]]
1293 )
1294 .is_err());
1295 Ok(())
1296 }()
1297 .unwrap();
1298 }
1299
1300 fn reveal_private_value(value: Value, t: Type) -> Result<Value> {
1301 let shares = value.to_vector()?;
1302 if matches!(t.clone(), Type::Array(_, _) | Type::Scalar(_)) {
1303 let mut res = Value::zero_of_type(t.clone());
1304 for share in shares {
1305 res = evaluate_add_subtract_multiply(
1306 t.clone(),
1307 res.clone(),
1308 t.clone(),
1309 share,
1310 Operation::Add,
1311 t.clone(),
1312 )?;
1313 }
1314 return Ok(res);
1315 }
1316
1317 let vector_types = get_types_vector(t.clone())?;
1318 let mut shares_vec = vec![];
1319 for i in 0..PARTIES {
1320 shares_vec.push(shares[i].to_vector()?);
1321 }
1322 let mut res_vec = vec![];
1323 for i in 0..vector_types.len() {
1324 let mut tuple_vec = vec![];
1325 for j in 0..PARTIES {
1326 tuple_vec.push(shares_vec[j][i].clone());
1327 }
1328 let tuple = Value::from_vector(tuple_vec);
1329 res_vec.push(reveal_private_value(tuple, (*vector_types[i]).clone())?);
1330 }
1331 Ok(Value::from_vector(res_vec))
1332 }
1333
1334 #[test]
1335 fn test_input() {
1336 let seed = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F";
1337 let mut prng = PRNG::new(Some(*seed)).unwrap();
1338 let mut helper =
1339 |t: Type, input_status: IOStatus, output_parties: Vec<IOStatus>| -> Result<()> {
1340 let c = simple_context(|g| g.input(t.clone()))?;
1341 let mpc_mapped_context = compile_to_mpc(
1342 c,
1343 vec![vec![input_status.clone()]],
1344 vec![output_parties.clone()],
1345 )?;
1346 let mpc_context = mpc_mapped_context.get_context();
1347 let mpc_graph = mpc_context.get_main_graph()?;
1348 let mut inputs = vec![];
1349 if input_status == IOStatus::Shared {
1350 let tuple_t = tuple_type(vec![t.clone(); PARTIES]);
1351 inputs.push(prng.get_random_value(tuple_t.clone())?);
1352 } else {
1353 inputs.push(prng.get_random_value(t.clone())?);
1354 }
1355 let output = random_evaluate(mpc_graph.clone(), inputs.clone())?;
1356
1357 let mpc_computation_graph = mpc_context.get_graphs()[0].clone();
1358
1359 let computation_output_node = mpc_computation_graph.get_output_node()?;
1360 let computation_output_annotations = computation_output_node.get_annotations()?;
1361 if input_status != IOStatus::Public {
1362 let expected = if input_status == IOStatus::Shared {
1363 reveal_private_value(inputs[0].clone(), t.clone())?
1364 } else {
1365 inputs[0].clone()
1366 };
1367 if output_parties.is_empty() {
1369 let revealed_output = reveal_private_value(output.clone(), t.clone())?;
1370 assert!(output.check_type(tuple_type(vec![t.clone(); PARTIES]))?);
1371 assert_eq!(revealed_output, expected);
1372 } else {
1373 assert!(output.check_type(t.clone())?);
1375 assert_eq!(output, expected.clone());
1376 }
1377 assert!(computation_output_annotations.contains(&NodeAnnotation::Private));
1378 } else {
1379 if output_parties.is_empty() {
1381 let revealed_output = reveal_private_value(output.clone(), t.clone())?;
1382 assert!(output.check_type(tuple_type(vec![t.clone(); PARTIES]))?);
1383 assert_eq!(revealed_output, inputs[0]);
1384 let output_annotations = mpc_graph.get_output_node()?.get_annotations()?;
1386 assert!(output_annotations.contains(&NodeAnnotation::Private));
1387 } else {
1388 assert_eq!(output, inputs[0]);
1389 }
1390 assert!(!computation_output_annotations.contains(&NodeAnnotation::Private));
1392 }
1393 Ok(())
1394 };
1395
1396 helper(
1397 array_type(vec![2, 2], INT64),
1398 IOStatus::Party(0),
1399 vec![IOStatus::Party(1)],
1400 )
1401 .unwrap();
1402 helper(
1403 array_type(vec![2, 2], INT64),
1404 IOStatus::Public,
1405 vec![IOStatus::Party(0)],
1406 )
1407 .unwrap();
1408 helper(
1409 scalar_type(UINT64),
1410 IOStatus::Party(1),
1411 vec![IOStatus::Party(1), IOStatus::Party(2)],
1412 )
1413 .unwrap();
1414 helper(
1415 scalar_type(UINT64),
1416 IOStatus::Public,
1417 vec![IOStatus::Party(0)],
1418 )
1419 .unwrap();
1420 helper(
1421 scalar_type(UINT64),
1422 IOStatus::Shared,
1423 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1424 )
1425 .unwrap();
1426 helper(scalar_type(UINT64), IOStatus::Shared, vec![]).unwrap();
1427 helper(scalar_type(UINT64), IOStatus::Public, vec![]).unwrap();
1428 }
1429
1430 fn prepare_private_value(value: Value, t: Type) -> Result<Vec<Value>> {
1431 if let Type::Scalar(st) | Type::Array(_, st) = t.clone() {
1434 let mut res = vec![];
1435 let zero = Value::zero_of_type(t.clone());
1436 let one = Value::from_scalar(1, st)?;
1437 let two = Value::from_scalar(2, st)?;
1438 for i in 0..PARTIES {
1439 let (add_sub, l_value, r_value) = match i {
1440 0 => (Operation::Add, value.clone(), two.clone()),
1441 1 => (Operation::Subtract, zero.clone(), one.clone()),
1442 2 => (Operation::Subtract, zero.clone(), one.clone()),
1443 _ => panic!("More than 3 parties are not supported"),
1444 };
1445 let share = evaluate_add_subtract_multiply(
1446 t.clone(),
1447 l_value,
1448 scalar_type(st),
1449 r_value,
1450 add_sub,
1451 t.clone(),
1452 )?;
1453 res.push(share);
1454 }
1455 return Ok(res);
1456 }
1457
1458 let vector_types = get_types_vector(t.clone())?;
1459 let mut shares = vec![vec![]; PARTIES];
1460 value.access_vector(|vector_values| {
1461 for i in 0..vector_values.len() {
1462 let tuple_i =
1463 prepare_private_value(vector_values[i].clone(), (*vector_types[i]).clone())?;
1464 for j in 0..PARTIES {
1465 shares[j].push(tuple_i[j].clone())
1466 }
1467 }
1468 Ok(())
1469 })?;
1470 let mut res = vec![];
1471 for share in shares {
1472 res.push(Value::from_vector(share));
1473 }
1474 Ok(res)
1475 }
1476
1477 fn prepare_value(value: Value, t: Type, is_input_private: bool) -> Result<Value> {
1478 if is_input_private {
1479 let tuple = prepare_private_value(value, t)?;
1480 return Ok(Value::from_vector(tuple));
1481 }
1482 Ok(value)
1483 }
1484
1485 fn prepare_input(
1486 input_types: Vec<Type>,
1487 is_input_shared: Vec<bool>,
1488 ) -> Result<(Vec<Value>, Vec<Value>)> {
1489 let seed: [u8; 16] = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
1490 let mut prng = PRNG::new(Some(seed))?;
1491 let mut plain_inputs = vec![];
1492 let mut mpc_inputs = vec![];
1493 for i in 0..input_types.len() {
1494 let random_value = prng.get_random_value(input_types[i].clone())?;
1495 plain_inputs.push(random_value.clone());
1496 mpc_inputs.push(prepare_value(
1497 random_value,
1498 input_types[i].clone(),
1499 is_input_shared[i].clone(),
1500 )?);
1501 }
1502 Ok((plain_inputs, mpc_inputs))
1503 }
1504
1505 fn check_output(
1506 plain_graph: Graph,
1507 mpc_graph: Graph,
1508 plain_inputs: Vec<Value>,
1509 mpc_inputs: Vec<Value>,
1510 output_parties: Vec<IOStatus>,
1511 t: Type,
1512 ) -> Result<()> {
1513 let plain_output = random_evaluate(plain_graph.clone(), plain_inputs)?;
1514 let mpc_output = random_evaluate(mpc_graph.clone(), mpc_inputs)?;
1515
1516 if output_parties.is_empty() {
1517 assert!(mpc_output.check_type(tuple_type(vec![t.clone(); PARTIES]))?);
1519 let value_revealed = reveal_private_value(mpc_output.clone(), t.clone())?;
1520 assert_eq!(value_revealed, plain_output);
1521 } else {
1522 assert!(mpc_output.check_type(t.clone())?);
1523 assert_eq!(mpc_output, plain_output);
1524 }
1525
1526 Ok(())
1527 }
1528
1529 fn helper_one_input(
1530 input_types: Vec<Type>,
1531 op: Operation,
1532 input_party_map: Vec<IOStatus>,
1533 output_parties: Vec<IOStatus>,
1534 ) -> Result<()> {
1535 let c = simple_context(|g| {
1536 let mut input_nodes = vec![];
1537 for i in 0..input_types.len() {
1538 let input_node = g.input(input_types[i].clone())?;
1539 input_node.set_name(&format!("Input {}", i))?;
1540 input_nodes.push(input_node);
1541 }
1542 let o = if op != Operation::VectorGet {
1543 g.add_node(input_nodes, vec![], op)?
1544 } else {
1545 input_nodes[0].vector_get(g.zeros(scalar_type(UINT64))?)?
1546 };
1547 o.set_name("Plaintext operation")?;
1548 Ok(o)
1549 })?;
1550 let g = c.get_main_graph()?;
1551 let output_type = g.get_output_node()?.get_type()?;
1552
1553 let inline_config = InlineConfig {
1554 default_mode: InlineMode::Simple,
1555 ..Default::default()
1556 };
1557 let mpc_c = prepare_for_mpc_evaluation(
1558 c.clone(),
1559 vec![input_party_map.clone()],
1560 vec![output_parties.clone()],
1561 inline_config,
1562 )?;
1563 let mpc_graph = mpc_c.get_main_graph()?;
1564 let mpc_node_result = mpc_c.retrieve_node(mpc_graph.clone(), "Plaintext operation");
1566 assert!(mpc_node_result.is_err());
1567 for i in 0..input_types.len() {
1568 let node_name = format!("Input {}", i);
1569 let new_input_node = mpc_c.retrieve_node(mpc_graph.clone(), &node_name);
1570 assert!(new_input_node.is_ok());
1571 }
1572
1573 let is_input_shared = input_party_map
1574 .iter()
1575 .map(|x| *x == IOStatus::Shared)
1576 .collect();
1577 let (plain_inputs, mpc_inputs) = prepare_input(input_types.clone(), is_input_shared)?;
1578
1579 check_output(
1580 g,
1581 mpc_graph,
1582 plain_inputs,
1583 mpc_inputs,
1584 output_parties,
1585 output_type,
1586 )?;
1587 Ok(())
1588 }
1589
1590 fn test_helper_one_input(input_type: Type, op: Operation) -> Result<()> {
1591 helper_one_input(
1592 vec![input_type.clone()],
1593 op.clone(),
1594 vec![IOStatus::Party(0)],
1595 vec![IOStatus::Party(0)],
1596 )?;
1597 helper_one_input(
1598 vec![input_type.clone()],
1599 op.clone(),
1600 vec![IOStatus::Party(0)],
1601 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1602 )?;
1603 helper_one_input(
1604 vec![input_type.clone()],
1605 op.clone(),
1606 vec![IOStatus::Public],
1607 vec![IOStatus::Party(1)],
1608 )?;
1609 helper_one_input(
1610 vec![input_type.clone()],
1611 op.clone(),
1612 vec![IOStatus::Shared],
1613 vec![IOStatus::Party(0), IOStatus::Party(1)],
1614 )?;
1615 helper_one_input(
1616 vec![input_type.clone()],
1617 op.clone(),
1618 vec![IOStatus::Party(0)],
1619 vec![],
1620 )?;
1621 helper_one_input(vec![input_type], op, vec![IOStatus::Public], vec![])?;
1622 Ok(())
1623 }
1624
1625 #[test]
1626 fn test_permute_axes() {
1627 test_helper_one_input(
1628 array_type(vec![4, 2, 3], INT32),
1629 Operation::PermuteAxes(vec![2, 0, 1]),
1630 )
1631 .unwrap();
1632 }
1633
1634 #[test]
1635 fn test_array_to_vector() {
1636 test_helper_one_input(array_type(vec![3, 1], UINT8), Operation::ArrayToVector).unwrap();
1637 }
1638
1639 #[test]
1640 fn test_vector_to_array() {
1641 test_helper_one_input(
1642 vector_type(10, array_type(vec![4, 3], INT32)),
1643 Operation::VectorToArray,
1644 )
1645 .unwrap();
1646 }
1647
1648 #[test]
1649 fn test_vector_get() {
1650 test_helper_one_input(
1651 vector_type(10, array_type(vec![4, 3], INT32)),
1652 Operation::VectorGet,
1653 )
1654 .unwrap();
1655 }
1656
1657 #[test]
1658 fn test_get_slice() {
1659 test_helper_one_input(
1660 array_type(vec![10, 128], INT32),
1661 Operation::GetSlice(vec![Ellipsis, SubArray(None, Some(-1), None)]),
1662 )
1663 .unwrap();
1664 }
1665
1666 #[test]
1667 fn test_reshape() {
1668 test_helper_one_input(
1669 array_type(vec![10, 128], INT32),
1670 Operation::Reshape(array_type(vec![20, 64], INT32)),
1671 )
1672 .unwrap();
1673 }
1674
1675 #[test]
1676 fn test_tuple_get() {
1677 let t = array_type(vec![10, 128], INT32);
1678 test_helper_one_input(
1679 tuple_type(vec![t.clone(), scalar_type(UINT64), t]),
1680 Operation::TupleGet(1),
1681 )
1682 .unwrap();
1683 }
1684
1685 #[test]
1686 fn test_named_tuple_get() {
1687 let t = array_type(vec![10, 128], INT32);
1688 test_helper_one_input(
1689 named_tuple_type(vec![
1690 ("a".to_owned(), t.clone()),
1691 ("b".to_owned(), scalar_type(UINT64)),
1692 ("c".to_owned(), t),
1693 ]),
1694 Operation::NamedTupleGet("b".to_string()),
1695 )
1696 .unwrap();
1697 }
1698
1699 #[test]
1700 fn test_sum() {
1701 test_helper_one_input(
1702 array_type(vec![10, 5, 12], INT32),
1703 Operation::Sum(vec![0, 1]),
1704 )
1705 .unwrap();
1706 }
1707
1708 #[test]
1709 fn test_cum_sum() -> Result<()> {
1710 test_helper_one_input(array_type(vec![10, 5, 12], INT32), Operation::CumSum(0))?;
1711 test_helper_one_input(array_type(vec![10, 5, 12], INT32), Operation::CumSum(1))?;
1712 test_helper_one_input(array_type(vec![10, 5, 12], INT32), Operation::CumSum(2))?;
1713 Ok(())
1714 }
1715
1716 #[test]
1717 fn test_get() {
1718 test_helper_one_input(array_type(vec![10, 128], INT32), Operation::Get(vec![5, 4]))
1719 .unwrap();
1720 }
1721
1722 #[test]
1723 fn test_repeat() {
1724 test_helper_one_input(array_type(vec![10, 128], INT32), Operation::Repeat(10)).unwrap();
1725 }
1726
1727 fn helper_create_ops(
1728 input_types: Vec<Type>,
1729 op: Operation,
1730 input_party_map: Vec<IOStatus>,
1731 output_parties: Vec<IOStatus>,
1732 include_constant: bool,
1733 ) -> Result<()> {
1734 let c = create_context()?;
1735 let g = c.create_graph()?;
1736 let mut input_nodes = vec![];
1737 for i in 0..input_types.len() {
1738 let input_node = g.input(input_types[i].clone())?;
1739 input_node.set_name(&format!("Input {}", i))?;
1740 input_nodes.push(input_node);
1741 }
1742 let resolved_op = if include_constant {
1743 input_nodes.push(g.constant(
1744 input_types[0].clone(),
1745 Value::zero_of_type(input_types[0].clone()),
1746 )?);
1747 match op {
1748 Operation::CreateNamedTuple(mut names) => {
1749 names.push("const".to_owned());
1750 Operation::CreateNamedTuple(names)
1751 }
1752 Operation::Stack(outer_shape) => {
1753 let mut pr = 1;
1754 for x in &outer_shape {
1755 pr *= *x;
1756 }
1757 Operation::Stack(vec![pr + 1])
1758 }
1759 _ => op,
1760 }
1761 } else {
1762 op
1763 };
1764 let o = g.add_node(input_nodes, vec![], resolved_op)?;
1765 o.set_name("Plaintext operation")?;
1766 let output_type = o.get_type()?;
1767 g.set_output_node(o.clone())?;
1768 g.finalize()?;
1769 c.set_main_graph(g.clone())?;
1770 c.finalize()?;
1771
1772 let inline_config = InlineConfig {
1773 default_mode: InlineMode::Simple,
1774 ..Default::default()
1775 };
1776 let mpc_c = prepare_for_mpc_evaluation(
1777 c.clone(),
1778 vec![input_party_map.clone()],
1779 vec![output_parties.clone()],
1780 inline_config,
1781 )?;
1782 let mpc_graph = mpc_c.get_main_graph()?;
1783 let mpc_node_result = mpc_c.retrieve_node(mpc_graph.clone(), "Plaintext operation");
1785 assert!(mpc_node_result.is_err());
1786 for i in 0..input_types.len() {
1787 let node_name = format!("Input {}", i);
1788 let new_input_node = mpc_c.retrieve_node(mpc_graph.clone(), &node_name);
1789 assert!(new_input_node.is_ok());
1790 }
1791
1792 let is_input_shared = input_party_map
1793 .iter()
1794 .map(|x| *x == IOStatus::Shared)
1795 .collect();
1796 let (plain_inputs, mpc_inputs) = prepare_input(input_types.clone(), is_input_shared)?;
1797
1798 check_output(
1799 g,
1800 mpc_graph,
1801 plain_inputs,
1802 mpc_inputs,
1803 output_parties,
1804 output_type,
1805 )?;
1806 Ok(())
1807 }
1808
1809 fn test_helper_create_ops(input_types: Vec<Type>, op: Operation) -> Result<()> {
1810 helper_create_ops(
1811 input_types.clone(),
1812 op.clone(),
1813 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1814 vec![IOStatus::Party(0)],
1815 true,
1816 )?;
1817 helper_create_ops(
1818 input_types.clone(),
1819 op.clone(),
1820 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1821 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1822 true,
1823 )?;
1824 helper_create_ops(
1825 input_types.clone(),
1826 op.clone(),
1827 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1828 vec![IOStatus::Party(0)],
1829 true,
1830 )?;
1831 helper_create_ops(
1832 input_types.clone(),
1833 op.clone(),
1834 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1835 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1836 true,
1837 )?;
1838
1839 helper_create_ops(
1840 input_types.clone(),
1841 op.clone(),
1842 vec![IOStatus::Party(0), IOStatus::Public, IOStatus::Party(1)],
1843 vec![IOStatus::Party(0)],
1844 true,
1845 )?;
1846 helper_create_ops(
1847 input_types.clone(),
1848 op.clone(),
1849 vec![IOStatus::Party(0), IOStatus::Public, IOStatus::Party(1)],
1850 vec![IOStatus::Party(0)],
1851 true,
1852 )?;
1853
1854 helper_create_ops(
1855 input_types.clone(),
1856 op.clone(),
1857 vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1858 vec![IOStatus::Party(0)],
1859 false,
1860 )?;
1861 helper_create_ops(
1862 input_types.clone(),
1863 op.clone(),
1864 vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1865 vec![IOStatus::Party(0)],
1866 false,
1867 )?;
1868 helper_create_ops(
1869 input_types.clone(),
1870 op.clone(),
1871 vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1872 vec![IOStatus::Party(0), IOStatus::Party(1), IOStatus::Party(2)],
1873 false,
1874 )?;
1875 helper_create_ops(
1876 input_types.clone(),
1877 op.clone(),
1878 vec![IOStatus::Shared, IOStatus::Shared, IOStatus::Party(0)],
1879 vec![IOStatus::Party(0), IOStatus::Party(1)],
1880 true,
1881 )?;
1882 helper_create_ops(
1883 input_types.clone(),
1884 op.clone(),
1885 vec![IOStatus::Shared, IOStatus::Shared, IOStatus::Party(0)],
1886 vec![],
1887 true,
1888 )?;
1889 helper_create_ops(
1890 input_types.clone(),
1891 op.clone(),
1892 vec![IOStatus::Public, IOStatus::Public, IOStatus::Public],
1893 vec![],
1894 true,
1895 )?;
1896 Ok(())
1897 }
1898
1899 #[test]
1900 fn test_create_tuple() {
1901 let t = array_type(vec![10, 128], INT32);
1902 test_helper_create_ops(
1903 vec![t.clone(), scalar_type(UINT64), t.clone()],
1904 Operation::CreateTuple,
1905 )
1906 .unwrap();
1907 test_helper_create_ops(vec![t, VOID_TYPE], Operation::CreateTuple).unwrap();
1908 }
1909
1910 #[test]
1911 fn test_create_named_tuple() {
1912 let t = array_type(vec![10, 128], INT32);
1913 test_helper_create_ops(
1914 vec![t.clone(), scalar_type(UINT64), t.clone()],
1915 Operation::CreateNamedTuple(vec!["a".to_owned(), "b".to_owned(), "c".to_owned()]),
1916 )
1917 .unwrap();
1918 test_helper_create_ops(
1919 vec![t, VOID_TYPE],
1920 Operation::CreateNamedTuple(vec!["a".to_owned(), "b".to_owned()]),
1921 )
1922 .unwrap();
1923 }
1924
1925 #[test]
1926 fn test_create_vector() {
1927 let t = array_type(vec![10, 128], INT32);
1928 test_helper_create_ops(vec![t.clone(); 3], Operation::CreateVector(t)).unwrap();
1929 }
1930
1931 #[test]
1932 fn test_zip() {
1933 let t = vector_type(10, array_type(vec![4, 3], INT32));
1934 test_helper_create_ops(vec![t.clone(); 3], Operation::Zip).unwrap();
1935 }
1936 #[test]
1937 fn test_stack() {
1938 let t = array_type(vec![10, 128], INT32);
1939 test_helper_create_ops(vec![t.clone(); 3], Operation::Stack(vec![3])).unwrap();
1940 }
1941 #[test]
1942 fn test_concatenate() {
1943 let t1 = array_type(vec![10, 1, 10], INT32);
1944 let t2 = array_type(vec![10, 2, 10], INT32);
1945 let t3 = array_type(vec![10, 3, 10], INT32);
1946 test_helper_create_ops(vec![t1, t2, t3], Operation::Concatenate(1)).unwrap();
1947 }
1948
1949 fn check_prf_id(context: Context) -> Result<()> {
1951 let mut iv_node_map: HashMap<u64, Node> = HashMap::new();
1952 let graphs = context.get_graphs();
1953 for graph in graphs {
1954 let nodes = graph.get_nodes();
1955 for node in nodes {
1956 let iv = match node.get_operation() {
1957 Operation::PRF(iv, _) => iv,
1958 Operation::PermutationFromPRF(iv, _) => iv,
1959 _ => continue,
1960 };
1961 if let Some(other_node) = iv_node_map.get(&iv) {
1962 if *other_node != node {
1963 return Err(runtime_error!("PRF node with non-unique iv"));
1964 }
1965 } else {
1966 iv_node_map.insert(iv, node);
1967 }
1968 }
1969 }
1970 Ok(())
1971 }
1972
1973 #[test]
1974 fn test_prf_id() {
1975 || -> Result<()> {
1976 let c = create_context()?;
1977 let g1 = c.create_graph()?;
1978 {
1979 let i = g1.input(scalar_type(UINT64))?;
1980 let o = i.a2b()?;
1981 g1.set_output_node(o)?;
1982 g1.finalize()?;
1983 }
1984 let g2 = c.create_graph()?;
1985 {
1986 let i = g2.input(scalar_type(UINT64))?;
1987 let o = i.a2b()?;
1988 g2.set_output_node(o)?;
1989 g2.finalize()?;
1990 }
1991
1992 c.set_main_graph(g2)?;
1993 c.finalize()?;
1994
1995 let mpc_c = compile_to_mpc(
1996 c,
1997 vec![vec![IOStatus::Party(0)], vec![IOStatus::Party(1)]],
1998 vec![vec![IOStatus::Party(1)], vec![IOStatus::Party(2)]],
1999 )?
2000 .get_context();
2001 let instantiated_context = run_instantiation_pass(mpc_c)?.get_context();
2002 assert!(check_prf_id(instantiated_context.clone()).is_err());
2003 let inlined_context = inline_operations(
2004 instantiated_context.clone(),
2005 InlineConfig {
2006 default_mode: InlineMode::Simple,
2007 ..Default::default()
2008 },
2009 )?;
2010 assert!(check_prf_id(inlined_context.clone()).is_err());
2011
2012 let validated_instantiated_context = uniquify_prf_id(instantiated_context)?;
2013 assert!(check_prf_id(validated_instantiated_context).is_ok());
2014 let validated_inlined_context = uniquify_prf_id(inlined_context)?;
2015 assert!(check_prf_id(validated_inlined_context).is_ok());
2016 Ok(())
2017 }()
2018 .unwrap()
2019 }
2020
2021 #[test]
2022 fn test_prf_ids_for_permutation_from_prf() -> Result<()> {
2023 let c = create_context()?;
2024 let g1 = c.create_graph()?;
2025 {
2026 let k = g1.random(array_type(vec![128], BIT))?;
2027 g1.permutation_from_prf(k, 0, 10)?.set_as_output()?;
2028 g1.finalize()?;
2029 }
2030 let g2 = c.create_graph()?;
2031 {
2032 let k = g2.random(array_type(vec![128], BIT))?;
2033 g2.prf(k, 0, scalar_type(UINT64))?.set_as_output()?;
2034 g2.finalize()?;
2035 }
2036 let g3 = c.create_graph()?;
2037 {
2038 let k = g3.random(array_type(vec![128], BIT))?;
2039 g3.permutation_from_prf(k, 0, 11)?.set_as_output()?;
2040 g3.finalize()?;
2041 }
2042 let g4 = c.create_graph()?;
2043 {
2044 let k = g4.random(array_type(vec![128], BIT))?;
2045 g4.prf(k, 0, scalar_type(INT32))?.set_as_output()?;
2046 g4.finalize()?;
2047 }
2048
2049 c.set_main_graph(g2)?;
2050 c.finalize()?;
2051 assert!(check_prf_id(c.clone()).is_err());
2052
2053 let c = uniquify_prf_id(c)?;
2054 assert!(check_prf_id(c).is_ok());
2055 Ok(())
2056 }
2057
2058 #[test]
2059 fn test_resharing() -> Result<()> {
2060 {
2061 let c = create_context()?;
2062 let g = c.create_graph()?;
2063 let i1 = g.input(array_type(vec![2, 10], BIT))?;
2064 let i2 = g.input(array_type(vec![10, 3], BIT))?;
2065 let prod = i1.matmul(i2)?;
2066 let out = prod.sum(vec![0])?;
2067 out.set_as_output()?;
2068 g.finalize()?;
2069 g.set_as_main()?;
2070 c.finalize()?;
2071
2072 let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true])?.0;
2073 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2074
2075 assert!(reshared_nodes.len() == 1);
2076 assert!(reshared_nodes.contains(&out));
2077
2078 let shared_nodes = propagate_private_annotations(g.clone(), vec![false, true])?.0;
2079 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2080
2081 assert!(reshared_nodes.len() == 0);
2082
2083 let shared_nodes = propagate_private_annotations(g.clone(), vec![false, false])?.0;
2084 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2085
2086 assert!(reshared_nodes.len() == 0);
2087 }
2088
2089 {
2090 let c = create_context()?;
2091 let g = c.create_graph()?;
2092 let i1 = g.input(array_type(vec![2, 10], BIT))?;
2093 let i2 = g.input(array_type(vec![10, 3], BIT))?;
2094 let prod = i1.matmul(i2)?;
2095 let i3 = g.input(array_type(vec![3, 4], BIT))?;
2096 let out = prod.matmul(i3)?;
2097 out.set_as_output()?;
2098 g.finalize()?;
2099 g.set_as_main()?;
2100 c.finalize()?;
2101
2102 let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true, true])?.0;
2103 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2104
2105 assert!(reshared_nodes.len() == 2);
2106 assert!(reshared_nodes.contains(&prod));
2107 assert!(reshared_nodes.contains(&out));
2108
2109 let shared_nodes = propagate_private_annotations(g.clone(), vec![false, true, true])?.0;
2110 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2111
2112 assert!(reshared_nodes.len() == 1);
2113 assert!(reshared_nodes.contains(&out));
2114
2115 let shared_nodes = propagate_private_annotations(g.clone(), vec![true, true, false])?.0;
2116 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2117
2118 assert!(reshared_nodes.len() == 1);
2119 assert!(reshared_nodes.contains(&prod));
2120 }
2121 {
2122 let c = create_context()?;
2123 let g = c.create_graph()?;
2124 let i1 = g.input(array_type(vec![2, 10], BIT))?;
2125 let i2 = g.input(array_type(vec![10, 3], BIT))?;
2126 let prod12 = i1.matmul(i2)?;
2127 let i3 = g.input(array_type(vec![2, 4], BIT))?;
2128 let i4 = g.input(array_type(vec![4, 3], BIT))?;
2129 let prod34 = i3.matmul(i4)?;
2130 let out = prod12.add(prod34)?;
2131 out.set_as_output()?;
2132 g.finalize()?;
2133 g.set_as_main()?;
2134 c.finalize()?;
2135
2136 let shared_nodes = propagate_private_annotations(g.clone(), vec![true; 4])?.0;
2137 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2138
2139 assert!(reshared_nodes.len() == 1);
2140 assert!(reshared_nodes.contains(&out));
2141 }
2142 {
2143 let c = create_context()?;
2144 let g = c.create_graph()?;
2145 let i1 = g.input(array_type(vec![2, 10], INT32))?;
2146 let i2 = g.input(array_type(vec![10, 3], INT32))?;
2147 let prod12 = i1.matmul(i2)?;
2148 let i3 = g.input(array_type(vec![2, 3], INT32))?;
2149 let s123 = prod12.add(i3)?;
2150 let out = s123.a2b()?;
2151 out.set_as_output()?;
2152 g.finalize()?;
2153 g.set_as_main()?;
2154 c.finalize()?;
2155
2156 let shared_nodes = propagate_private_annotations(g.clone(), vec![true; 3])?.0;
2157 let reshared_nodes = get_nodes_to_reshare(&g, &shared_nodes)?;
2158
2159 assert!(reshared_nodes.len() == 1);
2160 assert!(reshared_nodes.contains(&s123));
2161 }
2162
2163 Ok(())
2164 }
2165}