1use crate::data_types::{scalar_type, Type, BIT};
5use crate::errors::Result;
6use crate::graphs::{copy_node_name, create_context, Context, Graph, Node, Operation};
7
8use serde::{Deserialize, Serialize};
9
10use petgraph::algo::toposort;
11use petgraph::graph::{DiGraph, NodeIndex};
12
13use std::any::{Any, TypeId};
14use std::collections::{hash_map::DefaultHasher, HashMap};
15use std::fmt::Debug;
16use std::fmt::Write;
17use std::hash::{Hash, Hasher};
18use std::sync::Arc;
19
20#[cfg(feature = "py-binding")]
21use pywrapper_macro::struct_wrapper;
22
23#[doc(hidden)]
24pub trait DynEqHash {
30 fn as_any(&self) -> &dyn Any;
31 fn equals(&self, _: &dyn Any) -> bool;
32 fn hash(&self) -> u64;
33}
34
35impl<T: 'static + Eq + Hash> DynEqHash for T {
36 fn as_any(&self) -> &dyn Any {
37 self
38 }
39
40 fn equals(&self, other: &dyn Any) -> bool {
41 other.downcast_ref::<T>().map_or(false, |a| self == a)
42 }
43
44 fn hash(&self) -> u64 {
48 let mut h = DefaultHasher::new();
49 Hash::hash(&(TypeId::of::<T>(), self), &mut h);
50 h.finish()
51 }
52}
53
54#[typetag::serde(tag = "type")]
102pub trait CustomOperationBody: 'static + Debug + DynEqHash + Send + Sync {
103 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph>;
118
119 fn get_name(&self) -> String;
127}
128
129#[derive(Clone, Debug, Deserialize, Serialize)]
155#[cfg_attr(feature = "py-binding", struct_wrapper)]
156pub struct CustomOperation {
157 body: Arc<dyn CustomOperationBody>,
158}
159
160#[cfg(feature = "py-binding")]
161#[pyo3::pymethods]
162impl PyBindingCustomOperation {
163 #[new]
164 fn new(value: String) -> pyo3::PyResult<Self> {
165 let custom_op = serde_json::from_str::<CustomOperation>(&value)
166 .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?;
167 Ok(PyBindingCustomOperation { inner: custom_op })
168 }
169 fn __str__(&self) -> pyo3::PyResult<String> {
170 serde_json::to_string(&self.inner)
171 .map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))
172 }
173 fn __repr__(&self) -> pyo3::PyResult<String> {
174 self.__str__()
175 }
176}
177
178impl CustomOperation {
179 pub fn new<T: 'static + CustomOperationBody>(op: T) -> CustomOperation {
202 CustomOperation { body: Arc::new(op) }
203 }
204
205 pub fn get_name(&self) -> String {
211 self.body.get_name()
212 }
213}
214
215impl CustomOperation {
216 #[doc(hidden)]
217 pub fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
218 self.body.instantiate(context, arguments_types)
219 }
220}
221
222impl PartialEq for CustomOperation {
223 fn eq(&self, other: &Self) -> bool {
235 self.body.equals((*other.body).as_any())
236 }
237}
238
239impl Hash for CustomOperation {
240 fn hash<H: Hasher>(&self, state: &mut H) {
246 let hash_value = DynEqHash::hash(self.body.as_ref());
247 state.write_u64(hash_value);
248 }
249}
250
251impl Eq for CustomOperation {}
252
253#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
280pub struct Not {}
281
282#[typetag::serde]
283impl CustomOperationBody for Not {
284 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
285 if arguments_types.len() != 1 {
286 return Err(runtime_error!("Invalid number of arguments for Not"));
287 }
288 let g = context.create_graph()?;
289 g.input(arguments_types[0].clone())?
290 .add(g.ones(scalar_type(BIT))?)?
291 .set_as_output()?;
292 g.finalize()?;
293 Ok(g)
294 }
295
296 fn get_name(&self) -> String {
297 "Not".to_owned()
298 }
299}
300
301#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
333pub struct Or {}
334
335#[typetag::serde]
336impl CustomOperationBody for Or {
337 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
338 if arguments_types.len() != 2 {
339 return Err(runtime_error!("Invalid number of arguments for Or"));
340 }
341 let g = context.create_graph()?;
342 let i1 = g.input(arguments_types[0].clone())?;
343 let i2 = g.input(arguments_types[1].clone())?;
344 let i1_not = g.custom_op(CustomOperation::new(Not {}), vec![i1])?;
345 let i2_not = g.custom_op(CustomOperation::new(Not {}), vec![i2])?;
346 g.custom_op(CustomOperation::new(Not {}), vec![i1_not.multiply(i2_not)?])?
347 .set_as_output()?;
348 g.finalize()?;
349 Ok(g)
350 }
351
352 fn get_name(&self) -> String {
353 "Or".to_owned()
354 }
355}
356
357#[doc(hidden)]
358#[derive(Default)]
365pub struct ContextMappings {
366 node_mapping: HashMap<Node, Node>,
367 graph_mapping: HashMap<Graph, Graph>,
368}
369
370impl ContextMappings {
371 pub fn contains_graph(&self, graph: Graph) -> bool {
372 self.graph_mapping.contains_key(&graph)
373 }
374
375 pub fn contains_node(&self, node: Node) -> bool {
376 self.node_mapping.contains_key(&node)
377 }
378
379 pub fn get_graph(&self, graph: Graph) -> Graph {
381 self.graph_mapping
382 .get(&graph)
383 .expect("Graph is not found in graph_mapping")
384 .clone()
385 }
386
387 pub fn get_node(&self, node: Node) -> Node {
389 self.node_mapping
390 .get(&node)
391 .expect("Node is not found in node_mapping")
392 .clone()
393 }
394
395 pub fn insert_graph(&mut self, old_graph: Graph, new_graph: Graph) {
397 assert!(
398 self.graph_mapping.insert(old_graph, new_graph).is_none(),
399 "Graph has already been inserted in graph_mapping"
400 );
401 }
402
403 pub fn insert_node(&mut self, old_node: Node, new_node: Node) {
405 assert!(
406 self.node_mapping.insert(old_node, new_node).is_none(),
407 "Node has already been inserted in node_mapping"
408 );
409 }
410
411 pub fn remove_graph(&mut self, old_graph: Graph) {
413 assert!(
414 self.graph_mapping.remove(&old_graph).is_some(),
415 "Graph is not in graph_mapping"
416 );
417 }
418
419 pub fn remove_node(&mut self, old_node: Node) {
421 assert!(
422 self.node_mapping.remove(&old_node).is_some(),
423 "Node is not isn node_mapping"
424 );
425 }
426}
427
428#[doc(hidden)]
429pub struct MappedContext {
430 pub context: Context,
431 pub mappings: ContextMappings,
433}
434
435impl MappedContext {
436 pub fn new(context: Context) -> Self {
437 MappedContext {
438 context,
439 mappings: ContextMappings::default(),
440 }
441 }
442
443 pub fn get_context(&self) -> Context {
444 self.context.clone()
445 }
446}
447
448#[derive(Debug, Clone, PartialEq, Eq, Hash)]
451pub(super) struct Instantiation {
452 pub(super) op: CustomOperation,
453 pub(super) arguments_types: Vec<Type>,
454}
455
456impl Instantiation {
457 fn create_from_node(node: Node) -> Result<Self> {
460 if let Operation::Custom(custom_op) = node.get_operation() {
461 let mut node_dependencies_types = vec![];
462 for dependency in node.get_node_dependencies() {
463 node_dependencies_types.push(dependency.get_type()?);
464 }
465 Ok(Instantiation {
466 op: custom_op,
467 arguments_types: node_dependencies_types,
468 })
469 } else {
470 Err(runtime_error!(
471 "Instantiations can only be created from custom nodes"
472 ))
473 }
474 }
475
476 fn get_name(&self) -> String {
477 let mut name = "__".to_owned();
478 name.push_str(&self.op.get_name());
479 name.push_str("::<");
480 let mut first_argument = true;
481 for t in &self.arguments_types {
482 if first_argument {
483 first_argument = false;
484 } else {
485 name.push_str(", ");
486 }
487 write!(name, "{t}").unwrap();
488 }
489 name.push('>');
490 name
491 }
492}
493
494type InstantiationsGraph = DiGraph<Instantiation, (), usize>;
496type InstantiationsGraphNode = NodeIndex<usize>;
497
498#[derive(Default)]
500struct InstantiationsGraphMapping {
501 instantiation_to_node: HashMap<Instantiation, InstantiationsGraphNode>,
502 node_to_instantiation: HashMap<InstantiationsGraphNode, Instantiation>,
503}
504
505fn get_instantiations_graph_node(
508 instantiation: &Instantiation,
509 instantiations_graph_mapping: &mut InstantiationsGraphMapping,
510 instantiations_graph: &mut InstantiationsGraph,
511) -> (InstantiationsGraphNode, bool) {
512 match instantiations_graph_mapping
513 .instantiation_to_node
514 .get(instantiation)
515 {
516 Some(id) => (*id, true),
517 None => {
518 let new_inode = instantiations_graph.add_node(instantiation.clone());
519 instantiations_graph_mapping
520 .instantiation_to_node
521 .insert(instantiation.clone(), new_inode);
522 instantiations_graph_mapping
523 .node_to_instantiation
524 .insert(new_inode, instantiation.clone());
525 (new_inode, false)
526 }
527 }
528}
529
530fn process_instantiation(
533 instantiation: &Instantiation,
534 instantiations_graph_mapping: &mut InstantiationsGraphMapping,
535 instantiations_graph: &mut InstantiationsGraph,
536) -> Result<()> {
537 let fake_context = create_context()?;
538 let graph = instantiation
539 .op
540 .instantiate(fake_context.clone(), instantiation.arguments_types.clone())?;
541 for fake_graph in fake_context.get_graphs() {
544 for node in fake_graph.get_nodes() {
545 if let Operation::Custom(_) = node.get_operation() {
546 let new_instantiation = Instantiation::create_from_node(node)?;
547 let (node1, already_existed) = get_instantiations_graph_node(
548 &new_instantiation,
549 instantiations_graph_mapping,
550 instantiations_graph,
551 );
552 let (node2, _) = get_instantiations_graph_node(
553 instantiation,
554 instantiations_graph_mapping,
555 instantiations_graph,
556 );
557 instantiations_graph.add_edge(node1, node2, ());
558 if !already_existed {
559 process_instantiation(
560 &new_instantiation,
561 instantiations_graph_mapping,
562 instantiations_graph,
563 )?;
564 }
565 }
566 }
567 }
568 graph.set_as_main()?;
569 fake_context.finalize()?;
570 Ok(())
571}
572
573#[doc(hidden)]
574pub fn run_instantiation_pass(context: Context) -> Result<MappedContext> {
591 let mut needed_instantiations = vec![];
593 for graph in context.get_graphs() {
594 for node in graph.get_nodes() {
595 if let Operation::Custom(_) = node.get_operation() {
596 needed_instantiations.push(Instantiation::create_from_node(node)?);
597 }
598 }
599 }
600 let mut instantiations_graph_mapping = InstantiationsGraphMapping::default();
601 let mut instantiations_graph = InstantiationsGraph::default();
602 for instantiation in needed_instantiations {
603 let (_, already_existed) = get_instantiations_graph_node(
604 &instantiation,
605 &mut instantiations_graph_mapping,
606 &mut instantiations_graph,
607 );
608 if !already_existed {
609 process_instantiation(
610 &instantiation,
611 &mut instantiations_graph_mapping,
612 &mut instantiations_graph,
613 )?;
614 }
615 }
616 let result_context = create_context()?;
618 let glue_context = |glued_instantiations_cache: &HashMap<Instantiation, Graph>,
620 context_to_glue: Context|
621 -> Result<ContextMappings> {
622 let mut mapping = ContextMappings::default();
623 for graph_to_glue in context_to_glue.get_graphs() {
624 let glued_graph = result_context.create_graph()?;
625 for annotation in graph_to_glue.get_annotations()? {
626 glued_graph.add_annotation(annotation)?;
627 }
628 mapping.insert_graph(graph_to_glue.clone(), glued_graph.clone());
629 for node in graph_to_glue.get_nodes() {
630 let node_dependencies = node.get_node_dependencies();
631 let new_node_dependencies: Vec<Node> = node_dependencies
632 .iter()
633 .map(|node| mapping.get_node(node.clone()))
634 .collect();
635 let new_node = match node.get_operation() {
636 Operation::Custom(_) => {
637 let needed_instantiation = Instantiation::create_from_node(node.clone())?;
638 glued_graph.call(
639 glued_instantiations_cache
642 .get(&needed_instantiation)
643 .expect("Should not be here")
644 .clone(),
645 new_node_dependencies,
646 )?
647 }
648 _ => {
649 let graph_dependencies = node.get_graph_dependencies();
650 let new_graph_dependencies: Vec<Graph> = graph_dependencies
651 .iter()
652 .map(|graph| mapping.get_graph(graph.clone()))
653 .collect();
654 glued_graph.add_node(
655 new_node_dependencies,
656 new_graph_dependencies,
657 node.get_operation(),
658 )?
659 }
660 };
661 copy_node_name(node.clone(), new_node.clone())?;
662 let node_annotations = context_to_glue.get_node_annotations(node.clone())?;
663 if !node_annotations.is_empty() {
664 for node_annotation in node_annotations {
665 new_node.add_annotation(node_annotation)?;
666 }
667 }
668 mapping.insert_node(node, new_node);
669 }
670 glued_graph.set_output_node(mapping.get_node(graph_to_glue.get_output_node()?))?;
671 glued_graph.finalize()?;
672 }
673 Ok(mapping)
674 };
675 let mut glued_instantiations_cache = HashMap::<_, Graph>::new();
678 for instantiations_graph_node in toposort(&instantiations_graph, None)
679 .map_err(|_| runtime_error!("Circular dependency among instantiations"))?
680 {
681 let instantiation = instantiations_graph_mapping
682 .node_to_instantiation
683 .get(&instantiations_graph_node)
684 .expect("Should not be here");
685 let fake_context = create_context()?;
686 let g = instantiation
687 .op
688 .instantiate(fake_context.clone(), instantiation.arguments_types.clone())?
689 .set_as_main()?;
690 fake_context.finalize()?;
691 let mapping = glue_context(&glued_instantiations_cache, fake_context)?;
692 let mapped_graph = mapping.get_graph(g);
693 mapped_graph.set_name(&instantiation.get_name())?;
694 glued_instantiations_cache.insert(instantiation.clone(), mapped_graph);
695 }
696 let mut result = MappedContext::new(result_context.clone());
698 result.mappings = glue_context(&glued_instantiations_cache, context.clone())?;
699 result_context.set_main_graph(result.mappings.get_graph(context.get_main_graph()?))?;
700 result_context.finalize()?;
701 Ok(result)
702}
703
704#[cfg(test)]
705mod tests {
706
707 use super::*;
708
709 use crate::data_types::array_type;
710 use crate::data_values::Value;
711 use crate::evaluators::random_evaluate;
712 use crate::graphs::util::simple_context;
713 use crate::graphs::{contexts_deep_equal, NodeAnnotation};
714
715 fn get_hash(custom_op: &CustomOperation) -> u64 {
716 let mut h = DefaultHasher::new();
717 Hash::hash(custom_op, &mut h);
718 h.finish()
719 }
720
721 #[test]
722 fn test_custom_operation() {
723 assert_eq!(CustomOperation::new(Not {}), CustomOperation::new(Not {}));
724 assert_eq!(CustomOperation::new(Or {}), CustomOperation::new(Or {}));
725 assert!(CustomOperation::new(Not {}) != CustomOperation::new(Or {}));
726 assert_eq!(
727 get_hash(&CustomOperation::new(Not {})),
728 get_hash(&CustomOperation::new(Not {})),
729 );
730 assert_eq!(
731 get_hash(&CustomOperation::new(Or {})),
732 get_hash(&CustomOperation::new(Or {})),
733 );
734 assert!(get_hash(&CustomOperation::new(Or {})) != get_hash(&CustomOperation::new(Not {})),);
735 let v = vec![CustomOperation::new(Not {}), CustomOperation::new(Or {})];
736 let sers = vec![
737 "{\"body\":{\"type\":\"Not\"}}",
738 "{\"body\":{\"type\":\"Or\"}}",
739 ];
740 let debugs = vec![
741 "CustomOperation { body: Not }",
742 "CustomOperation { body: Or }",
743 ];
744 for i in 0..v.len() {
745 let s = serde_json::to_string(&v[i]).unwrap();
746 assert_eq!(s, sers[i]);
747 assert_eq!(serde_json::from_str::<CustomOperation>(&s).unwrap(), v[i]);
748 assert_eq!(v, v.clone());
749 assert_eq!(format!("{:?}", v[i]), debugs[i]);
750 }
751 assert!(serde_json::from_str::<CustomOperation>(
752 "{\"body\":{\"type\":\"InvalidCustomOperation\"}}"
753 )
754 .is_err());
755 }
756
757 #[test]
758 fn test_not() {
759 || -> Result<()> {
760 let c = create_context()?;
761 let g = c.create_graph()?;
762 let i = g.input(scalar_type(BIT))?;
763 let o = g.custom_op(CustomOperation::new(Not {}), vec![i])?;
764 g.set_output_node(o)?;
765 g.finalize()?;
766 c.set_main_graph(g.clone())?;
767 c.finalize()?;
768 let mapped_c = run_instantiation_pass(c)?;
769 for x in vec![0, 1] {
770 let result = random_evaluate(
771 mapped_c.mappings.get_graph(g.clone()),
772 vec![Value::from_scalar(x, BIT)?],
773 )?;
774 let result = result.to_u8(BIT)?;
775 assert_eq!(result, !(x != 0) as u8);
776 }
777 Ok(())
778 }()
779 .unwrap();
780 || -> Result<()> {
782 let c = create_context()?;
783 let g = c.create_graph()?;
784 let i = g.input(array_type(vec![3, 3], BIT))?;
785 let o = g.custom_op(CustomOperation::new(Not {}), vec![i])?;
786 g.set_output_node(o)?;
787 g.finalize()?;
788 c.set_main_graph(g.clone())?;
789 c.finalize()?;
790 let mapped_c = run_instantiation_pass(c)?;
791 let result = random_evaluate(
792 mapped_c.mappings.get_graph(g.clone()),
793 vec![Value::from_flattened_array(
794 &vec![0, 1, 1, 0, 1, 0, 0, 1, 1],
795 BIT,
796 )?],
797 )?;
798 let result = result.to_flattened_array_u64(array_type(vec![3, 3], BIT))?;
799 assert_eq!(result, vec![1, 0, 0, 1, 0, 1, 1, 0, 0]);
800 Ok(())
801 }()
802 .unwrap();
803 }
804
805 #[test]
806 fn test_or() {
807 || -> Result<()> {
808 let c = create_context()?;
809 let g = c.create_graph()?;
810 let i1 = g.input(scalar_type(BIT))?;
811 let i2 = g.input(scalar_type(BIT))?;
812 let o = g.custom_op(CustomOperation::new(Or {}), vec![i1, i2])?;
813 g.set_output_node(o)?;
814 g.finalize()?;
815 c.set_main_graph(g.clone())?;
816 c.finalize()?;
817 let mapped_c = run_instantiation_pass(c)?;
818 for x in vec![0, 1] {
819 for y in vec![0, 1] {
820 let result = random_evaluate(
821 mapped_c.mappings.get_graph(g.clone()),
822 vec![Value::from_scalar(x, BIT)?, Value::from_scalar(y, BIT)?],
823 )?;
824 let result = result.to_u8(BIT)?;
825 assert_eq!(result, ((x != 0) || (y != 0)) as u8);
826 }
827 }
828 Ok(())
829 }()
830 .unwrap();
831 }
832
833 #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
834 struct A {}
835
836 #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
837 struct B {}
838
839 #[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
840 struct C {}
841
842 #[typetag::serde]
843 impl CustomOperationBody for A {
844 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
845 let g = context.create_graph()?;
846 g.custom_op(
847 CustomOperation::new(B {}),
848 vec![g.input(arguments_types[0].clone())?],
849 )?
850 .set_as_output()?;
851 g.finalize()?;
852 Ok(g)
853 }
854
855 fn get_name(&self) -> String {
856 "A".to_owned()
857 }
858 }
859
860 #[typetag::serde]
861 impl CustomOperationBody for B {
862 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
863 let g = context.create_graph()?;
864 let i = g.input(arguments_types[0].clone())?;
865 g.set_output_node(i)?;
866 g.finalize()?;
867 let fake_g = context.create_graph()?;
868 let i = fake_g.input(scalar_type(BIT))?;
869 fake_g.set_output_node(i)?;
870 fake_g.finalize()?;
871 Ok(g)
872 }
873
874 fn get_name(&self) -> String {
875 "B".to_owned()
876 }
877 }
878
879 #[typetag::serde]
880 impl CustomOperationBody for C {
881 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
882 let g = context.create_graph()?;
883 let mut inputs = vec![];
884 for t in &arguments_types {
885 inputs.push(g.input(t.clone())?);
886 }
887 let o = if arguments_types.len() == 1 {
888 inputs[0].clone()
889 } else {
890 let node = g.create_tuple(vec![
891 g.custom_op(
892 CustomOperation::new(C {}),
893 inputs[0..inputs.len() / 2].to_vec(),
894 )?,
895 g.custom_op(
896 CustomOperation::new(C {}),
897 inputs[inputs.len() / 2..inputs.len()].to_vec(),
898 )?,
899 ])?;
900 context.add_node_annotation(&node, NodeAnnotation::AssociativeOperation)?;
901 node
902 };
903 g.set_output_node(o)?;
904 g.finalize()?;
905 Ok(g)
906 }
907
908 fn get_name(&self) -> String {
909 "C".to_owned()
910 }
911 }
912
913 #[test]
914 fn test_instantiation_pass() {
915 || -> Result<()> {
916 let c = simple_context(|g| {
917 let i = g.input(scalar_type(BIT))?;
918 let o = g.custom_op(CustomOperation::new(A {}), vec![i])?;
919 o.set_name("A")?;
920 Ok(o)
921 })?;
922
923 let processed_c = run_instantiation_pass(c)?.context;
924
925 let expected_c = create_context()?;
926 let g2 = expected_c.create_graph()?;
927 let i = g2.input(scalar_type(BIT))?;
928 g2.set_output_node(i)?;
929 g2.set_name("__B::<bit>")?;
930 g2.finalize()?;
931 let g1 = expected_c.create_graph()?;
932 let i = g1.input(scalar_type(BIT))?;
933 g1.set_output_node(i)?;
934 g1.finalize()?;
935 let g3 = expected_c.create_graph()?;
936 let i = g3.input(scalar_type(BIT))?;
937 let o = g3.call(g2, vec![i])?;
938 g3.set_output_node(o)?;
939 g3.set_name("__A::<bit>")?;
940 g3.finalize()?;
941 let g4 = expected_c.create_graph()?;
942 let i = g4.input(scalar_type(BIT))?;
943 let o = g4.call(g3, vec![i])?;
944 o.set_name("A")?;
945 g4.set_output_node(o)?;
946 g4.finalize()?;
947 expected_c.set_main_graph(g4)?;
948 expected_c.finalize()?;
949 assert!(contexts_deep_equal(expected_c, processed_c));
950 Ok(())
951 }()
952 .unwrap();
953
954 || -> Result<()> {
955 let c = create_context()?;
956 let sub_g = c.create_graph()?;
957 let i = sub_g.input(scalar_type(BIT))?;
958 sub_g.set_output_node(i)?;
959 sub_g.finalize()?;
960 let g = c.create_graph()?;
961 let i = g.input(scalar_type(BIT))?;
962 let ii = g.call(sub_g, vec![i])?;
963 let o = g.custom_op(CustomOperation::new(B {}), vec![ii])?;
964 o.set_name("B")?;
965 g.set_output_node(o)?;
966 g.finalize()?;
967 c.set_main_graph(g)?;
968 c.finalize()?;
969
970 let processed_c = run_instantiation_pass(c)?.context;
971
972 let expected_c = create_context()?;
973 let g1 = expected_c.create_graph()?;
974 let i = g1.input(scalar_type(BIT))?;
975 g1.set_output_node(i)?;
976 g1.set_name("__B::<bit>")?;
977 g1.finalize()?;
978 let g3 = expected_c.create_graph()?;
979 let i = g3.input(scalar_type(BIT))?;
980 g3.set_output_node(i)?;
981 g3.finalize()?;
982 let g2 = expected_c.create_graph()?;
983 let i = g2.input(scalar_type(BIT))?;
984 g2.set_output_node(i)?;
985 g2.finalize()?;
986 let g4 = expected_c.create_graph()?;
987 let i = g4.input(scalar_type(BIT))?;
988 let o = g4.call(g2, vec![i])?;
989 let oo = g4.call(g1, vec![o])?;
990 oo.set_name("B")?;
991 g4.set_output_node(oo)?;
992 g4.finalize()?;
993 expected_c.set_main_graph(g4)?;
994 expected_c.finalize()?;
995 assert!(contexts_deep_equal(expected_c, processed_c));
996 Ok(())
997 }()
998 .unwrap();
999
1000 || -> Result<()> {
1002 let generate_context = || -> Result<Context> {
1003 simple_context(|g| {
1004 let i1 = g.input(array_type(vec![1, 5], BIT))?;
1005 let i2 = g.input(array_type(vec![7, 5], BIT))?;
1006 let i3 = g.input(array_type(vec![4, 3], BIT))?;
1007 let i4 = g.input(array_type(vec![2, 3], BIT))?;
1008 g.custom_op(CustomOperation::new(C {}), vec![i1, i2, i3, i4])
1009 })
1010 };
1011 let mut contexts = vec![];
1012 for _ in 0..10 {
1013 contexts.push(generate_context()?);
1014 }
1015 let mut instantiated_contexts = vec![];
1016 for context in contexts {
1017 instantiated_contexts.push(run_instantiation_pass(context)?.context);
1018 }
1019 for i in 0..instantiated_contexts.len() {
1020 assert!(contexts_deep_equal(
1021 instantiated_contexts[0].clone(),
1022 instantiated_contexts[i].clone()
1023 ));
1024 }
1025 Ok(())
1026 }()
1027 .unwrap();
1028
1029 || -> Result<()> {
1031 let context = simple_context(|g| {
1032 let i1 = g.input(array_type(vec![1, 5], BIT))?;
1033 let i2 = g.input(array_type(vec![7, 5], BIT))?;
1034 let i3 = g.input(array_type(vec![4, 3], BIT))?;
1035 let i4 = g.input(array_type(vec![2, 3], BIT))?;
1036 g.custom_op(CustomOperation::new(C {}), vec![i1, i2, i3, i4])
1037 })?;
1038 let new_context = run_instantiation_pass(context)?.context;
1039 assert_eq!(
1040 new_context
1041 .get_node_annotations(new_context.get_graphs()[6].get_output_node()?)?
1042 .len(),
1043 1
1044 );
1045 Ok(())
1046 }()
1047 .unwrap();
1048
1049 || -> Result<()> {
1051 let c = simple_context(|g| {
1052 let i1 = g.input(array_type(vec![5], BIT))?;
1053 g.custom_op(CustomOperation::new(Not {}), vec![i1])
1054 })?;
1055 let mapped_c = run_instantiation_pass(c)?;
1056 let expected_c = create_context()?;
1057 let not_g = expected_c.create_graph()?;
1058 let i = not_g.input(array_type(vec![5], BIT))?;
1059 let c = not_g.ones(scalar_type(BIT))?;
1060 let o = not_g.add(i, c)?;
1061 not_g.set_output_node(o)?;
1062 not_g.set_name("__Not::<bit[5]>")?;
1063 not_g.finalize()?;
1064 let g = expected_c.create_graph()?;
1065 let i = g.input(array_type(vec![5], BIT))?;
1066 let o = g.call(not_g, vec![i])?;
1067 g.set_output_node(o)?;
1068 g.finalize()?;
1069 expected_c.set_main_graph(g)?;
1070 expected_c.finalize()?;
1071 assert!(contexts_deep_equal(mapped_c.context, expected_c));
1072 Ok(())
1073 }()
1074 .unwrap();
1075
1076 || -> Result<()> {
1078 let c = simple_context(|g| {
1079 let i1 = g.input(array_type(vec![5], BIT))?;
1080 let i2 = g.input(array_type(vec![3, 5], BIT))?;
1081 g.custom_op(CustomOperation::new(Or {}), vec![i1, i2])
1082 })?;
1083 let mapped_c = run_instantiation_pass(c)?;
1084 let expected_c = create_context()?;
1085 let not_g_2 = expected_c.create_graph()?;
1086 let i = not_g_2.input(array_type(vec![3, 5], BIT))?;
1087 let c = not_g_2.ones(scalar_type(BIT))?;
1088 let o = not_g_2.add(i, c)?;
1089 not_g_2.set_output_node(o)?;
1090 not_g_2.set_name("__Not::<bit[3, 5]>")?;
1091 not_g_2.finalize()?;
1092 let not_g = expected_c.create_graph()?;
1093 let i = not_g.input(array_type(vec![5], BIT))?;
1094 let c = not_g.ones(scalar_type(BIT))?;
1095 let o = not_g.add(i, c)?;
1096 not_g.set_output_node(o)?;
1097 not_g.set_name("__Not::<bit[5]>")?;
1098 not_g.finalize()?;
1099 let or_g = expected_c.create_graph()?;
1100 let i1 = or_g.input(array_type(vec![5], BIT))?;
1101 let i2 = or_g.input(array_type(vec![3, 5], BIT))?;
1102 let i1_not = or_g.call(not_g, vec![i1])?;
1103 let i2_not = or_g.call(not_g_2.clone(), vec![i2])?;
1104 let i1_not_and_i2_not = or_g.multiply(i1_not, i2_not)?;
1105 let o = or_g.call(not_g_2, vec![i1_not_and_i2_not])?;
1106 or_g.set_output_node(o)?;
1107 or_g.set_name("__Or::<bit[5], bit[3, 5]>")?;
1108 or_g.finalize()?;
1109 let g = expected_c.create_graph()?;
1110 let i1 = g.input(array_type(vec![5], BIT))?;
1111 let i2 = g.input(array_type(vec![3, 5], BIT))?;
1112 let o = g.call(or_g, vec![i1, i2])?;
1113 g.set_output_node(o)?;
1114 g.finalize()?;
1115 expected_c.set_main_graph(g)?;
1116 expected_c.finalize()?;
1117 assert!(contexts_deep_equal(mapped_c.context, expected_c));
1118 Ok(())
1119 }()
1120 .unwrap();
1121 }
1122}