ciphercore_base/
custom_ops.rs

1//! Structs and traits necessary to implement custom operations.
2//! A custom operation can be thought of as a polymorphic function, i.e., where the number of inputs and their types can vary.
3//! Two basic examples of custom operations are provided: [Not] and [Or].
4use crate::data_types::{scalar_type, Type, BIT};
5use crate::errors::Result;
6use crate::graphs::{copy_node_name, create_context, Context, Graph, Node, Operation};
7
8use serde::{Deserialize, Serialize};
9
10use petgraph::algo::toposort;
11use petgraph::graph::{DiGraph, NodeIndex};
12
13use std::any::{Any, TypeId};
14use std::collections::{hash_map::DefaultHasher, HashMap};
15use std::fmt::Debug;
16use std::fmt::Write;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20#[cfg(feature = "py-binding")]
21use pywrapper_macro::struct_wrapper;
22
23#[doc(hidden)]
24/// This trait can be used to compare and hash trait objects.
25/// Based on
26/// <https://stackoverflow.com/questions/25339603/how-to-test-for-equality-between-trait-objects>
27/// and
28/// <https://stackoverflow.com/questions/64838355/how-do-i-create-a-hashmap-with-type-erased-keys>.
29pub trait DynEqHash {
30    fn as_any(&self) -> &dyn Any;
31    fn equals(&self, _: &dyn Any) -> bool;
32    fn hash(&self) -> u64;
33}
34
35impl<T: 'static + Eq + Hash> DynEqHash for T {
36    fn as_any(&self) -> &dyn Any {
37        self
38    }
39
40    fn equals(&self, other: &dyn Any) -> bool {
41        other.downcast_ref::<T>().map_or(false, |a| self == a)
42    }
43
44    /// To hash an instance of `T`, we hash a pair (identifier of the type T, value)
45    /// using the `DefaultHasher` of `hash_map`, which seems to be implementing SipHash
46    /// (<https://github.com/veorq/SipHash/>).
47    fn hash(&self) -> u64 {
48        let mut h = DefaultHasher::new();
49        Hash::hash(&(TypeId::of::<T>(), self), &mut h);
50        h.finish()
51    }
52}
53
54/// A trait that must be implemented by any custom operation struct.
55///
56/// Only structures satisfying this trait can be used to create [CustomOperation].
57///
58/// Any structure implementing this trait must also implement the following traits:
59/// - [Debug],
60/// - [Serialize],
61/// - [Deserialize],
62/// - [Eq],
63/// - [PartialEq],
64/// - [Hash]
65///
66/// # Example
67/// This is the actual implementation of the custom operation [Not].
68/// ```
69/// use serde::{Deserialize, Serialize};
70/// # use ciphercore_base::data_types::{BIT, scalar_type, Type};
71/// # use ciphercore_base::data_values::Value;
72/// # use ciphercore_base::graphs::{Context, Graph};
73/// # use ciphercore_base::custom_ops::{CustomOperationBody};
74/// # use ciphercore_base::errors::Result;
75/// # use ciphercore_base::runtime_error;
76///
77/// #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
78/// pub struct Not {}
79/// #[typetag::serde] // requires the typetag crate
80/// impl CustomOperationBody for Not {
81///    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
82///        if arguments_types.len() != 1 {
83///           return Err(runtime_error!("Invalid number of arguments for Not"));
84///        }
85///        let g = context.create_graph()?;
86///        g.input(arguments_types[0].clone())?
87///         .add(g.constant(scalar_type(BIT), Value::from_scalar(1, BIT)?)?)?
88///         .set_as_output()?;
89///        g.finalize()?;
90///        Ok(g)
91///    }
92///    fn get_name(&self) -> String {
93///        "Not".to_owned()
94///    }
95/// }
96/// ```
97///
98// This is used to add a concrete type tag when serializing
99// Any `impl CustomOperationBody` should have
100// #[typetag::serde] before it
101#[typetag::serde(tag = "type")]
102pub trait CustomOperationBody: 'static + Debug + DynEqHash + Send + Sync {
103    /// Defines the logic of a custom operation.
104    ///
105    /// This function must create a graph in a given context computing a custom operation.
106    /// Note that that the number of inputs and their types can vary.
107    /// This function should describe the logic of the custom operation for all acceptable cases and return an error otherwise.
108    ///
109    /// # Arguments
110    ///
111    /// * `context` - context where a graph computing a custom operation should be created
112    /// * `arguments_types` - vector of input types of a custom operation
113    ///
114    /// # Returns
115    ///
116    /// New graph computing a custom operation
117    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph>;
118
119    /// Specifies and returns the name of this custom operation.
120    ///
121    /// The name must be unique among all the implemented custom operations.
122    ///
123    /// # Returns
124    ///
125    /// Name of this custom operation
126    fn get_name(&self) -> String;
127}
128
129/// A structure that stores a pointer to a custom operation.
130///
131/// A custom operation can be thought of as a polymorphic function, i.e., where the number of inputs and their types can vary.
132///
133/// Any struct can be a custom operation if it implements the [CustomOperationBody] trait.
134/// Then any such struct can be used to create a [CustomOperation] object that can be added to a computation graph with [Graph::custom_op].
135///
136/// # Rust crates
137///
138/// [Clone] trait duplicates the pointer, not the underlying custom operation.
139///
140/// [PartialEq] trait compares the related custom operations, not just pointer.
141///
142/// # Example
143///
144/// ```
145/// # use ciphercore_base::graphs::create_context;
146/// # use ciphercore_base::data_types::{array_type, BIT};
147/// # use ciphercore_base::custom_ops::{CustomOperation, Not};
148/// let c = create_context().unwrap();
149/// let g = c.create_graph().unwrap();
150/// let t = array_type(vec![3, 2], BIT);
151/// let n1 = g.input(t).unwrap();
152/// let n2 = g.custom_op(CustomOperation::new(Not {}), vec![n1]).unwrap();
153/// ```
154#[derive(Clone, Debug, Deserialize, Serialize)]
155#[cfg_attr(feature = "py-binding", struct_wrapper)]
156pub struct CustomOperation {
157    body: Arc<dyn CustomOperationBody>,
158}
159
160#[cfg(feature = "py-binding")]
161#[pyo3::pymethods]
162impl PyBindingCustomOperation {
163    #[new]
164    fn new(value: String) -> pyo3::PyResult<Self> {
165        let custom_op = serde_json::from_str::<CustomOperation>(&value)
166            .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?;
167        Ok(PyBindingCustomOperation { inner: custom_op })
168    }
169    fn __str__(&self) -> pyo3::PyResult<String> {
170        serde_json::to_string(&self.inner)
171            .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))
172    }
173    fn __repr__(&self) -> pyo3::PyResult<String> {
174        self.__str__()
175    }
176}
177
178impl CustomOperation {
179    /// Creates a new custom operation that can be added to a computation graph via [Graph::custom_op].
180    ///
181    /// # Arguments
182    ///
183    /// `op` - struct that implements the [CustomOperationBody] trait
184    ///
185    /// # Returns
186    ///
187    /// New custom operation
188    ///
189    /// # Example
190    ///
191    /// ```
192    /// # use ciphercore_base::graphs::create_context;
193    /// # use ciphercore_base::data_types::{array_type, BIT};
194    /// # use ciphercore_base::custom_ops::{CustomOperation, Not};
195    /// let c = create_context().unwrap();
196    /// let g = c.create_graph().unwrap();
197    /// let t = array_type(vec![3, 2], BIT);
198    /// let n1 = g.input(t).unwrap();
199    /// let n2 = g.custom_op(CustomOperation::new(Not {}), vec![n1]).unwrap();
200    /// ```
201    pub fn new<T: 'static + CustomOperationBody>(op: T) -> CustomOperation {
202        CustomOperation { body: Arc::new(op) }
203    }
204
205    /// Returns the name of the underlying custom operation by calling [CustomOperationBody::get_name].
206    ///
207    /// # Returns
208    ///
209    /// Name of this custom operation
210    pub fn get_name(&self) -> String {
211        self.body.get_name()
212    }
213}
214
215impl CustomOperation {
216    #[doc(hidden)]
217    pub fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
218        self.body.instantiate(context, arguments_types)
219    }
220}
221
222impl PartialEq for CustomOperation {
223    /// Tests whether `self` and `other` custom operations are equal.
224    ///
225    /// The underlying custom operation structs are compared via the [Eq] trait.
226    ///
227    /// # Arguments
228    ///
229    /// `other` - another [CustomOperation]
230    ///
231    /// # Returns
232    ///
233    /// `true` if the pointers in `self` and `other` point to equal custom operations, `false` otherwise
234    fn eq(&self, other: &Self) -> bool {
235        self.body.equals((*other.body).as_any())
236    }
237}
238
239impl Hash for CustomOperation {
240    /// Hashes the custom operation pointer.
241    ///
242    /// # Arguments
243    ///
244    /// `state` - state of a hash function that is changed after hashing the custom operation
245    fn hash<H: Hasher>(&self, state: &mut H) {
246        let hash_value = DynEqHash::hash(self.body.as_ref());
247        state.write_u64(hash_value);
248    }
249}
250
251impl Eq for CustomOperation {}
252
253/// A structure that defines the custom operation Not that inverts elementwise a binary array or scalar (individual bit).
254///
255/// This operation accepts only a binary array or scalar as input.
256///
257/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
258///
259/// # Custom operation arguments
260///
261/// Node containing a binary array or scalar
262///
263/// # Custom operation returns
264///
265/// New Not node
266///
267/// # Example
268///
269/// ```
270/// # use ciphercore_base::graphs::create_context;
271/// # use ciphercore_base::data_types::{scalar_type, BIT};
272/// # use ciphercore_base::custom_ops::{CustomOperation, Not};
273/// let c = create_context().unwrap();
274/// let g = c.create_graph().unwrap();
275/// let t = scalar_type(BIT);
276/// let n1 = g.input(t.clone()).unwrap();
277/// let n2 = g.custom_op(CustomOperation::new(Not {}), vec![n1]).unwrap();
278/// ```
279#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
280pub struct Not {}
281
282#[typetag::serde]
283impl CustomOperationBody for Not {
284    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
285        if arguments_types.len() != 1 {
286            return Err(runtime_error!("Invalid number of arguments for Not"));
287        }
288        let g = context.create_graph()?;
289        g.input(arguments_types[0].clone())?
290            .add(g.ones(scalar_type(BIT))?)?
291            .set_as_output()?;
292        g.finalize()?;
293        Ok(g)
294    }
295
296    fn get_name(&self) -> String {
297        "Not".to_owned()
298    }
299}
300
301/// A structure that defines the custom operation Or that is equivalent to the binary Or applied elementwise.
302///
303/// This operation accepts only binary arrays or scalars as input.
304///
305/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
306/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2,3]`.
307///
308/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
309///
310/// # Custom operation arguments
311///
312/// - Node containing a binary array or scalar
313/// - Node containing a binary array or scalar
314///
315/// # Custom operation returns
316///
317/// New Or node
318///
319/// # Example
320///
321/// ```
322/// # use ciphercore_base::graphs::create_context;
323/// # use ciphercore_base::data_types::{scalar_type, BIT};
324/// # use ciphercore_base::custom_ops::{CustomOperation, Or};
325/// let c = create_context().unwrap();
326/// let g = c.create_graph().unwrap();
327/// let t = scalar_type(BIT);
328/// let n1 = g.input(t.clone()).unwrap();
329/// let n2 = g.input(t.clone()).unwrap();
330/// let n3 = g.custom_op(CustomOperation::new(Or {}), vec![n1, n2]).unwrap();
331/// ```
332#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
333pub struct Or {}
334
335#[typetag::serde]
336impl CustomOperationBody for Or {
337    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
338        if arguments_types.len() != 2 {
339            return Err(runtime_error!("Invalid number of arguments for Or"));
340        }
341        let g = context.create_graph()?;
342        let i1 = g.input(arguments_types[0].clone())?;
343        let i2 = g.input(arguments_types[1].clone())?;
344        let i1_not = g.custom_op(CustomOperation::new(Not {}), vec![i1])?;
345        let i2_not = g.custom_op(CustomOperation::new(Not {}), vec![i2])?;
346        g.custom_op(CustomOperation::new(Not {}), vec![i1_not.multiply(i2_not)?])?
347            .set_as_output()?;
348        g.finalize()?;
349        Ok(g)
350    }
351
352    fn get_name(&self) -> String {
353        "Or".to_owned()
354    }
355}
356
357#[doc(hidden)]
358/// Data structure for storing maps between two contexts.
359/// Can be used conveniently to glue one context into another.
360/// Note that `Node` and `Graph` are hashed as pointers,
361/// so in general it's not a good idea to enumerate the entries
362/// of the maps using iterators, since it's not deterministic.
363/// Instead, one should sort by IDs, and then enumerate.
364#[derive(Default)]
365pub struct ContextMappings {
366    node_mapping: HashMap<Node, Node>,
367    graph_mapping: HashMap<Graph, Graph>,
368}
369
370impl ContextMappings {
371    pub fn contains_graph(&self, graph: Graph) -> bool {
372        self.graph_mapping.contains_key(&graph)
373    }
374
375    pub fn contains_node(&self, node: Node) -> bool {
376        self.node_mapping.contains_key(&node)
377    }
378
379    /// Panics if `graph` is not in `graph_mapping`
380    pub fn get_graph(&self, graph: Graph) -> Graph {
381        self.graph_mapping
382            .get(&graph)
383            .expect("Graph is not found in graph_mapping")
384            .clone()
385    }
386
387    /// Panics if `node` is not in `node_mapping`
388    pub fn get_node(&self, node: Node) -> Node {
389        self.node_mapping
390            .get(&node)
391            .expect("Node is not found in node_mapping")
392            .clone()
393    }
394
395    /// Panics if `old_graph` has already been inserted
396    pub fn insert_graph(&mut self, old_graph: Graph, new_graph: Graph) {
397        assert!(
398            self.graph_mapping.insert(old_graph, new_graph).is_none(),
399            "Graph has already been inserted in graph_mapping"
400        );
401    }
402
403    /// Panics if `old_node` has already been inserted
404    pub fn insert_node(&mut self, old_node: Node, new_node: Node) {
405        assert!(
406            self.node_mapping.insert(old_node, new_node).is_none(),
407            "Node has already been inserted in node_mapping"
408        );
409    }
410
411    /// Panics if `old_graph` is not inserted
412    pub fn remove_graph(&mut self, old_graph: Graph) {
413        assert!(
414            self.graph_mapping.remove(&old_graph).is_some(),
415            "Graph is not in graph_mapping"
416        );
417    }
418
419    /// Panics if `old_node` is not inserted
420    pub fn remove_node(&mut self, old_node: Node) {
421        assert!(
422            self.node_mapping.remove(&old_node).is_some(),
423            "Node is not isn node_mapping"
424        );
425    }
426}
427
428#[doc(hidden)]
429pub struct MappedContext {
430    pub context: Context,
431    // old -> new mappings
432    pub mappings: ContextMappings,
433}
434
435impl MappedContext {
436    pub fn new(context: Context) -> Self {
437        MappedContext {
438            context,
439            mappings: ContextMappings::default(),
440        }
441    }
442
443    pub fn get_context(&self) -> Context {
444        self.context.clone()
445    }
446}
447
448/// An instantiation is given by a custom operation
449/// and types of the arguments.
450#[derive(Debug, Clone, PartialEq, Eq, Hash)]
451pub(super) struct Instantiation {
452    pub(super) op: CustomOperation,
453    pub(super) arguments_types: Vec<Type>,
454}
455
456impl Instantiation {
457    /// `create_from_node` assumes that `node` carries a custom operation,
458    /// and the ambient context is type checked.
459    fn create_from_node(node: Node) -> Result<Self> {
460        if let Operation::Custom(custom_op) = node.get_operation() {
461            let mut node_dependencies_types = vec![];
462            for dependency in node.get_node_dependencies() {
463                node_dependencies_types.push(dependency.get_type()?);
464            }
465            Ok(Instantiation {
466                op: custom_op,
467                arguments_types: node_dependencies_types,
468            })
469        } else {
470            Err(runtime_error!(
471                "Instantiations can only be created from custom nodes"
472            ))
473        }
474    }
475
476    fn get_name(&self) -> String {
477        let mut name = "__".to_owned();
478        name.push_str(&self.op.get_name());
479        name.push_str("::<");
480        let mut first_argument = true;
481        for t in &self.arguments_types {
482            if first_argument {
483                first_argument = false;
484            } else {
485                name.push_str(", ");
486            }
487            write!(name, "{t}").unwrap();
488        }
489        name.push('>');
490        name
491    }
492}
493
494/// Data structures for instantiation graph.
495type InstantiationsGraph = DiGraph<Instantiation, (), usize>;
496type InstantiationsGraphNode = NodeIndex<usize>;
497
498/// Mapping between the graph of instantiations of actual instantiations.
499#[derive(Default)]
500struct InstantiationsGraphMapping {
501    instantiation_to_node: HashMap<Instantiation, InstantiationsGraphNode>,
502    node_to_instantiation: HashMap<InstantiationsGraphNode, Instantiation>,
503}
504
505/// Retrieves a node of the instantiations graph
506/// (and adds a fresh one if necessary).
507fn get_instantiations_graph_node(
508    instantiation: &Instantiation,
509    instantiations_graph_mapping: &mut InstantiationsGraphMapping,
510    instantiations_graph: &mut InstantiationsGraph,
511) -> (InstantiationsGraphNode, bool) {
512    match instantiations_graph_mapping
513        .instantiation_to_node
514        .get(instantiation)
515    {
516        Some(id) => (*id, true),
517        None => {
518            let new_inode = instantiations_graph.add_node(instantiation.clone());
519            instantiations_graph_mapping
520                .instantiation_to_node
521                .insert(instantiation.clone(), new_inode);
522            instantiations_graph_mapping
523                .node_to_instantiation
524                .insert(new_inode, instantiation.clone());
525            (new_inode, false)
526        }
527    }
528}
529
530/// Recursive function to find all the necessary instantiations
531/// and build a graph of those (appropriately caching things).
532fn process_instantiation(
533    instantiation: &Instantiation,
534    instantiations_graph_mapping: &mut InstantiationsGraphMapping,
535    instantiations_graph: &mut InstantiationsGraph,
536) -> Result<()> {
537    let fake_context = create_context()?;
538    let graph = instantiation
539        .op
540        .instantiate(fake_context.clone(), instantiation.arguments_types.clone())?;
541    // `instantiate()` may potentially create some auxiliary graphs, which we now need to process
542    // TODO: add a test that check that this is properly done
543    for fake_graph in fake_context.get_graphs() {
544        for node in fake_graph.get_nodes() {
545            if let Operation::Custom(_) = node.get_operation() {
546                let new_instantiation = Instantiation::create_from_node(node)?;
547                let (node1, already_existed) = get_instantiations_graph_node(
548                    &new_instantiation,
549                    instantiations_graph_mapping,
550                    instantiations_graph,
551                );
552                let (node2, _) = get_instantiations_graph_node(
553                    instantiation,
554                    instantiations_graph_mapping,
555                    instantiations_graph,
556                );
557                instantiations_graph.add_edge(node1, node2, ());
558                if !already_existed {
559                    process_instantiation(
560                        &new_instantiation,
561                        instantiations_graph_mapping,
562                        instantiations_graph,
563                    )?;
564                }
565            }
566        }
567    }
568    graph.set_as_main()?;
569    fake_context.finalize()?;
570    Ok(())
571}
572
573#[doc(hidden)]
574/// In order to instantiate all the custom operations in a given context,
575/// we do the following. First, we build a graph of instantiations as follows.
576/// As a seed set, we use the instantiated custom operations in the original context.
577/// Then, starting from this seed set, we detect custom operations necessary further
578/// down the road, and instantiate them recursively etc. If we deal with instantiation
579/// we already encountered, we stop.
580///
581/// After we built the graph, we sort it topologically. Finally, we glue
582/// all the necessary instantiations into the resulting context followed by the
583/// original context.
584///
585/// For instance, if we want to instantiate Or from inputs of type
586/// `Type::Array(vec![1, 7], BIT)` and `Type::Array(vec![3, 7], BIT)`,
587/// then we need Not for types `Type::Array(vec![1, 7], BIT)` once
588/// and for `Type::Array(vec![3, 7], BIT)` twice. But the latter will be instantiated
589/// only once due to caching.
590pub fn run_instantiation_pass(context: Context) -> Result<MappedContext> {
591    /* Build a graph of instantiations */
592    let mut needed_instantiations = vec![];
593    for graph in context.get_graphs() {
594        for node in graph.get_nodes() {
595            if let Operation::Custom(_) = node.get_operation() {
596                needed_instantiations.push(Instantiation::create_from_node(node)?);
597            }
598        }
599    }
600    let mut instantiations_graph_mapping = InstantiationsGraphMapping::default();
601    let mut instantiations_graph = InstantiationsGraph::default();
602    for instantiation in needed_instantiations {
603        let (_, already_existed) = get_instantiations_graph_node(
604            &instantiation,
605            &mut instantiations_graph_mapping,
606            &mut instantiations_graph,
607        );
608        if !already_existed {
609            process_instantiation(
610                &instantiation,
611                &mut instantiations_graph_mapping,
612                &mut instantiations_graph,
613            )?;
614        }
615    }
616    /* =============================== */
617    let result_context = create_context()?;
618    // Glues a given context into the final one
619    let glue_context = |glued_instantiations_cache: &HashMap<Instantiation, Graph>,
620                        context_to_glue: Context|
621     -> Result<ContextMappings> {
622        let mut mapping = ContextMappings::default();
623        for graph_to_glue in context_to_glue.get_graphs() {
624            let glued_graph = result_context.create_graph()?;
625            for annotation in graph_to_glue.get_annotations()? {
626                glued_graph.add_annotation(annotation)?;
627            }
628            mapping.insert_graph(graph_to_glue.clone(), glued_graph.clone());
629            for node in graph_to_glue.get_nodes() {
630                let node_dependencies = node.get_node_dependencies();
631                let new_node_dependencies: Vec<Node> = node_dependencies
632                    .iter()
633                    .map(|node| mapping.get_node(node.clone()))
634                    .collect();
635                let new_node = match node.get_operation() {
636                    Operation::Custom(_) => {
637                        let needed_instantiation = Instantiation::create_from_node(node.clone())?;
638                        glued_graph.call(
639                            // Retrieve a needed instantiation from the cache,
640                            // which should be glued before.
641                            glued_instantiations_cache
642                                .get(&needed_instantiation)
643                                .expect("Should not be here")
644                                .clone(),
645                            new_node_dependencies,
646                        )?
647                    }
648                    _ => {
649                        let graph_dependencies = node.get_graph_dependencies();
650                        let new_graph_dependencies: Vec<Graph> = graph_dependencies
651                            .iter()
652                            .map(|graph| mapping.get_graph(graph.clone()))
653                            .collect();
654                        glued_graph.add_node(
655                            new_node_dependencies,
656                            new_graph_dependencies,
657                            node.get_operation(),
658                        )?
659                    }
660                };
661                copy_node_name(node.clone(), new_node.clone())?;
662                let node_annotations = context_to_glue.get_node_annotations(node.clone())?;
663                if !node_annotations.is_empty() {
664                    for node_annotation in node_annotations {
665                        new_node.add_annotation(node_annotation)?;
666                    }
667                }
668                mapping.insert_node(node, new_node);
669            }
670            glued_graph.set_output_node(mapping.get_node(graph_to_glue.get_output_node()?))?;
671            glued_graph.finalize()?;
672        }
673        Ok(mapping)
674    };
675    // Glue necessary instantiations in the order of toposort of
676    // the instantiations graph, and add them to the cache.
677    let mut glued_instantiations_cache = HashMap::<_, Graph>::new();
678    for instantiations_graph_node in toposort(&instantiations_graph, None)
679        .map_err(|_| runtime_error!("Circular dependency among instantiations"))?
680    {
681        let instantiation = instantiations_graph_mapping
682            .node_to_instantiation
683            .get(&instantiations_graph_node)
684            .expect("Should not be here");
685        let fake_context = create_context()?;
686        let g = instantiation
687            .op
688            .instantiate(fake_context.clone(), instantiation.arguments_types.clone())?
689            .set_as_main()?;
690        fake_context.finalize()?;
691        let mapping = glue_context(&glued_instantiations_cache, fake_context)?;
692        let mapped_graph = mapping.get_graph(g);
693        mapped_graph.set_name(&instantiation.get_name())?;
694        glued_instantiations_cache.insert(instantiation.clone(), mapped_graph);
695    }
696    // Glue the final context.
697    let mut result = MappedContext::new(result_context.clone());
698    result.mappings = glue_context(&glued_instantiations_cache, context.clone())?;
699    result_context.set_main_graph(result.mappings.get_graph(context.get_main_graph()?))?;
700    result_context.finalize()?;
701    Ok(result)
702}
703
704#[cfg(test)]
705mod tests {
706
707    use super::*;
708
709    use crate::data_types::array_type;
710    use crate::data_values::Value;
711    use crate::evaluators::random_evaluate;
712    use crate::graphs::util::simple_context;
713    use crate::graphs::{contexts_deep_equal, NodeAnnotation};
714
715    fn get_hash(custom_op: &CustomOperation) -> u64 {
716        let mut h = DefaultHasher::new();
717        Hash::hash(custom_op, &mut h);
718        h.finish()
719    }
720
721    #[test]
722    fn test_custom_operation() {
723        assert_eq!(CustomOperation::new(Not {}), CustomOperation::new(Not {}));
724        assert_eq!(CustomOperation::new(Or {}), CustomOperation::new(Or {}));
725        assert!(CustomOperation::new(Not {}) != CustomOperation::new(Or {}));
726        assert_eq!(
727            get_hash(&CustomOperation::new(Not {})),
728            get_hash(&CustomOperation::new(Not {})),
729        );
730        assert_eq!(
731            get_hash(&CustomOperation::new(Or {})),
732            get_hash(&CustomOperation::new(Or {})),
733        );
734        assert!(get_hash(&CustomOperation::new(Or {})) != get_hash(&CustomOperation::new(Not {})),);
735        let v = vec![CustomOperation::new(Not {}), CustomOperation::new(Or {})];
736        let sers = vec![
737            "{\"body\":{\"type\":\"Not\"}}",
738            "{\"body\":{\"type\":\"Or\"}}",
739        ];
740        let debugs = vec![
741            "CustomOperation { body: Not }",
742            "CustomOperation { body: Or }",
743        ];
744        for i in 0..v.len() {
745            let s = serde_json::to_string(&v[i]).unwrap();
746            assert_eq!(s, sers[i]);
747            assert_eq!(serde_json::from_str::<CustomOperation>(&s).unwrap(), v[i]);
748            assert_eq!(v, v.clone());
749            assert_eq!(format!("{:?}", v[i]), debugs[i]);
750        }
751        assert!(serde_json::from_str::<CustomOperation>(
752            "{\"body\":{\"type\":\"InvalidCustomOperation\"}}"
753        )
754        .is_err());
755    }
756
757    #[test]
758    fn test_not() {
759        || -> Result<()> {
760            let c = create_context()?;
761            let g = c.create_graph()?;
762            let i = g.input(scalar_type(BIT))?;
763            let o = g.custom_op(CustomOperation::new(Not {}), vec![i])?;
764            g.set_output_node(o)?;
765            g.finalize()?;
766            c.set_main_graph(g.clone())?;
767            c.finalize()?;
768            let mapped_c = run_instantiation_pass(c)?;
769            for x in vec![0, 1] {
770                let result = random_evaluate(
771                    mapped_c.mappings.get_graph(g.clone()),
772                    vec![Value::from_scalar(x, BIT)?],
773                )?;
774                let result = result.to_u8(BIT)?;
775                assert_eq!(result, !(x != 0) as u8);
776            }
777            Ok(())
778        }()
779        .unwrap();
780        // Test broadcasting
781        || -> Result<()> {
782            let c = create_context()?;
783            let g = c.create_graph()?;
784            let i = g.input(array_type(vec![3, 3], BIT))?;
785            let o = g.custom_op(CustomOperation::new(Not {}), vec![i])?;
786            g.set_output_node(o)?;
787            g.finalize()?;
788            c.set_main_graph(g.clone())?;
789            c.finalize()?;
790            let mapped_c = run_instantiation_pass(c)?;
791            let result = random_evaluate(
792                mapped_c.mappings.get_graph(g.clone()),
793                vec![Value::from_flattened_array(
794                    &vec![0, 1, 1, 0, 1, 0, 0, 1, 1],
795                    BIT,
796                )?],
797            )?;
798            let result = result.to_flattened_array_u64(array_type(vec![3, 3], BIT))?;
799            assert_eq!(result, vec![1, 0, 0, 1, 0, 1, 1, 0, 0]);
800            Ok(())
801        }()
802        .unwrap();
803    }
804
805    #[test]
806    fn test_or() {
807        || -> Result<()> {
808            let c = create_context()?;
809            let g = c.create_graph()?;
810            let i1 = g.input(scalar_type(BIT))?;
811            let i2 = g.input(scalar_type(BIT))?;
812            let o = g.custom_op(CustomOperation::new(Or {}), vec![i1, i2])?;
813            g.set_output_node(o)?;
814            g.finalize()?;
815            c.set_main_graph(g.clone())?;
816            c.finalize()?;
817            let mapped_c = run_instantiation_pass(c)?;
818            for x in vec![0, 1] {
819                for y in vec![0, 1] {
820                    let result = random_evaluate(
821                        mapped_c.mappings.get_graph(g.clone()),
822                        vec![Value::from_scalar(x, BIT)?, Value::from_scalar(y, BIT)?],
823                    )?;
824                    let result = result.to_u8(BIT)?;
825                    assert_eq!(result, ((x != 0) || (y != 0)) as u8);
826                }
827            }
828            Ok(())
829        }()
830        .unwrap();
831    }
832
833    #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
834    struct A {}
835
836    #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
837    struct B {}
838
839    #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
840    struct C {}
841
842    #[typetag::serde]
843    impl CustomOperationBody for A {
844        fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
845            let g = context.create_graph()?;
846            g.custom_op(
847                CustomOperation::new(B {}),
848                vec![g.input(arguments_types[0].clone())?],
849            )?
850            .set_as_output()?;
851            g.finalize()?;
852            Ok(g)
853        }
854
855        fn get_name(&self) -> String {
856            "A".to_owned()
857        }
858    }
859
860    #[typetag::serde]
861    impl CustomOperationBody for B {
862        fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
863            let g = context.create_graph()?;
864            let i = g.input(arguments_types[0].clone())?;
865            g.set_output_node(i)?;
866            g.finalize()?;
867            let fake_g = context.create_graph()?;
868            let i = fake_g.input(scalar_type(BIT))?;
869            fake_g.set_output_node(i)?;
870            fake_g.finalize()?;
871            Ok(g)
872        }
873
874        fn get_name(&self) -> String {
875            "B".to_owned()
876        }
877    }
878
879    #[typetag::serde]
880    impl CustomOperationBody for C {
881        fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
882            let g = context.create_graph()?;
883            let mut inputs = vec![];
884            for t in &arguments_types {
885                inputs.push(g.input(t.clone())?);
886            }
887            let o = if arguments_types.len() == 1 {
888                inputs[0].clone()
889            } else {
890                let node = g.create_tuple(vec![
891                    g.custom_op(
892                        CustomOperation::new(C {}),
893                        inputs[0..inputs.len() / 2].to_vec(),
894                    )?,
895                    g.custom_op(
896                        CustomOperation::new(C {}),
897                        inputs[inputs.len() / 2..inputs.len()].to_vec(),
898                    )?,
899                ])?;
900                context.add_node_annotation(&node, NodeAnnotation::AssociativeOperation)?;
901                node
902            };
903            g.set_output_node(o)?;
904            g.finalize()?;
905            Ok(g)
906        }
907
908        fn get_name(&self) -> String {
909            "C".to_owned()
910        }
911    }
912
913    #[test]
914    fn test_instantiation_pass() {
915        || -> Result<()> {
916            let c = simple_context(|g| {
917                let i = g.input(scalar_type(BIT))?;
918                let o = g.custom_op(CustomOperation::new(A {}), vec![i])?;
919                o.set_name("A")?;
920                Ok(o)
921            })?;
922
923            let processed_c = run_instantiation_pass(c)?.context;
924
925            let expected_c = create_context()?;
926            let g2 = expected_c.create_graph()?;
927            let i = g2.input(scalar_type(BIT))?;
928            g2.set_output_node(i)?;
929            g2.set_name("__B::<bit>")?;
930            g2.finalize()?;
931            let g1 = expected_c.create_graph()?;
932            let i = g1.input(scalar_type(BIT))?;
933            g1.set_output_node(i)?;
934            g1.finalize()?;
935            let g3 = expected_c.create_graph()?;
936            let i = g3.input(scalar_type(BIT))?;
937            let o = g3.call(g2, vec![i])?;
938            g3.set_output_node(o)?;
939            g3.set_name("__A::<bit>")?;
940            g3.finalize()?;
941            let g4 = expected_c.create_graph()?;
942            let i = g4.input(scalar_type(BIT))?;
943            let o = g4.call(g3, vec![i])?;
944            o.set_name("A")?;
945            g4.set_output_node(o)?;
946            g4.finalize()?;
947            expected_c.set_main_graph(g4)?;
948            expected_c.finalize()?;
949            assert!(contexts_deep_equal(expected_c, processed_c));
950            Ok(())
951        }()
952        .unwrap();
953
954        || -> Result<()> {
955            let c = create_context()?;
956            let sub_g = c.create_graph()?;
957            let i = sub_g.input(scalar_type(BIT))?;
958            sub_g.set_output_node(i)?;
959            sub_g.finalize()?;
960            let g = c.create_graph()?;
961            let i = g.input(scalar_type(BIT))?;
962            let ii = g.call(sub_g, vec![i])?;
963            let o = g.custom_op(CustomOperation::new(B {}), vec![ii])?;
964            o.set_name("B")?;
965            g.set_output_node(o)?;
966            g.finalize()?;
967            c.set_main_graph(g)?;
968            c.finalize()?;
969
970            let processed_c = run_instantiation_pass(c)?.context;
971
972            let expected_c = create_context()?;
973            let g1 = expected_c.create_graph()?;
974            let i = g1.input(scalar_type(BIT))?;
975            g1.set_output_node(i)?;
976            g1.set_name("__B::<bit>")?;
977            g1.finalize()?;
978            let g3 = expected_c.create_graph()?;
979            let i = g3.input(scalar_type(BIT))?;
980            g3.set_output_node(i)?;
981            g3.finalize()?;
982            let g2 = expected_c.create_graph()?;
983            let i = g2.input(scalar_type(BIT))?;
984            g2.set_output_node(i)?;
985            g2.finalize()?;
986            let g4 = expected_c.create_graph()?;
987            let i = g4.input(scalar_type(BIT))?;
988            let o = g4.call(g2, vec![i])?;
989            let oo = g4.call(g1, vec![o])?;
990            oo.set_name("B")?;
991            g4.set_output_node(oo)?;
992            g4.finalize()?;
993            expected_c.set_main_graph(g4)?;
994            expected_c.finalize()?;
995            assert!(contexts_deep_equal(expected_c, processed_c));
996            Ok(())
997        }()
998        .unwrap();
999
1000        // Checking that `run_instantiation_pass` is deterministic
1001        || -> Result<()> {
1002            let generate_context = || -> Result<Context> {
1003                simple_context(|g| {
1004                    let i1 = g.input(array_type(vec![1, 5], BIT))?;
1005                    let i2 = g.input(array_type(vec![7, 5], BIT))?;
1006                    let i3 = g.input(array_type(vec![4, 3], BIT))?;
1007                    let i4 = g.input(array_type(vec![2, 3], BIT))?;
1008                    g.custom_op(CustomOperation::new(C {}), vec![i1, i2, i3, i4])
1009                })
1010            };
1011            let mut contexts = vec![];
1012            for _ in 0..10 {
1013                contexts.push(generate_context()?);
1014            }
1015            let mut instantiated_contexts = vec![];
1016            for context in contexts {
1017                instantiated_contexts.push(run_instantiation_pass(context)?.context);
1018            }
1019            for i in 0..instantiated_contexts.len() {
1020                assert!(contexts_deep_equal(
1021                    instantiated_contexts[0].clone(),
1022                    instantiated_contexts[i].clone()
1023                ));
1024            }
1025            Ok(())
1026        }()
1027        .unwrap();
1028
1029        // Checking that `run_instantiation_pass` copies node annotations
1030        || -> Result<()> {
1031            let context = simple_context(|g| {
1032                let i1 = g.input(array_type(vec![1, 5], BIT))?;
1033                let i2 = g.input(array_type(vec![7, 5], BIT))?;
1034                let i3 = g.input(array_type(vec![4, 3], BIT))?;
1035                let i4 = g.input(array_type(vec![2, 3], BIT))?;
1036                g.custom_op(CustomOperation::new(C {}), vec![i1, i2, i3, i4])
1037            })?;
1038            let new_context = run_instantiation_pass(context)?.context;
1039            assert_eq!(
1040                new_context
1041                    .get_node_annotations(new_context.get_graphs()[6].get_output_node()?)?
1042                    .len(),
1043                1
1044            );
1045            Ok(())
1046        }()
1047        .unwrap();
1048
1049        // Check `run_instantiation_pass` for Not
1050        || -> Result<()> {
1051            let c = simple_context(|g| {
1052                let i1 = g.input(array_type(vec![5], BIT))?;
1053                g.custom_op(CustomOperation::new(Not {}), vec![i1])
1054            })?;
1055            let mapped_c = run_instantiation_pass(c)?;
1056            let expected_c = create_context()?;
1057            let not_g = expected_c.create_graph()?;
1058            let i = not_g.input(array_type(vec![5], BIT))?;
1059            let c = not_g.ones(scalar_type(BIT))?;
1060            let o = not_g.add(i, c)?;
1061            not_g.set_output_node(o)?;
1062            not_g.set_name("__Not::<bit[5]>")?;
1063            not_g.finalize()?;
1064            let g = expected_c.create_graph()?;
1065            let i = g.input(array_type(vec![5], BIT))?;
1066            let o = g.call(not_g, vec![i])?;
1067            g.set_output_node(o)?;
1068            g.finalize()?;
1069            expected_c.set_main_graph(g)?;
1070            expected_c.finalize()?;
1071            assert!(contexts_deep_equal(mapped_c.context, expected_c));
1072            Ok(())
1073        }()
1074        .unwrap();
1075
1076        // Check `run_instantiation_pass` for Or
1077        || -> Result<()> {
1078            let c = simple_context(|g| {
1079                let i1 = g.input(array_type(vec![5], BIT))?;
1080                let i2 = g.input(array_type(vec![3, 5], BIT))?;
1081                g.custom_op(CustomOperation::new(Or {}), vec![i1, i2])
1082            })?;
1083            let mapped_c = run_instantiation_pass(c)?;
1084            let expected_c = create_context()?;
1085            let not_g_2 = expected_c.create_graph()?;
1086            let i = not_g_2.input(array_type(vec![3, 5], BIT))?;
1087            let c = not_g_2.ones(scalar_type(BIT))?;
1088            let o = not_g_2.add(i, c)?;
1089            not_g_2.set_output_node(o)?;
1090            not_g_2.set_name("__Not::<bit[3, 5]>")?;
1091            not_g_2.finalize()?;
1092            let not_g = expected_c.create_graph()?;
1093            let i = not_g.input(array_type(vec![5], BIT))?;
1094            let c = not_g.ones(scalar_type(BIT))?;
1095            let o = not_g.add(i, c)?;
1096            not_g.set_output_node(o)?;
1097            not_g.set_name("__Not::<bit[5]>")?;
1098            not_g.finalize()?;
1099            let or_g = expected_c.create_graph()?;
1100            let i1 = or_g.input(array_type(vec![5], BIT))?;
1101            let i2 = or_g.input(array_type(vec![3, 5], BIT))?;
1102            let i1_not = or_g.call(not_g, vec![i1])?;
1103            let i2_not = or_g.call(not_g_2.clone(), vec![i2])?;
1104            let i1_not_and_i2_not = or_g.multiply(i1_not, i2_not)?;
1105            let o = or_g.call(not_g_2, vec![i1_not_and_i2_not])?;
1106            or_g.set_output_node(o)?;
1107            or_g.set_name("__Or::<bit[5], bit[3, 5]>")?;
1108            or_g.finalize()?;
1109            let g = expected_c.create_graph()?;
1110            let i1 = g.input(array_type(vec![5], BIT))?;
1111            let i2 = g.input(array_type(vec![3, 5], BIT))?;
1112            let o = g.call(or_g, vec![i1, i2])?;
1113            g.set_output_node(o)?;
1114            g.finalize()?;
1115            expected_c.set_main_graph(g)?;
1116            expected_c.finalize()?;
1117            assert!(contexts_deep_equal(mapped_c.context, expected_c));
1118            Ok(())
1119        }()
1120        .unwrap();
1121    }
1122}