ciphercore_base/graphs.rs
1//! Crucial structs, enums, functions and types to create computation graphs.
2use atomic_refcell::AtomicRefCell;
3use std::collections::HashMap;
4use std::fmt;
5use std::hash::Hash;
6use std::hash::Hasher;
7use std::ptr;
8use std::sync::Arc;
9use std::sync::Weak;
10
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13use crate::constants::type_size_limit_constants;
14use crate::custom_ops::CustomOperation;
15use crate::data_types::{get_size_estimation_in_bits, ArrayShape, ScalarType, Type};
16use crate::data_values::Value;
17use crate::errors::Result;
18use crate::type_inference::{create_type_inference_worker, TypeInferenceWorker};
19
20use crate::version::{VersionedData, DATA_VERSION};
21
22#[cfg(feature = "py-binding")]
23use crate::custom_ops::PyBindingCustomOperation;
24#[cfg(feature = "py-binding")]
25use crate::data_types::{PyBindingScalarType, PyBindingType};
26#[cfg(feature = "py-binding")]
27use crate::typed_value::PyBindingTypedValue;
28#[cfg(feature = "py-binding")]
29use pywrapper_macro::{enum_to_struct_wrapper, fn_wrapper, impl_wrapper, struct_wrapper};
30
31/// This enum represents different types of slice elements that are used to create indexing slices (see [Slice] and [Graph::get_slice]).
32///
33/// The semantics is similar to [the NumPy slice indexing](https://numpy.org/doc/stable/user/basics.indexing.html).
34#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
35#[cfg_attr(feature = "py-binding", enum_to_struct_wrapper)]
36pub enum SliceElement {
37 /// Single index of a given array dimension.
38 ///
39 /// The index is given by a signed integer. If negative, the index is interpreted as in [NumPy](https://numpy.org/doc/stable/user/basics.indexing.html).
40 ///
41 /// For example, to choose all the elements of an array with the last index in the first dimension and the first index in the second dimension, one can use a slice `vec![SingleIndex(-1), SingleIndex(0)]`.
42 SingleIndex(i64),
43 /// Sub-array denotes a range of indices of a given array dimension.
44 ///
45 /// It follows the description of [the NumPy basic slice](https://numpy.org/doc/stable/user/basics.indexing.html), which is defined by 3 signed integers: `start`, `stop`, `step`. `step` can't be equal to zero.
46 ///
47 /// For example, to choose all the elements of an array with even indices in the first dimension, one can use a slice `vec![SubArray(Some(0), None, Some(2))].
48 SubArray(Option<i64>, Option<i64>, Option<i64>),
49 /// Ellipsis denotes several dimensions where indices are not restricted.
50 ///
51 /// For example, to choose all the elements of an array with index `0` in the first dimension and index `2` in the last dimension, one can use a slice `vec![SingleIndex(0), Ellipsis, SingleIndex(2)]`.
52 Ellipsis,
53}
54
55/// Slice type denotes an indexing slice (see [NumPy slicing](https://numpy.org/doc/stable/user/basics.indexing.html)).
56///
57/// It is a vector of slice elements that describes the indices of a sub-array in any appropriate array.
58pub type Slice = Vec<SliceElement>;
59
60#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Copy)]
61#[cfg_attr(feature = "py-binding", enum_to_struct_wrapper)]
62pub enum JoinType {
63 Inner,
64 Left,
65 Union,
66 Full,
67}
68
69#[doc(hidden)]
70#[cfg(feature = "py-binding")]
71#[pyo3::pymethods]
72impl PyBindingJoinType {
73 #[staticmethod]
74 pub fn from_inner() -> Self {
75 PyBindingJoinType {
76 inner: JoinType::Inner,
77 }
78 }
79 #[staticmethod]
80 pub fn from_left() -> Self {
81 PyBindingJoinType {
82 inner: JoinType::Left,
83 }
84 }
85 #[staticmethod]
86 pub fn from_union() -> Self {
87 PyBindingJoinType {
88 inner: JoinType::Union,
89 }
90 }
91}
92
93/// Shard config contains the parameters of the Sharding operation, namely:
94///
95/// - number of shards into which input dataset will be split,
96/// - size of each shard, i.e., the number of rows in each shard,
97/// - headers of columns whose rows are hashed to find the index of a shard where the corresponding row will be placed.
98#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
99#[cfg_attr(feature = "py-binding", struct_wrapper)]
100pub struct ShardConfig {
101 pub num_shards: u64,
102 pub shard_size: u64,
103 /// headers of columns whose rows are hashed to find the index of a shard where the corresponding row will be placed
104 pub shard_headers: Vec<String>,
105}
106
107#[doc(hidden)]
108#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
109pub enum Operation {
110 Input(Type),
111 Zeros(Type),
112 Ones(Type),
113 Add,
114 Subtract,
115 Multiply,
116 // Elementwise multiplication of integer arrays by bit arrays.
117 // It leaves an integer array element as is or make it zero it depending on a bit array element.
118 MixedMultiply,
119 // Dot operation follows the numpy (tensor)dot semantics: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
120 Dot,
121 // Matmul operation follows the numpy matmul semantics: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html
122 // In particular, unlike Dot, it doesn't support scalar inputs.
123 Matmul,
124 Gemm(bool, bool),
125 Truncate(u128),
126 Sum(ArrayShape),
127 CumSum(u64),
128 PermuteAxes(ArrayShape),
129 Get(ArrayShape),
130 GetSlice(Slice),
131 Reshape(Type),
132 NOP,
133 Random(Type),
134 PRF(u64, Type),
135 PermutationFromPRF(u64, u64),
136 Stack(ArrayShape),
137 Concatenate(u64),
138 Constant(Type, Value),
139 A2B,
140 B2A(ScalarType),
141 CreateTuple,
142 CreateNamedTuple(Vec<String>),
143 CreateVector(Type),
144 TupleGet(u64),
145 NamedTupleGet(String),
146 VectorGet,
147 Zip,
148 Repeat(u64),
149 Call,
150 Iterate,
151 ArrayToVector,
152 VectorToArray,
153 // Operations that can't be compiled to MPC protocols
154 RandomPermutation(u64),
155 Gather(u64),
156 CuckooHash,
157 InversePermutation,
158 CuckooToPermutation,
159 DecomposeSwitchingMap(u64),
160 SegmentCumSum,
161 Shard(ShardConfig),
162 // SQL joins
163 Join(JoinType, HashMap<String, String>),
164 JoinWithColumnMasks(JoinType, HashMap<String, String>),
165 ApplyPermutation(bool),
166 Sort(String),
167 Custom(CustomOperation),
168 // Operations used for debugging graphs.
169 Print(String),
170 Assert(String),
171}
172
173impl fmt::Display for Operation {
174 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
175 let operation_name = if let Operation::Custom(custom_op) = self {
176 custom_op.get_name()
177 } else {
178 let operation_w_type_str = format!("{:?}", *self);
179 let split_for_operation = operation_w_type_str.split('(');
180 let vec_operation_and_types: Vec<&str> = split_for_operation.collect();
181 if vec_operation_and_types.is_empty() {
182 "-null-".to_owned()
183 } else {
184 vec_operation_and_types[0].to_owned()
185 }
186 };
187 write!(f, "{operation_name}")
188 }
189}
190
191impl Operation {
192 pub fn is_prf_operation(&self) -> bool {
193 matches!(
194 self,
195 Operation::PRF(_, _) | Operation::PermutationFromPRF(_, _)
196 )
197 }
198
199 pub fn is_broadcasting_called(&self) -> bool {
200 matches!(
201 self,
202 Operation::Add
203 | Operation::Subtract
204 | Operation::Multiply
205 | Operation::Matmul
206 | Operation::Gemm(_, _)
207 | Operation::MixedMultiply
208 | Operation::Stack(_)
209 )
210 }
211
212 pub fn is_mpc_compiled(&self) -> bool {
213 matches!(
214 self,
215 Operation::Input(_)
216 | Operation::Zeros(_)
217 | Operation::Ones(_)
218 | Operation::Add
219 | Operation::Subtract
220 | Operation::Multiply
221 | Operation::MixedMultiply
222 | Operation::Dot
223 | Operation::Matmul
224 | Operation::Gemm(_, _)
225 | Operation::Truncate(_)
226 | Operation::Sum(_)
227 | Operation::CumSum(_)
228 | Operation::PermuteAxes(_)
229 | Operation::Get(_)
230 | Operation::GetSlice(_)
231 | Operation::Reshape(_)
232 | Operation::Stack(_)
233 | Operation::Concatenate(_)
234 | Operation::Constant(_, _)
235 | Operation::A2B
236 | Operation::B2A(_)
237 | Operation::CreateTuple
238 | Operation::CreateNamedTuple(_)
239 | Operation::CreateVector(_)
240 | Operation::TupleGet(_)
241 | Operation::NamedTupleGet(_)
242 | Operation::VectorGet
243 | Operation::Zip
244 | Operation::Repeat(_)
245 | Operation::ArrayToVector
246 | Operation::VectorToArray
247 | Operation::Join(_, _)
248 | Operation::JoinWithColumnMasks(_, _)
249 | Operation::ApplyPermutation(_)
250 | Operation::Sort(_)
251 )
252 }
253
254 pub fn update_prf_id(&self, prf_id: u64) -> Result<Self> {
255 match self {
256 Operation::PRF(_, scalar_type) => Ok(Operation::PRF(prf_id, scalar_type.clone())),
257 Operation::PermutationFromPRF(_, size) => {
258 Ok(Operation::PermutationFromPRF(prf_id, *size))
259 }
260 _ => Err(runtime_error!("Operation is not a PRF operation")),
261 }
262 }
263
264 pub fn is_input(&self) -> bool {
265 matches!(self, Operation::Input(_))
266 }
267
268 pub fn is_const_optimizable(&self) -> Result<bool> {
269 match self {
270 // Zeros and Ones exist precisely because we don't want to store them as Constants
271 // to keep the graph size small.
272 Operation::Zeros(_) | Operation::Ones(_) => Ok(false),
273 op => Ok(!op.is_input() && !op.is_randomizing()?),
274 }
275 }
276
277 // If an operation computes a randomized output, return true
278 pub fn is_randomizing(&self) -> Result<bool> {
279 match self {
280 Operation::Random(_)
281 | Operation::RandomPermutation(_)
282 | Operation::CuckooToPermutation
283 | Operation::DecomposeSwitchingMap(_) => Ok(true),
284 Operation::Input(_)
285 | Operation::Zeros(_)
286 | Operation::Ones(_)
287 | Operation::A2B
288 | Operation::Add
289 | Operation::ApplyPermutation(_)
290 | Operation::ArrayToVector
291 | Operation::Assert(_)
292 | Operation::B2A(_)
293 | Operation::Subtract
294 | Operation::Multiply
295 | Operation::MixedMultiply
296 | Operation::Matmul
297 | Operation::Dot
298 | Operation::Gemm(_, _)
299 | Operation::Truncate(_)
300 | Operation::Sum(_)
301 | Operation::CumSum(_)
302 | Operation::Concatenate(_)
303 | Operation::CreateNamedTuple(_)
304 | Operation::CreateTuple
305 | Operation::CreateVector(_)
306 | Operation::CuckooHash
307 | Operation::SegmentCumSum
308 | Operation::PermuteAxes(_)
309 | Operation::Get(_)
310 | Operation::Gather(_)
311 | Operation::GetSlice(_)
312 | Operation::Reshape(_)
313 | Operation::NOP
314 | Operation::InversePermutation
315 | Operation::PRF(_, _)
316 | Operation::PermutationFromPRF(_, _)
317 | Operation::Stack(_)
318 | Operation::NamedTupleGet(_)
319 | Operation::Sort(_)
320 | Operation::TupleGet(_)
321 | Operation::Constant(_, _)
322 | Operation::VectorGet
323 | Operation::Zip
324 | Operation::Repeat(_)
325 | Operation::VectorToArray
326 | Operation::Join(_, _)
327 | Operation::JoinWithColumnMasks(_, _)
328 | Operation::Print(_)
329 | Operation::Shard(_) => Ok(false),
330 Operation::Call | Operation::Iterate => Err(runtime_error!(
331 "The status of operations calling other graphs cannot be defined"
332 )),
333 Operation::Custom(_) => Err(runtime_error!(
334 "The status of custom operations cannot be defined"
335 )),
336 }
337 }
338}
339
340struct NodeBody {
341 graph: WeakGraph,
342 node_dependencies: Vec<WeakNode>,
343 graph_dependencies: Vec<WeakGraph>,
344 operation: Operation,
345 id: u64,
346}
347
348#[derive(Serialize, Deserialize)]
349struct SerializableNodeBody {
350 node_dependencies: Vec<u64>,
351 graph_dependencies: Vec<u64>,
352 operation: Operation,
353}
354
355type NodeBodyPointer = Arc<AtomicRefCell<NodeBody>>;
356
357/// A structure that stores a pointer to a computation graph node that corresponds to an operation.
358///
359/// [Clone] trait duplicates the pointer, not the underlying nodes.
360///
361/// [PartialEq] trait compares pointers, not the related nodes.
362///
363/// # Example
364///
365/// ```
366/// # use ciphercore_base::graphs::create_context;
367/// # use ciphercore_base::data_types::{scalar_type, BIT};
368/// let c = create_context().unwrap();
369/// let g = c.create_graph().unwrap();
370/// let t = scalar_type(BIT);
371/// let n1 = g.input(t.clone()).unwrap();
372/// let n2 = g.input(t).unwrap();
373/// assert!(n1 != n2);
374/// let n3 = n1.clone();
375/// assert!(n1 == n3);
376/// ```
377#[cfg_attr(feature = "py-binding", struct_wrapper)]
378pub struct Node {
379 body: NodeBodyPointer,
380}
381
382type SerializableNode = Arc<SerializableNodeBody>;
383
384impl Clone for Node {
385 /// Returns a new [Node] value with a copy of the pointer to a node.
386 fn clone(&self) -> Self {
387 Node {
388 body: self.body.clone(),
389 }
390 }
391}
392
393impl PartialEq for Node {
394 /// Tests whether `self` and `other` nodes are equal via comparison of their respective pointers.
395 ///
396 /// # Arguments
397 ///
398 /// `other` - another [Node] value
399 ///
400 /// # Returns
401 ///
402 /// `true` if `self` and `other` are equal, `false` otherwise
403 fn eq(&self, other: &Self) -> bool {
404 Arc::ptr_eq(&self.body, &other.body)
405 }
406}
407
408impl Eq for Node {}
409
410impl Hash for Node {
411 /// Hashes the node pointer.
412 ///
413 /// # Arguments
414 ///
415 /// `state` - state of a hash function that is changed after hashing the node
416 fn hash<H: Hasher>(&self, state: &mut H) {
417 ptr::hash(&*self.body, state);
418 }
419}
420
421/// Public methods which supposed to be imported in Python.
422#[cfg_attr(feature = "py-binding", impl_wrapper)]
423impl Node {
424 /// Returns the parent graph that contains the node.
425 ///
426 /// # Returns
427 ///
428 /// Parent graph of the node
429 pub fn get_graph(&self) -> Graph {
430 self.body.borrow().graph.upgrade()
431 }
432
433 /// Returns the dependency nodes that are used to compute the value in the current node.
434 ///
435 /// # Returns
436 ///
437 /// Vector of nodes used by the node to perform its operation
438 pub fn get_node_dependencies(&self) -> Vec<Node> {
439 self.body
440 .borrow()
441 .node_dependencies
442 .iter()
443 .map(|n| n.upgrade())
444 .collect()
445 }
446
447 /// Returns the dependency graphs that are used to compute the value in the current node.
448 ///
449 /// These dependencies are non-empty only for `Call` and `Iterate` operations.
450 ///
451 /// # Returns
452 ///
453 /// Vector of graphs used by the node to perform its operation
454 pub fn get_graph_dependencies(&self) -> Vec<Graph> {
455 self.body
456 .borrow()
457 .graph_dependencies
458 .iter()
459 .map(|g| g.upgrade())
460 .collect()
461 }
462
463 /// Returns the ID of the node.
464 ///
465 /// A node ID is a serial number of a node between `0` and `n-1` where `n` is the number of nodes in the parent graph.
466 /// This number is equal to the number of nodes in the parent graph before this node was added to it.
467 ///
468 /// # Returns
469 ///
470 /// Node ID
471 pub fn get_id(&self) -> u64 {
472 self.body.borrow().id
473 }
474
475 /// Returns the pair of the parent graph ID and node ID
476 ///
477 /// # Returns
478 ///
479 /// (Graph ID, Node ID)
480 pub fn get_global_id(&self) -> (u64, u64) {
481 (self.get_graph().get_id(), self.get_id())
482 }
483
484 /// Returns the operation associated with the node.
485 ///
486 /// # Returns
487 ///
488 /// Operation associated with the node
489 pub fn get_operation(&self) -> Operation {
490 self.body.borrow().operation.clone()
491 }
492
493 /// Returns the type of the value computed by the node.
494 ///
495 /// # Returns
496 ///
497 /// Output type of the node operation
498 pub fn get_type(&self) -> Result<Type> {
499 let context = self.get_graph().get_context();
500
501 {
502 let context_body = context.body.borrow();
503 if let Some(tc) = &context_body.type_checker {
504 if let Some(cached_type) = tc.cached_node_type(self)? {
505 return Ok(cached_type);
506 }
507 }
508 }
509
510 let mut context_body = context.body.borrow_mut();
511 if let Some(tc) = &mut context_body.type_checker {
512 tc.process_node(self.clone())
513 } else {
514 Err(runtime_error!("Type checker is not available"))
515 }
516 }
517 /// Applies [Context::set_node_name] to the parent context and `this` node. Returns the clone of `this`.
518 ///
519 /// # Example
520 ///
521 /// ```
522 /// # use ciphercore_base::graphs::create_context;
523 /// # use ciphercore_base::data_types::{scalar_type, BIT};
524 /// let c = create_context().unwrap();
525 /// let g = c.create_graph().unwrap();
526 /// let t = scalar_type(BIT);
527 /// let n = g.input(t).unwrap();
528 /// n.set_name("XOR").unwrap();
529 /// ```
530 pub fn set_name(&self, name: &str) -> Result<Node> {
531 self.get_graph()
532 .get_context()
533 .set_node_name(self.clone(), name)?;
534 Ok(self.clone())
535 }
536
537 /// Applies [Context::get_node_name] to the parent context and `this` node.
538 ///
539 /// # Example
540 ///
541 /// ```
542 /// # use ciphercore_base::graphs::create_context;
543 /// # use ciphercore_base::data_types::{scalar_type, BIT};
544 /// let c = create_context().unwrap();
545 /// let g = c.create_graph().unwrap();
546 /// let t = scalar_type(BIT);
547 /// let n = g.input(t).unwrap();
548 /// n.set_name("XOR").unwrap();
549 /// assert_eq!(n.get_name().unwrap(), Some("XOR".to_owned()));
550 /// ```
551 pub fn get_name(&self) -> Result<Option<String>> {
552 self.get_graph().get_context().get_node_name(self.clone())
553 }
554
555 /// Adds a node to the parent graph that adds elementwise the array or scalar associated with the node to an array or scalar of the same scalar type associated with another node.
556 ///
557 /// Applies [Graph::add] to the parent graph, `this` node and the `b` node.
558 ///
559 /// # Example
560 ///
561 /// ```
562 /// # use ciphercore_base::graphs::create_context;
563 /// # use ciphercore_base::data_types::{BIT, scalar_type};
564 /// let c = create_context().unwrap();
565 /// let g = c.create_graph().unwrap();
566 /// let t = scalar_type(BIT);
567 /// let n1 = g.input(t.clone()).unwrap();
568 /// let n2 = g.input(t).unwrap();
569 /// let n3 = n1.add(n2).unwrap();
570 /// ```
571 pub fn add(&self, b: Node) -> Result<Node> {
572 self.get_graph().add(self.clone(), b)
573 }
574
575 /// Adds a node to the parent graph that subtracts elementwise the array or scalar of the same scalar type associated with another node from an array or scalar associated with the node.
576 ///
577 /// Applies [Graph::subtract] to the parent graph, `this` node and the `b` node.
578 ///
579 /// # Example
580 ///
581 /// ```
582 /// # use ciphercore_base::graphs::create_context;
583 /// # use ciphercore_base::data_types::{BIT, scalar_type};
584 /// let c = create_context().unwrap();
585 /// let g = c.create_graph().unwrap();
586 /// let t = scalar_type(BIT);
587 /// let n1 = g.input(t.clone()).unwrap();
588 /// let n2 = g.input(t).unwrap();
589 /// let n3 = n1.subtract(n2).unwrap();
590 /// ```
591 pub fn subtract(&self, b: Node) -> Result<Node> {
592 self.get_graph().subtract(self.clone(), b)
593 }
594
595 /// Adds a node to the parent graph that multiplies elementwise the array or scalar associated with the node by an array or scalar of the same scalar type associated with another node.
596 ///
597 /// Applies [Graph::multiply] to the parent graph, `this` node and the `b` node.
598 ///
599 /// # Example
600 ///
601 /// ```
602 /// # use ciphercore_base::graphs::create_context;
603 /// # use ciphercore_base::data_types::{BIT, scalar_type};
604 /// let c = create_context().unwrap();
605 /// let g = c.create_graph().unwrap();
606 /// let t = scalar_type(BIT);
607 /// let n1 = g.input(t.clone()).unwrap();
608 /// let n2 = g.input(t).unwrap();
609 /// let n3 = n1.multiply(n2).unwrap();
610 /// ```
611 pub fn multiply(&self, b: Node) -> Result<Node> {
612 self.get_graph().multiply(self.clone(), b)
613 }
614
615 /// Adds a node to the parent graph that multiplies elementwise the array or scalar associated with the node by a binary array or scalar associated with another node.
616 ///
617 /// Applies [Graph::mixed_multiply] to the parent graph, `this` node and the `b` node.
618 ///
619 /// # Example
620 ///
621 /// ```
622 /// # use ciphercore_base::graphs::create_context;
623 /// # use ciphercore_base::data_types::{BIT, INT32, scalar_type};
624 /// let c = create_context().unwrap();
625 /// let g = c.create_graph().unwrap();
626 /// let t = scalar_type(INT32);
627 /// let bit_t = scalar_type(BIT);
628 /// let n1 = g.input(t).unwrap();
629 /// let n2 = g.input(bit_t).unwrap();
630 /// let n3 = n1.mixed_multiply(n2).unwrap();
631 /// ```
632 pub fn mixed_multiply(&self, b: Node) -> Result<Node> {
633 self.get_graph().mixed_multiply(self.clone(), b)
634 }
635
636 /// Adds a node to the parent graph that computes the dot product of arrays or scalars associated with the node and another node.
637 ///
638 /// Applies [Graph::dot] to the parent graph, `this` node and the `b` node.
639 ///
640 /// # Example
641 ///
642 /// ```
643 /// # use ciphercore_base::graphs::create_context;
644 /// # use ciphercore_base::data_types::{INT32, array_type};
645 /// let c = create_context().unwrap();
646 /// let g = c.create_graph().unwrap();
647 /// let t = array_type(vec![10], INT32);
648 /// let n1 = g.input(t.clone()).unwrap();
649 /// let n2 = g.input(t).unwrap();
650 /// let n3 = n1.dot(n2).unwrap();
651 /// ```
652 pub fn dot(&self, b: Node) -> Result<Node> {
653 self.get_graph().dot(self.clone(), b)
654 }
655
656 /// Adds a node to the parent graph that computes the matrix product of two arrays associated with the node and another node.
657 ///
658 /// Applies [Graph::matmul] to the parent graph, `this` node and the `b` node.
659 ///
660 /// # Example
661 ///
662 /// ```
663 /// # use ciphercore_base::graphs::create_context;
664 /// # use ciphercore_base::data_types::{INT32, array_type};
665 /// let c = create_context().unwrap();
666 /// let g = c.create_graph().unwrap();
667 /// let t1 = array_type(vec![2, 3], INT32);
668 /// let t2 = array_type(vec![3, 2], INT32);
669 /// let n1 = g.input(t1).unwrap();
670 /// let n2 = g.input(t2).unwrap();
671 /// let n3 = n1.matmul(n2).unwrap();
672 /// ```
673 pub fn matmul(&self, b: Node) -> Result<Node> {
674 self.get_graph().matmul(self.clone(), b)
675 }
676
677 /// Adds a node to the parent graph that computes the generatl matrix product of two arrays associated with the node and another node.
678 ///
679 /// Applies [Graph::gemm] to the parent graph, `this` node and the `b` node.
680 ///
681 /// # Example
682 ///
683 /// ```
684 /// # use ciphercore_base::graphs::create_context;
685 /// # use ciphercore_base::data_types::{INT32, array_type};
686 /// let c = create_context().unwrap();
687 /// let g = c.create_graph().unwrap();
688 /// let t1 = array_type(vec![2, 3], INT32);
689 /// let t2 = array_type(vec![2, 3], INT32);
690 /// let n1 = g.input(t1).unwrap();
691 /// let n2 = g.input(t2).unwrap();
692 /// let n3 = n1.gemm(n2, false, true).unwrap();
693 /// ```
694 #[doc(hidden)]
695 pub fn gemm(&self, b: Node, transpose_a: bool, transpose_b: bool) -> Result<Node> {
696 self.get_graph()
697 .gemm(self.clone(), b, transpose_a, transpose_b)
698 }
699
700 /// Adds a node that computes a join of a given type on two named tuples along given key headers.
701 /// More detailed documentation can be found in [Graph::join].
702 ///
703 /// Applies [Graph::join] to the parent graph, `this` node and the `b` node.
704 ///
705 /// # Example
706 ///
707 /// ```
708 /// # use ciphercore_base::graphs::{create_context, JoinType};
709 /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type};
710 /// # use ciphercore_base::type_inference::NULL_HEADER;
711 /// # use std::collections::HashMap;
712 /// let c = create_context().unwrap();
713 /// let g = c.create_graph().unwrap();
714 /// let t1n = array_type(vec![100], BIT);
715 /// let t11 = array_type(vec![100], INT32);
716 /// let t12 = array_type(vec![100, 128], BIT);
717 /// let t13 = array_type(vec![100], INT64);
718 /// let t2n = array_type(vec![50], BIT);
719 /// let t21 = array_type(vec![50], INT32);
720 /// let t22 = array_type(vec![50, 128], BIT);
721 /// let t23 = array_type(vec![50], UINT8);
722 /// let t1 = named_tuple_type(vec![
723 /// (NULL_HEADER.to_owned(), t1n),
724 /// ("ID".to_owned(), t11),
725 /// ("Occupation".to_owned(), t12),
726 /// ("Revenue".to_owned(), t13),
727 /// ]);
728 /// let t2 = named_tuple_type(vec![
729 /// (NULL_HEADER.to_owned(), t2n),
730 /// ("ID".to_owned(), t21),
731 /// ("Job".to_owned(), t22),
732 /// ("Age".to_owned(), t23),
733 /// ]);
734 /// let n1 = g.input(t1).unwrap();
735 /// let n2 = g.input(t2).unwrap();
736 /// let n3 = n1.join(n2, JoinType::Inner, HashMap::from([
737 /// ("ID".to_owned(), "ID".to_owned()),
738 /// ("Occupation".to_owned(), "Job".to_owned()),
739 /// ])).unwrap();
740 /// ```
741 pub fn join(&self, b: Node, t: JoinType, headers: HashMap<String, String>) -> Result<Node> {
742 self.get_graph().join(self.clone(), b, t, headers)
743 }
744
745 /// Adds a node that computes a join of a given type on two named tuples along given key headers.
746 /// More detailed documentation can be found in [Graph::join_with_column_masks].
747 ///
748 /// Applies [Graph::join_with_column_masks] to the parent graph, `this` node and the `b` node.
749 ///
750 /// # Example
751 ///
752 /// ```
753 /// # use ciphercore_base::graphs::{create_context, JoinType};
754 /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type, tuple_type};
755 /// # use ciphercore_base::type_inference::NULL_HEADER;
756 /// # use std::collections::HashMap;
757 /// let c = create_context().unwrap();
758 /// let g = c.create_graph().unwrap();
759 /// let t1n = array_type(vec![100], BIT);
760 /// let t11 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT32)]);
761 /// let t12 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100, 128], BIT)]);
762 /// let t13 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT64)]);
763 /// let t2n = array_type(vec![50], BIT);
764 /// let t21 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], INT32)]);
765 /// let t22 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50, 128], BIT)]);
766 /// let t23 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], UINT8)]);
767 /// let t1 = named_tuple_type(vec![
768 /// (NULL_HEADER.to_owned(), t1n),
769 /// ("ID".to_owned(), t11),
770 /// ("Occupation".to_owned(), t12),
771 /// ("Revenue".to_owned(), t13),
772 /// ]);
773 /// let t2 = named_tuple_type(vec![
774 /// (NULL_HEADER.to_owned(), t2n),
775 /// ("ID".to_owned(), t21),
776 /// ("Job".to_owned(), t22),
777 /// ("Age".to_owned(), t23),
778 /// ]);
779 /// let n1 = g.input(t1).unwrap();
780 /// let n2 = g.input(t2).unwrap();
781 /// let n3 = n1.join_with_column_masks(n2, JoinType::Inner, HashMap::from([
782 /// ("ID".to_owned(), "ID".to_owned()),
783 /// ("Occupation".to_owned(), "Job".to_owned()),
784 /// ])).unwrap();
785 /// ```
786 pub fn join_with_column_masks(
787 &self,
788 b: Node,
789 t: JoinType,
790 headers: HashMap<String, String>,
791 ) -> Result<Node> {
792 self.get_graph()
793 .join_with_column_masks(self.clone(), b, t, headers)
794 }
795
796 /// Adds a node that applies a permutation to the array along the first dimension.
797 ///
798 /// # Arguments
799 ///
800 /// * `p` - node containing a permutation.
801 ///
802 /// # Returns
803 ///
804 /// New permuted node
805 ///
806 /// # Example
807 ///
808 /// ```
809 /// # use ciphercore_base::graphs::create_context;
810 /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
811 /// let c = create_context().unwrap();
812 /// let g = c.create_graph().unwrap();
813 /// let t = array_type(vec![25, 3], INT32);
814 /// let a = g.input(t).unwrap();
815 /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
816 /// let a = a.apply_permutation(p).unwrap();
817 /// ```
818 #[doc(hidden)]
819 pub fn apply_permutation(&self, p: Node) -> Result<Node> {
820 self.get_graph().apply_permutation(self.clone(), p)
821 }
822
823 /// Adds a node that applies an inverse permutation to the array along the first dimension.
824 ///
825 /// # Arguments
826 ///
827 /// * `p` - node containing a permutation.
828 ///
829 /// # Returns
830 ///
831 /// New permuted node
832 ///
833 /// # Example
834 ///
835 /// ```
836 /// # use ciphercore_base::graphs::create_context;
837 /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
838 /// let c = create_context().unwrap();
839 /// let g = c.create_graph().unwrap();
840 /// let t = array_type(vec![25, 3], INT32);
841 /// let a = g.input(t).unwrap();
842 /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
843 /// let a = a.apply_inverse_permutation(p).unwrap();
844 /// ```
845 #[doc(hidden)]
846 pub fn apply_inverse_permutation(&self, p: Node) -> Result<Node> {
847 self.get_graph().apply_inverse_permutation(self.clone(), p)
848 }
849
850 /// Adds a node that sorts a table given as named tuple according to the column given by the key argument.
851 /// The key column must be a 2-d BIT array of shape [n, b], interpreted as bitstrings of length b.
852 /// Other columns in the named tuple must be arrays of arbitrary type and shape, as long as they
853 /// share the first dimension: [n, ...].
854 /// Bitstrings are sorted lexicographically, and the sorting algorithm is stable: preserving relative
855 /// order of entries in other arrays where the corresponding key entries match.
856 ///
857 /// # Arguments
858 /// * `key` - name of the field to sort on it, this array must be 2-d of type BIT.
859 ///
860 /// # Returns
861 ///
862 /// New sorted node
863 ///
864 /// # Example
865 ///
866 /// ```
867 /// # use ciphercore_base::graphs::create_context;
868 /// # use ciphercore_base::data_types::{BIT, INT32, UINT64, array_type, named_tuple_type};
869 /// let c = create_context().unwrap();
870 /// let g = c.create_graph().unwrap();
871 /// let v1 = g.input(array_type(vec![20], INT32)).unwrap();
872 /// let v2 = g.input(array_type(vec![20, 10, 2], UINT64)).unwrap();
873 /// let k = g.input(array_type(vec![20, 32], BIT)).unwrap();
874 /// let a = g.create_named_tuple(vec![("key".to_string(), k), ("value1".to_string(), v1), ("value2".to_string(), v2)]).unwrap();
875 /// let a = a.sort("key".to_string()).unwrap();
876 /// ```
877 pub fn sort(&self, key: String) -> Result<Node> {
878 self.get_graph().sort(self.clone(), key)
879 }
880
881 /// Adds a node to the parent graph that divides a scalar or each entry of the array associated with the node by a positive constant integer `scale`.
882 ///
883 /// Applies [Graph::add] to the parent graph, `this` node and `scale`.
884 ///
885 /// # Example
886 ///
887 /// ```
888 /// # use ciphercore_base::graphs::create_context;
889 /// # use ciphercore_base::data_types::{INT32, array_type};
890 /// let c = create_context().unwrap();
891 /// let g = c.create_graph().unwrap();
892 /// let t = array_type(vec![2, 3], INT32);
893 /// let n1 = g.input(t).unwrap();
894 /// let n2 = n1.truncate(4).unwrap();
895 /// ```
896 pub fn truncate(&self, scale: u128) -> Result<Node> {
897 self.get_graph().truncate(self.clone(), scale)
898 }
899
900 /// Adds a node to the parent graph that computes the sum of entries of the array associated with the node along given axes.
901 ///
902 /// Applies [Graph::sum] to the parent graph, `this` node and `axes`.
903 ///
904 /// # Example
905 ///
906 /// ```
907 /// # use ciphercore_base::graphs::create_context;
908 /// # use ciphercore_base::data_types::{INT32, array_type};
909 /// let c = create_context().unwrap();
910 /// let g = c.create_graph().unwrap();
911 /// let t = array_type(vec![3, 2, 3], INT32);
912 /// let axes = vec![1, 0];
913 /// let n1 = g.input(t).unwrap();
914 /// let n2 = n1.sum(axes).unwrap();
915 /// ```
916 pub fn sum(&self, axes: ArrayShape) -> Result<Node> {
917 self.get_graph().sum(self.clone(), axes)
918 }
919
920 /// Adds a node to the parent graph that computes the cumulative sum of elements along a given axis.
921 ///
922 /// Applies [Graph::cum_sum] to the parent graph, `this` node and `axis`.
923 ///
924 /// # Example
925 ///
926 /// ```
927 /// # use ciphercore_base::graphs::create_context;
928 /// # use ciphercore_base::data_types::{INT32, array_type};
929 /// let c = create_context().unwrap();
930 /// let g = c.create_graph().unwrap();
931 /// let t = array_type(vec![3, 2], INT32);
932 /// let n1 = g.input(t).unwrap();
933 /// let n2 = n1.cum_sum(1).unwrap();
934 /// ```
935 pub fn cum_sum(&self, axis: u64) -> Result<Node> {
936 self.get_graph().cum_sum(self.clone(), axis)
937 }
938
939 /// Adds a node to the parent graph that permutes the array associated with the node along given axes.
940 ///
941 /// Applies [Graph::permute_axes] to the parent graph, `this` node and `axes`.
942 ///
943 /// # Example
944 ///
945 /// ```
946 /// # use ciphercore_base::graphs::create_context;
947 /// # use ciphercore_base::data_types::{INT32, array_type};
948 /// let c = create_context().unwrap();
949 /// let g = c.create_graph().unwrap();
950 /// let t = array_type(vec![3, 2, 3], INT32);
951 /// let axes = vec![1, 0, 2];
952 /// let n1 = g.input(t).unwrap();
953 /// let n2 = n1.permute_axes(axes).unwrap();
954 /// ```
955 pub fn permute_axes(&self, axes: ArrayShape) -> Result<Node> {
956 self.get_graph().permute_axes(self.clone(), axes)
957 }
958
959 /// Adds a node to the parent graph that inverts a given permutation.
960 ///
961 /// Applies [Graph::inverse_permutation] to the parent graph and `this` node.
962 #[doc(hidden)]
963 pub fn inverse_permutation(&self) -> Result<Node> {
964 self.get_graph().inverse_permutation(self.clone())
965 }
966
967 /// Adds a node to the parent graph that extracts a sub-array with a given index from the array associated with the node.
968 ///
969 /// Applies [Graph::get] to the parent graph, `this` node and `index`.
970 ///
971 /// # Example
972 ///
973 /// ```
974 /// # use ciphercore_base::graphs::create_context;
975 /// # use ciphercore_base::data_types::{INT32, array_type};
976 /// let c = create_context().unwrap();
977 /// let g = c.create_graph().unwrap();
978 /// let t = array_type(vec![3, 2, 3], INT32);
979 /// let index = vec![2];
980 /// let n1 = g.input(t).unwrap();
981 /// let n2 = n1.get(index).unwrap();
982 /// ```
983 pub fn get(&self, index: ArrayShape) -> Result<Node> {
984 self.get_graph().get(self.clone(), index)
985 }
986
987 /// Adds a node that extracts a sub-array corresponding to a given slice from the array associated with the node.
988 ///
989 /// Applies [Graph::get_slice] to the parent graph, `this` node and `slice`.
990 ///
991 /// # Example
992 ///
993 /// ```
994 /// # use ciphercore_base::graphs::{create_context, SliceElement};
995 /// # use ciphercore_base::data_types::{INT32, array_type};
996 /// let c = create_context().unwrap();
997 /// let g = c.create_graph().unwrap();
998 /// let t = array_type(vec![3, 2, 3], INT32);
999 /// let slice = vec![SliceElement::Ellipsis, SliceElement::SubArray(None, None, Some(-2))];
1000 /// let n1 = g.input(t).unwrap();
1001 /// let n2 = n1.get_slice(slice).unwrap();
1002 /// ```
1003 pub fn get_slice(&self, slice: Slice) -> Result<Node> {
1004 self.get_graph().get_slice(self.clone(), slice)
1005 }
1006
1007 /// Adds a node to the parent graph that reshapes a value associated with the node to a given compatible type.
1008 ///
1009 /// Applies [Graph::reshape] to the parent graph, `this` node and `new_type`.
1010 ///
1011 /// # Example
1012 ///
1013 /// ```
1014 /// # use ciphercore_base::graphs::create_context;
1015 /// # use ciphercore_base::data_types::{INT32, array_type};
1016 /// let c = create_context().unwrap();
1017 /// let g = c.create_graph().unwrap();
1018 /// let old_t = array_type(vec![3, 2, 3], INT32);
1019 /// let new_t = array_type(vec![3,6], INT32);
1020 /// let n1 = g.input(old_t).unwrap();
1021 /// let n2 = n1.reshape(new_t).unwrap();
1022 /// ```
1023 pub fn reshape(&self, new_type: Type) -> Result<Node> {
1024 self.get_graph().reshape(self.clone(), new_type)
1025 }
1026
1027 #[doc(hidden)]
1028 pub fn nop(&self) -> Result<Node> {
1029 self.get_graph().nop(self.clone())
1030 }
1031
1032 #[doc(hidden)]
1033 pub fn prf(&self, iv: u64, output_type: Type) -> Result<Node> {
1034 self.get_graph().prf(self.clone(), iv, output_type)
1035 }
1036
1037 #[doc(hidden)]
1038 pub fn permutation_from_prf(&self, iv: u64, n: u64) -> Result<Node> {
1039 self.get_graph().permutation_from_prf(self.clone(), iv, n)
1040 }
1041
1042 /// Adds a node to the parent graph converting an integer array or scalar associated with the node to the binary form.
1043 ///
1044 /// Applies [Graph::a2b] to the parent graph and `this` node.
1045 ///
1046 /// # Example
1047 ///
1048 /// ```
1049 /// # use ciphercore_base::graphs::create_context;
1050 /// # use ciphercore_base::data_types::{array_type, INT32};
1051 /// let c = create_context().unwrap();
1052 /// let g = c.create_graph().unwrap();
1053 /// let t = array_type(vec![3, 2], INT32);
1054 /// let n1 = g.input(t).unwrap();
1055 /// let n2 = n1.a2b().unwrap();
1056 /// ```
1057 pub fn a2b(&self) -> Result<Node> {
1058 self.get_graph().a2b(self.clone())
1059 }
1060
1061 /// Adds a node to the parent graph converting a binary array associated with the node to an array of a given scalar type.
1062 ///
1063 /// Applies [Graph::b2a] to the parent graph, `this` node and `scalar_type`.
1064 ///
1065 /// # Example
1066 ///
1067 /// ```
1068 /// # use ciphercore_base::graphs::create_context;
1069 /// # use ciphercore_base::data_types::{BIT, INT32, array_type};
1070 /// let c = create_context().unwrap();
1071 /// let g = c.create_graph().unwrap();
1072 /// let t = array_type(vec![3, 32], BIT);
1073 /// let n1 = g.input(t).unwrap();
1074 /// let n2 = n1.b2a(INT32).unwrap();
1075 /// ```
1076 pub fn b2a(&self, scalar_type: ScalarType) -> Result<Node> {
1077 self.get_graph().b2a(self.clone(), scalar_type)
1078 }
1079
1080 /// Adds a node that extracts an element of a tuple associated with the node.
1081 ///
1082 /// Applies [Graph::tuple_get] to the parent graph, `this` node and `index`.
1083 ///
1084 /// # Example
1085 ///
1086 /// ```
1087 /// # use ciphercore_base::data_types::{INT32, array_type};
1088 /// # use ciphercore_base::graphs::create_context;
1089 /// let c = create_context().unwrap();
1090 /// let g = c.create_graph().unwrap();
1091 /// let t1 = array_type(vec![3, 2, 3], INT32);
1092 /// let t2 = array_type(vec![2, 3], INT32);
1093 /// let n1 = g.input(t1).unwrap();
1094 /// let n2 = g.input(t2).unwrap();
1095 /// let n3 = g.create_tuple(vec![n1, n2]).unwrap();
1096 /// let n4 = n3.tuple_get(1).unwrap();
1097 /// ```
1098 pub fn tuple_get(&self, index: u64) -> Result<Node> {
1099 self.get_graph().tuple_get(self.clone(), index)
1100 }
1101
1102 /// Adds a node to the parent graph that extracts an element of a named tuple associated with the node.
1103 ///
1104 /// Applies [Graph::named_tuple_get] to the parent graph, `this` node and the `key` string.
1105 ///
1106 /// # Example
1107 ///
1108 /// ```
1109 /// # use ciphercore_base::graphs::create_context;
1110 /// # use ciphercore_base::data_types::{array_type, INT32};
1111 /// let c = create_context().unwrap();
1112 /// let g = c.create_graph().unwrap();
1113 /// let t1 = array_type(vec![3, 2, 3], INT32);
1114 /// let t2 = array_type(vec![2, 3], INT32);
1115 /// let n1 = g.input(t1).unwrap();
1116 /// let n2 = g.input(t2).unwrap();
1117 /// let n3 = g.create_named_tuple(vec![("node1".to_owned(), n1), ("node2".to_owned(), n2)]).unwrap();
1118 /// let n4 = n3.named_tuple_get("node2".to_owned()).unwrap();
1119 /// ```
1120 pub fn named_tuple_get(&self, key: String) -> Result<Node> {
1121 self.get_graph().named_tuple_get(self.clone(), key)
1122 }
1123
1124 /// Adds a node to the parent graph that extracts an element of a vector associated with the node.
1125 ///
1126 /// Applies [Graph::vector_get] to the parent graph, `this` node and the `index` node.
1127 ///
1128 /// # Example
1129 ///
1130 /// ```
1131 /// # use ciphercore_base::graphs::create_context;
1132 /// # use ciphercore_base::data_types::{UINT32, INT32, array_type, scalar_type};
1133 /// # use ciphercore_base::data_values::Value;
1134 /// let c = create_context().unwrap();
1135 /// let g = c.create_graph().unwrap();
1136 /// let t = array_type(vec![3, 2, 3], INT32);
1137 /// let n1 = g.input(t.clone()).unwrap();
1138 /// let n2 = g.input(t.clone()).unwrap();
1139 /// let n3 = g.create_vector(t, vec![n1,n2]).unwrap();
1140 /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
1141 /// let n4 = n3.vector_get(index).unwrap();
1142 /// ```
1143 pub fn vector_get(&self, index: Node) -> Result<Node> {
1144 self.get_graph().vector_get(self.clone(), index)
1145 }
1146
1147 /// Adds a node to the parent graph converting an array associated with the node to a vector.
1148 ///
1149 /// Applies [Graph::array_to_vector] to the parent graph and `this` node.
1150 ///
1151 /// # Example
1152 ///
1153 /// ```
1154 /// # use ciphercore_base::graphs::create_context;
1155 /// # use ciphercore_base::data_types::{array_type, scalar_type, INT32, UINT32};
1156 /// # use ciphercore_base::data_values::Value;
1157 /// let c = create_context().unwrap();
1158 /// let g = c.create_graph().unwrap();
1159 /// let t = array_type(vec![4, 3, 2], INT32);
1160 /// let n1 = g.input(t).unwrap();
1161 /// let n2 = g.array_to_vector(n1).unwrap();
1162 /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
1163 /// let n3 = n2.vector_get(index).unwrap();
1164 ///
1165 /// assert!(n2.get_type().unwrap().is_vector());
1166 /// assert_eq!(n3.get_type().unwrap().get_shape(), vec![3,2]);
1167 /// ```
1168 pub fn array_to_vector(&self) -> Result<Node> {
1169 self.get_graph().array_to_vector(self.clone())
1170 }
1171
1172 /// Adds a node to the parent graph converting a vector associated with the node to an array.
1173 ///
1174 /// Applies [Graph::vector_to_array] to the parent graph and `this` node.
1175 ///
1176 /// # Example
1177 ///
1178 /// ```
1179 /// # use ciphercore_base::graphs::create_context;
1180 /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
1181 /// let c = create_context().unwrap();
1182 /// let g = c.create_graph().unwrap();
1183 /// let t = array_type(vec![3, 2], INT32);
1184 /// let vec_t = vector_type(4, t);
1185 /// let n1 = g.input(vec_t).unwrap();
1186 /// let n2 = n1.vector_to_array().unwrap();
1187 ///
1188 /// assert!(n2.get_type().unwrap().is_array());
1189 /// assert_eq!(n2.get_type().unwrap().get_shape(), vec![4, 3, 2]);
1190 /// ```
1191 pub fn vector_to_array(&self) -> Result<Node> {
1192 self.get_graph().vector_to_array(self.clone())
1193 }
1194
1195 /// Adds a node to the parent graph converting a vector associated with the node to an array.
1196 ///
1197 /// Applies [Graph::gather] to the parent graph and `this` node.
1198 pub fn gather(&self, indices: Node, axis: u64) -> Result<Node> {
1199 self.get_graph().gather(self.clone(), indices, axis)
1200 }
1201
1202 /// Adds a node that creates a vector with `n` copies of a value of this node.
1203 ///
1204 /// Applies [Graph::repeat] to the parent graph, `this` node and `n`.
1205 ///
1206 /// # Example
1207 ///
1208 /// ```
1209 /// # use ciphercore_base::data_types::{INT32, array_type};
1210 /// # use ciphercore_base::graphs::create_context;
1211 /// let c = create_context().unwrap();
1212 /// let g = c.create_graph().unwrap();
1213 /// let t = array_type(vec![3, 2, 3], INT32);
1214 /// let n1 = g.input(t).unwrap();
1215 /// let n2 = n1.repeat(10).unwrap();
1216 /// ```
1217 pub fn repeat(&self, n: u64) -> Result<Node> {
1218 self.get_graph().repeat(self.clone(), n)
1219 }
1220
1221 /// Adds a node returning the Cuckoo hash map of an input array of binary strings using provided hash functions.
1222 ///
1223 /// Applies [Graph::cuckoo_hash] to the parent graph, `this` node and `hash_matrices`.
1224 #[doc(hidden)]
1225 pub fn cuckoo_hash(&self, hash_matrices: Node) -> Result<Node> {
1226 self.get_graph().cuckoo_hash(self.clone(), hash_matrices)
1227 }
1228
1229 /// Adds a node that, given an input multidimensional array A, binary one-dimensional array B (first dimension is n in both array) and starting value v, computes the following iteration
1230 ///
1231 /// output[i] = A[i-1] + B[i-1] * output[i-1]
1232 ///
1233 /// where i in {1,...,n} and output[0] = v.
1234 ///
1235 /// Applies [Graph::segment_cumsum] to the parent graph, `this` node, `binary_array` and `first_row`.
1236 #[doc(hidden)]
1237 pub fn segment_cumsum(&self, binary_array: Node, first_row: Node) -> Result<Node> {
1238 self.get_graph()
1239 .segment_cumsum(self.clone(), binary_array, first_row)
1240 }
1241
1242 /// Adds a node that computes sharding of a given table according to a given sharding config.
1243 /// Sharding config contains names of the columns whose hashed values are used for sharding.
1244 ///
1245 /// Each shard is accompanied by a Boolean mask indicating whether a corresponding row stems from the input table or padded (1 if a row comes from input).
1246 ///
1247 /// Applies [Graph::shard] to the parent graph, `this` node and `shard_config`.
1248 #[doc(hidden)]
1249 pub fn shard(&self, shard_config: ShardConfig) -> Result<Node> {
1250 self.get_graph().shard(self.clone(), shard_config)
1251 }
1252
1253 /// Adds a node that converts a switching map array into a tuple of the following components:
1254 /// - a permutation map array with deletion,
1255 /// - a duplication map array,
1256 /// - a permutation map array without deletion.
1257 ///
1258 /// The composition of these maps is equal to the input switching map, which is an array containing non-unique indices of some array.
1259 ///
1260 /// Applies [Graph::decompose_switching_map] to the parent graph and `this`.
1261 #[doc(hidden)]
1262 pub fn decompose_switching_map(&self, n: u64) -> Result<Node> {
1263 self.get_graph().decompose_switching_map(self.clone(), n)
1264 }
1265
1266 /// Adds a node that converts a Cuckoo hash table to a random permutation.
1267 ///
1268 /// Applies [Graph::cuckoo_to_permutation] to the parent graph and `this` node.
1269 #[doc(hidden)]
1270 pub fn cuckoo_to_permutation(&self) -> Result<Node> {
1271 self.get_graph().cuckoo_to_permutation(self.clone())
1272 }
1273
1274 /// Adds an operation which logs the value of the node at runtime.
1275 pub fn print(&self, message: String) -> Result<Node> {
1276 self.get_graph().print(message, self.clone())
1277 }
1278
1279 /// Applies [Graph::set_output_node] to the parent graph and `this` node.
1280 ///
1281 /// # Returns
1282 ///
1283 /// This node
1284 ///
1285 /// # Example
1286 ///
1287 /// ```
1288 /// # use ciphercore_base::graphs::create_context;
1289 /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
1290 /// let c = create_context().unwrap();
1291 /// let g = c.create_graph().unwrap();
1292 /// let t = array_type(vec![3, 2], INT32);
1293 /// let vec_t = vector_type(4, t);
1294 /// let n1 = g.input(vec_t).unwrap();
1295 /// let n2 = g.vector_to_array(n1).unwrap();
1296 /// n2.set_as_output().unwrap();
1297 /// g.finalize().unwrap();
1298 /// ```
1299 pub fn set_as_output(&self) -> Result<Node> {
1300 self.get_graph().set_output_node(self.clone())?;
1301 Ok(self.clone())
1302 }
1303}
1304
1305/// Methods which aren't supposed to be imported in Python.
1306impl Node {
1307 fn make_serializable(&self) -> SerializableNode {
1308 Arc::new(SerializableNodeBody {
1309 node_dependencies: self
1310 .get_node_dependencies()
1311 .iter()
1312 .map(|n| n.get_id())
1313 .collect(),
1314 graph_dependencies: self
1315 .get_graph_dependencies()
1316 .iter()
1317 .map(|n| n.get_id())
1318 .collect(),
1319 operation: self.get_operation(),
1320 })
1321 }
1322
1323 fn downgrade(&self) -> WeakNode {
1324 WeakNode {
1325 body: Arc::downgrade(&self.body),
1326 }
1327 }
1328
1329 #[doc(hidden)]
1330 pub fn add_annotation(&self, annotation: NodeAnnotation) -> Result<Node> {
1331 self.get_graph()
1332 .get_context()
1333 .add_node_annotation(self, annotation)?;
1334 Ok(self.clone())
1335 }
1336
1337 #[doc(hidden)]
1338 pub fn get_annotations(&self) -> Result<Vec<NodeAnnotation>> {
1339 self.get_graph()
1340 .get_context()
1341 .get_node_annotations(self.clone())
1342 }
1343}
1344type WeakNodeBodyPointer = Weak<AtomicRefCell<NodeBody>>;
1345
1346struct WeakNode {
1347 body: WeakNodeBodyPointer,
1348}
1349
1350impl WeakNode {
1351 //upgrade function panics if the the Node pointer it downgraded from went out of scope
1352 fn upgrade(&self) -> Node {
1353 Node {
1354 body: self.body.upgrade().unwrap(),
1355 }
1356 }
1357}
1358
1359impl Clone for WeakNode {
1360 fn clone(&self) -> Self {
1361 WeakNode {
1362 body: self.body.clone(),
1363 }
1364 }
1365}
1366
1367struct GraphBody {
1368 finalized: bool,
1369 nodes: Vec<Node>,
1370 output_node: Option<WeakNode>,
1371 id: u64,
1372 context: WeakContext,
1373}
1374
1375#[derive(Serialize, Deserialize)]
1376struct SerializableGraphBody {
1377 finalized: bool,
1378 nodes: Vec<SerializableNode>,
1379 output_node: Option<u64>,
1380}
1381
1382type GraphBodyPointer = Arc<AtomicRefCell<GraphBody>>;
1383
1384/// A structure that stores a pointer to a computation graph, where every node corresponds to an operation.
1385///
1386/// # Rust crates
1387///
1388/// [Clone] trait duplicates the pointer, not the underlying graph.
1389///
1390/// [PartialEq] trait compares pointers, not the related graphs.
1391///
1392/// # Example
1393///
1394/// ```
1395/// # use ciphercore_base::graphs::create_context;
1396/// let c = create_context().unwrap();
1397/// let g1 = c.create_graph().unwrap();
1398/// let g2 = c.create_graph().unwrap();
1399/// assert_ne!(g1, g2);
1400/// let g3 = g1.clone();
1401/// assert_eq!(g1, g3);
1402/// ```
1403#[cfg_attr(feature = "py-binding", struct_wrapper)]
1404pub struct Graph {
1405 body: GraphBodyPointer,
1406}
1407
1408type SerializableGraph = Arc<SerializableGraphBody>;
1409
1410impl fmt::Debug for Graph {
1411 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1412 f.debug_struct("Graph")
1413 .field("body", &self.body.as_ptr())
1414 .finish()
1415 }
1416}
1417
1418impl Clone for Graph {
1419 /// Returns a new [Graph] value with a copy of the pointer to a computation graph.
1420 fn clone(&self) -> Self {
1421 Graph {
1422 body: self.body.clone(),
1423 }
1424 }
1425}
1426
1427impl PartialEq for Graph {
1428 /// Tests whether `self` and `other` graphs are equal via comparison of their respective pointers.
1429 ///
1430 /// # Arguments
1431 ///
1432 /// `other` - another [Graph] value
1433 ///
1434 /// # Returns
1435 ///
1436 /// `true` if `self` and `other` are equal, `false` otherwise
1437 fn eq(&self, other: &Self) -> bool {
1438 Arc::ptr_eq(&self.body, &other.body)
1439 }
1440}
1441
1442impl Eq for Graph {}
1443
1444impl Hash for Graph {
1445 /// Hashes the graph pointer.
1446 ///
1447 /// # Arguments
1448 ///
1449 /// `state` - state of a hash function that is changed after hashing the graph
1450 fn hash<H: Hasher>(&self, state: &mut H) {
1451 ptr::hash(&*self.body, state);
1452 }
1453}
1454
1455/// Public methods which supposed to be imported in Python.
1456#[cfg_attr(feature = "py-binding", impl_wrapper)]
1457impl Graph {
1458 /// Applies [Context::set_main_graph] to the parent context and `this` graph. Returns the clone of `this`.
1459 ///
1460 /// # Returns
1461 ///
1462 /// This graph
1463 ///
1464 /// # Example
1465 ///
1466 /// ```
1467 /// # use ciphercore_base::graphs::create_context;
1468 /// # use ciphercore_base::data_types::{array_type, INT32};
1469 /// let c = create_context().unwrap();
1470 /// let g = c.create_graph().unwrap();
1471 /// let t = array_type(vec![3, 2], INT32);
1472 /// let n = g.input(t).unwrap();
1473 /// n.set_as_output().unwrap();
1474 /// g.finalize().unwrap();
1475 /// g.set_as_main().unwrap();
1476 /// ```
1477 pub fn set_as_main(&self) -> Result<Graph> {
1478 self.get_context().set_main_graph(self.clone())?;
1479 Ok(self.clone())
1480 }
1481
1482 /// Applies [Context::set_graph_name] to the parent context and `this` graph. Returns the clone of `this`.
1483 ///
1484 /// # Arguments
1485 ///
1486 /// `name` - name of the graph
1487 ///
1488 /// # Returns
1489 ///
1490 /// This graph
1491 ///
1492 /// # Example
1493 ///
1494 /// ```
1495 /// # use ciphercore_base::graphs::create_context;
1496 /// let c = create_context().unwrap();
1497 /// let g = c.create_graph().unwrap();
1498 /// g.set_name("relu").unwrap();
1499 /// ```
1500 pub fn set_name(&self, name: &str) -> Result<Graph> {
1501 self.get_context().set_graph_name(self.clone(), name)?;
1502 Ok(self.clone())
1503 }
1504
1505 /// Applies [Context::get_graph_name] to the parent context and `this` graph.
1506 ///
1507 /// # Example
1508 ///
1509 /// ```
1510 /// # use ciphercore_base::graphs::create_context;
1511 /// let c = create_context().unwrap();
1512 /// let g = c.create_graph().unwrap();
1513 /// g.set_name("relu").unwrap();
1514 /// assert_eq!(g.get_name().unwrap(), "relu".to_owned());
1515 /// ```
1516 pub fn get_name(&self) -> Result<String> {
1517 self.get_context().get_graph_name(self.clone())
1518 }
1519
1520 /// Applies [Context::retrieve_node] to the parent context and `this` graph.
1521 ///
1522 /// # Example
1523 ///
1524 /// ```
1525 /// # use ciphercore_base::graphs::create_context;
1526 /// # use ciphercore_base::data_types::{BIT, scalar_type};
1527 /// let c = create_context().unwrap();
1528 /// let g = c.create_graph().unwrap();
1529 /// let n = g.input(scalar_type(BIT)).unwrap();
1530 /// n.set_name("input_node").unwrap();
1531 /// assert!(n == g.retrieve_node("input_node").unwrap());
1532 /// ```
1533 pub fn retrieve_node(&self, name: &str) -> Result<Node> {
1534 self.get_context().retrieve_node(self.clone(), name)
1535 }
1536
1537 /// Adds an input node to the graph and returns it.
1538 ///
1539 /// During evaluation, input nodes require values to be supplied.
1540 ///
1541 /// # Arguments
1542 ///
1543 /// `input_type` - type of a new input node
1544 ///
1545 /// # Returns
1546 ///
1547 /// New input node
1548 ///
1549 /// # Example
1550 ///
1551 /// ```
1552 /// # use ciphercore_base::graphs::create_context;
1553 /// # use ciphercore_base::data_types::{BIT, scalar_type};
1554 /// let c = create_context().unwrap();
1555 /// let g = c.create_graph().unwrap();
1556 /// let t = scalar_type(BIT);
1557 /// let n = g.input(t).unwrap();
1558 /// ```
1559 pub fn input(&self, input_type: Type) -> Result<Node> {
1560 self.add_node(vec![], vec![], Operation::Input(input_type))
1561 }
1562
1563 /// Adds an node with zeros of given type.
1564 ///
1565 /// Compared to `constant` this node does produce a big value array in serialized graph.
1566 ///
1567 /// # Arguments
1568 ///
1569 /// `t` - node type
1570 ///
1571 /// # Returns
1572 ///
1573 /// New node with zeros of given type.
1574 ///
1575 /// # Example
1576 ///
1577 /// ```
1578 /// # use ciphercore_base::graphs::create_context;
1579 /// # use ciphercore_base::data_types::{UINT8, array_type};
1580 /// let c = create_context().unwrap();
1581 /// let g = c.create_graph().unwrap();
1582 /// let z = g.zeros(array_type(vec![10, 20], UINT8)).unwrap();
1583 /// ```
1584 pub fn zeros(&self, t: Type) -> Result<Node> {
1585 self.add_node(vec![], vec![], Operation::Zeros(t))
1586 }
1587
1588 /// Adds an node with ones of given type.
1589 ///
1590 /// Compared to `constant` this node does produce a big value array in serialized graph.
1591 ///
1592 /// # Arguments
1593 ///
1594 /// `t` - node type
1595 ///
1596 /// # Returns
1597 ///
1598 /// New node with ones of given type.
1599 ///
1600 /// # Example
1601 ///
1602 /// ```
1603 /// # use ciphercore_base::graphs::create_context;
1604 /// # use ciphercore_base::data_types::{UINT8, array_type};
1605 /// let c = create_context().unwrap();
1606 /// let g = c.create_graph().unwrap();
1607 /// let z = g.ones(array_type(vec![10, 20], UINT8)).unwrap();
1608 /// ```
1609 pub fn ones(&self, t: Type) -> Result<Node> {
1610 self.add_node(vec![], vec![], Operation::Ones(t))
1611 }
1612
1613 /// Adds a node that sums two arrays or scalars of the same scalar type elementwise.
1614 ///
1615 /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, adding two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1616 ///
1617 /// # Arguments
1618 ///
1619 /// * `a` - node containing the first term (array or scalar)
1620 /// * `b` - node containing the second term (array or scalar)
1621 ///
1622 /// # Returns
1623 ///
1624 /// New addition node
1625 ///
1626 /// # Example
1627 ///
1628 /// ```
1629 /// # use ciphercore_base::graphs::create_context;
1630 /// # use ciphercore_base::data_types::{BIT, scalar_type};
1631 /// let c = create_context().unwrap();
1632 /// let g = c.create_graph().unwrap();
1633 /// let t = scalar_type(BIT);
1634 /// let n1 = g.input(t.clone()).unwrap();
1635 /// let n2 = g.input(t).unwrap();
1636 /// let n3 = g.add(n1, n2).unwrap();
1637 /// ```
1638 pub fn add(&self, a: Node, b: Node) -> Result<Node> {
1639 self.add_node(vec![a, b], vec![], Operation::Add)
1640 }
1641
1642 /// Adds a node that subtracts two arrays or scalars of the same scalar type elementwise.
1643 ///
1644 /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, subtracting two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1645 ///
1646 /// # Arguments
1647 ///
1648 /// * `a` - node containing the minuend (array or scalar)
1649 /// * `b` - node containing the subtrahend (array or scalar)
1650 ///
1651 /// # Returns
1652 ///
1653 /// New subtraction node
1654 ///
1655 /// # Example
1656 ///
1657 /// ```
1658 /// # use ciphercore_base::graphs::create_context;
1659 /// # use ciphercore_base::data_types::{BIT, scalar_type};
1660 /// let c = create_context().unwrap();
1661 /// let g = c.create_graph().unwrap();
1662 /// let t = scalar_type(BIT);
1663 /// let n1 = g.input(t.clone()).unwrap();
1664 /// let n2 = g.input(t).unwrap();
1665 /// let n3 = g.subtract(n1, n2).unwrap();
1666 /// ```
1667 pub fn subtract(&self, a: Node, b: Node) -> Result<Node> {
1668 self.add_node(vec![a, b], vec![], Operation::Subtract)
1669 }
1670
1671 /// Adds a node that multiplies two arrays or scalars of the same scalar type elementwise.
1672 ///
1673 /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, multiplication of two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1674 ///
1675 /// # Arguments
1676 ///
1677 /// * `a` - node containing the first factor (array or scalar)
1678 /// * `b` - node containing the second factor (array or scalar)
1679 ///
1680 /// # Returns
1681 ///
1682 /// New multiplication node
1683 ///
1684 /// # Example
1685 ///
1686 /// ```
1687 /// # use ciphercore_base::graphs::create_context;
1688 /// # use ciphercore_base::data_types::{BIT, scalar_type};
1689 /// let c = create_context().unwrap();
1690 /// let g = c.create_graph().unwrap();
1691 /// let t = scalar_type(BIT);
1692 /// let n1 = g.input(t.clone()).unwrap();
1693 /// let n2 = g.input(t).unwrap();
1694 /// let n3 = g.multiply(n1, n2).unwrap();
1695 /// ```
1696 pub fn multiply(&self, a: Node, b: Node) -> Result<Node> {
1697 self.add_node(vec![a, b], vec![], Operation::Multiply)
1698 }
1699
1700 /// Adds a node that multiplies an integer array or scalar by a binary array or scalar elementwise.
1701 /// For each integer element, this operation returns this element or zero depending on the corresponding bit element.
1702 ///
1703 /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, multiplication of two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1704 ///
1705 /// # Arguments
1706 ///
1707 /// * `a` - node containing an integer array or scalar
1708 /// * `b` - node containing a binary array or scalar
1709 ///
1710 /// # Returns
1711 ///
1712 /// New mixed multiplication node
1713 ///
1714 /// # Example
1715 ///
1716 /// ```
1717 /// # use ciphercore_base::graphs::create_context;
1718 /// # use ciphercore_base::data_types::{BIT, INT32, scalar_type};
1719 /// let c = create_context().unwrap();
1720 /// let g = c.create_graph().unwrap();
1721 /// let t1 = scalar_type(INT32);
1722 /// let t2 = scalar_type(BIT);
1723 /// let n1 = g.input(t1).unwrap();
1724 /// let n2 = g.input(t2).unwrap();
1725 /// let n3 = g.mixed_multiply(n1, n2).unwrap();
1726 /// ```
1727 pub fn mixed_multiply(&self, a: Node, b: Node) -> Result<Node> {
1728 self.add_node(vec![a, b], vec![], Operation::MixedMultiply)
1729 }
1730
1731 /// Adds a node that computes the dot product according to [the NumPy rules](https://numpy.org/doc/stable/reference/generated/numpy.dot.html):
1732 /// * if both factors are 1-dimensional arrays, return their inner product;
1733 /// * if both factors are 2-dimensional arrays, return their matrix product;
1734 /// * if one of the factors is scalar, return the result of [multiply](Graph::multiply);
1735 /// * if the first factor is n-dimensional and the second one is 1-dimensional,
1736 /// compute the elementwise multiplication and return the sum over the last axis.
1737 /// * if both factors are n-dimensional (n>2), return the sum product
1738 /// over the last axis of the first factor and the second-to-last axis of the second factor, i.e.
1739 ///
1740 /// `dot(A, B)[i,j,k,m] = sum(A[i,j,:] * B[k,:,m])` (in the NumPy notation).
1741 ///
1742 /// # Arguments
1743 ///
1744 /// * `a` - node containing the first factor (array or scalar)
1745 /// * `b` - node containing the second factor (array or scalar)
1746 ///
1747 /// # Returns
1748 ///
1749 /// New dot product node
1750 ///
1751 /// # Example
1752 ///
1753 /// ```
1754 /// # use ciphercore_base::graphs::create_context;
1755 /// # use ciphercore_base::data_types::{INT32, array_type};
1756 /// let c = create_context().unwrap();
1757 /// let g = c.create_graph().unwrap();
1758 /// let t = array_type(vec![10], INT32);
1759 /// let n1 = g.input(t.clone()).unwrap();
1760 /// let n2 = g.input(t).unwrap();
1761 /// let n3 = g.dot(n1, n2).unwrap();
1762 /// ```
1763 pub fn dot(&self, a: Node, b: Node) -> Result<Node> {
1764 self.add_node(vec![a, b], vec![], Operation::Dot)
1765 }
1766
1767 /// Adds a node that computes the matrix product of two arrays according to [the NumPy rules](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
1768 ///
1769 /// Each array is represented as an array of 2-dimensional matrix elements and this node returns the elementwise product of such matrix arrays.
1770 ///
1771 /// # Arguments
1772 ///
1773 /// * `a` - node containing the first array
1774 /// * `b` - node containing the second array
1775 ///
1776 /// # Returns
1777 ///
1778 /// New matrix product node
1779 ///
1780 /// # Example
1781 ///
1782 /// ```
1783 /// # use ciphercore_base::graphs::create_context;
1784 /// # use ciphercore_base::data_types::{INT32, array_type};
1785 /// let c = create_context().unwrap();
1786 /// let g = c.create_graph().unwrap();
1787 /// let t1 = array_type(vec![2, 3], INT32);
1788 /// let t2 = array_type(vec![3, 2], INT32);
1789 /// let n1 = g.input(t1).unwrap();
1790 /// let n2 = g.input(t2).unwrap();
1791 /// let n3 = g.matmul(n1, n2).unwrap();
1792 /// ```
1793 pub fn matmul(&self, a: Node, b: Node) -> Result<Node> {
1794 self.add_node(vec![a, b], vec![], Operation::Matmul)
1795 }
1796
1797 /// Adds a node that computes the general matrix product of two arrays according to [the ONNX rules](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) with `alpha = 1`, `beta = 0` and `C = 0`.
1798 ///
1799 /// Each array is represented as an array of 2-dimensional matrix elements and this node returns the elementwise product of such matrix arrays.
1800 /// Each matrix should have at least 2 dimensions.
1801 /// To multiply by 1-dimensional matrices (i.e., vectors), please resort to `matmul` or `dot`.
1802 ///
1803 /// # Arguments
1804 ///
1805 /// * `a` - node containing the first array
1806 /// * `b` - node containing the second array
1807 /// * `transpose_a` - if true, the first array will be transposed
1808 /// * `transpose_b` - if true, the second array will be transposed
1809 ///
1810 /// # Returns
1811 ///
1812 /// New Gemm node
1813 ///
1814 /// # Example
1815 ///
1816 /// ```
1817 /// # use ciphercore_base::graphs::create_context;
1818 /// # use ciphercore_base::data_types::{INT32, array_type};
1819 /// let c = create_context().unwrap();
1820 /// let g = c.create_graph().unwrap();
1821 /// let t1 = array_type(vec![2, 3], INT32);
1822 /// let t2 = array_type(vec![2, 3], INT32);
1823 /// let n1 = g.input(t1).unwrap();
1824 /// let n2 = g.input(t2).unwrap();
1825 /// let n3 = g.gemm(n1, n2, false, true).unwrap();
1826 /// ```
1827 #[doc(hidden)]
1828 pub fn gemm(&self, a: Node, b: Node, transpose_a: bool, transpose_b: bool) -> Result<Node> {
1829 self.add_node(
1830 vec![a, b],
1831 vec![],
1832 Operation::Gemm(transpose_a, transpose_b),
1833 )
1834 }
1835
1836 /// Adds a node that computes a join of a given type on two named tuples along given key headers.
1837 ///
1838 /// Each tuple should consist of arrays having the same number of rows, i.e. the first dimensions of these arrays should be equal.
1839 ///
1840 /// In addition, each named tuple should have a binary array named with NULL_HEADER that contains zeros in rows void of content; otherwise, it contains ones.
1841 /// This column is called the null column.
1842 ///
1843 /// Let `row key` be the bitstring obtained by concatenating data elements for given key headers.
1844 /// **WARNING**: Rows must have unique row keys, except for rows where NULL_HEADER is zero: those rows are ignored.
1845 ///
1846 /// This operation returns:
1847 /// - Inner join: a named tuple containing rows where input tuples have matching row keys.
1848 /// - Left join: a named tuple containing all the rows of the first input tuple merged with the rows of the second input tuple having same row keys.
1849 /// - Union join: a named tuple containing rows of the first input tuple that are not in the inner join and all the rows of the second tuple.
1850 /// In contrast to the SQL union, this operation does not require that input datasets have respective columns of the same type.
1851 /// This means that columns of both datasets are included and filled with zeros where no data can be retrieved.
1852 /// Namely, the rows of the second set in the union join will contain zeros in non-key columns of the first set and vice versa.
1853 /// - Full join: a named tuple containing all the rows of input tuples.
1854 /// If the row key of a row of the first set match with the row key of a row of the second set, they are merged into one.
1855 /// The order of rows goes as follows:
1856 /// 1. the rows of the first set that don't belong to the inner join.
1857 /// 2. all the rows of the second set including those merged with the rows of the first set as in inner join.
1858 /// In this form, full join is computed as `union_join(a, left_join(b, a))`.
1859 ///
1860 /// # Arguments
1861 ///
1862 /// * `a` - node containing the first named tuple
1863 /// * `b` - node containing the second named tuple
1864 /// * `t` - join type (Inner/Left/Union/Full)
1865 /// * `headers` - headers of columns along which the join is performed
1866 ///
1867 /// # Returns
1868 ///
1869 /// New join node
1870 ///
1871 /// # Example
1872 ///
1873 /// ```
1874 /// # use ciphercore_base::graphs::{create_context, JoinType};
1875 /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type};
1876 /// # use ciphercore_base::type_inference::NULL_HEADER;
1877 /// # use std::collections::HashMap;
1878 /// let c = create_context().unwrap();
1879 /// let g = c.create_graph().unwrap();
1880 /// let t1n = array_type(vec![100], BIT);
1881 /// let t11 = array_type(vec![100], INT32);
1882 /// let t12 = array_type(vec![100, 128], BIT);
1883 /// let t13 = array_type(vec![100], INT64);
1884 /// let t2n = array_type(vec![50], BIT);
1885 /// let t21 = array_type(vec![50], INT32);
1886 /// let t22 = array_type(vec![50, 128], BIT);
1887 /// let t23 = array_type(vec![50], UINT8);
1888 /// let t1 = named_tuple_type(vec![
1889 /// (NULL_HEADER.to_owned(), t1n),
1890 /// ("ID".to_owned(), t11),
1891 /// ("Occupation".to_owned(), t12),
1892 /// ("Revenue".to_owned(), t13),
1893 /// ]);
1894 /// let t2 = named_tuple_type(vec![
1895 /// (NULL_HEADER.to_owned(), t2n),
1896 /// ("ID".to_owned(), t21),
1897 /// ("Job".to_owned(), t22),
1898 /// ("Age".to_owned(), t23),
1899 /// ]);
1900 /// let n1 = g.input(t1).unwrap();
1901 /// let n2 = g.input(t2).unwrap();
1902 /// let n3 = g.join(n1, n2, JoinType::Inner, HashMap::from([
1903 /// ("ID".to_owned(), "ID".to_owned()),
1904 /// ("Occupation".to_owned(), "Job".to_owned()),
1905 /// ])).unwrap();
1906 /// ```
1907 pub fn join(
1908 &self,
1909 a: Node,
1910 b: Node,
1911 t: JoinType,
1912 headers: HashMap<String, String>,
1913 ) -> Result<Node> {
1914 self.add_node(vec![a, b], vec![], Operation::Join(t, headers))
1915 }
1916
1917 /// Adds a node that computes a join of a given type on two named tuples along given key headers.
1918 ///
1919 /// Each tuple should consist of pairs of arrays having the same number of rows, i.e. the first dimensions of these arrays should be equal.
1920 /// Each pair has a binary array that contains zeros in rows where the data array has no content and an array with data.
1921 ///
1922 /// In addition, each named tuple should have a binary array named with NULL_HEADER that contains zeros in rows void of content; otherwise, it contains ones.
1923 /// This column is called the null column.
1924 ///
1925 /// Let `row key` be the bitstring obtained by concatenating data elements for given key headers where all the corresponding mask elements are set to one.
1926 /// **WARNING**: Rows must have unique row keys, except for rows where NULL_HEADER is zero or at least one mask element in given key headers is zero.
1927 /// Rows with zero NULL_HEADER are ignored.
1928 /// We assume that rows with zero mask elements don't match with other rows.
1929 /// Thus, they can't show up in inner and left joins, but can be copied over to the result of union or full joins.
1930 ///
1931 /// This operation returns:
1932 /// - Inner join: a named tuple containing rows where input tuples have matching row keys.
1933 /// - Left join: a named tuple containing all the rows of the first input tuple merged with the rows of the second input tuple having same row keys.
1934 /// - Union join: a named tuple containing rows of the first input tuple that are not in the inner join and all the rows of the second tuple.
1935 /// In contrast to the SQL union, this operation does not require that input datasets have respective columns of the same type.
1936 /// This means that columns of both datasets are included and filled with zeros where no data can be retrieved.
1937 /// Namely, the rows of the second set in the union join will contain zeros in non-key columns of the first set and vice versa.
1938 /// - Full join: a named tuple containing all the rows of input tuples.
1939 /// If the row key of a row of the first set match with the row key of a row of the second set, they are merged into one.
1940 /// The order of rows goes as follows:
1941 /// 1. the rows of the first set that don't belong to the inner join.
1942 /// 2. all the rows of the second set including those merged with the rows of the first set as in inner join.
1943 /// In this form, full join is computed as `union_join(a, left_join(b, a))`.
1944 ///
1945 /// # Arguments
1946 ///
1947 /// * `a` - node containing the first named tuple
1948 /// * `b` - node containing the second named tuple
1949 /// * `t` - join type (Inner/Left/Union/Full)
1950 /// * `headers` - headers of columns along which the join is performed
1951 ///
1952 /// # Returns
1953 ///
1954 /// New join node
1955 ///
1956 /// # Example
1957 ///
1958 /// ```
1959 /// # use ciphercore_base::graphs::{create_context, JoinType};
1960 /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type, tuple_type};
1961 /// # use ciphercore_base::type_inference::NULL_HEADER;
1962 /// # use std::collections::HashMap;
1963 /// let c = create_context().unwrap();
1964 /// let g = c.create_graph().unwrap();
1965 /// let t1n = array_type(vec![100], BIT);
1966 /// let t11 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT32)]);
1967 /// let t12 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100, 128], BIT)]);
1968 /// let t13 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT64)]);
1969 /// let t2n = array_type(vec![50], BIT);
1970 /// let t21 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], INT32)]);
1971 /// let t22 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50, 128], BIT)]);
1972 /// let t23 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], UINT8)]);
1973 /// let t1 = named_tuple_type(vec![
1974 /// (NULL_HEADER.to_owned(), t1n),
1975 /// ("ID".to_owned(), t11),
1976 /// ("Occupation".to_owned(), t12),
1977 /// ("Revenue".to_owned(), t13),
1978 /// ]);
1979 /// let t2 = named_tuple_type(vec![
1980 /// (NULL_HEADER.to_owned(), t2n),
1981 /// ("ID".to_owned(), t21),
1982 /// ("Job".to_owned(), t22),
1983 /// ("Age".to_owned(), t23),
1984 /// ]);
1985 /// let n1 = g.input(t1).unwrap();
1986 /// let n2 = g.input(t2).unwrap();
1987 /// let n3 = g.join_with_column_masks(n1, n2, JoinType::Inner, HashMap::from([
1988 /// ("ID".to_owned(), "ID".to_owned()),
1989 /// ("Occupation".to_owned(), "Job".to_owned()),
1990 /// ])).unwrap();
1991 /// ```
1992 pub fn join_with_column_masks(
1993 &self,
1994 a: Node,
1995 b: Node,
1996 t: JoinType,
1997 headers: HashMap<String, String>,
1998 ) -> Result<Node> {
1999 self.add_node(
2000 vec![a, b],
2001 vec![],
2002 Operation::JoinWithColumnMasks(t, headers),
2003 )
2004 }
2005
2006 /// Adds a node that applies a permutation to the array along the first dimension.
2007 ///
2008 /// # Arguments
2009 ///
2010 /// * `a` - node containing an array to permute.
2011 /// * `p` - node containing a permutation.
2012 ///
2013 /// # Returns
2014 ///
2015 /// New permuted node
2016 ///
2017 /// # Example
2018 ///
2019 /// ```
2020 /// # use ciphercore_base::graphs::create_context;
2021 /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
2022 /// let c = create_context().unwrap();
2023 /// let g = c.create_graph().unwrap();
2024 /// let t = array_type(vec![25, 3], INT32);
2025 /// let a = g.input(t).unwrap();
2026 /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
2027 /// let a = g.apply_permutation(a, p).unwrap();
2028 /// ```
2029 #[doc(hidden)]
2030 pub fn apply_permutation(&self, a: Node, p: Node) -> Result<Node> {
2031 self.add_node(vec![a, p], vec![], Operation::ApplyPermutation(false))
2032 }
2033
2034 /// Adds a node that applies an inverse permutation to the array along the first dimension.
2035 ///
2036 /// # Arguments
2037 ///
2038 /// * `a` - node containing an array to permute.
2039 /// * `p` - node containing a permutation.
2040 ///
2041 /// # Returns
2042 ///
2043 /// New permuted node
2044 ///
2045 /// # Example
2046 ///
2047 /// ```
2048 /// # use ciphercore_base::graphs::create_context;
2049 /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
2050 /// let c = create_context().unwrap();
2051 /// let g = c.create_graph().unwrap();
2052 /// let t = array_type(vec![25, 3], INT32);
2053 /// let a = g.input(t).unwrap();
2054 /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
2055 /// let a = g.apply_inverse_permutation(a, p).unwrap();
2056 /// ```
2057 #[doc(hidden)]
2058 pub fn apply_inverse_permutation(&self, a: Node, p: Node) -> Result<Node> {
2059 self.add_node(vec![a, p], vec![], Operation::ApplyPermutation(true))
2060 }
2061
2062 /// Adds a node that sorts a table given as named tuple according to the column given by the key argument.
2063 /// The key column must be a 2-d BIT array of shape [n, b], interpreted as bitstrings of length b.
2064 /// Other columns in the named tuple must be arrays of arbitrary type and shape, as long as they
2065 /// share the first dimension: [n, ...].
2066 /// Bitstrings are sorted lexicographically, and the sorting algorithm is stable: preserving relative
2067 /// order of entries in other arrays where the corresponding key entries match.
2068 ///
2069 /// # Arguments
2070 /// * `a` - node containing a named tuple -- arrays to sort.
2071 /// * `key` - name of the field to sort on it, this array must be 2-d of type BIT.
2072 ///
2073 /// # Returns
2074 ///
2075 /// New sorted node
2076 ///
2077 /// # Example
2078 ///
2079 /// ```
2080 /// # use ciphercore_base::graphs::create_context;
2081 /// # use ciphercore_base::data_types::{BIT, INT32, UINT64, array_type, named_tuple_type};
2082 /// let c = create_context().unwrap();
2083 /// let g = c.create_graph().unwrap();
2084 /// let v1 = g.input(array_type(vec![20], INT32)).unwrap();
2085 /// let v2 = g.input(array_type(vec![20, 10, 2], UINT64)).unwrap();
2086 /// let k = g.input(array_type(vec![20, 32], BIT)).unwrap();
2087 /// let a = g.create_named_tuple(vec![("key".to_string(), k), ("value1".to_string(), v1), ("value2".to_string(), v2)]).unwrap();
2088 /// let a = g.sort(a, "key".to_string()).unwrap();
2089 /// ```
2090 pub fn sort(&self, a: Node, key: String) -> Result<Node> {
2091 self.add_node(vec![a], vec![], Operation::Sort(key))
2092 }
2093
2094 /// Adds a node that divides a scalar or each entry of an array by a positive constant integer `scale`.
2095 ///
2096 /// # Arguments
2097 ///
2098 /// * `a` - node containing a scalar or an array
2099 /// * `scale` - positive integer
2100 ///
2101 /// # Returns
2102 ///
2103 /// New truncate node
2104 ///
2105 /// # Example
2106 ///
2107 /// ```
2108 /// # use ciphercore_base::graphs::create_context;
2109 /// # use ciphercore_base::data_types::{INT32, array_type};
2110 /// let c = create_context().unwrap();
2111 /// let g = c.create_graph().unwrap();
2112 /// let t = array_type(vec![2, 3], INT32);
2113 /// let n1 = g.input(t).unwrap();
2114 /// let n2 = g.truncate(n1, 4).unwrap();
2115 /// ```
2116 pub fn truncate(&self, a: Node, scale: u128) -> Result<Node> {
2117 self.add_node(vec![a], vec![], Operation::Truncate(scale))
2118 }
2119
2120 /// Adds a node that computes the sum of entries of an array along given axes (see [numpy.sum](https://numpy.org/doc/stable/reference/generated/numpy.sum.html)).
2121 ///
2122 /// For example, summing the array `[[1000, 200], [30, 4]]` along the first or the second axes results in the arrays `[1030,204]` or `[1200,34]`, respectively. Summing along both axes yields `1234`.
2123 ///
2124 /// # Arguments
2125 ///
2126 /// * `a` - node containing an array
2127 /// * `axes` - indices of the axes of `a`
2128 ///
2129 /// # Returns
2130 ///
2131 /// New sum node
2132 ///
2133 /// # Example
2134 ///
2135 /// ```
2136 /// # use ciphercore_base::graphs::create_context;
2137 /// # use ciphercore_base::data_types::{INT32, array_type};
2138 /// let c = create_context().unwrap();
2139 /// let g = c.create_graph().unwrap();
2140 /// let t = array_type(vec![3, 2, 3], INT32);
2141 /// let axes = vec![1, 0];
2142 /// let n1 = g.input(t).unwrap();
2143 /// let n2 = g.sum(n1, axes).unwrap();
2144 /// ```
2145 pub fn sum(&self, a: Node, axes: ArrayShape) -> Result<Node> {
2146 self.add_node(vec![a], vec![], Operation::Sum(axes))
2147 }
2148
2149 /// Adds a node that computes the cumulative sum of elements along a given axis. (see [numpy.cumsum](https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html)).
2150 ///
2151 /// For example, summing the array `[[1000, 200], [30, 4]]` along the first or the second axes results in the arrays `[[1000, 200], [1030, 204]]` or `[[1000, 1200], [30, 34]]`, respectively.
2152 ///
2153 /// # Arguments
2154 ///
2155 /// * `a` - node containing an array
2156 /// * `axis` - axis along which the cumulative sum is computed
2157 ///
2158 /// # Returns
2159 ///
2160 /// New cumulative sum node
2161 ///
2162 /// # Example
2163 ///
2164 /// ```
2165 /// # use ciphercore_base::graphs::create_context;
2166 /// # use ciphercore_base::data_types::{INT32, array_type};
2167 /// let c = create_context().unwrap();
2168 /// let g = c.create_graph().unwrap();
2169 /// let t = array_type(vec![3, 2], INT32);
2170 /// let n1 = g.input(t).unwrap();
2171 /// let n2 = g.cum_sum(n1, 1).unwrap();
2172 /// ```
2173 pub fn cum_sum(&self, a: Node, axis: u64) -> Result<Node> {
2174 self.add_node(vec![a], vec![], Operation::CumSum(axis))
2175 }
2176
2177 /// Adds a node that permutes an array along given axes (see [numpy.transpose](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)). This function generalizes matrix transposition.
2178 ///
2179 /// For example, permutation of an array of shape `[a,b,c]` with permutation `[2,0,1]` results in an array of shape `[c,a,b]`.
2180 ///
2181 /// # Arguments
2182 ///
2183 /// * `a` - node containing an array
2184 /// * `axes` - indices of the axes of `a`
2185 ///
2186 /// # Returns
2187 ///
2188 /// New node with permuted axes
2189 ///
2190 /// # Example
2191 ///
2192 /// ```
2193 /// # use ciphercore_base::graphs::create_context;
2194 /// # use ciphercore_base::data_types::{INT32, array_type};
2195 /// let c = create_context().unwrap();
2196 /// let g = c.create_graph().unwrap();
2197 /// let t = array_type(vec![3, 2, 3], INT32);
2198 /// let axes = vec![1, 0, 2];
2199 /// let n1 = g.input(t).unwrap();
2200 /// let n2 = g.permute_axes(n1, axes).unwrap();
2201 /// ```
2202 pub fn permute_axes(&self, a: Node, axes: ArrayShape) -> Result<Node> {
2203 self.add_node(vec![a], vec![], Operation::PermuteAxes(axes))
2204 }
2205
2206 /// Adds a node to the parent graph that inverts a given permutation.
2207 ///
2208 /// An input permutation should be given by a 1-dimensional array of length n, containing unique integers between 0 and n-1.
2209 /// The i-th element of an output array is output[i] = j if input[j] = i.
2210 ///
2211 /// This operation could be realized through [the Scatter operation](https://en.wikipedia.org/wiki/Gather-scatter_(vector_addressing)#Scatter).
2212 /// However, the Scatter operation poses a security risk as the corresponding map should hide empty output positions.
2213 /// This is usually done by padding an input array with dummy values such that its size is equal to the output size.
2214 /// Then, the Scatter map can be turned into a permutation, which can be easily split into a composition of random permutation maps.
2215 /// But permutation maps can be performed by Gather, thus making Scatter unnecessary.
2216 ///
2217 /// **WARNING**: this function should not be used before MPC compilation.
2218 ///
2219 /// # Arguments
2220 ///
2221 /// `a` - node containing an array with permutation.
2222 #[doc(hidden)]
2223 pub fn inverse_permutation(&self, a: Node) -> Result<Node> {
2224 self.add_node(vec![a], vec![], Operation::InversePermutation)
2225 }
2226
2227 /// Adds a node that extracts a sub-array with a given index. This is a special case of [get_slice](Graph::get_slice) and corresponds to single element indexing as in [NumPy](https://numpy.org/doc/stable/user/basics.indexing.html).
2228 ///
2229 /// For example, given an array `A` of shape `[a,b,c,d]`, its subarray `B` of shape `[c,d]` with index `[i,j]` can be extracted as follows
2230 ///
2231 /// `B = A[i,j,:,:]` (in the NumPy notation)
2232 ///
2233 /// # Arguments
2234 ///
2235 /// * `a` - node containing an array
2236 /// * `index` - index of a sub-array
2237 ///
2238 /// # Returns
2239 ///
2240 /// New node containing an extracted sub-array
2241 ///
2242 /// # Example
2243 ///
2244 /// ```
2245 /// # use ciphercore_base::graphs::create_context;
2246 /// # use ciphercore_base::data_types::{INT32, array_type};
2247 /// let c = create_context().unwrap();
2248 /// let g = c.create_graph().unwrap();
2249 /// let t = array_type(vec![3, 2, 3], INT32);
2250 /// let index = vec![2];
2251 /// let n1 = g.input(t).unwrap();
2252 /// let n2 = g.get(n1, index).unwrap();
2253 /// ```
2254 pub fn get(&self, a: Node, index: ArrayShape) -> Result<Node> {
2255 self.add_node(vec![a], vec![], Operation::Get(index))
2256 }
2257
2258 /// Adds a node that extracts a sub-array corresponding to a given slice.
2259 ///
2260 /// Our slicing conventions follow [the NumPy rules](https://numpy.org/doc/stable/user/basics.indexing.html).
2261 ///
2262 /// For example, given an array `A` of shape `[a,b]`, its subarray `B` containing only the last 3 rows of `A` can be extracted as follows
2263 ///
2264 /// `get_slice(A, [-3::])[i,j] = A[a-3+i,j]`.
2265 ///
2266 /// Slices are defined as vectors of [SliceElements](SliceElement) that have 3 possible types:
2267 ///
2268 /// * [SingleIndex(`i64`)](SliceElement::SingleIndex) is used to extract all the elements with a given index in a respective dimension,
2269 /// * [SubArray(`Option<i64>, Option<i64>, Option<i64>`)](SliceElement::SubArray) describes the range of indices that should be extracted over a certain dimension (similar to the `a:b:c` notation in [NumPy](https://numpy.org/doc/stable/user/basics.indexing.html))
2270 /// * [Ellipsis](SliceElement::Ellipsis) describes several consecutive dimensions that must be extracted in full, e.g. the slice `[i,...,j]` can be used to extract all the elements with the index `i` in the first dimension and the index `j` in the last one, while the indices of all the other dimensions have no constraints. See [the NumPy slicing](https://numpy.org/doc/stable/user/basics.indexing.html) for more details.
2271 ///
2272 /// # Arguments
2273 ///
2274 /// * `a` - node containing an array
2275 /// * `slice` - array slice
2276 ///
2277 /// # Returns
2278 ///
2279 /// New node containing an extracted sub-array
2280 ///
2281 /// # Example
2282 ///
2283 /// ```
2284 /// # use ciphercore_base::graphs::{create_context, SliceElement};
2285 /// # use ciphercore_base::data_types::{INT32, array_type};
2286 /// let c = create_context().unwrap();
2287 /// let g = c.create_graph().unwrap();
2288 /// let t = array_type(vec![3, 2, 3], INT32);
2289 /// let slice = vec![SliceElement::Ellipsis, SliceElement::SubArray(None, None, Some(-2))];
2290 /// let n1 = g.input(t).unwrap();
2291 /// let n2 = g.get_slice(n1, slice).unwrap();
2292 /// ```
2293 pub fn get_slice(&self, a: Node, slice: Slice) -> Result<Node> {
2294 self.add_node(vec![a], vec![], Operation::GetSlice(slice))
2295 }
2296
2297 /// Adds a node that reshapes a value to a given compatible type (similar to [numpy.reshape](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.reshape.html?highlight=reshape#numpy.ndarray.reshape), but more general). Specifically,
2298 ///
2299 /// * if the input value is an array, it can be reshaped to any array with the same number of elements;
2300 /// * if the input value in the flattened form contains `n` arrays or scalars, it can be reshaped to any type with the same number of arrays and scalars. Each array can be reshaped as in the above rule.
2301 ///
2302 /// For example, an array of shape `[3,10,5]` can be reshaped to `[2,75]`. A tuple with arrays of shapes `[3,4]`, `[12]`, `[2,6]` can be reshaped to a vector with 3 array elements of shape `[2,2,3]`.
2303 ///
2304 /// # Arguments
2305 ///
2306 /// * `a` - node containing a value
2307 /// * `new_type` - type
2308 ///
2309 /// # Returns
2310 ///
2311 /// New node with a reshaped value
2312 ///
2313 /// # Example
2314 ///
2315 /// ```
2316 /// # use ciphercore_base::graphs::create_context;
2317 /// # use ciphercore_base::data_types::{INT32, array_type};
2318 /// let c = create_context().unwrap();
2319 /// let g = c.create_graph().unwrap();
2320 /// let old_t = array_type(vec![3, 2, 3], INT32);
2321 /// let new_t = array_type(vec![3,6], INT32);
2322 /// let n1 = g.input(old_t).unwrap();
2323 /// let n2 = g.reshape(n1, new_t).unwrap();
2324 /// ```
2325 pub fn reshape(&self, a: Node, new_type: Type) -> Result<Node> {
2326 let size_estimate = get_size_estimation_in_bits(new_type.clone());
2327 if size_estimate.is_err() {
2328 return Err(runtime_error!(
2329 "Trying to add a reshape node with invalid type size: {:?}",
2330 size_estimate
2331 ));
2332 }
2333 if size_estimate? > type_size_limit_constants::MAX_INDIVIDUAL_NODE_SIZE {
2334 return Err(runtime_error!(
2335 "Trying to add a reshape node larger than MAX_INDIVIDUAL_NODE_SIZE"
2336 ));
2337 }
2338 self.add_node(vec![a], vec![], Operation::Reshape(new_type))
2339 }
2340
2341 /// Adds a node creating a random value of a given type.
2342 ///
2343 /// **WARNING**: this function should not be used before MPC compilation.
2344 ///
2345 /// # Arguments
2346 ///
2347 /// `output_type` - type of a constant
2348 ///
2349 /// # Returns
2350 ///
2351 /// New random node
2352 ///
2353 /// # Example
2354 ///
2355 /// ```
2356 /// # use ciphercore_base::graphs::create_context;
2357 /// # use ciphercore_base::data_types::{BIT, scalar_type};
2358 /// let c = create_context().unwrap();
2359 /// let g = c.create_graph().unwrap();
2360 /// let t = scalar_type(BIT);
2361 /// let n = g.random(t).unwrap();
2362 /// ```
2363 #[doc(hidden)]
2364 pub fn random(&self, output_type: Type) -> Result<Node> {
2365 self.add_node(vec![], vec![], Operation::Random(output_type))
2366 }
2367
2368 /// Adds a node creating a random permutation map of a one-dimensional array of length `n`.
2369 ///
2370 /// This operation generates a random array of all 64-bit integers from 0 to n-1 in random order.
2371 ///
2372 /// **WARNING**: this function should not be used before MPC compilation.
2373 ///
2374 /// # Arguments
2375 ///
2376 /// `n` - length of permutation
2377 ///
2378 /// # Returns
2379 ///
2380 /// New random permutation node
2381 #[doc(hidden)]
2382 pub fn random_permutation(&self, n: u64) -> Result<Node> {
2383 self.add_node(vec![], vec![], Operation::RandomPermutation(n))
2384 }
2385
2386 /// Adds a node returning the Cuckoo hash map of an input array of binary strings using provided hash functions.
2387 ///
2388 /// Hash functions are defined as an array of binary matrices.
2389 /// The hash of an input string is a product of one of these matrices and this string.
2390 /// Hence, the last dimension of these matrices should coincide with the length of input strings.
2391 ///
2392 /// Random matrices yield a better success probability of hashing.
2393 ///
2394 /// If the input array has shape `[..., n, b]` and hash matrices are given as an `[h, m, b]`-array,
2395 /// then the hash map is an array of shape `[..., 2^m]`.
2396 /// The hash table element with index `[..., i]` is equal to `j` if the `[..., j]`-th input `b`-bit string is hashed to `i` by some of the given hash functions.
2397 ///
2398 /// The number of hash matrices (the first dimension of hash matrices) must be at least 3.
2399 ///
2400 /// A bigger ratio `m/n` leads to higher success probability (recommended one is `>=2`)
2401 ///
2402 /// **WARNING**: this function should not be used before MPC compilation.
2403 ///
2404 /// # Arguments
2405 ///
2406 /// - `array` - input array of binary strings of shape [..., n, b]
2407 /// - `hash_matrices` - random binary [h, m, b]-array.
2408 ///
2409 /// # Returns
2410 ///
2411 /// New CuckooHash node
2412 #[doc(hidden)]
2413 pub fn cuckoo_hash(&self, array: Node, hash_matrices: Node) -> Result<Node> {
2414 self.add_node(vec![array, hash_matrices], vec![], Operation::CuckooHash)
2415 }
2416
2417 /// Adds a node that, given an input multidimensional array A, binary one-dimensional array B (first dimension is n in both array) and starting value v, computes the following iteration
2418 ///
2419 /// output[i] = A[i-1] + B[i-1] * output[i-1]
2420 ///
2421 /// where i in {1,...,n} and output[0] = v.
2422 /// This is similar to computing cumulative sums of consecutive elements (segments) of the input array A.
2423 /// The locations of these segments are defined by the binary array B.
2424 ///
2425 /// This iteration is used in the Duplication protocol (see mpc::mpc_psi) and is done locally by one of the computing parties.
2426 ///
2427 /// **WARNING**: this function should not be used before MPC compilation.
2428 ///
2429 /// # Arguments
2430 ///
2431 /// - `input_array` - input array whose rows are summed within the iteration
2432 /// - `binary_array` - binary array indicating whether a row of the input array should be added to a previous row of the output array
2433 /// - `first_row` - first row of the output array
2434 ///
2435 /// # Returns
2436 ///
2437 /// New SegmentCumSum node containing the output array
2438 #[doc(hidden)]
2439 pub fn segment_cumsum(
2440 &self,
2441 input_array: Node,
2442 binary_array: Node,
2443 first_row: Node,
2444 ) -> Result<Node> {
2445 self.add_node(
2446 vec![input_array, binary_array, first_row],
2447 vec![],
2448 Operation::SegmentCumSum,
2449 )
2450 }
2451
2452 /// Adds a node that computes sharding of a given table according to a given sharding config.
2453 /// Sharding config contains names of the columns whose hashed values are used for sharding.
2454 /// The size of each shard (i.e., the number of rows) and the number of shards is given in the sharding config.
2455 /// The number of shards should be smaller than 700.
2456 ///
2457 ///
2458 /// If some resulting shards don't have `shard_size` elements, they're padded with zeros to reach this size.
2459 /// If the size of some shards exceeds `shard_size`, sharding fails.
2460 ///
2461 /// To choose these parameters, consult [the following paper](http://wwwmayr.informatik.tu-muenchen.de/personen/raab/publ/balls.pdf).
2462 /// Note that for large shard sizes and small number of shards, it holds that
2463 ///
2464 /// `shard_size = num_input_rows / num_shards + alpha * sqrt(2 * num_input_rows / num_shards * log(num_shards))`.
2465 ///
2466 /// With `alpha = 2`, it is possible to achieve failure probability 2^(-40) if `num_shards < 700` and `shard_size > 2^17`.
2467 ///
2468 ///
2469 /// Each shard is accompanied by a Boolean mask indicating whether a corresponding row stems from the input table or padded (1 if a row comes from input).
2470 /// The output is given in the form of a tuple of `(mask, shard)`, where `mask` is a binary array and `shard` is a table, i.e., named tuple.
2471 ///
2472 /// **WARNING**: this function cannot be compiled to an MPC protocol.
2473 ///
2474 /// # Arguments
2475 ///
2476 /// - `input_table` - named tuple of arrays containing data for sharding
2477 /// - `shard_config` - parameters of sharding: number of shards, shard size and names of columns that are hashed in sharding
2478 ///
2479 /// # Returns
2480 ///
2481 /// New Shard node containing a tuple of shards
2482 #[doc(hidden)]
2483 pub fn shard(&self, input_table: Node, shard_config: ShardConfig) -> Result<Node> {
2484 self.add_node(vec![input_table], vec![], Operation::Shard(shard_config))
2485 }
2486
2487 /// Adds a node that converts a switching map array into a random tuple of the following components:
2488 /// - a permutation map array with deletion (some indices of this map are uniformly random, see below),
2489 /// - a tuple of duplication map array and duplication bits,
2490 /// - a permutation map array without deletion.
2491 ///
2492 /// The composition of these maps is equal to the input switching map, which is an array containing non-unique indices of some array.
2493 ///
2494 /// To create a permutation with deletion, this operation first groups identical indices of the input map together and shifts other indices accordingly, e.g.
2495 ///
2496 /// [1, 4, 5, 7, 2, 4] -> [1, 4, 4, 5, 7, 2].
2497 ///
2498 /// This can be done by permutation p = [1, 2, 6, 3, 4, 5].
2499 /// Then, it replaces copies with unique random indices not present in the switching map, e.g.
2500 ///
2501 /// [1, 4, 4, 5, 7, 2] -> [1, 4, 3, 5, 7, 2].
2502 ///
2503 ///
2504 /// A duplication map is a tuple of two one-dimensional arrays of length `n`.
2505 /// The first array contains indices from `[0,n]` in the increasing order with possible repetitions.
2506 /// The second array contains only zeros and ones.
2507 /// If its i-th element is zero, it means that the duplication map doesn't change the i-th element of an array it acts upon.
2508 /// If map's i-th element is one, then the map copies the previous element of the result.
2509 /// This rules can be summarized by the following equation
2510 ///
2511 /// duplication_indices[i] = duplication_bits[i] * duplication_indices[i-1] + (1 - duplication_bits[i]) * i.
2512 ///
2513 /// A duplication map is created from the above switching map with grouped indices, replacing the first index occurrence with 0 and other copies with 1, e.g.
2514 ///
2515 /// [1, 4, 4, 5, 7, 2] -> ([0, 1, 1, 3, 4, 5], [0, 0, 1, 0, 0, 0]).
2516 ///
2517 /// The last permutation is the inverse of the above permutation p, i.e.
2518 ///
2519 /// [1, 2, 4, 5, 6, 3].
2520 ///
2521 /// This operation supports vectorization.
2522 ///
2523 /// **WARNING**: this function should not be used before MPC compilation.
2524 ///
2525 /// # Arguments
2526 ///
2527 /// - `switching_map` - an array of one-dimensional arrays containing non-unique indices of some array of length `n` (usually a simple hash table),
2528 /// - `n` - length of an array that can be mapped by the above switching map.
2529 ///
2530 /// # Returns
2531 ///
2532 /// New DecomposeSwitchingMap node
2533 #[doc(hidden)]
2534 pub fn decompose_switching_map(&self, switching_map: Node, n: u64) -> Result<Node> {
2535 self.add_node(
2536 vec![switching_map],
2537 vec![],
2538 Operation::DecomposeSwitchingMap(n),
2539 )
2540 }
2541
2542 /// Adds a node that converts a Cuckoo hash table to a random permutation.
2543 ///
2544 /// Conversion is done via replacing dummy hash elements by random indices such that the resulting array constitute a permutation.
2545 ///
2546 /// **WARNING**: this function should not be used before MPC compilation.
2547 ///
2548 /// # Arguments
2549 ///
2550 /// `cuckoo_map` - an array containing a Cuckoo hash map with dummy values
2551 ///
2552 /// # Returns
2553 ///
2554 /// New CuckooToPermutation node
2555 #[doc(hidden)]
2556 pub fn cuckoo_to_permutation(&self, cuckoo_map: Node) -> Result<Node> {
2557 self.add_node(vec![cuckoo_map], vec![], Operation::CuckooToPermutation)
2558 }
2559
2560 /// Adds a node that joins a sequence of arrays governed by a given shape.
2561 ///
2562 /// The input arrays should have the same shape or be able to be broadcast to the same shape.
2563 ///
2564 /// For example, stacking 2 arrays of shapes `[2,2]` and `[2,1]` with the outer shape `[2]` works as follows
2565 ///
2566 /// `stack(arrays=[[[1,2],[3,4]], [5,6]], shape=[2]) = [[[1,2],[3,4]], [[5,5], [6,6]]]`
2567 ///
2568 /// # Arguments
2569 ///
2570 /// * `nodes` - vector of nodes containing arrays
2571 /// * `outer_shape` - shape defining how the input arrays are arranged in the resulting array
2572 ///
2573 /// # Returns
2574 ///
2575 /// New stack node
2576 ///
2577 /// # Example
2578 ///
2579 /// ```
2580 /// # use ciphercore_base::graphs::create_context;
2581 /// # use ciphercore_base::data_types::{INT32, array_type};
2582 /// let c = create_context().unwrap();
2583 /// let g = c.create_graph().unwrap();
2584 /// let t1 = array_type(vec![3, 2, 3], INT32);
2585 /// let t2 = array_type(vec![2, 3], INT32);
2586 /// let shape = vec![2];
2587 /// let n1 = g.input(t1).unwrap();
2588 /// let n2 = g.input(t2).unwrap();
2589 /// let n3 = g.stack(vec![n1,n2], shape).unwrap();
2590 /// ```
2591 pub fn stack(&self, nodes: Vec<Node>, outer_shape: ArrayShape) -> Result<Node> {
2592 self.add_node(nodes, vec![], Operation::Stack(outer_shape))
2593 }
2594
2595 /// Adds a node that joins a sequence of arrays along a given axis.
2596 /// This operation is similar to [the NumPy concatenate](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html).
2597 ///
2598 /// The input arrays should have the same shape except in the given axis.
2599 ///
2600 /// # Arguments
2601 ///
2602 /// * `nodes` - vector of nodes containing arrays
2603 /// * `axis` - axis along which the above arrays are joined
2604 ///
2605 /// # Returns
2606 ///
2607 /// New Concatenate node
2608 ///
2609 /// # Example
2610 ///
2611 /// ```
2612 /// # use ciphercore_base::graphs::create_context;
2613 /// # use ciphercore_base::data_types::{INT32, array_type};
2614 /// let c = create_context().unwrap();
2615 /// let g = c.create_graph().unwrap();
2616 /// let t1 = array_type(vec![3, 2, 3], INT32);
2617 /// let t2 = array_type(vec![3, 2, 10], INT32);
2618 /// let shape = vec![2];
2619 /// let n1 = g.input(t1).unwrap();
2620 /// let n2 = g.input(t2).unwrap();
2621 /// let n3 = g.concatenate(vec![n1,n2], 2).unwrap();
2622 /// ```
2623 pub fn concatenate(&self, nodes: Vec<Node>, axis: u64) -> Result<Node> {
2624 self.add_node(nodes, vec![], Operation::Concatenate(axis))
2625 }
2626
2627 /// Adds a node creating a constant of a given type and value.
2628 ///
2629 /// # Arguments
2630 ///
2631 /// * `output_type` - type of a constant
2632 /// * `value` - value of a constant
2633 ///
2634 /// # Returns
2635 ///
2636 /// New constant node
2637 ///
2638 /// # Example
2639 ///
2640 /// ```
2641 /// # use ciphercore_base::graphs::create_context;
2642 /// # use ciphercore_base::data_types::{BIT, scalar_type};
2643 /// # use ciphercore_base::data_values::Value;
2644 /// let c = create_context().unwrap();
2645 /// let g = c.create_graph().unwrap();
2646 /// let t = scalar_type(BIT);
2647 /// let v = Value::from_scalar(0, BIT).unwrap();
2648 /// let n = g.constant(t, v).unwrap();
2649 /// ```
2650 pub fn constant(&self, output_type: Type, value: Value) -> Result<Node> {
2651 self.add_node(vec![], vec![], Operation::Constant(output_type, value))
2652 }
2653
2654 /// Adds a node converting an integer array or scalar to the binary form.
2655 ///
2656 /// Given an array of shape `[a,b,c]` and scalar type `st`, this node returns an array of shape `[a,b,c,s]` where `s` is the bit size of `st`. For example, an array of shape `[1,2,3]` with `INT32` entries will be converted to a binary array of shape `[1,2,3,32]`.
2657 ///
2658 /// # Arguments
2659 ///
2660 /// `a` - node containing an array or scalar
2661 ///
2662 /// # Returns
2663 ///
2664 /// New node converting an array/scalar to the binary form
2665 ///
2666 /// # Example
2667 ///
2668 /// ```
2669 /// # use ciphercore_base::graphs::create_context;
2670 /// # use ciphercore_base::data_types::{array_type, INT32};
2671 /// let c = create_context().unwrap();
2672 /// let g = c.create_graph().unwrap();
2673 /// let t = array_type(vec![3, 2], INT32);
2674 /// let n1 = g.input(t).unwrap();
2675 /// let n2 = g.a2b(n1).unwrap();
2676 /// ```
2677 pub fn a2b(&self, a: Node) -> Result<Node> {
2678 self.add_node(vec![a], vec![], Operation::A2B)
2679 }
2680
2681 /// Adds a node converting a binary array to an array of a given scalar type.
2682 ///
2683 /// Given a binary array of shape `[a,b,c]` and a scalar type `st` of bit size `c`, this node returns an array of shape `[a,b]` with `st` entries. For example, a binary array of shape `[2,3,32]` can be converted to an array of shape `[2,3]` with `INT32` entries.
2684 ///
2685 /// # Arguments
2686 ///
2687 /// * `a` - node containing an array or scalar
2688 /// * `scalar_type` - scalar type
2689 ///
2690 /// # Returns
2691 ///
2692 /// New node converting an array from the binary form
2693 ///
2694 /// # Example
2695 ///
2696 /// ```
2697 /// # use ciphercore_base::graphs::create_context;
2698 /// # use ciphercore_base::data_types::{BIT, INT32, array_type};
2699 /// let c = create_context().unwrap();
2700 /// let g = c.create_graph().unwrap();
2701 /// let t = array_type(vec![3, 32], BIT);
2702 /// let n1 = g.input(t).unwrap();
2703 /// let n2 = g.b2a(n1, INT32).unwrap();
2704 /// ```
2705 pub fn b2a(&self, a: Node, scalar_type: ScalarType) -> Result<Node> {
2706 self.add_node(vec![a], vec![], Operation::B2A(scalar_type))
2707 }
2708
2709 /// Adds a node that creates a tuple from several (possibly, zero) elements.
2710 ///
2711 /// # Arguments
2712 ///
2713 /// `elements` - vector of nodes
2714 ///
2715 /// # Returns
2716 ///
2717 /// New node with a tuple
2718 ///
2719 /// # Example
2720 ///
2721 /// ```
2722 /// # use ciphercore_base::graphs::create_context;
2723 /// # use ciphercore_base::data_types::{INT32, array_type};
2724 /// let c = create_context().unwrap();
2725 /// let g = c.create_graph().unwrap();
2726 /// let t1 = array_type(vec![3, 2, 3], INT32);
2727 /// let t2 = array_type(vec![2, 3], INT32);
2728 /// let n1 = g.input(t1).unwrap();
2729 /// let n2 = g.input(t2).unwrap();
2730 /// let n3 = g.create_tuple(vec![n1,n2]).unwrap();
2731 /// ```
2732 pub fn create_tuple(&self, elements: Vec<Node>) -> Result<Node> {
2733 self.add_node(elements, vec![], Operation::CreateTuple)
2734 }
2735
2736 /// Adds a node that creates a vector from several (possibly, zero) elements of the same type.
2737 ///
2738 /// # Arguments
2739 ///
2740 /// `elements` - vector of nodes
2741 ///
2742 /// # Returns
2743 ///
2744 /// New node with a created vector
2745 ///
2746 /// # Example
2747 ///
2748 /// ```
2749 /// # use ciphercore_base::graphs::create_context;
2750 /// # use ciphercore_base::data_types::{INT32, array_type};
2751 /// let c = create_context().unwrap();
2752 /// let g = c.create_graph().unwrap();
2753 /// let t = array_type(vec![3, 2, 3], INT32);
2754 /// let n1 = g.input(t.clone()).unwrap();
2755 /// let n2 = g.input(t.clone()).unwrap();
2756 /// let n3 = g.create_vector(t, vec![n1,n2]).unwrap();
2757 /// ```
2758 pub fn create_vector(&self, element_type: Type, elements: Vec<Node>) -> Result<Node> {
2759 self.add_node(elements, vec![], Operation::CreateVector(element_type))
2760 }
2761
2762 /// Adds a node that creates a named tuple from several (possibly, zero) elements.
2763 ///
2764 /// # Arguments
2765 ///
2766 /// `elements` - vector of pairs (node name, node)
2767 ///
2768 /// # Returns
2769 ///
2770 /// New node creating a named tuple
2771 ///
2772 /// # Example
2773 ///
2774 /// ```
2775 /// # use ciphercore_base::graphs::create_context;
2776 /// # use ciphercore_base::data_types::{INT32, array_type};
2777 /// let c = create_context().unwrap();
2778 /// let g = c.create_graph().unwrap();
2779 /// let t1 = array_type(vec![3, 2, 3], INT32);
2780 /// let t2 = array_type(vec![2, 3], INT32);
2781 /// let n1 = g.input(t1).unwrap();
2782 /// let n2 = g.input(t2).unwrap();
2783 /// let n3 = g.create_named_tuple(vec![("node1".to_owned(), n1), ("node2".to_owned(), n2)]).unwrap();
2784 /// ```
2785 pub fn create_named_tuple(&self, elements: Vec<(String, Node)>) -> Result<Node> {
2786 let mut nodes = vec![];
2787 let mut names = vec![];
2788 for (name, node) in elements {
2789 nodes.push(node);
2790 names.push(name);
2791 }
2792 self.add_node(nodes, vec![], Operation::CreateNamedTuple(names))
2793 }
2794
2795 /// Adds a node that extracts an element of a tuple.
2796 ///
2797 /// # Arguments
2798 ///
2799 /// * `tuple` - node containing a tuple
2800 /// * `index` - index of a tuple element between 0 and tuple length minus 1
2801 ///
2802 /// # Returns
2803 ///
2804 /// New node with an extracted element
2805 ///
2806 /// # Example
2807 ///
2808 /// ```
2809 /// # use ciphercore_base::data_types::{INT32, array_type};
2810 /// # use ciphercore_base::graphs::create_context;
2811 /// let c = create_context().unwrap();
2812 /// let g = c.create_graph().unwrap();
2813 /// let t1 = array_type(vec![3, 2, 3], INT32);
2814 /// let t2 = array_type(vec![2, 3], INT32);
2815 /// let n1 = g.input(t1).unwrap();
2816 /// let n2 = g.input(t2).unwrap();
2817 /// let n3 = g.create_tuple(vec![n1, n2]).unwrap();
2818 /// let n4 = g.tuple_get(n3, 1).unwrap();
2819 /// ```
2820 pub fn tuple_get(&self, tuple: Node, index: u64) -> Result<Node> {
2821 self.add_node(vec![tuple], vec![], Operation::TupleGet(index))
2822 }
2823
2824 /// Adds a node that extracts an element of a named tuple.
2825 ///
2826 /// # Arguments
2827 ///
2828 /// * `tuple` - node containing a named tuple
2829 /// * `key` - key of a tuple element
2830 ///
2831 /// # Returns
2832 ///
2833 /// New node extracting a tuple element
2834 ///
2835 /// # Example
2836 ///
2837 /// ```
2838 /// # use ciphercore_base::graphs::create_context;
2839 /// # use ciphercore_base::data_types::{array_type, INT32};
2840 /// let c = create_context().unwrap();
2841 /// let g = c.create_graph().unwrap();
2842 /// let t1 = array_type(vec![3, 2, 3], INT32);
2843 /// let t2 = array_type(vec![2, 3], INT32);
2844 /// let n1 = g.input(t1).unwrap();
2845 /// let n2 = g.input(t2).unwrap();
2846 /// let n3 = g.create_named_tuple(vec![("node1".to_owned(), n1), ("node2".to_owned(), n2)]).unwrap();
2847 /// let n4 = g.named_tuple_get(n3, "node2".to_owned()).unwrap();
2848 /// ```
2849 pub fn named_tuple_get(&self, tuple: Node, key: String) -> Result<Node> {
2850 self.add_node(vec![tuple], vec![], Operation::NamedTupleGet(key))
2851 }
2852
2853 /// Adds a node that extracts an element of a vector.
2854 ///
2855 /// # Arguments
2856 ///
2857 /// * `vec` - node containing a vector
2858 /// * `index` - node containing the index of a tuple element
2859 ///
2860 /// # Returns
2861 ///
2862 /// New node extracting a vector element
2863 ///
2864 /// # Example
2865 ///
2866 /// ```
2867 /// # use ciphercore_base::graphs::create_context;
2868 /// # use ciphercore_base::data_types::{UINT32, INT32, array_type, scalar_type};
2869 /// # use ciphercore_base::data_values::Value;
2870 /// let c = create_context().unwrap();
2871 /// let g = c.create_graph().unwrap();
2872 /// let t = array_type(vec![3, 2, 3], INT32);
2873 /// let n1 = g.input(t.clone()).unwrap();
2874 /// let n2 = g.input(t.clone()).unwrap();
2875 /// let n3 = g.create_vector(t, vec![n1,n2]).unwrap();
2876 /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
2877 /// let n4 = g.vector_get(n3, index).unwrap();
2878 /// ```
2879 pub fn vector_get(&self, vec: Node, index: Node) -> Result<Node> {
2880 self.add_node(vec![vec, index], vec![], Operation::VectorGet)
2881 }
2882
2883 /// Adds a node that takes vectors V<sub>1</sub>(n, t<sub>1</sub>), V<sub>2</sub>(n, t<sub>2</sub>), ..., V<sub>k</sub>(n, t<sub>k</sub>) of the same length and returns a vector V(n, tuple(t<sub>1</sub>, ..., t<sub>k</sub>)) (similar to [zip](https://doc.rust-lang.org/stable/std/iter/fn.zip.html)).
2884 ///
2885 /// # Arguments
2886 ///
2887 /// `nodes` - vector of nodes containing input vectors
2888 ///
2889 /// # Returns
2890 ///
2891 /// New zip node
2892 ///
2893 /// # Example
2894 ///
2895 /// ```
2896 /// # use ciphercore_base::data_types::{INT32, array_type, vector_type};
2897 /// # use ciphercore_base::graphs::create_context;
2898 /// let c = create_context().unwrap();
2899 /// let g = c.create_graph().unwrap();
2900 /// let t = array_type(vec![3, 2, 3], INT32);
2901 /// let vec_t = vector_type(3, t);
2902 /// let n1 = g.input(vec_t.clone()).unwrap();
2903 /// let n2 = g.input(vec_t.clone()).unwrap();
2904 /// let n3 = g.zip(vec![n1,n2]).unwrap();
2905 /// ```
2906 pub fn zip(&self, nodes: Vec<Node>) -> Result<Node> {
2907 self.add_node(nodes, vec![], Operation::Zip)
2908 }
2909
2910 /// Adds a node that creates a vector with `n` copies of a value of a given node.
2911 ///
2912 /// # Arguments
2913 ///
2914 /// * `a` - node containing a value
2915 /// * `n` - number of copies
2916 ///
2917 /// # Returns
2918 ///
2919 /// New repeat node
2920 ///
2921 /// # Example
2922 ///
2923 /// ```
2924 /// # use ciphercore_base::data_types::{INT32, array_type};
2925 /// # use ciphercore_base::graphs::create_context;
2926 /// let c = create_context().unwrap();
2927 /// let g = c.create_graph().unwrap();
2928 /// let t = array_type(vec![3, 2, 3], INT32);
2929 /// let n1 = g.input(t).unwrap();
2930 /// let n2 = g.repeat(n1, 10).unwrap();
2931 /// ```
2932 pub fn repeat(&self, a: Node, n: u64) -> Result<Node> {
2933 self.add_node(vec![a], vec![], Operation::Repeat(n))
2934 }
2935
2936 /// Adds a node that calls another graph with inputs contained in given nodes.
2937 ///
2938 /// The input graph must be finalized and have as many inputs as the number of provided arguments.
2939 ///
2940 /// For example, let `G` be a graph implementing the function `max(x,0)`, then `call(G, [17]) = max(17, 0)`.
2941 ///
2942 /// # Arguments
2943 ///
2944 /// * `graph` - graph with `n` input nodes
2945 /// * `arguments` - vector of `n` nodes
2946 ///
2947 /// # Returns
2948 ///
2949 /// New call node
2950 ///
2951 /// # Example
2952 ///
2953 /// ```
2954 /// # use ciphercore_base::data_types::{INT32, array_type};
2955 /// # use ciphercore_base::graphs::create_context;
2956 /// let c = create_context().unwrap();
2957 ///
2958 /// let g1 = c.create_graph().unwrap();
2959 /// let t = array_type(vec![3, 2, 3], INT32);
2960 /// let n1 = g1.input(t.clone()).unwrap();
2961 /// let n2 = g1.repeat(n1, 10).unwrap();
2962 /// let n3 = g1.vector_to_array(n2).unwrap();
2963 /// n3.set_as_output().unwrap();
2964 /// g1.finalize().unwrap();
2965 ///
2966 /// let g2 = c.create_graph().unwrap();
2967 /// let n4 = g2.input(t).unwrap();
2968 /// let n5 = g2.add(n4.clone(), n4).unwrap();
2969 /// let n6 = g2.call(g1, vec![n5]).unwrap();
2970 /// ```
2971 pub fn call(&self, graph: Graph, arguments: Vec<Node>) -> Result<Node> {
2972 self.add_node(arguments, vec![graph], Operation::Call)
2973 }
2974
2975 /// Adds a node that iteratively computes a given finalized graph on the elements of a given vector and updates the state value accordingly.
2976 ///
2977 /// This node calls another `graph` with 2 input nodes `old_state` and `input` and an output node that returns a [tuple](Type::Tuple) `(new_state, output)`. This graph is used to map the elements of a given vector `V` to another vector `W` as follows:
2978 /// ```text
2979 /// graph(state_0, V[0]) -> (state1, W[0]),
2980 /// graph(state_1, V[1]) -> (state2, W[1]),
2981 /// ...
2982 /// graph(state_k, V[k]) -> (final_state, W[k]).
2983 /// ```
2984 /// The output is a [tuple](Type::Tuple) `(final_state, W)`. The initial state `state_0` should be provided as an argument.
2985 ///
2986 /// This node generalize `map` and `reduce` procedures (see [MapReduce](https://en.wikipedia.org/wiki/MapReduce) for more details).
2987 ///
2988 /// For example, let `G` be a graph implementing the function `max(x,0)` and incrementing `state` if its output is negative, then `iterate(G, 0, [-1,2,0,3,2]) = (1, [0,2,0,3,2])`. The final state is equal to the number of negative values in the input vector.
2989 ///
2990 /// # Arguments
2991 ///
2992 /// * `graph` - graph with 2 input nodes of types T<sub>s</sub> and T<sub>i</sub> and returning a tuple of type (T<sub>s</sub>, T<sub>o</sub>)
2993 /// * `state` - node containing an initial state of type T<sub>s</sub>
2994 /// * `input` - node containing a vector with elements of type T<sub>i</sub>
2995 ///
2996 /// # Returns
2997 ///
2998 /// New iterate node
2999 ///
3000 /// # Example
3001 ///
3002 /// ```
3003 /// # use ciphercore_base::data_types::{INT32, BIT, scalar_type, vector_type};
3004 /// # use ciphercore_base::graphs::create_context;
3005 /// # use ciphercore_base::ops::utils::constant_scalar;
3006 /// let c = create_context().unwrap();
3007 ///
3008 /// let t_s = scalar_type(BIT);
3009 /// let t = scalar_type(INT32);
3010 /// let vec_t = vector_type(10, t.clone());
3011 ///
3012 /// // Graph that outputs 0 at even indices or input value at odd indices.
3013 /// let g1 = c.create_graph().unwrap();
3014 /// {
3015 /// let old_state = g1.input(t_s.clone()).unwrap();
3016 /// let input = g1.input(t.clone()).unwrap();
3017 /// let result = g1.mixed_multiply(input, old_state.clone()).unwrap();
3018 /// let new_state = g1.add(old_state, constant_scalar(&g1, 1, BIT).unwrap()).unwrap();
3019 /// let out_tuple = g1.create_tuple(vec![new_state, result]).unwrap();
3020 /// out_tuple.set_as_output().unwrap();
3021 /// g1.finalize().unwrap();
3022 /// }
3023 ///
3024 /// let g2 = c.create_graph().unwrap();
3025 /// let initial_state = constant_scalar(&g2, 0, BIT).unwrap();
3026 /// let input_vector = g2.input(vec_t).unwrap();
3027 /// g2.iterate(g1, initial_state, input_vector).unwrap();
3028 /// ```
3029 pub fn iterate(&self, graph: Graph, state: Node, input: Node) -> Result<Node> {
3030 self.add_node(vec![state, input], vec![graph], Operation::Iterate)
3031 }
3032
3033 /// Adds a node converting an array to a vector.
3034 ///
3035 /// Given an array of shape `[a,b,c]`, this node returns a vector of `a` arrays of shape `[b,c]`.
3036 ///
3037 /// # Arguments
3038 ///
3039 /// `a` - node containing an array
3040 ///
3041 /// # Returns
3042 ///
3043 /// New node converting an array to a vector
3044 ///
3045 /// # Example
3046 ///
3047 /// ```
3048 /// # use ciphercore_base::graphs::create_context;
3049 /// # use ciphercore_base::data_types::{array_type, scalar_type, INT32, UINT32};
3050 /// # use ciphercore_base::data_values::Value;
3051 /// let c = create_context().unwrap();
3052 /// let g = c.create_graph().unwrap();
3053 /// let t = array_type(vec![4, 3, 2], INT32);
3054 /// let n1 = g.input(t).unwrap();
3055 /// let n2 = g.array_to_vector(n1).unwrap();
3056 /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
3057 /// let n3 = g.vector_get(n2.clone(), index).unwrap();
3058 ///
3059 /// assert!(n2.get_type().unwrap().is_vector());
3060 /// assert_eq!(n3.get_type().unwrap().get_shape(), vec![3,2]);
3061 /// ```
3062 pub fn array_to_vector(&self, a: Node) -> Result<Node> {
3063 self.add_node(vec![a], vec![], Operation::ArrayToVector)
3064 }
3065
3066 /// Adds a node converting a vector to an array.
3067 ///
3068 /// Given a vector of `a` arrays of shape `[b,c]`, this node returns an array of shape `[a,b,c]`.
3069 ///
3070 /// # Arguments
3071 ///
3072 /// `a` - node containing a vector
3073 ///
3074 /// # Returns
3075 ///
3076 /// New node converting a vector to an array
3077 ///
3078 /// # Example
3079 ///
3080 /// ```
3081 /// # use ciphercore_base::graphs::create_context;
3082 /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3083 /// let c = create_context().unwrap();
3084 /// let g = c.create_graph().unwrap();
3085 /// let t = array_type(vec![3, 2], INT32);
3086 /// let vec_t = vector_type(4, t);
3087 /// let n1 = g.input(vec_t).unwrap();
3088 /// let n2 = g.vector_to_array(n1).unwrap();
3089 ///
3090 /// assert!(n2.get_type().unwrap().is_array());
3091 /// assert_eq!(n2.get_type().unwrap().get_shape(), vec![4, 3, 2]);
3092 /// ```
3093 pub fn vector_to_array(&self, a: Node) -> Result<Node> {
3094 self.add_node(vec![a], vec![], Operation::VectorToArray)
3095 }
3096
3097 /// Adds a node creating an array from the elements of an input array indexed by another array along a given axis.
3098 ///
3099 /// Given an input array, this node replaces the dimension `axis` with the dimensions introduced by the indexing array.
3100 ///
3101 /// Indices must be unique to prevent possible duplication of shares/ciphertexts.
3102 /// Such duplicates might cause devastating data leakage.
3103 ///
3104 /// This operation is similar to [the NumPy take operation](https://numpy.org/doc/stable/reference/generated/numpy.take.html).
3105 ///
3106 /// **WARNING**: this function should not be used before MPC compilation.
3107 ///
3108 /// # Arguments
3109 ///
3110 /// `input` - node containing an input array
3111 /// `indices` - node containing indices
3112 /// `axis` - index of the axis along which indices are chosen
3113 ///
3114 /// # Returns
3115 ///
3116 /// New Gather node
3117 #[doc(hidden)]
3118 pub fn gather(&self, input: Node, indices: Node, axis: u64) -> Result<Node> {
3119 self.add_node(vec![input, indices], vec![], Operation::Gather(axis))
3120 }
3121
3122 /// Checks that the graph has an output node and finalizes the graph.
3123 ///
3124 /// After finalization the graph can't be changed.
3125 ///
3126 /// # Returns
3127 ///
3128 /// Finalized graph
3129 ///
3130 /// # Example
3131 ///
3132 /// ```
3133 /// # use ciphercore_base::graphs::create_context;
3134 /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3135 /// let c = create_context().unwrap();
3136 /// let g = c.create_graph().unwrap();
3137 /// let t = array_type(vec![3, 2], INT32);
3138 /// let vec_t = vector_type(4, t);
3139 /// let n1 = g.input(vec_t).unwrap();
3140 /// let n2 = g.vector_to_array(n1).unwrap();
3141 /// n2.set_as_output().unwrap();
3142 /// g.finalize().unwrap();
3143 /// ```
3144 pub fn finalize(&self) -> Result<Graph> {
3145 let output_node = self.body.borrow_mut().output_node.clone();
3146 match output_node {
3147 Some(_) => {
3148 self.body.borrow_mut().finalized = true;
3149 Ok(self.clone())
3150 }
3151 None => Err(runtime_error!("Output node is not set")),
3152 }
3153 }
3154
3155 /// Returns the vector of nodes contained in the graph in order of construction.
3156 ///
3157 /// # Returns
3158 ///
3159 /// Vector of nodes of the graph
3160 pub fn get_nodes(&self) -> Vec<Node> {
3161 self.body.borrow().nodes.clone()
3162 }
3163
3164 /// Promotes a given node to the output node of the parent graph.
3165 ///
3166 /// # Arguments
3167 ///
3168 /// `output_node` - node to be set as output
3169 ///
3170 /// # Example
3171 ///
3172 /// ```
3173 /// # use ciphercore_base::graphs::create_context;
3174 /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3175 /// let c = create_context().unwrap();
3176 /// let g = c.create_graph().unwrap();
3177 /// let t = array_type(vec![3, 2], INT32);
3178 /// let vec_t = vector_type(4, t);
3179 /// let n1 = g.input(vec_t).unwrap();
3180 /// let n2 = g.vector_to_array(n1).unwrap();
3181 /// g.set_output_node(n2).unwrap();
3182 /// g.finalize().unwrap();
3183 /// ```
3184 pub fn set_output_node(&self, output_node: Node) -> Result<()> {
3185 let current_output_node = self.body.borrow().output_node.clone();
3186 match current_output_node {
3187 Some(_) => Err(runtime_error!("Output node is already set")),
3188 None => {
3189 if output_node.get_graph() != *self {
3190 Err(runtime_error!("Output node has to be from the same graph"))
3191 } else {
3192 self.body.borrow_mut().output_node = Some(output_node.downgrade());
3193 Ok(())
3194 }
3195 }
3196 }
3197 }
3198
3199 /// Returns the output node of the graph.
3200 ///
3201 /// # Returns
3202 ///
3203 /// Output node of the graph
3204 pub fn get_output_node(&self) -> Result<Node> {
3205 let current_output_node = self.body.borrow().output_node.clone();
3206 match current_output_node {
3207 Some(output_node) => Ok(output_node.upgrade()),
3208 None => Err(runtime_error!("Output node is not set")),
3209 }
3210 }
3211
3212 /// Returns the ID of the graph.
3213 ///
3214 /// A graph ID is a serial number of a graph between `0` and `n-1` where `n` is the number of graphs in the parent context.
3215 ///
3216 /// # Returns
3217 ///
3218 /// Graph ID
3219 pub fn get_id(&self) -> u64 {
3220 self.body.borrow().id
3221 }
3222
3223 /// Returns the number of the graph nodes.
3224 ///
3225 /// # Returns
3226 ///
3227 /// Number of the graph nodes
3228 pub fn get_num_nodes(&self) -> u64 {
3229 self.body.borrow().nodes.len() as u64
3230 }
3231
3232 /// Returns the node corresponding to a given ID.
3233 ///
3234 /// # Arguments
3235 ///
3236 /// `id` - node ID
3237 ///
3238 /// # Returns
3239 ///
3240 /// Node with a given ID
3241 pub fn get_node_by_id(&self, id: u64) -> Result<Node> {
3242 let nodes = &self.body.borrow().nodes;
3243 if id >= nodes.len() as u64 {
3244 Err(runtime_error!("Invalid id for the node retrieval"))
3245 } else {
3246 Ok(nodes[id as usize].clone())
3247 }
3248 }
3249
3250 /// Returns the context of the graph nodes.
3251 ///
3252 /// # Returns
3253 ///
3254 /// Context of the graph
3255 pub fn get_context(&self) -> Context {
3256 self.body.borrow().context.upgrade()
3257 }
3258
3259 /// Adds a node computing a given custom operation.
3260 ///
3261 /// Custom operations can be created by the user as public structs implementing the [CustomOperationBody](../custom_ops/trait.CustomOperationBody.html).
3262 ///
3263 /// # Arguments
3264 ///
3265 /// * `op` - custom operation
3266 /// * `arguments` - vector of nodes used as input for the custom operation
3267 ///
3268 /// # Returns
3269 ///
3270 /// New custom operation node
3271 ///
3272 /// # Example
3273 ///
3274 /// ```
3275 /// # use ciphercore_base::graphs::create_context;
3276 /// # use ciphercore_base::data_types::{array_type, BIT};
3277 /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
3278 /// let c = create_context().unwrap();
3279 /// let g = c.create_graph().unwrap();
3280 /// let t = array_type(vec![3, 2], BIT);
3281 /// let n1 = g.input(t).unwrap();
3282 /// let n2 = g.custom_op(CustomOperation::new(Not {}), vec![n1]).unwrap();
3283 /// ```
3284 pub fn custom_op(&self, op: CustomOperation, arguments: Vec<Node>) -> Result<Node> {
3285 self.add_node(arguments, vec![], Operation::Custom(op))
3286 }
3287
3288 /// Adds a node which logs its input at runtime, and returns the input.
3289 /// This is intended to be used for debugging.
3290 ///
3291 /// # Arguments
3292 ///
3293 /// * `message` - Informational message to be printed
3294 /// * `input` - Node to be printed
3295 ///
3296 /// # Returns
3297 ///
3298 /// The value of the node.
3299 ///
3300 /// # Example
3301 ///
3302 /// ```
3303 /// # use ciphercore_base::graphs::create_context;
3304 /// # use ciphercore_base::data_types::{array_type, BIT};
3305 /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
3306 /// let c = create_context().unwrap();
3307 /// let g = c.create_graph().unwrap();
3308 /// let t = array_type(vec![3, 2], BIT);
3309 /// let n1 = g.input(t).unwrap();
3310 /// let n2 = g.print("n1:".into(), n1).unwrap();
3311 /// ```
3312 pub fn print(&self, message: String, input: Node) -> Result<Node> {
3313 self.add_node(vec![input], vec![], Operation::Print(message))
3314 }
3315
3316 /// Adds a node which fails the execution at runtime if `condition` is false, and returns the `input` otherwise.
3317 /// This is intended to be used for debugging.
3318 ///
3319 /// # Arguments
3320 ///
3321 /// * `message` - message to be returned for the failed assertion.
3322 /// * `condition` - BIT to be checked in the assertion.
3323 /// * `input` - Node to be returned for pass-through.
3324 ///
3325 /// # Returns
3326 ///
3327 /// The value of the node.
3328 ///
3329 /// # Example
3330 ///
3331 /// ```
3332 /// # use ciphercore_base::graphs::create_context;
3333 /// # use ciphercore_base::data_types::{array_type, scalar_type, BIT};
3334 /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
3335 /// let c = create_context().unwrap();
3336 /// let g = c.create_graph().unwrap();
3337 /// let cond = g.input(scalar_type(BIT)).unwrap();
3338 /// let t = array_type(vec![3, 2], BIT);
3339 /// let n1 = g.input(t).unwrap();
3340 /// let n2 = g.assert("Condition".into(), cond, n1).unwrap();
3341 /// ```
3342 pub fn assert(&self, message: String, condition: Node, input: Node) -> Result<Node> {
3343 self.add_node(vec![condition, input], vec![], Operation::Assert(message))
3344 }
3345}
3346
3347/// Methods which aren't supposed to be imported in Python.
3348impl Graph {
3349 /// Adds an operation node to the graph and returns it.
3350 ///
3351 /// # Arguments
3352 ///
3353 /// * `node_dependencies` - vector of nodes necessary to perform the given operation
3354 /// * `graph_dependencies` - vector of graphs necessary to perform the given operation
3355 /// * `operation` - operation performed by the node
3356 ///
3357 /// # Returns
3358 ///
3359 /// New operation node that gets added
3360 pub fn add_node(
3361 &self,
3362 node_dependencies: Vec<Node>,
3363 graph_dependencies: Vec<Graph>,
3364 operation: Operation,
3365 ) -> Result<Node> {
3366 if self.is_finalized() {
3367 return Err(runtime_error!("Can't add a node to a finalized graph"));
3368 }
3369 for dependency in &node_dependencies {
3370 if dependency.get_graph() != *self
3371 || dependency.get_id() >= self.body.borrow().nodes.len() as u64
3372 || self.body.borrow().nodes[dependency.get_id() as usize] != *dependency
3373 {
3374 return Err(runtime_error!(
3375 "Can't add a node with invalid node dependencies"
3376 ));
3377 }
3378 }
3379 for dependency in &graph_dependencies {
3380 if !dependency.is_finalized() {
3381 return Err(runtime_error!(
3382 "Can't add a node with not finilized graph dependency"
3383 ));
3384 }
3385 if dependency.get_id() >= self.get_id() {
3386 return Err(runtime_error!(
3387 "Can't add a node with graph dependency with bigger id. {} >= {}",
3388 dependency.get_id(),
3389 self.get_id()
3390 ));
3391 }
3392 if dependency.get_context() != self.get_context() {
3393 return Err(runtime_error!(
3394 "Can't add a node with graph dependency from different context"
3395 ));
3396 }
3397 }
3398 let id = self.body.borrow().nodes.len() as u64;
3399 let result = Node {
3400 body: Arc::new(AtomicRefCell::new(NodeBody {
3401 graph: self.downgrade(),
3402 node_dependencies: node_dependencies.iter().map(|n| n.downgrade()).collect(),
3403 graph_dependencies: graph_dependencies.iter().map(|g| g.downgrade()).collect(),
3404 operation,
3405 id,
3406 })),
3407 };
3408 {
3409 let mut cell = self.body.borrow_mut();
3410 cell.nodes.push(result.clone());
3411 }
3412 let mut context_has_type_checker = false;
3413 {
3414 let context = self.get_context();
3415 let mut context_cell = context.body.borrow_mut();
3416 let type_checker = &mut context_cell.type_checker;
3417 if type_checker.is_some() {
3418 context_has_type_checker = true;
3419 }
3420 }
3421 if context_has_type_checker {
3422 let type_checking_result = result.get_type();
3423 if type_checking_result.is_err() {
3424 self.remove_last_node(result)?;
3425 return Err(type_checking_result.expect_err("Should not be here"));
3426 }
3427 let type_result = type_checking_result?;
3428
3429 let size_estimate = get_size_estimation_in_bits(type_result);
3430 if size_estimate.is_err() {
3431 self.remove_last_node(result)?;
3432 return Err(runtime_error!("Trying to add a node with invalid size"));
3433 }
3434 if size_estimate? > type_size_limit_constants::MAX_INDIVIDUAL_NODE_SIZE {
3435 self.remove_last_node(result)?;
3436 return Err(runtime_error!(
3437 "Trying to add a node larger than MAX_INDIVIDUAL_NODE_SIZE"
3438 ));
3439 }
3440
3441 let context = self.get_context();
3442 let size_checking_result = context.try_update_total_size(result.clone());
3443 if size_checking_result.is_err() {
3444 self.remove_last_node(result)?;
3445 return Err(size_checking_result.expect_err("Should not be here"));
3446 }
3447 }
3448 Ok(result)
3449 }
3450
3451 fn remove_last_node(&self, n: Node) -> Result<()> {
3452 if n.get_graph() != *self {
3453 return Err(runtime_error!(
3454 "The node to be removed from a different graph"
3455 ));
3456 }
3457 {
3458 let cell = self.body.borrow();
3459 if n != *cell
3460 .nodes
3461 .last()
3462 .ok_or_else(|| runtime_error!("Nodes list is empty"))?
3463 {
3464 return Err(runtime_error!(
3465 "The node to be removed is not the last node"
3466 ));
3467 }
3468 };
3469 let context = self.get_context();
3470 context.unregister_node(n.clone())?;
3471 let mut context_body = context.body.borrow_mut();
3472 if let Some(tc) = &mut context_body.type_checker {
3473 tc.unregister_node(n)?;
3474 }
3475 let mut cell = self.body.borrow_mut();
3476 cell.nodes.pop();
3477 Ok(())
3478 }
3479
3480 pub(crate) fn nop(&self, a: Node) -> Result<Node> {
3481 self.add_node(vec![a], vec![], Operation::NOP)
3482 }
3483
3484 pub(crate) fn prf(&self, key: Node, iv: u64, output_type: Type) -> Result<Node> {
3485 self.add_node(vec![key], vec![], Operation::PRF(iv, output_type))
3486 }
3487
3488 pub(crate) fn permutation_from_prf(&self, key: Node, iv: u64, n: u64) -> Result<Node> {
3489 self.add_node(vec![key], vec![], Operation::PermutationFromPRF(iv, n))
3490 }
3491
3492 pub(super) fn is_finalized(&self) -> bool {
3493 self.body.borrow().finalized
3494 }
3495
3496 pub(super) fn check_finalized(&self) -> Result<()> {
3497 if !self.is_finalized() {
3498 return Err(runtime_error!("Graph is not finalized"));
3499 }
3500 Ok(())
3501 }
3502
3503 fn make_serializable(&self) -> SerializableGraph {
3504 let output_node = match self.get_output_node() {
3505 Ok(n) => Some(n.get_id()),
3506 Err(_) => None,
3507 };
3508 Arc::new(SerializableGraphBody {
3509 finalized: self.is_finalized(),
3510 nodes: self
3511 .get_nodes()
3512 .iter()
3513 .map(|n| n.make_serializable())
3514 .collect(),
3515 output_node,
3516 })
3517 }
3518
3519 fn downgrade(&self) -> WeakGraph {
3520 WeakGraph {
3521 body: Arc::downgrade(&self.body),
3522 }
3523 }
3524
3525 #[doc(hidden)]
3526 pub fn add_annotation(&self, annotation: GraphAnnotation) -> Result<Graph> {
3527 self.get_context().add_graph_annotation(self, annotation)?;
3528 Ok(self.clone())
3529 }
3530
3531 pub fn get_annotations(&self) -> Result<Vec<GraphAnnotation>> {
3532 self.get_context().get_graph_annotations(self.clone())
3533 }
3534
3535 /// Rearrange given input values according to the names and the order of the related input nodes.
3536 ///
3537 /// For example, given a graph with the first input node named 'A' and the second one named 'B' and input values `{'B': v, 'A': w}`, this function returns a vector `[w, v]`.
3538 ///
3539 /// # Arguments
3540 ///
3541 /// `values` - hashmap of values keyed by node names
3542 ///
3543 /// # Returns
3544 ///
3545 /// Vector of values arranged by node names
3546 ///
3547 /// # Example
3548 ///
3549 /// ```
3550 /// # use ciphercore_base::graphs::create_context;
3551 /// # use ciphercore_base::data_types::{BIT, scalar_type};
3552 /// # use std::collections::HashMap;
3553 /// let c = create_context().unwrap();
3554 /// let g = c.create_graph().unwrap();
3555 /// let t = scalar_type(BIT);
3556 /// let n1 = g.input(t.clone()).unwrap();
3557 /// n1.set_name("input1").unwrap();
3558 /// let n2 = g.input(t.clone()).unwrap();
3559 /// n2.set_name("input2").unwrap();
3560 ///
3561 /// let mut input_map = HashMap::new();
3562 /// input_map.insert("input2", 2);
3563 /// input_map.insert("input1", 1);
3564 /// let ordered_input = g.prepare_input_values(input_map).unwrap();
3565 ///
3566 /// assert_eq!(vec![1,2], ordered_input);
3567 /// ```
3568 pub fn prepare_input_values<T: Clone>(&self, values: HashMap<&str, T>) -> Result<Vec<T>> {
3569 self.get_context()
3570 .prepare_input_values(self.clone(), values)
3571 }
3572}
3573type WeakGraphBodyPointer = Weak<AtomicRefCell<GraphBody>>;
3574
3575struct WeakGraph {
3576 body: WeakGraphBodyPointer,
3577}
3578
3579impl WeakGraph {
3580 //upgrade function panics if the the Graph pointer it downgraded from went out of scope
3581 fn upgrade(&self) -> Graph {
3582 Graph {
3583 body: self.body.upgrade().unwrap(),
3584 }
3585 }
3586}
3587impl Clone for WeakGraph {
3588 fn clone(&self) -> Self {
3589 WeakGraph {
3590 body: self.body.clone(),
3591 }
3592 }
3593}
3594
3595#[doc(hidden)]
3596/// Various node-related properties which aren't used in the graph building
3597/// or type inference, but can be used in node expansion or MPC compilation.
3598#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
3599pub enum NodeAnnotation {
3600 AssociativeOperation,
3601 Private,
3602 Send(u64, u64), // (sender_index, receiver_index); indices belong to the set 0..PARTIES
3603 PRFMultiplication,
3604 PRFB2A,
3605 PRFTruncate,
3606}
3607
3608#[doc(hidden)]
3609/// Various graph-related properties which aren't used in the graph building
3610/// or type inference, but can be used in node expansion or MPC compilation.
3611#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
3612pub enum GraphAnnotation {
3613 AssociativeOperation,
3614 OneBitState,
3615 SmallState,
3616}
3617
3618struct ContextBody {
3619 finalized: bool,
3620 graphs: Vec<Graph>,
3621 main_graph: Option<WeakGraph>,
3622 /// graph_id -> name
3623 graphs_names: HashMap<u64, String>,
3624 /// name -> graph_id
3625 graphs_names_inverse: HashMap<String, u64>,
3626 /// (graph_id, node_id) -> name
3627 nodes_names: HashMap<(u64, u64), String>,
3628 /// graph_id -> (name -> node_id)
3629 nodes_names_inverse: HashMap<u64, HashMap<String, u64>>,
3630 /// (graph_id, node_id) -> NodeAnnotation's
3631 nodes_annotations: HashMap<(u64, u64), Vec<NodeAnnotation>>,
3632 /// (graph_id) -> GraphAnnotation's
3633 graphs_annotations: HashMap<u64, Vec<GraphAnnotation>>,
3634 total_size_nodes: u64,
3635 type_checker: Option<TypeInferenceWorker>,
3636}
3637
3638type ContextBodyPointer = Arc<AtomicRefCell<ContextBody>>;
3639
3640/// A structure that stores a pointer to a computation context that contains related computation graphs.
3641///
3642/// Context is a basic object to create computation graphs, arrange data flow between them and keep necessary information about them that is used for optimization, secure compilation and evaluation.
3643///
3644/// Context should have a main graph and be finalized in order to evaluate any of its graphs.
3645///
3646/// # Rust crates
3647///
3648/// [Clone] trait duplicates the pointer, not the underlying context.
3649///
3650/// [PartialEq] trait compares pointers, not the related contexts.
3651///
3652/// # Example
3653///
3654/// ```
3655/// # #[macro_use] extern crate maplit;
3656/// # fn main() {
3657/// # use ciphercore_base::graphs::{Context, create_context};
3658/// # use ciphercore_base::data_values::Value;
3659/// # use ciphercore_base::evaluators::random_evaluate;
3660/// # use ciphercore_base::data_types::{INT32, scalar_type};
3661/// # use ciphercore_base::errors::Result;
3662/// let context = || -> Result<Context> {
3663/// let context = create_context()?;
3664/// let graph = context.create_graph()?.set_name("main")?;
3665/// graph
3666/// .input(scalar_type(INT32))?
3667/// .set_name("a")?
3668/// .add(graph
3669/// .input(scalar_type(INT32))?
3670/// .set_name("b")?)?
3671/// .set_as_output()?;
3672/// graph.finalize()?.set_as_main()?;
3673/// context.finalize()?;
3674/// Ok(context)
3675/// }().unwrap();
3676///
3677/// let result = || -> Result<i32> {
3678/// let g = context.retrieve_graph("main")?;
3679/// let result = random_evaluate(
3680/// g.clone(),
3681/// g.prepare_input_values(
3682/// hashmap!{
3683/// "a" => Value::from_scalar(123, INT32)?,
3684/// "b" => Value::from_scalar(654, INT32)?,
3685/// },
3686/// )?,
3687/// )?;
3688/// let result = result.to_i32(INT32)?;
3689/// Ok(result)
3690/// }().unwrap();
3691///
3692/// assert_eq!(result, 777);
3693/// # }
3694/// ```
3695#[cfg_attr(feature = "py-binding", struct_wrapper)]
3696pub struct Context {
3697 body: ContextBodyPointer,
3698}
3699
3700impl fmt::Debug for Context {
3701 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3702 fmt::Display::fmt(self, f)
3703 }
3704}
3705
3706impl fmt::Display for Context {
3707 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3708 match serde_json::to_string(&self) {
3709 Ok(s) => write!(f, "{s}"),
3710 Err(_err) => Err(fmt::Error::default()),
3711 }
3712 }
3713}
3714
3715impl fmt::Display for Graph {
3716 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3717 write!(f, "Graph[num_nodes={}]", self.get_num_nodes())
3718 }
3719}
3720
3721impl fmt::Display for Node {
3722 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3723 write!(
3724 f,
3725 "Node[type={}]",
3726 self.get_type().map_err(|_| fmt::Error::default())?
3727 )
3728 }
3729}
3730
3731#[derive(Serialize, Deserialize)]
3732struct SerializableContextBody {
3733 finalized: bool,
3734 graphs: Vec<SerializableGraph>,
3735 main_graph: Option<u64>,
3736 /// graph_id -> name
3737 graphs_names: Vec<(u64, String)>,
3738 /// (graph_id, node_id) -> name
3739 nodes_names: Vec<((u64, u64), String)>,
3740 /// (graph_id, node_id) -> NodeAnnotation's
3741 nodes_annotations: Vec<((u64, u64), Vec<NodeAnnotation>)>,
3742 /// (graph_id) -> GraphAnnotation's
3743 graphs_annotations: Vec<(u64, Vec<GraphAnnotation>)>,
3744}
3745
3746impl SerializableContextBody {
3747 fn recover_original_graph(
3748 serializable_graph: SerializableGraph,
3749 context: Context,
3750 ) -> Result<Graph> {
3751 let result_graph = context.create_graph()?;
3752 for node in &serializable_graph.nodes {
3753 let mut node_dependencies = vec![];
3754 for id in &node.node_dependencies {
3755 let current_nodes = &result_graph.body.borrow().nodes;
3756 if *id >= current_nodes.len() as u64 {
3757 return Err(runtime_error!("Non-existent node dependency"));
3758 }
3759 node_dependencies.push(current_nodes[*id as usize].clone());
3760 }
3761 let mut graph_dependencies = vec![];
3762 for id in &node.graph_dependencies {
3763 let context = result_graph.get_context();
3764 let current_graphs = &context.body.borrow().graphs;
3765 if *id >= current_graphs.len() as u64 {
3766 return Err(runtime_error!("Non-existent graph dependency"));
3767 }
3768 graph_dependencies.push(current_graphs[*id as usize].clone());
3769 }
3770 result_graph.add_node(
3771 node_dependencies,
3772 graph_dependencies,
3773 node.operation.clone(),
3774 )?;
3775 }
3776 if let Some(id) = serializable_graph.output_node {
3777 let rebuilt_output_node = {
3778 let current_nodes = &result_graph.body.borrow().nodes;
3779 if id >= current_nodes.len() as u64 {
3780 return Err(runtime_error!("Non-existent output node"));
3781 }
3782 current_nodes[id as usize].clone()
3783 };
3784 result_graph.set_output_node(rebuilt_output_node)?;
3785 }
3786 if serializable_graph.finalized {
3787 result_graph.finalize()?;
3788 }
3789 Ok(result_graph)
3790 }
3791
3792 fn recover_original_context(&self) -> Result<Context> {
3793 let result_context = create_context()?;
3794 for graph in &self.graphs {
3795 let _result_graph =
3796 Self::recover_original_graph(graph.clone(), result_context.clone())?;
3797 }
3798 if let Some(id) = self.main_graph {
3799 let rebuilt_main_graph = {
3800 let current_graphs = &result_context.body.borrow().graphs;
3801 if id >= current_graphs.len() as u64 {
3802 return Err(runtime_error!("Non-existent main graph"));
3803 }
3804 current_graphs[id as usize].clone()
3805 };
3806 result_context.set_main_graph(rebuilt_main_graph)?;
3807 }
3808 for (id, _) in &self.graphs_names {
3809 let current_graphs = &result_context.body.borrow().graphs;
3810 if *id >= current_graphs.len() as u64 {
3811 return Err(runtime_error!("graphs_names contain an invalid ID"));
3812 }
3813 }
3814 for ((graph_id, node_id), _) in &self.nodes_names {
3815 let current_graphs = &result_context.body.borrow().graphs;
3816 if *graph_id >= current_graphs.len() as u64 {
3817 return Err(runtime_error!("nodes_names contain an invalid graph ID"));
3818 }
3819 let current_nodes = ¤t_graphs[*graph_id as usize].body.borrow().nodes;
3820 if *node_id >= current_nodes.len() as u64 {
3821 return Err(runtime_error!("nodes_names contain an invalid node ID"));
3822 }
3823 }
3824 for (id, name) in &self.graphs_names {
3825 let current_graph = {
3826 let current_graphs = &result_context.body.borrow().graphs;
3827 current_graphs[*id as usize].clone()
3828 };
3829 result_context.set_graph_name(current_graph, name)?;
3830 }
3831 for ((graph_id, node_id), name) in &self.nodes_names {
3832 let current_node = {
3833 let current_graphs = &result_context.body.borrow().graphs;
3834 let current_nodes = ¤t_graphs[*graph_id as usize].body.borrow().nodes;
3835 current_nodes[*node_id as usize].clone()
3836 };
3837 result_context.set_node_name(current_node, name)?;
3838 }
3839 for (id, annotations) in &self.graphs_annotations {
3840 let current_graph = {
3841 let current_graphs = &result_context.body.borrow().graphs;
3842 current_graphs[*id as usize].clone()
3843 };
3844 for annotation in annotations {
3845 result_context.add_graph_annotation(¤t_graph, annotation.clone())?;
3846 }
3847 }
3848 for ((graph_id, node_id), annotations) in &self.nodes_annotations {
3849 let current_node = {
3850 let current_graphs = &result_context.body.borrow().graphs;
3851 let current_nodes = ¤t_graphs[*graph_id as usize].body.borrow().nodes;
3852 current_nodes[*node_id as usize].clone()
3853 };
3854 for annotation in annotations {
3855 result_context.add_node_annotation(¤t_node, annotation.clone())?;
3856 }
3857 }
3858 if self.finalized {
3859 result_context.finalize()?;
3860 }
3861 Ok(result_context)
3862 }
3863}
3864
3865type SerializableContext = Arc<SerializableContextBody>;
3866
3867impl Clone for Context {
3868 /// Returns a new [Context] value with a copy of the pointer to a node.
3869 fn clone(&self) -> Self {
3870 Context {
3871 body: self.body.clone(),
3872 }
3873 }
3874}
3875
3876impl PartialEq for Context {
3877 /// Tests whether `self` and `other` contexts are equal via comparison of their respective pointers.
3878 ///
3879 /// # Arguments
3880 ///
3881 /// `other` - another [Context] value
3882 ///
3883 /// # Returns
3884 ///
3885 /// `true` if `self` and `other` are equal, `false` otherwise
3886 fn eq(&self, other: &Self) -> bool {
3887 Arc::ptr_eq(&self.body, &other.body)
3888 }
3889}
3890
3891impl Eq for Context {}
3892
3893/// Public methods which supposed to be imported in Python.
3894#[cfg_attr(feature = "py-binding", impl_wrapper)]
3895impl Context {
3896 /// Creates an empty computation graph in this context.
3897 ///
3898 /// # Returns
3899 ///
3900 /// New computation graph
3901 ///
3902 /// # Example
3903 ///
3904 /// ```
3905 /// # use ciphercore_base::graphs::create_context;
3906 /// let c = create_context().unwrap();
3907 /// let g = c.create_graph().unwrap();
3908 /// ```
3909 pub fn create_graph(&self) -> Result<Graph> {
3910 if self.body.borrow().finalized {
3911 return Err(runtime_error!("Can't add a graph to a finalized context"));
3912 }
3913 let id = self.body.borrow().graphs.len() as u64;
3914 let result = Graph {
3915 body: Arc::new(AtomicRefCell::new(GraphBody {
3916 finalized: false,
3917 nodes: vec![],
3918 output_node: None,
3919 id,
3920 context: self.downgrade(),
3921 })),
3922 };
3923 self.body.borrow_mut().graphs.push(result.clone());
3924 Ok(result)
3925 }
3926
3927 /// Finalizes the context if all its graphs are finalized and the main graph is set.
3928 ///
3929 /// After finalization the context can't be changed.
3930 ///
3931 /// # Returns
3932 ///
3933 /// Finalized context
3934 ///
3935 /// # Example
3936 ///
3937 /// ```
3938 /// # use ciphercore_base::graphs::create_context;
3939 /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3940 /// let c = create_context().unwrap();
3941 /// let g = c.create_graph().unwrap();
3942 /// let t = array_type(vec![3, 2], INT32);
3943 /// let vec_t = vector_type(4, t);
3944 /// let n1 = g.input(vec_t).unwrap();
3945 /// let n2 = g.vector_to_array(n1).unwrap();
3946 /// n2.set_as_output().unwrap();
3947 /// g.finalize().unwrap();
3948 /// c.set_main_graph(g).unwrap();
3949 /// c.finalize().unwrap();
3950 /// ```
3951 pub fn finalize(&self) -> Result<Context> {
3952 for graph in self.get_graphs() {
3953 graph.check_finalized()?;
3954 }
3955 let main_graph = self.body.borrow().main_graph.clone();
3956 match main_graph {
3957 Some(_) => {
3958 self.body.borrow_mut().finalized = true;
3959 Ok(self.clone())
3960 }
3961 _ => Err(runtime_error!(
3962 "Can't finalize the context without the main graph"
3963 )),
3964 }
3965 }
3966
3967 /// Promotes a graph to the main one in this context.
3968 ///
3969 /// # Arguments
3970 ///
3971 /// `graph` - graph
3972 ///
3973 /// # Returns
3974 ///
3975 /// This context
3976 ///
3977 /// # Example
3978 ///
3979 /// ```
3980 /// # use ciphercore_base::graphs::create_context;
3981 /// # use ciphercore_base::data_types::{array_type, INT32};
3982 /// let c = create_context().unwrap();
3983 /// let g = c.create_graph().unwrap();
3984 /// let t = array_type(vec![3, 2], INT32);
3985 /// let n = g.input(t).unwrap();
3986 /// n.set_as_output().unwrap();
3987 /// g.finalize().unwrap();
3988 /// c.set_main_graph(g).unwrap();
3989 /// ```
3990 pub fn set_main_graph(&self, graph: Graph) -> Result<Context> {
3991 let current_main_graph = self.body.borrow().main_graph.clone();
3992 match current_main_graph {
3993 Some(_) => Err(runtime_error!("Main graph is already set")),
3994 None => {
3995 if graph.get_context() != *self {
3996 return Err(runtime_error!("Main graph is from the wrong context"));
3997 }
3998 graph.check_finalized()?;
3999 self.body.borrow_mut().main_graph = Some(graph.downgrade());
4000 Ok(self.clone())
4001 }
4002 }
4003 }
4004
4005 /// Returns the vector of graphs contained in this context in order of creation.
4006 ///
4007 /// # Returns
4008 ///
4009 /// Vector of the graphs in this context
4010 pub fn get_graphs(&self) -> Vec<Graph> {
4011 self.body.borrow().graphs.clone()
4012 }
4013
4014 /// Does nothing if the context is finalized; otherwise returns a runtime error.
4015 ///
4016 /// # Returns
4017 ///
4018 /// Runtime error if this context is not finalized
4019 pub fn check_finalized(&self) -> Result<()> {
4020 if !self.is_finalized() {
4021 return Err(runtime_error!("Context is not finalized"));
4022 }
4023 Ok(())
4024 }
4025
4026 /// Returns the main graph of the context if it is already set.
4027 ///
4028 /// # Returns
4029 ///
4030 /// Main graph of the context
4031 pub fn get_main_graph(&self) -> Result<Graph> {
4032 match &self.body.borrow().main_graph {
4033 Some(g) => Ok(g.upgrade()),
4034 None => Err(runtime_error!("main graph is not set")),
4035 }
4036 }
4037
4038 /// Returns the number of graphs contained in this context.
4039 ///
4040 /// # Returns
4041 ///
4042 /// Number of the graphs in this context
4043 pub fn get_num_graphs(&self) -> u64 {
4044 self.body.borrow().graphs.len() as u64
4045 }
4046
4047 /// Returns the graph contained in this context with a given ID.
4048 ///
4049 /// A graph ID is a serial number of a graph between `0` and `n-1` where `n` is the number of graphs in this context.
4050 ///
4051 /// # Arguments
4052 ///
4053 /// `id` - ID of a graph
4054 ///
4055 /// # Returns
4056 ///
4057 /// Graph with the given ID
4058 pub fn get_graph_by_id(&self, id: u64) -> Result<Graph> {
4059 let graphs = &self.body.borrow().graphs;
4060 if id >= graphs.len() as u64 {
4061 Err(runtime_error!("Invalid id for the graph retrieval"))
4062 } else {
4063 Ok(graphs[id as usize].clone())
4064 }
4065 }
4066
4067 /// Returns the node contained in this context with a given global ID.
4068 ///
4069 /// The global ID of a node is a pair of the node ID and the ID of its parent graph.
4070 ///
4071 /// # Arguments
4072 ///
4073 /// `id` - tuple (graph ID, node ID)
4074 ///
4075 /// # Returns
4076 ///
4077 /// Node with the given global ID
4078 pub fn get_node_by_global_id(&self, id: (u64, u64)) -> Result<Node> {
4079 self.get_graph_by_id(id.0)?.get_node_by_id(id.1)
4080 }
4081
4082 /// Sets the name of a graph.
4083 ///
4084 /// A given name should be unique.
4085 ///
4086 /// # Arguments
4087 ///
4088 /// * `graph` - graph
4089 /// * `name` - name of the graph
4090 ///
4091 /// # Returns
4092 ///
4093 /// This context
4094 ///
4095 /// # Example
4096 ///
4097 /// ```
4098 /// # use ciphercore_base::graphs::create_context;
4099 /// let c = create_context().unwrap();
4100 /// let g = c.create_graph().unwrap();
4101 /// g.set_name("relu").unwrap();
4102 /// ```
4103 pub fn set_graph_name(&self, graph: Graph, name: &str) -> Result<Context> {
4104 if graph.get_context() != *self {
4105 return Err(runtime_error!(
4106 "The graph to be named is in a different context"
4107 ));
4108 }
4109 if self.is_finalized() {
4110 return Err(runtime_error!(
4111 "Can't set a graph name in a finalized context"
4112 ));
4113 }
4114 let id = graph.get_id();
4115 let name_owned = name.to_owned();
4116 let mut cell = self.body.borrow_mut();
4117 if cell.graphs_names.get(&id).is_some() {
4118 return Err(runtime_error!("Can't set the graph name twice"));
4119 }
4120 if cell.graphs_names_inverse.get(name).is_some() {
4121 return Err(runtime_error!("Graph names must be unique"));
4122 }
4123 cell.graphs_names.insert(id, name_owned.clone());
4124 cell.graphs_names_inverse.insert(name_owned, id);
4125 Ok(self.clone())
4126 }
4127
4128 /// Returns the name of a graph.
4129 ///
4130 /// # Arguments
4131 ///
4132 /// `graph` - graph
4133 ///
4134 /// # Returns
4135 ///
4136 /// Name of a given graph
4137 ///
4138 /// # Example
4139 ///
4140 /// ```
4141 /// # use ciphercore_base::graphs::create_context;
4142 /// let c = create_context().unwrap();
4143 /// let g = c.create_graph().unwrap();
4144 /// g.set_name("relu").unwrap();
4145 /// assert_eq!(c.get_graph_name(g).unwrap(), "relu".to_owned());
4146 /// ```
4147 pub fn get_graph_name(&self, graph: Graph) -> Result<String> {
4148 if graph.get_context() != *self {
4149 return Err(runtime_error!("The graph is in a different context"));
4150 }
4151 let cell = self.body.borrow();
4152 Ok(cell
4153 .graphs_names
4154 .get(&graph.get_id())
4155 .ok_or_else(|| runtime_error!("The graph does not have a name assigned"))?
4156 .clone())
4157 }
4158
4159 /// Returns the graph with a given name in this context.
4160 ///
4161 /// # Arguments
4162 ///
4163 /// `name` - graph name
4164 ///
4165 /// # Returns
4166 ///
4167 /// Graph with a given name
4168 ///
4169 /// # Example
4170 ///
4171 /// ```
4172 /// # use ciphercore_base::graphs::create_context;
4173 /// # use ciphercore_base::data_types::{BIT, scalar_type};
4174 /// let c = create_context().unwrap();
4175 /// let g = c.create_graph().unwrap();
4176 /// let n = g.input(scalar_type(BIT)).unwrap();
4177 /// g.set_name("input_graph").unwrap();
4178 /// assert!(g == c.retrieve_graph("input_graph").unwrap());
4179 /// ```
4180 pub fn retrieve_graph(&self, name: &str) -> Result<Graph> {
4181 let cell = self.body.borrow();
4182 let id = cell
4183 .graphs_names_inverse
4184 .get(name)
4185 .ok_or_else(|| runtime_error!("No graph with such a name exists"))?;
4186 let graph = cell.graphs[*id as usize].clone();
4187 Ok(graph)
4188 }
4189
4190 /// Sets the name of a node.
4191 ///
4192 /// A given name should be unique.
4193 ///
4194 /// # Arguments
4195 ///
4196 /// * `node` - node
4197 /// * `name` - name of a node
4198 ///
4199 /// # Returns
4200 ///
4201 /// This context
4202 ///
4203 /// # Example
4204 ///
4205 /// ```
4206 /// # use ciphercore_base::graphs::create_context;
4207 /// # use ciphercore_base::data_types::{scalar_type, BIT};
4208 /// let c = create_context().unwrap();
4209 /// let g = c.create_graph().unwrap();
4210 /// let t = scalar_type(BIT);
4211 /// let n = g.input(t).unwrap();
4212 /// c.set_node_name(n, "XOR").unwrap();
4213 /// ```
4214 pub fn set_node_name(&self, node: Node, name: &str) -> Result<Context> {
4215 if node.get_graph().get_context() != *self {
4216 return Err(runtime_error!(
4217 "The node to be named is in a different context"
4218 ));
4219 }
4220 if self.is_finalized() {
4221 return Err(runtime_error!(
4222 "Can't set a node name in a finalized context"
4223 ));
4224 }
4225 let node_id = node.get_id();
4226 let graph_id = node.get_graph().get_id();
4227 let mut cell = self.body.borrow_mut();
4228 if cell.nodes_names.get(&(graph_id, node_id)).is_some() {
4229 return Err(runtime_error!("Can't set the node name twice"));
4230 }
4231 if cell.nodes_names_inverse.get(&graph_id).is_none() {
4232 cell.nodes_names_inverse.insert(graph_id, HashMap::new());
4233 }
4234 let graph_map_inverse = cell
4235 .nodes_names_inverse
4236 .get_mut(&graph_id)
4237 .expect("Should not be here!");
4238 if graph_map_inverse.get(name).is_some() {
4239 return Err(runtime_error!(
4240 "Node names must be unique (within the graph)"
4241 ));
4242 }
4243 graph_map_inverse.insert(name.to_owned(), node_id);
4244 cell.nodes_names
4245 .insert((graph_id, node_id), name.to_owned());
4246 Ok(self.clone())
4247 }
4248
4249 /// Returns the name of a node.
4250 ///
4251 /// # Arguments
4252 ///
4253 /// `node` - node
4254 ///
4255 /// # Returns
4256 ///
4257 /// Name of a node or None if it doesn't have a name
4258 ///
4259 /// # Example
4260 ///
4261 /// ```
4262 /// # use ciphercore_base::graphs::create_context;
4263 /// # use ciphercore_base::data_types::{scalar_type, BIT};
4264 /// let c = create_context().unwrap();
4265 /// let g = c.create_graph().unwrap();
4266 /// let t = scalar_type(BIT);
4267 /// let n = g.input(t).unwrap();
4268 /// n.set_name("XOR").unwrap();
4269 /// assert_eq!(c.get_node_name(n).unwrap(), Some("XOR".to_owned()));
4270 /// ```
4271 pub fn get_node_name(&self, node: Node) -> Result<Option<String>> {
4272 if node.get_graph().get_context() != *self {
4273 return Err(runtime_error!("The node is in a different context"));
4274 }
4275 let node_id = node.get_id();
4276 let graph_id = node.get_graph().get_id();
4277 let cell = self.body.borrow();
4278 Ok(cell.nodes_names.get(&(graph_id, node_id)).cloned())
4279 }
4280
4281 /// Returns the node with a given name in a given graph.
4282 ///
4283 /// # Arguments
4284 ///
4285 /// * `graph` - graph
4286 /// * `name` - node name
4287 ///
4288 /// # Returns
4289 ///
4290 /// Node with a given name
4291 ///
4292 /// # Example
4293 ///
4294 /// ```
4295 /// # use ciphercore_base::graphs::create_context;
4296 /// # use ciphercore_base::data_types::{BIT, scalar_type};
4297 /// let c = create_context().unwrap();
4298 /// let g = c.create_graph().unwrap();
4299 /// let n = g.input(scalar_type(BIT)).unwrap();
4300 /// n.set_name("input_node").unwrap();
4301 /// assert!(n == c.retrieve_node(g, "input_node").unwrap());
4302 /// ```
4303 pub fn retrieve_node(&self, graph: Graph, name: &str) -> Result<Node> {
4304 if graph.get_context() != *self {
4305 return Err(runtime_error!("The graph is in a different context"));
4306 }
4307 let graph_id = graph.get_id();
4308 let cell = self.body.borrow();
4309 let node_id = cell
4310 .nodes_names_inverse
4311 .get(&graph_id)
4312 .ok_or_else(|| runtime_error!("The graph has no named nodes"))?
4313 .get(name)
4314 .ok_or_else(|| runtime_error!("Node with a given name does not exist"))?;
4315 Ok(graph.body.borrow().nodes[*node_id as usize].clone())
4316 }
4317 /// Check that two given contexts contain the same data, i.e. graphs, nodes, names, parameters.
4318 ///
4319 /// Underlying structures that contain pointers (graphs, nodes) are compared by data they refer to.
4320 ///
4321 /// # Arguments
4322 ///
4323 /// * `context2` - context to compare
4324 ///
4325 /// # Returns
4326 ///
4327 /// `true` if the given contexts contain the same content, otherwise `false`
4328 pub fn deep_equal(&self, context2: Context) -> bool {
4329 contexts_deep_equal(self.clone(), context2)
4330 }
4331}
4332
4333fn serialize_hashmap<K, V>(map: HashMap<K, V>) -> Vec<(K, V)>
4334where
4335 K: Ord + Copy,
4336{
4337 let mut vec: Vec<_> = map.into_iter().collect();
4338 vec.sort_by_key(|(k, _)| *k);
4339 vec
4340}
4341
4342/// Methods which aren't supposed to be imported in Python.
4343impl Context {
4344 pub(super) fn is_finalized(&self) -> bool {
4345 self.body.borrow().finalized
4346 }
4347
4348 fn make_serializable(&self) -> SerializableContext {
4349 let main_graph = match self.get_main_graph() {
4350 Ok(g) => Some(g.get_id()),
4351 Err(_) => None,
4352 };
4353 let cell = self.body.borrow();
4354 Arc::new(SerializableContextBody {
4355 finalized: self.is_finalized(),
4356 graphs: self
4357 .get_graphs()
4358 .iter()
4359 .map(|g| g.make_serializable())
4360 .collect(),
4361 main_graph,
4362 graphs_names: serialize_hashmap(cell.graphs_names.clone()),
4363 nodes_names: serialize_hashmap(cell.nodes_names.clone()),
4364 graphs_annotations: serialize_hashmap(cell.graphs_annotations.clone()),
4365 nodes_annotations: serialize_hashmap(cell.nodes_annotations.clone()),
4366 })
4367 }
4368
4369 fn add_type_checker(&self) -> Result<Context> {
4370 {
4371 let mut cell = self.body.borrow_mut();
4372 if cell.type_checker.is_some() {
4373 return Err(runtime_error!(
4374 "Type checker associated with the context already exists"
4375 ));
4376 }
4377 cell.type_checker = Some(create_type_inference_worker(self.clone()));
4378 }
4379 for graph in self.get_graphs() {
4380 for node in graph.get_nodes() {
4381 node.get_type()?;
4382 }
4383 }
4384 Ok(self.clone())
4385 }
4386
4387 fn get_total_size_nodes(&self) -> u64 {
4388 self.body.borrow().total_size_nodes
4389 }
4390
4391 fn set_total_size_nodes(&self, size: u64) {
4392 self.body.borrow_mut().total_size_nodes = size;
4393 }
4394
4395 fn try_update_total_size(&self, node: Node) -> Result<()> {
4396 let node_type = match node.get_operation() {
4397 Operation::Input(input_type) => input_type,
4398 Operation::Constant(t, _) => t,
4399 _ => return Ok(()),
4400 };
4401 if !node_type.is_valid() {
4402 return Err(runtime_error!("Node with an invalid type: {:?}", node_type));
4403 }
4404 let new_total_size = self
4405 .get_total_size_nodes()
4406 .checked_add(get_size_estimation_in_bits(node_type)?)
4407 .ok_or_else(|| runtime_error!("add overflow!"))?;
4408 if new_total_size > type_size_limit_constants::MAX_TOTAL_SIZE_NODES {
4409 return Err(runtime_error!(
4410 "Can't add a node: total size of nodes exceeds MAX_TOTAL_SIZE_NODES"
4411 ));
4412 }
4413 self.set_total_size_nodes(new_total_size);
4414 Ok(())
4415 }
4416
4417 fn unregister_node(&self, node: Node) -> Result<()> {
4418 if node.get_graph().get_context() != *self {
4419 return Err(runtime_error!(
4420 "The node to be unregister from a different context"
4421 ));
4422 }
4423 if self.is_finalized() {
4424 return Err(runtime_error!(
4425 "Can't unregister a node from a finalized context"
4426 ));
4427 }
4428
4429 let node_id = node.get_id();
4430 let graph_id = node.get_graph().get_id();
4431
4432 let mut cell = self.body.borrow_mut();
4433 let name_option = cell.nodes_names.remove(&(graph_id, node_id));
4434 cell.nodes_annotations.remove(&(graph_id, node_id));
4435 if cell.nodes_names_inverse.get(&graph_id).is_none() {
4436 return Ok(());
4437 }
4438 let graph_map_inverse = cell
4439 .nodes_names_inverse
4440 .get_mut(&graph_id)
4441 .expect("Should not be here!");
4442 if let Some(name) = name_option {
4443 graph_map_inverse.remove(&name);
4444 }
4445 Ok(())
4446 }
4447
4448 fn to_versioned_data(&self) -> Result<VersionedData> {
4449 VersionedData::create_versioned_data(
4450 DATA_VERSION,
4451 serde_json::to_string(&self.make_serializable())?,
4452 )
4453 }
4454 fn prepare_input_values<T: Clone>(
4455 &self,
4456 graph: Graph,
4457 values: HashMap<&str, T>,
4458 ) -> Result<Vec<T>> {
4459 if graph.get_context() != *self {
4460 return Err(runtime_error!("The graph is in a different context"));
4461 }
4462 let graph_id = graph.get_id();
4463 let cell = self.body.borrow();
4464 for node_name in values.keys() {
4465 cell.nodes_names_inverse
4466 .get(&graph_id)
4467 .ok_or_else(|| runtime_error!("Trying to call graph without named nodes"))?
4468 .get(node_name as &str)
4469 .ok_or_else(|| runtime_error!("Input with a given name is not found"))?;
4470 }
4471 let mut result = vec![];
4472 for node in graph.get_nodes() {
4473 if node.get_operation().is_input() {
4474 let node_id = node.get_id();
4475 let node_name = cell
4476 .nodes_names
4477 .get(&(graph_id, node_id))
4478 .ok_or_else(|| runtime_error!("Unnamed input"))?;
4479 let node_value = values
4480 .get(node_name as &str)
4481 .ok_or_else(|| runtime_error!("Unspecified input"))?
4482 .clone();
4483 result.push(node_value);
4484 }
4485 }
4486 Ok(result)
4487 }
4488
4489 pub(super) fn add_node_annotation(
4490 &self,
4491 node: &Node,
4492 annotation: NodeAnnotation,
4493 ) -> Result<Context> {
4494 if node.get_graph().get_context() != *self {
4495 return Err(runtime_error!(
4496 "The node to be annotated is in a different context"
4497 ));
4498 }
4499 if self.is_finalized() {
4500 return Err(runtime_error!(
4501 "Can't add a node annotation in a finalized context"
4502 ));
4503 }
4504 let node_id = node.get_id();
4505 let graph_id = node.get_graph().get_id();
4506 let key = (graph_id, node_id);
4507 let mut cell = self.body.borrow_mut();
4508 let annotations = cell.nodes_annotations.get_mut(&key);
4509 if let Some(annotation_vec) = annotations {
4510 annotation_vec.push(annotation);
4511 } else {
4512 cell.nodes_annotations.insert(key, vec![annotation]);
4513 }
4514 Ok(self.clone())
4515 }
4516
4517 pub(super) fn get_node_annotations(&self, node: Node) -> Result<Vec<NodeAnnotation>> {
4518 if node.get_graph().get_context() != *self {
4519 return Err(runtime_error!("The node is in a different context"));
4520 }
4521 let node_id = node.get_id();
4522 let graph_id = node.get_graph().get_id();
4523 let cell = self.body.borrow();
4524 Ok(cell
4525 .nodes_annotations
4526 .get(&(graph_id, node_id))
4527 .cloned()
4528 .unwrap_or_default())
4529 }
4530
4531 fn add_graph_annotation(&self, graph: &Graph, annotation: GraphAnnotation) -> Result<Context> {
4532 if graph.get_context() != *self {
4533 return Err(runtime_error!(
4534 "The graph to be annotated is in a different context"
4535 ));
4536 }
4537 if self.is_finalized() {
4538 return Err(runtime_error!(
4539 "Can't set a graph annotation in a finalized context"
4540 ));
4541 }
4542 let id = graph.get_id();
4543 let mut cell = self.body.borrow_mut();
4544 let annotations = cell.graphs_annotations.get_mut(&id);
4545 if let Some(annotation_vec) = annotations {
4546 annotation_vec.push(annotation);
4547 } else {
4548 cell.graphs_annotations.insert(id, vec![annotation]);
4549 }
4550 Ok(self.clone())
4551 }
4552
4553 fn get_graph_annotations(&self, graph: Graph) -> Result<Vec<GraphAnnotation>> {
4554 if graph.get_context() != *self {
4555 return Err(runtime_error!("The graph is in a different context"));
4556 }
4557 let cell = self.body.borrow();
4558 Ok(cell
4559 .graphs_annotations
4560 .get(&graph.get_id())
4561 .cloned()
4562 .unwrap_or_default())
4563 }
4564
4565 pub(super) fn downgrade(&self) -> WeakContext {
4566 WeakContext {
4567 body: Arc::downgrade(&self.body),
4568 }
4569 }
4570}
4571
4572impl Serialize for Context {
4573 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
4574 where
4575 S: Serializer,
4576 {
4577 let versioned_context = self
4578 .to_versioned_data()
4579 .expect("Error during conversion from Context into VersionedData");
4580 //VersionedData::from(self.clone());
4581 versioned_context.serialize(serializer)
4582 }
4583}
4584
4585impl<'de> Deserialize<'de> for Context {
4586 fn deserialize<D>(deserializer: D) -> std::result::Result<Context, D::Error>
4587 where
4588 D: Deserializer<'de>,
4589 {
4590 let versioned_context = VersionedData::deserialize(deserializer)?;
4591 if !versioned_context.check_version(DATA_VERSION) {
4592 Err(runtime_error!(
4593 "Context version doesn't match the requirement"
4594 ))
4595 .map_err(serde::de::Error::custom)
4596 } else {
4597 let serializable_context =
4598 serde_json::from_str::<SerializableContext>(versioned_context.get_data_string())
4599 .expect("Error during deserialization of SerializableContext");
4600 serializable_context
4601 .recover_original_context()
4602 .map_err(serde::de::Error::custom)
4603 }
4604 }
4605}
4606
4607/// In general, `create_unchecked_context()` should not return errors, but
4608/// we still make the result type Result<Context> for uniformity.
4609pub(super) fn create_unchecked_context() -> Result<Context> {
4610 Ok(Context {
4611 body: Arc::new(AtomicRefCell::new(ContextBody {
4612 finalized: false,
4613 graphs: vec![],
4614 main_graph: None,
4615 graphs_names: HashMap::new(),
4616 graphs_names_inverse: HashMap::new(),
4617 nodes_names: HashMap::new(),
4618 nodes_names_inverse: HashMap::new(),
4619 graphs_annotations: HashMap::new(),
4620 nodes_annotations: HashMap::new(),
4621 type_checker: None,
4622 total_size_nodes: 0,
4623 })),
4624 })
4625}
4626
4627/// Creates an empty computation context.
4628///
4629/// # Returns
4630///
4631/// New computation context
4632///
4633/// # Example
4634///
4635/// ```
4636/// # use ciphercore_base::graphs::create_context;
4637/// let c = create_context().unwrap();
4638/// ```
4639#[cfg_attr(feature = "py-binding", fn_wrapper)]
4640pub fn create_context() -> Result<Context> {
4641 let context = create_unchecked_context()?;
4642 context.add_type_checker()?;
4643 Ok(context)
4644}
4645
4646fn graphs_deep_equal(graph1: Graph, graph2: Graph) -> bool {
4647 let graph1_body = graph1.body.borrow();
4648 let graph2_body = graph2.body.borrow();
4649 if graph1_body.finalized != graph2_body.finalized {
4650 return false;
4651 }
4652 if graph1_body.nodes.len() != graph2_body.nodes.len() {
4653 return false;
4654 }
4655 for j in 0..graph1_body.nodes.len() {
4656 let node1 = graph1_body.nodes[j].clone();
4657 let node2 = graph2_body.nodes[j].clone();
4658 let node1_body = node1.body.borrow();
4659 let node2_body = node2.body.borrow();
4660 if node1_body.operation != node2_body.operation {
4661 return false;
4662 }
4663 let node_dependencies1: Vec<u64> = node1_body
4664 .node_dependencies
4665 .iter()
4666 .map(|n| n.upgrade().get_id())
4667 .collect();
4668 let node_dependencies2: Vec<u64> = node2_body
4669 .node_dependencies
4670 .iter()
4671 .map(|n| n.upgrade().get_id())
4672 .collect();
4673 if node_dependencies1 != node_dependencies2 {
4674 return false;
4675 }
4676 let graph_dependencies1: Vec<u64> = node1_body
4677 .graph_dependencies
4678 .iter()
4679 .map(|g| g.upgrade().get_id())
4680 .collect();
4681 let graph_dependencies2: Vec<u64> = node2_body
4682 .graph_dependencies
4683 .iter()
4684 .map(|g| g.upgrade().get_id())
4685 .collect();
4686 if graph_dependencies1 != graph_dependencies2 {
4687 return false;
4688 }
4689 }
4690 if graph1_body
4691 .output_node
4692 .clone()
4693 .map(|n| n.upgrade().get_id())
4694 != graph2_body
4695 .output_node
4696 .clone()
4697 .map(|n| n.upgrade().get_id())
4698 {
4699 return false;
4700 }
4701 true
4702}
4703
4704/// Check that two given contexts contain the same data, i.e. graphs, nodes, names, parameters.
4705///
4706/// Underlying structures that contain pointers (graphs, nodes) are compared by data they refer to.
4707///
4708/// # Arguments
4709///
4710/// * `context1` - first context to compare
4711/// * `context2` - second context to compare
4712///
4713/// # Returns
4714///
4715/// `true` if the given contexts contain the same content, otherwise `false`
4716pub fn contexts_deep_equal(context1: Context, context2: Context) -> bool {
4717 let body1 = context1.body.borrow();
4718 let body2 = context2.body.borrow();
4719 if body1.finalized != body2.finalized {
4720 return false;
4721 }
4722 if body1.graphs_names != body2.graphs_names {
4723 return false;
4724 }
4725 if body1.nodes_names != body2.nodes_names {
4726 return false;
4727 }
4728 if body1.nodes_annotations != body2.nodes_annotations {
4729 return false;
4730 }
4731 if body1.graphs_annotations != body2.graphs_annotations {
4732 return false;
4733 }
4734 if body1.graphs.len() != body2.graphs.len() {
4735 return false;
4736 }
4737 for i in 0..body1.graphs.len() {
4738 if !graphs_deep_equal(body1.graphs[i].clone(), body2.graphs[i].clone()) {
4739 return false;
4740 }
4741 }
4742 body1.main_graph.clone().map(|g| g.upgrade().get_id())
4743 == body2.main_graph.clone().map(|g| g.upgrade().get_id())
4744}
4745
4746// Pass the node name of `in_node` to `out_node` if it is present.
4747pub(crate) fn copy_node_name(in_node: Node, out_node: Node) -> Result<()> {
4748 if let Some(node_name) = in_node.get_name()? {
4749 out_node.set_name(&node_name)?;
4750 }
4751 Ok(())
4752}
4753
4754type WeakContextBodyPointer = Weak<AtomicRefCell<ContextBody>>;
4755
4756pub(super) struct WeakContext {
4757 body: WeakContextBodyPointer,
4758}
4759
4760impl WeakContext {
4761 //upgrade function panics if the the Context pointer it downgraded from went out of scope
4762 pub(super) fn upgrade(&self) -> Context {
4763 Context {
4764 body: self.body.upgrade().unwrap(),
4765 }
4766 }
4767}
4768
4769#[doc(hidden)]
4770#[cfg(feature = "py-binding")]
4771#[pyo3::pymethods]
4772impl PyBindingSliceElement {
4773 #[staticmethod]
4774 pub fn from_single_element(ind: i64) -> Self {
4775 PyBindingSliceElement {
4776 inner: SliceElement::SingleIndex(ind),
4777 }
4778 }
4779 #[staticmethod]
4780 pub fn from_sub_array(start: Option<i64>, end: Option<i64>, step: Option<i64>) -> Self {
4781 PyBindingSliceElement {
4782 inner: SliceElement::SubArray(start, end, step),
4783 }
4784 }
4785 #[staticmethod]
4786 pub fn from_ellipsis() -> Self {
4787 PyBindingSliceElement {
4788 inner: SliceElement::Ellipsis,
4789 }
4790 }
4791}
4792
4793pub mod util {
4794 use super::{create_context, Context, Graph, Node};
4795 use crate::errors::Result;
4796
4797 /// Creates a computation context with a single graph within it.
4798 ///
4799 /// The graph is passed to the provided `build_graph_fn` function to
4800 /// specify the computation. The node returned by the function is marked
4801 /// as output.
4802 ///
4803 /// # Returns
4804 ///
4805 /// New computation context
4806 ///
4807 /// # Example
4808 ///
4809 /// ```
4810 /// # use ciphercore_base::graphs::util::simple_context;
4811 /// # use ciphercore_base::data_types::{scalar_type, INT32};
4812 /// let c = simple_context(|g| {
4813 /// let a = g.input(scalar_type(INT32))?;
4814 /// let b = g.input(scalar_type(INT32))?;
4815 /// g.add(a, b)
4816 /// }).unwrap();
4817 /// ```
4818 pub fn simple_context<F>(build_graph_fn: F) -> Result<Context>
4819 where
4820 F: FnOnce(&Graph) -> Result<Node>,
4821 {
4822 let c = create_context()?;
4823 let g = c.create_graph()?;
4824 let out = build_graph_fn(&g)?;
4825 out.set_as_output()?;
4826 g.finalize()?;
4827 g.set_as_main()?;
4828 c.finalize()?;
4829 Ok(c)
4830 }
4831}
4832
4833#[cfg(test)]
4834mod tests {
4835 use super::*;
4836 use crate::data_types::{
4837 array_type, scalar_type, tuple_type, vector_type, BIT, UINT16, UINT64,
4838 };
4839 use crate::inline::inline_ops::InlineConfig;
4840 use crate::mpc::mpc_compiler::{prepare_for_mpc_evaluation, IOStatus};
4841 use crate::version::DATA_VERSION;
4842 use std::panic;
4843 use std::rc::Rc;
4844
4845 #[test]
4846 fn test_wellformed_cases() {
4847 let context = create_unchecked_context().unwrap();
4848 let graph = context.create_graph().unwrap();
4849 let input1 = graph.input(scalar_type(BIT)).unwrap();
4850 let input2 = graph.input(scalar_type(BIT)).unwrap();
4851 graph.add(input1.clone(), input2.clone()).unwrap();
4852 graph.subtract(input1.clone(), input2.clone()).unwrap();
4853 graph.multiply(input1.clone(), input2.clone()).unwrap();
4854 graph.dot(input1.clone(), input2.clone()).unwrap();
4855 graph.matmul(input1.clone(), input2.clone()).unwrap();
4856 graph.truncate(input1.clone(), 123).unwrap();
4857 let input3 = graph.input(array_type(vec![10, 20, 30], BIT)).unwrap();
4858 graph.sum(input3.clone(), vec![1, 2]).unwrap();
4859 graph.permute_axes(input3.clone(), vec![1, 2, 0]).unwrap();
4860 graph.get(input3.clone(), vec![1, 2]).unwrap();
4861 graph
4862 .reshape(input3.clone(), array_type(vec![20, 300], BIT))
4863 .unwrap();
4864 graph.nop(input3.clone()).unwrap();
4865 let key = graph.random(array_type(vec![128], BIT)).unwrap();
4866 graph
4867 .prf(key.clone(), 0, array_type(vec![10, 10], UINT64))
4868 .unwrap();
4869 graph
4870 .stack(vec![input1.clone(), input2.clone()], vec![2, 1])
4871 .unwrap();
4872 let c = graph
4873 .constant(scalar_type(BIT), Value::from_bytes(vec![1]))
4874 .unwrap();
4875 let input4 = graph.input(array_type(vec![10, 10], UINT64)).unwrap();
4876 let bits = graph.a2b(input4.clone()).unwrap();
4877 graph.b2a(bits.clone(), UINT64).unwrap();
4878 let t = graph
4879 .create_tuple(vec![input1.clone(), input2.clone()])
4880 .unwrap();
4881 let _v = graph
4882 .create_vector(scalar_type(BIT), vec![input1.clone(), input2.clone()])
4883 .unwrap();
4884 let nt = graph
4885 .create_named_tuple(vec![
4886 ("Name".to_owned(), input1.clone()),
4887 ("Gender".to_owned(), input2.clone()),
4888 ])
4889 .unwrap();
4890 graph.tuple_get(t, 1).unwrap();
4891 graph.named_tuple_get(nt, "Gender".to_owned()).unwrap();
4892 let v = graph.repeat(c.clone(), 100).unwrap();
4893 graph.zip(vec![v.clone(), v.clone(), v.clone()]).unwrap();
4894 let zero = graph
4895 .constant(scalar_type(UINT64), Value::from_bytes(vec![0; 8]))
4896 .unwrap();
4897 graph.vector_get(v, zero).unwrap();
4898 graph.array_to_vector(input1.clone()).unwrap();
4899 graph.vector_to_array(input1.clone()).unwrap();
4900 }
4901
4902 #[test]
4903 fn call_iterate_test() {
4904 let context = create_unchecked_context().unwrap();
4905 let single_bit_adder = context.create_graph().unwrap();
4906 {
4907 let carry = single_bit_adder.input(scalar_type(BIT)).unwrap();
4908 let inputs = single_bit_adder
4909 .input(tuple_type(vec![scalar_type(BIT), scalar_type(BIT)]))
4910 .unwrap();
4911 let a = single_bit_adder.tuple_get(inputs.clone(), 0).unwrap();
4912 let b = single_bit_adder.tuple_get(inputs.clone(), 1).unwrap();
4913 let ac = single_bit_adder.add(carry.clone(), a.clone()).unwrap();
4914 let bc = single_bit_adder.add(carry.clone(), b.clone()).unwrap();
4915 let result = single_bit_adder.add(ac.clone(), b.clone()).unwrap();
4916 let result_carry = single_bit_adder
4917 .add(
4918 single_bit_adder.multiply(ac.clone(), bc.clone()).unwrap(),
4919 carry,
4920 )
4921 .unwrap();
4922 let output = single_bit_adder
4923 .create_tuple(vec![result_carry.clone(), result.clone()])
4924 .unwrap();
4925 single_bit_adder.set_output_node(output).unwrap();
4926 single_bit_adder.finalize().unwrap();
4927 }
4928 let v32 = vector_type(32, scalar_type(BIT));
4929 let adder = context.create_graph().unwrap();
4930 {
4931 let a = adder.input(v32.clone()).unwrap();
4932 let b = adder.input(v32.clone()).unwrap();
4933 let azb = adder.zip(vec![a, b]).unwrap();
4934 let c = adder
4935 .constant(scalar_type(BIT), Value::from_bytes(vec![0]))
4936 .unwrap();
4937 let cr = adder.iterate(single_bit_adder, c, azb).unwrap();
4938 let r = adder.tuple_get(cr, 1).unwrap();
4939 adder.set_output_node(r).unwrap();
4940 adder.finalize().unwrap();
4941 }
4942 let three_adder = context.create_graph().unwrap();
4943 let a = three_adder.input(v32.clone()).unwrap();
4944 let b = three_adder.input(v32.clone()).unwrap();
4945 let c = three_adder.input(v32.clone()).unwrap();
4946 let result = three_adder
4947 .call(
4948 adder.clone(),
4949 vec![three_adder.call(adder.clone(), vec![a, b]).unwrap(), c],
4950 )
4951 .unwrap();
4952 three_adder.set_output_node(result).unwrap();
4953 three_adder.finalize().unwrap();
4954 context.set_main_graph(three_adder).unwrap();
4955 context.finalize().unwrap();
4956 }
4957
4958 #[test]
4959 fn test_malformed_graphs() {
4960 let context = create_unchecked_context().unwrap();
4961 let graph = context.create_graph().unwrap();
4962 let graph2 = context.create_graph().unwrap();
4963 let input1 = graph.input(scalar_type(UINT64)).unwrap();
4964 let input2 = graph2.input(scalar_type(UINT64)).unwrap();
4965 let e1 = graph.add(input1.clone(), input2.clone());
4966 assert!(e1.is_err());
4967 let fake_node = Node {
4968 body: Arc::new(AtomicRefCell::new(NodeBody {
4969 graph: graph.downgrade(),
4970 node_dependencies: vec![],
4971 graph_dependencies: vec![],
4972 operation: Operation::Input(scalar_type(BIT)),
4973 id: 0,
4974 })),
4975 };
4976 let e2 = graph.add(fake_node.clone(), input1.clone());
4977 assert!(e2.is_err());
4978 let fake_node_2 = Node {
4979 body: Arc::new(AtomicRefCell::new(NodeBody {
4980 graph: graph.downgrade(),
4981 node_dependencies: vec![],
4982 graph_dependencies: vec![],
4983 operation: Operation::Input(scalar_type(BIT)),
4984 id: 31337,
4985 })),
4986 };
4987 let e3 = graph.add(fake_node_2.clone(), input1.clone());
4988 assert!(e3.is_err());
4989 graph.set_output_node(input1.clone()).unwrap();
4990 graph.finalize().unwrap();
4991 let e4 = graph.add(input1.clone(), input1.clone());
4992 assert!(e4.is_err());
4993 let graph3 = context.create_graph().unwrap();
4994 let e5 = graph3.finalize();
4995 assert!(e5.is_err());
4996 let e6 = graph3.set_output_node(input1);
4997 assert!(e6.is_err());
4998 }
4999
5000 #[test]
5001 fn test_malformed_contexts() {
5002 let context = create_unchecked_context().unwrap();
5003 let e1 = context.finalize();
5004 assert!(e1.is_err());
5005 let graph = context.create_graph().unwrap();
5006 let e2 = graph.finalize();
5007 assert!(e2.is_err());
5008 graph
5009 .set_output_node(graph.create_tuple(vec![]).unwrap())
5010 .unwrap();
5011 let e4 = context.set_main_graph(graph.clone());
5012 assert!(e4.is_err());
5013 graph.finalize().unwrap();
5014 let e3 = context.finalize();
5015 assert!(e3.is_err());
5016 context.set_main_graph(graph.clone()).unwrap();
5017 context.finalize().unwrap();
5018 }
5019
5020 #[test]
5021 fn test_malformed_call_iterate() {
5022 let context1 = create_unchecked_context().unwrap();
5023 let graph1 = context1.create_graph().unwrap();
5024 let output = graph1.create_tuple(vec![]).unwrap();
5025 graph1.set_output_node(output).unwrap();
5026 let graph2 = context1.create_graph().unwrap();
5027 let e1 = graph2.call(graph1.clone(), vec![]);
5028 assert!(e1.is_err());
5029 graph1.finalize().unwrap();
5030 graph2.call(graph1.clone(), vec![]).unwrap();
5031 let context2 = create_unchecked_context().unwrap();
5032 let graph3 = context2.create_graph().unwrap();
5033 let e2 = graph3.call(graph1.clone(), vec![]);
5034 assert!(e2.is_err());
5035 let graph4 = context1.create_graph().unwrap();
5036 graph4.input(tuple_type(vec![])).unwrap();
5037 graph4.input(tuple_type(vec![])).unwrap();
5038 let t = graph4.create_tuple(vec![]).unwrap();
5039 let tt = graph4.create_tuple(vec![t.clone(), t.clone()]).unwrap();
5040 graph4.set_output_node(tt).unwrap();
5041 let graph5 = context1.create_graph().unwrap();
5042 let es = graph5.create_tuple(vec![]).unwrap();
5043 let v = graph5
5044 .repeat(graph5.create_tuple(vec![]).unwrap(), 10)
5045 .unwrap();
5046 let e3 = graph5.iterate(graph4.clone(), es.clone(), v.clone());
5047 assert!(e3.is_err());
5048 graph4.finalize().unwrap();
5049 graph5
5050 .iterate(graph4.clone(), es.clone(), v.clone())
5051 .unwrap();
5052 let graph6 = context2.create_graph().unwrap();
5053 let es = graph6.create_tuple(vec![]).unwrap();
5054 let v = graph6
5055 .repeat(graph6.create_tuple(vec![]).unwrap(), 10)
5056 .unwrap();
5057 let e4 = graph6.iterate(graph4.clone(), es.clone(), v.clone());
5058 assert!(e4.is_err());
5059 }
5060
5061 #[test]
5062 fn test_graph_consistency() {
5063 let context = create_unchecked_context().unwrap();
5064 let graph = context.create_graph().unwrap();
5065 let input1 = graph.input(scalar_type(BIT)).unwrap();
5066 let input2 = graph.input(scalar_type(BIT)).unwrap();
5067 graph.add(input1.clone(), input2.clone()).unwrap();
5068 graph.set_output_node(input1.clone()).unwrap();
5069 graph.finalize().unwrap();
5070 for (i, node) in graph.get_nodes().iter().enumerate() {
5071 assert_eq!(node.get_id(), i as u64);
5072 assert!(graph == node.get_graph());
5073 for dependency in node.get_node_dependencies() {
5074 assert!(dependency.get_id() < node.get_id());
5075 }
5076 }
5077 let operations: Vec<Operation> = graph
5078 .get_nodes()
5079 .iter()
5080 .map(|x| x.get_operation())
5081 .collect();
5082 assert!(operations.len() == 3);
5083 if !operations[0].is_input() {
5084 panic!("Input expected");
5085 }
5086 if !operations[1].is_input() {
5087 panic!("Input expected");
5088 }
5089 match operations[2] {
5090 Operation::Add => {}
5091 _ => {
5092 panic!("Add expected");
5093 }
5094 }
5095 }
5096
5097 #[test]
5098 fn test_unfinalized_graphs() {
5099 let context = create_unchecked_context().unwrap();
5100 let e = context.finalize();
5101 assert!(e.is_err());
5102 let graph = context.create_graph().unwrap();
5103 let graph2 = context.create_graph().unwrap();
5104 let e = context.finalize();
5105 assert!(e.is_err());
5106 let i = graph2.input(scalar_type(BIT)).unwrap();
5107 graph2.set_output_node(i).unwrap();
5108 graph2.finalize().unwrap();
5109 context.set_main_graph(graph2).unwrap();
5110 let e = context.finalize();
5111 assert!(e.is_err());
5112 let ii = graph.input(scalar_type(BIT)).unwrap();
5113 graph.set_output_node(ii).unwrap();
5114 graph.finalize().unwrap();
5115 context.finalize().unwrap();
5116 }
5117
5118 #[test]
5119 fn test_operation_serialization() {
5120 let o = Operation::Constant(scalar_type(BIT), Value::from_bytes(vec![1]));
5121 let se = serde_json::to_string(&o).unwrap();
5122 assert_eq!(
5123 se,
5124 format!("{{\"Constant\":[{{\"Scalar\":\"bit\"}},{{\"version\":{},\"data\":\"{{\\\"body\\\":{{\\\"Bytes\\\":[1]}}}}\"}}]}}", DATA_VERSION)
5125 );
5126 let de = serde_json::from_str::<Operation>(&se).unwrap();
5127 assert_eq!(de, o);
5128 }
5129
5130 fn context_generators() -> Vec<Box<dyn Fn() -> Context>> {
5131 let context1 = || {
5132 let context = create_unchecked_context().unwrap();
5133 context
5134 };
5135 let context2 = || {
5136 let context = create_unchecked_context().unwrap();
5137 let graph = context.create_graph().unwrap();
5138 let i = graph.input(scalar_type(BIT)).unwrap();
5139 graph.set_output_node(i).unwrap();
5140 graph.finalize().unwrap();
5141 context.set_main_graph(graph).unwrap();
5142 context.finalize().unwrap();
5143 context
5144 };
5145 let context3 = || {
5146 let context = create_unchecked_context().unwrap();
5147 context.create_graph().unwrap();
5148 context
5149 };
5150 let context4 = || {
5151 let context = create_unchecked_context().unwrap();
5152 let graph = context.create_graph().unwrap();
5153 let i = graph.input(scalar_type(BIT)).unwrap();
5154 graph.set_output_node(i).unwrap();
5155 graph.finalize().unwrap();
5156 context
5157 };
5158 let context5 = || {
5159 let context = create_unchecked_context().unwrap();
5160 let graph = context.create_graph().unwrap();
5161 graph.input(scalar_type(BIT)).unwrap();
5162 context
5163 };
5164 let context6 = || {
5165 let context = create_unchecked_context().unwrap();
5166 let graph = context.create_graph().unwrap();
5167 graph
5168 .constant(scalar_type(BIT), Value::from_bytes(vec![1]))
5169 .unwrap();
5170 context
5171 };
5172 let context7 = || {
5173 let context = create_unchecked_context().unwrap();
5174 let graph = context.create_graph().unwrap();
5175 let i1 = graph.input(scalar_type(BIT)).unwrap();
5176 let i2 = graph.input(scalar_type(BIT)).unwrap();
5177 graph.add(i1, i2).unwrap();
5178 context
5179 };
5180 let context8 = || {
5181 let context = create_unchecked_context().unwrap();
5182 let graph = context.create_graph().unwrap();
5183 let i1 = graph.input(scalar_type(BIT)).unwrap();
5184 let i2 = graph.input(scalar_type(BIT)).unwrap();
5185 graph.add(i2, i1).unwrap();
5186 context
5187 };
5188 let context9 = || {
5189 let context = create_unchecked_context().unwrap();
5190 let graph1 = context.create_graph().unwrap();
5191 let i1 = graph1.input(scalar_type(BIT)).unwrap();
5192 graph1.set_output_node(i1).unwrap();
5193 graph1.finalize().unwrap();
5194 let graph2 = context.create_graph().unwrap();
5195 let i2 = graph2.input(scalar_type(BIT)).unwrap();
5196 graph2.set_output_node(i2).unwrap();
5197 graph2.finalize().unwrap();
5198 let graph3 = context.create_graph().unwrap();
5199 let i = graph3.input(scalar_type(BIT)).unwrap();
5200 graph3.call(graph1, vec![i]).unwrap();
5201 context
5202 };
5203 let context10 = || {
5204 let context = create_unchecked_context().unwrap();
5205 let graph1 = context.create_graph().unwrap();
5206 let i1 = graph1.input(scalar_type(BIT)).unwrap();
5207 graph1.set_output_node(i1).unwrap();
5208 graph1.finalize().unwrap();
5209 let graph2 = context.create_graph().unwrap();
5210 let i2 = graph2.input(scalar_type(BIT)).unwrap();
5211 graph2.set_output_node(i2).unwrap();
5212 graph2.finalize().unwrap();
5213 let graph3 = context.create_graph().unwrap();
5214 let i = graph3.input(scalar_type(BIT)).unwrap();
5215 graph3.call(graph2, vec![i]).unwrap();
5216 context
5217 };
5218 let context11 = || {
5219 let context = create_unchecked_context().unwrap();
5220 let graph1 = context.create_graph().unwrap();
5221 let i1 = graph1.input(scalar_type(BIT)).unwrap();
5222 graph1.set_output_node(i1).unwrap();
5223 graph1.finalize().unwrap();
5224 let graph2 = context.create_graph().unwrap();
5225 let i2 = graph2.input(scalar_type(BIT)).unwrap();
5226 graph2.set_output_node(i2).unwrap();
5227 graph2.finalize().unwrap();
5228 let graph3 = context.create_graph().unwrap();
5229 let i = graph3.input(scalar_type(BIT)).unwrap();
5230 let o = graph3.call(graph2, vec![i]).unwrap();
5231 graph3.set_output_node(o).unwrap();
5232 context
5233 };
5234 let context12 = || {
5235 let context = create_unchecked_context().unwrap();
5236 let graph = context.create_graph().unwrap();
5237 let i = graph.input(scalar_type(BIT)).unwrap();
5238 graph.set_output_node(i).unwrap();
5239 graph.finalize().unwrap();
5240 context.set_main_graph(graph).unwrap();
5241 context
5242 };
5243 let context13 = || {
5244 let context = create_unchecked_context().unwrap();
5245 let graph = context.create_graph().unwrap();
5246 let i = graph.input(scalar_type(BIT)).unwrap();
5247 graph.set_output_node(i).unwrap();
5248 graph.finalize().unwrap();
5249 context.set_main_graph(graph.clone()).unwrap();
5250 context.set_graph_name(graph, "main").unwrap();
5251 context.finalize().unwrap();
5252 context
5253 };
5254 let context14 = || {
5255 let context = create_unchecked_context().unwrap();
5256 let graph = context.create_graph().unwrap();
5257 let i = graph.input(scalar_type(BIT)).unwrap();
5258 graph.set_output_node(i.clone()).unwrap();
5259 graph.finalize().unwrap();
5260 context.set_main_graph(graph.clone()).unwrap();
5261 context.set_graph_name(graph.clone(), "main").unwrap();
5262 context.set_node_name(i.clone(), "input").unwrap();
5263 context
5264 .add_graph_annotation(&graph, GraphAnnotation::AssociativeOperation)
5265 .unwrap();
5266 context
5267 .add_node_annotation(&i, NodeAnnotation::AssociativeOperation)
5268 .unwrap();
5269 context.finalize().unwrap();
5270 context
5271 };
5272 let context15 = || {
5273 let context = create_unchecked_context().unwrap();
5274 let graph = context.create_graph().unwrap();
5275 let mut x = graph.input(scalar_type(BIT)).unwrap();
5276 for i in 1..20 {
5277 let y = graph.input(scalar_type(BIT)).unwrap();
5278 y.set_name(format!("input_{}", i).as_str()).unwrap();
5279 x = graph.add(x, y).unwrap();
5280 }
5281 graph.set_output_node(x).unwrap();
5282 graph.finalize().unwrap();
5283 context
5284 };
5285 let mut closures: Vec<Box<dyn Fn() -> Context>> = vec![];
5286 closures.push(Box::new(context1));
5287 closures.push(Box::new(context2));
5288 closures.push(Box::new(context3));
5289 closures.push(Box::new(context4));
5290 closures.push(Box::new(context5));
5291 closures.push(Box::new(context6));
5292 closures.push(Box::new(context7));
5293 closures.push(Box::new(context8));
5294 closures.push(Box::new(context9));
5295 closures.push(Box::new(context10));
5296 closures.push(Box::new(context11));
5297 closures.push(Box::new(context12));
5298 closures.push(Box::new(context13));
5299 closures.push(Box::new(context14));
5300 closures.push(Box::new(context15));
5301 closures
5302 }
5303
5304 fn test_context_deep_equal_helper_equal<F>(f: F)
5305 where
5306 F: Fn() -> Context,
5307 {
5308 let context1 = f();
5309 let context2 = f();
5310 assert!(context1 != context2);
5311 assert!(contexts_deep_equal(context1, context2));
5312 }
5313
5314 fn test_context_deep_equal_helper_nonequal<F1, F2>(f1: F1, f2: F2)
5315 where
5316 F1: Fn() -> Context,
5317 F2: Fn() -> Context,
5318 {
5319 let context1 = f1();
5320 let context2 = f2();
5321 assert!(context1 != context2);
5322 assert!(!contexts_deep_equal(context1, context2));
5323 }
5324
5325 #[test]
5326 fn test_context_deep_equal() {
5327 let generators = context_generators();
5328 for i in 0..generators.len() {
5329 test_context_deep_equal_helper_equal(&generators[i]);
5330 for j in 0..i {
5331 test_context_deep_equal_helper_nonequal(&generators[i], &generators[j]);
5332 }
5333 }
5334 }
5335
5336 pub fn deserialize_error_lenient(serialized_string: &str, error_msg: &str) {
5337 use std::panic::catch_unwind;
5338 panic::set_hook(Box::new(|_info| {
5339 // See: https://stackoverflow.com/questions/35559267/suppress-panic-output-in-rust-when-using-paniccatch-unwind
5340 }));
5341 let result = catch_unwind(|| serde_json::from_str::<Context>(serialized_string).unwrap());
5342 // This is a (nasty) hack.
5343 // We check whether the returned error contain the expected error message.
5344 use ciphercore_utils::execute_main::extract_panic_message;
5345 if let Err(e) = result {
5346 match extract_panic_message(e) {
5347 Some(msg) => {
5348 if !msg.contains(error_msg) {
5349 panic!("Undesirable panic: {}", msg);
5350 }
5351 }
5352 None => panic!("Panic of unknown type"),
5353 }
5354 } else {
5355 panic!("Expected error not occur")
5356 }
5357 }
5358
5359 use std::{
5360 fs::File,
5361 io::{prelude::*, BufReader},
5362 path::Path,
5363 };
5364
5365 fn lines_from_file(filename: impl AsRef<Path>) -> Vec<String> {
5366 let file = File::open(filename).expect("no such file");
5367 let buf = BufReader::new(file);
5368 buf.lines()
5369 .map(|l| l.expect("Could not parse line"))
5370 .collect()
5371 }
5372
5373 #[test]
5374 fn test_context_serialize() {
5375 let generators = context_generators();
5376 let contexts: Vec<Context> = generators.iter().map(|generator| generator()).collect();
5377 let serialized_contexts: Vec<String> = contexts
5378 .iter()
5379 .map(|context| serde_json::to_string(context).unwrap())
5380 .collect();
5381 let deserialized_contexts: Vec<Context> = serialized_contexts
5382 .iter()
5383 .map(|serialized_context| serde_json::from_str(serialized_context).unwrap())
5384 .collect();
5385 assert_eq!(contexts.len(), deserialized_contexts.len());
5386 for i in 0..contexts.len() {
5387 assert!(contexts[i] != deserialized_contexts[i]);
5388 assert!(contexts_deep_equal(
5389 contexts[i].clone(),
5390 deserialized_contexts[i].clone()
5391 ));
5392 assert_eq!(
5393 serialized_contexts[i],
5394 serde_json::to_string(&deserialized_contexts[i]).unwrap()
5395 )
5396 }
5397
5398 //Read test cases from golden file
5399 let test_case = lines_from_file("./src/test_data/version_testcase.txt");
5400 assert_eq!(serde_json::to_string(&contexts[0]).unwrap(), test_case[0]);
5401
5402 //Following test case expect an error message "Non-existent main graph" which is caused by the field "\"main_graph\":918276318"
5403 deserialize_error_lenient(&test_case[1], "Non-existent main graph");
5404 assert_eq!(serde_json::to_string(&contexts[9]).unwrap(), test_case[2]);
5405 //Following test case expect an error message "Non-existent node dependency" which is caused by the field "\"node_dependencies\":[918723]"
5406 deserialize_error_lenient(&test_case[3], "Non-existent node dependency");
5407 //Following test case expect an error message "Non-existent graph dependency" which is caused by the field "\"graph_dependencies\":[918723]"
5408 deserialize_error_lenient(&test_case[4], "Non-existent graph dependency");
5409 assert_eq!(serde_json::to_string(&contexts[13]).unwrap(), test_case[5]);
5410 //Following test case expect an error message "Non-existent output node" which is caused by the field "\"output_node\":9817273"
5411 deserialize_error_lenient(&test_case[6], "Non-existent output node");
5412 //Following test case expect an error message "graphs_names contain an invalid ID" which is caused by the field "\"graphs_names\":[[8079123,\"main\"]]"
5413 deserialize_error_lenient(&test_case[7], "graphs_names contain an invalid ID");
5414 //Following test case expect an error message "nodes_names contain an invalid graph ID" which is caused by the field "\"nodes_names\":[[[8079123,0],\"input\"]]"
5415 deserialize_error_lenient(&test_case[8], "nodes_names contain an invalid graph ID");
5416 //Following test case expect an error message "nodes_names contain an invalid graph ID" which is caused by the field "\"nodes_names\":[[[0,8079123],\"input\"]]"
5417 deserialize_error_lenient(&test_case[9], "nodes_names contain an invalid node ID");
5418 //Following test case expect an error message "Context version doesn't match the requirement" which is caused by its old version number
5419 deserialize_error_lenient(
5420 &test_case[10],
5421 "Context version doesn't match the requirement",
5422 );
5423 //Following test case expect an error message "Context version doesn't match the requirement" which is caused by its old version number. Although its payload is unsupported, this should not cause any error before passing the version check.
5424 deserialize_error_lenient(
5425 &test_case[11],
5426 "Context version doesn't match the requirement",
5427 );
5428 }
5429
5430 use crate::data_types::INT32;
5431 use crate::data_values::Value;
5432 use crate::evaluators::random_evaluate;
5433 use std::iter::FromIterator;
5434
5435 #[test]
5436 fn test_named_contexts() {
5437 let helper = || -> Result<Context> {
5438 let context = create_context()?;
5439 let graph = context.create_graph()?;
5440 let input_a = graph.input(scalar_type(INT32))?;
5441 let input_b = graph.input(scalar_type(INT32))?;
5442 let output = graph.add(input_a.clone(), input_b.clone())?;
5443 graph.set_output_node(output.clone())?;
5444 graph.finalize()?;
5445 context.set_main_graph(graph.clone())?;
5446 assert!(context.get_graph_name(graph.clone()).is_err());
5447 assert!(context.retrieve_graph("main").is_err());
5448 assert!(context.get_node_name(input_a.clone())?.is_none());
5449 assert!(context.retrieve_node(graph.clone(), "a").is_err());
5450 context.set_graph_name(graph.clone(), "main")?;
5451 context.set_node_name(input_a.clone(), "a")?;
5452 assert!(context.retrieve_node(graph.clone(), "b").is_err());
5453 context.set_node_name(input_b.clone(), "b")?;
5454 context.finalize()?;
5455 assert_eq!(context.get_graph_name(graph.clone())?, "main");
5456 assert_eq!(
5457 context.get_node_name(input_a.clone())?,
5458 Some("a".to_owned())
5459 );
5460 assert_eq!(
5461 context.get_node_name(input_b.clone())?,
5462 Some("b".to_owned())
5463 );
5464 assert!(context.retrieve_node(graph.clone(), "a")? == input_a.clone());
5465 Ok(context)
5466 };
5467 let context = helper().unwrap();
5468 let helper2 = |context: Context| -> Result<i32> {
5469 let other_context = create_context()?;
5470 let other_graph = other_context.create_graph()?;
5471 let input = other_graph.input(scalar_type(BIT))?;
5472 let other_input = other_graph.input(scalar_type(BIT))?;
5473 assert!(context
5474 .prepare_input_values::<Value>(other_graph.clone(), HashMap::new())
5475 .is_err());
5476 assert!(other_context
5477 .prepare_input_values::<Value>(
5478 other_graph.clone(),
5479 HashMap::from_iter([("a", Value::from_scalar(123, INT32)?)])
5480 )
5481 .is_err());
5482 other_context.set_node_name(input, "b")?;
5483 assert!(other_context
5484 .prepare_input_values::<Value>(
5485 other_graph.clone(),
5486 HashMap::from_iter([("a", Value::from_scalar(123, INT32)?)])
5487 )
5488 .is_err());
5489 assert!(other_context
5490 .prepare_input_values::<Value>(
5491 other_graph.clone(),
5492 HashMap::from_iter([("b", Value::from_scalar(123, INT32)?)])
5493 )
5494 .is_err());
5495 other_context.set_node_name(other_input, "c")?;
5496 assert!(other_context
5497 .prepare_input_values::<Value>(
5498 other_graph,
5499 HashMap::from_iter([("b", Value::from_scalar(123, INT32)?)])
5500 )
5501 .is_err());
5502 let g = context.retrieve_graph("main")?;
5503 let result = random_evaluate(
5504 g.clone(),
5505 context.prepare_input_values(
5506 g.clone(),
5507 HashMap::from_iter([
5508 ("a", Value::from_scalar(123, INT32)?),
5509 ("b", Value::from_scalar(456, INT32)?),
5510 ]),
5511 )?,
5512 )?;
5513 let result = result.to_i32(INT32)?;
5514 Ok(result)
5515 };
5516 assert_eq!(helper2(context).unwrap(), 579);
5517 let helper3 = |context: Context| -> Result<()> {
5518 let other_context = create_context()?;
5519 let other_graph = other_context.create_graph()?;
5520 let other_node = other_graph.input(scalar_type(BIT))?;
5521 assert!(context
5522 .set_graph_name(other_graph.clone(), "outside")
5523 .is_err());
5524 assert!(context.get_graph_name(other_graph.clone()).is_err());
5525 assert!(context
5526 .set_node_name(other_node.clone(), "outside")
5527 .is_err());
5528 assert!(context.get_node_name(other_node.clone()).is_err());
5529 assert!(context.retrieve_node(other_graph.clone(), "a").is_err());
5530 Ok(())
5531 };
5532 helper3(helper().unwrap()).unwrap();
5533 let helper4 = || -> Result<()> {
5534 let context = create_context()?;
5535 let graph = context.create_graph()?;
5536 let input = graph.input(scalar_type(BIT))?;
5537 graph.set_output_node(input.clone())?;
5538 graph.finalize()?;
5539 context.set_main_graph(graph.clone())?;
5540 context.finalize()?;
5541 assert!(context.set_graph_name(graph, "main").is_err());
5542 assert!(context.set_node_name(input, "input").is_err());
5543 Ok(())
5544 };
5545 helper4().unwrap();
5546 let helper5 = || -> Result<()> {
5547 let context = create_context()?;
5548 let graph = context.create_graph()?;
5549 let input = graph.input(scalar_type(BIT))?;
5550 let other_graph = context.create_graph()?;
5551 let other_input = graph.input(scalar_type(BIT))?;
5552 context.set_graph_name(graph.clone(), "main")?;
5553 assert!(context.set_graph_name(graph, "main3").is_err());
5554 assert!(context.set_graph_name(other_graph, "main").is_err());
5555 context.set_node_name(input.clone(), "input")?;
5556 assert!(context.set_node_name(input, "input3").is_err());
5557 assert!(context.set_node_name(other_input, "input").is_err());
5558 Ok(())
5559 };
5560 helper5().unwrap();
5561 }
5562
5563 #[test]
5564 fn test_context_type_checking() {
5565 || -> Result<()> {
5566 let context = create_context()?;
5567 let g = context.create_graph()?;
5568 let i = g.input(tuple_type(vec![]))?;
5569 assert!(g.add(i.clone(), i.clone()).is_err());
5570 // Now checking that the node actually have not gotten added by accident
5571 assert_eq!(g.get_nodes().len(), 1);
5572 Ok(())
5573 }()
5574 .unwrap();
5575 }
5576
5577 fn generate_pair_of_equal_contexts() -> Vec<(Context, Context)> {
5578 let context1 = || -> Result<Context> {
5579 let context = create_unchecked_context()?;
5580 let g = context.create_graph()?;
5581 let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
5582 g.set_output_node(i)?;
5583 g.finalize()?;
5584 g.set_as_main()?;
5585 Ok(context)
5586 }()
5587 .unwrap();
5588 let context2 = || -> Result<Context> {
5589 let context = create_unchecked_context()?;
5590 let g = context.create_graph()?;
5591 let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
5592 i.set_as_output()?;
5593 g.finalize()?;
5594 context.set_main_graph(g)?;
5595 Ok(context)
5596 }()
5597 .unwrap();
5598 let context3 = || -> Result<Context> {
5599 let context = create_unchecked_context()?;
5600 let g = context.create_graph()?;
5601 context.set_graph_name(g, "random graph name")?;
5602 Ok(context)
5603 }()
5604 .unwrap();
5605 let context4 = || -> Result<Context> {
5606 let context = create_unchecked_context()?;
5607 context.create_graph()?.set_name("random graph name")?;
5608 Ok(context)
5609 }()
5610 .unwrap();
5611 let context5 = || -> Result<Context> {
5612 let context = create_unchecked_context()?;
5613 let g = context.create_graph()?;
5614 let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
5615 context.set_node_name(i, "random node name")?;
5616 Ok(context)
5617 }()
5618 .unwrap();
5619 let context6 = || -> Result<Context> {
5620 let context = create_unchecked_context()?;
5621 let g = context.create_graph()?;
5622 g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?
5623 .set_name("random node name")?;
5624 Ok(context)
5625 }()
5626 .unwrap();
5627 let context7 = || -> Result<Context> {
5628 let context = create_unchecked_context()?;
5629 let g = context.create_graph()?;
5630 let i1 = g.input(scalar_type(BIT))?;
5631 let i2 = g.input(scalar_type(BIT))?;
5632 g.add(i1.clone(), i2.clone())?;
5633 g.subtract(i1.clone(), i2.clone())?;
5634 g.multiply(i1.clone(), i2.clone())?;
5635 g.dot(i1.clone(), i2.clone())?;
5636 g.matmul(i1.clone(), i2.clone())?;
5637 g.truncate(i1.clone(), 123)?;
5638 g.sum(i1.clone(), vec![1, 4, 7])?;
5639 g.permute_axes(i1.clone(), vec![1, 4, 7])?;
5640 g.get(i1.clone(), vec![1, 4])?;
5641 g.reshape(i1.clone(), array_type(vec![12, 34], BIT))?;
5642 g.nop(i1.clone())?;
5643 g.prf(i1.clone(), 123, scalar_type(BIT))?;
5644 g.a2b(i1.clone())?;
5645 g.b2a(i1.clone(), BIT)?;
5646 g.tuple_get(i1.clone(), 0)?;
5647 g.named_tuple_get(i1.clone(), "field name".to_owned())?;
5648 g.vector_get(i1.clone(), i2)?;
5649 g.array_to_vector(i1.clone())?;
5650 g.vector_to_array(i1.clone())?;
5651 g.repeat(i1.clone(), 123)?;
5652 Ok(context)
5653 }()
5654 .unwrap();
5655 let context8 = || -> Result<Context> {
5656 let context = create_unchecked_context()?;
5657 let g = context.create_graph()?;
5658 let i1 = g.input(scalar_type(BIT))?;
5659 let i2 = g.input(scalar_type(BIT))?;
5660 i1.add(i2.clone())?;
5661 i1.subtract(i2.clone())?;
5662 i1.multiply(i2.clone())?;
5663 i1.dot(i2.clone())?;
5664 i1.matmul(i2.clone())?;
5665 i1.truncate(123)?;
5666 i1.sum(vec![1, 4, 7])?;
5667 i1.permute_axes(vec![1, 4, 7])?;
5668 i1.get(vec![1, 4])?;
5669 i1.reshape(array_type(vec![12, 34], BIT))?;
5670 i1.nop()?;
5671 i1.prf(123, scalar_type(BIT))?;
5672 i1.a2b()?;
5673 i1.b2a(BIT)?;
5674 i1.tuple_get(0)?;
5675 i1.named_tuple_get("field name".to_owned())?;
5676 i1.vector_get(i2)?;
5677 i1.array_to_vector()?;
5678 i1.vector_to_array()?;
5679 i1.repeat(123)?;
5680 Ok(context)
5681 }()
5682 .unwrap();
5683 let result = vec![
5684 (context1, context2),
5685 (context3, context4),
5686 (context5, context6),
5687 (context7, context8),
5688 ];
5689 result
5690 }
5691
5692 #[test]
5693 fn test_node_graph_helpers() {
5694 let pairs_of_contexts = generate_pair_of_equal_contexts();
5695 for (context1, context2) in pairs_of_contexts {
5696 assert!(contexts_deep_equal(context1, context2));
5697 }
5698 || -> Result<()> {
5699 let context = create_context()?;
5700 let g = context.create_graph()?.set_name("graph name")?;
5701 let i = g.input(scalar_type(BIT))?.set_name("node name")?;
5702 assert_eq!(g.get_name()?, "graph name");
5703 assert!(g.retrieve_node("node name")? == i);
5704 assert_eq!(i.get_name()?, Some("node name".to_owned()));
5705 assert_eq!(
5706 g.prepare_input_values(hashmap!("node name" => Value::from_scalar(1, BIT)?))?,
5707 vec![Value::from_scalar(1, BIT)?]
5708 );
5709 Ok(())
5710 }()
5711 .unwrap();
5712 }
5713
5714 #[test]
5715 fn test_operation_fmt_display() {
5716 let test_operation_fmt_display_helper = || -> Result<()> {
5717 let o0 = Rc::new(Operation::Input(scalar_type(UINT16)));
5718 assert_eq!(format!("{}", o0), "Input");
5719 let o1 = Rc::new(Operation::Add);
5720 assert_eq!(format!("{}", o1), "Add");
5721 let o2 = Rc::new(Operation::Truncate(10));
5722 assert_eq!(format!("{}", o2), "Truncate");
5723 let o3 = Rc::new(Operation::Get(vec![10, 20]));
5724 assert_eq!(format!("{}", o3), "Get");
5725 let o4 = Rc::new(Operation::NOP);
5726 assert_eq!(format!("{}", o4), "NOP");
5727 let o5 = Rc::new(Operation::CreateNamedTuple(vec![
5728 "Name".to_string(),
5729 "Address".to_string(),
5730 ]));
5731 assert_eq!(format!("{}", o5), "CreateNamedTuple");
5732 let o6 = Rc::new(Operation::NamedTupleGet("Name".to_string()));
5733 assert_eq!(format!("{}", o6), "NamedTupleGet");
5734 Ok(())
5735 };
5736 test_operation_fmt_display_helper().unwrap();
5737 }
5738
5739 #[test]
5740 fn test_annotations() {
5741 let test_annotations_helper = || -> Result<()> {
5742 let context = create_context()?;
5743 let g = context.create_graph()?;
5744 let i = g.input(scalar_type(BIT))?;
5745 g.add_annotation(GraphAnnotation::AssociativeOperation)?;
5746 i.add_annotation(NodeAnnotation::AssociativeOperation)?;
5747 assert_eq!(
5748 g.get_annotations()?,
5749 vec![GraphAnnotation::AssociativeOperation]
5750 );
5751 assert_eq!(
5752 i.get_annotations()?,
5753 vec![NodeAnnotation::AssociativeOperation]
5754 );
5755 Ok(())
5756 };
5757 test_annotations_helper().unwrap();
5758 }
5759
5760 async fn parallel_get_type(output: Node) -> Result<Type> {
5761 output.get_type()
5762 }
5763
5764 async fn parallel_random_evaluate(graph: Graph, context: Context) -> Result<Value> {
5765 random_evaluate(
5766 graph.clone(),
5767 context.prepare_input_values(
5768 graph,
5769 HashMap::from_iter([
5770 ("one", Value::from_scalar(123, INT32)?),
5771 ("two", Value::from_scalar(456, INT32)?),
5772 ]),
5773 )?,
5774 )
5775 }
5776
5777 async fn parallel_prepare_for_mpc_evaluation(
5778 context: Context,
5779 input_party_map: Vec<Vec<IOStatus>>,
5780 output_parties: Vec<Vec<IOStatus>>,
5781 inline_config: InlineConfig,
5782 ) -> Result<Context> {
5783 prepare_for_mpc_evaluation(context, input_party_map, output_parties, inline_config)
5784 }
5785
5786 #[tokio::test(flavor = "multi_thread", worker_threads = 50)]
5787 async fn test_parallel_after_finalize() -> Result<()> {
5788 let context = create_context()?;
5789 let graph = context.create_graph()?;
5790 let input1 = graph.input(scalar_type(INT32))?;
5791 let input2 = graph.input(scalar_type(INT32))?;
5792 let output = graph.add(input1.clone(), input2.clone())?;
5793 graph.set_output_node(output.clone())?;
5794 graph.finalize()?;
5795 context.set_main_graph(graph.clone())?;
5796
5797 context.set_node_name(input1.clone(), "one")?;
5798 context.set_node_name(input2.clone(), "two")?;
5799
5800 context.finalize()?;
5801
5802 assert!(output.clone().get_type().is_ok());
5803
5804 const PAR_ITERS: usize = 2001;
5805
5806 let mut get_type_futures = vec![];
5807 for _ in 0..PAR_ITERS {
5808 get_type_futures.push(parallel_get_type(output.clone()));
5809 }
5810 futures::future::try_join_all(get_type_futures).await?;
5811
5812 let mut get_random_evaluate_futures = vec![];
5813 for _ in 0..PAR_ITERS {
5814 get_random_evaluate_futures
5815 .push(parallel_random_evaluate(graph.clone(), context.clone()))
5816 }
5817 futures::future::try_join_all(get_random_evaluate_futures).await?;
5818
5819 let input_parties = vec![IOStatus::Party(0), IOStatus::Party(1)];
5820 let output_parties = vec![IOStatus::Party(0)];
5821
5822 let mut get_mpc_eval_futures = vec![];
5823 for _ in 0..PAR_ITERS {
5824 get_mpc_eval_futures.push(parallel_prepare_for_mpc_evaluation(
5825 context.clone(),
5826 vec![input_parties.clone()],
5827 vec![output_parties.clone()],
5828 InlineConfig::default(),
5829 ));
5830 }
5831 futures::future::try_join_all(get_mpc_eval_futures).await?;
5832
5833 Ok(())
5834 }
5835}