1use crate::custom_ops::{CustomOperation, CustomOperationBody, Not};
3use crate::data_types::{array_type, scalar_type, ArrayShape, Type, BIT};
4use crate::data_values::Value;
5use crate::errors::Result;
6use crate::graphs::*;
7use crate::ops::utils::pull_out_bits;
8use crate::ops::utils::{expand_dims, validate_arguments_in_broadcast_bit_ops};
9use std::cmp::max;
10
11use serde::{Deserialize, Serialize};
12
13use super::utils::unsqueeze;
14
15#[derive(Clone)]
107struct ComparisonResult {
108 a_equal_b: Node,
109 a: Node,
110}
111
112struct ShrinkResult {
113 shrinked: Option<ComparisonResult>,
114 remainder: Option<ComparisonResult>,
115}
116
117impl ComparisonResult {
118 fn from_a_b(a: Node, b: Node) -> Result<Self> {
119 let graph = a.get_graph();
120 let one = graph.ones(scalar_type(BIT))?;
121
122 let a_equal_b = a.add(b)?.add(one)?;
123 Ok(Self { a_equal_b, a })
124 }
125
126 fn join(&self, rhs: &Self) -> Result<Self> {
133 let graph = &self.a_equal_b.get_graph();
134 let one = graph.ones(scalar_type(BIT))?;
135
136 let a = self
137 .a
138 .multiply(rhs.a_equal_b.clone())?
139 .add(rhs.a.multiply(rhs.a_equal_b.add(one)?)?)?;
140
141 let a_equal_b = self.a_equal_b.multiply(rhs.a_equal_b.clone())?;
142
143 Ok(Self { a_equal_b, a })
144 }
145
146 fn shrink(&self) -> Result<ShrinkResult> {
150 let bit_len = self.a_equal_b.get_type()?.get_shape()[0] as i64;
151 let offset = bit_len % 2;
152 let remainder = if offset == 0 {
153 None
154 } else {
155 Some(Self {
156 a_equal_b: self.a_equal_b.get(vec![0])?,
157 a: self.a.get(vec![0])?,
158 })
159 };
160 let shrinked = if bit_len <= 1 {
161 None
162 } else {
163 let slice0 = self.sub_slice(offset, bit_len)?;
164 let slice1 = self.sub_slice(offset + 1, bit_len)?;
165 Some(slice0.join(&slice1)?)
166 };
167
168 Ok(ShrinkResult {
169 shrinked,
170 remainder,
171 })
172 }
173
174 fn sub_slice(&self, start_offset: i64, bit_len: i64) -> Result<Self> {
176 let get_slice = |node: &Node| {
184 node.get_slice(vec![SliceElement::SubArray(
185 Some(start_offset),
186 Some(bit_len),
187 Some(2),
188 )])
189 };
190 Ok(Self {
191 a_equal_b: get_slice(&self.a_equal_b)?,
192 a: get_slice(&self.a)?,
193 })
194 }
195
196 fn not_a(&self) -> Result<Node> {
197 let graph = self.a_equal_b.get_graph();
198 graph.custom_op(CustomOperation::new(Not {}), vec![self.a.clone()])
199 }
200
201 fn equal(&self) -> Result<Node> {
202 Ok(self.a_equal_b.clone())
203 }
204
205 fn not_equal(&self) -> Result<Node> {
206 let graph = self.a_equal_b.get_graph();
207 graph.custom_op(CustomOperation::new(Not {}), vec![self.equal()?])
208 }
209
210 fn less_than(&self) -> Result<Node> {
211 self.not_a()?.multiply(self.not_equal()?)
212 }
213
214 fn greater_than(&self) -> Result<Node> {
215 self.a.multiply(self.not_equal()?)
216 }
217
218 fn greater_than_equal_to(&self) -> Result<Node> {
219 let graph = self.a_equal_b.get_graph();
220 graph.custom_op(CustomOperation::new(Not {}), vec![self.less_than()?])
221 }
222
223 fn less_than_equal_to(&self) -> Result<Node> {
224 let graph = self.a_equal_b.get_graph();
225 graph.custom_op(CustomOperation::new(Not {}), vec![self.greater_than()?])
226 }
227}
228
229fn build_comparison_graph(a: Node, b: Node) -> Result<ComparisonResult> {
235 let mut to_shrink = ComparisonResult::from_a_b(a, b)?;
236 let mut remainders = vec![];
237 loop {
238 let shrink_res = to_shrink.shrink()?;
239
240 if let Some(remainder) = shrink_res.remainder {
241 remainders.push(remainder);
242 }
243
244 if let Some(shrinked) = shrink_res.shrinked {
245 to_shrink = shrinked;
246 } else {
247 break;
248 }
249 }
250
251 let mut res = remainders[0].clone();
252 for remainder in remainders[1..].iter() {
253 res = res.join(remainder)?;
254 }
255 Ok(res)
256}
257
258fn expand_to_same_dims(a: Node, b: Node) -> Result<(Node, Node)> {
272 let len_a = a.get_type()?.get_shape().len();
273 let len_b = b.get_type()?.get_shape().len();
274 let result_len = max(len_a, len_b);
275 let a = expand_dims(a, &(0..result_len - len_a).collect::<Vec<_>>())?;
276 let b = expand_dims(b, &(0..result_len - len_b).collect::<Vec<_>>())?;
277 Ok((a, b))
278}
279
280pub(super) fn flip_msb(ip: Node) -> Result<Node> {
332 ip.add(get_msb_flip_constant(
333 ip.get_type()?.get_shape(),
334 &ip.get_graph(),
335 )?)
336}
337
338fn get_msb_flip_constant(shape: ArrayShape, g: &Graph) -> Result<Node> {
339 let n = shape[shape.len() - 1] as usize;
340 let mut msb_mask = vec![0; n];
341 msb_mask[n - 1] = 1;
342 let mut msb_mask = g.constant(
343 array_type(vec![n as u64], BIT),
344 Value::from_flattened_array(&msb_mask, BIT)?,
345 )?;
346 while msb_mask.get_type()?.get_shape().len() < shape.len() {
347 msb_mask = unsqueeze(msb_mask, 0)?;
348 }
349 Ok(msb_mask)
350}
351
352fn preprocess_input(signed_comparison: bool, node: Node) -> Result<Node> {
357 let node = if signed_comparison {
358 flip_msb(node)?
359 } else {
360 node
361 };
362 pull_out_bits(node)
363}
364
365fn preprocess_inputs(signed_comparison: bool, a: Node, b: Node) -> Result<(Node, Node)> {
366 let (a, b) = expand_to_same_dims(a, b)?;
367 let a = preprocess_input(signed_comparison, a)?;
368 let b = preprocess_input(signed_comparison, b)?;
369 Ok((a, b))
370}
371
372fn validate_signed_arguments(custom_op_name: &str, arguments_types: Vec<Type>) -> Result<()> {
376 for (arg_id, arg_type) in arguments_types.iter().enumerate() {
377 if *arg_type.get_shape().last().unwrap() < 2 {
378 return Err(runtime_error!(
379 "{custom_op_name}: Signed input{arg_id} has less than 2 bits"
380 ));
381 }
382 }
383 Ok(())
384}
385
386fn instantiate_comparison_custom_op(
390 context: Context,
391 arguments_types: Vec<Type>,
392 signed_comparison: bool,
393 custom_op_name: &str,
394 post_process_result: impl FnOnce(&ComparisonResult) -> Result<Node>,
395) -> Result<Graph> {
396 validate_arguments_in_broadcast_bit_ops(arguments_types.clone(), custom_op_name)?;
397 if signed_comparison {
398 validate_signed_arguments(custom_op_name, arguments_types.clone())?;
399 }
400
401 let graph = context.create_graph()?;
402 let a = graph.input(arguments_types[0].clone())?;
403 let b = graph.input(arguments_types[1].clone())?;
404
405 let (a, b) = preprocess_inputs(signed_comparison, a, b)?;
406 let result = post_process_result(&build_comparison_graph(a, b)?)?;
407
408 graph.set_output_node(result)?;
409 graph.finalize()?;
410 Ok(graph)
411}
412
413#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
449pub struct GreaterThan {
450 pub signed_comparison: bool,
452}
453
454#[typetag::serde]
455impl CustomOperationBody for GreaterThan {
456 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
457 instantiate_comparison_custom_op(
458 context,
459 arguments_types,
460 self.signed_comparison,
461 &self.get_name(),
462 |res| res.greater_than(),
463 )
464 }
465
466 fn get_name(&self) -> String {
467 format!("GreaterThan(signed_comparison={})", self.signed_comparison)
468 }
469}
470
471#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
505pub struct NotEqual {}
506
507#[typetag::serde]
508impl CustomOperationBody for NotEqual {
509 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
510 instantiate_comparison_custom_op(context, arguments_types, false, &self.get_name(), |res| {
511 res.not_equal()
512 })
513 }
514
515 fn get_name(&self) -> String {
516 "NotEqual".to_owned()
517 }
518}
519
520#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
556pub struct LessThan {
557 pub signed_comparison: bool,
559}
560
561#[typetag::serde]
562impl CustomOperationBody for LessThan {
563 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
564 instantiate_comparison_custom_op(
565 context,
566 arguments_types,
567 self.signed_comparison,
568 &self.get_name(),
569 |res| res.less_than(),
570 )
571 }
572
573 fn get_name(&self) -> String {
574 format!("LessThan(signed_comparison={})", self.signed_comparison)
575 }
576}
577
578#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
614pub struct LessThanEqualTo {
615 pub signed_comparison: bool,
617}
618
619#[typetag::serde]
620impl CustomOperationBody for LessThanEqualTo {
621 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
622 instantiate_comparison_custom_op(
623 context,
624 arguments_types,
625 self.signed_comparison,
626 &self.get_name(),
627 |res| res.less_than_equal_to(),
628 )
629 }
630
631 fn get_name(&self) -> String {
632 format!(
633 "LessThanEqualTo(signed_comparison={})",
634 self.signed_comparison
635 )
636 }
637}
638
639#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
675pub struct GreaterThanEqualTo {
676 pub signed_comparison: bool,
678}
679
680#[typetag::serde]
681impl CustomOperationBody for GreaterThanEqualTo {
682 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
683 instantiate_comparison_custom_op(
684 context,
685 arguments_types,
686 self.signed_comparison,
687 &self.get_name(),
688 |res| res.greater_than_equal_to(),
689 )
690 }
691
692 fn get_name(&self) -> String {
693 format!(
694 "GreaterThanEqualTo(signed_comparison={})",
695 self.signed_comparison
696 )
697 }
698}
699
700#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
734pub struct Equal {}
735
736#[typetag::serde]
737impl CustomOperationBody for Equal {
738 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
739 instantiate_comparison_custom_op(context, arguments_types, false, &self.get_name(), |res| {
740 res.equal()
741 })
742 }
743
744 fn get_name(&self) -> String {
745 "Equal".to_owned()
746 }
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752
753 use crate::broadcast::broadcast_shapes;
754 use crate::custom_ops::run_instantiation_pass;
755 use crate::custom_ops::CustomOperation;
756 use crate::data_types::scalar_type;
757 use crate::data_types::tuple_type;
758 use crate::data_types::ArrayShape;
759 use crate::data_types::{
760 array_type, ScalarType, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8,
761 };
762 use crate::data_values::Value;
763 use crate::evaluators::random_evaluate;
764 use crate::graphs::create_context;
765 use crate::inline::inline_common::DepthOptimizationLevel;
766 use crate::inline::inline_ops::inline_operations;
767 use crate::inline::inline_ops::InlineConfig;
768 use crate::inline::inline_ops::InlineMode;
769
770 fn test_unsigned_greater_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
771 let c = create_context()?;
772 let g = c.create_graph()?;
773 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
774 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
775 let o = g.custom_op(
776 CustomOperation::new(GreaterThan {
777 signed_comparison: false,
778 }),
779 vec![i_a, i_b],
780 )?;
781 g.set_output_node(o)?;
782 g.finalize()?;
783 c.set_main_graph(g.clone())?;
784 c.finalize()?;
785 let mapped_c = run_instantiation_pass(c)?;
786 let v_a = Value::from_flattened_array(&a, BIT)?;
787 let v_b = Value::from_flattened_array(&b, BIT)?;
788 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
789 }
790
791 fn test_signed_greater_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
792 let c = create_context()?;
793 let g = c.create_graph()?;
794 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
795 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
796 let o = g.custom_op(
797 CustomOperation::new(GreaterThan {
798 signed_comparison: true,
799 }),
800 vec![i_a, i_b],
801 )?;
802 g.set_output_node(o)?;
803 g.finalize()?;
804 c.set_main_graph(g.clone())?;
805 c.finalize()?;
806 let mapped_c = run_instantiation_pass(c)?;
807 let v_a = Value::from_flattened_array(&a, BIT)?;
808 let v_b = Value::from_flattened_array(&b, BIT)?;
809 let random_val = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?;
810 let op = random_val.to_u8(BIT)?;
811 Ok(op)
812 }
813
814 fn test_not_equal_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
815 let c = create_context()?;
816 let g = c.create_graph()?;
817 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
818 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
819 let o = g.custom_op(CustomOperation::new(NotEqual {}), vec![i_a, i_b])?;
820 g.set_output_node(o)?;
821 g.finalize()?;
822 c.set_main_graph(g.clone())?;
823 c.finalize()?;
824 let mapped_c = run_instantiation_pass(c)?;
825 let v_a = Value::from_flattened_array(&a, BIT)?;
826 let v_b = Value::from_flattened_array(&b, BIT)?;
827 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
828 }
829
830 fn get_u_scalar_type_from_bits(bit_size: u64) -> Result<ScalarType> {
832 match bit_size {
833 8 => Ok(UINT8),
834 16 => Ok(UINT16),
835 32 => Ok(UINT32),
836 64 => Ok(UINT64),
837 _ => Err(runtime_error!("Unsupported bit size")),
838 }
839 }
840
841 fn get_s_scalar_type_from_bits(bit_size: u64) -> Result<ScalarType> {
843 match bit_size {
844 8 => Ok(INT8),
845 16 => Ok(INT16),
846 32 => Ok(INT32),
847 64 => Ok(INT64),
848 _ => Err(runtime_error!("Unsupported bit size")),
849 }
850 }
851
852 fn test_unsigned_comparison_cust_op_for_vec_helper(
859 comparison_op: CustomOperation,
860 a: Vec<u64>,
861 b: Vec<u64>,
862 shape_a: ArrayShape,
863 shape_b: ArrayShape,
864 ) -> Result<Vec<u64>> {
865 let bit_vector_len_a = shape_a[shape_a.len() - 1];
866 let bit_vector_len_b = shape_b[shape_b.len() - 1];
867 let data_scalar_type_a = get_u_scalar_type_from_bits(bit_vector_len_a)?;
868 let data_scalar_type_b = get_u_scalar_type_from_bits(bit_vector_len_b)?;
869
870 let c = create_context()?;
871 let g = c.create_graph()?;
872 let i_va = g.input(array_type(shape_a.clone(), BIT))?;
873 let i_vb = g.input(array_type(shape_b.clone(), BIT))?;
874 let o = g.custom_op(comparison_op.clone(), vec![i_va, i_vb])?;
875 g.set_output_node(o)?;
876 g.finalize()?;
877 c.set_main_graph(g.clone())?;
878 c.finalize()?;
879 let mapped_c = run_instantiation_pass(c)?;
880 let v_a = Value::from_flattened_array(&a, data_scalar_type_a)?;
882 let v_b = Value::from_flattened_array(&b, data_scalar_type_b)?;
883 let broadcasted_output_shape = broadcast_shapes(
884 shape_a[0..(shape_a.len() - 1)].to_vec(),
885 shape_b[0..(shape_b.len() - 1)].to_vec(),
886 )?;
887
888 let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?
889 .to_flattened_array_u64(array_type(broadcasted_output_shape, BIT))?;
890 Ok(result)
891 }
892
893 fn test_signed_comparison_cust_op_for_vec_helper(
900 comparison_op: CustomOperation,
901 a: Vec<i64>,
902 b: Vec<i64>,
903 shape_a: ArrayShape,
904 shape_b: ArrayShape,
905 ) -> Result<Vec<u64>> {
906 let bit_vector_len_a = shape_a[shape_a.len() - 1];
907 let bit_vector_len_b = shape_b[shape_b.len() - 1];
908 let data_scalar_type_a = get_s_scalar_type_from_bits(bit_vector_len_a)?;
909 let data_scalar_type_b = get_s_scalar_type_from_bits(bit_vector_len_b)?;
910
911 let c = create_context()?;
912 let g = c.create_graph()?;
913 let i_va = g.input(array_type(shape_a.clone(), BIT))?;
914 let i_vb = g.input(array_type(shape_b.clone(), BIT))?;
915 let o = g.custom_op(comparison_op.clone(), vec![i_va, i_vb])?;
916 g.set_output_node(o)?;
917 g.finalize()?;
918 c.set_main_graph(g.clone())?;
919 c.finalize()?;
920 let mapped_c = run_instantiation_pass(c)?;
921 let v_a = Value::from_flattened_array(&a, data_scalar_type_a)?;
923 let v_b = Value::from_flattened_array(&b, data_scalar_type_b)?;
924 let broadcasted_output_shape = broadcast_shapes(
925 shape_a[0..(shape_a.len() - 1)].to_vec(),
926 shape_b[0..(shape_b.len() - 1)].to_vec(),
927 )?;
928
929 let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?
930 .to_flattened_array_u64(array_type(broadcasted_output_shape, BIT))?;
931 Ok(result)
932 }
933
934 fn test_unsigned_less_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
935 let c = create_context()?;
936 let g = c.create_graph()?;
937 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
938 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
939 let o = g.custom_op(
940 CustomOperation::new(LessThan {
941 signed_comparison: false,
942 }),
943 vec![i_a, i_b],
944 )?;
945 g.set_output_node(o)?;
946 g.finalize()?;
947 c.set_main_graph(g.clone())?;
948 c.finalize()?;
949 let mapped_c = run_instantiation_pass(c)?;
950 let v_a = Value::from_flattened_array(&a, BIT)?;
951 let v_b = Value::from_flattened_array(&b, BIT)?;
952 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
953 }
954
955 fn test_signed_less_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
956 let c = create_context()?;
957 let g = c.create_graph()?;
958 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
959 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
960 let o = g.custom_op(
961 CustomOperation::new(LessThan {
962 signed_comparison: true,
963 }),
964 vec![i_a, i_b],
965 )?;
966 g.set_output_node(o)?;
967 g.finalize()?;
968 c.set_main_graph(g.clone())?;
969 c.finalize()?;
970 let mapped_c = run_instantiation_pass(c)?;
971 let v_a = Value::from_flattened_array(&a, BIT)?;
972 let v_b = Value::from_flattened_array(&b, BIT)?;
973 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
974 }
975
976 fn test_unsigned_less_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
977 let c = create_context()?;
978 let g = c.create_graph()?;
979 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
980 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
981 let o = g.custom_op(
982 CustomOperation::new(LessThanEqualTo {
983 signed_comparison: false,
984 }),
985 vec![i_a, i_b],
986 )?;
987 g.set_output_node(o)?;
988 g.finalize()?;
989 c.set_main_graph(g.clone())?;
990 c.finalize()?;
991 let mapped_c = run_instantiation_pass(c)?;
992 let v_a = Value::from_flattened_array(&a, BIT)?;
993 let v_b = Value::from_flattened_array(&b, BIT)?;
994 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
995 }
996
997 fn test_signed_less_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
998 let c = create_context()?;
999 let g = c.create_graph()?;
1000 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1001 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1002 let o = g.custom_op(
1003 CustomOperation::new(LessThanEqualTo {
1004 signed_comparison: true,
1005 }),
1006 vec![i_a, i_b],
1007 )?;
1008 g.set_output_node(o)?;
1009 g.finalize()?;
1010 c.set_main_graph(g.clone())?;
1011 c.finalize()?;
1012 let mapped_c = run_instantiation_pass(c)?;
1013 let v_a = Value::from_flattened_array(&a, BIT)?;
1014 let v_b = Value::from_flattened_array(&b, BIT)?;
1015 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1016 }
1017
1018 fn test_unsigned_greater_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
1019 let c = create_context()?;
1020 let g = c.create_graph()?;
1021 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1022 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1023 let o = g.custom_op(
1024 CustomOperation::new(GreaterThanEqualTo {
1025 signed_comparison: false,
1026 }),
1027 vec![i_a, i_b],
1028 )?;
1029 g.set_output_node(o)?;
1030 g.finalize()?;
1031 c.set_main_graph(g.clone())?;
1032 c.finalize()?;
1033 let mapped_c = run_instantiation_pass(c)?;
1034 let v_a = Value::from_flattened_array(&a, BIT)?;
1035 let v_b = Value::from_flattened_array(&b, BIT)?;
1036 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1037 }
1038
1039 fn test_signed_greater_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
1040 let c = create_context()?;
1041 let g = c.create_graph()?;
1042 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1043 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1044 let o = g.custom_op(
1045 CustomOperation::new(GreaterThanEqualTo {
1046 signed_comparison: true,
1047 }),
1048 vec![i_a, i_b],
1049 )?;
1050 g.set_output_node(o)?;
1051 g.finalize()?;
1052 c.set_main_graph(g.clone())?;
1053 c.finalize()?;
1054 let mapped_c = run_instantiation_pass(c)?;
1055 let v_a = Value::from_flattened_array(&a, BIT)?;
1056 let v_b = Value::from_flattened_array(&b, BIT)?;
1057 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1058 }
1059
1060 fn test_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
1061 let c = create_context()?;
1062 let g = c.create_graph()?;
1063 let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1064 let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1065 let o = g.custom_op(CustomOperation::new(Equal {}), vec![i_a, i_b])?;
1066 g.set_output_node(o)?;
1067 g.finalize()?;
1068 c.set_main_graph(g.clone())?;
1069 c.finalize()?;
1070 let mapped_c = run_instantiation_pass(c)?;
1071 let v_a = Value::from_flattened_array(&a, BIT)?;
1072 let v_b = Value::from_flattened_array(&b, BIT)?;
1073 Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1074 }
1075
1076 #[test]
1077 fn test_greater_than_cust_op() {
1078 || -> Result<()> {
1079 assert_eq!(
1080 test_unsigned_greater_than_cust_op_helper(vec![0], vec![0])?,
1081 0
1082 );
1083 assert_eq!(
1084 test_unsigned_greater_than_cust_op_helper(vec![0], vec![1])?,
1085 0
1086 );
1087 assert_eq!(
1088 test_unsigned_greater_than_cust_op_helper(vec![1], vec![0])?,
1089 1
1090 );
1091 assert_eq!(
1092 test_unsigned_greater_than_cust_op_helper(vec![1], vec![1])?,
1093 0
1094 );
1095 Ok(())
1096 }()
1097 .unwrap();
1098 }
1099
1100 #[test]
1101 fn test_signed_greater_than_cust_op() {
1102 || -> Result<()> {
1103 assert_eq!(
1105 test_signed_greater_than_cust_op_helper(vec![0, 0], vec![0, 0])?,
1106 0
1107 );
1108 assert_eq!(
1109 test_signed_greater_than_cust_op_helper(vec![0, 0], vec![1, 0])?,
1110 0
1111 );
1112 assert_eq!(
1113 test_signed_greater_than_cust_op_helper(vec![1, 0], vec![0, 0])?,
1114 1
1115 );
1116 assert_eq!(
1117 test_signed_greater_than_cust_op_helper(vec![1, 0], vec![1, 0])?,
1118 0
1119 );
1120 assert_eq!(
1122 test_signed_greater_than_cust_op_helper(vec![0, 1], vec![0, 1])?,
1123 0
1124 );
1125 assert_eq!(
1126 test_signed_greater_than_cust_op_helper(vec![0, 1], vec![1, 1])?,
1127 0
1128 );
1129 assert_eq!(
1130 test_signed_greater_than_cust_op_helper(vec![1, 1], vec![0, 1])?,
1131 1
1132 );
1133 assert_eq!(
1134 test_signed_greater_than_cust_op_helper(vec![1, 1], vec![1, 1])?,
1135 0
1136 );
1137 assert_eq!(
1139 test_signed_greater_than_cust_op_helper(vec![0, 1], vec![0, 0])?,
1140 0
1141 );
1142 assert_eq!(
1143 test_signed_greater_than_cust_op_helper(vec![0, 0], vec![0, 1])?,
1144 1
1145 );
1146 assert_eq!(
1147 test_signed_greater_than_cust_op_helper(vec![0, 1], vec![1, 0])?,
1148 0
1149 );
1150 assert_eq!(
1151 test_signed_greater_than_cust_op_helper(vec![0, 0], vec![1, 1])?,
1152 1
1153 );
1154 assert_eq!(
1155 test_signed_greater_than_cust_op_helper(vec![1, 1], vec![0, 0])?,
1156 0
1157 );
1158 assert_eq!(
1159 test_signed_greater_than_cust_op_helper(vec![1, 0], vec![0, 1])?,
1160 1
1161 );
1162 assert_eq!(
1163 test_signed_greater_than_cust_op_helper(vec![1, 1], vec![1, 0])?,
1164 0
1165 );
1166 assert_eq!(
1167 test_signed_greater_than_cust_op_helper(vec![1, 0], vec![1, 1])?,
1168 1
1169 );
1170 Ok(())
1171 }()
1172 .unwrap();
1173 }
1174
1175 #[test]
1176 fn test_unsigned_less_than_cust_op() {
1177 || -> Result<()> {
1178 assert_eq!(test_unsigned_less_than_cust_op_helper(vec![0], vec![0])?, 0);
1179 assert_eq!(test_unsigned_less_than_cust_op_helper(vec![0], vec![1])?, 1);
1180 assert_eq!(test_unsigned_less_than_cust_op_helper(vec![1], vec![0])?, 0);
1181 assert_eq!(test_unsigned_less_than_cust_op_helper(vec![1], vec![1])?, 0);
1182 Ok(())
1183 }()
1184 .unwrap();
1185 }
1186
1187 #[test]
1188 fn test_signed_less_than_cust_op() {
1189 || -> Result<()> {
1190 assert_eq!(
1192 test_signed_less_than_cust_op_helper(vec![0, 0], vec![0, 0])?,
1193 0
1194 );
1195 assert_eq!(
1196 test_signed_less_than_cust_op_helper(vec![0, 0], vec![1, 0])?,
1197 1
1198 );
1199 assert_eq!(
1200 test_signed_less_than_cust_op_helper(vec![1, 0], vec![0, 0])?,
1201 0
1202 );
1203 assert_eq!(
1204 test_signed_less_than_cust_op_helper(vec![1, 0], vec![1, 0])?,
1205 0
1206 );
1207 assert_eq!(
1209 test_signed_less_than_cust_op_helper(vec![0, 1], vec![0, 1])?,
1210 0
1211 );
1212 assert_eq!(
1213 test_signed_less_than_cust_op_helper(vec![0, 1], vec![1, 1])?,
1214 1
1215 );
1216 assert_eq!(
1217 test_signed_less_than_cust_op_helper(vec![1, 1], vec![0, 1])?,
1218 0
1219 );
1220 assert_eq!(
1221 test_signed_less_than_cust_op_helper(vec![1, 1], vec![1, 1])?,
1222 0
1223 );
1224 assert_eq!(
1226 test_signed_less_than_cust_op_helper(vec![0, 1], vec![0, 0])?,
1227 1
1228 );
1229 assert_eq!(
1230 test_signed_less_than_cust_op_helper(vec![0, 0], vec![0, 1])?,
1231 0
1232 );
1233 assert_eq!(
1234 test_signed_less_than_cust_op_helper(vec![0, 1], vec![1, 0])?,
1235 1
1236 );
1237 assert_eq!(
1238 test_signed_less_than_cust_op_helper(vec![0, 0], vec![1, 1])?,
1239 0
1240 );
1241 assert_eq!(
1242 test_signed_less_than_cust_op_helper(vec![1, 1], vec![0, 0])?,
1243 1
1244 );
1245 assert_eq!(
1246 test_signed_less_than_cust_op_helper(vec![1, 0], vec![0, 1])?,
1247 0
1248 );
1249 assert_eq!(
1250 test_signed_less_than_cust_op_helper(vec![1, 1], vec![1, 0])?,
1251 1
1252 );
1253 assert_eq!(
1254 test_signed_less_than_cust_op_helper(vec![1, 0], vec![1, 1])?,
1255 0
1256 );
1257 Ok(())
1258 }()
1259 .unwrap();
1260 }
1261
1262 #[test]
1263 fn test_unsigned_less_than_or_eq_to_cust_op() {
1264 || -> Result<()> {
1265 assert_eq!(
1266 test_unsigned_less_than_equal_to_cust_op_helper(vec![0], vec![0])?,
1267 1
1268 );
1269 assert_eq!(
1270 test_unsigned_less_than_equal_to_cust_op_helper(vec![0], vec![1])?,
1271 1
1272 );
1273 assert_eq!(
1274 test_unsigned_less_than_equal_to_cust_op_helper(vec![1], vec![0])?,
1275 0
1276 );
1277 assert_eq!(
1278 test_unsigned_less_than_equal_to_cust_op_helper(vec![1], vec![1])?,
1279 1
1280 );
1281 Ok(())
1282 }()
1283 .unwrap();
1284 }
1285
1286 #[test]
1287 fn test_signed_less_than_or_eq_to_cust_op() {
1288 || -> Result<()> {
1289 assert_eq!(
1291 test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 0])?,
1292 1
1293 );
1294 assert_eq!(
1295 test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 0])?,
1296 1
1297 );
1298 assert_eq!(
1299 test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 0])?,
1300 0
1301 );
1302 assert_eq!(
1303 test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 0])?,
1304 1
1305 );
1306 assert_eq!(
1308 test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 1])?,
1309 1
1310 );
1311 assert_eq!(
1312 test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 1])?,
1313 1
1314 );
1315 assert_eq!(
1316 test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 1])?,
1317 0
1318 );
1319 assert_eq!(
1320 test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 1])?,
1321 1
1322 );
1323 assert_eq!(
1325 test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 0])?,
1326 1
1327 );
1328 assert_eq!(
1329 test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 1])?,
1330 0
1331 );
1332 assert_eq!(
1333 test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 0])?,
1334 1
1335 );
1336 assert_eq!(
1337 test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 1])?,
1338 0
1339 );
1340 assert_eq!(
1341 test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 0])?,
1342 1
1343 );
1344 assert_eq!(
1345 test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 1])?,
1346 0
1347 );
1348 assert_eq!(
1349 test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 0])?,
1350 1
1351 );
1352 assert_eq!(
1353 test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 1])?,
1354 0
1355 );
1356 Ok(())
1357 }()
1358 .unwrap();
1359 }
1360
1361 #[test]
1362 fn test_unsigned_greater_than_or_eq_to_cust_op() {
1363 || -> Result<()> {
1364 assert_eq!(
1365 test_unsigned_greater_than_equal_to_cust_op_helper(vec![0], vec![0])?,
1366 1
1367 );
1368 assert_eq!(
1369 test_unsigned_greater_than_equal_to_cust_op_helper(vec![0], vec![1])?,
1370 0
1371 );
1372 assert_eq!(
1373 test_unsigned_greater_than_equal_to_cust_op_helper(vec![1], vec![0])?,
1374 1
1375 );
1376 assert_eq!(
1377 test_unsigned_greater_than_equal_to_cust_op_helper(vec![1], vec![1])?,
1378 1
1379 );
1380 Ok(())
1381 }()
1382 .unwrap();
1383 }
1384
1385 #[test]
1386 fn test_signed_greater_than_or_eq_to_cust_op() {
1387 || -> Result<()> {
1388 assert_eq!(
1390 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 0])?,
1391 1
1392 );
1393 assert_eq!(
1394 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 0])?,
1395 0
1396 );
1397 assert_eq!(
1398 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 0])?,
1399 1
1400 );
1401 assert_eq!(
1402 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 0])?,
1403 1
1404 );
1405 assert_eq!(
1407 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 1])?,
1408 1
1409 );
1410 assert_eq!(
1411 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 1])?,
1412 0
1413 );
1414 assert_eq!(
1415 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 1])?,
1416 1
1417 );
1418 assert_eq!(
1419 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 1])?,
1420 1
1421 );
1422 assert_eq!(
1424 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 0])?,
1425 0
1426 );
1427 assert_eq!(
1428 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 1])?,
1429 1
1430 );
1431 assert_eq!(
1432 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 0])?,
1433 0
1434 );
1435 assert_eq!(
1436 test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 1])?,
1437 1
1438 );
1439 assert_eq!(
1440 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 0])?,
1441 0
1442 );
1443 assert_eq!(
1444 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 1])?,
1445 1
1446 );
1447 assert_eq!(
1448 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 0])?,
1449 0
1450 );
1451 assert_eq!(
1452 test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 1])?,
1453 1
1454 );
1455 Ok(())
1456 }()
1457 .unwrap();
1458 }
1459
1460 #[test]
1461 fn test_not_equal_cust_op() {
1462 || -> Result<()> {
1463 assert_eq!(test_not_equal_cust_op_helper(vec![0], vec![0])?, 0);
1464 assert_eq!(test_not_equal_cust_op_helper(vec![0], vec![1])?, 1);
1465 assert_eq!(test_not_equal_cust_op_helper(vec![1], vec![0])?, 1);
1466 assert_eq!(test_not_equal_cust_op_helper(vec![1], vec![1])?, 0);
1467 Ok(())
1468 }()
1469 .unwrap();
1470 }
1471
1472 #[test]
1473 fn test_equal_to_cust_op() {
1474 || -> Result<()> {
1475 assert_eq!(test_equal_to_cust_op_helper(vec![0], vec![0])?, 1);
1476 assert_eq!(test_equal_to_cust_op_helper(vec![0], vec![1])?, 0);
1477 assert_eq!(test_equal_to_cust_op_helper(vec![1], vec![0])?, 0);
1478 assert_eq!(test_equal_to_cust_op_helper(vec![1], vec![1])?, 1);
1479 Ok(())
1480 }()
1481 .unwrap();
1482 }
1483
1484 #[test]
1485 fn test_unsigned_multiple_bit_comparisons_cust_op() {
1486 || -> Result<()> {
1487 for i in 0..8 {
1488 for j in 0..8 {
1489 let a: Vec<u8> = vec![i & 1, (i & 2) >> 1, (i & 4) >> 2];
1490 let b: Vec<u8> = vec![j & 1, (j & 2) >> 1, (j & 4) >> 2];
1491 assert_eq!(
1492 test_unsigned_greater_than_cust_op_helper(a.clone(), b.clone())?,
1493 if i > j { 1 } else { 0 }
1494 );
1495 assert_eq!(
1496 test_unsigned_less_than_cust_op_helper(a.clone(), b.clone())?,
1497 if i < j { 1 } else { 0 }
1498 );
1499 assert_eq!(
1500 test_unsigned_greater_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1501 if i >= j { 1 } else { 0 }
1502 );
1503 assert_eq!(
1504 test_unsigned_less_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1505 if i <= j { 1 } else { 0 }
1506 );
1507 assert_eq!(
1508 test_not_equal_cust_op_helper(a.clone(), b.clone())?,
1509 if i != j { 1 } else { 0 }
1510 );
1511 assert_eq!(
1512 test_equal_to_cust_op_helper(a.clone(), b.clone())?,
1513 if i == j { 1 } else { 0 }
1514 );
1515 }
1516 }
1517 Ok(())
1518 }()
1519 .unwrap();
1520 }
1521
1522 #[test]
1523 fn test_signed_multiple_bit_comparisons_cust_op() {
1524 || -> Result<()> {
1525 for i in 0..8 {
1526 for j in 0..8 {
1527 let a: Vec<u8> = vec![i & 1, (i & 2) >> 1, (i & 4) >> 2];
1528 let b: Vec<u8> = vec![j & 1, (j & 2) >> 1, (j & 4) >> 2];
1529 let s_i = if i > 3 { i as i8 - 8 } else { i as i8 };
1530 let s_j = if j > 3 { j as i8 - 8 } else { j as i8 };
1531 assert_eq!(
1532 test_signed_greater_than_cust_op_helper(a.clone(), b.clone())?,
1533 if s_i > s_j { 1 } else { 0 }
1534 );
1535 assert_eq!(
1536 test_signed_less_than_cust_op_helper(a.clone(), b.clone())?,
1537 if s_i < s_j { 1 } else { 0 }
1538 );
1539 assert_eq!(
1540 test_signed_greater_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1541 if s_i >= s_j { 1 } else { 0 }
1542 );
1543 assert_eq!(
1544 test_signed_less_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1545 if s_i <= s_j { 1 } else { 0 }
1546 );
1547 }
1548 }
1549 Ok(())
1550 }()
1551 .unwrap();
1552 }
1553
1554 #[test]
1555 fn test_unsigned_malformed_basic_cust_ops() {
1556 || -> Result<()> {
1557 let cust_ops = vec![
1558 CustomOperation::new(GreaterThan {
1559 signed_comparison: false,
1560 }),
1561 CustomOperation::new(NotEqual {}),
1562 ];
1563 for cust_op in cust_ops.into_iter() {
1564 let c = create_context()?;
1565 let g = c.create_graph()?;
1566 let i_a = g.input(array_type(vec![1], BIT))?;
1567 let i_b = g.input(array_type(vec![1], BIT))?;
1568 let i_c = g.input(array_type(vec![1], BIT))?;
1569 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b, i_c]).is_err());
1570
1571 let c = create_context()?;
1572 let g = c.create_graph()?;
1573 let i_a = g.input(scalar_type(BIT))?;
1574 let i_b = g.input(array_type(vec![1], BIT))?;
1575 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1576
1577 let c = create_context()?;
1578 let g = c.create_graph()?;
1579 let i_a = g.input(array_type(vec![1], BIT))?;
1580 let i_b = g.input(tuple_type(vec![array_type(vec![1], BIT)]))?;
1581 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1582
1583 let c = create_context()?;
1584 let g = c.create_graph()?;
1585 let i_a = g.input(array_type(vec![1], INT16))?;
1586 let i_b = g.input(array_type(vec![1], BIT))?;
1587 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1588
1589 let c = create_context()?;
1590 let g = c.create_graph()?;
1591 let i_a = g.input(array_type(vec![1], UINT16))?;
1592 let i_b = g.input(array_type(vec![1], BIT))?;
1593 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1594
1595 let c = create_context()?;
1596 let g = c.create_graph()?;
1597 let i_a = g.input(array_type(vec![1], BIT))?;
1598 let i_b = g.input(array_type(vec![1], INT32))?;
1599 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1600
1601 let c = create_context()?;
1602 let g = c.create_graph()?;
1603 let i_a = g.input(array_type(vec![1], BIT))?;
1604 let i_b = g.input(array_type(vec![1], UINT32))?;
1605 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1606
1607 let c = create_context()?;
1608 let g = c.create_graph()?;
1609 let i_a = g.input(array_type(vec![1], BIT))?;
1610 let i_b = g.input(array_type(vec![9], BIT))?;
1611 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1612
1613 let c = create_context()?;
1614 let g = c.create_graph()?;
1615 let i_a = g.input(array_type(vec![1], BIT))?;
1616 let i_b = g.input(array_type(vec![1, 2], BIT))?;
1617 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1618
1619 let v_a = vec![170, 120, 61, 85];
1620 let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
1621 assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1622 cust_op.clone(),
1623 v_a.clone(),
1624 v_b.clone(),
1625 vec![2, 2, 16],
1626 vec![3, 3, 32]
1627 )
1628 .is_err());
1629
1630 let v_a = vec![170];
1631 let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
1632 assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1633 cust_op.clone(),
1634 v_a.clone(),
1635 v_b.clone(),
1636 vec![2, 2, 16],
1637 vec![3, 3, 16]
1638 )
1639 .is_err());
1640
1641 let v_a = vec![];
1642 let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
1643 assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1644 cust_op.clone(),
1645 v_a.clone(),
1646 v_b.clone(),
1647 vec![0, 64],
1648 vec![3, 3, 64]
1649 )
1650 .is_err());
1651
1652 let v_a = vec![170, 200];
1653 let v_b = vec![];
1654 assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1655 cust_op.clone(),
1656 v_a.clone(),
1657 v_b.clone(),
1658 vec![2, 1, 64],
1659 vec![2, 2, 1, 64]
1660 )
1661 .is_err());
1662 }
1663
1664 Ok(())
1665 }()
1666 .unwrap();
1667 }
1668
1669 #[test]
1670 fn test_signed_malformed_basic_cust_ops() {
1671 || -> Result<()> {
1672 let cust_ops = vec![CustomOperation::new(GreaterThan {
1673 signed_comparison: true,
1674 })];
1675 for cust_op in cust_ops.into_iter() {
1676 let c = create_context()?;
1677 let g = c.create_graph()?;
1678 let i_a = g.input(array_type(vec![1], BIT))?;
1679 let i_b = g.input(array_type(vec![1, 1], BIT))?;
1680 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1681
1682 let c = create_context()?;
1683 let g = c.create_graph()?;
1684 let i_a = g.input(array_type(vec![1, 1], BIT))?;
1685 let i_b = g.input(array_type(vec![1], BIT))?;
1686 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1687
1688 let c = create_context()?;
1689 let g = c.create_graph()?;
1690 let i_a = g.input(array_type(vec![1, 64], BIT))?;
1691 let i_b = g.input(array_type(vec![1], BIT))?;
1692 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1693
1694 let c = create_context()?;
1695 let g = c.create_graph()?;
1696 let i_a = g.input(scalar_type(BIT))?;
1697 let i_b = g.input(array_type(vec![1], BIT))?;
1698 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1699
1700 let c = create_context()?;
1701 let g = c.create_graph()?;
1702 let i_a = g.input(array_type(vec![1], UINT16))?;
1703 let i_b = g.input(array_type(vec![1], BIT))?;
1704 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1705
1706 let c = create_context()?;
1707 let g = c.create_graph()?;
1708 let i_a = g.input(array_type(vec![1], BIT))?;
1709 let i_b = g.input(array_type(vec![1], INT32))?;
1710 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1711
1712 let c = create_context()?;
1713 let g = c.create_graph()?;
1714 let i_a = g.input(array_type(vec![1], BIT))?;
1715 let i_b = g.input(array_type(vec![1], UINT32))?;
1716 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1717
1718 let c = create_context()?;
1719 let g = c.create_graph()?;
1720 let i_a = g.input(array_type(vec![1], BIT))?;
1721 let i_b = g.input(array_type(vec![9], BIT))?;
1722 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1723
1724 let c = create_context()?;
1725 let g = c.create_graph()?;
1726 let i_a = g.input(array_type(vec![1, 2, 3], BIT))?;
1727 let i_b = g.input(array_type(vec![9], BIT))?;
1728 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1729
1730 let c = create_context()?;
1731 let g = c.create_graph()?;
1732 let i_a = g.input(array_type(vec![1], BIT))?;
1733 let i_b = g.input(array_type(vec![1, 2], BIT))?;
1734 assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1735
1736 let v_a = vec![170, 120, 61, 85];
1737 let v_b = vec![
1738 -1176658021,
1739 -301476304,
1740 788180273,
1741 -1085188538,
1742 -1358798926,
1743 -120286105,
1744 -1300942710,
1745 -389618936,
1746 258721418,
1747 ];
1748 assert!(test_signed_comparison_cust_op_for_vec_helper(
1749 cust_op.clone(),
1750 v_a.clone(),
1751 v_b.clone(),
1752 vec![2, 2, 16],
1753 vec![3, 3, 32]
1754 )
1755 .is_err());
1756
1757 let v_a = vec![-14735];
1758 let v_b = vec![
1759 16490, -10345, -31409, 2787, -15039, 26085, 7881, 32423, -23915,
1760 ];
1761 assert!(test_signed_comparison_cust_op_for_vec_helper(
1762 cust_op.clone(),
1763 v_a.clone(),
1764 v_b.clone(),
1765 vec![2, 2, 16],
1766 vec![3, 3, 16]
1767 )
1768 .is_err());
1769
1770 let v_a = vec![];
1771 let v_b = vec![
1772 -2600362169875399934,
1773 6278463339984150730,
1774 -2962726308672949899,
1775 3404980137287029349,
1776 ];
1777 assert!(test_signed_comparison_cust_op_for_vec_helper(
1778 cust_op.clone(),
1779 v_a.clone(),
1780 v_b.clone(),
1781 vec![0, 64],
1782 vec![2, 2, 64]
1783 )
1784 .is_err());
1785
1786 let v_a = vec![-2600362169875399934, 6278463339984150730];
1787 let v_b = vec![];
1788 assert!(test_signed_comparison_cust_op_for_vec_helper(
1789 cust_op.clone(),
1790 v_a.clone(),
1791 v_b.clone(),
1792 vec![2, 1, 64],
1793 vec![2, 2, 1, 64]
1794 )
1795 .is_err());
1796 }
1797
1798 Ok(())
1799 }()
1800 .unwrap();
1801 }
1802
1803 #[test]
1804 fn test_unsigned_vector_comparisons() {
1805 || -> Result<()> {
1806 let mut v_a = vec![170, 120, 61, 85];
1807 let mut v_b = vec![
1808 76, 20, 70, 249, 217, 190, 43, 83, 33710, 27637, 43918, 38683,
1809 ];
1810 assert_eq!(
1811 test_unsigned_comparison_cust_op_for_vec_helper(
1812 CustomOperation::new(GreaterThan {
1813 signed_comparison: false
1814 }),
1815 v_a.clone(),
1816 v_b.clone(),
1817 vec![2, 2, 64],
1818 vec![3, 2, 2, 64],
1819 )?,
1820 vec![1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]
1821 );
1822
1823 v_a = vec![170, 120, 61, 85, 75, 149, 50, 54, 8811, 29720, 1009, 33126];
1824 v_b = vec![76, 20, 70, 249, 217, 190];
1825 assert_eq!(
1826 test_unsigned_comparison_cust_op_for_vec_helper(
1827 CustomOperation::new(GreaterThan {
1828 signed_comparison: false
1829 }),
1830 v_a.clone(),
1831 v_b.clone(),
1832 vec![2, 3, 2, 32],
1833 vec![3, 2, 32],
1834 )?,
1835 vec![1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1836 );
1837
1838 v_a = vec![170, 120, 61, 85, 75, 149, 50, 54];
1839 v_b = vec![76, 20, 70, 249];
1840 assert_eq!(
1841 test_unsigned_comparison_cust_op_for_vec_helper(
1842 CustomOperation::new(GreaterThan {
1843 signed_comparison: false
1844 }),
1845 v_a.clone(),
1846 v_b.clone(),
1847 vec![2, 2, 2, 16],
1848 vec![2, 2, 16],
1849 )?,
1850 vec![1, 1, 0, 0, 0, 1, 0, 0]
1851 );
1852
1853 v_a = vec![170, 120, 61, 85, 75, 149, 50, 54];
1854 v_b = vec![76, 20, 70, 249, 217, 190, 43, 83];
1855 assert_eq!(
1856 test_unsigned_comparison_cust_op_for_vec_helper(
1857 CustomOperation::new(GreaterThan {
1858 signed_comparison: false
1859 }),
1860 v_a.clone(),
1861 v_b.clone(),
1862 vec![2, 2, 2, 64],
1863 vec![2, 2, 2, 64],
1864 )?,
1865 vec![1, 1, 0, 0, 0, 0, 1, 0]
1866 );
1867
1868 v_a = vec![170, 120, 61];
1869 v_b = vec![76, 20, 70];
1870 assert_eq!(
1871 test_unsigned_comparison_cust_op_for_vec_helper(
1872 CustomOperation::new(GreaterThan {
1873 signed_comparison: false
1874 }),
1875 v_a.clone(),
1876 v_b.clone(),
1877 vec![3, 64],
1878 vec![3, 64],
1879 )?,
1880 vec![1, 1, 0]
1881 );
1882
1883 v_a = vec![170, 120, 61, 85, 75, 149];
1884 v_b = vec![76, 20, 70];
1885 assert_eq!(
1886 test_unsigned_comparison_cust_op_for_vec_helper(
1887 CustomOperation::new(LessThan {
1888 signed_comparison: false
1889 }),
1890 v_a.clone(),
1891 v_b.clone(),
1892 vec![2, 3, 64],
1893 vec![3, 64],
1894 )?,
1895 vec![0, 0, 1, 0, 0, 0]
1896 );
1897
1898 v_a = vec![170, 120, 61, 85, 75, 70, 50, 54, 8811, 29720, 1009, 33126];
1899 v_b = vec![76, 1009, 70];
1900 assert_eq!(
1901 test_unsigned_comparison_cust_op_for_vec_helper(
1902 CustomOperation::new(LessThanEqualTo {
1903 signed_comparison: false
1904 }),
1905 v_a.clone(),
1906 v_b.clone(),
1907 vec![2, 2, 3, 64],
1908 vec![3, 64],
1909 )?,
1910 vec![0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0]
1911 );
1912
1913 v_a = vec![170, 120, 61, 85, 75, 76, 50, 54];
1914 v_b = vec![76];
1915 assert_eq!(
1916 test_unsigned_comparison_cust_op_for_vec_helper(
1917 CustomOperation::new(GreaterThanEqualTo {
1918 signed_comparison: false
1919 }),
1920 v_a.clone(),
1921 v_b.clone(),
1922 vec![2, 2, 2, 64],
1923 vec![1, 64],
1924 )?,
1925 vec![1, 1, 0, 1, 0, 1, 0, 0]
1926 );
1927
1928 v_a = vec![170];
1929 v_b = vec![76];
1930 assert_eq!(
1931 test_unsigned_comparison_cust_op_for_vec_helper(
1932 CustomOperation::new(GreaterThanEqualTo {
1933 signed_comparison: false
1934 }),
1935 v_a.clone(),
1936 v_b.clone(),
1937 vec![1, 64],
1938 vec![1, 64],
1939 )?,
1940 vec![1]
1941 );
1942
1943 v_a = vec![76];
1944 v_b = vec![76, 170];
1945 assert_eq!(
1946 test_unsigned_comparison_cust_op_for_vec_helper(
1947 CustomOperation::new(GreaterThanEqualTo {
1948 signed_comparison: false
1949 }),
1950 v_a.clone(),
1951 v_b.clone(),
1952 vec![1, 64],
1953 vec![1, 2, 64],
1954 )?,
1955 vec![1, 0]
1956 );
1957
1958 let v_a = vec![83, 172, 214, 2, 68];
1959 let v_b = vec![83];
1960 assert_eq!(
1961 test_unsigned_comparison_cust_op_for_vec_helper(
1962 CustomOperation::new(GreaterThanEqualTo {
1963 signed_comparison: false
1964 }),
1965 v_a,
1966 v_b,
1967 vec![5, 8],
1968 vec![8]
1969 )?,
1970 vec![1, 1, 1, 0, 0]
1971 );
1972
1973 let v_a = vec![2];
1974 let v_b = vec![83, 1, 2, 100];
1975 assert_eq!(
1976 test_unsigned_comparison_cust_op_for_vec_helper(
1977 CustomOperation::new(LessThan {
1978 signed_comparison: false
1979 }),
1980 v_a,
1981 v_b,
1982 vec![1, 32],
1983 vec![2, 2, 32]
1984 )?,
1985 vec![1, 0, 0, 1]
1986 );
1987
1988 let v_a = vec![83, 2];
1989 let v_b = vec![83, 172, 214, 2, 68, 34, 87, 45, 83, 23];
1990 assert_eq!(
1991 test_unsigned_comparison_cust_op_for_vec_helper(
1992 CustomOperation::new(LessThanEqualTo {
1993 signed_comparison: false
1994 }),
1995 v_a,
1996 v_b,
1997 vec![2, 1, 64],
1998 vec![2, 5, 64]
1999 )?,
2000 vec![1, 1, 1, 0, 0, 1, 1, 1, 1, 1]
2001 );
2002
2003 let v_a = vec![83, 2];
2004 let v_b = vec![83, 172, 214, 2, 68, 2, 87, 45];
2005 assert_eq!(
2006 test_unsigned_comparison_cust_op_for_vec_helper(
2007 CustomOperation::new(NotEqual {}),
2008 v_a,
2009 v_b,
2010 vec![2, 1, 64],
2011 vec![2, 4, 64]
2012 )?,
2013 vec![0, 1, 1, 1, 1, 0, 1, 1]
2014 );
2015
2016 let v_a = vec![4, 2];
2017 let v_b = vec![83, 21];
2018 assert_eq!(
2019 test_unsigned_comparison_cust_op_for_vec_helper(
2020 CustomOperation::new(NotEqual {}),
2021 v_a,
2022 v_b,
2023 vec![1, 2, 64],
2024 vec![2, 1, 64]
2025 )?,
2026 vec![1, 1, 1, 1]
2027 );
2028
2029 let v_a = vec![247, 170, 249, 162, 102, 243, 61, 203, 125];
2030 let v_b = vec![247, 170, 249, 162, 102, 243, 61, 203, 125];
2031 assert_eq!(
2032 test_unsigned_comparison_cust_op_for_vec_helper(
2033 CustomOperation::new(NotEqual {}),
2034 v_a,
2035 v_b,
2036 vec![3, 3, 16],
2037 vec![3, 3, 16]
2038 )?,
2039 vec![0, 0, 0, 0, 0, 0, 0, 0, 0]
2040 );
2041
2042 let v_a = vec![83, 2];
2043 let v_b = vec![83, 172, 214, 2, 68, 2, 87, 45];
2044 assert_eq!(
2045 test_unsigned_comparison_cust_op_for_vec_helper(
2046 CustomOperation::new(Equal {}),
2047 v_a,
2048 v_b,
2049 vec![2, 1, 64],
2050 vec![2, 4, 64]
2051 )?,
2052 vec![1, 0, 0, 0, 0, 1, 0, 0]
2053 );
2054
2055 let v_a = vec![4, 2];
2056 let v_b = vec![83, 21];
2057 assert_eq!(
2058 test_unsigned_comparison_cust_op_for_vec_helper(
2059 CustomOperation::new(Equal {}),
2060 v_a,
2061 v_b,
2062 vec![1, 2, 64],
2063 vec![2, 1, 64]
2064 )?,
2065 vec![0, 0, 0, 0]
2066 );
2067
2068 let v_a = vec![180, 16, 62, 141, 122, 217];
2069 let v_b = vec![141, 122, 217, 100, 11, 29];
2070 assert_eq!(
2071 test_unsigned_comparison_cust_op_for_vec_helper(
2072 CustomOperation::new(Equal {}),
2073 v_a,
2074 v_b,
2075 vec![3, 2, 1, 16],
2076 vec![1, 2, 3, 16]
2077 )?,
2078 vec![
2079 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 ]
2086 );
2087
2088 let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
2089 let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
2090 assert_eq!(
2091 test_unsigned_comparison_cust_op_for_vec_helper(
2092 CustomOperation::new(GreaterThan {
2093 signed_comparison: false
2094 }),
2095 v_a.clone(),
2096 v_b.clone(),
2097 vec![4, 1, 64],
2098 vec![1, 4, 64],
2099 )?,
2100 vec![0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0]
2101 );
2102
2103 let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
2104 let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
2105 assert_eq!(
2106 test_unsigned_comparison_cust_op_for_vec_helper(
2107 CustomOperation::new(GreaterThanEqualTo {
2108 signed_comparison: false
2109 }),
2110 v_a.clone(),
2111 v_b.clone(),
2112 vec![4, 1, 64],
2113 vec![1, 4, 64],
2114 )?,
2115 vec![1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1]
2116 );
2117
2118 let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
2119 let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
2120 assert_eq!(
2121 test_unsigned_comparison_cust_op_for_vec_helper(
2122 CustomOperation::new(NotEqual {}),
2123 v_a.clone(),
2124 v_b.clone(),
2125 vec![4, 1, 64],
2126 vec![1, 4, 64],
2127 )?,
2128 vec![0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
2129 );
2130
2131 Ok(())
2132 }()
2133 .unwrap();
2134 }
2135
2136 #[test]
2137 fn test_signed_vector_comparisons() {
2138 || -> Result<()> {
2139 let v_a = vec![
2140 -9223372036854775808,
2141 -9223372036854775807,
2142 -1,
2143 0,
2144 1,
2145 9223372036854775806,
2146 9223372036854775807,
2147 ];
2148 let v_b = vec![
2149 -9223372036854775808,
2150 -9223372036854775807,
2151 -1,
2152 0,
2153 1,
2154 9223372036854775806,
2155 9223372036854775807,
2156 ];
2157 assert_eq!(
2158 test_signed_comparison_cust_op_for_vec_helper(
2159 CustomOperation::new(GreaterThan {
2160 signed_comparison: true
2161 }),
2162 v_a.clone(),
2163 v_b.clone(),
2164 vec![7, 1, 64],
2165 vec![1, 7, 64],
2166 )?,
2167 vec![
2168 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
2169 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0
2170 ]
2171 );
2172
2173 let v_a = vec![
2174 -9223372036854775808,
2175 -9223372036854775807,
2176 -1,
2177 0,
2178 1,
2179 9223372036854775806,
2180 9223372036854775807,
2181 ];
2182 let v_b = vec![
2183 -9223372036854775808,
2184 -9223372036854775807,
2185 -1,
2186 0,
2187 1,
2188 9223372036854775806,
2189 9223372036854775807,
2190 ];
2191 assert_eq!(
2192 test_signed_comparison_cust_op_for_vec_helper(
2193 CustomOperation::new(GreaterThanEqualTo {
2194 signed_comparison: true
2195 }),
2196 v_a.clone(),
2197 v_b.clone(),
2198 vec![7, 1, 64],
2199 vec![1, 7, 64],
2200 )?,
2201 vec![
2202 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0,
2203 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1
2204 ]
2205 );
2206
2207 let mut v_a = vec![-6749, -1885, 7550, 9659];
2208 let mut v_b = vec![
2209 9918, 3462, -5690, 3436, 3214, -1733, 6171, 3148, -3534, 8282, -4904, -5976,
2210 ];
2211 assert_eq!(
2212 test_signed_comparison_cust_op_for_vec_helper(
2213 CustomOperation::new(GreaterThan {
2214 signed_comparison: true
2215 }),
2216 v_a.clone(),
2217 v_b.clone(),
2218 vec![2, 2, 64],
2219 vec![3, 2, 2, 64],
2220 )?,
2221 vec![0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
2222 );
2223
2224 v_a = vec![
2225 -48, -9935, -745, 2360, -4597, -5271, 5130, -2632, 3112, 8089, 8293, 6058,
2226 ];
2227 v_b = vec![2913, 7260, 1388, 6205, 1855, 3246];
2228 assert_eq!(
2229 test_signed_comparison_cust_op_for_vec_helper(
2230 CustomOperation::new(GreaterThan {
2231 signed_comparison: true
2232 }),
2233 v_a.clone(),
2234 v_b.clone(),
2235 vec![2, 3, 2, 32],
2236 vec![3, 2, 32],
2237 )?,
2238 vec![0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]
2239 );
2240
2241 v_a = vec![9838, -574, -4181, -8107, -2880, -2866, 2272, 3743];
2242 v_b = vec![626, 4664, 1490, -5118, 7485, 6160, 4221, 2092];
2243 assert_eq!(
2244 test_signed_comparison_cust_op_for_vec_helper(
2245 CustomOperation::new(GreaterThan {
2246 signed_comparison: true
2247 }),
2248 v_a.clone(),
2249 v_b.clone(),
2250 vec![2, 2, 2, 64],
2251 vec![2, 2, 2, 64],
2252 )?,
2253 vec![1, 0, 0, 0, 0, 0, 0, 1]
2254 );
2255
2256 v_a = vec![-75, 95, -84, 67, -81, 14];
2257 v_b = vec![-78, 21, -66];
2258 assert_eq!(
2259 test_signed_comparison_cust_op_for_vec_helper(
2260 CustomOperation::new(LessThan {
2261 signed_comparison: true
2262 }),
2263 v_a.clone(),
2264 v_b.clone(),
2265 vec![2, 3, 8],
2266 vec![3, 8],
2267 )?,
2268 vec![0, 0, 1, 0, 1, 0]
2269 );
2270
2271 v_a = vec![-52, -119, 30, -24, -74, -45, 66, 110, 21, 1, 95, -66];
2272 v_b = vec![33, -78, 39];
2273 assert_eq!(
2274 test_signed_comparison_cust_op_for_vec_helper(
2275 CustomOperation::new(LessThanEqualTo {
2276 signed_comparison: true
2277 }),
2278 v_a.clone(),
2279 v_b.clone(),
2280 vec![2, 2, 3, 8],
2281 vec![3, 8],
2282 )?,
2283 vec![1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1]
2284 );
2285
2286 v_a = vec![-128, 127, 0, 1, 0, -128, 1, 127];
2287 v_b = vec![-128];
2288 assert_eq!(
2289 test_signed_comparison_cust_op_for_vec_helper(
2290 CustomOperation::new(GreaterThanEqualTo {
2291 signed_comparison: true
2292 }),
2293 v_a.clone(),
2294 v_b.clone(),
2295 vec![2, 2, 2, 8],
2296 vec![1, 8],
2297 )?,
2298 vec![1, 1, 1, 1, 1, 1, 1, 1]
2299 );
2300
2301 v_a = vec![-128, 127, 0, 1, 0, -128, 1, 127];
2302 v_b = vec![-128];
2303 assert_eq!(
2304 test_signed_comparison_cust_op_for_vec_helper(
2305 CustomOperation::new(GreaterThan {
2306 signed_comparison: true
2307 }),
2308 v_a.clone(),
2309 v_b.clone(),
2310 vec![2, 2, 2, 8],
2311 vec![1, 8],
2312 )?,
2313 vec![0, 1, 1, 1, 1, 0, 1, 1]
2314 );
2315
2316 Ok(())
2317 }()
2318 .unwrap();
2319 }
2320
2321 #[test]
2322 fn test_comparison_graph_size() -> Result<()> {
2323 let mut custom_ops = vec![];
2324 custom_ops.push(CustomOperation::new(Equal {}));
2325 custom_ops.push(CustomOperation::new(NotEqual {}));
2326 for &signed_comparison in [false, true].iter() {
2327 custom_ops.push(CustomOperation::new(GreaterThan { signed_comparison }));
2328 custom_ops.push(CustomOperation::new(LessThan { signed_comparison }));
2329 custom_ops.push(CustomOperation::new(GreaterThanEqualTo {
2330 signed_comparison,
2331 }));
2332 custom_ops.push(CustomOperation::new(LessThanEqualTo { signed_comparison }));
2333 }
2334
2335 for custom_op in custom_ops.into_iter() {
2336 let c = create_context()?;
2337 let g = c.create_graph()?;
2338 let i_a = g.input(array_type(vec![64], BIT))?;
2339 let i_b = g.input(array_type(vec![64], BIT))?;
2340 let o = g.custom_op(custom_op, vec![i_a, i_b])?;
2341 g.set_output_node(o)?;
2342 g.finalize()?;
2343
2344 c.set_main_graph(g.clone())?;
2345 c.finalize()?;
2346
2347 let inline_config = InlineConfig {
2348 default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
2349 ..Default::default()
2350 };
2351 let instantiated_context = run_instantiation_pass(c)?.get_context();
2352 let inlined_context = inline_operations(instantiated_context, inline_config.clone())?;
2353 let num_nodes = inlined_context.get_main_graph()?.get_num_nodes();
2354
2355 assert!(num_nodes <= 200);
2356 }
2357 Ok(())
2358 }
2359}