ciphercore_base/
graphs.rs

1//! Crucial structs, enums, functions and types to create computation graphs.
2use atomic_refcell::AtomicRefCell;
3use std::collections::HashMap;
4use std::fmt;
5use std::hash::Hash;
6use std::hash::Hasher;
7use std::ptr;
8use std::sync::Arc;
9use std::sync::Weak;
10
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13use crate::constants::type_size_limit_constants;
14use crate::custom_ops::CustomOperation;
15use crate::data_types::{get_size_estimation_in_bits, ArrayShape, ScalarType, Type};
16use crate::data_values::Value;
17use crate::errors::Result;
18use crate::type_inference::{create_type_inference_worker, TypeInferenceWorker};
19
20use crate::version::{VersionedData, DATA_VERSION};
21
22#[cfg(feature = "py-binding")]
23use crate::custom_ops::PyBindingCustomOperation;
24#[cfg(feature = "py-binding")]
25use crate::data_types::{PyBindingScalarType, PyBindingType};
26#[cfg(feature = "py-binding")]
27use crate::typed_value::PyBindingTypedValue;
28#[cfg(feature = "py-binding")]
29use pywrapper_macro::{enum_to_struct_wrapper, fn_wrapper, impl_wrapper, struct_wrapper};
30
31/// This enum represents different types of slice elements that are used to create indexing slices (see [Slice] and [Graph::get_slice]).
32///
33/// The semantics is similar to [the NumPy slice indexing](https://numpy.org/doc/stable/user/basics.indexing.html).
34#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
35#[cfg_attr(feature = "py-binding", enum_to_struct_wrapper)]
36pub enum SliceElement {
37    /// Single index of a given array dimension.
38    ///
39    /// The index is given by a signed integer. If negative, the index is interpreted as in [NumPy](https://numpy.org/doc/stable/user/basics.indexing.html).
40    ///
41    /// For example, to choose all the elements of an array with the last index in the first dimension and the first index in the second dimension, one can use a slice `vec![SingleIndex(-1), SingleIndex(0)]`.
42    SingleIndex(i64),
43    /// Sub-array denotes a range of indices of a given array dimension.
44    ///
45    /// It follows the description of [the NumPy basic slice](https://numpy.org/doc/stable/user/basics.indexing.html), which is defined by 3 signed integers: `start`, `stop`, `step`. `step` can't be equal to zero.
46    ///
47    /// For example, to choose all the elements of an array with even indices in the first dimension, one can use a slice `vec![SubArray(Some(0), None, Some(2))].
48    SubArray(Option<i64>, Option<i64>, Option<i64>),
49    /// Ellipsis denotes several dimensions where indices are not restricted.
50    ///
51    /// For example, to choose all the elements of an array with index `0` in the first dimension and index `2` in the last dimension, one can use a slice `vec![SingleIndex(0), Ellipsis, SingleIndex(2)]`.
52    Ellipsis,
53}
54
55/// Slice type denotes an indexing slice (see [NumPy slicing](https://numpy.org/doc/stable/user/basics.indexing.html)).
56///
57/// It is a vector of slice elements that describes the indices of a sub-array in any appropriate array.
58pub type Slice = Vec<SliceElement>;
59
60#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Copy)]
61#[cfg_attr(feature = "py-binding", enum_to_struct_wrapper)]
62pub enum JoinType {
63    Inner,
64    Left,
65    Union,
66    Full,
67}
68
69#[doc(hidden)]
70#[cfg(feature = "py-binding")]
71#[pyo3::pymethods]
72impl PyBindingJoinType {
73    #[staticmethod]
74    pub fn from_inner() -> Self {
75        PyBindingJoinType {
76            inner: JoinType::Inner,
77        }
78    }
79    #[staticmethod]
80    pub fn from_left() -> Self {
81        PyBindingJoinType {
82            inner: JoinType::Left,
83        }
84    }
85    #[staticmethod]
86    pub fn from_union() -> Self {
87        PyBindingJoinType {
88            inner: JoinType::Union,
89        }
90    }
91}
92
93/// Shard config contains the parameters of the Sharding operation, namely:
94///
95/// - number of shards into which input dataset will be split,
96/// - size of each shard, i.e., the number of rows in each shard,
97/// - headers of columns whose rows are hashed to find the index of a shard where the corresponding row will be placed.
98#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
99#[cfg_attr(feature = "py-binding", struct_wrapper)]
100pub struct ShardConfig {
101    pub num_shards: u64,
102    pub shard_size: u64,
103    /// headers of columns whose rows are hashed to find the index of a shard where the corresponding row will be placed
104    pub shard_headers: Vec<String>,
105}
106
107#[doc(hidden)]
108#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
109pub enum Operation {
110    Input(Type),
111    Zeros(Type),
112    Ones(Type),
113    Add,
114    Subtract,
115    Multiply,
116    // Elementwise multiplication of integer arrays by bit arrays.
117    // It leaves an integer array element as is or make it zero it depending on a bit array element.
118    MixedMultiply,
119    // Dot operation follows the numpy (tensor)dot semantics: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
120    Dot,
121    // Matmul operation follows the numpy matmul semantics: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html
122    // In particular, unlike Dot, it doesn't support scalar inputs.
123    Matmul,
124    Gemm(bool, bool),
125    Truncate(u128),
126    Sum(ArrayShape),
127    CumSum(u64),
128    PermuteAxes(ArrayShape),
129    Get(ArrayShape),
130    GetSlice(Slice),
131    Reshape(Type),
132    NOP,
133    Random(Type),
134    PRF(u64, Type),
135    PermutationFromPRF(u64, u64),
136    Stack(ArrayShape),
137    Concatenate(u64),
138    Constant(Type, Value),
139    A2B,
140    B2A(ScalarType),
141    CreateTuple,
142    CreateNamedTuple(Vec<String>),
143    CreateVector(Type),
144    TupleGet(u64),
145    NamedTupleGet(String),
146    VectorGet,
147    Zip,
148    Repeat(u64),
149    Call,
150    Iterate,
151    ArrayToVector,
152    VectorToArray,
153    // Operations that can't be compiled to MPC protocols
154    RandomPermutation(u64),
155    Gather(u64),
156    CuckooHash,
157    InversePermutation,
158    CuckooToPermutation,
159    DecomposeSwitchingMap(u64),
160    SegmentCumSum,
161    Shard(ShardConfig),
162    // SQL joins
163    Join(JoinType, HashMap<String, String>),
164    JoinWithColumnMasks(JoinType, HashMap<String, String>),
165    ApplyPermutation(bool),
166    Sort(String),
167    Custom(CustomOperation),
168    // Operations used for debugging graphs.
169    Print(String),
170    Assert(String),
171}
172
173impl fmt::Display for Operation {
174    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
175        let operation_name = if let Operation::Custom(custom_op) = self {
176            custom_op.get_name()
177        } else {
178            let operation_w_type_str = format!("{:?}", *self);
179            let split_for_operation = operation_w_type_str.split('(');
180            let vec_operation_and_types: Vec<&str> = split_for_operation.collect();
181            if vec_operation_and_types.is_empty() {
182                "-null-".to_owned()
183            } else {
184                vec_operation_and_types[0].to_owned()
185            }
186        };
187        write!(f, "{operation_name}")
188    }
189}
190
191impl Operation {
192    pub fn is_prf_operation(&self) -> bool {
193        matches!(
194            self,
195            Operation::PRF(_, _) | Operation::PermutationFromPRF(_, _)
196        )
197    }
198
199    pub fn is_broadcasting_called(&self) -> bool {
200        matches!(
201            self,
202            Operation::Add
203                | Operation::Subtract
204                | Operation::Multiply
205                | Operation::Matmul
206                | Operation::Gemm(_, _)
207                | Operation::MixedMultiply
208                | Operation::Stack(_)
209        )
210    }
211
212    pub fn is_mpc_compiled(&self) -> bool {
213        matches!(
214            self,
215            Operation::Input(_)
216                | Operation::Zeros(_)
217                | Operation::Ones(_)
218                | Operation::Add
219                | Operation::Subtract
220                | Operation::Multiply
221                | Operation::MixedMultiply
222                | Operation::Dot
223                | Operation::Matmul
224                | Operation::Gemm(_, _)
225                | Operation::Truncate(_)
226                | Operation::Sum(_)
227                | Operation::CumSum(_)
228                | Operation::PermuteAxes(_)
229                | Operation::Get(_)
230                | Operation::GetSlice(_)
231                | Operation::Reshape(_)
232                | Operation::Stack(_)
233                | Operation::Concatenate(_)
234                | Operation::Constant(_, _)
235                | Operation::A2B
236                | Operation::B2A(_)
237                | Operation::CreateTuple
238                | Operation::CreateNamedTuple(_)
239                | Operation::CreateVector(_)
240                | Operation::TupleGet(_)
241                | Operation::NamedTupleGet(_)
242                | Operation::VectorGet
243                | Operation::Zip
244                | Operation::Repeat(_)
245                | Operation::ArrayToVector
246                | Operation::VectorToArray
247                | Operation::Join(_, _)
248                | Operation::JoinWithColumnMasks(_, _)
249                | Operation::ApplyPermutation(_)
250                | Operation::Sort(_)
251        )
252    }
253
254    pub fn update_prf_id(&self, prf_id: u64) -> Result<Self> {
255        match self {
256            Operation::PRF(_, scalar_type) => Ok(Operation::PRF(prf_id, scalar_type.clone())),
257            Operation::PermutationFromPRF(_, size) => {
258                Ok(Operation::PermutationFromPRF(prf_id, *size))
259            }
260            _ => Err(runtime_error!("Operation is not a PRF operation")),
261        }
262    }
263
264    pub fn is_input(&self) -> bool {
265        matches!(self, Operation::Input(_))
266    }
267
268    pub fn is_const_optimizable(&self) -> Result<bool> {
269        match self {
270            // Zeros and Ones exist precisely because we don't want to store them as Constants
271            // to keep the graph size small.
272            Operation::Zeros(_) | Operation::Ones(_) => Ok(false),
273            op => Ok(!op.is_input() && !op.is_randomizing()?),
274        }
275    }
276
277    // If an operation computes a randomized output, return true
278    pub fn is_randomizing(&self) -> Result<bool> {
279        match self {
280            Operation::Random(_)
281            | Operation::RandomPermutation(_)
282            | Operation::CuckooToPermutation
283            | Operation::DecomposeSwitchingMap(_) => Ok(true),
284            Operation::Input(_)
285            | Operation::Zeros(_)
286            | Operation::Ones(_)
287            | Operation::A2B
288            | Operation::Add
289            | Operation::ApplyPermutation(_)
290            | Operation::ArrayToVector
291            | Operation::Assert(_)
292            | Operation::B2A(_)
293            | Operation::Subtract
294            | Operation::Multiply
295            | Operation::MixedMultiply
296            | Operation::Matmul
297            | Operation::Dot
298            | Operation::Gemm(_, _)
299            | Operation::Truncate(_)
300            | Operation::Sum(_)
301            | Operation::CumSum(_)
302            | Operation::Concatenate(_)
303            | Operation::CreateNamedTuple(_)
304            | Operation::CreateTuple
305            | Operation::CreateVector(_)
306            | Operation::CuckooHash
307            | Operation::SegmentCumSum
308            | Operation::PermuteAxes(_)
309            | Operation::Get(_)
310            | Operation::Gather(_)
311            | Operation::GetSlice(_)
312            | Operation::Reshape(_)
313            | Operation::NOP
314            | Operation::InversePermutation
315            | Operation::PRF(_, _)
316            | Operation::PermutationFromPRF(_, _)
317            | Operation::Stack(_)
318            | Operation::NamedTupleGet(_)
319            | Operation::Sort(_)
320            | Operation::TupleGet(_)
321            | Operation::Constant(_, _)
322            | Operation::VectorGet
323            | Operation::Zip
324            | Operation::Repeat(_)
325            | Operation::VectorToArray
326            | Operation::Join(_, _)
327            | Operation::JoinWithColumnMasks(_, _)
328            | Operation::Print(_)
329            | Operation::Shard(_) => Ok(false),
330            Operation::Call | Operation::Iterate => Err(runtime_error!(
331                "The status of operations calling other graphs cannot be defined"
332            )),
333            Operation::Custom(_) => Err(runtime_error!(
334                "The status of custom operations cannot be defined"
335            )),
336        }
337    }
338}
339
340struct NodeBody {
341    graph: WeakGraph,
342    node_dependencies: Vec<WeakNode>,
343    graph_dependencies: Vec<WeakGraph>,
344    operation: Operation,
345    id: u64,
346}
347
348#[derive(Serialize, Deserialize)]
349struct SerializableNodeBody {
350    node_dependencies: Vec<u64>,
351    graph_dependencies: Vec<u64>,
352    operation: Operation,
353}
354
355type NodeBodyPointer = Arc<AtomicRefCell<NodeBody>>;
356
357/// A structure that stores a pointer to a computation graph node that corresponds to an operation.
358///
359/// [Clone] trait duplicates the pointer, not the underlying nodes.
360///
361/// [PartialEq] trait compares pointers, not the related nodes.
362///
363/// # Example
364///
365/// ```
366/// # use ciphercore_base::graphs::create_context;
367/// # use ciphercore_base::data_types::{scalar_type, BIT};
368/// let c = create_context().unwrap();
369/// let g = c.create_graph().unwrap();
370/// let t = scalar_type(BIT);
371/// let n1 = g.input(t.clone()).unwrap();
372/// let n2 = g.input(t).unwrap();
373/// assert!(n1 != n2);
374/// let n3 = n1.clone();
375/// assert!(n1 == n3);
376/// ```
377#[cfg_attr(feature = "py-binding", struct_wrapper)]
378pub struct Node {
379    body: NodeBodyPointer,
380}
381
382type SerializableNode = Arc<SerializableNodeBody>;
383
384impl Clone for Node {
385    /// Returns a new [Node] value with a copy of the pointer to a node.
386    fn clone(&self) -> Self {
387        Node {
388            body: self.body.clone(),
389        }
390    }
391}
392
393impl PartialEq for Node {
394    /// Tests whether `self` and `other` nodes are equal via comparison of their respective pointers.
395    ///
396    /// # Arguments
397    ///
398    /// `other` - another [Node] value
399    ///
400    /// # Returns
401    ///
402    /// `true` if `self` and `other` are equal, `false` otherwise
403    fn eq(&self, other: &Self) -> bool {
404        Arc::ptr_eq(&self.body, &other.body)
405    }
406}
407
408impl Eq for Node {}
409
410impl Hash for Node {
411    /// Hashes the node pointer.
412    ///
413    /// # Arguments
414    ///
415    /// `state` - state of a hash function that is changed after hashing the node
416    fn hash<H: Hasher>(&self, state: &mut H) {
417        ptr::hash(&*self.body, state);
418    }
419}
420
421/// Public methods which supposed to be imported in Python.
422#[cfg_attr(feature = "py-binding", impl_wrapper)]
423impl Node {
424    /// Returns the parent graph that contains the node.
425    ///
426    /// # Returns
427    ///
428    /// Parent graph of the node
429    pub fn get_graph(&self) -> Graph {
430        self.body.borrow().graph.upgrade()
431    }
432
433    /// Returns the dependency nodes that are used to compute the value in the current node.
434    ///
435    /// # Returns
436    ///
437    /// Vector of nodes used by the node to perform its operation
438    pub fn get_node_dependencies(&self) -> Vec<Node> {
439        self.body
440            .borrow()
441            .node_dependencies
442            .iter()
443            .map(|n| n.upgrade())
444            .collect()
445    }
446
447    /// Returns the dependency graphs that are used to compute the value in the current node.
448    ///
449    /// These dependencies are non-empty only for `Call` and `Iterate` operations.
450    ///
451    /// # Returns
452    ///
453    /// Vector of graphs used by the node to perform its operation
454    pub fn get_graph_dependencies(&self) -> Vec<Graph> {
455        self.body
456            .borrow()
457            .graph_dependencies
458            .iter()
459            .map(|g| g.upgrade())
460            .collect()
461    }
462
463    /// Returns the ID of the node.
464    ///
465    /// A node ID is a serial number of a node between `0` and `n-1` where `n` is the number of nodes in the parent graph.
466    /// This number is equal to the number of nodes in the parent graph before this node was added to it.
467    ///
468    /// # Returns
469    ///
470    /// Node ID
471    pub fn get_id(&self) -> u64 {
472        self.body.borrow().id
473    }
474
475    /// Returns the pair of the parent graph ID and node ID
476    ///
477    /// # Returns
478    ///
479    /// (Graph ID, Node ID)
480    pub fn get_global_id(&self) -> (u64, u64) {
481        (self.get_graph().get_id(), self.get_id())
482    }
483
484    /// Returns the operation associated with the node.
485    ///
486    /// # Returns
487    ///
488    /// Operation associated with the node
489    pub fn get_operation(&self) -> Operation {
490        self.body.borrow().operation.clone()
491    }
492
493    /// Returns the type of the value computed by the node.
494    ///
495    /// # Returns
496    ///
497    /// Output type of the node operation
498    pub fn get_type(&self) -> Result<Type> {
499        let context = self.get_graph().get_context();
500
501        {
502            let context_body = context.body.borrow();
503            if let Some(tc) = &context_body.type_checker {
504                if let Some(cached_type) = tc.cached_node_type(self)? {
505                    return Ok(cached_type);
506                }
507            }
508        }
509
510        let mut context_body = context.body.borrow_mut();
511        if let Some(tc) = &mut context_body.type_checker {
512            tc.process_node(self.clone())
513        } else {
514            Err(runtime_error!("Type checker is not available"))
515        }
516    }
517    /// Applies [Context::set_node_name] to the parent context and `this` node. Returns the clone of `this`.
518    ///
519    /// # Example
520    ///
521    /// ```
522    /// # use ciphercore_base::graphs::create_context;
523    /// # use ciphercore_base::data_types::{scalar_type, BIT};
524    /// let c = create_context().unwrap();
525    /// let g = c.create_graph().unwrap();
526    /// let t = scalar_type(BIT);
527    /// let n = g.input(t).unwrap();
528    /// n.set_name("XOR").unwrap();
529    /// ```
530    pub fn set_name(&self, name: &str) -> Result<Node> {
531        self.get_graph()
532            .get_context()
533            .set_node_name(self.clone(), name)?;
534        Ok(self.clone())
535    }
536
537    /// Applies [Context::get_node_name] to the parent context and `this` node.
538    ///
539    /// # Example
540    ///
541    /// ```
542    /// # use ciphercore_base::graphs::create_context;
543    /// # use ciphercore_base::data_types::{scalar_type, BIT};
544    /// let c = create_context().unwrap();
545    /// let g = c.create_graph().unwrap();
546    /// let t = scalar_type(BIT);
547    /// let n = g.input(t).unwrap();
548    /// n.set_name("XOR").unwrap();
549    /// assert_eq!(n.get_name().unwrap(), Some("XOR".to_owned()));
550    /// ```
551    pub fn get_name(&self) -> Result<Option<String>> {
552        self.get_graph().get_context().get_node_name(self.clone())
553    }
554
555    /// Adds a node to the parent graph that adds elementwise the array or scalar associated with the node to an array or scalar of the same scalar type associated with another node.
556    ///
557    /// Applies [Graph::add] to the parent graph, `this` node and the `b` node.
558    ///
559    /// # Example
560    ///
561    /// ```
562    /// # use ciphercore_base::graphs::create_context;
563    /// # use ciphercore_base::data_types::{BIT, scalar_type};
564    /// let c = create_context().unwrap();
565    /// let g = c.create_graph().unwrap();
566    /// let t = scalar_type(BIT);
567    /// let n1 = g.input(t.clone()).unwrap();
568    /// let n2 = g.input(t).unwrap();
569    /// let n3 = n1.add(n2).unwrap();
570    /// ```
571    pub fn add(&self, b: Node) -> Result<Node> {
572        self.get_graph().add(self.clone(), b)
573    }
574
575    /// Adds a node to the parent graph that subtracts elementwise the array or scalar of the same scalar type associated with another node from an array or scalar associated with the node.
576    ///
577    /// Applies [Graph::subtract] to the parent graph, `this` node and the `b` node.
578    ///
579    /// # Example
580    ///
581    /// ```
582    /// # use ciphercore_base::graphs::create_context;
583    /// # use ciphercore_base::data_types::{BIT, scalar_type};
584    /// let c = create_context().unwrap();
585    /// let g = c.create_graph().unwrap();
586    /// let t = scalar_type(BIT);
587    /// let n1 = g.input(t.clone()).unwrap();
588    /// let n2 = g.input(t).unwrap();
589    /// let n3 = n1.subtract(n2).unwrap();
590    /// ```
591    pub fn subtract(&self, b: Node) -> Result<Node> {
592        self.get_graph().subtract(self.clone(), b)
593    }
594
595    /// Adds a node to the parent graph that multiplies elementwise the array or scalar associated with the node by an array or scalar of the same scalar type associated with another node.
596    ///
597    /// Applies [Graph::multiply] to the parent graph, `this` node and the `b` node.
598    ///
599    /// # Example
600    ///
601    /// ```
602    /// # use ciphercore_base::graphs::create_context;
603    /// # use ciphercore_base::data_types::{BIT, scalar_type};
604    /// let c = create_context().unwrap();
605    /// let g = c.create_graph().unwrap();
606    /// let t = scalar_type(BIT);
607    /// let n1 = g.input(t.clone()).unwrap();
608    /// let n2 = g.input(t).unwrap();
609    /// let n3 = n1.multiply(n2).unwrap();
610    /// ```
611    pub fn multiply(&self, b: Node) -> Result<Node> {
612        self.get_graph().multiply(self.clone(), b)
613    }
614
615    /// Adds a node to the parent graph that multiplies elementwise the array or scalar associated with the node by a binary array or scalar associated with another node.
616    ///
617    /// Applies [Graph::mixed_multiply] to the parent graph, `this` node and the `b` node.
618    ///
619    /// # Example
620    ///
621    /// ```
622    /// # use ciphercore_base::graphs::create_context;
623    /// # use ciphercore_base::data_types::{BIT, INT32, scalar_type};
624    /// let c = create_context().unwrap();
625    /// let g = c.create_graph().unwrap();
626    /// let t = scalar_type(INT32);
627    /// let bit_t = scalar_type(BIT);
628    /// let n1 = g.input(t).unwrap();
629    /// let n2 = g.input(bit_t).unwrap();
630    /// let n3 = n1.mixed_multiply(n2).unwrap();
631    /// ```
632    pub fn mixed_multiply(&self, b: Node) -> Result<Node> {
633        self.get_graph().mixed_multiply(self.clone(), b)
634    }
635
636    /// Adds a node to the parent graph that computes the dot product of arrays or scalars associated with the node and another node.
637    ///
638    /// Applies [Graph::dot] to the parent graph, `this` node and the `b` node.
639    ///
640    /// # Example
641    ///
642    /// ```
643    /// # use ciphercore_base::graphs::create_context;
644    /// # use ciphercore_base::data_types::{INT32, array_type};
645    /// let c = create_context().unwrap();
646    /// let g = c.create_graph().unwrap();
647    /// let t = array_type(vec![10], INT32);
648    /// let n1 = g.input(t.clone()).unwrap();
649    /// let n2 = g.input(t).unwrap();
650    /// let n3 = n1.dot(n2).unwrap();
651    /// ```
652    pub fn dot(&self, b: Node) -> Result<Node> {
653        self.get_graph().dot(self.clone(), b)
654    }
655
656    /// Adds a node to the parent graph that computes the matrix product of two arrays associated with the node and another node.
657    ///
658    /// Applies [Graph::matmul] to the parent graph, `this` node and the `b` node.
659    ///
660    /// # Example
661    ///
662    /// ```
663    /// # use ciphercore_base::graphs::create_context;
664    /// # use ciphercore_base::data_types::{INT32, array_type};
665    /// let c = create_context().unwrap();
666    /// let g = c.create_graph().unwrap();
667    /// let t1 = array_type(vec![2, 3], INT32);
668    /// let t2 = array_type(vec![3, 2], INT32);
669    /// let n1 = g.input(t1).unwrap();
670    /// let n2 = g.input(t2).unwrap();
671    /// let n3 = n1.matmul(n2).unwrap();
672    /// ```
673    pub fn matmul(&self, b: Node) -> Result<Node> {
674        self.get_graph().matmul(self.clone(), b)
675    }
676
677    /// Adds a node to the parent graph that computes the generatl matrix product of two arrays associated with the node and another node.
678    ///
679    /// Applies [Graph::gemm] to the parent graph, `this` node and the `b` node.
680    ///
681    /// # Example
682    ///
683    /// ```
684    /// # use ciphercore_base::graphs::create_context;
685    /// # use ciphercore_base::data_types::{INT32, array_type};
686    /// let c = create_context().unwrap();
687    /// let g = c.create_graph().unwrap();
688    /// let t1 = array_type(vec![2, 3], INT32);
689    /// let t2 = array_type(vec![2, 3], INT32);
690    /// let n1 = g.input(t1).unwrap();
691    /// let n2 = g.input(t2).unwrap();
692    /// let n3 = n1.gemm(n2, false, true).unwrap();
693    /// ```
694    #[doc(hidden)]
695    pub fn gemm(&self, b: Node, transpose_a: bool, transpose_b: bool) -> Result<Node> {
696        self.get_graph()
697            .gemm(self.clone(), b, transpose_a, transpose_b)
698    }
699
700    /// Adds a node that computes a join of a given type on two named tuples along given key headers.
701    /// More detailed documentation can be found in [Graph::join].
702    ///
703    /// Applies [Graph::join] to the parent graph, `this` node and the `b` node.
704    ///
705    /// # Example
706    ///
707    /// ```
708    /// # use ciphercore_base::graphs::{create_context, JoinType};
709    /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type};
710    /// # use ciphercore_base::type_inference::NULL_HEADER;
711    /// # use std::collections::HashMap;
712    /// let c = create_context().unwrap();
713    /// let g = c.create_graph().unwrap();
714    /// let t1n = array_type(vec![100], BIT);
715    /// let t11 = array_type(vec![100], INT32);
716    /// let t12 = array_type(vec![100, 128], BIT);
717    /// let t13 = array_type(vec![100], INT64);
718    /// let t2n = array_type(vec![50], BIT);
719    /// let t21 = array_type(vec![50], INT32);
720    /// let t22 = array_type(vec![50, 128], BIT);
721    /// let t23 = array_type(vec![50], UINT8);
722    /// let t1 = named_tuple_type(vec![
723    ///     (NULL_HEADER.to_owned(), t1n),
724    ///     ("ID".to_owned(), t11),
725    ///     ("Occupation".to_owned(), t12),
726    ///     ("Revenue".to_owned(), t13),
727    /// ]);
728    /// let t2 = named_tuple_type(vec![
729    ///     (NULL_HEADER.to_owned(), t2n),
730    ///     ("ID".to_owned(), t21),
731    ///     ("Job".to_owned(), t22),
732    ///     ("Age".to_owned(), t23),
733    /// ]);
734    /// let n1 = g.input(t1).unwrap();
735    /// let n2 = g.input(t2).unwrap();
736    /// let n3 = n1.join(n2, JoinType::Inner, HashMap::from([
737    ///     ("ID".to_owned(), "ID".to_owned()),
738    ///     ("Occupation".to_owned(), "Job".to_owned()),
739    /// ])).unwrap();
740    /// ```
741    pub fn join(&self, b: Node, t: JoinType, headers: HashMap<String, String>) -> Result<Node> {
742        self.get_graph().join(self.clone(), b, t, headers)
743    }
744
745    /// Adds a node that computes a join of a given type on two named tuples along given key headers.
746    /// More detailed documentation can be found in [Graph::join_with_column_masks].
747    ///
748    /// Applies [Graph::join_with_column_masks] to the parent graph, `this` node and the `b` node.
749    ///
750    /// # Example
751    ///
752    /// ```
753    /// # use ciphercore_base::graphs::{create_context, JoinType};
754    /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type, tuple_type};
755    /// # use ciphercore_base::type_inference::NULL_HEADER;
756    /// # use std::collections::HashMap;
757    /// let c = create_context().unwrap();
758    /// let g = c.create_graph().unwrap();
759    /// let t1n = array_type(vec![100], BIT);
760    /// let t11 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT32)]);
761    /// let t12 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100, 128], BIT)]);
762    /// let t13 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT64)]);
763    /// let t2n = array_type(vec![50], BIT);
764    /// let t21 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], INT32)]);
765    /// let t22 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50, 128], BIT)]);
766    /// let t23 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], UINT8)]);
767    /// let t1 = named_tuple_type(vec![
768    ///     (NULL_HEADER.to_owned(), t1n),
769    ///     ("ID".to_owned(), t11),
770    ///     ("Occupation".to_owned(), t12),
771    ///     ("Revenue".to_owned(), t13),
772    /// ]);
773    /// let t2 = named_tuple_type(vec![
774    ///     (NULL_HEADER.to_owned(), t2n),
775    ///     ("ID".to_owned(), t21),
776    ///     ("Job".to_owned(), t22),
777    ///     ("Age".to_owned(), t23),
778    /// ]);
779    /// let n1 = g.input(t1).unwrap();
780    /// let n2 = g.input(t2).unwrap();
781    /// let n3 = n1.join_with_column_masks(n2, JoinType::Inner, HashMap::from([
782    ///     ("ID".to_owned(), "ID".to_owned()),
783    ///     ("Occupation".to_owned(), "Job".to_owned()),
784    /// ])).unwrap();
785    /// ```
786    pub fn join_with_column_masks(
787        &self,
788        b: Node,
789        t: JoinType,
790        headers: HashMap<String, String>,
791    ) -> Result<Node> {
792        self.get_graph()
793            .join_with_column_masks(self.clone(), b, t, headers)
794    }
795
796    /// Adds a node that applies a permutation to the array along the first dimension.
797    ///
798    /// # Arguments
799    ///
800    /// * `p` - node containing a permutation.
801    ///
802    /// # Returns
803    ///
804    /// New permuted node
805    ///
806    /// # Example
807    ///
808    /// ```
809    /// # use ciphercore_base::graphs::create_context;
810    /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
811    /// let c = create_context().unwrap();
812    /// let g = c.create_graph().unwrap();
813    /// let t = array_type(vec![25, 3], INT32);
814    /// let a = g.input(t).unwrap();
815    /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
816    /// let a = a.apply_permutation(p).unwrap();
817    /// ```
818    #[doc(hidden)]
819    pub fn apply_permutation(&self, p: Node) -> Result<Node> {
820        self.get_graph().apply_permutation(self.clone(), p)
821    }
822
823    /// Adds a node that applies an inverse permutation to the array along the first dimension.
824    ///
825    /// # Arguments
826    ///
827    /// * `p` - node containing a permutation.
828    ///
829    /// # Returns
830    ///
831    /// New permuted node
832    ///
833    /// # Example
834    ///
835    /// ```
836    /// # use ciphercore_base::graphs::create_context;
837    /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
838    /// let c = create_context().unwrap();
839    /// let g = c.create_graph().unwrap();
840    /// let t = array_type(vec![25, 3], INT32);
841    /// let a = g.input(t).unwrap();
842    /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
843    /// let a = a.apply_inverse_permutation(p).unwrap();
844    /// ```
845    #[doc(hidden)]
846    pub fn apply_inverse_permutation(&self, p: Node) -> Result<Node> {
847        self.get_graph().apply_inverse_permutation(self.clone(), p)
848    }
849
850    /// Adds a node that sorts a table given as named tuple according to the column given by the key argument.
851    /// The key column must be a 2-d BIT array of shape [n, b], interpreted as bitstrings of length b.
852    /// Other columns in the named tuple must be arrays of arbitrary type and shape, as long as they
853    /// share the first dimension: [n, ...].
854    /// Bitstrings are sorted lexicographically, and the sorting algorithm is stable: preserving relative
855    /// order of entries in other arrays where the corresponding key entries match.
856    ///
857    /// # Arguments
858    /// * `key` - name of the field to sort on it, this array must be 2-d of type BIT.
859    ///
860    /// # Returns
861    ///
862    /// New sorted node
863    ///
864    /// # Example
865    ///
866    /// ```
867    /// # use ciphercore_base::graphs::create_context;
868    /// # use ciphercore_base::data_types::{BIT, INT32, UINT64, array_type, named_tuple_type};
869    /// let c = create_context().unwrap();
870    /// let g = c.create_graph().unwrap();
871    /// let v1 = g.input(array_type(vec![20], INT32)).unwrap();
872    /// let v2 = g.input(array_type(vec![20, 10, 2], UINT64)).unwrap();
873    /// let k = g.input(array_type(vec![20, 32], BIT)).unwrap();
874    /// let a = g.create_named_tuple(vec![("key".to_string(), k), ("value1".to_string(), v1), ("value2".to_string(), v2)]).unwrap();
875    /// let a = a.sort("key".to_string()).unwrap();
876    /// ```
877    pub fn sort(&self, key: String) -> Result<Node> {
878        self.get_graph().sort(self.clone(), key)
879    }
880
881    /// Adds a node to the parent graph that divides a scalar or each entry of the array associated with the node by a positive constant integer `scale`.
882    ///
883    /// Applies [Graph::add] to the parent graph, `this` node and `scale`.
884    ///
885    /// # Example
886    ///
887    /// ```
888    /// # use ciphercore_base::graphs::create_context;
889    /// # use ciphercore_base::data_types::{INT32, array_type};
890    /// let c = create_context().unwrap();
891    /// let g = c.create_graph().unwrap();
892    /// let t = array_type(vec![2, 3], INT32);
893    /// let n1 = g.input(t).unwrap();
894    /// let n2 = n1.truncate(4).unwrap();
895    /// ```
896    pub fn truncate(&self, scale: u128) -> Result<Node> {
897        self.get_graph().truncate(self.clone(), scale)
898    }
899
900    /// Adds a node to the parent graph that computes the sum of entries of the array associated with the node along given axes.
901    ///
902    /// Applies [Graph::sum] to the parent graph, `this` node and `axes`.
903    ///
904    /// # Example
905    ///
906    /// ```
907    /// # use ciphercore_base::graphs::create_context;
908    /// # use ciphercore_base::data_types::{INT32, array_type};
909    /// let c = create_context().unwrap();
910    /// let g = c.create_graph().unwrap();
911    /// let t = array_type(vec![3, 2, 3], INT32);
912    /// let axes = vec![1, 0];
913    /// let n1 = g.input(t).unwrap();
914    /// let n2 = n1.sum(axes).unwrap();
915    /// ```
916    pub fn sum(&self, axes: ArrayShape) -> Result<Node> {
917        self.get_graph().sum(self.clone(), axes)
918    }
919
920    /// Adds a node to the parent graph that computes the cumulative sum of elements along a given axis.
921    ///
922    /// Applies [Graph::cum_sum] to the parent graph, `this` node and `axis`.
923    ///
924    /// # Example
925    ///
926    /// ```
927    /// # use ciphercore_base::graphs::create_context;
928    /// # use ciphercore_base::data_types::{INT32, array_type};
929    /// let c = create_context().unwrap();
930    /// let g = c.create_graph().unwrap();
931    /// let t = array_type(vec![3, 2], INT32);
932    /// let n1 = g.input(t).unwrap();
933    /// let n2 = n1.cum_sum(1).unwrap();
934    /// ```
935    pub fn cum_sum(&self, axis: u64) -> Result<Node> {
936        self.get_graph().cum_sum(self.clone(), axis)
937    }
938
939    /// Adds a node to the parent graph that permutes the array associated with the node along given axes.
940    ///
941    /// Applies [Graph::permute_axes] to the parent graph, `this` node and `axes`.
942    ///
943    /// # Example
944    ///
945    /// ```
946    /// # use ciphercore_base::graphs::create_context;
947    /// # use ciphercore_base::data_types::{INT32, array_type};
948    /// let c = create_context().unwrap();
949    /// let g = c.create_graph().unwrap();
950    /// let t = array_type(vec![3, 2, 3], INT32);
951    /// let axes = vec![1, 0, 2];
952    /// let n1 = g.input(t).unwrap();
953    /// let n2 = n1.permute_axes(axes).unwrap();
954    /// ```
955    pub fn permute_axes(&self, axes: ArrayShape) -> Result<Node> {
956        self.get_graph().permute_axes(self.clone(), axes)
957    }
958
959    /// Adds a node to the parent graph that inverts a given permutation.
960    ///
961    /// Applies [Graph::inverse_permutation] to the parent graph and `this` node.
962    #[doc(hidden)]
963    pub fn inverse_permutation(&self) -> Result<Node> {
964        self.get_graph().inverse_permutation(self.clone())
965    }
966
967    /// Adds a node to the parent graph that extracts a sub-array with a given index from the array associated with the node.
968    ///
969    /// Applies [Graph::get] to the parent graph, `this` node and `index`.
970    ///
971    /// # Example
972    ///
973    /// ```
974    /// # use ciphercore_base::graphs::create_context;
975    /// # use ciphercore_base::data_types::{INT32, array_type};
976    /// let c = create_context().unwrap();
977    /// let g = c.create_graph().unwrap();
978    /// let t = array_type(vec![3, 2, 3], INT32);
979    /// let index = vec![2];
980    /// let n1 = g.input(t).unwrap();
981    /// let n2 = n1.get(index).unwrap();
982    /// ```
983    pub fn get(&self, index: ArrayShape) -> Result<Node> {
984        self.get_graph().get(self.clone(), index)
985    }
986
987    /// Adds a node that extracts a sub-array corresponding to a given slice from the array associated with the node.
988    ///
989    /// Applies [Graph::get_slice] to the parent graph, `this` node and `slice`.
990    ///
991    /// # Example
992    ///
993    /// ```
994    /// # use ciphercore_base::graphs::{create_context, SliceElement};
995    /// # use ciphercore_base::data_types::{INT32, array_type};
996    /// let c = create_context().unwrap();
997    /// let g = c.create_graph().unwrap();
998    /// let t = array_type(vec![3, 2, 3], INT32);
999    /// let slice = vec![SliceElement::Ellipsis, SliceElement::SubArray(None, None, Some(-2))];
1000    /// let n1 = g.input(t).unwrap();
1001    /// let n2 = n1.get_slice(slice).unwrap();
1002    /// ```
1003    pub fn get_slice(&self, slice: Slice) -> Result<Node> {
1004        self.get_graph().get_slice(self.clone(), slice)
1005    }
1006
1007    /// Adds a node to the parent graph that reshapes a value associated with the node to a given compatible type.
1008    ///
1009    /// Applies [Graph::reshape] to the parent graph, `this` node and `new_type`.
1010    ///
1011    /// # Example
1012    ///
1013    /// ```
1014    /// # use ciphercore_base::graphs::create_context;
1015    /// # use ciphercore_base::data_types::{INT32, array_type};
1016    /// let c = create_context().unwrap();
1017    /// let g = c.create_graph().unwrap();
1018    /// let old_t = array_type(vec![3, 2, 3], INT32);
1019    /// let new_t = array_type(vec![3,6], INT32);
1020    /// let n1 = g.input(old_t).unwrap();
1021    /// let n2 = n1.reshape(new_t).unwrap();
1022    /// ```
1023    pub fn reshape(&self, new_type: Type) -> Result<Node> {
1024        self.get_graph().reshape(self.clone(), new_type)
1025    }
1026
1027    #[doc(hidden)]
1028    pub fn nop(&self) -> Result<Node> {
1029        self.get_graph().nop(self.clone())
1030    }
1031
1032    #[doc(hidden)]
1033    pub fn prf(&self, iv: u64, output_type: Type) -> Result<Node> {
1034        self.get_graph().prf(self.clone(), iv, output_type)
1035    }
1036
1037    #[doc(hidden)]
1038    pub fn permutation_from_prf(&self, iv: u64, n: u64) -> Result<Node> {
1039        self.get_graph().permutation_from_prf(self.clone(), iv, n)
1040    }
1041
1042    /// Adds a node to the parent graph converting an integer array or scalar associated with the node to the binary form.
1043    ///
1044    /// Applies [Graph::a2b] to the parent graph and `this` node.
1045    ///
1046    /// # Example
1047    ///
1048    /// ```
1049    /// # use ciphercore_base::graphs::create_context;
1050    /// # use ciphercore_base::data_types::{array_type, INT32};
1051    /// let c = create_context().unwrap();
1052    /// let g = c.create_graph().unwrap();
1053    /// let t = array_type(vec![3, 2], INT32);
1054    /// let n1 = g.input(t).unwrap();
1055    /// let n2 = n1.a2b().unwrap();
1056    /// ```
1057    pub fn a2b(&self) -> Result<Node> {
1058        self.get_graph().a2b(self.clone())
1059    }
1060
1061    /// Adds a node to the parent graph converting a binary array associated with the node to an array of a given scalar type.
1062    ///
1063    /// Applies [Graph::b2a] to the parent graph, `this` node and `scalar_type`.
1064    ///
1065    /// # Example
1066    ///
1067    /// ```
1068    /// # use ciphercore_base::graphs::create_context;
1069    /// # use ciphercore_base::data_types::{BIT, INT32, array_type};
1070    /// let c = create_context().unwrap();
1071    /// let g = c.create_graph().unwrap();
1072    /// let t = array_type(vec![3, 32], BIT);
1073    /// let n1 = g.input(t).unwrap();
1074    /// let n2 = n1.b2a(INT32).unwrap();
1075    /// ```
1076    pub fn b2a(&self, scalar_type: ScalarType) -> Result<Node> {
1077        self.get_graph().b2a(self.clone(), scalar_type)
1078    }
1079
1080    /// Adds a node that extracts an element of a tuple associated with the node.
1081    ///
1082    /// Applies [Graph::tuple_get] to the parent graph, `this` node and `index`.
1083    ///
1084    /// # Example
1085    ///
1086    /// ```
1087    /// # use ciphercore_base::data_types::{INT32, array_type};
1088    /// # use ciphercore_base::graphs::create_context;
1089    /// let c = create_context().unwrap();
1090    /// let g = c.create_graph().unwrap();
1091    /// let t1 = array_type(vec![3, 2, 3], INT32);
1092    /// let t2 = array_type(vec![2, 3], INT32);
1093    /// let n1 = g.input(t1).unwrap();
1094    /// let n2 = g.input(t2).unwrap();
1095    /// let n3 = g.create_tuple(vec![n1, n2]).unwrap();
1096    /// let n4 = n3.tuple_get(1).unwrap();
1097    /// ```
1098    pub fn tuple_get(&self, index: u64) -> Result<Node> {
1099        self.get_graph().tuple_get(self.clone(), index)
1100    }
1101
1102    /// Adds a node to the parent graph that extracts an element of a named tuple associated with the node.
1103    ///
1104    /// Applies [Graph::named_tuple_get] to the parent graph, `this` node and the `key` string.
1105    ///
1106    /// # Example
1107    ///
1108    /// ```
1109    /// # use ciphercore_base::graphs::create_context;
1110    /// # use ciphercore_base::data_types::{array_type, INT32};
1111    /// let c = create_context().unwrap();
1112    /// let g = c.create_graph().unwrap();
1113    /// let t1 = array_type(vec![3, 2, 3], INT32);
1114    /// let t2 = array_type(vec![2, 3], INT32);
1115    /// let n1 = g.input(t1).unwrap();
1116    /// let n2 = g.input(t2).unwrap();
1117    /// let n3 = g.create_named_tuple(vec![("node1".to_owned(), n1), ("node2".to_owned(), n2)]).unwrap();
1118    /// let n4 = n3.named_tuple_get("node2".to_owned()).unwrap();
1119    /// ```
1120    pub fn named_tuple_get(&self, key: String) -> Result<Node> {
1121        self.get_graph().named_tuple_get(self.clone(), key)
1122    }
1123
1124    /// Adds a node to the parent graph that extracts an element of a vector associated with the node.
1125    ///
1126    /// Applies [Graph::vector_get] to the parent graph, `this` node and the `index` node.
1127    ///
1128    /// # Example
1129    ///
1130    /// ```
1131    /// # use ciphercore_base::graphs::create_context;
1132    /// # use ciphercore_base::data_types::{UINT32, INT32, array_type, scalar_type};
1133    /// # use ciphercore_base::data_values::Value;
1134    /// let c = create_context().unwrap();
1135    /// let g = c.create_graph().unwrap();
1136    /// let t = array_type(vec![3, 2, 3], INT32);
1137    /// let n1 = g.input(t.clone()).unwrap();
1138    /// let n2 = g.input(t.clone()).unwrap();
1139    /// let n3 = g.create_vector(t, vec![n1,n2]).unwrap();
1140    /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
1141    /// let n4 = n3.vector_get(index).unwrap();
1142    /// ```
1143    pub fn vector_get(&self, index: Node) -> Result<Node> {
1144        self.get_graph().vector_get(self.clone(), index)
1145    }
1146
1147    /// Adds a node to the parent graph converting an array associated with the node to a vector.
1148    ///
1149    /// Applies [Graph::array_to_vector] to the parent graph and `this` node.
1150    ///
1151    /// # Example
1152    ///
1153    /// ```
1154    /// # use ciphercore_base::graphs::create_context;
1155    /// # use ciphercore_base::data_types::{array_type, scalar_type, INT32, UINT32};
1156    /// # use ciphercore_base::data_values::Value;
1157    /// let c = create_context().unwrap();
1158    /// let g = c.create_graph().unwrap();
1159    /// let t = array_type(vec![4, 3, 2], INT32);
1160    /// let n1 = g.input(t).unwrap();
1161    /// let n2 = g.array_to_vector(n1).unwrap();
1162    /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
1163    /// let n3 = n2.vector_get(index).unwrap();
1164    ///
1165    /// assert!(n2.get_type().unwrap().is_vector());
1166    /// assert_eq!(n3.get_type().unwrap().get_shape(), vec![3,2]);
1167    /// ```
1168    pub fn array_to_vector(&self) -> Result<Node> {
1169        self.get_graph().array_to_vector(self.clone())
1170    }
1171
1172    /// Adds a node to the parent graph converting a vector associated with the node to an array.
1173    ///
1174    /// Applies [Graph::vector_to_array] to the parent graph and `this` node.
1175    ///
1176    /// # Example
1177    ///
1178    /// ```
1179    /// # use ciphercore_base::graphs::create_context;
1180    /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
1181    /// let c = create_context().unwrap();
1182    /// let g = c.create_graph().unwrap();
1183    /// let t = array_type(vec![3, 2], INT32);
1184    /// let vec_t = vector_type(4, t);
1185    /// let n1 = g.input(vec_t).unwrap();
1186    /// let n2 = n1.vector_to_array().unwrap();
1187    ///
1188    /// assert!(n2.get_type().unwrap().is_array());
1189    /// assert_eq!(n2.get_type().unwrap().get_shape(), vec![4, 3, 2]);
1190    /// ```
1191    pub fn vector_to_array(&self) -> Result<Node> {
1192        self.get_graph().vector_to_array(self.clone())
1193    }
1194
1195    /// Adds a node to the parent graph converting a vector associated with the node to an array.
1196    ///
1197    /// Applies [Graph::gather] to the parent graph and `this` node.
1198    pub fn gather(&self, indices: Node, axis: u64) -> Result<Node> {
1199        self.get_graph().gather(self.clone(), indices, axis)
1200    }
1201
1202    /// Adds a node that creates a vector with `n` copies of a value of this node.
1203    ///
1204    /// Applies [Graph::repeat] to the parent graph, `this` node and `n`.
1205    ///
1206    /// # Example
1207    ///
1208    /// ```
1209    /// # use ciphercore_base::data_types::{INT32, array_type};
1210    /// # use ciphercore_base::graphs::create_context;
1211    /// let c = create_context().unwrap();
1212    /// let g = c.create_graph().unwrap();
1213    /// let t = array_type(vec![3, 2, 3], INT32);
1214    /// let n1 = g.input(t).unwrap();
1215    /// let n2 = n1.repeat(10).unwrap();
1216    /// ```
1217    pub fn repeat(&self, n: u64) -> Result<Node> {
1218        self.get_graph().repeat(self.clone(), n)
1219    }
1220
1221    /// Adds a node returning the Cuckoo hash map of an input array of binary strings using provided hash functions.
1222    ///
1223    /// Applies [Graph::cuckoo_hash] to the parent graph, `this` node and `hash_matrices`.
1224    #[doc(hidden)]
1225    pub fn cuckoo_hash(&self, hash_matrices: Node) -> Result<Node> {
1226        self.get_graph().cuckoo_hash(self.clone(), hash_matrices)
1227    }
1228
1229    /// Adds a node that, given an input multidimensional array A, binary one-dimensional array B (first dimension is n in both array) and starting value v, computes the following iteration
1230    ///
1231    /// output[i] = A[i-1] + B[i-1] * output[i-1]
1232    ///
1233    /// where i in {1,...,n} and output[0] = v.
1234    ///
1235    /// Applies [Graph::segment_cumsum] to the parent graph, `this` node, `binary_array` and `first_row`.
1236    #[doc(hidden)]
1237    pub fn segment_cumsum(&self, binary_array: Node, first_row: Node) -> Result<Node> {
1238        self.get_graph()
1239            .segment_cumsum(self.clone(), binary_array, first_row)
1240    }
1241
1242    /// Adds a node that computes sharding of a given table according to a given sharding config.
1243    /// Sharding config contains names of the columns whose hashed values are used for sharding.
1244    ///
1245    /// Each shard is accompanied by a Boolean mask indicating whether a corresponding row stems from the input table or padded (1 if a row comes from input).
1246    ///
1247    /// Applies [Graph::shard] to the parent graph, `this` node and `shard_config`.
1248    #[doc(hidden)]
1249    pub fn shard(&self, shard_config: ShardConfig) -> Result<Node> {
1250        self.get_graph().shard(self.clone(), shard_config)
1251    }
1252
1253    /// Adds a node that converts a switching map array into a tuple of the following components:
1254    /// - a permutation map array with deletion,
1255    /// - a duplication map array,
1256    /// - a permutation map array without deletion.
1257    ///
1258    /// The composition of these maps is equal to the input switching map, which is an array containing non-unique indices of some array.
1259    ///
1260    /// Applies [Graph::decompose_switching_map] to the parent graph and `this`.
1261    #[doc(hidden)]
1262    pub fn decompose_switching_map(&self, n: u64) -> Result<Node> {
1263        self.get_graph().decompose_switching_map(self.clone(), n)
1264    }
1265
1266    /// Adds a node that converts a Cuckoo hash table to a random permutation.
1267    ///
1268    /// Applies [Graph::cuckoo_to_permutation] to the parent graph and `this` node.
1269    #[doc(hidden)]
1270    pub fn cuckoo_to_permutation(&self) -> Result<Node> {
1271        self.get_graph().cuckoo_to_permutation(self.clone())
1272    }
1273
1274    /// Adds an operation which logs the value of the node at runtime.
1275    pub fn print(&self, message: String) -> Result<Node> {
1276        self.get_graph().print(message, self.clone())
1277    }
1278
1279    /// Applies [Graph::set_output_node] to the parent graph and `this` node.
1280    ///
1281    /// # Returns
1282    ///
1283    /// This node
1284    ///
1285    /// # Example
1286    ///
1287    /// ```
1288    /// # use ciphercore_base::graphs::create_context;
1289    /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
1290    /// let c = create_context().unwrap();
1291    /// let g = c.create_graph().unwrap();
1292    /// let t = array_type(vec![3, 2], INT32);
1293    /// let vec_t = vector_type(4, t);
1294    /// let n1 = g.input(vec_t).unwrap();
1295    /// let n2 = g.vector_to_array(n1).unwrap();
1296    /// n2.set_as_output().unwrap();
1297    /// g.finalize().unwrap();
1298    /// ```
1299    pub fn set_as_output(&self) -> Result<Node> {
1300        self.get_graph().set_output_node(self.clone())?;
1301        Ok(self.clone())
1302    }
1303}
1304
1305/// Methods which aren't supposed to be imported in Python.
1306impl Node {
1307    fn make_serializable(&self) -> SerializableNode {
1308        Arc::new(SerializableNodeBody {
1309            node_dependencies: self
1310                .get_node_dependencies()
1311                .iter()
1312                .map(|n| n.get_id())
1313                .collect(),
1314            graph_dependencies: self
1315                .get_graph_dependencies()
1316                .iter()
1317                .map(|n| n.get_id())
1318                .collect(),
1319            operation: self.get_operation(),
1320        })
1321    }
1322
1323    fn downgrade(&self) -> WeakNode {
1324        WeakNode {
1325            body: Arc::downgrade(&self.body),
1326        }
1327    }
1328
1329    #[doc(hidden)]
1330    pub fn add_annotation(&self, annotation: NodeAnnotation) -> Result<Node> {
1331        self.get_graph()
1332            .get_context()
1333            .add_node_annotation(self, annotation)?;
1334        Ok(self.clone())
1335    }
1336
1337    #[doc(hidden)]
1338    pub fn get_annotations(&self) -> Result<Vec<NodeAnnotation>> {
1339        self.get_graph()
1340            .get_context()
1341            .get_node_annotations(self.clone())
1342    }
1343}
1344type WeakNodeBodyPointer = Weak<AtomicRefCell<NodeBody>>;
1345
1346struct WeakNode {
1347    body: WeakNodeBodyPointer,
1348}
1349
1350impl WeakNode {
1351    //upgrade function panics if the the Node pointer it downgraded from went out of scope
1352    fn upgrade(&self) -> Node {
1353        Node {
1354            body: self.body.upgrade().unwrap(),
1355        }
1356    }
1357}
1358
1359impl Clone for WeakNode {
1360    fn clone(&self) -> Self {
1361        WeakNode {
1362            body: self.body.clone(),
1363        }
1364    }
1365}
1366
1367struct GraphBody {
1368    finalized: bool,
1369    nodes: Vec<Node>,
1370    output_node: Option<WeakNode>,
1371    id: u64,
1372    context: WeakContext,
1373}
1374
1375#[derive(Serialize, Deserialize)]
1376struct SerializableGraphBody {
1377    finalized: bool,
1378    nodes: Vec<SerializableNode>,
1379    output_node: Option<u64>,
1380}
1381
1382type GraphBodyPointer = Arc<AtomicRefCell<GraphBody>>;
1383
1384/// A structure that stores a pointer to a computation graph, where every node corresponds to an operation.
1385///
1386/// # Rust crates
1387///
1388/// [Clone] trait duplicates the pointer, not the underlying graph.
1389///
1390/// [PartialEq] trait compares pointers, not the related graphs.
1391///
1392/// # Example
1393///
1394/// ```
1395/// # use ciphercore_base::graphs::create_context;
1396/// let c = create_context().unwrap();
1397/// let g1 = c.create_graph().unwrap();
1398/// let g2 = c.create_graph().unwrap();
1399/// assert_ne!(g1, g2);
1400/// let g3 = g1.clone();
1401/// assert_eq!(g1, g3);
1402/// ```
1403#[cfg_attr(feature = "py-binding", struct_wrapper)]
1404pub struct Graph {
1405    body: GraphBodyPointer,
1406}
1407
1408type SerializableGraph = Arc<SerializableGraphBody>;
1409
1410impl fmt::Debug for Graph {
1411    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1412        f.debug_struct("Graph")
1413            .field("body", &self.body.as_ptr())
1414            .finish()
1415    }
1416}
1417
1418impl Clone for Graph {
1419    /// Returns a new [Graph] value with a copy of the pointer to a computation graph.
1420    fn clone(&self) -> Self {
1421        Graph {
1422            body: self.body.clone(),
1423        }
1424    }
1425}
1426
1427impl PartialEq for Graph {
1428    /// Tests whether `self` and `other` graphs are equal via comparison of their respective pointers.
1429    ///
1430    /// # Arguments
1431    ///
1432    /// `other` - another [Graph] value
1433    ///
1434    /// # Returns
1435    ///
1436    /// `true` if `self` and `other` are equal, `false` otherwise
1437    fn eq(&self, other: &Self) -> bool {
1438        Arc::ptr_eq(&self.body, &other.body)
1439    }
1440}
1441
1442impl Eq for Graph {}
1443
1444impl Hash for Graph {
1445    /// Hashes the graph pointer.
1446    ///
1447    /// # Arguments
1448    ///
1449    /// `state` - state of a hash function that is changed after hashing the graph
1450    fn hash<H: Hasher>(&self, state: &mut H) {
1451        ptr::hash(&*self.body, state);
1452    }
1453}
1454
1455/// Public methods which supposed to be imported in Python.
1456#[cfg_attr(feature = "py-binding", impl_wrapper)]
1457impl Graph {
1458    /// Applies [Context::set_main_graph] to the parent context and `this` graph. Returns the clone of `this`.
1459    ///
1460    /// # Returns
1461    ///
1462    /// This graph
1463    ///
1464    /// # Example
1465    ///
1466    /// ```
1467    /// # use ciphercore_base::graphs::create_context;
1468    /// # use ciphercore_base::data_types::{array_type, INT32};
1469    /// let c = create_context().unwrap();
1470    /// let g = c.create_graph().unwrap();
1471    /// let t = array_type(vec![3, 2], INT32);
1472    /// let n = g.input(t).unwrap();
1473    /// n.set_as_output().unwrap();
1474    /// g.finalize().unwrap();
1475    /// g.set_as_main().unwrap();
1476    /// ```
1477    pub fn set_as_main(&self) -> Result<Graph> {
1478        self.get_context().set_main_graph(self.clone())?;
1479        Ok(self.clone())
1480    }
1481
1482    /// Applies [Context::set_graph_name] to the parent context and `this` graph. Returns the clone of `this`.
1483    ///
1484    /// # Arguments
1485    ///
1486    /// `name` - name of the graph
1487    ///
1488    /// # Returns
1489    ///
1490    /// This graph
1491    ///
1492    /// # Example
1493    ///
1494    /// ```
1495    /// # use ciphercore_base::graphs::create_context;
1496    /// let c = create_context().unwrap();
1497    /// let g = c.create_graph().unwrap();
1498    /// g.set_name("relu").unwrap();
1499    /// ```
1500    pub fn set_name(&self, name: &str) -> Result<Graph> {
1501        self.get_context().set_graph_name(self.clone(), name)?;
1502        Ok(self.clone())
1503    }
1504
1505    /// Applies [Context::get_graph_name] to the parent context and `this` graph.
1506    ///
1507    /// # Example
1508    ///
1509    /// ```
1510    /// # use ciphercore_base::graphs::create_context;
1511    /// let c = create_context().unwrap();
1512    /// let g = c.create_graph().unwrap();
1513    /// g.set_name("relu").unwrap();
1514    /// assert_eq!(g.get_name().unwrap(), "relu".to_owned());
1515    /// ```
1516    pub fn get_name(&self) -> Result<String> {
1517        self.get_context().get_graph_name(self.clone())
1518    }
1519
1520    /// Applies [Context::retrieve_node] to the parent context and `this` graph.
1521    ///
1522    /// # Example
1523    ///
1524    /// ```
1525    /// # use ciphercore_base::graphs::create_context;
1526    /// # use ciphercore_base::data_types::{BIT, scalar_type};
1527    /// let c = create_context().unwrap();
1528    /// let g = c.create_graph().unwrap();
1529    /// let n = g.input(scalar_type(BIT)).unwrap();
1530    /// n.set_name("input_node").unwrap();
1531    /// assert!(n == g.retrieve_node("input_node").unwrap());
1532    /// ```
1533    pub fn retrieve_node(&self, name: &str) -> Result<Node> {
1534        self.get_context().retrieve_node(self.clone(), name)
1535    }
1536
1537    /// Adds an input node to the graph and returns it.
1538    ///
1539    /// During evaluation, input nodes require values to be supplied.
1540    ///
1541    /// # Arguments
1542    ///
1543    /// `input_type` - type of a new input node
1544    ///
1545    /// # Returns
1546    ///
1547    /// New input node
1548    ///
1549    /// # Example
1550    ///
1551    /// ```
1552    /// # use ciphercore_base::graphs::create_context;
1553    /// # use ciphercore_base::data_types::{BIT, scalar_type};
1554    /// let c = create_context().unwrap();
1555    /// let g = c.create_graph().unwrap();
1556    /// let t = scalar_type(BIT);
1557    /// let n = g.input(t).unwrap();
1558    /// ```
1559    pub fn input(&self, input_type: Type) -> Result<Node> {
1560        self.add_node(vec![], vec![], Operation::Input(input_type))
1561    }
1562
1563    /// Adds an node with zeros of given type.
1564    ///
1565    /// Compared to `constant` this node does produce a big value array in serialized graph.
1566    ///
1567    /// # Arguments
1568    ///
1569    /// `t` - node type
1570    ///
1571    /// # Returns
1572    ///
1573    /// New node with zeros of given type.
1574    ///
1575    /// # Example
1576    ///
1577    /// ```
1578    /// # use ciphercore_base::graphs::create_context;
1579    /// # use ciphercore_base::data_types::{UINT8, array_type};
1580    /// let c = create_context().unwrap();
1581    /// let g = c.create_graph().unwrap();
1582    /// let z = g.zeros(array_type(vec![10, 20], UINT8)).unwrap();
1583    /// ```
1584    pub fn zeros(&self, t: Type) -> Result<Node> {
1585        self.add_node(vec![], vec![], Operation::Zeros(t))
1586    }
1587
1588    /// Adds an node with ones of given type.
1589    ///
1590    /// Compared to `constant` this node does produce a big value array in serialized graph.
1591    ///
1592    /// # Arguments
1593    ///
1594    /// `t` - node type
1595    ///
1596    /// # Returns
1597    ///
1598    /// New node with ones of given type.
1599    ///
1600    /// # Example
1601    ///
1602    /// ```
1603    /// # use ciphercore_base::graphs::create_context;
1604    /// # use ciphercore_base::data_types::{UINT8, array_type};
1605    /// let c = create_context().unwrap();
1606    /// let g = c.create_graph().unwrap();
1607    /// let z = g.ones(array_type(vec![10, 20], UINT8)).unwrap();
1608    /// ```
1609    pub fn ones(&self, t: Type) -> Result<Node> {
1610        self.add_node(vec![], vec![], Operation::Ones(t))
1611    }
1612
1613    /// Adds a node that sums two arrays or scalars of the same scalar type elementwise.
1614    ///
1615    /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, adding two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1616    ///
1617    /// # Arguments
1618    ///
1619    /// * `a` - node containing the first term (array or scalar)
1620    /// * `b` - node containing the second term (array or scalar)
1621    ///
1622    /// # Returns
1623    ///
1624    /// New addition node
1625    ///
1626    /// # Example
1627    ///
1628    /// ```
1629    /// # use ciphercore_base::graphs::create_context;
1630    /// # use ciphercore_base::data_types::{BIT, scalar_type};
1631    /// let c = create_context().unwrap();
1632    /// let g = c.create_graph().unwrap();
1633    /// let t = scalar_type(BIT);
1634    /// let n1 = g.input(t.clone()).unwrap();
1635    /// let n2 = g.input(t).unwrap();
1636    /// let n3 = g.add(n1, n2).unwrap();
1637    /// ```
1638    pub fn add(&self, a: Node, b: Node) -> Result<Node> {
1639        self.add_node(vec![a, b], vec![], Operation::Add)
1640    }
1641
1642    /// Adds a node that subtracts two arrays or scalars of the same scalar type elementwise.
1643    ///
1644    /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, subtracting two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1645    ///
1646    /// # Arguments
1647    ///
1648    /// * `a` - node containing the minuend (array or scalar)
1649    /// * `b` - node containing the subtrahend (array or scalar)
1650    ///
1651    /// # Returns
1652    ///
1653    /// New subtraction node
1654    ///
1655    /// # Example
1656    ///
1657    /// ```
1658    /// # use ciphercore_base::graphs::create_context;
1659    /// # use ciphercore_base::data_types::{BIT, scalar_type};
1660    /// let c = create_context().unwrap();
1661    /// let g = c.create_graph().unwrap();
1662    /// let t = scalar_type(BIT);
1663    /// let n1 = g.input(t.clone()).unwrap();
1664    /// let n2 = g.input(t).unwrap();
1665    /// let n3 = g.subtract(n1, n2).unwrap();
1666    /// ```
1667    pub fn subtract(&self, a: Node, b: Node) -> Result<Node> {
1668        self.add_node(vec![a, b], vec![], Operation::Subtract)
1669    }
1670
1671    /// Adds a node that multiplies two arrays or scalars of the same scalar type elementwise.
1672    ///
1673    /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, multiplication of two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1674    ///
1675    /// # Arguments
1676    ///
1677    /// * `a` - node containing the first factor (array or scalar)
1678    /// * `b` - node containing the second factor (array or scalar)
1679    ///
1680    /// # Returns
1681    ///
1682    /// New multiplication node
1683    ///
1684    /// # Example
1685    ///
1686    /// ```
1687    /// # use ciphercore_base::graphs::create_context;
1688    /// # use ciphercore_base::data_types::{BIT, scalar_type};
1689    /// let c = create_context().unwrap();
1690    /// let g = c.create_graph().unwrap();
1691    /// let t = scalar_type(BIT);
1692    /// let n1 = g.input(t.clone()).unwrap();
1693    /// let n2 = g.input(t).unwrap();
1694    /// let n3 = g.multiply(n1, n2).unwrap();
1695    /// ```
1696    pub fn multiply(&self, a: Node, b: Node) -> Result<Node> {
1697        self.add_node(vec![a, b], vec![], Operation::Multiply)
1698    }
1699
1700    /// Adds a node that multiplies an integer array or scalar by a binary array or scalar elementwise.
1701    /// For each integer element, this operation returns this element or zero depending on the corresponding bit element.
1702    ///
1703    /// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)). For example, multiplication of two arrays of shapes `[10,1,7]` and `[8,1]` results in an array of shape `[10,8,7]`.
1704    ///
1705    /// # Arguments
1706    ///
1707    /// * `a` - node containing an integer array or scalar
1708    /// * `b` - node containing a binary array or scalar
1709    ///
1710    /// # Returns
1711    ///
1712    /// New mixed multiplication node
1713    ///
1714    /// # Example
1715    ///
1716    /// ```
1717    /// # use ciphercore_base::graphs::create_context;
1718    /// # use ciphercore_base::data_types::{BIT, INT32, scalar_type};
1719    /// let c = create_context().unwrap();
1720    /// let g = c.create_graph().unwrap();
1721    /// let t1 = scalar_type(INT32);
1722    /// let t2 = scalar_type(BIT);
1723    /// let n1 = g.input(t1).unwrap();
1724    /// let n2 = g.input(t2).unwrap();
1725    /// let n3 = g.mixed_multiply(n1, n2).unwrap();
1726    /// ```
1727    pub fn mixed_multiply(&self, a: Node, b: Node) -> Result<Node> {
1728        self.add_node(vec![a, b], vec![], Operation::MixedMultiply)
1729    }
1730
1731    /// Adds a node that computes the dot product according to [the NumPy rules](https://numpy.org/doc/stable/reference/generated/numpy.dot.html):
1732    /// * if both factors are 1-dimensional arrays, return their inner product;
1733    /// * if both factors are 2-dimensional arrays, return their matrix product;
1734    /// * if one of the factors is scalar, return the result of [multiply](Graph::multiply);
1735    /// * if the first factor is n-dimensional and the second one is 1-dimensional,
1736    /// compute the elementwise multiplication and return the sum over the last axis.
1737    /// * if both factors are n-dimensional (n>2), return the sum product
1738    /// over the last axis of the first factor and the second-to-last axis of the second factor, i.e.
1739    ///
1740    /// `dot(A, B)[i,j,k,m] = sum(A[i,j,:] * B[k,:,m])` (in the NumPy notation).
1741    ///
1742    /// # Arguments
1743    ///
1744    /// * `a` - node containing the first factor (array or scalar)
1745    /// * `b` - node containing the second factor (array or scalar)
1746    ///
1747    /// # Returns
1748    ///
1749    /// New dot product node
1750    ///
1751    /// # Example
1752    ///
1753    /// ```
1754    /// # use ciphercore_base::graphs::create_context;
1755    /// # use ciphercore_base::data_types::{INT32, array_type};
1756    /// let c = create_context().unwrap();
1757    /// let g = c.create_graph().unwrap();
1758    /// let t = array_type(vec![10], INT32);
1759    /// let n1 = g.input(t.clone()).unwrap();
1760    /// let n2 = g.input(t).unwrap();
1761    /// let n3 = g.dot(n1, n2).unwrap();
1762    /// ```
1763    pub fn dot(&self, a: Node, b: Node) -> Result<Node> {
1764        self.add_node(vec![a, b], vec![], Operation::Dot)
1765    }
1766
1767    /// Adds a node that computes the matrix product of two arrays according to [the NumPy rules](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
1768    ///
1769    /// Each array is represented as an array of 2-dimensional matrix elements and this node returns the elementwise product of such matrix arrays.
1770    ///
1771    /// # Arguments
1772    ///
1773    /// * `a` - node containing the first array
1774    /// * `b` - node containing the second array
1775    ///
1776    /// # Returns
1777    ///
1778    /// New matrix product node
1779    ///
1780    /// # Example
1781    ///
1782    /// ```
1783    /// # use ciphercore_base::graphs::create_context;
1784    /// # use ciphercore_base::data_types::{INT32, array_type};
1785    /// let c = create_context().unwrap();
1786    /// let g = c.create_graph().unwrap();
1787    /// let t1 = array_type(vec![2, 3], INT32);
1788    /// let t2 = array_type(vec![3, 2], INT32);
1789    /// let n1 = g.input(t1).unwrap();
1790    /// let n2 = g.input(t2).unwrap();
1791    /// let n3 = g.matmul(n1, n2).unwrap();
1792    /// ```
1793    pub fn matmul(&self, a: Node, b: Node) -> Result<Node> {
1794        self.add_node(vec![a, b], vec![], Operation::Matmul)
1795    }
1796
1797    /// Adds a node that computes the general matrix product of two arrays according to [the ONNX rules](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) with `alpha = 1`, `beta = 0` and `C = 0`.
1798    ///
1799    /// Each array is represented as an array of 2-dimensional matrix elements and this node returns the elementwise product of such matrix arrays.
1800    /// Each matrix should have at least 2 dimensions.
1801    /// To multiply by 1-dimensional matrices (i.e., vectors), please resort to `matmul` or `dot`.
1802    ///
1803    /// # Arguments
1804    ///
1805    /// * `a` - node containing the first array
1806    /// * `b` - node containing the second array
1807    /// * `transpose_a` - if true, the first array will be transposed
1808    /// * `transpose_b` - if true, the second array will be transposed
1809    ///
1810    /// # Returns
1811    ///
1812    /// New Gemm node
1813    ///
1814    /// # Example
1815    ///
1816    /// ```
1817    /// # use ciphercore_base::graphs::create_context;
1818    /// # use ciphercore_base::data_types::{INT32, array_type};
1819    /// let c = create_context().unwrap();
1820    /// let g = c.create_graph().unwrap();
1821    /// let t1 = array_type(vec![2, 3], INT32);
1822    /// let t2 = array_type(vec![2, 3], INT32);
1823    /// let n1 = g.input(t1).unwrap();
1824    /// let n2 = g.input(t2).unwrap();
1825    /// let n3 = g.gemm(n1, n2, false, true).unwrap();
1826    /// ```
1827    #[doc(hidden)]
1828    pub fn gemm(&self, a: Node, b: Node, transpose_a: bool, transpose_b: bool) -> Result<Node> {
1829        self.add_node(
1830            vec![a, b],
1831            vec![],
1832            Operation::Gemm(transpose_a, transpose_b),
1833        )
1834    }
1835
1836    /// Adds a node that computes a join of a given type on two named tuples along given key headers.
1837    ///
1838    /// Each tuple should consist of arrays having the same number of rows, i.e. the first dimensions of these arrays should be equal.
1839    ///
1840    /// In addition, each named tuple should have a binary array named with NULL_HEADER that contains zeros in rows void of content; otherwise, it contains ones.
1841    /// This column is called the null column.
1842    ///
1843    /// Let `row key` be the bitstring obtained by concatenating data elements for given key headers.
1844    /// **WARNING**: Rows must have unique row keys, except for rows where NULL_HEADER is zero: those rows are ignored.
1845    ///
1846    /// This operation returns:
1847    /// - Inner join: a named tuple containing rows where input tuples have matching row keys.
1848    /// - Left join: a named tuple containing all the rows of the first input tuple merged with the rows of the second input tuple having same row keys.
1849    /// - Union join: a named tuple containing rows of the first input tuple that are not in the inner join and all the rows of the second tuple.
1850    /// In contrast to the SQL union, this operation does not require that input datasets have respective columns of the same type.
1851    /// This means that columns of both datasets are included and filled with zeros where no data can be retrieved.
1852    /// Namely, the rows of the second set in the union join will contain zeros in non-key columns of the first set and vice versa.
1853    /// - Full join: a named tuple containing all the rows of input tuples.
1854    /// If the row key of a row of the first set match with the row key of a row of the second set, they are merged into one.
1855    /// The order of rows goes as follows:
1856    /// 1. the rows of the first set that don't belong to the inner join.
1857    /// 2. all the rows of the second set including those merged with the rows of the first set as in inner join.
1858    /// In this form, full join is computed as `union_join(a, left_join(b, a))`.
1859    ///
1860    /// # Arguments
1861    ///
1862    /// * `a` - node containing the first named tuple
1863    /// * `b` - node containing the second named tuple
1864    /// * `t` - join type (Inner/Left/Union/Full)
1865    /// * `headers` - headers of columns along which the join is performed
1866    ///
1867    /// # Returns
1868    ///
1869    /// New join node
1870    ///
1871    /// # Example
1872    ///
1873    /// ```
1874    /// # use ciphercore_base::graphs::{create_context, JoinType};
1875    /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type};
1876    /// # use ciphercore_base::type_inference::NULL_HEADER;
1877    /// # use std::collections::HashMap;
1878    /// let c = create_context().unwrap();
1879    /// let g = c.create_graph().unwrap();
1880    /// let t1n = array_type(vec![100], BIT);
1881    /// let t11 = array_type(vec![100], INT32);
1882    /// let t12 = array_type(vec![100, 128], BIT);
1883    /// let t13 = array_type(vec![100], INT64);
1884    /// let t2n = array_type(vec![50], BIT);
1885    /// let t21 = array_type(vec![50], INT32);
1886    /// let t22 = array_type(vec![50, 128], BIT);
1887    /// let t23 = array_type(vec![50], UINT8);
1888    /// let t1 = named_tuple_type(vec![
1889    ///     (NULL_HEADER.to_owned(), t1n),
1890    ///     ("ID".to_owned(), t11),
1891    ///     ("Occupation".to_owned(), t12),
1892    ///     ("Revenue".to_owned(), t13),
1893    /// ]);
1894    /// let t2 = named_tuple_type(vec![
1895    ///     (NULL_HEADER.to_owned(), t2n),
1896    ///     ("ID".to_owned(), t21),
1897    ///     ("Job".to_owned(), t22),
1898    ///     ("Age".to_owned(), t23),
1899    /// ]);
1900    /// let n1 = g.input(t1).unwrap();
1901    /// let n2 = g.input(t2).unwrap();
1902    /// let n3 = g.join(n1, n2, JoinType::Inner, HashMap::from([
1903    ///     ("ID".to_owned(), "ID".to_owned()),
1904    ///     ("Occupation".to_owned(), "Job".to_owned()),
1905    /// ])).unwrap();
1906    /// ```
1907    pub fn join(
1908        &self,
1909        a: Node,
1910        b: Node,
1911        t: JoinType,
1912        headers: HashMap<String, String>,
1913    ) -> Result<Node> {
1914        self.add_node(vec![a, b], vec![], Operation::Join(t, headers))
1915    }
1916
1917    /// Adds a node that computes a join of a given type on two named tuples along given key headers.
1918    ///
1919    /// Each tuple should consist of pairs of arrays having the same number of rows, i.e. the first dimensions of these arrays should be equal.
1920    /// Each pair has a binary array that contains zeros in rows where the data array has no content and an array with data.
1921    ///
1922    /// In addition, each named tuple should have a binary array named with NULL_HEADER that contains zeros in rows void of content; otherwise, it contains ones.
1923    /// This column is called the null column.
1924    ///
1925    /// Let `row key` be the bitstring obtained by concatenating data elements for given key headers where all the corresponding mask elements are set to one.
1926    /// **WARNING**: Rows must have unique row keys, except for rows where NULL_HEADER is zero or at least one mask element in given key headers is zero.
1927    /// Rows with zero NULL_HEADER are ignored.
1928    /// We assume that rows with zero mask elements don't match with other rows.
1929    /// Thus, they can't show up in inner and left joins, but can be copied over to the result of union or full joins.
1930    ///
1931    /// This operation returns:
1932    /// - Inner join: a named tuple containing rows where input tuples have matching row keys.
1933    /// - Left join: a named tuple containing all the rows of the first input tuple merged with the rows of the second input tuple having same row keys.
1934    /// - Union join: a named tuple containing rows of the first input tuple that are not in the inner join and all the rows of the second tuple.
1935    /// In contrast to the SQL union, this operation does not require that input datasets have respective columns of the same type.
1936    /// This means that columns of both datasets are included and filled with zeros where no data can be retrieved.
1937    /// Namely, the rows of the second set in the union join will contain zeros in non-key columns of the first set and vice versa.
1938    /// - Full join: a named tuple containing all the rows of input tuples.
1939    /// If the row key of a row of the first set match with the row key of a row of the second set, they are merged into one.
1940    /// The order of rows goes as follows:
1941    /// 1. the rows of the first set that don't belong to the inner join.
1942    /// 2. all the rows of the second set including those merged with the rows of the first set as in inner join.
1943    /// In this form, full join is computed as `union_join(a, left_join(b, a))`.
1944    ///
1945    /// # Arguments
1946    ///
1947    /// * `a` - node containing the first named tuple
1948    /// * `b` - node containing the second named tuple
1949    /// * `t` - join type (Inner/Left/Union/Full)
1950    /// * `headers` - headers of columns along which the join is performed
1951    ///
1952    /// # Returns
1953    ///
1954    /// New join node
1955    ///
1956    /// # Example
1957    ///
1958    /// ```
1959    /// # use ciphercore_base::graphs::{create_context, JoinType};
1960    /// # use ciphercore_base::data_types::{INT32, INT64, UINT8, BIT, array_type, named_tuple_type, tuple_type};
1961    /// # use ciphercore_base::type_inference::NULL_HEADER;
1962    /// # use std::collections::HashMap;
1963    /// let c = create_context().unwrap();
1964    /// let g = c.create_graph().unwrap();
1965    /// let t1n = array_type(vec![100], BIT);
1966    /// let t11 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT32)]);
1967    /// let t12 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100, 128], BIT)]);
1968    /// let t13 = tuple_type(vec![array_type(vec![100], BIT), array_type(vec![100], INT64)]);
1969    /// let t2n = array_type(vec![50], BIT);
1970    /// let t21 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], INT32)]);
1971    /// let t22 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50, 128], BIT)]);
1972    /// let t23 = tuple_type(vec![array_type(vec![50], BIT), array_type(vec![50], UINT8)]);
1973    /// let t1 = named_tuple_type(vec![
1974    ///     (NULL_HEADER.to_owned(), t1n),
1975    ///     ("ID".to_owned(), t11),
1976    ///     ("Occupation".to_owned(), t12),
1977    ///     ("Revenue".to_owned(), t13),
1978    /// ]);
1979    /// let t2 = named_tuple_type(vec![
1980    ///     (NULL_HEADER.to_owned(), t2n),
1981    ///     ("ID".to_owned(), t21),
1982    ///     ("Job".to_owned(), t22),
1983    ///     ("Age".to_owned(), t23),
1984    /// ]);
1985    /// let n1 = g.input(t1).unwrap();
1986    /// let n2 = g.input(t2).unwrap();
1987    /// let n3 = g.join_with_column_masks(n1, n2, JoinType::Inner, HashMap::from([
1988    ///     ("ID".to_owned(), "ID".to_owned()),
1989    ///     ("Occupation".to_owned(), "Job".to_owned()),
1990    /// ])).unwrap();
1991    /// ```
1992    pub fn join_with_column_masks(
1993        &self,
1994        a: Node,
1995        b: Node,
1996        t: JoinType,
1997        headers: HashMap<String, String>,
1998    ) -> Result<Node> {
1999        self.add_node(
2000            vec![a, b],
2001            vec![],
2002            Operation::JoinWithColumnMasks(t, headers),
2003        )
2004    }
2005
2006    /// Adds a node that applies a permutation to the array along the first dimension.
2007    ///
2008    /// # Arguments
2009    ///
2010    /// * `a` - node containing an array to permute.
2011    /// * `p` - node containing a permutation.
2012    ///
2013    /// # Returns
2014    ///
2015    /// New permuted node
2016    ///
2017    /// # Example
2018    ///
2019    /// ```
2020    /// # use ciphercore_base::graphs::create_context;
2021    /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
2022    /// let c = create_context().unwrap();
2023    /// let g = c.create_graph().unwrap();
2024    /// let t = array_type(vec![25, 3], INT32);
2025    /// let a = g.input(t).unwrap();
2026    /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
2027    /// let a = g.apply_permutation(a, p).unwrap();
2028    /// ```
2029    #[doc(hidden)]
2030    pub fn apply_permutation(&self, a: Node, p: Node) -> Result<Node> {
2031        self.add_node(vec![a, p], vec![], Operation::ApplyPermutation(false))
2032    }
2033
2034    /// Adds a node that applies an inverse permutation to the array along the first dimension.
2035    ///
2036    /// # Arguments
2037    ///
2038    /// * `a` - node containing an array to permute.
2039    /// * `p` - node containing a permutation.
2040    ///
2041    /// # Returns
2042    ///
2043    /// New permuted node
2044    ///
2045    /// # Example
2046    ///
2047    /// ```
2048    /// # use ciphercore_base::graphs::create_context;
2049    /// # use ciphercore_base::data_types::{INT32, UINT64, array_type};
2050    /// let c = create_context().unwrap();
2051    /// let g = c.create_graph().unwrap();
2052    /// let t = array_type(vec![25, 3], INT32);
2053    /// let a = g.input(t).unwrap();
2054    /// let p = g.input(array_type(vec![25], UINT64)).unwrap();
2055    /// let a = g.apply_inverse_permutation(a, p).unwrap();
2056    /// ```
2057    #[doc(hidden)]
2058    pub fn apply_inverse_permutation(&self, a: Node, p: Node) -> Result<Node> {
2059        self.add_node(vec![a, p], vec![], Operation::ApplyPermutation(true))
2060    }
2061
2062    /// Adds a node that sorts a table given as named tuple according to the column given by the key argument.
2063    /// The key column must be a 2-d BIT array of shape [n, b], interpreted as bitstrings of length b.
2064    /// Other columns in the named tuple must be arrays of arbitrary type and shape, as long as they
2065    /// share the first dimension: [n, ...].
2066    /// Bitstrings are sorted lexicographically, and the sorting algorithm is stable: preserving relative
2067    /// order of entries in other arrays where the corresponding key entries match.
2068    ///
2069    /// # Arguments
2070    /// * `a` - node containing a named tuple -- arrays to sort.
2071    /// * `key` - name of the field to sort on it, this array must be 2-d of type BIT.
2072    ///
2073    /// # Returns
2074    ///
2075    /// New sorted node
2076    ///
2077    /// # Example
2078    ///
2079    /// ```
2080    /// # use ciphercore_base::graphs::create_context;
2081    /// # use ciphercore_base::data_types::{BIT, INT32, UINT64, array_type, named_tuple_type};
2082    /// let c = create_context().unwrap();
2083    /// let g = c.create_graph().unwrap();
2084    /// let v1 = g.input(array_type(vec![20], INT32)).unwrap();
2085    /// let v2 = g.input(array_type(vec![20, 10, 2], UINT64)).unwrap();
2086    /// let k = g.input(array_type(vec![20, 32], BIT)).unwrap();
2087    /// let a = g.create_named_tuple(vec![("key".to_string(), k), ("value1".to_string(), v1), ("value2".to_string(), v2)]).unwrap();
2088    /// let a = g.sort(a, "key".to_string()).unwrap();
2089    /// ```
2090    pub fn sort(&self, a: Node, key: String) -> Result<Node> {
2091        self.add_node(vec![a], vec![], Operation::Sort(key))
2092    }
2093
2094    /// Adds a node that divides a scalar or each entry of an array by a positive constant integer `scale`.
2095    ///
2096    /// # Arguments
2097    ///
2098    /// * `a` - node containing a scalar or an array
2099    /// * `scale` - positive integer
2100    ///
2101    /// # Returns
2102    ///
2103    /// New truncate node
2104    ///
2105    /// # Example
2106    ///
2107    /// ```
2108    /// # use ciphercore_base::graphs::create_context;
2109    /// # use ciphercore_base::data_types::{INT32, array_type};
2110    /// let c = create_context().unwrap();
2111    /// let g = c.create_graph().unwrap();
2112    /// let t = array_type(vec![2, 3], INT32);
2113    /// let n1 = g.input(t).unwrap();
2114    /// let n2 = g.truncate(n1, 4).unwrap();
2115    /// ```
2116    pub fn truncate(&self, a: Node, scale: u128) -> Result<Node> {
2117        self.add_node(vec![a], vec![], Operation::Truncate(scale))
2118    }
2119
2120    /// Adds a node that computes the sum of entries of an array along given axes (see [numpy.sum](https://numpy.org/doc/stable/reference/generated/numpy.sum.html)).
2121    ///
2122    /// For example, summing the array `[[1000, 200], [30, 4]]` along the first or the second axes results in the arrays `[1030,204]` or `[1200,34]`, respectively. Summing along both axes yields `1234`.
2123    ///
2124    /// # Arguments
2125    ///
2126    /// * `a` - node containing an array
2127    /// * `axes` - indices of the axes of `a`
2128    ///
2129    /// # Returns
2130    ///
2131    /// New sum node
2132    ///
2133    /// # Example
2134    ///
2135    /// ```
2136    /// # use ciphercore_base::graphs::create_context;
2137    /// # use ciphercore_base::data_types::{INT32, array_type};
2138    /// let c = create_context().unwrap();
2139    /// let g = c.create_graph().unwrap();
2140    /// let t = array_type(vec![3, 2, 3], INT32);
2141    /// let axes = vec![1, 0];
2142    /// let n1 = g.input(t).unwrap();
2143    /// let n2 = g.sum(n1, axes).unwrap();
2144    /// ```
2145    pub fn sum(&self, a: Node, axes: ArrayShape) -> Result<Node> {
2146        self.add_node(vec![a], vec![], Operation::Sum(axes))
2147    }
2148
2149    /// Adds a node that computes the cumulative sum of elements along a given axis. (see [numpy.cumsum](https://numpy.org/doc/stable/reference/generated/numpy.cumsum.html)).
2150    ///
2151    /// For example, summing the array `[[1000, 200], [30, 4]]` along the first or the second axes results in the arrays `[[1000, 200], [1030, 204]]` or `[[1000, 1200], [30, 34]]`, respectively.
2152    ///
2153    /// # Arguments
2154    ///
2155    /// * `a` - node containing an array
2156    /// * `axis` - axis along which the cumulative sum is computed
2157    ///
2158    /// # Returns
2159    ///
2160    /// New cumulative sum node
2161    ///
2162    /// # Example
2163    ///
2164    /// ```
2165    /// # use ciphercore_base::graphs::create_context;
2166    /// # use ciphercore_base::data_types::{INT32, array_type};
2167    /// let c = create_context().unwrap();
2168    /// let g = c.create_graph().unwrap();
2169    /// let t = array_type(vec![3, 2], INT32);
2170    /// let n1 = g.input(t).unwrap();
2171    /// let n2 = g.cum_sum(n1, 1).unwrap();
2172    /// ```
2173    pub fn cum_sum(&self, a: Node, axis: u64) -> Result<Node> {
2174        self.add_node(vec![a], vec![], Operation::CumSum(axis))
2175    }
2176
2177    /// Adds a node that permutes an array along given axes (see [numpy.transpose](https://numpy.org/doc/stable/reference/generated/numpy.transpose.html)). This function generalizes matrix transposition.
2178    ///
2179    /// For example, permutation of an array of shape `[a,b,c]` with permutation `[2,0,1]` results in an array of shape `[c,a,b]`.
2180    ///
2181    /// # Arguments
2182    ///
2183    /// * `a` - node containing an array
2184    /// * `axes` - indices of the axes of `a`
2185    ///
2186    /// # Returns
2187    ///
2188    /// New node with permuted axes
2189    ///
2190    /// # Example
2191    ///
2192    /// ```
2193    /// # use ciphercore_base::graphs::create_context;
2194    /// # use ciphercore_base::data_types::{INT32, array_type};
2195    /// let c = create_context().unwrap();
2196    /// let g = c.create_graph().unwrap();
2197    /// let t = array_type(vec![3, 2, 3], INT32);
2198    /// let axes = vec![1, 0, 2];
2199    /// let n1 = g.input(t).unwrap();
2200    /// let n2 = g.permute_axes(n1, axes).unwrap();
2201    /// ```
2202    pub fn permute_axes(&self, a: Node, axes: ArrayShape) -> Result<Node> {
2203        self.add_node(vec![a], vec![], Operation::PermuteAxes(axes))
2204    }
2205
2206    /// Adds a node to the parent graph that inverts a given permutation.
2207    ///
2208    /// An input permutation should be given by a 1-dimensional array of length n, containing unique integers between 0 and n-1.
2209    /// The i-th element of an output array is output[i] = j if input[j] = i.
2210    ///
2211    /// This operation could be realized through [the Scatter operation](https://en.wikipedia.org/wiki/Gather-scatter_(vector_addressing)#Scatter).
2212    /// However, the Scatter operation poses a security risk as the corresponding map should hide empty output positions.
2213    /// This is usually done by padding an input array with dummy values such that its size is equal to the output size.
2214    /// Then, the Scatter map can be turned into a permutation, which can be easily split into a composition of random permutation maps.
2215    /// But permutation maps can be performed by Gather, thus making Scatter unnecessary.
2216    ///
2217    /// **WARNING**: this function should not be used before MPC compilation.
2218    ///
2219    /// # Arguments
2220    ///
2221    /// `a` - node containing an array with permutation.
2222    #[doc(hidden)]
2223    pub fn inverse_permutation(&self, a: Node) -> Result<Node> {
2224        self.add_node(vec![a], vec![], Operation::InversePermutation)
2225    }
2226
2227    /// Adds a node that extracts a sub-array with a given index. This is a special case of [get_slice](Graph::get_slice) and corresponds to single element indexing as in [NumPy](https://numpy.org/doc/stable/user/basics.indexing.html).
2228    ///
2229    /// For example, given an array `A` of shape `[a,b,c,d]`, its subarray `B` of shape `[c,d]` with index `[i,j]` can be extracted as follows
2230    ///
2231    /// `B = A[i,j,:,:]` (in the NumPy notation)
2232    ///
2233    /// # Arguments
2234    ///
2235    /// * `a` - node containing an array
2236    /// * `index` - index of a sub-array
2237    ///
2238    /// # Returns
2239    ///
2240    /// New node containing an extracted sub-array
2241    ///
2242    /// # Example
2243    ///
2244    /// ```
2245    /// # use ciphercore_base::graphs::create_context;
2246    /// # use ciphercore_base::data_types::{INT32, array_type};
2247    /// let c = create_context().unwrap();
2248    /// let g = c.create_graph().unwrap();
2249    /// let t = array_type(vec![3, 2, 3], INT32);
2250    /// let index = vec![2];
2251    /// let n1 = g.input(t).unwrap();
2252    /// let n2 = g.get(n1, index).unwrap();
2253    /// ```
2254    pub fn get(&self, a: Node, index: ArrayShape) -> Result<Node> {
2255        self.add_node(vec![a], vec![], Operation::Get(index))
2256    }
2257
2258    /// Adds a node that extracts a sub-array corresponding to a given slice.
2259    ///
2260    /// Our slicing conventions follow [the NumPy rules](https://numpy.org/doc/stable/user/basics.indexing.html).
2261    ///
2262    /// For example, given an array `A` of shape `[a,b]`, its subarray `B` containing only the last 3 rows of `A` can be extracted as follows
2263    ///
2264    /// `get_slice(A, [-3::])[i,j] = A[a-3+i,j]`.
2265    ///
2266    /// Slices are defined as vectors of [SliceElements](SliceElement) that have 3 possible types:
2267    ///
2268    /// * [SingleIndex(`i64`)](SliceElement::SingleIndex) is used to extract all the elements with a given index in a respective dimension,
2269    /// * [SubArray(`Option<i64>, Option<i64>, Option<i64>`)](SliceElement::SubArray) describes the range of indices that should be extracted over a certain dimension (similar to the `a:b:c` notation in [NumPy](https://numpy.org/doc/stable/user/basics.indexing.html))
2270    /// * [Ellipsis](SliceElement::Ellipsis) describes several consecutive dimensions that must be extracted in full, e.g. the slice `[i,...,j]` can be used to extract all the elements with the index `i` in the first dimension and the index `j` in the last one, while the indices of all the other dimensions have no constraints. See [the NumPy slicing](https://numpy.org/doc/stable/user/basics.indexing.html) for more details.
2271    ///
2272    /// # Arguments
2273    ///
2274    /// * `a` - node containing an array
2275    /// * `slice` - array slice
2276    ///
2277    /// # Returns
2278    ///
2279    /// New node containing an extracted sub-array
2280    ///
2281    /// # Example
2282    ///
2283    /// ```
2284    /// # use ciphercore_base::graphs::{create_context, SliceElement};
2285    /// # use ciphercore_base::data_types::{INT32, array_type};
2286    /// let c = create_context().unwrap();
2287    /// let g = c.create_graph().unwrap();
2288    /// let t = array_type(vec![3, 2, 3], INT32);
2289    /// let slice = vec![SliceElement::Ellipsis, SliceElement::SubArray(None, None, Some(-2))];
2290    /// let n1 = g.input(t).unwrap();
2291    /// let n2 = g.get_slice(n1, slice).unwrap();
2292    /// ```
2293    pub fn get_slice(&self, a: Node, slice: Slice) -> Result<Node> {
2294        self.add_node(vec![a], vec![], Operation::GetSlice(slice))
2295    }
2296
2297    /// Adds a node that reshapes a value to a given compatible type (similar to [numpy.reshape](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.reshape.html?highlight=reshape#numpy.ndarray.reshape), but more general). Specifically,
2298    ///
2299    /// * if the input value is an array, it can be reshaped to any array with the same number of elements;
2300    /// * if the input value in the flattened form contains `n` arrays or scalars, it can be reshaped to any type with the same number of arrays and scalars. Each array can be reshaped as in the above rule.
2301    ///
2302    /// For example, an array of shape `[3,10,5]` can be reshaped to `[2,75]`. A tuple with arrays of shapes `[3,4]`, `[12]`, `[2,6]` can be reshaped to a vector with 3 array elements of shape `[2,2,3]`.
2303    ///
2304    /// # Arguments
2305    ///
2306    /// * `a` - node containing a value
2307    /// * `new_type` - type
2308    ///
2309    /// # Returns
2310    ///
2311    /// New node with a reshaped value
2312    ///
2313    /// # Example
2314    ///
2315    /// ```
2316    /// # use ciphercore_base::graphs::create_context;
2317    /// # use ciphercore_base::data_types::{INT32, array_type};
2318    /// let c = create_context().unwrap();
2319    /// let g = c.create_graph().unwrap();
2320    /// let old_t = array_type(vec![3, 2, 3], INT32);
2321    /// let new_t = array_type(vec![3,6], INT32);
2322    /// let n1 = g.input(old_t).unwrap();
2323    /// let n2 = g.reshape(n1, new_t).unwrap();
2324    /// ```
2325    pub fn reshape(&self, a: Node, new_type: Type) -> Result<Node> {
2326        let size_estimate = get_size_estimation_in_bits(new_type.clone());
2327        if size_estimate.is_err() {
2328            return Err(runtime_error!(
2329                "Trying to add a reshape node with invalid type size: {:?}",
2330                size_estimate
2331            ));
2332        }
2333        if size_estimate? > type_size_limit_constants::MAX_INDIVIDUAL_NODE_SIZE {
2334            return Err(runtime_error!(
2335                "Trying to add a reshape node larger than MAX_INDIVIDUAL_NODE_SIZE"
2336            ));
2337        }
2338        self.add_node(vec![a], vec![], Operation::Reshape(new_type))
2339    }
2340
2341    /// Adds a node creating a random value of a given type.
2342    ///
2343    /// **WARNING**: this function should not be used before MPC compilation.
2344    ///
2345    /// # Arguments
2346    ///
2347    /// `output_type` - type of a constant
2348    ///
2349    /// # Returns
2350    ///
2351    /// New random node
2352    ///
2353    /// # Example
2354    ///
2355    /// ```
2356    /// # use ciphercore_base::graphs::create_context;
2357    /// # use ciphercore_base::data_types::{BIT, scalar_type};
2358    /// let c = create_context().unwrap();
2359    /// let g = c.create_graph().unwrap();
2360    /// let t = scalar_type(BIT);
2361    /// let n = g.random(t).unwrap();
2362    /// ```
2363    #[doc(hidden)]
2364    pub fn random(&self, output_type: Type) -> Result<Node> {
2365        self.add_node(vec![], vec![], Operation::Random(output_type))
2366    }
2367
2368    /// Adds a node creating a random permutation map of a one-dimensional array of length `n`.
2369    ///
2370    /// This operation generates a random array of all 64-bit integers from 0 to n-1 in random order.
2371    ///
2372    /// **WARNING**: this function should not be used before MPC compilation.
2373    ///
2374    /// # Arguments
2375    ///
2376    /// `n` - length of permutation
2377    ///
2378    /// # Returns
2379    ///
2380    /// New random permutation node
2381    #[doc(hidden)]
2382    pub fn random_permutation(&self, n: u64) -> Result<Node> {
2383        self.add_node(vec![], vec![], Operation::RandomPermutation(n))
2384    }
2385
2386    /// Adds a node returning the Cuckoo hash map of an input array of binary strings using provided hash functions.
2387    ///
2388    /// Hash functions are defined as an array of binary matrices.
2389    /// The hash of an input string is a product of one of these matrices and this string.
2390    /// Hence, the last dimension of these matrices should coincide with the length of input strings.
2391    ///
2392    /// Random matrices yield a better success probability of hashing.
2393    ///
2394    /// If the input array has shape `[..., n, b]` and hash matrices are given as an `[h, m, b]`-array,
2395    /// then the hash map is an array of shape `[..., 2^m]`.
2396    /// The hash table element with index `[..., i]` is equal to `j` if the `[..., j]`-th input `b`-bit string is hashed to `i` by some of the given hash functions.
2397    ///
2398    /// The number of hash matrices (the first dimension of hash matrices) must be at least 3.
2399    ///
2400    /// A bigger ratio `m/n` leads to higher success probability (recommended one is `>=2`)    
2401    ///
2402    /// **WARNING**: this function should not be used before MPC compilation.
2403    ///
2404    /// # Arguments
2405    ///
2406    /// - `array` - input array of binary strings of shape [..., n, b]
2407    /// - `hash_matrices` - random binary [h, m, b]-array.
2408    ///
2409    /// # Returns
2410    ///
2411    /// New CuckooHash node
2412    #[doc(hidden)]
2413    pub fn cuckoo_hash(&self, array: Node, hash_matrices: Node) -> Result<Node> {
2414        self.add_node(vec![array, hash_matrices], vec![], Operation::CuckooHash)
2415    }
2416
2417    /// Adds a node that, given an input multidimensional array A, binary one-dimensional array B (first dimension is n in both array) and starting value v, computes the following iteration
2418    ///
2419    /// output[i] = A[i-1] + B[i-1] * output[i-1]
2420    ///
2421    /// where i in {1,...,n} and output[0] = v.
2422    /// This is similar to computing cumulative sums of consecutive elements (segments) of the input array A.
2423    /// The locations of these segments are defined by the binary array B.
2424    ///
2425    /// This iteration is used in the Duplication protocol (see mpc::mpc_psi) and is done locally by one of the computing parties.
2426    ///
2427    /// **WARNING**: this function should not be used before MPC compilation.
2428    ///
2429    /// # Arguments
2430    ///
2431    /// - `input_array` - input array whose rows are summed within the iteration
2432    /// - `binary_array` - binary array indicating whether a row of the input array should be added to a previous row of the output array
2433    /// - `first_row` - first row of the output array
2434    ///
2435    /// # Returns
2436    ///
2437    /// New SegmentCumSum node containing the output array
2438    #[doc(hidden)]
2439    pub fn segment_cumsum(
2440        &self,
2441        input_array: Node,
2442        binary_array: Node,
2443        first_row: Node,
2444    ) -> Result<Node> {
2445        self.add_node(
2446            vec![input_array, binary_array, first_row],
2447            vec![],
2448            Operation::SegmentCumSum,
2449        )
2450    }
2451
2452    /// Adds a node that computes sharding of a given table according to a given sharding config.
2453    /// Sharding config contains names of the columns whose hashed values are used for sharding.
2454    /// The size of each shard (i.e., the number of rows) and the number of shards is given in the sharding config.
2455    /// The number of shards should be smaller than 700.
2456    ///
2457    ///
2458    /// If some resulting shards don't have `shard_size` elements, they're padded with zeros to reach this size.
2459    /// If the size of some shards exceeds `shard_size`, sharding fails.
2460    ///
2461    /// To choose these parameters, consult [the following paper](http://wwwmayr.informatik.tu-muenchen.de/personen/raab/publ/balls.pdf).
2462    /// Note that for large shard sizes and small number of shards, it holds that
2463    ///
2464    /// `shard_size = num_input_rows / num_shards + alpha * sqrt(2 * num_input_rows / num_shards * log(num_shards))`.
2465    ///
2466    /// With `alpha = 2`, it is possible to achieve failure probability 2^(-40) if `num_shards < 700` and `shard_size > 2^17`.
2467    ///
2468    ///
2469    /// Each shard is accompanied by a Boolean mask indicating whether a corresponding row stems from the input table or padded (1 if a row comes from input).
2470    /// The output is given in the form of a tuple of `(mask, shard)`, where `mask` is a binary array and `shard` is a table, i.e., named tuple.
2471    ///
2472    /// **WARNING**: this function cannot be compiled to an MPC protocol.
2473    ///
2474    /// # Arguments
2475    ///
2476    /// - `input_table` - named tuple of arrays containing data for sharding
2477    /// - `shard_config` - parameters of sharding: number of shards, shard size and names of columns that are hashed in sharding
2478    ///
2479    /// # Returns
2480    ///
2481    /// New Shard node containing a tuple of shards
2482    #[doc(hidden)]
2483    pub fn shard(&self, input_table: Node, shard_config: ShardConfig) -> Result<Node> {
2484        self.add_node(vec![input_table], vec![], Operation::Shard(shard_config))
2485    }
2486
2487    /// Adds a node that converts a switching map array into a random tuple of the following components:
2488    /// - a permutation map array with deletion (some indices of this map are uniformly random, see below),
2489    /// - a tuple of duplication map array and duplication bits,
2490    /// - a permutation map array without deletion.
2491    ///
2492    /// The composition of these maps is equal to the input switching map, which is an array containing non-unique indices of some array.
2493    ///
2494    /// To create a permutation with deletion, this operation first groups identical indices of the input map together and shifts other indices accordingly, e.g.
2495    ///
2496    /// [1, 4, 5, 7, 2, 4] -> [1, 4, 4, 5, 7, 2].
2497    ///
2498    /// This can be done by permutation p = [1, 2, 6, 3, 4, 5].
2499    /// Then, it replaces copies with unique random indices not present in the switching map, e.g.
2500    ///
2501    /// [1, 4, 4, 5, 7, 2] -> [1, 4, 3, 5, 7, 2].
2502    ///
2503    ///
2504    /// A duplication map is a tuple of two one-dimensional arrays of length `n`.
2505    /// The first array contains indices from `[0,n]` in the increasing order with possible repetitions.
2506    /// The second array contains only zeros and ones.
2507    /// If its i-th element is zero, it means that the duplication map doesn't change the i-th element of an array it acts upon.
2508    /// If map's i-th element is one, then the map copies the previous element of the result.
2509    /// This rules can be summarized by the following equation
2510    ///
2511    /// duplication_indices[i] = duplication_bits[i] * duplication_indices[i-1] + (1 - duplication_bits[i]) * i.
2512    ///
2513    /// A duplication map is created from the above switching map with grouped indices, replacing the first index occurrence with 0 and other copies with 1, e.g.
2514    ///
2515    ///  [1, 4, 4, 5, 7, 2] -> ([0, 1, 1, 3, 4, 5], [0, 0, 1, 0, 0, 0]).
2516    ///
2517    /// The last permutation is the inverse of the above permutation p, i.e.
2518    ///
2519    /// [1, 2, 4, 5, 6, 3].
2520    ///
2521    /// This operation supports vectorization.
2522    ///
2523    /// **WARNING**: this function should not be used before MPC compilation.
2524    ///
2525    /// # Arguments
2526    ///
2527    /// - `switching_map` - an array of one-dimensional arrays containing non-unique indices of some array of length `n` (usually a simple hash table),
2528    /// - `n` - length of an array that can be mapped by the above switching map.
2529    ///
2530    /// # Returns
2531    ///
2532    /// New DecomposeSwitchingMap node
2533    #[doc(hidden)]
2534    pub fn decompose_switching_map(&self, switching_map: Node, n: u64) -> Result<Node> {
2535        self.add_node(
2536            vec![switching_map],
2537            vec![],
2538            Operation::DecomposeSwitchingMap(n),
2539        )
2540    }
2541
2542    /// Adds a node that converts a Cuckoo hash table to a random permutation.
2543    ///
2544    /// Conversion is done via replacing dummy hash elements by random indices such that the resulting array constitute a permutation.
2545    ///
2546    /// **WARNING**: this function should not be used before MPC compilation.
2547    ///
2548    /// # Arguments
2549    ///
2550    /// `cuckoo_map` - an array containing a Cuckoo hash map with dummy values
2551    ///
2552    /// # Returns
2553    ///
2554    /// New CuckooToPermutation node
2555    #[doc(hidden)]
2556    pub fn cuckoo_to_permutation(&self, cuckoo_map: Node) -> Result<Node> {
2557        self.add_node(vec![cuckoo_map], vec![], Operation::CuckooToPermutation)
2558    }
2559
2560    /// Adds a node that joins a sequence of arrays governed by a given shape.
2561    ///
2562    /// The input arrays should have the same shape or be able to be broadcast to the same shape.
2563    ///
2564    /// For example, stacking 2 arrays of shapes `[2,2]` and `[2,1]` with the outer shape `[2]` works as follows
2565    ///
2566    /// `stack(arrays=[[[1,2],[3,4]], [5,6]], shape=[2]) = [[[1,2],[3,4]], [[5,5], [6,6]]]`
2567    ///
2568    /// # Arguments
2569    ///
2570    /// * `nodes` - vector of nodes containing arrays
2571    /// * `outer_shape` - shape defining how the input arrays are arranged in the resulting array
2572    ///
2573    /// # Returns
2574    ///
2575    /// New stack node
2576    ///
2577    /// # Example
2578    ///
2579    /// ```
2580    /// # use ciphercore_base::graphs::create_context;
2581    /// # use ciphercore_base::data_types::{INT32, array_type};
2582    /// let c = create_context().unwrap();
2583    /// let g = c.create_graph().unwrap();
2584    /// let t1 = array_type(vec![3, 2, 3], INT32);
2585    /// let t2 = array_type(vec![2, 3], INT32);
2586    /// let shape = vec![2];
2587    /// let n1 = g.input(t1).unwrap();
2588    /// let n2 = g.input(t2).unwrap();
2589    /// let n3 = g.stack(vec![n1,n2], shape).unwrap();
2590    /// ```
2591    pub fn stack(&self, nodes: Vec<Node>, outer_shape: ArrayShape) -> Result<Node> {
2592        self.add_node(nodes, vec![], Operation::Stack(outer_shape))
2593    }
2594
2595    /// Adds a node that joins a sequence of arrays along a given axis.
2596    /// This operation is similar to [the NumPy concatenate](https://numpy.org/doc/stable/reference/generated/numpy.concatenate.html).
2597    ///
2598    /// The input arrays should have the same shape except in the given axis.
2599    ///
2600    /// # Arguments
2601    ///
2602    /// * `nodes` - vector of nodes containing arrays
2603    /// * `axis` - axis along which the above arrays are joined
2604    ///
2605    /// # Returns
2606    ///
2607    /// New Concatenate node
2608    ///
2609    /// # Example
2610    ///
2611    /// ```
2612    /// # use ciphercore_base::graphs::create_context;
2613    /// # use ciphercore_base::data_types::{INT32, array_type};
2614    /// let c = create_context().unwrap();
2615    /// let g = c.create_graph().unwrap();
2616    /// let t1 = array_type(vec![3, 2, 3], INT32);
2617    /// let t2 = array_type(vec![3, 2, 10], INT32);
2618    /// let shape = vec![2];
2619    /// let n1 = g.input(t1).unwrap();
2620    /// let n2 = g.input(t2).unwrap();
2621    /// let n3 = g.concatenate(vec![n1,n2], 2).unwrap();
2622    /// ```
2623    pub fn concatenate(&self, nodes: Vec<Node>, axis: u64) -> Result<Node> {
2624        self.add_node(nodes, vec![], Operation::Concatenate(axis))
2625    }
2626
2627    /// Adds a node creating a constant of a given type and value.
2628    ///
2629    /// # Arguments
2630    ///
2631    /// * `output_type` - type of a constant
2632    /// * `value` - value of a constant
2633    ///
2634    /// # Returns
2635    ///
2636    /// New constant node
2637    ///
2638    /// # Example
2639    ///
2640    /// ```
2641    /// # use ciphercore_base::graphs::create_context;
2642    /// # use ciphercore_base::data_types::{BIT, scalar_type};
2643    /// # use ciphercore_base::data_values::Value;
2644    /// let c = create_context().unwrap();
2645    /// let g = c.create_graph().unwrap();
2646    /// let t = scalar_type(BIT);
2647    /// let v = Value::from_scalar(0, BIT).unwrap();
2648    /// let n = g.constant(t, v).unwrap();
2649    /// ```
2650    pub fn constant(&self, output_type: Type, value: Value) -> Result<Node> {
2651        self.add_node(vec![], vec![], Operation::Constant(output_type, value))
2652    }
2653
2654    /// Adds a node converting an integer array or scalar to the binary form.
2655    ///
2656    /// Given an array of shape `[a,b,c]` and scalar type `st`, this node returns an array of shape `[a,b,c,s]` where `s` is the bit size of `st`. For example, an array of shape `[1,2,3]` with `INT32` entries will be converted to a binary array of shape `[1,2,3,32]`.
2657    ///
2658    /// # Arguments
2659    ///
2660    /// `a` - node containing an array or scalar
2661    ///
2662    /// # Returns
2663    ///
2664    /// New node converting an array/scalar to the binary form
2665    ///
2666    /// # Example
2667    ///
2668    /// ```
2669    /// # use ciphercore_base::graphs::create_context;
2670    /// # use ciphercore_base::data_types::{array_type, INT32};
2671    /// let c = create_context().unwrap();
2672    /// let g = c.create_graph().unwrap();
2673    /// let t = array_type(vec![3, 2], INT32);
2674    /// let n1 = g.input(t).unwrap();
2675    /// let n2 = g.a2b(n1).unwrap();
2676    /// ```
2677    pub fn a2b(&self, a: Node) -> Result<Node> {
2678        self.add_node(vec![a], vec![], Operation::A2B)
2679    }
2680
2681    /// Adds a node converting a binary array to an array of a given scalar type.
2682    ///
2683    /// Given a binary array of shape `[a,b,c]` and a scalar type `st` of bit size `c`, this node returns an array of shape `[a,b]` with `st` entries. For example, a binary array of shape `[2,3,32]` can be converted to an array of shape `[2,3]` with `INT32` entries.
2684    ///
2685    /// # Arguments
2686    ///
2687    /// * `a` - node containing an array or scalar
2688    /// * `scalar_type` - scalar type
2689    ///
2690    /// # Returns
2691    ///
2692    /// New node converting an array from the binary form
2693    ///
2694    /// # Example
2695    ///
2696    /// ```
2697    /// # use ciphercore_base::graphs::create_context;
2698    /// # use ciphercore_base::data_types::{BIT, INT32, array_type};
2699    /// let c = create_context().unwrap();
2700    /// let g = c.create_graph().unwrap();
2701    /// let t = array_type(vec![3, 32], BIT);
2702    /// let n1 = g.input(t).unwrap();
2703    /// let n2 = g.b2a(n1, INT32).unwrap();
2704    /// ```
2705    pub fn b2a(&self, a: Node, scalar_type: ScalarType) -> Result<Node> {
2706        self.add_node(vec![a], vec![], Operation::B2A(scalar_type))
2707    }
2708
2709    /// Adds a node that creates a tuple from several (possibly, zero) elements.
2710    ///
2711    /// # Arguments
2712    ///
2713    /// `elements` - vector of nodes
2714    ///
2715    /// # Returns
2716    ///
2717    /// New node with a tuple
2718    ///
2719    /// # Example
2720    ///
2721    /// ```
2722    /// # use ciphercore_base::graphs::create_context;
2723    /// # use ciphercore_base::data_types::{INT32, array_type};
2724    /// let c = create_context().unwrap();
2725    /// let g = c.create_graph().unwrap();
2726    /// let t1 = array_type(vec![3, 2, 3], INT32);
2727    /// let t2 = array_type(vec![2, 3], INT32);
2728    /// let n1 = g.input(t1).unwrap();
2729    /// let n2 = g.input(t2).unwrap();
2730    /// let n3 = g.create_tuple(vec![n1,n2]).unwrap();
2731    /// ```
2732    pub fn create_tuple(&self, elements: Vec<Node>) -> Result<Node> {
2733        self.add_node(elements, vec![], Operation::CreateTuple)
2734    }
2735
2736    /// Adds a node that creates a vector from several (possibly, zero) elements of the same type.
2737    ///
2738    /// # Arguments
2739    ///
2740    /// `elements` - vector of nodes
2741    ///
2742    /// # Returns
2743    ///
2744    /// New node with a created vector
2745    ///
2746    /// # Example
2747    ///
2748    /// ```
2749    /// # use ciphercore_base::graphs::create_context;
2750    /// # use ciphercore_base::data_types::{INT32, array_type};
2751    /// let c = create_context().unwrap();
2752    /// let g = c.create_graph().unwrap();
2753    /// let t = array_type(vec![3, 2, 3], INT32);
2754    /// let n1 = g.input(t.clone()).unwrap();
2755    /// let n2 = g.input(t.clone()).unwrap();
2756    /// let n3 = g.create_vector(t, vec![n1,n2]).unwrap();
2757    /// ```
2758    pub fn create_vector(&self, element_type: Type, elements: Vec<Node>) -> Result<Node> {
2759        self.add_node(elements, vec![], Operation::CreateVector(element_type))
2760    }
2761
2762    /// Adds a node that creates a named tuple from several (possibly, zero) elements.
2763    ///
2764    /// # Arguments
2765    ///
2766    /// `elements` - vector of pairs (node name, node)
2767    ///
2768    /// # Returns
2769    ///
2770    /// New node creating a named tuple
2771    ///
2772    /// # Example
2773    ///
2774    /// ```
2775    /// # use ciphercore_base::graphs::create_context;
2776    /// # use ciphercore_base::data_types::{INT32, array_type};
2777    /// let c = create_context().unwrap();
2778    /// let g = c.create_graph().unwrap();
2779    /// let t1 = array_type(vec![3, 2, 3], INT32);
2780    /// let t2 = array_type(vec![2, 3], INT32);
2781    /// let n1 = g.input(t1).unwrap();
2782    /// let n2 = g.input(t2).unwrap();
2783    /// let n3 = g.create_named_tuple(vec![("node1".to_owned(), n1), ("node2".to_owned(), n2)]).unwrap();
2784    /// ```
2785    pub fn create_named_tuple(&self, elements: Vec<(String, Node)>) -> Result<Node> {
2786        let mut nodes = vec![];
2787        let mut names = vec![];
2788        for (name, node) in elements {
2789            nodes.push(node);
2790            names.push(name);
2791        }
2792        self.add_node(nodes, vec![], Operation::CreateNamedTuple(names))
2793    }
2794
2795    /// Adds a node that extracts an element of a tuple.
2796    ///
2797    /// # Arguments
2798    ///
2799    /// * `tuple` - node containing a tuple
2800    /// * `index` - index of a tuple element between 0 and tuple length minus 1
2801    ///
2802    /// # Returns
2803    ///
2804    /// New node with an extracted element
2805    ///
2806    /// # Example
2807    ///
2808    /// ```
2809    /// # use ciphercore_base::data_types::{INT32, array_type};
2810    /// # use ciphercore_base::graphs::create_context;
2811    /// let c = create_context().unwrap();
2812    /// let g = c.create_graph().unwrap();
2813    /// let t1 = array_type(vec![3, 2, 3], INT32);
2814    /// let t2 = array_type(vec![2, 3], INT32);
2815    /// let n1 = g.input(t1).unwrap();
2816    /// let n2 = g.input(t2).unwrap();
2817    /// let n3 = g.create_tuple(vec![n1, n2]).unwrap();
2818    /// let n4 = g.tuple_get(n3, 1).unwrap();
2819    /// ```
2820    pub fn tuple_get(&self, tuple: Node, index: u64) -> Result<Node> {
2821        self.add_node(vec![tuple], vec![], Operation::TupleGet(index))
2822    }
2823
2824    /// Adds a node that extracts an element of a named tuple.
2825    ///
2826    /// # Arguments
2827    ///
2828    /// * `tuple` - node containing a named tuple
2829    /// * `key` - key of a tuple element
2830    ///
2831    /// # Returns
2832    ///
2833    /// New node extracting a tuple element
2834    ///
2835    /// # Example
2836    ///
2837    /// ```
2838    /// # use ciphercore_base::graphs::create_context;
2839    /// # use ciphercore_base::data_types::{array_type, INT32};
2840    /// let c = create_context().unwrap();
2841    /// let g = c.create_graph().unwrap();
2842    /// let t1 = array_type(vec![3, 2, 3], INT32);
2843    /// let t2 = array_type(vec![2, 3], INT32);
2844    /// let n1 = g.input(t1).unwrap();
2845    /// let n2 = g.input(t2).unwrap();
2846    /// let n3 = g.create_named_tuple(vec![("node1".to_owned(), n1), ("node2".to_owned(), n2)]).unwrap();
2847    /// let n4 = g.named_tuple_get(n3, "node2".to_owned()).unwrap();
2848    /// ```
2849    pub fn named_tuple_get(&self, tuple: Node, key: String) -> Result<Node> {
2850        self.add_node(vec![tuple], vec![], Operation::NamedTupleGet(key))
2851    }
2852
2853    /// Adds a node that extracts an element of a vector.
2854    ///
2855    /// # Arguments
2856    ///
2857    /// * `vec` - node containing a vector
2858    /// * `index` - node containing the index of a tuple element
2859    ///
2860    /// # Returns
2861    ///
2862    /// New node extracting a vector element
2863    ///
2864    /// # Example
2865    ///
2866    /// ```
2867    /// # use ciphercore_base::graphs::create_context;
2868    /// # use ciphercore_base::data_types::{UINT32, INT32, array_type, scalar_type};
2869    /// # use ciphercore_base::data_values::Value;
2870    /// let c = create_context().unwrap();
2871    /// let g = c.create_graph().unwrap();
2872    /// let t = array_type(vec![3, 2, 3], INT32);
2873    /// let n1 = g.input(t.clone()).unwrap();
2874    /// let n2 = g.input(t.clone()).unwrap();
2875    /// let n3 = g.create_vector(t, vec![n1,n2]).unwrap();
2876    /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
2877    /// let n4 = g.vector_get(n3, index).unwrap();
2878    /// ```
2879    pub fn vector_get(&self, vec: Node, index: Node) -> Result<Node> {
2880        self.add_node(vec![vec, index], vec![], Operation::VectorGet)
2881    }
2882
2883    /// Adds a node that takes vectors V<sub>1</sub>(n, t<sub>1</sub>), V<sub>2</sub>(n, t<sub>2</sub>), ..., V<sub>k</sub>(n, t<sub>k</sub>) of the same length and returns a vector V(n, tuple(t<sub>1</sub>, ..., t<sub>k</sub>)) (similar to [zip](https://doc.rust-lang.org/stable/std/iter/fn.zip.html)).
2884    ///
2885    /// # Arguments
2886    ///
2887    /// `nodes` - vector of nodes containing input vectors
2888    ///
2889    /// # Returns
2890    ///
2891    /// New zip node
2892    ///
2893    /// # Example
2894    ///
2895    /// ```
2896    /// # use ciphercore_base::data_types::{INT32, array_type, vector_type};
2897    /// # use ciphercore_base::graphs::create_context;
2898    /// let c = create_context().unwrap();
2899    /// let g = c.create_graph().unwrap();
2900    /// let t = array_type(vec![3, 2, 3], INT32);
2901    /// let vec_t = vector_type(3, t);
2902    /// let n1 = g.input(vec_t.clone()).unwrap();
2903    /// let n2 = g.input(vec_t.clone()).unwrap();
2904    /// let n3 = g.zip(vec![n1,n2]).unwrap();
2905    /// ```
2906    pub fn zip(&self, nodes: Vec<Node>) -> Result<Node> {
2907        self.add_node(nodes, vec![], Operation::Zip)
2908    }
2909
2910    /// Adds a node that creates a vector with `n` copies of a value of a given node.
2911    ///
2912    /// # Arguments
2913    ///
2914    /// * `a` - node containing a value
2915    /// * `n` - number of copies
2916    ///
2917    /// # Returns
2918    ///
2919    /// New repeat node
2920    ///
2921    /// # Example
2922    ///
2923    /// ```
2924    /// # use ciphercore_base::data_types::{INT32, array_type};
2925    /// # use ciphercore_base::graphs::create_context;
2926    /// let c = create_context().unwrap();
2927    /// let g = c.create_graph().unwrap();
2928    /// let t = array_type(vec![3, 2, 3], INT32);
2929    /// let n1 = g.input(t).unwrap();
2930    /// let n2 = g.repeat(n1, 10).unwrap();
2931    /// ```
2932    pub fn repeat(&self, a: Node, n: u64) -> Result<Node> {
2933        self.add_node(vec![a], vec![], Operation::Repeat(n))
2934    }
2935
2936    /// Adds a node that calls another graph with inputs contained in given nodes.
2937    ///
2938    /// The input graph must be finalized and have as many inputs as the number of provided arguments.
2939    ///
2940    /// For example, let `G` be a graph implementing the function `max(x,0)`, then `call(G, [17]) = max(17, 0)`.
2941    ///
2942    /// # Arguments
2943    ///
2944    /// * `graph` - graph with `n` input nodes
2945    /// * `arguments` - vector of `n` nodes
2946    ///
2947    /// # Returns
2948    ///
2949    /// New call node
2950    ///
2951    /// # Example
2952    ///
2953    /// ```
2954    /// # use ciphercore_base::data_types::{INT32, array_type};
2955    /// # use ciphercore_base::graphs::create_context;
2956    /// let c = create_context().unwrap();
2957    ///
2958    /// let g1 = c.create_graph().unwrap();
2959    /// let t = array_type(vec![3, 2, 3], INT32);
2960    /// let n1 = g1.input(t.clone()).unwrap();
2961    /// let n2 = g1.repeat(n1, 10).unwrap();
2962    /// let n3 = g1.vector_to_array(n2).unwrap();
2963    /// n3.set_as_output().unwrap();
2964    /// g1.finalize().unwrap();
2965    ///
2966    /// let g2 = c.create_graph().unwrap();
2967    /// let n4 = g2.input(t).unwrap();
2968    /// let n5 = g2.add(n4.clone(), n4).unwrap();
2969    /// let n6 = g2.call(g1, vec![n5]).unwrap();
2970    /// ```
2971    pub fn call(&self, graph: Graph, arguments: Vec<Node>) -> Result<Node> {
2972        self.add_node(arguments, vec![graph], Operation::Call)
2973    }
2974
2975    /// Adds a node that iteratively computes a given finalized graph on the elements of a given vector and updates the state value accordingly.
2976    ///
2977    /// This node calls another `graph` with 2 input nodes `old_state` and `input` and an output node that returns a [tuple](Type::Tuple) `(new_state, output)`. This graph is used to map the elements of a given vector `V` to another vector `W` as follows:
2978    /// ```text
2979    /// graph(state_0, V[0]) -> (state1, W[0]),
2980    /// graph(state_1, V[1]) -> (state2, W[1]),
2981    /// ...
2982    /// graph(state_k, V[k]) -> (final_state, W[k]).
2983    /// ```
2984    /// The output is a [tuple](Type::Tuple) `(final_state, W)`. The initial state `state_0` should be provided as an argument.
2985    ///
2986    /// This node generalize `map` and `reduce` procedures (see [MapReduce](https://en.wikipedia.org/wiki/MapReduce) for more details).
2987    ///
2988    /// For example, let `G` be a graph implementing the function `max(x,0)` and incrementing `state` if its output is negative, then `iterate(G, 0, [-1,2,0,3,2]) = (1, [0,2,0,3,2])`. The final state is equal to the number of negative values in the input vector.
2989    ///
2990    /// # Arguments
2991    ///
2992    /// * `graph` - graph with 2 input nodes of types T<sub>s</sub> and T<sub>i</sub> and returning a tuple of type (T<sub>s</sub>, T<sub>o</sub>)
2993    /// * `state` - node containing an initial state of type T<sub>s</sub>
2994    /// * `input` - node containing a vector with elements of type T<sub>i</sub>
2995    ///
2996    /// # Returns
2997    ///
2998    /// New iterate node
2999    ///
3000    /// # Example
3001    ///
3002    /// ```
3003    /// # use ciphercore_base::data_types::{INT32, BIT, scalar_type, vector_type};
3004    /// # use ciphercore_base::graphs::create_context;
3005    /// # use ciphercore_base::ops::utils::constant_scalar;
3006    /// let c = create_context().unwrap();
3007    ///
3008    /// let t_s = scalar_type(BIT);
3009    /// let t = scalar_type(INT32);
3010    /// let vec_t = vector_type(10, t.clone());
3011    ///
3012    /// // Graph that outputs 0 at even indices or input value at odd indices.
3013    /// let g1 = c.create_graph().unwrap();
3014    /// {
3015    ///     let old_state = g1.input(t_s.clone()).unwrap();
3016    ///     let input = g1.input(t.clone()).unwrap();
3017    ///     let result = g1.mixed_multiply(input, old_state.clone()).unwrap();
3018    ///     let new_state = g1.add(old_state, constant_scalar(&g1, 1, BIT).unwrap()).unwrap();
3019    ///     let out_tuple = g1.create_tuple(vec![new_state, result]).unwrap();
3020    ///     out_tuple.set_as_output().unwrap();
3021    ///     g1.finalize().unwrap();
3022    /// }
3023    ///
3024    /// let g2 = c.create_graph().unwrap();
3025    /// let initial_state = constant_scalar(&g2, 0, BIT).unwrap();
3026    /// let input_vector = g2.input(vec_t).unwrap();
3027    /// g2.iterate(g1, initial_state, input_vector).unwrap();
3028    /// ```
3029    pub fn iterate(&self, graph: Graph, state: Node, input: Node) -> Result<Node> {
3030        self.add_node(vec![state, input], vec![graph], Operation::Iterate)
3031    }
3032
3033    /// Adds a node converting an array to a vector.
3034    ///
3035    /// Given an array of shape `[a,b,c]`, this node returns a vector of `a` arrays of shape `[b,c]`.
3036    ///
3037    /// # Arguments
3038    ///
3039    /// `a` - node containing an array
3040    ///
3041    /// # Returns
3042    ///
3043    /// New node converting an array to a vector
3044    ///
3045    /// # Example
3046    ///
3047    /// ```
3048    /// # use ciphercore_base::graphs::create_context;
3049    /// # use ciphercore_base::data_types::{array_type, scalar_type, INT32, UINT32};
3050    /// # use ciphercore_base::data_values::Value;
3051    /// let c = create_context().unwrap();
3052    /// let g = c.create_graph().unwrap();
3053    /// let t = array_type(vec![4, 3, 2], INT32);
3054    /// let n1 = g.input(t).unwrap();
3055    /// let n2 = g.array_to_vector(n1).unwrap();
3056    /// let index = g.constant(scalar_type(UINT32), Value::from_scalar(0, UINT32).unwrap()).unwrap();
3057    /// let n3 = g.vector_get(n2.clone(), index).unwrap();
3058    ///
3059    /// assert!(n2.get_type().unwrap().is_vector());
3060    /// assert_eq!(n3.get_type().unwrap().get_shape(), vec![3,2]);
3061    /// ```    
3062    pub fn array_to_vector(&self, a: Node) -> Result<Node> {
3063        self.add_node(vec![a], vec![], Operation::ArrayToVector)
3064    }
3065
3066    /// Adds a node converting a vector to an array.
3067    ///
3068    /// Given a vector of `a` arrays of shape `[b,c]`, this node returns an array of shape `[a,b,c]`.
3069    ///
3070    /// # Arguments
3071    ///
3072    /// `a` - node containing a vector
3073    ///
3074    /// # Returns
3075    ///
3076    /// New node converting a vector to an array
3077    ///
3078    /// # Example
3079    ///
3080    /// ```
3081    /// # use ciphercore_base::graphs::create_context;
3082    /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3083    /// let c = create_context().unwrap();
3084    /// let g = c.create_graph().unwrap();
3085    /// let t = array_type(vec![3, 2], INT32);
3086    /// let vec_t = vector_type(4, t);
3087    /// let n1 = g.input(vec_t).unwrap();
3088    /// let n2 = g.vector_to_array(n1).unwrap();
3089    ///
3090    /// assert!(n2.get_type().unwrap().is_array());
3091    /// assert_eq!(n2.get_type().unwrap().get_shape(), vec![4, 3, 2]);
3092    /// ```
3093    pub fn vector_to_array(&self, a: Node) -> Result<Node> {
3094        self.add_node(vec![a], vec![], Operation::VectorToArray)
3095    }
3096
3097    /// Adds a node creating an array from the elements of an input array indexed by another array along a given axis.
3098    ///
3099    /// Given an input array, this node replaces the dimension `axis` with the dimensions introduced by the indexing array.
3100    ///
3101    /// Indices must be unique to prevent possible duplication of shares/ciphertexts.
3102    /// Such duplicates might cause devastating data leakage.
3103    ///
3104    /// This operation is similar to [the NumPy take operation](https://numpy.org/doc/stable/reference/generated/numpy.take.html).
3105    ///
3106    /// **WARNING**: this function should not be used before MPC compilation.
3107    ///
3108    /// # Arguments
3109    ///
3110    /// `input` - node containing an input array
3111    /// `indices` - node containing indices
3112    /// `axis` - index of the axis along which indices are chosen
3113    ///
3114    /// # Returns
3115    ///
3116    /// New Gather node
3117    #[doc(hidden)]
3118    pub fn gather(&self, input: Node, indices: Node, axis: u64) -> Result<Node> {
3119        self.add_node(vec![input, indices], vec![], Operation::Gather(axis))
3120    }
3121
3122    /// Checks that the graph has an output node and finalizes the graph.
3123    ///
3124    /// After finalization the graph can't be changed.
3125    ///
3126    /// # Returns
3127    ///
3128    /// Finalized graph
3129    ///
3130    /// # Example
3131    ///
3132    /// ```
3133    /// # use ciphercore_base::graphs::create_context;
3134    /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3135    /// let c = create_context().unwrap();
3136    /// let g = c.create_graph().unwrap();
3137    /// let t = array_type(vec![3, 2], INT32);
3138    /// let vec_t = vector_type(4, t);
3139    /// let n1 = g.input(vec_t).unwrap();
3140    /// let n2 = g.vector_to_array(n1).unwrap();
3141    /// n2.set_as_output().unwrap();
3142    /// g.finalize().unwrap();
3143    /// ```
3144    pub fn finalize(&self) -> Result<Graph> {
3145        let output_node = self.body.borrow_mut().output_node.clone();
3146        match output_node {
3147            Some(_) => {
3148                self.body.borrow_mut().finalized = true;
3149                Ok(self.clone())
3150            }
3151            None => Err(runtime_error!("Output node is not set")),
3152        }
3153    }
3154
3155    /// Returns the vector of nodes contained in the graph in order of construction.
3156    ///
3157    /// # Returns
3158    ///
3159    /// Vector of nodes of the graph
3160    pub fn get_nodes(&self) -> Vec<Node> {
3161        self.body.borrow().nodes.clone()
3162    }
3163
3164    /// Promotes a given node to the output node of the parent graph.
3165    ///
3166    /// # Arguments
3167    ///
3168    /// `output_node` - node to be set as output
3169    ///
3170    /// # Example
3171    ///
3172    /// ```
3173    /// # use ciphercore_base::graphs::create_context;
3174    /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3175    /// let c = create_context().unwrap();
3176    /// let g = c.create_graph().unwrap();
3177    /// let t = array_type(vec![3, 2], INT32);
3178    /// let vec_t = vector_type(4, t);
3179    /// let n1 = g.input(vec_t).unwrap();
3180    /// let n2 = g.vector_to_array(n1).unwrap();
3181    /// g.set_output_node(n2).unwrap();
3182    /// g.finalize().unwrap();
3183    /// ```
3184    pub fn set_output_node(&self, output_node: Node) -> Result<()> {
3185        let current_output_node = self.body.borrow().output_node.clone();
3186        match current_output_node {
3187            Some(_) => Err(runtime_error!("Output node is already set")),
3188            None => {
3189                if output_node.get_graph() != *self {
3190                    Err(runtime_error!("Output node has to be from the same graph"))
3191                } else {
3192                    self.body.borrow_mut().output_node = Some(output_node.downgrade());
3193                    Ok(())
3194                }
3195            }
3196        }
3197    }
3198
3199    /// Returns the output node of the graph.
3200    ///
3201    /// # Returns
3202    ///
3203    /// Output node of the graph
3204    pub fn get_output_node(&self) -> Result<Node> {
3205        let current_output_node = self.body.borrow().output_node.clone();
3206        match current_output_node {
3207            Some(output_node) => Ok(output_node.upgrade()),
3208            None => Err(runtime_error!("Output node is not set")),
3209        }
3210    }
3211
3212    /// Returns the ID of the graph.
3213    ///
3214    /// A graph ID is a serial number of a graph between `0` and `n-1` where `n` is the number of graphs in the parent context.
3215    ///
3216    /// # Returns
3217    ///
3218    /// Graph ID
3219    pub fn get_id(&self) -> u64 {
3220        self.body.borrow().id
3221    }
3222
3223    /// Returns the number of the graph nodes.
3224    ///
3225    /// # Returns
3226    ///
3227    /// Number of the graph nodes
3228    pub fn get_num_nodes(&self) -> u64 {
3229        self.body.borrow().nodes.len() as u64
3230    }
3231
3232    /// Returns the node corresponding to a given ID.
3233    ///
3234    /// # Arguments
3235    ///
3236    /// `id` - node ID
3237    ///
3238    /// # Returns
3239    ///
3240    /// Node with a given ID
3241    pub fn get_node_by_id(&self, id: u64) -> Result<Node> {
3242        let nodes = &self.body.borrow().nodes;
3243        if id >= nodes.len() as u64 {
3244            Err(runtime_error!("Invalid id for the node retrieval"))
3245        } else {
3246            Ok(nodes[id as usize].clone())
3247        }
3248    }
3249
3250    /// Returns the context of the graph nodes.
3251    ///
3252    /// # Returns
3253    ///
3254    /// Context of the graph
3255    pub fn get_context(&self) -> Context {
3256        self.body.borrow().context.upgrade()
3257    }
3258
3259    /// Adds a node computing a given custom operation.
3260    ///
3261    /// Custom operations can be created by the user as public structs implementing the [CustomOperationBody](../custom_ops/trait.CustomOperationBody.html).
3262    ///
3263    /// # Arguments
3264    ///
3265    /// * `op` - custom operation
3266    /// * `arguments` - vector of nodes used as input for the custom operation
3267    ///
3268    /// # Returns
3269    ///
3270    /// New custom operation node
3271    ///
3272    /// # Example
3273    ///
3274    /// ```
3275    /// # use ciphercore_base::graphs::create_context;
3276    /// # use ciphercore_base::data_types::{array_type, BIT};
3277    /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
3278    /// let c = create_context().unwrap();
3279    /// let g = c.create_graph().unwrap();
3280    /// let t = array_type(vec![3, 2], BIT);
3281    /// let n1 = g.input(t).unwrap();
3282    /// let n2 = g.custom_op(CustomOperation::new(Not {}), vec![n1]).unwrap();
3283    /// ```
3284    pub fn custom_op(&self, op: CustomOperation, arguments: Vec<Node>) -> Result<Node> {
3285        self.add_node(arguments, vec![], Operation::Custom(op))
3286    }
3287
3288    /// Adds a node which logs its input at runtime, and returns the input.
3289    /// This is intended to be used for debugging.
3290    ///
3291    /// # Arguments
3292    ///
3293    /// * `message` - Informational message to be printed
3294    /// * `input` - Node to be printed
3295    ///
3296    /// # Returns
3297    ///
3298    /// The value of the node.
3299    ///
3300    /// # Example
3301    ///
3302    /// ```
3303    /// # use ciphercore_base::graphs::create_context;
3304    /// # use ciphercore_base::data_types::{array_type, BIT};
3305    /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
3306    /// let c = create_context().unwrap();
3307    /// let g = c.create_graph().unwrap();
3308    /// let t = array_type(vec![3, 2], BIT);
3309    /// let n1 = g.input(t).unwrap();
3310    /// let n2 = g.print("n1:".into(), n1).unwrap();
3311    /// ```
3312    pub fn print(&self, message: String, input: Node) -> Result<Node> {
3313        self.add_node(vec![input], vec![], Operation::Print(message))
3314    }
3315
3316    /// Adds a node which fails the execution at runtime if `condition` is false, and returns the `input` otherwise.
3317    /// This is intended to be used for debugging.
3318    ///
3319    /// # Arguments
3320    ///
3321    /// * `message` - message to be returned for the failed assertion.
3322    /// * `condition` - BIT to be checked in the assertion.
3323    /// * `input` - Node to be returned for pass-through.
3324    ///
3325    /// # Returns
3326    ///
3327    /// The value of the node.
3328    ///
3329    /// # Example
3330    ///
3331    /// ```
3332    /// # use ciphercore_base::graphs::create_context;
3333    /// # use ciphercore_base::data_types::{array_type, scalar_type, BIT};
3334    /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
3335    /// let c = create_context().unwrap();
3336    /// let g = c.create_graph().unwrap();
3337    /// let cond = g.input(scalar_type(BIT)).unwrap();
3338    /// let t = array_type(vec![3, 2], BIT);
3339    /// let n1 = g.input(t).unwrap();
3340    /// let n2 = g.assert("Condition".into(), cond, n1).unwrap();
3341    /// ```
3342    pub fn assert(&self, message: String, condition: Node, input: Node) -> Result<Node> {
3343        self.add_node(vec![condition, input], vec![], Operation::Assert(message))
3344    }
3345}
3346
3347/// Methods which aren't supposed to be imported in Python.
3348impl Graph {
3349    /// Adds an operation node to the graph and returns it.
3350    ///
3351    /// # Arguments
3352    ///
3353    /// * `node_dependencies` - vector of nodes necessary to perform the given operation
3354    /// * `graph_dependencies` - vector of graphs necessary to perform the given operation
3355    /// * `operation` - operation performed by the node
3356    ///
3357    /// # Returns
3358    ///
3359    /// New operation node that gets added
3360    pub fn add_node(
3361        &self,
3362        node_dependencies: Vec<Node>,
3363        graph_dependencies: Vec<Graph>,
3364        operation: Operation,
3365    ) -> Result<Node> {
3366        if self.is_finalized() {
3367            return Err(runtime_error!("Can't add a node to a finalized graph"));
3368        }
3369        for dependency in &node_dependencies {
3370            if dependency.get_graph() != *self
3371                || dependency.get_id() >= self.body.borrow().nodes.len() as u64
3372                || self.body.borrow().nodes[dependency.get_id() as usize] != *dependency
3373            {
3374                return Err(runtime_error!(
3375                    "Can't add a node with invalid node dependencies"
3376                ));
3377            }
3378        }
3379        for dependency in &graph_dependencies {
3380            if !dependency.is_finalized() {
3381                return Err(runtime_error!(
3382                    "Can't add a node with not finilized graph dependency"
3383                ));
3384            }
3385            if dependency.get_id() >= self.get_id() {
3386                return Err(runtime_error!(
3387                    "Can't add a node with graph dependency with bigger id. {} >= {}",
3388                    dependency.get_id(),
3389                    self.get_id()
3390                ));
3391            }
3392            if dependency.get_context() != self.get_context() {
3393                return Err(runtime_error!(
3394                    "Can't add a node with graph dependency from different context"
3395                ));
3396            }
3397        }
3398        let id = self.body.borrow().nodes.len() as u64;
3399        let result = Node {
3400            body: Arc::new(AtomicRefCell::new(NodeBody {
3401                graph: self.downgrade(),
3402                node_dependencies: node_dependencies.iter().map(|n| n.downgrade()).collect(),
3403                graph_dependencies: graph_dependencies.iter().map(|g| g.downgrade()).collect(),
3404                operation,
3405                id,
3406            })),
3407        };
3408        {
3409            let mut cell = self.body.borrow_mut();
3410            cell.nodes.push(result.clone());
3411        }
3412        let mut context_has_type_checker = false;
3413        {
3414            let context = self.get_context();
3415            let mut context_cell = context.body.borrow_mut();
3416            let type_checker = &mut context_cell.type_checker;
3417            if type_checker.is_some() {
3418                context_has_type_checker = true;
3419            }
3420        }
3421        if context_has_type_checker {
3422            let type_checking_result = result.get_type();
3423            if type_checking_result.is_err() {
3424                self.remove_last_node(result)?;
3425                return Err(type_checking_result.expect_err("Should not be here"));
3426            }
3427            let type_result = type_checking_result?;
3428
3429            let size_estimate = get_size_estimation_in_bits(type_result);
3430            if size_estimate.is_err() {
3431                self.remove_last_node(result)?;
3432                return Err(runtime_error!("Trying to add a node with invalid size"));
3433            }
3434            if size_estimate? > type_size_limit_constants::MAX_INDIVIDUAL_NODE_SIZE {
3435                self.remove_last_node(result)?;
3436                return Err(runtime_error!(
3437                    "Trying to add a node larger than MAX_INDIVIDUAL_NODE_SIZE"
3438                ));
3439            }
3440
3441            let context = self.get_context();
3442            let size_checking_result = context.try_update_total_size(result.clone());
3443            if size_checking_result.is_err() {
3444                self.remove_last_node(result)?;
3445                return Err(size_checking_result.expect_err("Should not be here"));
3446            }
3447        }
3448        Ok(result)
3449    }
3450
3451    fn remove_last_node(&self, n: Node) -> Result<()> {
3452        if n.get_graph() != *self {
3453            return Err(runtime_error!(
3454                "The node to be removed from a different graph"
3455            ));
3456        }
3457        {
3458            let cell = self.body.borrow();
3459            if n != *cell
3460                .nodes
3461                .last()
3462                .ok_or_else(|| runtime_error!("Nodes list is empty"))?
3463            {
3464                return Err(runtime_error!(
3465                    "The node to be removed is not the last node"
3466                ));
3467            }
3468        };
3469        let context = self.get_context();
3470        context.unregister_node(n.clone())?;
3471        let mut context_body = context.body.borrow_mut();
3472        if let Some(tc) = &mut context_body.type_checker {
3473            tc.unregister_node(n)?;
3474        }
3475        let mut cell = self.body.borrow_mut();
3476        cell.nodes.pop();
3477        Ok(())
3478    }
3479
3480    pub(crate) fn nop(&self, a: Node) -> Result<Node> {
3481        self.add_node(vec![a], vec![], Operation::NOP)
3482    }
3483
3484    pub(crate) fn prf(&self, key: Node, iv: u64, output_type: Type) -> Result<Node> {
3485        self.add_node(vec![key], vec![], Operation::PRF(iv, output_type))
3486    }
3487
3488    pub(crate) fn permutation_from_prf(&self, key: Node, iv: u64, n: u64) -> Result<Node> {
3489        self.add_node(vec![key], vec![], Operation::PermutationFromPRF(iv, n))
3490    }
3491
3492    pub(super) fn is_finalized(&self) -> bool {
3493        self.body.borrow().finalized
3494    }
3495
3496    pub(super) fn check_finalized(&self) -> Result<()> {
3497        if !self.is_finalized() {
3498            return Err(runtime_error!("Graph is not finalized"));
3499        }
3500        Ok(())
3501    }
3502
3503    fn make_serializable(&self) -> SerializableGraph {
3504        let output_node = match self.get_output_node() {
3505            Ok(n) => Some(n.get_id()),
3506            Err(_) => None,
3507        };
3508        Arc::new(SerializableGraphBody {
3509            finalized: self.is_finalized(),
3510            nodes: self
3511                .get_nodes()
3512                .iter()
3513                .map(|n| n.make_serializable())
3514                .collect(),
3515            output_node,
3516        })
3517    }
3518
3519    fn downgrade(&self) -> WeakGraph {
3520        WeakGraph {
3521            body: Arc::downgrade(&self.body),
3522        }
3523    }
3524
3525    #[doc(hidden)]
3526    pub fn add_annotation(&self, annotation: GraphAnnotation) -> Result<Graph> {
3527        self.get_context().add_graph_annotation(self, annotation)?;
3528        Ok(self.clone())
3529    }
3530
3531    pub fn get_annotations(&self) -> Result<Vec<GraphAnnotation>> {
3532        self.get_context().get_graph_annotations(self.clone())
3533    }
3534
3535    /// Rearrange given input values according to the names and the order of the related input nodes.
3536    ///
3537    /// For example, given a graph with the first input node named 'A' and the second one named 'B' and input values `{'B': v, 'A': w}`, this function returns a vector `[w, v]`.
3538    ///
3539    /// # Arguments
3540    ///
3541    /// `values` - hashmap of values keyed by node names
3542    ///
3543    /// # Returns
3544    ///
3545    /// Vector of values arranged by node names
3546    ///
3547    /// # Example
3548    ///
3549    /// ```
3550    /// # use ciphercore_base::graphs::create_context;
3551    /// # use ciphercore_base::data_types::{BIT, scalar_type};
3552    /// # use std::collections::HashMap;
3553    /// let c = create_context().unwrap();
3554    /// let g = c.create_graph().unwrap();
3555    /// let t = scalar_type(BIT);
3556    /// let n1 = g.input(t.clone()).unwrap();
3557    /// n1.set_name("input1").unwrap();
3558    /// let n2 = g.input(t.clone()).unwrap();
3559    /// n2.set_name("input2").unwrap();
3560    ///
3561    /// let mut input_map = HashMap::new();
3562    /// input_map.insert("input2", 2);
3563    /// input_map.insert("input1", 1);
3564    /// let ordered_input = g.prepare_input_values(input_map).unwrap();
3565    ///
3566    /// assert_eq!(vec![1,2], ordered_input);
3567    /// ```
3568    pub fn prepare_input_values<T: Clone>(&self, values: HashMap<&str, T>) -> Result<Vec<T>> {
3569        self.get_context()
3570            .prepare_input_values(self.clone(), values)
3571    }
3572}
3573type WeakGraphBodyPointer = Weak<AtomicRefCell<GraphBody>>;
3574
3575struct WeakGraph {
3576    body: WeakGraphBodyPointer,
3577}
3578
3579impl WeakGraph {
3580    //upgrade function panics if the the Graph pointer it downgraded from went out of scope
3581    fn upgrade(&self) -> Graph {
3582        Graph {
3583            body: self.body.upgrade().unwrap(),
3584        }
3585    }
3586}
3587impl Clone for WeakGraph {
3588    fn clone(&self) -> Self {
3589        WeakGraph {
3590            body: self.body.clone(),
3591        }
3592    }
3593}
3594
3595#[doc(hidden)]
3596/// Various node-related properties which aren't used in the graph building
3597/// or type inference, but can be used in node expansion or MPC compilation.
3598#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
3599pub enum NodeAnnotation {
3600    AssociativeOperation,
3601    Private,
3602    Send(u64, u64), // (sender_index, receiver_index); indices belong to the set 0..PARTIES
3603    PRFMultiplication,
3604    PRFB2A,
3605    PRFTruncate,
3606}
3607
3608#[doc(hidden)]
3609/// Various graph-related properties which aren't used in the graph building
3610/// or type inference, but can be used in node expansion or MPC compilation.
3611#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
3612pub enum GraphAnnotation {
3613    AssociativeOperation,
3614    OneBitState,
3615    SmallState,
3616}
3617
3618struct ContextBody {
3619    finalized: bool,
3620    graphs: Vec<Graph>,
3621    main_graph: Option<WeakGraph>,
3622    /// graph_id -> name
3623    graphs_names: HashMap<u64, String>,
3624    /// name -> graph_id
3625    graphs_names_inverse: HashMap<String, u64>,
3626    /// (graph_id, node_id) -> name
3627    nodes_names: HashMap<(u64, u64), String>,
3628    /// graph_id -> (name -> node_id)
3629    nodes_names_inverse: HashMap<u64, HashMap<String, u64>>,
3630    /// (graph_id, node_id) -> NodeAnnotation's
3631    nodes_annotations: HashMap<(u64, u64), Vec<NodeAnnotation>>,
3632    /// (graph_id) -> GraphAnnotation's
3633    graphs_annotations: HashMap<u64, Vec<GraphAnnotation>>,
3634    total_size_nodes: u64,
3635    type_checker: Option<TypeInferenceWorker>,
3636}
3637
3638type ContextBodyPointer = Arc<AtomicRefCell<ContextBody>>;
3639
3640/// A structure that stores a pointer to a computation context that contains related computation graphs.
3641///
3642/// Context is a basic object to create computation graphs, arrange data flow between them and keep necessary information about them that is used for optimization, secure compilation and evaluation.
3643///
3644/// Context should have a main graph and be finalized in order to evaluate any of its graphs.
3645///
3646/// # Rust crates
3647///
3648/// [Clone] trait duplicates the pointer, not the underlying context.
3649///
3650/// [PartialEq] trait compares pointers, not the related contexts.
3651///
3652/// # Example
3653///
3654/// ```
3655/// # #[macro_use] extern crate maplit;
3656/// # fn main() {
3657/// # use ciphercore_base::graphs::{Context, create_context};
3658/// # use ciphercore_base::data_values::Value;
3659/// # use ciphercore_base::evaluators::random_evaluate;
3660/// # use ciphercore_base::data_types::{INT32, scalar_type};
3661/// # use ciphercore_base::errors::Result;
3662/// let context = || -> Result<Context> {
3663///     let context = create_context()?;
3664///     let graph = context.create_graph()?.set_name("main")?;
3665///     graph
3666///         .input(scalar_type(INT32))?
3667///         .set_name("a")?
3668///         .add(graph
3669///             .input(scalar_type(INT32))?
3670///             .set_name("b")?)?
3671///         .set_as_output()?;
3672///     graph.finalize()?.set_as_main()?;
3673///     context.finalize()?;
3674///     Ok(context)
3675/// }().unwrap();
3676///
3677/// let result = || -> Result<i32> {
3678///     let g = context.retrieve_graph("main")?;
3679///     let result = random_evaluate(
3680///         g.clone(),
3681///         g.prepare_input_values(
3682///             hashmap!{
3683///                 "a" => Value::from_scalar(123, INT32)?,
3684///                 "b" => Value::from_scalar(654, INT32)?,
3685///             },
3686///         )?,
3687///     )?;
3688///     let result = result.to_i32(INT32)?;
3689///     Ok(result)
3690/// }().unwrap();
3691///
3692/// assert_eq!(result, 777);
3693/// # }
3694/// ```
3695#[cfg_attr(feature = "py-binding", struct_wrapper)]
3696pub struct Context {
3697    body: ContextBodyPointer,
3698}
3699
3700impl fmt::Debug for Context {
3701    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3702        fmt::Display::fmt(self, f)
3703    }
3704}
3705
3706impl fmt::Display for Context {
3707    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3708        match serde_json::to_string(&self) {
3709            Ok(s) => write!(f, "{s}"),
3710            Err(_err) => Err(fmt::Error::default()),
3711        }
3712    }
3713}
3714
3715impl fmt::Display for Graph {
3716    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3717        write!(f, "Graph[num_nodes={}]", self.get_num_nodes())
3718    }
3719}
3720
3721impl fmt::Display for Node {
3722    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
3723        write!(
3724            f,
3725            "Node[type={}]",
3726            self.get_type().map_err(|_| fmt::Error::default())?
3727        )
3728    }
3729}
3730
3731#[derive(Serialize, Deserialize)]
3732struct SerializableContextBody {
3733    finalized: bool,
3734    graphs: Vec<SerializableGraph>,
3735    main_graph: Option<u64>,
3736    /// graph_id -> name
3737    graphs_names: Vec<(u64, String)>,
3738    /// (graph_id, node_id) -> name
3739    nodes_names: Vec<((u64, u64), String)>,
3740    /// (graph_id, node_id) -> NodeAnnotation's
3741    nodes_annotations: Vec<((u64, u64), Vec<NodeAnnotation>)>,
3742    /// (graph_id) -> GraphAnnotation's
3743    graphs_annotations: Vec<(u64, Vec<GraphAnnotation>)>,
3744}
3745
3746impl SerializableContextBody {
3747    fn recover_original_graph(
3748        serializable_graph: SerializableGraph,
3749        context: Context,
3750    ) -> Result<Graph> {
3751        let result_graph = context.create_graph()?;
3752        for node in &serializable_graph.nodes {
3753            let mut node_dependencies = vec![];
3754            for id in &node.node_dependencies {
3755                let current_nodes = &result_graph.body.borrow().nodes;
3756                if *id >= current_nodes.len() as u64 {
3757                    return Err(runtime_error!("Non-existent node dependency"));
3758                }
3759                node_dependencies.push(current_nodes[*id as usize].clone());
3760            }
3761            let mut graph_dependencies = vec![];
3762            for id in &node.graph_dependencies {
3763                let context = result_graph.get_context();
3764                let current_graphs = &context.body.borrow().graphs;
3765                if *id >= current_graphs.len() as u64 {
3766                    return Err(runtime_error!("Non-existent graph dependency"));
3767                }
3768                graph_dependencies.push(current_graphs[*id as usize].clone());
3769            }
3770            result_graph.add_node(
3771                node_dependencies,
3772                graph_dependencies,
3773                node.operation.clone(),
3774            )?;
3775        }
3776        if let Some(id) = serializable_graph.output_node {
3777            let rebuilt_output_node = {
3778                let current_nodes = &result_graph.body.borrow().nodes;
3779                if id >= current_nodes.len() as u64 {
3780                    return Err(runtime_error!("Non-existent output node"));
3781                }
3782                current_nodes[id as usize].clone()
3783            };
3784            result_graph.set_output_node(rebuilt_output_node)?;
3785        }
3786        if serializable_graph.finalized {
3787            result_graph.finalize()?;
3788        }
3789        Ok(result_graph)
3790    }
3791
3792    fn recover_original_context(&self) -> Result<Context> {
3793        let result_context = create_context()?;
3794        for graph in &self.graphs {
3795            let _result_graph =
3796                Self::recover_original_graph(graph.clone(), result_context.clone())?;
3797        }
3798        if let Some(id) = self.main_graph {
3799            let rebuilt_main_graph = {
3800                let current_graphs = &result_context.body.borrow().graphs;
3801                if id >= current_graphs.len() as u64 {
3802                    return Err(runtime_error!("Non-existent main graph"));
3803                }
3804                current_graphs[id as usize].clone()
3805            };
3806            result_context.set_main_graph(rebuilt_main_graph)?;
3807        }
3808        for (id, _) in &self.graphs_names {
3809            let current_graphs = &result_context.body.borrow().graphs;
3810            if *id >= current_graphs.len() as u64 {
3811                return Err(runtime_error!("graphs_names contain an invalid ID"));
3812            }
3813        }
3814        for ((graph_id, node_id), _) in &self.nodes_names {
3815            let current_graphs = &result_context.body.borrow().graphs;
3816            if *graph_id >= current_graphs.len() as u64 {
3817                return Err(runtime_error!("nodes_names contain an invalid graph ID"));
3818            }
3819            let current_nodes = &current_graphs[*graph_id as usize].body.borrow().nodes;
3820            if *node_id >= current_nodes.len() as u64 {
3821                return Err(runtime_error!("nodes_names contain an invalid node ID"));
3822            }
3823        }
3824        for (id, name) in &self.graphs_names {
3825            let current_graph = {
3826                let current_graphs = &result_context.body.borrow().graphs;
3827                current_graphs[*id as usize].clone()
3828            };
3829            result_context.set_graph_name(current_graph, name)?;
3830        }
3831        for ((graph_id, node_id), name) in &self.nodes_names {
3832            let current_node = {
3833                let current_graphs = &result_context.body.borrow().graphs;
3834                let current_nodes = &current_graphs[*graph_id as usize].body.borrow().nodes;
3835                current_nodes[*node_id as usize].clone()
3836            };
3837            result_context.set_node_name(current_node, name)?;
3838        }
3839        for (id, annotations) in &self.graphs_annotations {
3840            let current_graph = {
3841                let current_graphs = &result_context.body.borrow().graphs;
3842                current_graphs[*id as usize].clone()
3843            };
3844            for annotation in annotations {
3845                result_context.add_graph_annotation(&current_graph, annotation.clone())?;
3846            }
3847        }
3848        for ((graph_id, node_id), annotations) in &self.nodes_annotations {
3849            let current_node = {
3850                let current_graphs = &result_context.body.borrow().graphs;
3851                let current_nodes = &current_graphs[*graph_id as usize].body.borrow().nodes;
3852                current_nodes[*node_id as usize].clone()
3853            };
3854            for annotation in annotations {
3855                result_context.add_node_annotation(&current_node, annotation.clone())?;
3856            }
3857        }
3858        if self.finalized {
3859            result_context.finalize()?;
3860        }
3861        Ok(result_context)
3862    }
3863}
3864
3865type SerializableContext = Arc<SerializableContextBody>;
3866
3867impl Clone for Context {
3868    /// Returns a new [Context] value with a copy of the pointer to a node.
3869    fn clone(&self) -> Self {
3870        Context {
3871            body: self.body.clone(),
3872        }
3873    }
3874}
3875
3876impl PartialEq for Context {
3877    /// Tests whether `self` and `other` contexts are equal via comparison of their respective pointers.
3878    ///
3879    /// # Arguments
3880    ///
3881    /// `other` - another [Context] value
3882    ///
3883    /// # Returns
3884    ///
3885    /// `true` if `self` and `other` are equal, `false` otherwise
3886    fn eq(&self, other: &Self) -> bool {
3887        Arc::ptr_eq(&self.body, &other.body)
3888    }
3889}
3890
3891impl Eq for Context {}
3892
3893/// Public methods which supposed to be imported in Python.
3894#[cfg_attr(feature = "py-binding", impl_wrapper)]
3895impl Context {
3896    /// Creates an empty computation graph in this context.
3897    ///
3898    /// # Returns
3899    ///
3900    /// New computation graph
3901    ///
3902    /// # Example
3903    ///
3904    /// ```
3905    /// # use ciphercore_base::graphs::create_context;
3906    /// let c = create_context().unwrap();
3907    /// let g = c.create_graph().unwrap();
3908    /// ```
3909    pub fn create_graph(&self) -> Result<Graph> {
3910        if self.body.borrow().finalized {
3911            return Err(runtime_error!("Can't add a graph to a finalized context"));
3912        }
3913        let id = self.body.borrow().graphs.len() as u64;
3914        let result = Graph {
3915            body: Arc::new(AtomicRefCell::new(GraphBody {
3916                finalized: false,
3917                nodes: vec![],
3918                output_node: None,
3919                id,
3920                context: self.downgrade(),
3921            })),
3922        };
3923        self.body.borrow_mut().graphs.push(result.clone());
3924        Ok(result)
3925    }
3926
3927    /// Finalizes the context if all its graphs are finalized and the main graph is set.
3928    ///
3929    /// After finalization the context can't be changed.
3930    ///
3931    /// # Returns
3932    ///
3933    /// Finalized context
3934    ///
3935    /// # Example
3936    ///
3937    /// ```
3938    /// # use ciphercore_base::graphs::create_context;
3939    /// # use ciphercore_base::data_types::{array_type, vector_type, INT32};
3940    /// let c = create_context().unwrap();
3941    /// let g = c.create_graph().unwrap();
3942    /// let t = array_type(vec![3, 2], INT32);
3943    /// let vec_t = vector_type(4, t);
3944    /// let n1 = g.input(vec_t).unwrap();
3945    /// let n2 = g.vector_to_array(n1).unwrap();
3946    /// n2.set_as_output().unwrap();
3947    /// g.finalize().unwrap();
3948    /// c.set_main_graph(g).unwrap();
3949    /// c.finalize().unwrap();
3950    /// ```
3951    pub fn finalize(&self) -> Result<Context> {
3952        for graph in self.get_graphs() {
3953            graph.check_finalized()?;
3954        }
3955        let main_graph = self.body.borrow().main_graph.clone();
3956        match main_graph {
3957            Some(_) => {
3958                self.body.borrow_mut().finalized = true;
3959                Ok(self.clone())
3960            }
3961            _ => Err(runtime_error!(
3962                "Can't finalize the context without the main graph"
3963            )),
3964        }
3965    }
3966
3967    /// Promotes a graph to the main one in this context.
3968    ///
3969    /// # Arguments
3970    ///
3971    /// `graph` - graph
3972    ///
3973    /// # Returns
3974    ///
3975    /// This context
3976    ///
3977    /// # Example
3978    ///
3979    /// ```
3980    /// # use ciphercore_base::graphs::create_context;
3981    /// # use ciphercore_base::data_types::{array_type, INT32};
3982    /// let c = create_context().unwrap();
3983    /// let g = c.create_graph().unwrap();
3984    /// let t = array_type(vec![3, 2], INT32);
3985    /// let n = g.input(t).unwrap();
3986    /// n.set_as_output().unwrap();
3987    /// g.finalize().unwrap();
3988    /// c.set_main_graph(g).unwrap();
3989    /// ```
3990    pub fn set_main_graph(&self, graph: Graph) -> Result<Context> {
3991        let current_main_graph = self.body.borrow().main_graph.clone();
3992        match current_main_graph {
3993            Some(_) => Err(runtime_error!("Main graph is already set")),
3994            None => {
3995                if graph.get_context() != *self {
3996                    return Err(runtime_error!("Main graph is from the wrong context"));
3997                }
3998                graph.check_finalized()?;
3999                self.body.borrow_mut().main_graph = Some(graph.downgrade());
4000                Ok(self.clone())
4001            }
4002        }
4003    }
4004
4005    /// Returns the vector of graphs contained in this context in order of creation.
4006    ///
4007    /// # Returns
4008    ///
4009    /// Vector of the graphs in this context
4010    pub fn get_graphs(&self) -> Vec<Graph> {
4011        self.body.borrow().graphs.clone()
4012    }
4013
4014    /// Does nothing if the context is finalized; otherwise returns a runtime error.
4015    ///
4016    /// # Returns
4017    ///
4018    /// Runtime error if this context is not finalized
4019    pub fn check_finalized(&self) -> Result<()> {
4020        if !self.is_finalized() {
4021            return Err(runtime_error!("Context is not finalized"));
4022        }
4023        Ok(())
4024    }
4025
4026    /// Returns the main graph of the context if it is already set.
4027    ///
4028    /// # Returns
4029    ///
4030    /// Main graph of the context
4031    pub fn get_main_graph(&self) -> Result<Graph> {
4032        match &self.body.borrow().main_graph {
4033            Some(g) => Ok(g.upgrade()),
4034            None => Err(runtime_error!("main graph is not set")),
4035        }
4036    }
4037
4038    /// Returns the number of graphs contained in this context.
4039    ///
4040    /// # Returns
4041    ///
4042    /// Number of the graphs in this context
4043    pub fn get_num_graphs(&self) -> u64 {
4044        self.body.borrow().graphs.len() as u64
4045    }
4046
4047    /// Returns the graph contained in this context with a given ID.
4048    ///
4049    /// A graph ID is a serial number of a graph between `0` and `n-1` where `n` is the number of graphs in this context.
4050    ///
4051    /// # Arguments
4052    ///
4053    /// `id` - ID of a graph
4054    ///
4055    /// # Returns
4056    ///
4057    /// Graph with the given ID
4058    pub fn get_graph_by_id(&self, id: u64) -> Result<Graph> {
4059        let graphs = &self.body.borrow().graphs;
4060        if id >= graphs.len() as u64 {
4061            Err(runtime_error!("Invalid id for the graph retrieval"))
4062        } else {
4063            Ok(graphs[id as usize].clone())
4064        }
4065    }
4066
4067    /// Returns the node contained in this context with a given global ID.
4068    ///
4069    /// The global ID of a node is a pair of the node ID and the ID of its parent graph.
4070    ///
4071    /// # Arguments
4072    ///
4073    /// `id` - tuple (graph ID, node ID)
4074    ///
4075    /// # Returns
4076    ///
4077    /// Node with the given global ID
4078    pub fn get_node_by_global_id(&self, id: (u64, u64)) -> Result<Node> {
4079        self.get_graph_by_id(id.0)?.get_node_by_id(id.1)
4080    }
4081
4082    /// Sets the name of a graph.
4083    ///
4084    /// A given name should be unique.
4085    ///
4086    /// # Arguments
4087    ///
4088    /// * `graph` - graph
4089    /// * `name` - name of the graph
4090    ///
4091    /// # Returns
4092    ///
4093    /// This context
4094    ///
4095    /// # Example
4096    ///
4097    /// ```
4098    /// # use ciphercore_base::graphs::create_context;
4099    /// let c = create_context().unwrap();
4100    /// let g = c.create_graph().unwrap();
4101    /// g.set_name("relu").unwrap();
4102    /// ```
4103    pub fn set_graph_name(&self, graph: Graph, name: &str) -> Result<Context> {
4104        if graph.get_context() != *self {
4105            return Err(runtime_error!(
4106                "The graph to be named is in a different context"
4107            ));
4108        }
4109        if self.is_finalized() {
4110            return Err(runtime_error!(
4111                "Can't set a graph name in a finalized context"
4112            ));
4113        }
4114        let id = graph.get_id();
4115        let name_owned = name.to_owned();
4116        let mut cell = self.body.borrow_mut();
4117        if cell.graphs_names.get(&id).is_some() {
4118            return Err(runtime_error!("Can't set the graph name twice"));
4119        }
4120        if cell.graphs_names_inverse.get(name).is_some() {
4121            return Err(runtime_error!("Graph names must be unique"));
4122        }
4123        cell.graphs_names.insert(id, name_owned.clone());
4124        cell.graphs_names_inverse.insert(name_owned, id);
4125        Ok(self.clone())
4126    }
4127
4128    /// Returns the name of a graph.
4129    ///
4130    /// # Arguments
4131    ///
4132    /// `graph` - graph
4133    ///
4134    /// # Returns
4135    ///
4136    /// Name of a given graph
4137    ///
4138    /// # Example
4139    ///
4140    /// ```
4141    /// # use ciphercore_base::graphs::create_context;
4142    /// let c = create_context().unwrap();
4143    /// let g = c.create_graph().unwrap();
4144    /// g.set_name("relu").unwrap();
4145    /// assert_eq!(c.get_graph_name(g).unwrap(), "relu".to_owned());
4146    /// ```
4147    pub fn get_graph_name(&self, graph: Graph) -> Result<String> {
4148        if graph.get_context() != *self {
4149            return Err(runtime_error!("The graph is in a different context"));
4150        }
4151        let cell = self.body.borrow();
4152        Ok(cell
4153            .graphs_names
4154            .get(&graph.get_id())
4155            .ok_or_else(|| runtime_error!("The graph does not have a name assigned"))?
4156            .clone())
4157    }
4158
4159    /// Returns the graph with a given name in this context.
4160    ///
4161    /// # Arguments
4162    ///
4163    /// `name` - graph name
4164    ///
4165    /// # Returns
4166    ///
4167    /// Graph with a given name
4168    ///
4169    /// # Example
4170    ///
4171    /// ```
4172    /// # use ciphercore_base::graphs::create_context;
4173    /// # use ciphercore_base::data_types::{BIT, scalar_type};
4174    /// let c = create_context().unwrap();
4175    /// let g = c.create_graph().unwrap();
4176    /// let n = g.input(scalar_type(BIT)).unwrap();
4177    /// g.set_name("input_graph").unwrap();
4178    /// assert!(g == c.retrieve_graph("input_graph").unwrap());
4179    /// ```
4180    pub fn retrieve_graph(&self, name: &str) -> Result<Graph> {
4181        let cell = self.body.borrow();
4182        let id = cell
4183            .graphs_names_inverse
4184            .get(name)
4185            .ok_or_else(|| runtime_error!("No graph with such a name exists"))?;
4186        let graph = cell.graphs[*id as usize].clone();
4187        Ok(graph)
4188    }
4189
4190    /// Sets the name of a node.
4191    ///
4192    /// A given name should be unique.
4193    ///
4194    /// # Arguments
4195    ///
4196    /// * `node` - node
4197    /// * `name` - name of a node
4198    ///
4199    /// # Returns
4200    ///
4201    /// This context
4202    ///
4203    /// # Example
4204    ///
4205    /// ```
4206    /// # use ciphercore_base::graphs::create_context;
4207    /// # use ciphercore_base::data_types::{scalar_type, BIT};
4208    /// let c = create_context().unwrap();
4209    /// let g = c.create_graph().unwrap();
4210    /// let t = scalar_type(BIT);
4211    /// let n = g.input(t).unwrap();
4212    /// c.set_node_name(n, "XOR").unwrap();
4213    /// ```
4214    pub fn set_node_name(&self, node: Node, name: &str) -> Result<Context> {
4215        if node.get_graph().get_context() != *self {
4216            return Err(runtime_error!(
4217                "The node to be named is in a different context"
4218            ));
4219        }
4220        if self.is_finalized() {
4221            return Err(runtime_error!(
4222                "Can't set a node name in a finalized context"
4223            ));
4224        }
4225        let node_id = node.get_id();
4226        let graph_id = node.get_graph().get_id();
4227        let mut cell = self.body.borrow_mut();
4228        if cell.nodes_names.get(&(graph_id, node_id)).is_some() {
4229            return Err(runtime_error!("Can't set the node name twice"));
4230        }
4231        if cell.nodes_names_inverse.get(&graph_id).is_none() {
4232            cell.nodes_names_inverse.insert(graph_id, HashMap::new());
4233        }
4234        let graph_map_inverse = cell
4235            .nodes_names_inverse
4236            .get_mut(&graph_id)
4237            .expect("Should not be here!");
4238        if graph_map_inverse.get(name).is_some() {
4239            return Err(runtime_error!(
4240                "Node names must be unique (within the graph)"
4241            ));
4242        }
4243        graph_map_inverse.insert(name.to_owned(), node_id);
4244        cell.nodes_names
4245            .insert((graph_id, node_id), name.to_owned());
4246        Ok(self.clone())
4247    }
4248
4249    /// Returns the name of a node.
4250    ///
4251    /// # Arguments
4252    ///
4253    /// `node` - node
4254    ///
4255    /// # Returns
4256    ///
4257    /// Name of a node or None if it doesn't have a name
4258    ///
4259    /// # Example
4260    ///
4261    /// ```
4262    /// # use ciphercore_base::graphs::create_context;
4263    /// # use ciphercore_base::data_types::{scalar_type, BIT};
4264    /// let c = create_context().unwrap();
4265    /// let g = c.create_graph().unwrap();
4266    /// let t = scalar_type(BIT);
4267    /// let n = g.input(t).unwrap();
4268    /// n.set_name("XOR").unwrap();
4269    /// assert_eq!(c.get_node_name(n).unwrap(), Some("XOR".to_owned()));
4270    /// ```
4271    pub fn get_node_name(&self, node: Node) -> Result<Option<String>> {
4272        if node.get_graph().get_context() != *self {
4273            return Err(runtime_error!("The node is in a different context"));
4274        }
4275        let node_id = node.get_id();
4276        let graph_id = node.get_graph().get_id();
4277        let cell = self.body.borrow();
4278        Ok(cell.nodes_names.get(&(graph_id, node_id)).cloned())
4279    }
4280
4281    /// Returns the node with a given name in a given graph.
4282    ///
4283    /// # Arguments
4284    ///
4285    /// * `graph` - graph
4286    /// * `name` - node name
4287    ///
4288    /// # Returns
4289    ///
4290    /// Node with a given name
4291    ///
4292    /// # Example
4293    ///
4294    /// ```
4295    /// # use ciphercore_base::graphs::create_context;
4296    /// # use ciphercore_base::data_types::{BIT, scalar_type};
4297    /// let c = create_context().unwrap();
4298    /// let g = c.create_graph().unwrap();
4299    /// let n = g.input(scalar_type(BIT)).unwrap();
4300    /// n.set_name("input_node").unwrap();
4301    /// assert!(n == c.retrieve_node(g, "input_node").unwrap());
4302    /// ```
4303    pub fn retrieve_node(&self, graph: Graph, name: &str) -> Result<Node> {
4304        if graph.get_context() != *self {
4305            return Err(runtime_error!("The graph is in a different context"));
4306        }
4307        let graph_id = graph.get_id();
4308        let cell = self.body.borrow();
4309        let node_id = cell
4310            .nodes_names_inverse
4311            .get(&graph_id)
4312            .ok_or_else(|| runtime_error!("The graph has no named nodes"))?
4313            .get(name)
4314            .ok_or_else(|| runtime_error!("Node with a given name does not exist"))?;
4315        Ok(graph.body.borrow().nodes[*node_id as usize].clone())
4316    }
4317    /// Check that two given contexts contain the same data, i.e. graphs, nodes, names, parameters.
4318    ///
4319    /// Underlying structures that contain pointers (graphs, nodes) are compared by data they refer to.
4320    ///
4321    /// # Arguments
4322    ///
4323    /// * `context2` - context to compare
4324    ///
4325    /// # Returns
4326    ///
4327    /// `true` if the given contexts contain the same content, otherwise `false`
4328    pub fn deep_equal(&self, context2: Context) -> bool {
4329        contexts_deep_equal(self.clone(), context2)
4330    }
4331}
4332
4333fn serialize_hashmap<K, V>(map: HashMap<K, V>) -> Vec<(K, V)>
4334where
4335    K: Ord + Copy,
4336{
4337    let mut vec: Vec<_> = map.into_iter().collect();
4338    vec.sort_by_key(|(k, _)| *k);
4339    vec
4340}
4341
4342/// Methods which aren't supposed to be imported in Python.
4343impl Context {
4344    pub(super) fn is_finalized(&self) -> bool {
4345        self.body.borrow().finalized
4346    }
4347
4348    fn make_serializable(&self) -> SerializableContext {
4349        let main_graph = match self.get_main_graph() {
4350            Ok(g) => Some(g.get_id()),
4351            Err(_) => None,
4352        };
4353        let cell = self.body.borrow();
4354        Arc::new(SerializableContextBody {
4355            finalized: self.is_finalized(),
4356            graphs: self
4357                .get_graphs()
4358                .iter()
4359                .map(|g| g.make_serializable())
4360                .collect(),
4361            main_graph,
4362            graphs_names: serialize_hashmap(cell.graphs_names.clone()),
4363            nodes_names: serialize_hashmap(cell.nodes_names.clone()),
4364            graphs_annotations: serialize_hashmap(cell.graphs_annotations.clone()),
4365            nodes_annotations: serialize_hashmap(cell.nodes_annotations.clone()),
4366        })
4367    }
4368
4369    fn add_type_checker(&self) -> Result<Context> {
4370        {
4371            let mut cell = self.body.borrow_mut();
4372            if cell.type_checker.is_some() {
4373                return Err(runtime_error!(
4374                    "Type checker associated with the context already exists"
4375                ));
4376            }
4377            cell.type_checker = Some(create_type_inference_worker(self.clone()));
4378        }
4379        for graph in self.get_graphs() {
4380            for node in graph.get_nodes() {
4381                node.get_type()?;
4382            }
4383        }
4384        Ok(self.clone())
4385    }
4386
4387    fn get_total_size_nodes(&self) -> u64 {
4388        self.body.borrow().total_size_nodes
4389    }
4390
4391    fn set_total_size_nodes(&self, size: u64) {
4392        self.body.borrow_mut().total_size_nodes = size;
4393    }
4394
4395    fn try_update_total_size(&self, node: Node) -> Result<()> {
4396        let node_type = match node.get_operation() {
4397            Operation::Input(input_type) => input_type,
4398            Operation::Constant(t, _) => t,
4399            _ => return Ok(()),
4400        };
4401        if !node_type.is_valid() {
4402            return Err(runtime_error!("Node with an invalid type: {:?}", node_type));
4403        }
4404        let new_total_size = self
4405            .get_total_size_nodes()
4406            .checked_add(get_size_estimation_in_bits(node_type)?)
4407            .ok_or_else(|| runtime_error!("add overflow!"))?;
4408        if new_total_size > type_size_limit_constants::MAX_TOTAL_SIZE_NODES {
4409            return Err(runtime_error!(
4410                "Can't add a node: total size of nodes exceeds MAX_TOTAL_SIZE_NODES"
4411            ));
4412        }
4413        self.set_total_size_nodes(new_total_size);
4414        Ok(())
4415    }
4416
4417    fn unregister_node(&self, node: Node) -> Result<()> {
4418        if node.get_graph().get_context() != *self {
4419            return Err(runtime_error!(
4420                "The node to be unregister from  a different context"
4421            ));
4422        }
4423        if self.is_finalized() {
4424            return Err(runtime_error!(
4425                "Can't unregister a node from  a finalized context"
4426            ));
4427        }
4428
4429        let node_id = node.get_id();
4430        let graph_id = node.get_graph().get_id();
4431
4432        let mut cell = self.body.borrow_mut();
4433        let name_option = cell.nodes_names.remove(&(graph_id, node_id));
4434        cell.nodes_annotations.remove(&(graph_id, node_id));
4435        if cell.nodes_names_inverse.get(&graph_id).is_none() {
4436            return Ok(());
4437        }
4438        let graph_map_inverse = cell
4439            .nodes_names_inverse
4440            .get_mut(&graph_id)
4441            .expect("Should not be here!");
4442        if let Some(name) = name_option {
4443            graph_map_inverse.remove(&name);
4444        }
4445        Ok(())
4446    }
4447
4448    fn to_versioned_data(&self) -> Result<VersionedData> {
4449        VersionedData::create_versioned_data(
4450            DATA_VERSION,
4451            serde_json::to_string(&self.make_serializable())?,
4452        )
4453    }
4454    fn prepare_input_values<T: Clone>(
4455        &self,
4456        graph: Graph,
4457        values: HashMap<&str, T>,
4458    ) -> Result<Vec<T>> {
4459        if graph.get_context() != *self {
4460            return Err(runtime_error!("The graph is in a different context"));
4461        }
4462        let graph_id = graph.get_id();
4463        let cell = self.body.borrow();
4464        for node_name in values.keys() {
4465            cell.nodes_names_inverse
4466                .get(&graph_id)
4467                .ok_or_else(|| runtime_error!("Trying to call graph without named nodes"))?
4468                .get(node_name as &str)
4469                .ok_or_else(|| runtime_error!("Input with a given name is not found"))?;
4470        }
4471        let mut result = vec![];
4472        for node in graph.get_nodes() {
4473            if node.get_operation().is_input() {
4474                let node_id = node.get_id();
4475                let node_name = cell
4476                    .nodes_names
4477                    .get(&(graph_id, node_id))
4478                    .ok_or_else(|| runtime_error!("Unnamed input"))?;
4479                let node_value = values
4480                    .get(node_name as &str)
4481                    .ok_or_else(|| runtime_error!("Unspecified input"))?
4482                    .clone();
4483                result.push(node_value);
4484            }
4485        }
4486        Ok(result)
4487    }
4488
4489    pub(super) fn add_node_annotation(
4490        &self,
4491        node: &Node,
4492        annotation: NodeAnnotation,
4493    ) -> Result<Context> {
4494        if node.get_graph().get_context() != *self {
4495            return Err(runtime_error!(
4496                "The node to be annotated is in a different context"
4497            ));
4498        }
4499        if self.is_finalized() {
4500            return Err(runtime_error!(
4501                "Can't add a node annotation in a finalized context"
4502            ));
4503        }
4504        let node_id = node.get_id();
4505        let graph_id = node.get_graph().get_id();
4506        let key = (graph_id, node_id);
4507        let mut cell = self.body.borrow_mut();
4508        let annotations = cell.nodes_annotations.get_mut(&key);
4509        if let Some(annotation_vec) = annotations {
4510            annotation_vec.push(annotation);
4511        } else {
4512            cell.nodes_annotations.insert(key, vec![annotation]);
4513        }
4514        Ok(self.clone())
4515    }
4516
4517    pub(super) fn get_node_annotations(&self, node: Node) -> Result<Vec<NodeAnnotation>> {
4518        if node.get_graph().get_context() != *self {
4519            return Err(runtime_error!("The node is in a different context"));
4520        }
4521        let node_id = node.get_id();
4522        let graph_id = node.get_graph().get_id();
4523        let cell = self.body.borrow();
4524        Ok(cell
4525            .nodes_annotations
4526            .get(&(graph_id, node_id))
4527            .cloned()
4528            .unwrap_or_default())
4529    }
4530
4531    fn add_graph_annotation(&self, graph: &Graph, annotation: GraphAnnotation) -> Result<Context> {
4532        if graph.get_context() != *self {
4533            return Err(runtime_error!(
4534                "The graph to be annotated is in a different context"
4535            ));
4536        }
4537        if self.is_finalized() {
4538            return Err(runtime_error!(
4539                "Can't set a graph annotation in a finalized context"
4540            ));
4541        }
4542        let id = graph.get_id();
4543        let mut cell = self.body.borrow_mut();
4544        let annotations = cell.graphs_annotations.get_mut(&id);
4545        if let Some(annotation_vec) = annotations {
4546            annotation_vec.push(annotation);
4547        } else {
4548            cell.graphs_annotations.insert(id, vec![annotation]);
4549        }
4550        Ok(self.clone())
4551    }
4552
4553    fn get_graph_annotations(&self, graph: Graph) -> Result<Vec<GraphAnnotation>> {
4554        if graph.get_context() != *self {
4555            return Err(runtime_error!("The graph is in a different context"));
4556        }
4557        let cell = self.body.borrow();
4558        Ok(cell
4559            .graphs_annotations
4560            .get(&graph.get_id())
4561            .cloned()
4562            .unwrap_or_default())
4563    }
4564
4565    pub(super) fn downgrade(&self) -> WeakContext {
4566        WeakContext {
4567            body: Arc::downgrade(&self.body),
4568        }
4569    }
4570}
4571
4572impl Serialize for Context {
4573    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
4574    where
4575        S: Serializer,
4576    {
4577        let versioned_context = self
4578            .to_versioned_data()
4579            .expect("Error during conversion from Context into VersionedData");
4580        //VersionedData::from(self.clone());
4581        versioned_context.serialize(serializer)
4582    }
4583}
4584
4585impl<'de> Deserialize<'de> for Context {
4586    fn deserialize<D>(deserializer: D) -> std::result::Result<Context, D::Error>
4587    where
4588        D: Deserializer<'de>,
4589    {
4590        let versioned_context = VersionedData::deserialize(deserializer)?;
4591        if !versioned_context.check_version(DATA_VERSION) {
4592            Err(runtime_error!(
4593                "Context version doesn't match the requirement"
4594            ))
4595            .map_err(serde::de::Error::custom)
4596        } else {
4597            let serializable_context =
4598                serde_json::from_str::<SerializableContext>(versioned_context.get_data_string())
4599                    .expect("Error during deserialization of SerializableContext");
4600            serializable_context
4601                .recover_original_context()
4602                .map_err(serde::de::Error::custom)
4603        }
4604    }
4605}
4606
4607/// In general, `create_unchecked_context()` should not return errors, but
4608/// we still make the result type Result<Context> for uniformity.
4609pub(super) fn create_unchecked_context() -> Result<Context> {
4610    Ok(Context {
4611        body: Arc::new(AtomicRefCell::new(ContextBody {
4612            finalized: false,
4613            graphs: vec![],
4614            main_graph: None,
4615            graphs_names: HashMap::new(),
4616            graphs_names_inverse: HashMap::new(),
4617            nodes_names: HashMap::new(),
4618            nodes_names_inverse: HashMap::new(),
4619            graphs_annotations: HashMap::new(),
4620            nodes_annotations: HashMap::new(),
4621            type_checker: None,
4622            total_size_nodes: 0,
4623        })),
4624    })
4625}
4626
4627/// Creates an empty computation context.
4628///
4629/// # Returns
4630///
4631/// New computation context
4632///
4633/// # Example
4634///
4635/// ```
4636/// # use ciphercore_base::graphs::create_context;
4637/// let c = create_context().unwrap();
4638/// ```
4639#[cfg_attr(feature = "py-binding", fn_wrapper)]
4640pub fn create_context() -> Result<Context> {
4641    let context = create_unchecked_context()?;
4642    context.add_type_checker()?;
4643    Ok(context)
4644}
4645
4646fn graphs_deep_equal(graph1: Graph, graph2: Graph) -> bool {
4647    let graph1_body = graph1.body.borrow();
4648    let graph2_body = graph2.body.borrow();
4649    if graph1_body.finalized != graph2_body.finalized {
4650        return false;
4651    }
4652    if graph1_body.nodes.len() != graph2_body.nodes.len() {
4653        return false;
4654    }
4655    for j in 0..graph1_body.nodes.len() {
4656        let node1 = graph1_body.nodes[j].clone();
4657        let node2 = graph2_body.nodes[j].clone();
4658        let node1_body = node1.body.borrow();
4659        let node2_body = node2.body.borrow();
4660        if node1_body.operation != node2_body.operation {
4661            return false;
4662        }
4663        let node_dependencies1: Vec<u64> = node1_body
4664            .node_dependencies
4665            .iter()
4666            .map(|n| n.upgrade().get_id())
4667            .collect();
4668        let node_dependencies2: Vec<u64> = node2_body
4669            .node_dependencies
4670            .iter()
4671            .map(|n| n.upgrade().get_id())
4672            .collect();
4673        if node_dependencies1 != node_dependencies2 {
4674            return false;
4675        }
4676        let graph_dependencies1: Vec<u64> = node1_body
4677            .graph_dependencies
4678            .iter()
4679            .map(|g| g.upgrade().get_id())
4680            .collect();
4681        let graph_dependencies2: Vec<u64> = node2_body
4682            .graph_dependencies
4683            .iter()
4684            .map(|g| g.upgrade().get_id())
4685            .collect();
4686        if graph_dependencies1 != graph_dependencies2 {
4687            return false;
4688        }
4689    }
4690    if graph1_body
4691        .output_node
4692        .clone()
4693        .map(|n| n.upgrade().get_id())
4694        != graph2_body
4695            .output_node
4696            .clone()
4697            .map(|n| n.upgrade().get_id())
4698    {
4699        return false;
4700    }
4701    true
4702}
4703
4704/// Check that two given contexts contain the same data, i.e. graphs, nodes, names, parameters.
4705///
4706/// Underlying structures that contain pointers (graphs, nodes) are compared by data they refer to.
4707///
4708/// # Arguments
4709///
4710/// * `context1` - first context to compare
4711/// * `context2` - second context to compare
4712///
4713/// # Returns
4714///
4715/// `true` if the given contexts contain the same content, otherwise `false`
4716pub fn contexts_deep_equal(context1: Context, context2: Context) -> bool {
4717    let body1 = context1.body.borrow();
4718    let body2 = context2.body.borrow();
4719    if body1.finalized != body2.finalized {
4720        return false;
4721    }
4722    if body1.graphs_names != body2.graphs_names {
4723        return false;
4724    }
4725    if body1.nodes_names != body2.nodes_names {
4726        return false;
4727    }
4728    if body1.nodes_annotations != body2.nodes_annotations {
4729        return false;
4730    }
4731    if body1.graphs_annotations != body2.graphs_annotations {
4732        return false;
4733    }
4734    if body1.graphs.len() != body2.graphs.len() {
4735        return false;
4736    }
4737    for i in 0..body1.graphs.len() {
4738        if !graphs_deep_equal(body1.graphs[i].clone(), body2.graphs[i].clone()) {
4739            return false;
4740        }
4741    }
4742    body1.main_graph.clone().map(|g| g.upgrade().get_id())
4743        == body2.main_graph.clone().map(|g| g.upgrade().get_id())
4744}
4745
4746// Pass the node name of `in_node` to `out_node` if it is present.
4747pub(crate) fn copy_node_name(in_node: Node, out_node: Node) -> Result<()> {
4748    if let Some(node_name) = in_node.get_name()? {
4749        out_node.set_name(&node_name)?;
4750    }
4751    Ok(())
4752}
4753
4754type WeakContextBodyPointer = Weak<AtomicRefCell<ContextBody>>;
4755
4756pub(super) struct WeakContext {
4757    body: WeakContextBodyPointer,
4758}
4759
4760impl WeakContext {
4761    //upgrade function panics if the the Context pointer it downgraded from went out of scope
4762    pub(super) fn upgrade(&self) -> Context {
4763        Context {
4764            body: self.body.upgrade().unwrap(),
4765        }
4766    }
4767}
4768
4769#[doc(hidden)]
4770#[cfg(feature = "py-binding")]
4771#[pyo3::pymethods]
4772impl PyBindingSliceElement {
4773    #[staticmethod]
4774    pub fn from_single_element(ind: i64) -> Self {
4775        PyBindingSliceElement {
4776            inner: SliceElement::SingleIndex(ind),
4777        }
4778    }
4779    #[staticmethod]
4780    pub fn from_sub_array(start: Option<i64>, end: Option<i64>, step: Option<i64>) -> Self {
4781        PyBindingSliceElement {
4782            inner: SliceElement::SubArray(start, end, step),
4783        }
4784    }
4785    #[staticmethod]
4786    pub fn from_ellipsis() -> Self {
4787        PyBindingSliceElement {
4788            inner: SliceElement::Ellipsis,
4789        }
4790    }
4791}
4792
4793pub mod util {
4794    use super::{create_context, Context, Graph, Node};
4795    use crate::errors::Result;
4796
4797    /// Creates a computation context with a single graph within it.
4798    ///
4799    /// The graph is passed to the provided `build_graph_fn` function to
4800    /// specify the computation. The node returned by the function is marked
4801    /// as output.
4802    ///
4803    /// # Returns
4804    ///
4805    /// New computation context
4806    ///
4807    /// # Example
4808    ///
4809    /// ```
4810    /// # use ciphercore_base::graphs::util::simple_context;
4811    /// # use ciphercore_base::data_types::{scalar_type, INT32};
4812    /// let c = simple_context(|g| {
4813    ///     let a = g.input(scalar_type(INT32))?;
4814    ///     let b = g.input(scalar_type(INT32))?;
4815    ///     g.add(a, b)
4816    /// }).unwrap();
4817    /// ```
4818    pub fn simple_context<F>(build_graph_fn: F) -> Result<Context>
4819    where
4820        F: FnOnce(&Graph) -> Result<Node>,
4821    {
4822        let c = create_context()?;
4823        let g = c.create_graph()?;
4824        let out = build_graph_fn(&g)?;
4825        out.set_as_output()?;
4826        g.finalize()?;
4827        g.set_as_main()?;
4828        c.finalize()?;
4829        Ok(c)
4830    }
4831}
4832
4833#[cfg(test)]
4834mod tests {
4835    use super::*;
4836    use crate::data_types::{
4837        array_type, scalar_type, tuple_type, vector_type, BIT, UINT16, UINT64,
4838    };
4839    use crate::inline::inline_ops::InlineConfig;
4840    use crate::mpc::mpc_compiler::{prepare_for_mpc_evaluation, IOStatus};
4841    use crate::version::DATA_VERSION;
4842    use std::panic;
4843    use std::rc::Rc;
4844
4845    #[test]
4846    fn test_wellformed_cases() {
4847        let context = create_unchecked_context().unwrap();
4848        let graph = context.create_graph().unwrap();
4849        let input1 = graph.input(scalar_type(BIT)).unwrap();
4850        let input2 = graph.input(scalar_type(BIT)).unwrap();
4851        graph.add(input1.clone(), input2.clone()).unwrap();
4852        graph.subtract(input1.clone(), input2.clone()).unwrap();
4853        graph.multiply(input1.clone(), input2.clone()).unwrap();
4854        graph.dot(input1.clone(), input2.clone()).unwrap();
4855        graph.matmul(input1.clone(), input2.clone()).unwrap();
4856        graph.truncate(input1.clone(), 123).unwrap();
4857        let input3 = graph.input(array_type(vec![10, 20, 30], BIT)).unwrap();
4858        graph.sum(input3.clone(), vec![1, 2]).unwrap();
4859        graph.permute_axes(input3.clone(), vec![1, 2, 0]).unwrap();
4860        graph.get(input3.clone(), vec![1, 2]).unwrap();
4861        graph
4862            .reshape(input3.clone(), array_type(vec![20, 300], BIT))
4863            .unwrap();
4864        graph.nop(input3.clone()).unwrap();
4865        let key = graph.random(array_type(vec![128], BIT)).unwrap();
4866        graph
4867            .prf(key.clone(), 0, array_type(vec![10, 10], UINT64))
4868            .unwrap();
4869        graph
4870            .stack(vec![input1.clone(), input2.clone()], vec![2, 1])
4871            .unwrap();
4872        let c = graph
4873            .constant(scalar_type(BIT), Value::from_bytes(vec![1]))
4874            .unwrap();
4875        let input4 = graph.input(array_type(vec![10, 10], UINT64)).unwrap();
4876        let bits = graph.a2b(input4.clone()).unwrap();
4877        graph.b2a(bits.clone(), UINT64).unwrap();
4878        let t = graph
4879            .create_tuple(vec![input1.clone(), input2.clone()])
4880            .unwrap();
4881        let _v = graph
4882            .create_vector(scalar_type(BIT), vec![input1.clone(), input2.clone()])
4883            .unwrap();
4884        let nt = graph
4885            .create_named_tuple(vec![
4886                ("Name".to_owned(), input1.clone()),
4887                ("Gender".to_owned(), input2.clone()),
4888            ])
4889            .unwrap();
4890        graph.tuple_get(t, 1).unwrap();
4891        graph.named_tuple_get(nt, "Gender".to_owned()).unwrap();
4892        let v = graph.repeat(c.clone(), 100).unwrap();
4893        graph.zip(vec![v.clone(), v.clone(), v.clone()]).unwrap();
4894        let zero = graph
4895            .constant(scalar_type(UINT64), Value::from_bytes(vec![0; 8]))
4896            .unwrap();
4897        graph.vector_get(v, zero).unwrap();
4898        graph.array_to_vector(input1.clone()).unwrap();
4899        graph.vector_to_array(input1.clone()).unwrap();
4900    }
4901
4902    #[test]
4903    fn call_iterate_test() {
4904        let context = create_unchecked_context().unwrap();
4905        let single_bit_adder = context.create_graph().unwrap();
4906        {
4907            let carry = single_bit_adder.input(scalar_type(BIT)).unwrap();
4908            let inputs = single_bit_adder
4909                .input(tuple_type(vec![scalar_type(BIT), scalar_type(BIT)]))
4910                .unwrap();
4911            let a = single_bit_adder.tuple_get(inputs.clone(), 0).unwrap();
4912            let b = single_bit_adder.tuple_get(inputs.clone(), 1).unwrap();
4913            let ac = single_bit_adder.add(carry.clone(), a.clone()).unwrap();
4914            let bc = single_bit_adder.add(carry.clone(), b.clone()).unwrap();
4915            let result = single_bit_adder.add(ac.clone(), b.clone()).unwrap();
4916            let result_carry = single_bit_adder
4917                .add(
4918                    single_bit_adder.multiply(ac.clone(), bc.clone()).unwrap(),
4919                    carry,
4920                )
4921                .unwrap();
4922            let output = single_bit_adder
4923                .create_tuple(vec![result_carry.clone(), result.clone()])
4924                .unwrap();
4925            single_bit_adder.set_output_node(output).unwrap();
4926            single_bit_adder.finalize().unwrap();
4927        }
4928        let v32 = vector_type(32, scalar_type(BIT));
4929        let adder = context.create_graph().unwrap();
4930        {
4931            let a = adder.input(v32.clone()).unwrap();
4932            let b = adder.input(v32.clone()).unwrap();
4933            let azb = adder.zip(vec![a, b]).unwrap();
4934            let c = adder
4935                .constant(scalar_type(BIT), Value::from_bytes(vec![0]))
4936                .unwrap();
4937            let cr = adder.iterate(single_bit_adder, c, azb).unwrap();
4938            let r = adder.tuple_get(cr, 1).unwrap();
4939            adder.set_output_node(r).unwrap();
4940            adder.finalize().unwrap();
4941        }
4942        let three_adder = context.create_graph().unwrap();
4943        let a = three_adder.input(v32.clone()).unwrap();
4944        let b = three_adder.input(v32.clone()).unwrap();
4945        let c = three_adder.input(v32.clone()).unwrap();
4946        let result = three_adder
4947            .call(
4948                adder.clone(),
4949                vec![three_adder.call(adder.clone(), vec![a, b]).unwrap(), c],
4950            )
4951            .unwrap();
4952        three_adder.set_output_node(result).unwrap();
4953        three_adder.finalize().unwrap();
4954        context.set_main_graph(three_adder).unwrap();
4955        context.finalize().unwrap();
4956    }
4957
4958    #[test]
4959    fn test_malformed_graphs() {
4960        let context = create_unchecked_context().unwrap();
4961        let graph = context.create_graph().unwrap();
4962        let graph2 = context.create_graph().unwrap();
4963        let input1 = graph.input(scalar_type(UINT64)).unwrap();
4964        let input2 = graph2.input(scalar_type(UINT64)).unwrap();
4965        let e1 = graph.add(input1.clone(), input2.clone());
4966        assert!(e1.is_err());
4967        let fake_node = Node {
4968            body: Arc::new(AtomicRefCell::new(NodeBody {
4969                graph: graph.downgrade(),
4970                node_dependencies: vec![],
4971                graph_dependencies: vec![],
4972                operation: Operation::Input(scalar_type(BIT)),
4973                id: 0,
4974            })),
4975        };
4976        let e2 = graph.add(fake_node.clone(), input1.clone());
4977        assert!(e2.is_err());
4978        let fake_node_2 = Node {
4979            body: Arc::new(AtomicRefCell::new(NodeBody {
4980                graph: graph.downgrade(),
4981                node_dependencies: vec![],
4982                graph_dependencies: vec![],
4983                operation: Operation::Input(scalar_type(BIT)),
4984                id: 31337,
4985            })),
4986        };
4987        let e3 = graph.add(fake_node_2.clone(), input1.clone());
4988        assert!(e3.is_err());
4989        graph.set_output_node(input1.clone()).unwrap();
4990        graph.finalize().unwrap();
4991        let e4 = graph.add(input1.clone(), input1.clone());
4992        assert!(e4.is_err());
4993        let graph3 = context.create_graph().unwrap();
4994        let e5 = graph3.finalize();
4995        assert!(e5.is_err());
4996        let e6 = graph3.set_output_node(input1);
4997        assert!(e6.is_err());
4998    }
4999
5000    #[test]
5001    fn test_malformed_contexts() {
5002        let context = create_unchecked_context().unwrap();
5003        let e1 = context.finalize();
5004        assert!(e1.is_err());
5005        let graph = context.create_graph().unwrap();
5006        let e2 = graph.finalize();
5007        assert!(e2.is_err());
5008        graph
5009            .set_output_node(graph.create_tuple(vec![]).unwrap())
5010            .unwrap();
5011        let e4 = context.set_main_graph(graph.clone());
5012        assert!(e4.is_err());
5013        graph.finalize().unwrap();
5014        let e3 = context.finalize();
5015        assert!(e3.is_err());
5016        context.set_main_graph(graph.clone()).unwrap();
5017        context.finalize().unwrap();
5018    }
5019
5020    #[test]
5021    fn test_malformed_call_iterate() {
5022        let context1 = create_unchecked_context().unwrap();
5023        let graph1 = context1.create_graph().unwrap();
5024        let output = graph1.create_tuple(vec![]).unwrap();
5025        graph1.set_output_node(output).unwrap();
5026        let graph2 = context1.create_graph().unwrap();
5027        let e1 = graph2.call(graph1.clone(), vec![]);
5028        assert!(e1.is_err());
5029        graph1.finalize().unwrap();
5030        graph2.call(graph1.clone(), vec![]).unwrap();
5031        let context2 = create_unchecked_context().unwrap();
5032        let graph3 = context2.create_graph().unwrap();
5033        let e2 = graph3.call(graph1.clone(), vec![]);
5034        assert!(e2.is_err());
5035        let graph4 = context1.create_graph().unwrap();
5036        graph4.input(tuple_type(vec![])).unwrap();
5037        graph4.input(tuple_type(vec![])).unwrap();
5038        let t = graph4.create_tuple(vec![]).unwrap();
5039        let tt = graph4.create_tuple(vec![t.clone(), t.clone()]).unwrap();
5040        graph4.set_output_node(tt).unwrap();
5041        let graph5 = context1.create_graph().unwrap();
5042        let es = graph5.create_tuple(vec![]).unwrap();
5043        let v = graph5
5044            .repeat(graph5.create_tuple(vec![]).unwrap(), 10)
5045            .unwrap();
5046        let e3 = graph5.iterate(graph4.clone(), es.clone(), v.clone());
5047        assert!(e3.is_err());
5048        graph4.finalize().unwrap();
5049        graph5
5050            .iterate(graph4.clone(), es.clone(), v.clone())
5051            .unwrap();
5052        let graph6 = context2.create_graph().unwrap();
5053        let es = graph6.create_tuple(vec![]).unwrap();
5054        let v = graph6
5055            .repeat(graph6.create_tuple(vec![]).unwrap(), 10)
5056            .unwrap();
5057        let e4 = graph6.iterate(graph4.clone(), es.clone(), v.clone());
5058        assert!(e4.is_err());
5059    }
5060
5061    #[test]
5062    fn test_graph_consistency() {
5063        let context = create_unchecked_context().unwrap();
5064        let graph = context.create_graph().unwrap();
5065        let input1 = graph.input(scalar_type(BIT)).unwrap();
5066        let input2 = graph.input(scalar_type(BIT)).unwrap();
5067        graph.add(input1.clone(), input2.clone()).unwrap();
5068        graph.set_output_node(input1.clone()).unwrap();
5069        graph.finalize().unwrap();
5070        for (i, node) in graph.get_nodes().iter().enumerate() {
5071            assert_eq!(node.get_id(), i as u64);
5072            assert!(graph == node.get_graph());
5073            for dependency in node.get_node_dependencies() {
5074                assert!(dependency.get_id() < node.get_id());
5075            }
5076        }
5077        let operations: Vec<Operation> = graph
5078            .get_nodes()
5079            .iter()
5080            .map(|x| x.get_operation())
5081            .collect();
5082        assert!(operations.len() == 3);
5083        if !operations[0].is_input() {
5084            panic!("Input expected");
5085        }
5086        if !operations[1].is_input() {
5087            panic!("Input expected");
5088        }
5089        match operations[2] {
5090            Operation::Add => {}
5091            _ => {
5092                panic!("Add expected");
5093            }
5094        }
5095    }
5096
5097    #[test]
5098    fn test_unfinalized_graphs() {
5099        let context = create_unchecked_context().unwrap();
5100        let e = context.finalize();
5101        assert!(e.is_err());
5102        let graph = context.create_graph().unwrap();
5103        let graph2 = context.create_graph().unwrap();
5104        let e = context.finalize();
5105        assert!(e.is_err());
5106        let i = graph2.input(scalar_type(BIT)).unwrap();
5107        graph2.set_output_node(i).unwrap();
5108        graph2.finalize().unwrap();
5109        context.set_main_graph(graph2).unwrap();
5110        let e = context.finalize();
5111        assert!(e.is_err());
5112        let ii = graph.input(scalar_type(BIT)).unwrap();
5113        graph.set_output_node(ii).unwrap();
5114        graph.finalize().unwrap();
5115        context.finalize().unwrap();
5116    }
5117
5118    #[test]
5119    fn test_operation_serialization() {
5120        let o = Operation::Constant(scalar_type(BIT), Value::from_bytes(vec![1]));
5121        let se = serde_json::to_string(&o).unwrap();
5122        assert_eq!(
5123            se,
5124            format!("{{\"Constant\":[{{\"Scalar\":\"bit\"}},{{\"version\":{},\"data\":\"{{\\\"body\\\":{{\\\"Bytes\\\":[1]}}}}\"}}]}}", DATA_VERSION)
5125        );
5126        let de = serde_json::from_str::<Operation>(&se).unwrap();
5127        assert_eq!(de, o);
5128    }
5129
5130    fn context_generators() -> Vec<Box<dyn Fn() -> Context>> {
5131        let context1 = || {
5132            let context = create_unchecked_context().unwrap();
5133            context
5134        };
5135        let context2 = || {
5136            let context = create_unchecked_context().unwrap();
5137            let graph = context.create_graph().unwrap();
5138            let i = graph.input(scalar_type(BIT)).unwrap();
5139            graph.set_output_node(i).unwrap();
5140            graph.finalize().unwrap();
5141            context.set_main_graph(graph).unwrap();
5142            context.finalize().unwrap();
5143            context
5144        };
5145        let context3 = || {
5146            let context = create_unchecked_context().unwrap();
5147            context.create_graph().unwrap();
5148            context
5149        };
5150        let context4 = || {
5151            let context = create_unchecked_context().unwrap();
5152            let graph = context.create_graph().unwrap();
5153            let i = graph.input(scalar_type(BIT)).unwrap();
5154            graph.set_output_node(i).unwrap();
5155            graph.finalize().unwrap();
5156            context
5157        };
5158        let context5 = || {
5159            let context = create_unchecked_context().unwrap();
5160            let graph = context.create_graph().unwrap();
5161            graph.input(scalar_type(BIT)).unwrap();
5162            context
5163        };
5164        let context6 = || {
5165            let context = create_unchecked_context().unwrap();
5166            let graph = context.create_graph().unwrap();
5167            graph
5168                .constant(scalar_type(BIT), Value::from_bytes(vec![1]))
5169                .unwrap();
5170            context
5171        };
5172        let context7 = || {
5173            let context = create_unchecked_context().unwrap();
5174            let graph = context.create_graph().unwrap();
5175            let i1 = graph.input(scalar_type(BIT)).unwrap();
5176            let i2 = graph.input(scalar_type(BIT)).unwrap();
5177            graph.add(i1, i2).unwrap();
5178            context
5179        };
5180        let context8 = || {
5181            let context = create_unchecked_context().unwrap();
5182            let graph = context.create_graph().unwrap();
5183            let i1 = graph.input(scalar_type(BIT)).unwrap();
5184            let i2 = graph.input(scalar_type(BIT)).unwrap();
5185            graph.add(i2, i1).unwrap();
5186            context
5187        };
5188        let context9 = || {
5189            let context = create_unchecked_context().unwrap();
5190            let graph1 = context.create_graph().unwrap();
5191            let i1 = graph1.input(scalar_type(BIT)).unwrap();
5192            graph1.set_output_node(i1).unwrap();
5193            graph1.finalize().unwrap();
5194            let graph2 = context.create_graph().unwrap();
5195            let i2 = graph2.input(scalar_type(BIT)).unwrap();
5196            graph2.set_output_node(i2).unwrap();
5197            graph2.finalize().unwrap();
5198            let graph3 = context.create_graph().unwrap();
5199            let i = graph3.input(scalar_type(BIT)).unwrap();
5200            graph3.call(graph1, vec![i]).unwrap();
5201            context
5202        };
5203        let context10 = || {
5204            let context = create_unchecked_context().unwrap();
5205            let graph1 = context.create_graph().unwrap();
5206            let i1 = graph1.input(scalar_type(BIT)).unwrap();
5207            graph1.set_output_node(i1).unwrap();
5208            graph1.finalize().unwrap();
5209            let graph2 = context.create_graph().unwrap();
5210            let i2 = graph2.input(scalar_type(BIT)).unwrap();
5211            graph2.set_output_node(i2).unwrap();
5212            graph2.finalize().unwrap();
5213            let graph3 = context.create_graph().unwrap();
5214            let i = graph3.input(scalar_type(BIT)).unwrap();
5215            graph3.call(graph2, vec![i]).unwrap();
5216            context
5217        };
5218        let context11 = || {
5219            let context = create_unchecked_context().unwrap();
5220            let graph1 = context.create_graph().unwrap();
5221            let i1 = graph1.input(scalar_type(BIT)).unwrap();
5222            graph1.set_output_node(i1).unwrap();
5223            graph1.finalize().unwrap();
5224            let graph2 = context.create_graph().unwrap();
5225            let i2 = graph2.input(scalar_type(BIT)).unwrap();
5226            graph2.set_output_node(i2).unwrap();
5227            graph2.finalize().unwrap();
5228            let graph3 = context.create_graph().unwrap();
5229            let i = graph3.input(scalar_type(BIT)).unwrap();
5230            let o = graph3.call(graph2, vec![i]).unwrap();
5231            graph3.set_output_node(o).unwrap();
5232            context
5233        };
5234        let context12 = || {
5235            let context = create_unchecked_context().unwrap();
5236            let graph = context.create_graph().unwrap();
5237            let i = graph.input(scalar_type(BIT)).unwrap();
5238            graph.set_output_node(i).unwrap();
5239            graph.finalize().unwrap();
5240            context.set_main_graph(graph).unwrap();
5241            context
5242        };
5243        let context13 = || {
5244            let context = create_unchecked_context().unwrap();
5245            let graph = context.create_graph().unwrap();
5246            let i = graph.input(scalar_type(BIT)).unwrap();
5247            graph.set_output_node(i).unwrap();
5248            graph.finalize().unwrap();
5249            context.set_main_graph(graph.clone()).unwrap();
5250            context.set_graph_name(graph, "main").unwrap();
5251            context.finalize().unwrap();
5252            context
5253        };
5254        let context14 = || {
5255            let context = create_unchecked_context().unwrap();
5256            let graph = context.create_graph().unwrap();
5257            let i = graph.input(scalar_type(BIT)).unwrap();
5258            graph.set_output_node(i.clone()).unwrap();
5259            graph.finalize().unwrap();
5260            context.set_main_graph(graph.clone()).unwrap();
5261            context.set_graph_name(graph.clone(), "main").unwrap();
5262            context.set_node_name(i.clone(), "input").unwrap();
5263            context
5264                .add_graph_annotation(&graph, GraphAnnotation::AssociativeOperation)
5265                .unwrap();
5266            context
5267                .add_node_annotation(&i, NodeAnnotation::AssociativeOperation)
5268                .unwrap();
5269            context.finalize().unwrap();
5270            context
5271        };
5272        let context15 = || {
5273            let context = create_unchecked_context().unwrap();
5274            let graph = context.create_graph().unwrap();
5275            let mut x = graph.input(scalar_type(BIT)).unwrap();
5276            for i in 1..20 {
5277                let y = graph.input(scalar_type(BIT)).unwrap();
5278                y.set_name(format!("input_{}", i).as_str()).unwrap();
5279                x = graph.add(x, y).unwrap();
5280            }
5281            graph.set_output_node(x).unwrap();
5282            graph.finalize().unwrap();
5283            context
5284        };
5285        let mut closures: Vec<Box<dyn Fn() -> Context>> = vec![];
5286        closures.push(Box::new(context1));
5287        closures.push(Box::new(context2));
5288        closures.push(Box::new(context3));
5289        closures.push(Box::new(context4));
5290        closures.push(Box::new(context5));
5291        closures.push(Box::new(context6));
5292        closures.push(Box::new(context7));
5293        closures.push(Box::new(context8));
5294        closures.push(Box::new(context9));
5295        closures.push(Box::new(context10));
5296        closures.push(Box::new(context11));
5297        closures.push(Box::new(context12));
5298        closures.push(Box::new(context13));
5299        closures.push(Box::new(context14));
5300        closures.push(Box::new(context15));
5301        closures
5302    }
5303
5304    fn test_context_deep_equal_helper_equal<F>(f: F)
5305    where
5306        F: Fn() -> Context,
5307    {
5308        let context1 = f();
5309        let context2 = f();
5310        assert!(context1 != context2);
5311        assert!(contexts_deep_equal(context1, context2));
5312    }
5313
5314    fn test_context_deep_equal_helper_nonequal<F1, F2>(f1: F1, f2: F2)
5315    where
5316        F1: Fn() -> Context,
5317        F2: Fn() -> Context,
5318    {
5319        let context1 = f1();
5320        let context2 = f2();
5321        assert!(context1 != context2);
5322        assert!(!contexts_deep_equal(context1, context2));
5323    }
5324
5325    #[test]
5326    fn test_context_deep_equal() {
5327        let generators = context_generators();
5328        for i in 0..generators.len() {
5329            test_context_deep_equal_helper_equal(&generators[i]);
5330            for j in 0..i {
5331                test_context_deep_equal_helper_nonequal(&generators[i], &generators[j]);
5332            }
5333        }
5334    }
5335
5336    pub fn deserialize_error_lenient(serialized_string: &str, error_msg: &str) {
5337        use std::panic::catch_unwind;
5338        panic::set_hook(Box::new(|_info| {
5339            // See: https://stackoverflow.com/questions/35559267/suppress-panic-output-in-rust-when-using-paniccatch-unwind
5340        }));
5341        let result = catch_unwind(|| serde_json::from_str::<Context>(serialized_string).unwrap());
5342        // This is a (nasty) hack.
5343        // We check whether the returned error contain the expected error message.
5344        use ciphercore_utils::execute_main::extract_panic_message;
5345        if let Err(e) = result {
5346            match extract_panic_message(e) {
5347                Some(msg) => {
5348                    if !msg.contains(error_msg) {
5349                        panic!("Undesirable panic: {}", msg);
5350                    }
5351                }
5352                None => panic!("Panic of unknown type"),
5353            }
5354        } else {
5355            panic!("Expected error not occur")
5356        }
5357    }
5358
5359    use std::{
5360        fs::File,
5361        io::{prelude::*, BufReader},
5362        path::Path,
5363    };
5364
5365    fn lines_from_file(filename: impl AsRef<Path>) -> Vec<String> {
5366        let file = File::open(filename).expect("no such file");
5367        let buf = BufReader::new(file);
5368        buf.lines()
5369            .map(|l| l.expect("Could not parse line"))
5370            .collect()
5371    }
5372
5373    #[test]
5374    fn test_context_serialize() {
5375        let generators = context_generators();
5376        let contexts: Vec<Context> = generators.iter().map(|generator| generator()).collect();
5377        let serialized_contexts: Vec<String> = contexts
5378            .iter()
5379            .map(|context| serde_json::to_string(context).unwrap())
5380            .collect();
5381        let deserialized_contexts: Vec<Context> = serialized_contexts
5382            .iter()
5383            .map(|serialized_context| serde_json::from_str(serialized_context).unwrap())
5384            .collect();
5385        assert_eq!(contexts.len(), deserialized_contexts.len());
5386        for i in 0..contexts.len() {
5387            assert!(contexts[i] != deserialized_contexts[i]);
5388            assert!(contexts_deep_equal(
5389                contexts[i].clone(),
5390                deserialized_contexts[i].clone()
5391            ));
5392            assert_eq!(
5393                serialized_contexts[i],
5394                serde_json::to_string(&deserialized_contexts[i]).unwrap()
5395            )
5396        }
5397
5398        //Read test cases from golden file
5399        let test_case = lines_from_file("./src/test_data/version_testcase.txt");
5400        assert_eq!(serde_json::to_string(&contexts[0]).unwrap(), test_case[0]);
5401
5402        //Following test case expect an error message "Non-existent main graph" which is caused by the field "\"main_graph\":918276318"
5403        deserialize_error_lenient(&test_case[1], "Non-existent main graph");
5404        assert_eq!(serde_json::to_string(&contexts[9]).unwrap(), test_case[2]);
5405        //Following test case expect an error message "Non-existent node dependency" which is caused by the field "\"node_dependencies\":[918723]"
5406        deserialize_error_lenient(&test_case[3], "Non-existent node dependency");
5407        //Following test case expect an error message "Non-existent graph dependency" which is caused by the field "\"graph_dependencies\":[918723]"
5408        deserialize_error_lenient(&test_case[4], "Non-existent graph dependency");
5409        assert_eq!(serde_json::to_string(&contexts[13]).unwrap(), test_case[5]);
5410        //Following test case expect an error message "Non-existent output node" which is caused by the field "\"output_node\":9817273"
5411        deserialize_error_lenient(&test_case[6], "Non-existent output node");
5412        //Following test case expect an error message "graphs_names contain an invalid ID" which is caused by the field "\"graphs_names\":[[8079123,\"main\"]]"
5413        deserialize_error_lenient(&test_case[7], "graphs_names contain an invalid ID");
5414        //Following test case expect an error message "nodes_names contain an invalid graph ID" which is caused by the field "\"nodes_names\":[[[8079123,0],\"input\"]]"
5415        deserialize_error_lenient(&test_case[8], "nodes_names contain an invalid graph ID");
5416        //Following test case expect an error message "nodes_names contain an invalid graph ID" which is caused by the field "\"nodes_names\":[[[0,8079123],\"input\"]]"
5417        deserialize_error_lenient(&test_case[9], "nodes_names contain an invalid node ID");
5418        //Following test case expect an error message "Context version doesn't match the requirement" which is caused by its old version number
5419        deserialize_error_lenient(
5420            &test_case[10],
5421            "Context version doesn't match the requirement",
5422        );
5423        //Following test case expect an error message "Context version doesn't match the requirement" which is caused by its old version number. Although its payload is unsupported, this should not cause any error before passing the version check.
5424        deserialize_error_lenient(
5425            &test_case[11],
5426            "Context version doesn't match the requirement",
5427        );
5428    }
5429
5430    use crate::data_types::INT32;
5431    use crate::data_values::Value;
5432    use crate::evaluators::random_evaluate;
5433    use std::iter::FromIterator;
5434
5435    #[test]
5436    fn test_named_contexts() {
5437        let helper = || -> Result<Context> {
5438            let context = create_context()?;
5439            let graph = context.create_graph()?;
5440            let input_a = graph.input(scalar_type(INT32))?;
5441            let input_b = graph.input(scalar_type(INT32))?;
5442            let output = graph.add(input_a.clone(), input_b.clone())?;
5443            graph.set_output_node(output.clone())?;
5444            graph.finalize()?;
5445            context.set_main_graph(graph.clone())?;
5446            assert!(context.get_graph_name(graph.clone()).is_err());
5447            assert!(context.retrieve_graph("main").is_err());
5448            assert!(context.get_node_name(input_a.clone())?.is_none());
5449            assert!(context.retrieve_node(graph.clone(), "a").is_err());
5450            context.set_graph_name(graph.clone(), "main")?;
5451            context.set_node_name(input_a.clone(), "a")?;
5452            assert!(context.retrieve_node(graph.clone(), "b").is_err());
5453            context.set_node_name(input_b.clone(), "b")?;
5454            context.finalize()?;
5455            assert_eq!(context.get_graph_name(graph.clone())?, "main");
5456            assert_eq!(
5457                context.get_node_name(input_a.clone())?,
5458                Some("a".to_owned())
5459            );
5460            assert_eq!(
5461                context.get_node_name(input_b.clone())?,
5462                Some("b".to_owned())
5463            );
5464            assert!(context.retrieve_node(graph.clone(), "a")? == input_a.clone());
5465            Ok(context)
5466        };
5467        let context = helper().unwrap();
5468        let helper2 = |context: Context| -> Result<i32> {
5469            let other_context = create_context()?;
5470            let other_graph = other_context.create_graph()?;
5471            let input = other_graph.input(scalar_type(BIT))?;
5472            let other_input = other_graph.input(scalar_type(BIT))?;
5473            assert!(context
5474                .prepare_input_values::<Value>(other_graph.clone(), HashMap::new())
5475                .is_err());
5476            assert!(other_context
5477                .prepare_input_values::<Value>(
5478                    other_graph.clone(),
5479                    HashMap::from_iter([("a", Value::from_scalar(123, INT32)?)])
5480                )
5481                .is_err());
5482            other_context.set_node_name(input, "b")?;
5483            assert!(other_context
5484                .prepare_input_values::<Value>(
5485                    other_graph.clone(),
5486                    HashMap::from_iter([("a", Value::from_scalar(123, INT32)?)])
5487                )
5488                .is_err());
5489            assert!(other_context
5490                .prepare_input_values::<Value>(
5491                    other_graph.clone(),
5492                    HashMap::from_iter([("b", Value::from_scalar(123, INT32)?)])
5493                )
5494                .is_err());
5495            other_context.set_node_name(other_input, "c")?;
5496            assert!(other_context
5497                .prepare_input_values::<Value>(
5498                    other_graph,
5499                    HashMap::from_iter([("b", Value::from_scalar(123, INT32)?)])
5500                )
5501                .is_err());
5502            let g = context.retrieve_graph("main")?;
5503            let result = random_evaluate(
5504                g.clone(),
5505                context.prepare_input_values(
5506                    g.clone(),
5507                    HashMap::from_iter([
5508                        ("a", Value::from_scalar(123, INT32)?),
5509                        ("b", Value::from_scalar(456, INT32)?),
5510                    ]),
5511                )?,
5512            )?;
5513            let result = result.to_i32(INT32)?;
5514            Ok(result)
5515        };
5516        assert_eq!(helper2(context).unwrap(), 579);
5517        let helper3 = |context: Context| -> Result<()> {
5518            let other_context = create_context()?;
5519            let other_graph = other_context.create_graph()?;
5520            let other_node = other_graph.input(scalar_type(BIT))?;
5521            assert!(context
5522                .set_graph_name(other_graph.clone(), "outside")
5523                .is_err());
5524            assert!(context.get_graph_name(other_graph.clone()).is_err());
5525            assert!(context
5526                .set_node_name(other_node.clone(), "outside")
5527                .is_err());
5528            assert!(context.get_node_name(other_node.clone()).is_err());
5529            assert!(context.retrieve_node(other_graph.clone(), "a").is_err());
5530            Ok(())
5531        };
5532        helper3(helper().unwrap()).unwrap();
5533        let helper4 = || -> Result<()> {
5534            let context = create_context()?;
5535            let graph = context.create_graph()?;
5536            let input = graph.input(scalar_type(BIT))?;
5537            graph.set_output_node(input.clone())?;
5538            graph.finalize()?;
5539            context.set_main_graph(graph.clone())?;
5540            context.finalize()?;
5541            assert!(context.set_graph_name(graph, "main").is_err());
5542            assert!(context.set_node_name(input, "input").is_err());
5543            Ok(())
5544        };
5545        helper4().unwrap();
5546        let helper5 = || -> Result<()> {
5547            let context = create_context()?;
5548            let graph = context.create_graph()?;
5549            let input = graph.input(scalar_type(BIT))?;
5550            let other_graph = context.create_graph()?;
5551            let other_input = graph.input(scalar_type(BIT))?;
5552            context.set_graph_name(graph.clone(), "main")?;
5553            assert!(context.set_graph_name(graph, "main3").is_err());
5554            assert!(context.set_graph_name(other_graph, "main").is_err());
5555            context.set_node_name(input.clone(), "input")?;
5556            assert!(context.set_node_name(input, "input3").is_err());
5557            assert!(context.set_node_name(other_input, "input").is_err());
5558            Ok(())
5559        };
5560        helper5().unwrap();
5561    }
5562
5563    #[test]
5564    fn test_context_type_checking() {
5565        || -> Result<()> {
5566            let context = create_context()?;
5567            let g = context.create_graph()?;
5568            let i = g.input(tuple_type(vec![]))?;
5569            assert!(g.add(i.clone(), i.clone()).is_err());
5570            // Now checking that the node actually have not gotten added by accident
5571            assert_eq!(g.get_nodes().len(), 1);
5572            Ok(())
5573        }()
5574        .unwrap();
5575    }
5576
5577    fn generate_pair_of_equal_contexts() -> Vec<(Context, Context)> {
5578        let context1 = || -> Result<Context> {
5579            let context = create_unchecked_context()?;
5580            let g = context.create_graph()?;
5581            let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
5582            g.set_output_node(i)?;
5583            g.finalize()?;
5584            g.set_as_main()?;
5585            Ok(context)
5586        }()
5587        .unwrap();
5588        let context2 = || -> Result<Context> {
5589            let context = create_unchecked_context()?;
5590            let g = context.create_graph()?;
5591            let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
5592            i.set_as_output()?;
5593            g.finalize()?;
5594            context.set_main_graph(g)?;
5595            Ok(context)
5596        }()
5597        .unwrap();
5598        let context3 = || -> Result<Context> {
5599            let context = create_unchecked_context()?;
5600            let g = context.create_graph()?;
5601            context.set_graph_name(g, "random graph name")?;
5602            Ok(context)
5603        }()
5604        .unwrap();
5605        let context4 = || -> Result<Context> {
5606            let context = create_unchecked_context()?;
5607            context.create_graph()?.set_name("random graph name")?;
5608            Ok(context)
5609        }()
5610        .unwrap();
5611        let context5 = || -> Result<Context> {
5612            let context = create_unchecked_context()?;
5613            let g = context.create_graph()?;
5614            let i = g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?;
5615            context.set_node_name(i, "random node name")?;
5616            Ok(context)
5617        }()
5618        .unwrap();
5619        let context6 = || -> Result<Context> {
5620            let context = create_unchecked_context()?;
5621            let g = context.create_graph()?;
5622            g.constant(scalar_type(BIT), Value::from_scalar(0, BIT)?)?
5623                .set_name("random node name")?;
5624            Ok(context)
5625        }()
5626        .unwrap();
5627        let context7 = || -> Result<Context> {
5628            let context = create_unchecked_context()?;
5629            let g = context.create_graph()?;
5630            let i1 = g.input(scalar_type(BIT))?;
5631            let i2 = g.input(scalar_type(BIT))?;
5632            g.add(i1.clone(), i2.clone())?;
5633            g.subtract(i1.clone(), i2.clone())?;
5634            g.multiply(i1.clone(), i2.clone())?;
5635            g.dot(i1.clone(), i2.clone())?;
5636            g.matmul(i1.clone(), i2.clone())?;
5637            g.truncate(i1.clone(), 123)?;
5638            g.sum(i1.clone(), vec![1, 4, 7])?;
5639            g.permute_axes(i1.clone(), vec![1, 4, 7])?;
5640            g.get(i1.clone(), vec![1, 4])?;
5641            g.reshape(i1.clone(), array_type(vec![12, 34], BIT))?;
5642            g.nop(i1.clone())?;
5643            g.prf(i1.clone(), 123, scalar_type(BIT))?;
5644            g.a2b(i1.clone())?;
5645            g.b2a(i1.clone(), BIT)?;
5646            g.tuple_get(i1.clone(), 0)?;
5647            g.named_tuple_get(i1.clone(), "field name".to_owned())?;
5648            g.vector_get(i1.clone(), i2)?;
5649            g.array_to_vector(i1.clone())?;
5650            g.vector_to_array(i1.clone())?;
5651            g.repeat(i1.clone(), 123)?;
5652            Ok(context)
5653        }()
5654        .unwrap();
5655        let context8 = || -> Result<Context> {
5656            let context = create_unchecked_context()?;
5657            let g = context.create_graph()?;
5658            let i1 = g.input(scalar_type(BIT))?;
5659            let i2 = g.input(scalar_type(BIT))?;
5660            i1.add(i2.clone())?;
5661            i1.subtract(i2.clone())?;
5662            i1.multiply(i2.clone())?;
5663            i1.dot(i2.clone())?;
5664            i1.matmul(i2.clone())?;
5665            i1.truncate(123)?;
5666            i1.sum(vec![1, 4, 7])?;
5667            i1.permute_axes(vec![1, 4, 7])?;
5668            i1.get(vec![1, 4])?;
5669            i1.reshape(array_type(vec![12, 34], BIT))?;
5670            i1.nop()?;
5671            i1.prf(123, scalar_type(BIT))?;
5672            i1.a2b()?;
5673            i1.b2a(BIT)?;
5674            i1.tuple_get(0)?;
5675            i1.named_tuple_get("field name".to_owned())?;
5676            i1.vector_get(i2)?;
5677            i1.array_to_vector()?;
5678            i1.vector_to_array()?;
5679            i1.repeat(123)?;
5680            Ok(context)
5681        }()
5682        .unwrap();
5683        let result = vec![
5684            (context1, context2),
5685            (context3, context4),
5686            (context5, context6),
5687            (context7, context8),
5688        ];
5689        result
5690    }
5691
5692    #[test]
5693    fn test_node_graph_helpers() {
5694        let pairs_of_contexts = generate_pair_of_equal_contexts();
5695        for (context1, context2) in pairs_of_contexts {
5696            assert!(contexts_deep_equal(context1, context2));
5697        }
5698        || -> Result<()> {
5699            let context = create_context()?;
5700            let g = context.create_graph()?.set_name("graph name")?;
5701            let i = g.input(scalar_type(BIT))?.set_name("node name")?;
5702            assert_eq!(g.get_name()?, "graph name");
5703            assert!(g.retrieve_node("node name")? == i);
5704            assert_eq!(i.get_name()?, Some("node name".to_owned()));
5705            assert_eq!(
5706                g.prepare_input_values(hashmap!("node name" => Value::from_scalar(1, BIT)?))?,
5707                vec![Value::from_scalar(1, BIT)?]
5708            );
5709            Ok(())
5710        }()
5711        .unwrap();
5712    }
5713
5714    #[test]
5715    fn test_operation_fmt_display() {
5716        let test_operation_fmt_display_helper = || -> Result<()> {
5717            let o0 = Rc::new(Operation::Input(scalar_type(UINT16)));
5718            assert_eq!(format!("{}", o0), "Input");
5719            let o1 = Rc::new(Operation::Add);
5720            assert_eq!(format!("{}", o1), "Add");
5721            let o2 = Rc::new(Operation::Truncate(10));
5722            assert_eq!(format!("{}", o2), "Truncate");
5723            let o3 = Rc::new(Operation::Get(vec![10, 20]));
5724            assert_eq!(format!("{}", o3), "Get");
5725            let o4 = Rc::new(Operation::NOP);
5726            assert_eq!(format!("{}", o4), "NOP");
5727            let o5 = Rc::new(Operation::CreateNamedTuple(vec![
5728                "Name".to_string(),
5729                "Address".to_string(),
5730            ]));
5731            assert_eq!(format!("{}", o5), "CreateNamedTuple");
5732            let o6 = Rc::new(Operation::NamedTupleGet("Name".to_string()));
5733            assert_eq!(format!("{}", o6), "NamedTupleGet");
5734            Ok(())
5735        };
5736        test_operation_fmt_display_helper().unwrap();
5737    }
5738
5739    #[test]
5740    fn test_annotations() {
5741        let test_annotations_helper = || -> Result<()> {
5742            let context = create_context()?;
5743            let g = context.create_graph()?;
5744            let i = g.input(scalar_type(BIT))?;
5745            g.add_annotation(GraphAnnotation::AssociativeOperation)?;
5746            i.add_annotation(NodeAnnotation::AssociativeOperation)?;
5747            assert_eq!(
5748                g.get_annotations()?,
5749                vec![GraphAnnotation::AssociativeOperation]
5750            );
5751            assert_eq!(
5752                i.get_annotations()?,
5753                vec![NodeAnnotation::AssociativeOperation]
5754            );
5755            Ok(())
5756        };
5757        test_annotations_helper().unwrap();
5758    }
5759
5760    async fn parallel_get_type(output: Node) -> Result<Type> {
5761        output.get_type()
5762    }
5763
5764    async fn parallel_random_evaluate(graph: Graph, context: Context) -> Result<Value> {
5765        random_evaluate(
5766            graph.clone(),
5767            context.prepare_input_values(
5768                graph,
5769                HashMap::from_iter([
5770                    ("one", Value::from_scalar(123, INT32)?),
5771                    ("two", Value::from_scalar(456, INT32)?),
5772                ]),
5773            )?,
5774        )
5775    }
5776
5777    async fn parallel_prepare_for_mpc_evaluation(
5778        context: Context,
5779        input_party_map: Vec<Vec<IOStatus>>,
5780        output_parties: Vec<Vec<IOStatus>>,
5781        inline_config: InlineConfig,
5782    ) -> Result<Context> {
5783        prepare_for_mpc_evaluation(context, input_party_map, output_parties, inline_config)
5784    }
5785
5786    #[tokio::test(flavor = "multi_thread", worker_threads = 50)]
5787    async fn test_parallel_after_finalize() -> Result<()> {
5788        let context = create_context()?;
5789        let graph = context.create_graph()?;
5790        let input1 = graph.input(scalar_type(INT32))?;
5791        let input2 = graph.input(scalar_type(INT32))?;
5792        let output = graph.add(input1.clone(), input2.clone())?;
5793        graph.set_output_node(output.clone())?;
5794        graph.finalize()?;
5795        context.set_main_graph(graph.clone())?;
5796
5797        context.set_node_name(input1.clone(), "one")?;
5798        context.set_node_name(input2.clone(), "two")?;
5799
5800        context.finalize()?;
5801
5802        assert!(output.clone().get_type().is_ok());
5803
5804        const PAR_ITERS: usize = 2001;
5805
5806        let mut get_type_futures = vec![];
5807        for _ in 0..PAR_ITERS {
5808            get_type_futures.push(parallel_get_type(output.clone()));
5809        }
5810        futures::future::try_join_all(get_type_futures).await?;
5811
5812        let mut get_random_evaluate_futures = vec![];
5813        for _ in 0..PAR_ITERS {
5814            get_random_evaluate_futures
5815                .push(parallel_random_evaluate(graph.clone(), context.clone()))
5816        }
5817        futures::future::try_join_all(get_random_evaluate_futures).await?;
5818
5819        let input_parties = vec![IOStatus::Party(0), IOStatus::Party(1)];
5820        let output_parties = vec![IOStatus::Party(0)];
5821
5822        let mut get_mpc_eval_futures = vec![];
5823        for _ in 0..PAR_ITERS {
5824            get_mpc_eval_futures.push(parallel_prepare_for_mpc_evaluation(
5825                context.clone(),
5826                vec![input_parties.clone()],
5827                vec![output_parties.clone()],
5828                InlineConfig::default(),
5829            ));
5830        }
5831        futures::future::try_join_all(get_mpc_eval_futures).await?;
5832
5833        Ok(())
5834    }
5835}