ciphercore_base/ops/
comparisons.rs

1//! Various comparison functions for signed and unsigned integers including greater-than, less-than, greater-than-equal-to, less-than-equal-to, equal, not-equal.
2use crate::custom_ops::{CustomOperation, CustomOperationBody, Not};
3use crate::data_types::{array_type, scalar_type, ArrayShape, Type, BIT};
4use crate::data_values::Value;
5use crate::errors::Result;
6use crate::graphs::*;
7use crate::ops::utils::pull_out_bits;
8use crate::ops::utils::{expand_dims, validate_arguments_in_broadcast_bit_ops};
9use std::cmp::max;
10
11use serde::{Deserialize, Serialize};
12
13use super::utils::unsqueeze;
14
15/// All comparison operations are built on top of a general comparison graph.
16///
17/// For each pair of numbers, which are compared, we need to distinguish 3 possible
18/// options:
19/// - the first number is smaller
20/// - numbers are equal
21/// - the first number is larger
22///
23/// To represent it we use 2 bits, which leads to 4 possible options, which doesn't
24/// map nicely to 3, so we always have one redundant state.
25///
26/// There are several possible ways to choose what those two bits represent. We
27/// chose the way which, could be a little bit confusing, but which minimizes the
28/// number of network rounds needed to compute the final result.
29///
30/// Let's say we want to compare numbers `a` and `b`. We can find a pair of bits
31/// `(a', b')` such that the compare result for `a` and `b` is the same as result
32/// for `a'` and `b'`.
33/// - if `a < b`, `(a', b')` must be `(0, 1)`.
34/// - if `a > b`, `(a', b')` must be `(1, 0)`.
35/// - if `a == b`, `(a', b')` could be `(0, 0)` or `(1, 1)`.
36///
37/// This way of representing the state is nice, because in the case of bits comparison,
38/// we can just use `a` and `b` itself without any additional modifications.
39///
40/// In our algorithm we use pair `(a' == b', a')` because it is easier to recompute,
41/// but still doesn't require network communication for initialization.
42///
43/// When those two bits are known, it is possible to compute the results of all
44/// other comparison operations.
45///
46/// # How is this graph computed?
47///
48/// - First, we `permuteAxes` so the innermost dimension, which stores bits of
49/// numbers become outermost. After that, each component of the array corresponds
50/// to one specific bit.
51///
52/// - Second, if signed numbers are compared, the most significant bit is inverted.
53/// It turns out, if we invert that bit, and then compare numbers as usual unsigned
54/// numbers, the result will be correct. See the [`flip_msb`] documentation to
55/// get more details about this fact.
56///
57/// - Third, we generate an instance of the `ComparisonResult` for bits.
58///
59/// - Later, we iteratively shrink `ComparisonResult` to get the result for a whole number
60/// instead of the result for each separate bit. On each iteration, we split the result into
61/// odd-indexed bits and even-indexed bits, and combine them. If the number of components
62/// is not divisible by two, we cut the first component and join it to the existing
63/// result.
64///
65/// # Example
66///
67/// Let's compare two 5-bit unsigned integers 15 (= 8 + 4 + 2 + 1) and 20 (= 16 + 4).
68///
69/// | bits     | 0 | 1 | 2 | 3 | 4 |
70/// |----------|---|---|---|---|---|
71/// |    a     | 1 | 1 | 1 | 1 | 0 |
72/// |    b     | 0 | 0 | 1 | 0 | 1 |
73/// | a' == b' | 0 | 0 | 1 | 0 | 0 |
74///
75/// We have 5 components, so 0-th is saved to `res`, 1-st and 2-nd are joined,
76/// as well as 3rd and 4th. When components are joined, higher bits have priority.
77/// So if we already found that `a != b` based on higher bits, we use that
78/// result. Otherwise, use results from smaller bits.
79///
80/// | components | res | 1..2 | 3..4 |
81/// |------------|-----|------|------|
82/// |     a'     |  1  |   1  |   0  |
83/// |  a' == b'  |  0  |   0  |   0  |
84///
85/// Number of components is divisible by two, so we only join 1..2, and 3..4,
86/// and not change the `res`. We already know the result in the group 3..4,
87/// so just copy it.
88///
89/// | components | res | 1..4 |
90/// |------------|-----|------|
91/// |     a'     |  1  |   0  |
92/// |  a' == b'  |  0  |   0  |
93///
94/// Only one component left, so join it to `res`.
95///
96/// | components | res |
97/// |------------|-----|
98/// |     a'     |  0  |
99/// |  a' == b'  |  0  |
100///
101/// We know that `a' == b'` is `false` and `a'` is `0`. Based on that we can
102/// compute the results for other comparison functions.
103///
104/// For example, if we want to know if `a < b` we can compute `(not a) and (not (a == b))`.
105///
106#[derive(Clone)]
107struct ComparisonResult {
108    a_equal_b: Node,
109    a: Node,
110}
111
112struct ShrinkResult {
113    shrinked: Option<ComparisonResult>,
114    remainder: Option<ComparisonResult>,
115}
116
117impl ComparisonResult {
118    fn from_a_b(a: Node, b: Node) -> Result<Self> {
119        let graph = a.get_graph();
120        let one = graph.ones(scalar_type(BIT))?;
121
122        let a_equal_b = a.add(b)?.add(one)?;
123        Ok(Self { a_equal_b, a })
124    }
125
126    /// `rhs` has higher priority. So if `rhs.a_is_smaller` or `rhs.b_is_smaller`
127    /// is set to 1 for a specific position, this value is used. Otherwise, values
128    /// from `self` are used.
129    ///
130    /// Multiplication depth of formulas here is 1, which is important for performance
131    /// reasons.
132    fn join(&self, rhs: &Self) -> Result<Self> {
133        let graph = &self.a_equal_b.get_graph();
134        let one = graph.ones(scalar_type(BIT))?;
135
136        let a = self
137            .a
138            .multiply(rhs.a_equal_b.clone())?
139            .add(rhs.a.multiply(rhs.a_equal_b.add(one)?)?)?;
140
141        let a_equal_b = self.a_equal_b.multiply(rhs.a_equal_b.clone())?;
142
143        Ok(Self { a_equal_b, a })
144    }
145
146    /// Joins even-indexed and odd-indexed values
147    /// If the number of elements is not divisible by two,
148    /// the first element is returned in `ShrinkResult.remainder`.
149    fn shrink(&self) -> Result<ShrinkResult> {
150        let bit_len = self.a_equal_b.get_type()?.get_shape()[0] as i64;
151        let offset = bit_len % 2;
152        let remainder = if offset == 0 {
153            None
154        } else {
155            Some(Self {
156                a_equal_b: self.a_equal_b.get(vec![0])?,
157                a: self.a.get(vec![0])?,
158            })
159        };
160        let shrinked = if bit_len <= 1 {
161            None
162        } else {
163            let slice0 = self.sub_slice(offset, bit_len)?;
164            let slice1 = self.sub_slice(offset + 1, bit_len)?;
165            Some(slice0.join(&slice1)?)
166        };
167
168        Ok(ShrinkResult {
169            shrinked,
170            remainder,
171        })
172    }
173
174    /// Returns every second element starting from `start_offset`
175    fn sub_slice(&self, start_offset: i64, bit_len: i64) -> Result<Self> {
176        // TODO: at some point this could become the slowest part, as getting
177        // every second element is not efficient. If we have an efficient way to
178        // reorder elements of the array at the beginning, we potentially could
179        // do it an a way, such that later all splits will have the form
180        // [0..len/2] and [len/2..len].
181        // But currently the slowest part is `permuteAxes` and there is no good
182        // way of permutating the array, so not optimizing it right now.
183        let get_slice = |node: &Node| {
184            node.get_slice(vec![SliceElement::SubArray(
185                Some(start_offset),
186                Some(bit_len),
187                Some(2),
188            )])
189        };
190        Ok(Self {
191            a_equal_b: get_slice(&self.a_equal_b)?,
192            a: get_slice(&self.a)?,
193        })
194    }
195
196    fn not_a(&self) -> Result<Node> {
197        let graph = self.a_equal_b.get_graph();
198        graph.custom_op(CustomOperation::new(Not {}), vec![self.a.clone()])
199    }
200
201    fn equal(&self) -> Result<Node> {
202        Ok(self.a_equal_b.clone())
203    }
204
205    fn not_equal(&self) -> Result<Node> {
206        let graph = self.a_equal_b.get_graph();
207        graph.custom_op(CustomOperation::new(Not {}), vec![self.equal()?])
208    }
209
210    fn less_than(&self) -> Result<Node> {
211        self.not_a()?.multiply(self.not_equal()?)
212    }
213
214    fn greater_than(&self) -> Result<Node> {
215        self.a.multiply(self.not_equal()?)
216    }
217
218    fn greater_than_equal_to(&self) -> Result<Node> {
219        let graph = self.a_equal_b.get_graph();
220        graph.custom_op(CustomOperation::new(Not {}), vec![self.less_than()?])
221    }
222
223    fn less_than_equal_to(&self) -> Result<Node> {
224        let graph = self.a_equal_b.get_graph();
225        graph.custom_op(CustomOperation::new(Not {}), vec![self.greater_than()?])
226    }
227}
228
229/// See [`ComparisonResult`].
230///
231/// `a` and `b` should have type `Array` with bits pulled out to the outermost dimension.
232/// Inputs are interpreted as unsigned numbers. The number of bits should be the same
233/// in `a` and `b`.
234fn build_comparison_graph(a: Node, b: Node) -> Result<ComparisonResult> {
235    let mut to_shrink = ComparisonResult::from_a_b(a, b)?;
236    let mut remainders = vec![];
237    loop {
238        let shrink_res = to_shrink.shrink()?;
239
240        if let Some(remainder) = shrink_res.remainder {
241            remainders.push(remainder);
242        }
243
244        if let Some(shrinked) = shrink_res.shrinked {
245            to_shrink = shrinked;
246        } else {
247            break;
248        }
249    }
250
251    let mut res = remainders[0].clone();
252    for remainder in remainders[1..].iter() {
253        res = res.join(remainder)?;
254    }
255    Ok(res)
256}
257
258/// As we support broadcasting in comparison arguments, we need to make
259/// sure they are still broadcastable after bits are pulled out to the
260/// outermost dimension.
261///
262/// Say, we want to compare array of size `[2, 3, 64]` and array of size
263/// `[3, 64]`. This is a valid operation because `[3, 64]` could be broadcasted
264/// to `[2, 3, 64]`.
265///
266/// After pulling out bits, we get shapes `[64, 2, 3]` and `[64, 3]`, which are not
267/// broadcastable anymore.
268///
269/// To fix this, we convert `[3, 64]` into `[1, 3, 64]` first. After pulling out
270/// bits, shape is `[64, 1, 3]`, which could be broadcasted to `[64, 2, 3]`.
271fn expand_to_same_dims(a: Node, b: Node) -> Result<(Node, Node)> {
272    let len_a = a.get_type()?.get_shape().len();
273    let len_b = b.get_type()?.get_shape().len();
274    let result_len = max(len_a, len_b);
275    let a = expand_dims(a, &(0..result_len - len_a).collect::<Vec<_>>())?;
276    let b = expand_dims(b, &(0..result_len - len_b).collect::<Vec<_>>())?;
277    Ok((a, b))
278}
279
280/// - This function flips all values of the input Array's last component,
281/// which correspond to MSB bit (after `pull_out_bits`), to enable the
282/// signed comparisons.
283///
284/// - Why bit flip is sufficient for obtaining signed comparisons given
285/// unsigned comparison functionality? Please see below example:
286///
287///
288/// |sign bit MSB|  b1|  b0| unsigned value|   signed value|
289/// |------------|----|----|---------------|---------------|
290/// |           0|   0|   0|              0|              0|
291/// |           0|   0|   1|              1|              1|
292/// |           0|   1|   0|              2|              2|
293/// |           0|   1|   1|              3|              3|
294/// |           1|   0|   0|              4|             -4|
295/// |           1|   0|   1|              5|             -3|
296/// |           1|   1|   0|              6|             -2|
297/// |           1|   1|   1|              7|             -1|
298/// --------------------------------------------------------
299///
300/// - From the table, it can be seen that simply flipping the Most Significant Bit
301/// followed by doing unsigned comparison operation can provide the result achieved
302/// by performing the signed operation before the flipping.
303///
304/// - e.g. For both positive inputs,
305/// (2, 3)->(010, 011)-FlipMSB->(110, 111)-unsignedGreaterThan-> false
306/// unsigned comparison over them gives signed result
307///
308/// - e.g. For positive and negative inputs,
309/// (3, -4)->(011, 100)-FlipMSB->(111, 000)-unsignedGreaterThan-> true
310///
311/// - e.g. For negative and positive inputs,
312/// (-3, 3)->(101, 011)-FlipMSB->(001, 111)-unsignedGreaterThan-> false
313///
314/// - e.g. For both negative inputs,
315/// (-3 > -4) -> (101>100) -flipMSB-> (001>000)-unsignedGreaterThan-> true
316///
317/// - Once the MSB bit is flipped, and reattached, unsigned operations can be
318/// done on signed, MSB flipped inputs to enable signed comparisons
319///
320/// There is also another way of thinking about why bit flip is enough.
321///
322/// Let `A` be an unsigned integer with the following binary representation
323/// `A = | a_(n-1) | ... | a_0 |`.
324/// Let `a` be a signed integer with the following binary representations
325/// `a = | sign_bit | a_(n-1) | ... | a_0 | = A - sign_bit * 2^n`.
326/// Flipping the sign bit and recasting to unsigned results in a shift by `2^n`, i.e.
327/// `flip_msb(a) = A + (1 - sign_bit) * 2^n = a + 2^n`.
328///
329/// Here we xor the MSB bit with 1 to flip it. It is more efficient than do slice + concat.
330/// We rely on broadcasting to avoid the huge constants in graph.
331pub(super) fn flip_msb(ip: Node) -> Result<Node> {
332    ip.add(get_msb_flip_constant(
333        ip.get_type()?.get_shape(),
334        &ip.get_graph(),
335    )?)
336}
337
338fn get_msb_flip_constant(shape: ArrayShape, g: &Graph) -> Result<Node> {
339    let n = shape[shape.len() - 1] as usize;
340    let mut msb_mask = vec![0; n];
341    msb_mask[n - 1] = 1;
342    let mut msb_mask = g.constant(
343        array_type(vec![n as u64], BIT),
344        Value::from_flattened_array(&msb_mask, BIT)?,
345    )?;
346    while msb_mask.get_type()?.get_shape().len() < shape.len() {
347        msb_mask = unsqueeze(msb_mask, 0)?;
348    }
349    Ok(msb_mask)
350}
351
352/// This function pulls out bits to outermost dimension, and flips MSB for
353/// signed comparisons.
354///
355/// See [`flip_msb`] and [`ComparisonResult`] for details.
356fn preprocess_input(signed_comparison: bool, node: Node) -> Result<Node> {
357    let node = if signed_comparison {
358        flip_msb(node)?
359    } else {
360        node
361    };
362    pull_out_bits(node)
363}
364
365fn preprocess_inputs(signed_comparison: bool, a: Node, b: Node) -> Result<(Node, Node)> {
366    let (a, b) = expand_to_same_dims(a, b)?;
367    let a = preprocess_input(signed_comparison, a)?;
368    let b = preprocess_input(signed_comparison, b)?;
369    Ok((a, b))
370}
371
372/// This function validates if the `arguments_types` are suitable for the
373/// intended signed custom operation. E.g. there should be at least `2` bits
374/// i.e. ( magnitude + sign )
375fn validate_signed_arguments(custom_op_name: &str, arguments_types: Vec<Type>) -> Result<()> {
376    for (arg_id, arg_type) in arguments_types.iter().enumerate() {
377        if *arg_type.get_shape().last().unwrap() < 2 {
378            return Err(runtime_error!(
379                "{custom_op_name}: Signed input{arg_id} has less than 2 bits"
380            ));
381        }
382    }
383    Ok(())
384}
385
386/// This function first builds a generic comparison graph, and then
387/// calls `post_process_result` to obtain the final result.
388/// This functions handles pre-processing of input types to support vectorized inputs.
389fn instantiate_comparison_custom_op(
390    context: Context,
391    arguments_types: Vec<Type>,
392    signed_comparison: bool,
393    custom_op_name: &str,
394    post_process_result: impl FnOnce(&ComparisonResult) -> Result<Node>,
395) -> Result<Graph> {
396    validate_arguments_in_broadcast_bit_ops(arguments_types.clone(), custom_op_name)?;
397    if signed_comparison {
398        validate_signed_arguments(custom_op_name, arguments_types.clone())?;
399    }
400
401    let graph = context.create_graph()?;
402    let a = graph.input(arguments_types[0].clone())?;
403    let b = graph.input(arguments_types[1].clone())?;
404
405    let (a, b) = preprocess_inputs(signed_comparison, a, b)?;
406    let result = post_process_result(&build_comparison_graph(a, b)?)?;
407
408    graph.set_output_node(result)?;
409    graph.finalize()?;
410    Ok(graph)
411}
412
413/// A structure that defines the custom operation GreaterThan that compares arrays of binary strings elementwise as follows:
414///
415/// If a and b are two bitstrings, then GreaterThan(a,b) = 1 if a > b and 0 otherwise.
416///
417/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
418/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
419/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2]`.
420///
421/// To compare signed numbers, `signed_comparison` should be set `true`.
422///
423/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
424///
425/// # Custom operation arguments
426///
427/// - Node containing a binary array or scalar
428/// - Node containing a binary array or scalar
429///
430/// # Custom operation returns
431///
432/// New GreaterThan node
433///
434/// # Example
435///
436/// ```
437/// # use ciphercore_base::graphs::create_context;
438/// # use ciphercore_base::data_types::{array_type, BIT};
439/// # use ciphercore_base::custom_ops::{CustomOperation};
440/// # use ciphercore_base::ops::comparisons::GreaterThan;
441/// let c = create_context().unwrap();
442/// let g = c.create_graph().unwrap();
443/// let t = array_type(vec![2, 3], BIT);
444/// let n1 = g.input(t.clone()).unwrap();
445/// let n2 = g.input(t.clone()).unwrap();
446/// let n3 = g.custom_op(CustomOperation::new(GreaterThan {signed_comparison: false}), vec![n1, n2]).unwrap();
447/// ```
448#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
449pub struct GreaterThan {
450    /// Boolean value indicating whether input bitstring represent signed integers
451    pub signed_comparison: bool,
452}
453
454#[typetag::serde]
455impl CustomOperationBody for GreaterThan {
456    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
457        instantiate_comparison_custom_op(
458            context,
459            arguments_types,
460            self.signed_comparison,
461            &self.get_name(),
462            |res| res.greater_than(),
463        )
464    }
465
466    fn get_name(&self) -> String {
467        format!("GreaterThan(signed_comparison={})", self.signed_comparison)
468    }
469}
470
471/// A structure that defines the custom operation NotEqual that compares arrays of binary strings elementwise as follows:
472///
473/// If a and b are two bitstrings, then NotEqual(a,b) = 1 if a != b and 0 otherwise.
474///
475/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
476/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
477/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2]`.
478///
479/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
480///
481/// # Custom operation arguments
482///
483/// - Node containing a binary array or scalar
484/// - Node containing a binary array or scalar
485///
486/// # Custom operation returns
487///
488/// New NotEqual node
489///
490/// # Example
491///
492/// ```
493/// # use ciphercore_base::graphs::create_context;
494/// # use ciphercore_base::data_types::{array_type, BIT};
495/// # use ciphercore_base::custom_ops::{CustomOperation};
496/// # use ciphercore_base::ops::comparisons::NotEqual;
497/// let c = create_context().unwrap();
498/// let g = c.create_graph().unwrap();
499/// let t = array_type(vec![2, 3], BIT);
500/// let n1 = g.input(t.clone()).unwrap();
501/// let n2 = g.input(t.clone()).unwrap();
502/// let n3 = g.custom_op(CustomOperation::new(NotEqual {}), vec![n1, n2]).unwrap();
503/// ```
504#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
505pub struct NotEqual {}
506
507#[typetag::serde]
508impl CustomOperationBody for NotEqual {
509    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
510        instantiate_comparison_custom_op(context, arguments_types, false, &self.get_name(), |res| {
511            res.not_equal()
512        })
513    }
514
515    fn get_name(&self) -> String {
516        "NotEqual".to_owned()
517    }
518}
519
520/// A structure that defines the custom operation LessThan that compares arrays of binary strings elementwise as follows:
521///
522/// If a and b are two bitstrings, then LessThan(a,b) = 1 if a < b and 0 otherwise.
523///
524/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
525/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
526/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2]`.
527///
528/// To compare signed numbers, `signed_comparison` should be set `true`.
529///
530/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
531///
532/// # Custom operation arguments
533///
534/// - Node containing a binary array or scalar
535/// - Node containing a binary array or scalar
536///
537/// # Custom operation returns
538///
539/// New LessThan node
540///
541/// # Example
542///
543/// ```
544/// # use ciphercore_base::graphs::create_context;
545/// # use ciphercore_base::data_types::{array_type, BIT};
546/// # use ciphercore_base::custom_ops::{CustomOperation};
547/// # use ciphercore_base::ops::comparisons::LessThan;
548/// let c = create_context().unwrap();
549/// let g = c.create_graph().unwrap();
550/// let t = array_type(vec![2, 3], BIT);
551/// let n1 = g.input(t.clone()).unwrap();
552/// let n2 = g.input(t.clone()).unwrap();
553/// let n3 = g.custom_op(CustomOperation::new(LessThan {signed_comparison: true}), vec![n1, n2]).unwrap();
554/// ```
555#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
556pub struct LessThan {
557    /// Boolean value indicating whether input bitstring represent signed integers
558    pub signed_comparison: bool,
559}
560
561#[typetag::serde]
562impl CustomOperationBody for LessThan {
563    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
564        instantiate_comparison_custom_op(
565            context,
566            arguments_types,
567            self.signed_comparison,
568            &self.get_name(),
569            |res| res.less_than(),
570        )
571    }
572
573    fn get_name(&self) -> String {
574        format!("LessThan(signed_comparison={})", self.signed_comparison)
575    }
576}
577
578/// A structure that defines the custom operation LessThanEqualTo that compares arrays of binary strings elementwise as follows:
579///
580/// If a and b are two bitstrings, then LessThanEqualTo(a,b) = 1 if a <= b and 0 otherwise.
581///
582/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
583/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
584/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2]`.
585///
586/// To compare signed numbers, `signed_comparison` should be set `true`.
587///
588/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
589///
590/// # Custom operation arguments
591///
592/// - Node containing a binary array or scalar
593/// - Node containing a binary array or scalar
594///
595/// # Custom operation returns
596///
597/// New LessThanEqualTo node
598///
599/// # Example
600///
601/// ```
602/// # use ciphercore_base::graphs::create_context;
603/// # use ciphercore_base::data_types::{array_type, BIT};
604/// # use ciphercore_base::custom_ops::{CustomOperation};
605/// # use ciphercore_base::ops::comparisons::LessThanEqualTo;
606/// let c = create_context().unwrap();
607/// let g = c.create_graph().unwrap();
608/// let t = array_type(vec![2, 3], BIT);
609/// let n1 = g.input(t.clone()).unwrap();
610/// let n2 = g.input(t.clone()).unwrap();
611/// let n3 = g.custom_op(CustomOperation::new(LessThanEqualTo {signed_comparison: true}), vec![n1, n2]).unwrap();
612/// ```
613#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
614pub struct LessThanEqualTo {
615    /// Boolean value indicating whether input bitstring represent signed integers
616    pub signed_comparison: bool,
617}
618
619#[typetag::serde]
620impl CustomOperationBody for LessThanEqualTo {
621    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
622        instantiate_comparison_custom_op(
623            context,
624            arguments_types,
625            self.signed_comparison,
626            &self.get_name(),
627            |res| res.less_than_equal_to(),
628        )
629    }
630
631    fn get_name(&self) -> String {
632        format!(
633            "LessThanEqualTo(signed_comparison={})",
634            self.signed_comparison
635        )
636    }
637}
638
639/// A structure that defines the custom operation GreaterThanEqualTo that compares arrays of binary strings elementwise as follows:
640///
641/// If a and b are two bitstrings, then GreaterThanEqualTo(a,b) = 1 if a >= b and 0 otherwise.
642///
643/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
644/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
645/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2]`.
646///
647/// To compare signed numbers, `signed_comparison` should be set `true`.
648///
649/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
650///
651/// # Custom operation arguments
652///
653/// - Node containing a binary array or scalar
654/// - Node containing a binary array or scalar
655///
656/// # Custom operation returns
657///
658/// New GreaterThanEqualTo node
659///
660/// # Example
661///
662/// ```
663/// # use ciphercore_base::graphs::create_context;
664/// # use ciphercore_base::data_types::{array_type, BIT};
665/// # use ciphercore_base::custom_ops::{CustomOperation};
666/// # use ciphercore_base::ops::comparisons::GreaterThanEqualTo;
667/// let c = create_context().unwrap();
668/// let g = c.create_graph().unwrap();
669/// let t = array_type(vec![2, 3], BIT);
670/// let n1 = g.input(t.clone()).unwrap();
671/// let n2 = g.input(t.clone()).unwrap();
672/// let n3 = g.custom_op(CustomOperation::new(GreaterThanEqualTo {signed_comparison: true}), vec![n1, n2]).unwrap();
673/// ```
674#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
675pub struct GreaterThanEqualTo {
676    /// Boolean value indicating whether input bitstring represent signed integers
677    pub signed_comparison: bool,
678}
679
680#[typetag::serde]
681impl CustomOperationBody for GreaterThanEqualTo {
682    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
683        instantiate_comparison_custom_op(
684            context,
685            arguments_types,
686            self.signed_comparison,
687            &self.get_name(),
688            |res| res.greater_than_equal_to(),
689        )
690    }
691
692    fn get_name(&self) -> String {
693        format!(
694            "GreaterThanEqualTo(signed_comparison={})",
695            self.signed_comparison
696        )
697    }
698}
699
700/// A structure that defines the custom operation Equal that compares arrays of binary strings elementwise as follows:
701///
702/// If a and b are two bitstrings, then Equal(a,b) = 1 if a = b and 0 otherwise.
703///
704/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
705/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
706/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2]`.
707///
708/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
709///
710/// # Custom operation arguments
711///
712/// - Node containing a binary array or scalar
713/// - Node containing a binary array or scalar
714///
715/// # Custom operation returns
716///
717/// New Equal node
718///
719/// # Example
720///
721/// ```
722/// # use ciphercore_base::graphs::create_context;
723/// # use ciphercore_base::data_types::{array_type, BIT};
724/// # use ciphercore_base::custom_ops::{CustomOperation};
725/// # use ciphercore_base::ops::comparisons::Equal;
726/// let c = create_context().unwrap();
727/// let g = c.create_graph().unwrap();
728/// let t = array_type(vec![2, 3], BIT);
729/// let n1 = g.input(t.clone()).unwrap();
730/// let n2 = g.input(t.clone()).unwrap();
731/// let n3 = g.custom_op(CustomOperation::new(Equal {}), vec![n1, n2]).unwrap();
732/// ```
733#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
734pub struct Equal {}
735
736#[typetag::serde]
737impl CustomOperationBody for Equal {
738    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
739        instantiate_comparison_custom_op(context, arguments_types, false, &self.get_name(), |res| {
740            res.equal()
741        })
742    }
743
744    fn get_name(&self) -> String {
745        "Equal".to_owned()
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752
753    use crate::broadcast::broadcast_shapes;
754    use crate::custom_ops::run_instantiation_pass;
755    use crate::custom_ops::CustomOperation;
756    use crate::data_types::scalar_type;
757    use crate::data_types::tuple_type;
758    use crate::data_types::ArrayShape;
759    use crate::data_types::{
760        array_type, ScalarType, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8,
761    };
762    use crate::data_values::Value;
763    use crate::evaluators::random_evaluate;
764    use crate::graphs::create_context;
765    use crate::inline::inline_common::DepthOptimizationLevel;
766    use crate::inline::inline_ops::inline_operations;
767    use crate::inline::inline_ops::InlineConfig;
768    use crate::inline::inline_ops::InlineMode;
769
770    fn test_unsigned_greater_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
771        let c = create_context()?;
772        let g = c.create_graph()?;
773        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
774        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
775        let o = g.custom_op(
776            CustomOperation::new(GreaterThan {
777                signed_comparison: false,
778            }),
779            vec![i_a, i_b],
780        )?;
781        g.set_output_node(o)?;
782        g.finalize()?;
783        c.set_main_graph(g.clone())?;
784        c.finalize()?;
785        let mapped_c = run_instantiation_pass(c)?;
786        let v_a = Value::from_flattened_array(&a, BIT)?;
787        let v_b = Value::from_flattened_array(&b, BIT)?;
788        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
789    }
790
791    fn test_signed_greater_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
792        let c = create_context()?;
793        let g = c.create_graph()?;
794        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
795        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
796        let o = g.custom_op(
797            CustomOperation::new(GreaterThan {
798                signed_comparison: true,
799            }),
800            vec![i_a, i_b],
801        )?;
802        g.set_output_node(o)?;
803        g.finalize()?;
804        c.set_main_graph(g.clone())?;
805        c.finalize()?;
806        let mapped_c = run_instantiation_pass(c)?;
807        let v_a = Value::from_flattened_array(&a, BIT)?;
808        let v_b = Value::from_flattened_array(&b, BIT)?;
809        let random_val = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?;
810        let op = random_val.to_u8(BIT)?;
811        Ok(op)
812    }
813
814    fn test_not_equal_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
815        let c = create_context()?;
816        let g = c.create_graph()?;
817        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
818        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
819        let o = g.custom_op(CustomOperation::new(NotEqual {}), vec![i_a, i_b])?;
820        g.set_output_node(o)?;
821        g.finalize()?;
822        c.set_main_graph(g.clone())?;
823        c.finalize()?;
824        let mapped_c = run_instantiation_pass(c)?;
825        let v_a = Value::from_flattened_array(&a, BIT)?;
826        let v_b = Value::from_flattened_array(&b, BIT)?;
827        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
828    }
829
830    /// Given supported bit size returns unsigned ScalarType
831    fn get_u_scalar_type_from_bits(bit_size: u64) -> Result<ScalarType> {
832        match bit_size {
833            8 => Ok(UINT8),
834            16 => Ok(UINT16),
835            32 => Ok(UINT32),
836            64 => Ok(UINT64),
837            _ => Err(runtime_error!("Unsupported bit size")),
838        }
839    }
840
841    /// Given supported bit size returns signed ScalarType
842    fn get_s_scalar_type_from_bits(bit_size: u64) -> Result<ScalarType> {
843        match bit_size {
844            8 => Ok(INT8),
845            16 => Ok(INT16),
846            32 => Ok(INT32),
847            64 => Ok(INT64),
848            _ => Err(runtime_error!("Unsupported bit size")),
849        }
850    }
851
852    /// Inputs:
853    /// comparison_op
854    /// a: input assumed to pass as Vec<64>
855    /// b: input assumed to pass as Vec<64>
856    /// shape_a: intended shape for a within graph
857    /// shape_b: intended shape for b within graph
858    fn test_unsigned_comparison_cust_op_for_vec_helper(
859        comparison_op: CustomOperation,
860        a: Vec<u64>,
861        b: Vec<u64>,
862        shape_a: ArrayShape,
863        shape_b: ArrayShape,
864    ) -> Result<Vec<u64>> {
865        let bit_vector_len_a = shape_a[shape_a.len() - 1];
866        let bit_vector_len_b = shape_b[shape_b.len() - 1];
867        let data_scalar_type_a = get_u_scalar_type_from_bits(bit_vector_len_a)?;
868        let data_scalar_type_b = get_u_scalar_type_from_bits(bit_vector_len_b)?;
869
870        let c = create_context()?;
871        let g = c.create_graph()?;
872        let i_va = g.input(array_type(shape_a.clone(), BIT))?;
873        let i_vb = g.input(array_type(shape_b.clone(), BIT))?;
874        let o = g.custom_op(comparison_op.clone(), vec![i_va, i_vb])?;
875        g.set_output_node(o)?;
876        g.finalize()?;
877        c.set_main_graph(g.clone())?;
878        c.finalize()?;
879        let mapped_c = run_instantiation_pass(c)?;
880        // Restructure the input data
881        let v_a = Value::from_flattened_array(&a, data_scalar_type_a)?;
882        let v_b = Value::from_flattened_array(&b, data_scalar_type_b)?;
883        let broadcasted_output_shape = broadcast_shapes(
884            shape_a[0..(shape_a.len() - 1)].to_vec(),
885            shape_b[0..(shape_b.len() - 1)].to_vec(),
886        )?;
887
888        let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?
889            .to_flattened_array_u64(array_type(broadcasted_output_shape, BIT))?;
890        Ok(result)
891    }
892
893    /// Inputs:
894    /// comparison_op
895    /// a: input assumed to pass as Vec<64>
896    /// b: input assumed to pass as Vec<64>
897    /// shape_a: intended shape for a within graph
898    /// shape_b: intended shape for b within graph
899    fn test_signed_comparison_cust_op_for_vec_helper(
900        comparison_op: CustomOperation,
901        a: Vec<i64>,
902        b: Vec<i64>,
903        shape_a: ArrayShape,
904        shape_b: ArrayShape,
905    ) -> Result<Vec<u64>> {
906        let bit_vector_len_a = shape_a[shape_a.len() - 1];
907        let bit_vector_len_b = shape_b[shape_b.len() - 1];
908        let data_scalar_type_a = get_s_scalar_type_from_bits(bit_vector_len_a)?;
909        let data_scalar_type_b = get_s_scalar_type_from_bits(bit_vector_len_b)?;
910
911        let c = create_context()?;
912        let g = c.create_graph()?;
913        let i_va = g.input(array_type(shape_a.clone(), BIT))?;
914        let i_vb = g.input(array_type(shape_b.clone(), BIT))?;
915        let o = g.custom_op(comparison_op.clone(), vec![i_va, i_vb])?;
916        g.set_output_node(o)?;
917        g.finalize()?;
918        c.set_main_graph(g.clone())?;
919        c.finalize()?;
920        let mapped_c = run_instantiation_pass(c)?;
921        // Restructure the input data
922        let v_a = Value::from_flattened_array(&a, data_scalar_type_a)?;
923        let v_b = Value::from_flattened_array(&b, data_scalar_type_b)?;
924        let broadcasted_output_shape = broadcast_shapes(
925            shape_a[0..(shape_a.len() - 1)].to_vec(),
926            shape_b[0..(shape_b.len() - 1)].to_vec(),
927        )?;
928
929        let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?
930            .to_flattened_array_u64(array_type(broadcasted_output_shape, BIT))?;
931        Ok(result)
932    }
933
934    fn test_unsigned_less_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
935        let c = create_context()?;
936        let g = c.create_graph()?;
937        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
938        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
939        let o = g.custom_op(
940            CustomOperation::new(LessThan {
941                signed_comparison: false,
942            }),
943            vec![i_a, i_b],
944        )?;
945        g.set_output_node(o)?;
946        g.finalize()?;
947        c.set_main_graph(g.clone())?;
948        c.finalize()?;
949        let mapped_c = run_instantiation_pass(c)?;
950        let v_a = Value::from_flattened_array(&a, BIT)?;
951        let v_b = Value::from_flattened_array(&b, BIT)?;
952        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
953    }
954
955    fn test_signed_less_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
956        let c = create_context()?;
957        let g = c.create_graph()?;
958        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
959        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
960        let o = g.custom_op(
961            CustomOperation::new(LessThan {
962                signed_comparison: true,
963            }),
964            vec![i_a, i_b],
965        )?;
966        g.set_output_node(o)?;
967        g.finalize()?;
968        c.set_main_graph(g.clone())?;
969        c.finalize()?;
970        let mapped_c = run_instantiation_pass(c)?;
971        let v_a = Value::from_flattened_array(&a, BIT)?;
972        let v_b = Value::from_flattened_array(&b, BIT)?;
973        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
974    }
975
976    fn test_unsigned_less_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
977        let c = create_context()?;
978        let g = c.create_graph()?;
979        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
980        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
981        let o = g.custom_op(
982            CustomOperation::new(LessThanEqualTo {
983                signed_comparison: false,
984            }),
985            vec![i_a, i_b],
986        )?;
987        g.set_output_node(o)?;
988        g.finalize()?;
989        c.set_main_graph(g.clone())?;
990        c.finalize()?;
991        let mapped_c = run_instantiation_pass(c)?;
992        let v_a = Value::from_flattened_array(&a, BIT)?;
993        let v_b = Value::from_flattened_array(&b, BIT)?;
994        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
995    }
996
997    fn test_signed_less_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
998        let c = create_context()?;
999        let g = c.create_graph()?;
1000        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1001        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1002        let o = g.custom_op(
1003            CustomOperation::new(LessThanEqualTo {
1004                signed_comparison: true,
1005            }),
1006            vec![i_a, i_b],
1007        )?;
1008        g.set_output_node(o)?;
1009        g.finalize()?;
1010        c.set_main_graph(g.clone())?;
1011        c.finalize()?;
1012        let mapped_c = run_instantiation_pass(c)?;
1013        let v_a = Value::from_flattened_array(&a, BIT)?;
1014        let v_b = Value::from_flattened_array(&b, BIT)?;
1015        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1016    }
1017
1018    fn test_unsigned_greater_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
1019        let c = create_context()?;
1020        let g = c.create_graph()?;
1021        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1022        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1023        let o = g.custom_op(
1024            CustomOperation::new(GreaterThanEqualTo {
1025                signed_comparison: false,
1026            }),
1027            vec![i_a, i_b],
1028        )?;
1029        g.set_output_node(o)?;
1030        g.finalize()?;
1031        c.set_main_graph(g.clone())?;
1032        c.finalize()?;
1033        let mapped_c = run_instantiation_pass(c)?;
1034        let v_a = Value::from_flattened_array(&a, BIT)?;
1035        let v_b = Value::from_flattened_array(&b, BIT)?;
1036        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1037    }
1038
1039    fn test_signed_greater_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
1040        let c = create_context()?;
1041        let g = c.create_graph()?;
1042        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1043        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1044        let o = g.custom_op(
1045            CustomOperation::new(GreaterThanEqualTo {
1046                signed_comparison: true,
1047            }),
1048            vec![i_a, i_b],
1049        )?;
1050        g.set_output_node(o)?;
1051        g.finalize()?;
1052        c.set_main_graph(g.clone())?;
1053        c.finalize()?;
1054        let mapped_c = run_instantiation_pass(c)?;
1055        let v_a = Value::from_flattened_array(&a, BIT)?;
1056        let v_b = Value::from_flattened_array(&b, BIT)?;
1057        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1058    }
1059
1060    fn test_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
1061        let c = create_context()?;
1062        let g = c.create_graph()?;
1063        let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
1064        let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
1065        let o = g.custom_op(CustomOperation::new(Equal {}), vec![i_a, i_b])?;
1066        g.set_output_node(o)?;
1067        g.finalize()?;
1068        c.set_main_graph(g.clone())?;
1069        c.finalize()?;
1070        let mapped_c = run_instantiation_pass(c)?;
1071        let v_a = Value::from_flattened_array(&a, BIT)?;
1072        let v_b = Value::from_flattened_array(&b, BIT)?;
1073        Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
1074    }
1075
1076    #[test]
1077    fn test_greater_than_cust_op() {
1078        || -> Result<()> {
1079            assert_eq!(
1080                test_unsigned_greater_than_cust_op_helper(vec![0], vec![0])?,
1081                0
1082            );
1083            assert_eq!(
1084                test_unsigned_greater_than_cust_op_helper(vec![0], vec![1])?,
1085                0
1086            );
1087            assert_eq!(
1088                test_unsigned_greater_than_cust_op_helper(vec![1], vec![0])?,
1089                1
1090            );
1091            assert_eq!(
1092                test_unsigned_greater_than_cust_op_helper(vec![1], vec![1])?,
1093                0
1094            );
1095            Ok(())
1096        }()
1097        .unwrap();
1098    }
1099
1100    #[test]
1101    fn test_signed_greater_than_cust_op() {
1102        || -> Result<()> {
1103            // for signed positive values
1104            assert_eq!(
1105                test_signed_greater_than_cust_op_helper(vec![0, 0], vec![0, 0])?,
1106                0
1107            );
1108            assert_eq!(
1109                test_signed_greater_than_cust_op_helper(vec![0, 0], vec![1, 0])?,
1110                0
1111            );
1112            assert_eq!(
1113                test_signed_greater_than_cust_op_helper(vec![1, 0], vec![0, 0])?,
1114                1
1115            );
1116            assert_eq!(
1117                test_signed_greater_than_cust_op_helper(vec![1, 0], vec![1, 0])?,
1118                0
1119            );
1120            // for signed negative values
1121            assert_eq!(
1122                test_signed_greater_than_cust_op_helper(vec![0, 1], vec![0, 1])?,
1123                0
1124            );
1125            assert_eq!(
1126                test_signed_greater_than_cust_op_helper(vec![0, 1], vec![1, 1])?,
1127                0
1128            );
1129            assert_eq!(
1130                test_signed_greater_than_cust_op_helper(vec![1, 1], vec![0, 1])?,
1131                1
1132            );
1133            assert_eq!(
1134                test_signed_greater_than_cust_op_helper(vec![1, 1], vec![1, 1])?,
1135                0
1136            );
1137            // mixture of values
1138            assert_eq!(
1139                test_signed_greater_than_cust_op_helper(vec![0, 1], vec![0, 0])?,
1140                0
1141            );
1142            assert_eq!(
1143                test_signed_greater_than_cust_op_helper(vec![0, 0], vec![0, 1])?,
1144                1
1145            );
1146            assert_eq!(
1147                test_signed_greater_than_cust_op_helper(vec![0, 1], vec![1, 0])?,
1148                0
1149            );
1150            assert_eq!(
1151                test_signed_greater_than_cust_op_helper(vec![0, 0], vec![1, 1])?,
1152                1
1153            );
1154            assert_eq!(
1155                test_signed_greater_than_cust_op_helper(vec![1, 1], vec![0, 0])?,
1156                0
1157            );
1158            assert_eq!(
1159                test_signed_greater_than_cust_op_helper(vec![1, 0], vec![0, 1])?,
1160                1
1161            );
1162            assert_eq!(
1163                test_signed_greater_than_cust_op_helper(vec![1, 1], vec![1, 0])?,
1164                0
1165            );
1166            assert_eq!(
1167                test_signed_greater_than_cust_op_helper(vec![1, 0], vec![1, 1])?,
1168                1
1169            );
1170            Ok(())
1171        }()
1172        .unwrap();
1173    }
1174
1175    #[test]
1176    fn test_unsigned_less_than_cust_op() {
1177        || -> Result<()> {
1178            assert_eq!(test_unsigned_less_than_cust_op_helper(vec![0], vec![0])?, 0);
1179            assert_eq!(test_unsigned_less_than_cust_op_helper(vec![0], vec![1])?, 1);
1180            assert_eq!(test_unsigned_less_than_cust_op_helper(vec![1], vec![0])?, 0);
1181            assert_eq!(test_unsigned_less_than_cust_op_helper(vec![1], vec![1])?, 0);
1182            Ok(())
1183        }()
1184        .unwrap();
1185    }
1186
1187    #[test]
1188    fn test_signed_less_than_cust_op() {
1189        || -> Result<()> {
1190            // for signed positive values
1191            assert_eq!(
1192                test_signed_less_than_cust_op_helper(vec![0, 0], vec![0, 0])?,
1193                0
1194            );
1195            assert_eq!(
1196                test_signed_less_than_cust_op_helper(vec![0, 0], vec![1, 0])?,
1197                1
1198            );
1199            assert_eq!(
1200                test_signed_less_than_cust_op_helper(vec![1, 0], vec![0, 0])?,
1201                0
1202            );
1203            assert_eq!(
1204                test_signed_less_than_cust_op_helper(vec![1, 0], vec![1, 0])?,
1205                0
1206            );
1207            // for signed negative less
1208            assert_eq!(
1209                test_signed_less_than_cust_op_helper(vec![0, 1], vec![0, 1])?,
1210                0
1211            );
1212            assert_eq!(
1213                test_signed_less_than_cust_op_helper(vec![0, 1], vec![1, 1])?,
1214                1
1215            );
1216            assert_eq!(
1217                test_signed_less_than_cust_op_helper(vec![1, 1], vec![0, 1])?,
1218                0
1219            );
1220            assert_eq!(
1221                test_signed_less_than_cust_op_helper(vec![1, 1], vec![1, 1])?,
1222                0
1223            );
1224            // mixture of valuesless
1225            assert_eq!(
1226                test_signed_less_than_cust_op_helper(vec![0, 1], vec![0, 0])?,
1227                1
1228            );
1229            assert_eq!(
1230                test_signed_less_than_cust_op_helper(vec![0, 0], vec![0, 1])?,
1231                0
1232            );
1233            assert_eq!(
1234                test_signed_less_than_cust_op_helper(vec![0, 1], vec![1, 0])?,
1235                1
1236            );
1237            assert_eq!(
1238                test_signed_less_than_cust_op_helper(vec![0, 0], vec![1, 1])?,
1239                0
1240            );
1241            assert_eq!(
1242                test_signed_less_than_cust_op_helper(vec![1, 1], vec![0, 0])?,
1243                1
1244            );
1245            assert_eq!(
1246                test_signed_less_than_cust_op_helper(vec![1, 0], vec![0, 1])?,
1247                0
1248            );
1249            assert_eq!(
1250                test_signed_less_than_cust_op_helper(vec![1, 1], vec![1, 0])?,
1251                1
1252            );
1253            assert_eq!(
1254                test_signed_less_than_cust_op_helper(vec![1, 0], vec![1, 1])?,
1255                0
1256            );
1257            Ok(())
1258        }()
1259        .unwrap();
1260    }
1261
1262    #[test]
1263    fn test_unsigned_less_than_or_eq_to_cust_op() {
1264        || -> Result<()> {
1265            assert_eq!(
1266                test_unsigned_less_than_equal_to_cust_op_helper(vec![0], vec![0])?,
1267                1
1268            );
1269            assert_eq!(
1270                test_unsigned_less_than_equal_to_cust_op_helper(vec![0], vec![1])?,
1271                1
1272            );
1273            assert_eq!(
1274                test_unsigned_less_than_equal_to_cust_op_helper(vec![1], vec![0])?,
1275                0
1276            );
1277            assert_eq!(
1278                test_unsigned_less_than_equal_to_cust_op_helper(vec![1], vec![1])?,
1279                1
1280            );
1281            Ok(())
1282        }()
1283        .unwrap();
1284    }
1285
1286    #[test]
1287    fn test_signed_less_than_or_eq_to_cust_op() {
1288        || -> Result<()> {
1289            // for signed positive values
1290            assert_eq!(
1291                test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 0])?,
1292                1
1293            );
1294            assert_eq!(
1295                test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 0])?,
1296                1
1297            );
1298            assert_eq!(
1299                test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 0])?,
1300                0
1301            );
1302            assert_eq!(
1303                test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 0])?,
1304                1
1305            );
1306            // for signed negative less
1307            assert_eq!(
1308                test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 1])?,
1309                1
1310            );
1311            assert_eq!(
1312                test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 1])?,
1313                1
1314            );
1315            assert_eq!(
1316                test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 1])?,
1317                0
1318            );
1319            assert_eq!(
1320                test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 1])?,
1321                1
1322            );
1323            // mixture of valuesless
1324            assert_eq!(
1325                test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 0])?,
1326                1
1327            );
1328            assert_eq!(
1329                test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 1])?,
1330                0
1331            );
1332            assert_eq!(
1333                test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 0])?,
1334                1
1335            );
1336            assert_eq!(
1337                test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 1])?,
1338                0
1339            );
1340            assert_eq!(
1341                test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 0])?,
1342                1
1343            );
1344            assert_eq!(
1345                test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 1])?,
1346                0
1347            );
1348            assert_eq!(
1349                test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 0])?,
1350                1
1351            );
1352            assert_eq!(
1353                test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 1])?,
1354                0
1355            );
1356            Ok(())
1357        }()
1358        .unwrap();
1359    }
1360
1361    #[test]
1362    fn test_unsigned_greater_than_or_eq_to_cust_op() {
1363        || -> Result<()> {
1364            assert_eq!(
1365                test_unsigned_greater_than_equal_to_cust_op_helper(vec![0], vec![0])?,
1366                1
1367            );
1368            assert_eq!(
1369                test_unsigned_greater_than_equal_to_cust_op_helper(vec![0], vec![1])?,
1370                0
1371            );
1372            assert_eq!(
1373                test_unsigned_greater_than_equal_to_cust_op_helper(vec![1], vec![0])?,
1374                1
1375            );
1376            assert_eq!(
1377                test_unsigned_greater_than_equal_to_cust_op_helper(vec![1], vec![1])?,
1378                1
1379            );
1380            Ok(())
1381        }()
1382        .unwrap();
1383    }
1384
1385    #[test]
1386    fn test_signed_greater_than_or_eq_to_cust_op() {
1387        || -> Result<()> {
1388            // for signed positive values
1389            assert_eq!(
1390                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 0])?,
1391                1
1392            );
1393            assert_eq!(
1394                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 0])?,
1395                0
1396            );
1397            assert_eq!(
1398                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 0])?,
1399                1
1400            );
1401            assert_eq!(
1402                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 0])?,
1403                1
1404            );
1405            // for signed negative values
1406            assert_eq!(
1407                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 1])?,
1408                1
1409            );
1410            assert_eq!(
1411                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 1])?,
1412                0
1413            );
1414            assert_eq!(
1415                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 1])?,
1416                1
1417            );
1418            assert_eq!(
1419                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 1])?,
1420                1
1421            );
1422            // mixture of values
1423            assert_eq!(
1424                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 0])?,
1425                0
1426            );
1427            assert_eq!(
1428                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 1])?,
1429                1
1430            );
1431            assert_eq!(
1432                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 0])?,
1433                0
1434            );
1435            assert_eq!(
1436                test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 1])?,
1437                1
1438            );
1439            assert_eq!(
1440                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 0])?,
1441                0
1442            );
1443            assert_eq!(
1444                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 1])?,
1445                1
1446            );
1447            assert_eq!(
1448                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 0])?,
1449                0
1450            );
1451            assert_eq!(
1452                test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 1])?,
1453                1
1454            );
1455            Ok(())
1456        }()
1457        .unwrap();
1458    }
1459
1460    #[test]
1461    fn test_not_equal_cust_op() {
1462        || -> Result<()> {
1463            assert_eq!(test_not_equal_cust_op_helper(vec![0], vec![0])?, 0);
1464            assert_eq!(test_not_equal_cust_op_helper(vec![0], vec![1])?, 1);
1465            assert_eq!(test_not_equal_cust_op_helper(vec![1], vec![0])?, 1);
1466            assert_eq!(test_not_equal_cust_op_helper(vec![1], vec![1])?, 0);
1467            Ok(())
1468        }()
1469        .unwrap();
1470    }
1471
1472    #[test]
1473    fn test_equal_to_cust_op() {
1474        || -> Result<()> {
1475            assert_eq!(test_equal_to_cust_op_helper(vec![0], vec![0])?, 1);
1476            assert_eq!(test_equal_to_cust_op_helper(vec![0], vec![1])?, 0);
1477            assert_eq!(test_equal_to_cust_op_helper(vec![1], vec![0])?, 0);
1478            assert_eq!(test_equal_to_cust_op_helper(vec![1], vec![1])?, 1);
1479            Ok(())
1480        }()
1481        .unwrap();
1482    }
1483
1484    #[test]
1485    fn test_unsigned_multiple_bit_comparisons_cust_op() {
1486        || -> Result<()> {
1487            for i in 0..8 {
1488                for j in 0..8 {
1489                    let a: Vec<u8> = vec![i & 1, (i & 2) >> 1, (i & 4) >> 2];
1490                    let b: Vec<u8> = vec![j & 1, (j & 2) >> 1, (j & 4) >> 2];
1491                    assert_eq!(
1492                        test_unsigned_greater_than_cust_op_helper(a.clone(), b.clone())?,
1493                        if i > j { 1 } else { 0 }
1494                    );
1495                    assert_eq!(
1496                        test_unsigned_less_than_cust_op_helper(a.clone(), b.clone())?,
1497                        if i < j { 1 } else { 0 }
1498                    );
1499                    assert_eq!(
1500                        test_unsigned_greater_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1501                        if i >= j { 1 } else { 0 }
1502                    );
1503                    assert_eq!(
1504                        test_unsigned_less_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1505                        if i <= j { 1 } else { 0 }
1506                    );
1507                    assert_eq!(
1508                        test_not_equal_cust_op_helper(a.clone(), b.clone())?,
1509                        if i != j { 1 } else { 0 }
1510                    );
1511                    assert_eq!(
1512                        test_equal_to_cust_op_helper(a.clone(), b.clone())?,
1513                        if i == j { 1 } else { 0 }
1514                    );
1515                }
1516            }
1517            Ok(())
1518        }()
1519        .unwrap();
1520    }
1521
1522    #[test]
1523    fn test_signed_multiple_bit_comparisons_cust_op() {
1524        || -> Result<()> {
1525            for i in 0..8 {
1526                for j in 0..8 {
1527                    let a: Vec<u8> = vec![i & 1, (i & 2) >> 1, (i & 4) >> 2];
1528                    let b: Vec<u8> = vec![j & 1, (j & 2) >> 1, (j & 4) >> 2];
1529                    let s_i = if i > 3 { i as i8 - 8 } else { i as i8 };
1530                    let s_j = if j > 3 { j as i8 - 8 } else { j as i8 };
1531                    assert_eq!(
1532                        test_signed_greater_than_cust_op_helper(a.clone(), b.clone())?,
1533                        if s_i > s_j { 1 } else { 0 }
1534                    );
1535                    assert_eq!(
1536                        test_signed_less_than_cust_op_helper(a.clone(), b.clone())?,
1537                        if s_i < s_j { 1 } else { 0 }
1538                    );
1539                    assert_eq!(
1540                        test_signed_greater_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1541                        if s_i >= s_j { 1 } else { 0 }
1542                    );
1543                    assert_eq!(
1544                        test_signed_less_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
1545                        if s_i <= s_j { 1 } else { 0 }
1546                    );
1547                }
1548            }
1549            Ok(())
1550        }()
1551        .unwrap();
1552    }
1553
1554    #[test]
1555    fn test_unsigned_malformed_basic_cust_ops() {
1556        || -> Result<()> {
1557            let cust_ops = vec![
1558                CustomOperation::new(GreaterThan {
1559                    signed_comparison: false,
1560                }),
1561                CustomOperation::new(NotEqual {}),
1562            ];
1563            for cust_op in cust_ops.into_iter() {
1564                let c = create_context()?;
1565                let g = c.create_graph()?;
1566                let i_a = g.input(array_type(vec![1], BIT))?;
1567                let i_b = g.input(array_type(vec![1], BIT))?;
1568                let i_c = g.input(array_type(vec![1], BIT))?;
1569                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b, i_c]).is_err());
1570
1571                let c = create_context()?;
1572                let g = c.create_graph()?;
1573                let i_a = g.input(scalar_type(BIT))?;
1574                let i_b = g.input(array_type(vec![1], BIT))?;
1575                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1576
1577                let c = create_context()?;
1578                let g = c.create_graph()?;
1579                let i_a = g.input(array_type(vec![1], BIT))?;
1580                let i_b = g.input(tuple_type(vec![array_type(vec![1], BIT)]))?;
1581                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1582
1583                let c = create_context()?;
1584                let g = c.create_graph()?;
1585                let i_a = g.input(array_type(vec![1], INT16))?;
1586                let i_b = g.input(array_type(vec![1], BIT))?;
1587                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1588
1589                let c = create_context()?;
1590                let g = c.create_graph()?;
1591                let i_a = g.input(array_type(vec![1], UINT16))?;
1592                let i_b = g.input(array_type(vec![1], BIT))?;
1593                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1594
1595                let c = create_context()?;
1596                let g = c.create_graph()?;
1597                let i_a = g.input(array_type(vec![1], BIT))?;
1598                let i_b = g.input(array_type(vec![1], INT32))?;
1599                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1600
1601                let c = create_context()?;
1602                let g = c.create_graph()?;
1603                let i_a = g.input(array_type(vec![1], BIT))?;
1604                let i_b = g.input(array_type(vec![1], UINT32))?;
1605                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1606
1607                let c = create_context()?;
1608                let g = c.create_graph()?;
1609                let i_a = g.input(array_type(vec![1], BIT))?;
1610                let i_b = g.input(array_type(vec![9], BIT))?;
1611                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1612
1613                let c = create_context()?;
1614                let g = c.create_graph()?;
1615                let i_a = g.input(array_type(vec![1], BIT))?;
1616                let i_b = g.input(array_type(vec![1, 2], BIT))?;
1617                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1618
1619                let v_a = vec![170, 120, 61, 85];
1620                let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
1621                assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1622                    cust_op.clone(),
1623                    v_a.clone(),
1624                    v_b.clone(),
1625                    vec![2, 2, 16],
1626                    vec![3, 3, 32]
1627                )
1628                .is_err());
1629
1630                let v_a = vec![170];
1631                let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
1632                assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1633                    cust_op.clone(),
1634                    v_a.clone(),
1635                    v_b.clone(),
1636                    vec![2, 2, 16],
1637                    vec![3, 3, 16]
1638                )
1639                .is_err());
1640
1641                let v_a = vec![];
1642                let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
1643                assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1644                    cust_op.clone(),
1645                    v_a.clone(),
1646                    v_b.clone(),
1647                    vec![0, 64],
1648                    vec![3, 3, 64]
1649                )
1650                .is_err());
1651
1652                let v_a = vec![170, 200];
1653                let v_b = vec![];
1654                assert!(test_unsigned_comparison_cust_op_for_vec_helper(
1655                    cust_op.clone(),
1656                    v_a.clone(),
1657                    v_b.clone(),
1658                    vec![2, 1, 64],
1659                    vec![2, 2, 1, 64]
1660                )
1661                .is_err());
1662            }
1663
1664            Ok(())
1665        }()
1666        .unwrap();
1667    }
1668
1669    #[test]
1670    fn test_signed_malformed_basic_cust_ops() {
1671        || -> Result<()> {
1672            let cust_ops = vec![CustomOperation::new(GreaterThan {
1673                signed_comparison: true,
1674            })];
1675            for cust_op in cust_ops.into_iter() {
1676                let c = create_context()?;
1677                let g = c.create_graph()?;
1678                let i_a = g.input(array_type(vec![1], BIT))?;
1679                let i_b = g.input(array_type(vec![1, 1], BIT))?;
1680                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1681
1682                let c = create_context()?;
1683                let g = c.create_graph()?;
1684                let i_a = g.input(array_type(vec![1, 1], BIT))?;
1685                let i_b = g.input(array_type(vec![1], BIT))?;
1686                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1687
1688                let c = create_context()?;
1689                let g = c.create_graph()?;
1690                let i_a = g.input(array_type(vec![1, 64], BIT))?;
1691                let i_b = g.input(array_type(vec![1], BIT))?;
1692                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1693
1694                let c = create_context()?;
1695                let g = c.create_graph()?;
1696                let i_a = g.input(scalar_type(BIT))?;
1697                let i_b = g.input(array_type(vec![1], BIT))?;
1698                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1699
1700                let c = create_context()?;
1701                let g = c.create_graph()?;
1702                let i_a = g.input(array_type(vec![1], UINT16))?;
1703                let i_b = g.input(array_type(vec![1], BIT))?;
1704                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1705
1706                let c = create_context()?;
1707                let g = c.create_graph()?;
1708                let i_a = g.input(array_type(vec![1], BIT))?;
1709                let i_b = g.input(array_type(vec![1], INT32))?;
1710                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1711
1712                let c = create_context()?;
1713                let g = c.create_graph()?;
1714                let i_a = g.input(array_type(vec![1], BIT))?;
1715                let i_b = g.input(array_type(vec![1], UINT32))?;
1716                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1717
1718                let c = create_context()?;
1719                let g = c.create_graph()?;
1720                let i_a = g.input(array_type(vec![1], BIT))?;
1721                let i_b = g.input(array_type(vec![9], BIT))?;
1722                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1723
1724                let c = create_context()?;
1725                let g = c.create_graph()?;
1726                let i_a = g.input(array_type(vec![1, 2, 3], BIT))?;
1727                let i_b = g.input(array_type(vec![9], BIT))?;
1728                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1729
1730                let c = create_context()?;
1731                let g = c.create_graph()?;
1732                let i_a = g.input(array_type(vec![1], BIT))?;
1733                let i_b = g.input(array_type(vec![1, 2], BIT))?;
1734                assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
1735
1736                let v_a = vec![170, 120, 61, 85];
1737                let v_b = vec![
1738                    -1176658021,
1739                    -301476304,
1740                    788180273,
1741                    -1085188538,
1742                    -1358798926,
1743                    -120286105,
1744                    -1300942710,
1745                    -389618936,
1746                    258721418,
1747                ];
1748                assert!(test_signed_comparison_cust_op_for_vec_helper(
1749                    cust_op.clone(),
1750                    v_a.clone(),
1751                    v_b.clone(),
1752                    vec![2, 2, 16],
1753                    vec![3, 3, 32]
1754                )
1755                .is_err());
1756
1757                let v_a = vec![-14735];
1758                let v_b = vec![
1759                    16490, -10345, -31409, 2787, -15039, 26085, 7881, 32423, -23915,
1760                ];
1761                assert!(test_signed_comparison_cust_op_for_vec_helper(
1762                    cust_op.clone(),
1763                    v_a.clone(),
1764                    v_b.clone(),
1765                    vec![2, 2, 16],
1766                    vec![3, 3, 16]
1767                )
1768                .is_err());
1769
1770                let v_a = vec![];
1771                let v_b = vec![
1772                    -2600362169875399934,
1773                    6278463339984150730,
1774                    -2962726308672949899,
1775                    3404980137287029349,
1776                ];
1777                assert!(test_signed_comparison_cust_op_for_vec_helper(
1778                    cust_op.clone(),
1779                    v_a.clone(),
1780                    v_b.clone(),
1781                    vec![0, 64],
1782                    vec![2, 2, 64]
1783                )
1784                .is_err());
1785
1786                let v_a = vec![-2600362169875399934, 6278463339984150730];
1787                let v_b = vec![];
1788                assert!(test_signed_comparison_cust_op_for_vec_helper(
1789                    cust_op.clone(),
1790                    v_a.clone(),
1791                    v_b.clone(),
1792                    vec![2, 1, 64],
1793                    vec![2, 2, 1, 64]
1794                )
1795                .is_err());
1796            }
1797
1798            Ok(())
1799        }()
1800        .unwrap();
1801    }
1802
1803    #[test]
1804    fn test_unsigned_vector_comparisons() {
1805        || -> Result<()> {
1806            let mut v_a = vec![170, 120, 61, 85];
1807            let mut v_b = vec![
1808                76, 20, 70, 249, 217, 190, 43, 83, 33710, 27637, 43918, 38683,
1809            ];
1810            assert_eq!(
1811                test_unsigned_comparison_cust_op_for_vec_helper(
1812                    CustomOperation::new(GreaterThan {
1813                        signed_comparison: false
1814                    }),
1815                    v_a.clone(),
1816                    v_b.clone(),
1817                    vec![2, 2, 64],
1818                    vec![3, 2, 2, 64],
1819                )?,
1820                vec![1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]
1821            );
1822
1823            v_a = vec![170, 120, 61, 85, 75, 149, 50, 54, 8811, 29720, 1009, 33126];
1824            v_b = vec![76, 20, 70, 249, 217, 190];
1825            assert_eq!(
1826                test_unsigned_comparison_cust_op_for_vec_helper(
1827                    CustomOperation::new(GreaterThan {
1828                        signed_comparison: false
1829                    }),
1830                    v_a.clone(),
1831                    v_b.clone(),
1832                    vec![2, 3, 2, 32],
1833                    vec![3, 2, 32],
1834                )?,
1835                vec![1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
1836            );
1837
1838            v_a = vec![170, 120, 61, 85, 75, 149, 50, 54];
1839            v_b = vec![76, 20, 70, 249];
1840            assert_eq!(
1841                test_unsigned_comparison_cust_op_for_vec_helper(
1842                    CustomOperation::new(GreaterThan {
1843                        signed_comparison: false
1844                    }),
1845                    v_a.clone(),
1846                    v_b.clone(),
1847                    vec![2, 2, 2, 16],
1848                    vec![2, 2, 16],
1849                )?,
1850                vec![1, 1, 0, 0, 0, 1, 0, 0]
1851            );
1852
1853            v_a = vec![170, 120, 61, 85, 75, 149, 50, 54];
1854            v_b = vec![76, 20, 70, 249, 217, 190, 43, 83];
1855            assert_eq!(
1856                test_unsigned_comparison_cust_op_for_vec_helper(
1857                    CustomOperation::new(GreaterThan {
1858                        signed_comparison: false
1859                    }),
1860                    v_a.clone(),
1861                    v_b.clone(),
1862                    vec![2, 2, 2, 64],
1863                    vec![2, 2, 2, 64],
1864                )?,
1865                vec![1, 1, 0, 0, 0, 0, 1, 0]
1866            );
1867
1868            v_a = vec![170, 120, 61];
1869            v_b = vec![76, 20, 70];
1870            assert_eq!(
1871                test_unsigned_comparison_cust_op_for_vec_helper(
1872                    CustomOperation::new(GreaterThan {
1873                        signed_comparison: false
1874                    }),
1875                    v_a.clone(),
1876                    v_b.clone(),
1877                    vec![3, 64],
1878                    vec![3, 64],
1879                )?,
1880                vec![1, 1, 0]
1881            );
1882
1883            v_a = vec![170, 120, 61, 85, 75, 149];
1884            v_b = vec![76, 20, 70];
1885            assert_eq!(
1886                test_unsigned_comparison_cust_op_for_vec_helper(
1887                    CustomOperation::new(LessThan {
1888                        signed_comparison: false
1889                    }),
1890                    v_a.clone(),
1891                    v_b.clone(),
1892                    vec![2, 3, 64],
1893                    vec![3, 64],
1894                )?,
1895                vec![0, 0, 1, 0, 0, 0]
1896            );
1897
1898            v_a = vec![170, 120, 61, 85, 75, 70, 50, 54, 8811, 29720, 1009, 33126];
1899            v_b = vec![76, 1009, 70];
1900            assert_eq!(
1901                test_unsigned_comparison_cust_op_for_vec_helper(
1902                    CustomOperation::new(LessThanEqualTo {
1903                        signed_comparison: false
1904                    }),
1905                    v_a.clone(),
1906                    v_b.clone(),
1907                    vec![2, 2, 3, 64],
1908                    vec![3, 64],
1909                )?,
1910                vec![0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0]
1911            );
1912
1913            v_a = vec![170, 120, 61, 85, 75, 76, 50, 54];
1914            v_b = vec![76];
1915            assert_eq!(
1916                test_unsigned_comparison_cust_op_for_vec_helper(
1917                    CustomOperation::new(GreaterThanEqualTo {
1918                        signed_comparison: false
1919                    }),
1920                    v_a.clone(),
1921                    v_b.clone(),
1922                    vec![2, 2, 2, 64],
1923                    vec![1, 64],
1924                )?,
1925                vec![1, 1, 0, 1, 0, 1, 0, 0]
1926            );
1927
1928            v_a = vec![170];
1929            v_b = vec![76];
1930            assert_eq!(
1931                test_unsigned_comparison_cust_op_for_vec_helper(
1932                    CustomOperation::new(GreaterThanEqualTo {
1933                        signed_comparison: false
1934                    }),
1935                    v_a.clone(),
1936                    v_b.clone(),
1937                    vec![1, 64],
1938                    vec![1, 64],
1939                )?,
1940                vec![1]
1941            );
1942
1943            v_a = vec![76];
1944            v_b = vec![76, 170];
1945            assert_eq!(
1946                test_unsigned_comparison_cust_op_for_vec_helper(
1947                    CustomOperation::new(GreaterThanEqualTo {
1948                        signed_comparison: false
1949                    }),
1950                    v_a.clone(),
1951                    v_b.clone(),
1952                    vec![1, 64],
1953                    vec![1, 2, 64],
1954                )?,
1955                vec![1, 0]
1956            );
1957
1958            let v_a = vec![83, 172, 214, 2, 68];
1959            let v_b = vec![83];
1960            assert_eq!(
1961                test_unsigned_comparison_cust_op_for_vec_helper(
1962                    CustomOperation::new(GreaterThanEqualTo {
1963                        signed_comparison: false
1964                    }),
1965                    v_a,
1966                    v_b,
1967                    vec![5, 8],
1968                    vec![8]
1969                )?,
1970                vec![1, 1, 1, 0, 0]
1971            );
1972
1973            let v_a = vec![2];
1974            let v_b = vec![83, 1, 2, 100];
1975            assert_eq!(
1976                test_unsigned_comparison_cust_op_for_vec_helper(
1977                    CustomOperation::new(LessThan {
1978                        signed_comparison: false
1979                    }),
1980                    v_a,
1981                    v_b,
1982                    vec![1, 32],
1983                    vec![2, 2, 32]
1984                )?,
1985                vec![1, 0, 0, 1]
1986            );
1987
1988            let v_a = vec![83, 2];
1989            let v_b = vec![83, 172, 214, 2, 68, 34, 87, 45, 83, 23];
1990            assert_eq!(
1991                test_unsigned_comparison_cust_op_for_vec_helper(
1992                    CustomOperation::new(LessThanEqualTo {
1993                        signed_comparison: false
1994                    }),
1995                    v_a,
1996                    v_b,
1997                    vec![2, 1, 64],
1998                    vec![2, 5, 64]
1999                )?,
2000                vec![1, 1, 1, 0, 0, 1, 1, 1, 1, 1]
2001            );
2002
2003            let v_a = vec![83, 2];
2004            let v_b = vec![83, 172, 214, 2, 68, 2, 87, 45];
2005            assert_eq!(
2006                test_unsigned_comparison_cust_op_for_vec_helper(
2007                    CustomOperation::new(NotEqual {}),
2008                    v_a,
2009                    v_b,
2010                    vec![2, 1, 64],
2011                    vec![2, 4, 64]
2012                )?,
2013                vec![0, 1, 1, 1, 1, 0, 1, 1]
2014            );
2015
2016            let v_a = vec![4, 2];
2017            let v_b = vec![83, 21];
2018            assert_eq!(
2019                test_unsigned_comparison_cust_op_for_vec_helper(
2020                    CustomOperation::new(NotEqual {}),
2021                    v_a,
2022                    v_b,
2023                    vec![1, 2, 64],
2024                    vec![2, 1, 64]
2025                )?,
2026                vec![1, 1, 1, 1]
2027            );
2028
2029            let v_a = vec![247, 170, 249, 162, 102, 243, 61, 203, 125];
2030            let v_b = vec![247, 170, 249, 162, 102, 243, 61, 203, 125];
2031            assert_eq!(
2032                test_unsigned_comparison_cust_op_for_vec_helper(
2033                    CustomOperation::new(NotEqual {}),
2034                    v_a,
2035                    v_b,
2036                    vec![3, 3, 16],
2037                    vec![3, 3, 16]
2038                )?,
2039                vec![0, 0, 0, 0, 0, 0, 0, 0, 0]
2040            );
2041
2042            let v_a = vec![83, 2];
2043            let v_b = vec![83, 172, 214, 2, 68, 2, 87, 45];
2044            assert_eq!(
2045                test_unsigned_comparison_cust_op_for_vec_helper(
2046                    CustomOperation::new(Equal {}),
2047                    v_a,
2048                    v_b,
2049                    vec![2, 1, 64],
2050                    vec![2, 4, 64]
2051                )?,
2052                vec![1, 0, 0, 0, 0, 1, 0, 0]
2053            );
2054
2055            let v_a = vec![4, 2];
2056            let v_b = vec![83, 21];
2057            assert_eq!(
2058                test_unsigned_comparison_cust_op_for_vec_helper(
2059                    CustomOperation::new(Equal {}),
2060                    v_a,
2061                    v_b,
2062                    vec![1, 2, 64],
2063                    vec![2, 1, 64]
2064                )?,
2065                vec![0, 0, 0, 0]
2066            );
2067
2068            let v_a = vec![180, 16, 62, 141, 122, 217];
2069            let v_b = vec![141, 122, 217, 100, 11, 29];
2070            assert_eq!(
2071                test_unsigned_comparison_cust_op_for_vec_helper(
2072                    CustomOperation::new(Equal {}),
2073                    v_a,
2074                    v_b,
2075                    vec![3, 2, 1, 16],
2076                    vec![1, 2, 3, 16]
2077                )?,
2078                vec![
2079                    0, 0, 0, // 180==[141, 122, 217]
2080                    0, 0, 0, // 16==[100, 11, 29]
2081                    0, 0, 0, // 62==[141, 122, 217]
2082                    0, 0, 0, // 141==[100, 11, 29]
2083                    0, 1, 0, // 122==[141, 122, 217]
2084                    0, 0, 0 // 217==[100, 11, 29]
2085                ]
2086            );
2087
2088            let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
2089            let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
2090            assert_eq!(
2091                test_unsigned_comparison_cust_op_for_vec_helper(
2092                    CustomOperation::new(GreaterThan {
2093                        signed_comparison: false
2094                    }),
2095                    v_a.clone(),
2096                    v_b.clone(),
2097                    vec![4, 1, 64],
2098                    vec![1, 4, 64],
2099                )?,
2100                vec![0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0]
2101            );
2102
2103            let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
2104            let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
2105            assert_eq!(
2106                test_unsigned_comparison_cust_op_for_vec_helper(
2107                    CustomOperation::new(GreaterThanEqualTo {
2108                        signed_comparison: false
2109                    }),
2110                    v_a.clone(),
2111                    v_b.clone(),
2112                    vec![4, 1, 64],
2113                    vec![1, 4, 64],
2114                )?,
2115                vec![1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1]
2116            );
2117
2118            let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
2119            let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
2120            assert_eq!(
2121                test_unsigned_comparison_cust_op_for_vec_helper(
2122                    CustomOperation::new(NotEqual {}),
2123                    v_a.clone(),
2124                    v_b.clone(),
2125                    vec![4, 1, 64],
2126                    vec![1, 4, 64],
2127                )?,
2128                vec![0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
2129            );
2130
2131            Ok(())
2132        }()
2133        .unwrap();
2134    }
2135
2136    #[test]
2137    fn test_signed_vector_comparisons() {
2138        || -> Result<()> {
2139            let v_a = vec![
2140                -9223372036854775808,
2141                -9223372036854775807,
2142                -1,
2143                0,
2144                1,
2145                9223372036854775806,
2146                9223372036854775807,
2147            ];
2148            let v_b = vec![
2149                -9223372036854775808,
2150                -9223372036854775807,
2151                -1,
2152                0,
2153                1,
2154                9223372036854775806,
2155                9223372036854775807,
2156            ];
2157            assert_eq!(
2158                test_signed_comparison_cust_op_for_vec_helper(
2159                    CustomOperation::new(GreaterThan {
2160                        signed_comparison: true
2161                    }),
2162                    v_a.clone(),
2163                    v_b.clone(),
2164                    vec![7, 1, 64],
2165                    vec![1, 7, 64],
2166                )?,
2167                vec![
2168                    0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
2169                    0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0
2170                ]
2171            );
2172
2173            let v_a = vec![
2174                -9223372036854775808,
2175                -9223372036854775807,
2176                -1,
2177                0,
2178                1,
2179                9223372036854775806,
2180                9223372036854775807,
2181            ];
2182            let v_b = vec![
2183                -9223372036854775808,
2184                -9223372036854775807,
2185                -1,
2186                0,
2187                1,
2188                9223372036854775806,
2189                9223372036854775807,
2190            ];
2191            assert_eq!(
2192                test_signed_comparison_cust_op_for_vec_helper(
2193                    CustomOperation::new(GreaterThanEqualTo {
2194                        signed_comparison: true
2195                    }),
2196                    v_a.clone(),
2197                    v_b.clone(),
2198                    vec![7, 1, 64],
2199                    vec![1, 7, 64],
2200                )?,
2201                vec![
2202                    1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0,
2203                    0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1
2204                ]
2205            );
2206
2207            let mut v_a = vec![-6749, -1885, 7550, 9659];
2208            let mut v_b = vec![
2209                9918, 3462, -5690, 3436, 3214, -1733, 6171, 3148, -3534, 8282, -4904, -5976,
2210            ];
2211            assert_eq!(
2212                test_signed_comparison_cust_op_for_vec_helper(
2213                    CustomOperation::new(GreaterThan {
2214                        signed_comparison: true
2215                    }),
2216                    v_a.clone(),
2217                    v_b.clone(),
2218                    vec![2, 2, 64],
2219                    vec![3, 2, 2, 64],
2220                )?,
2221                vec![0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
2222            );
2223
2224            v_a = vec![
2225                -48, -9935, -745, 2360, -4597, -5271, 5130, -2632, 3112, 8089, 8293, 6058,
2226            ];
2227            v_b = vec![2913, 7260, 1388, 6205, 1855, 3246];
2228            assert_eq!(
2229                test_signed_comparison_cust_op_for_vec_helper(
2230                    CustomOperation::new(GreaterThan {
2231                        signed_comparison: true
2232                    }),
2233                    v_a.clone(),
2234                    v_b.clone(),
2235                    vec![2, 3, 2, 32],
2236                    vec![3, 2, 32],
2237                )?,
2238                vec![0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]
2239            );
2240
2241            v_a = vec![9838, -574, -4181, -8107, -2880, -2866, 2272, 3743];
2242            v_b = vec![626, 4664, 1490, -5118, 7485, 6160, 4221, 2092];
2243            assert_eq!(
2244                test_signed_comparison_cust_op_for_vec_helper(
2245                    CustomOperation::new(GreaterThan {
2246                        signed_comparison: true
2247                    }),
2248                    v_a.clone(),
2249                    v_b.clone(),
2250                    vec![2, 2, 2, 64],
2251                    vec![2, 2, 2, 64],
2252                )?,
2253                vec![1, 0, 0, 0, 0, 0, 0, 1]
2254            );
2255
2256            v_a = vec![-75, 95, -84, 67, -81, 14];
2257            v_b = vec![-78, 21, -66];
2258            assert_eq!(
2259                test_signed_comparison_cust_op_for_vec_helper(
2260                    CustomOperation::new(LessThan {
2261                        signed_comparison: true
2262                    }),
2263                    v_a.clone(),
2264                    v_b.clone(),
2265                    vec![2, 3, 8],
2266                    vec![3, 8],
2267                )?,
2268                vec![0, 0, 1, 0, 1, 0]
2269            );
2270
2271            v_a = vec![-52, -119, 30, -24, -74, -45, 66, 110, 21, 1, 95, -66];
2272            v_b = vec![33, -78, 39];
2273            assert_eq!(
2274                test_signed_comparison_cust_op_for_vec_helper(
2275                    CustomOperation::new(LessThanEqualTo {
2276                        signed_comparison: true
2277                    }),
2278                    v_a.clone(),
2279                    v_b.clone(),
2280                    vec![2, 2, 3, 8],
2281                    vec![3, 8],
2282                )?,
2283                vec![1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1]
2284            );
2285
2286            v_a = vec![-128, 127, 0, 1, 0, -128, 1, 127];
2287            v_b = vec![-128];
2288            assert_eq!(
2289                test_signed_comparison_cust_op_for_vec_helper(
2290                    CustomOperation::new(GreaterThanEqualTo {
2291                        signed_comparison: true
2292                    }),
2293                    v_a.clone(),
2294                    v_b.clone(),
2295                    vec![2, 2, 2, 8],
2296                    vec![1, 8],
2297                )?,
2298                vec![1, 1, 1, 1, 1, 1, 1, 1]
2299            );
2300
2301            v_a = vec![-128, 127, 0, 1, 0, -128, 1, 127];
2302            v_b = vec![-128];
2303            assert_eq!(
2304                test_signed_comparison_cust_op_for_vec_helper(
2305                    CustomOperation::new(GreaterThan {
2306                        signed_comparison: true
2307                    }),
2308                    v_a.clone(),
2309                    v_b.clone(),
2310                    vec![2, 2, 2, 8],
2311                    vec![1, 8],
2312                )?,
2313                vec![0, 1, 1, 1, 1, 0, 1, 1]
2314            );
2315
2316            Ok(())
2317        }()
2318        .unwrap();
2319    }
2320
2321    #[test]
2322    fn test_comparison_graph_size() -> Result<()> {
2323        let mut custom_ops = vec![];
2324        custom_ops.push(CustomOperation::new(Equal {}));
2325        custom_ops.push(CustomOperation::new(NotEqual {}));
2326        for &signed_comparison in [false, true].iter() {
2327            custom_ops.push(CustomOperation::new(GreaterThan { signed_comparison }));
2328            custom_ops.push(CustomOperation::new(LessThan { signed_comparison }));
2329            custom_ops.push(CustomOperation::new(GreaterThanEqualTo {
2330                signed_comparison,
2331            }));
2332            custom_ops.push(CustomOperation::new(LessThanEqualTo { signed_comparison }));
2333        }
2334
2335        for custom_op in custom_ops.into_iter() {
2336            let c = create_context()?;
2337            let g = c.create_graph()?;
2338            let i_a = g.input(array_type(vec![64], BIT))?;
2339            let i_b = g.input(array_type(vec![64], BIT))?;
2340            let o = g.custom_op(custom_op, vec![i_a, i_b])?;
2341            g.set_output_node(o)?;
2342            g.finalize()?;
2343
2344            c.set_main_graph(g.clone())?;
2345            c.finalize()?;
2346
2347            let inline_config = InlineConfig {
2348                default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
2349                ..Default::default()
2350            };
2351            let instantiated_context = run_instantiation_pass(c)?.get_context();
2352            let inlined_context = inline_operations(instantiated_context, inline_config.clone())?;
2353            let num_nodes = inlined_context.get_main_graph()?.get_num_nodes();
2354
2355            assert!(num_nodes <= 200);
2356        }
2357        Ok(())
2358    }
2359}