use crate::custom_ops::{CustomOperation, CustomOperationBody, Not};
use crate::data_types::{array_type, scalar_type, ArrayShape, Type, BIT};
use crate::data_values::Value;
use crate::errors::Result;
use crate::graphs::*;
use crate::ops::utils::pull_out_bits;
use crate::ops::utils::{expand_dims, validate_arguments_in_broadcast_bit_ops};
use std::cmp::max;
use serde::{Deserialize, Serialize};
use super::utils::unsqueeze;
#[derive(Clone)]
struct ComparisonResult {
a_equal_b: Node,
a: Node,
}
struct ShrinkResult {
shrinked: Option<ComparisonResult>,
remainder: Option<ComparisonResult>,
}
impl ComparisonResult {
fn from_a_b(a: Node, b: Node) -> Result<Self> {
let graph = a.get_graph();
let one = graph.ones(scalar_type(BIT))?;
let a_equal_b = a.add(b)?.add(one)?;
Ok(Self { a_equal_b, a })
}
fn join(&self, rhs: &Self) -> Result<Self> {
let graph = &self.a_equal_b.get_graph();
let one = graph.ones(scalar_type(BIT))?;
let a = self
.a
.multiply(rhs.a_equal_b.clone())?
.add(rhs.a.multiply(rhs.a_equal_b.add(one)?)?)?;
let a_equal_b = self.a_equal_b.multiply(rhs.a_equal_b.clone())?;
Ok(Self { a_equal_b, a })
}
fn shrink(&self) -> Result<ShrinkResult> {
let bit_len = self.a_equal_b.get_type()?.get_shape()[0] as i64;
let offset = bit_len % 2;
let remainder = if offset == 0 {
None
} else {
Some(Self {
a_equal_b: self.a_equal_b.get(vec![0])?,
a: self.a.get(vec![0])?,
})
};
let shrinked = if bit_len <= 1 {
None
} else {
let slice0 = self.sub_slice(offset, bit_len)?;
let slice1 = self.sub_slice(offset + 1, bit_len)?;
Some(slice0.join(&slice1)?)
};
Ok(ShrinkResult {
shrinked,
remainder,
})
}
fn sub_slice(&self, start_offset: i64, bit_len: i64) -> Result<Self> {
let get_slice = |node: &Node| {
node.get_slice(vec![SliceElement::SubArray(
Some(start_offset),
Some(bit_len),
Some(2),
)])
};
Ok(Self {
a_equal_b: get_slice(&self.a_equal_b)?,
a: get_slice(&self.a)?,
})
}
fn not_a(&self) -> Result<Node> {
let graph = self.a_equal_b.get_graph();
graph.custom_op(CustomOperation::new(Not {}), vec![self.a.clone()])
}
fn equal(&self) -> Result<Node> {
Ok(self.a_equal_b.clone())
}
fn not_equal(&self) -> Result<Node> {
let graph = self.a_equal_b.get_graph();
graph.custom_op(CustomOperation::new(Not {}), vec![self.equal()?])
}
fn less_than(&self) -> Result<Node> {
self.not_a()?.multiply(self.not_equal()?)
}
fn greater_than(&self) -> Result<Node> {
self.a.multiply(self.not_equal()?)
}
fn greater_than_equal_to(&self) -> Result<Node> {
let graph = self.a_equal_b.get_graph();
graph.custom_op(CustomOperation::new(Not {}), vec![self.less_than()?])
}
fn less_than_equal_to(&self) -> Result<Node> {
let graph = self.a_equal_b.get_graph();
graph.custom_op(CustomOperation::new(Not {}), vec![self.greater_than()?])
}
}
fn build_comparison_graph(a: Node, b: Node) -> Result<ComparisonResult> {
let mut to_shrink = ComparisonResult::from_a_b(a, b)?;
let mut remainders = vec![];
loop {
let shrink_res = to_shrink.shrink()?;
if let Some(remainder) = shrink_res.remainder {
remainders.push(remainder);
}
if let Some(shrinked) = shrink_res.shrinked {
to_shrink = shrinked;
} else {
break;
}
}
let mut res = remainders[0].clone();
for remainder in remainders[1..].iter() {
res = res.join(remainder)?;
}
Ok(res)
}
fn expand_to_same_dims(a: Node, b: Node) -> Result<(Node, Node)> {
let len_a = a.get_type()?.get_shape().len();
let len_b = b.get_type()?.get_shape().len();
let result_len = max(len_a, len_b);
let a = expand_dims(a, &(0..result_len - len_a).collect::<Vec<_>>())?;
let b = expand_dims(b, &(0..result_len - len_b).collect::<Vec<_>>())?;
Ok((a, b))
}
pub(super) fn flip_msb(ip: Node) -> Result<Node> {
ip.add(get_msb_flip_constant(
ip.get_type()?.get_shape(),
&ip.get_graph(),
)?)
}
fn get_msb_flip_constant(shape: ArrayShape, g: &Graph) -> Result<Node> {
let n = shape[shape.len() - 1] as usize;
let mut msb_mask = vec![0; n];
msb_mask[n - 1] = 1;
let mut msb_mask = g.constant(
array_type(vec![n as u64], BIT),
Value::from_flattened_array(&msb_mask, BIT)?,
)?;
while msb_mask.get_type()?.get_shape().len() < shape.len() {
msb_mask = unsqueeze(msb_mask, 0)?;
}
Ok(msb_mask)
}
fn preprocess_input(signed_comparison: bool, node: Node) -> Result<Node> {
let node = if signed_comparison {
flip_msb(node)?
} else {
node
};
pull_out_bits(node)
}
fn preprocess_inputs(signed_comparison: bool, a: Node, b: Node) -> Result<(Node, Node)> {
let (a, b) = expand_to_same_dims(a, b)?;
let a = preprocess_input(signed_comparison, a)?;
let b = preprocess_input(signed_comparison, b)?;
Ok((a, b))
}
fn validate_signed_arguments(custom_op_name: &str, arguments_types: Vec<Type>) -> Result<()> {
for (arg_id, arg_type) in arguments_types.iter().enumerate() {
if *arg_type.get_shape().last().unwrap() < 2 {
return Err(runtime_error!(
"{custom_op_name}: Signed input{arg_id} has less than 2 bits"
));
}
}
Ok(())
}
fn instantiate_comparison_custom_op(
context: Context,
arguments_types: Vec<Type>,
signed_comparison: bool,
custom_op_name: &str,
post_process_result: impl FnOnce(&ComparisonResult) -> Result<Node>,
) -> Result<Graph> {
validate_arguments_in_broadcast_bit_ops(arguments_types.clone(), custom_op_name)?;
if signed_comparison {
validate_signed_arguments(custom_op_name, arguments_types.clone())?;
}
let graph = context.create_graph()?;
let a = graph.input(arguments_types[0].clone())?;
let b = graph.input(arguments_types[1].clone())?;
let (a, b) = preprocess_inputs(signed_comparison, a, b)?;
let result = post_process_result(&build_comparison_graph(a, b)?)?;
graph.set_output_node(result)?;
graph.finalize()?;
Ok(graph)
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct GreaterThan {
pub signed_comparison: bool,
}
#[typetag::serde]
impl CustomOperationBody for GreaterThan {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
instantiate_comparison_custom_op(
context,
arguments_types,
self.signed_comparison,
&self.get_name(),
|res| res.greater_than(),
)
}
fn get_name(&self) -> String {
format!("GreaterThan(signed_comparison={})", self.signed_comparison)
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct NotEqual {}
#[typetag::serde]
impl CustomOperationBody for NotEqual {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
instantiate_comparison_custom_op(context, arguments_types, false, &self.get_name(), |res| {
res.not_equal()
})
}
fn get_name(&self) -> String {
"NotEqual".to_owned()
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LessThan {
pub signed_comparison: bool,
}
#[typetag::serde]
impl CustomOperationBody for LessThan {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
instantiate_comparison_custom_op(
context,
arguments_types,
self.signed_comparison,
&self.get_name(),
|res| res.less_than(),
)
}
fn get_name(&self) -> String {
format!("LessThan(signed_comparison={})", self.signed_comparison)
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct LessThanEqualTo {
pub signed_comparison: bool,
}
#[typetag::serde]
impl CustomOperationBody for LessThanEqualTo {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
instantiate_comparison_custom_op(
context,
arguments_types,
self.signed_comparison,
&self.get_name(),
|res| res.less_than_equal_to(),
)
}
fn get_name(&self) -> String {
format!(
"LessThanEqualTo(signed_comparison={})",
self.signed_comparison
)
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct GreaterThanEqualTo {
pub signed_comparison: bool,
}
#[typetag::serde]
impl CustomOperationBody for GreaterThanEqualTo {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
instantiate_comparison_custom_op(
context,
arguments_types,
self.signed_comparison,
&self.get_name(),
|res| res.greater_than_equal_to(),
)
}
fn get_name(&self) -> String {
format!(
"GreaterThanEqualTo(signed_comparison={})",
self.signed_comparison
)
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Equal {}
#[typetag::serde]
impl CustomOperationBody for Equal {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
instantiate_comparison_custom_op(context, arguments_types, false, &self.get_name(), |res| {
res.equal()
})
}
fn get_name(&self) -> String {
"Equal".to_owned()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::broadcast::broadcast_shapes;
use crate::custom_ops::run_instantiation_pass;
use crate::custom_ops::CustomOperation;
use crate::data_types::scalar_type;
use crate::data_types::tuple_type;
use crate::data_types::ArrayShape;
use crate::data_types::{
array_type, ScalarType, INT16, INT32, INT64, INT8, UINT16, UINT32, UINT64, UINT8,
};
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::create_context;
use crate::inline::inline_common::DepthOptimizationLevel;
use crate::inline::inline_ops::inline_operations;
use crate::inline::inline_ops::InlineConfig;
use crate::inline::inline_ops::InlineMode;
fn test_unsigned_greater_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(GreaterThan {
signed_comparison: false,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_signed_greater_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(GreaterThan {
signed_comparison: true,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
let random_val = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?;
let op = random_val.to_u8(BIT)?;
Ok(op)
}
fn test_not_equal_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(CustomOperation::new(NotEqual {}), vec![i_a, i_b])?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn get_u_scalar_type_from_bits(bit_size: u64) -> Result<ScalarType> {
match bit_size {
8 => Ok(UINT8),
16 => Ok(UINT16),
32 => Ok(UINT32),
64 => Ok(UINT64),
_ => Err(runtime_error!("Unsupported bit size")),
}
}
fn get_s_scalar_type_from_bits(bit_size: u64) -> Result<ScalarType> {
match bit_size {
8 => Ok(INT8),
16 => Ok(INT16),
32 => Ok(INT32),
64 => Ok(INT64),
_ => Err(runtime_error!("Unsupported bit size")),
}
}
fn test_unsigned_comparison_cust_op_for_vec_helper(
comparison_op: CustomOperation,
a: Vec<u64>,
b: Vec<u64>,
shape_a: ArrayShape,
shape_b: ArrayShape,
) -> Result<Vec<u64>> {
let bit_vector_len_a = shape_a[shape_a.len() - 1];
let bit_vector_len_b = shape_b[shape_b.len() - 1];
let data_scalar_type_a = get_u_scalar_type_from_bits(bit_vector_len_a)?;
let data_scalar_type_b = get_u_scalar_type_from_bits(bit_vector_len_b)?;
let c = create_context()?;
let g = c.create_graph()?;
let i_va = g.input(array_type(shape_a.clone(), BIT))?;
let i_vb = g.input(array_type(shape_b.clone(), BIT))?;
let o = g.custom_op(comparison_op.clone(), vec![i_va, i_vb])?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, data_scalar_type_a)?;
let v_b = Value::from_flattened_array(&b, data_scalar_type_b)?;
let broadcasted_output_shape = broadcast_shapes(
shape_a[0..(shape_a.len() - 1)].to_vec(),
shape_b[0..(shape_b.len() - 1)].to_vec(),
)?;
let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?
.to_flattened_array_u64(array_type(broadcasted_output_shape, BIT))?;
Ok(result)
}
fn test_signed_comparison_cust_op_for_vec_helper(
comparison_op: CustomOperation,
a: Vec<i64>,
b: Vec<i64>,
shape_a: ArrayShape,
shape_b: ArrayShape,
) -> Result<Vec<u64>> {
let bit_vector_len_a = shape_a[shape_a.len() - 1];
let bit_vector_len_b = shape_b[shape_b.len() - 1];
let data_scalar_type_a = get_s_scalar_type_from_bits(bit_vector_len_a)?;
let data_scalar_type_b = get_s_scalar_type_from_bits(bit_vector_len_b)?;
let c = create_context()?;
let g = c.create_graph()?;
let i_va = g.input(array_type(shape_a.clone(), BIT))?;
let i_vb = g.input(array_type(shape_b.clone(), BIT))?;
let o = g.custom_op(comparison_op.clone(), vec![i_va, i_vb])?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, data_scalar_type_a)?;
let v_b = Value::from_flattened_array(&b, data_scalar_type_b)?;
let broadcasted_output_shape = broadcast_shapes(
shape_a[0..(shape_a.len() - 1)].to_vec(),
shape_b[0..(shape_b.len() - 1)].to_vec(),
)?;
let result = random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?
.to_flattened_array_u64(array_type(broadcasted_output_shape, BIT))?;
Ok(result)
}
fn test_unsigned_less_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(LessThan {
signed_comparison: false,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_signed_less_than_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(LessThan {
signed_comparison: true,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_unsigned_less_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(LessThanEqualTo {
signed_comparison: false,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_signed_less_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(LessThanEqualTo {
signed_comparison: true,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_unsigned_greater_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: false,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_signed_greater_than_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: true,
}),
vec![i_a, i_b],
)?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
fn test_equal_to_cust_op_helper(a: Vec<u8>, b: Vec<u8>) -> Result<u8> {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![a.len() as u64], BIT))?;
let i_b = g.input(array_type(vec![b.len() as u64], BIT))?;
let o = g.custom_op(CustomOperation::new(Equal {}), vec![i_a, i_b])?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let mapped_c = run_instantiation_pass(c)?;
let v_a = Value::from_flattened_array(&a, BIT)?;
let v_b = Value::from_flattened_array(&b, BIT)?;
Ok(random_evaluate(mapped_c.mappings.get_graph(g), vec![v_a, v_b])?.to_u8(BIT)?)
}
#[test]
fn test_greater_than_cust_op() {
|| -> Result<()> {
assert_eq!(
test_unsigned_greater_than_cust_op_helper(vec![0], vec![0])?,
0
);
assert_eq!(
test_unsigned_greater_than_cust_op_helper(vec![0], vec![1])?,
0
);
assert_eq!(
test_unsigned_greater_than_cust_op_helper(vec![1], vec![0])?,
1
);
assert_eq!(
test_unsigned_greater_than_cust_op_helper(vec![1], vec![1])?,
0
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_greater_than_cust_op() {
|| -> Result<()> {
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 0], vec![0, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 0], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 0], vec![0, 0])?,
1
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 0], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 1], vec![0, 1])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 1], vec![1, 1])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 1], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 1], vec![1, 1])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 1], vec![0, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 0], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 1], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![0, 0], vec![1, 1])?,
1
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 1], vec![0, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 0], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 1], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_cust_op_helper(vec![1, 0], vec![1, 1])?,
1
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_unsigned_less_than_cust_op() {
|| -> Result<()> {
assert_eq!(test_unsigned_less_than_cust_op_helper(vec![0], vec![0])?, 0);
assert_eq!(test_unsigned_less_than_cust_op_helper(vec![0], vec![1])?, 1);
assert_eq!(test_unsigned_less_than_cust_op_helper(vec![1], vec![0])?, 0);
assert_eq!(test_unsigned_less_than_cust_op_helper(vec![1], vec![1])?, 0);
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_less_than_cust_op() {
|| -> Result<()> {
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 0], vec![0, 0])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 0], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 0], vec![0, 0])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 0], vec![1, 0])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 1], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 1], vec![1, 1])?,
1
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 1], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 1], vec![1, 1])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 1], vec![0, 0])?,
1
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 0], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 1], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![0, 0], vec![1, 1])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 1], vec![0, 0])?,
1
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 0], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 1], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_cust_op_helper(vec![1, 0], vec![1, 1])?,
0
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_unsigned_less_than_or_eq_to_cust_op() {
|| -> Result<()> {
assert_eq!(
test_unsigned_less_than_equal_to_cust_op_helper(vec![0], vec![0])?,
1
);
assert_eq!(
test_unsigned_less_than_equal_to_cust_op_helper(vec![0], vec![1])?,
1
);
assert_eq!(
test_unsigned_less_than_equal_to_cust_op_helper(vec![1], vec![0])?,
0
);
assert_eq!(
test_unsigned_less_than_equal_to_cust_op_helper(vec![1], vec![1])?,
1
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_less_than_or_eq_to_cust_op() {
|| -> Result<()> {
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 0])?,
0
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 1])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 1])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 1])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 1])?,
0
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 1])?,
0
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 0])?,
1
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 1])?,
0
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_unsigned_greater_than_or_eq_to_cust_op() {
|| -> Result<()> {
assert_eq!(
test_unsigned_greater_than_equal_to_cust_op_helper(vec![0], vec![0])?,
1
);
assert_eq!(
test_unsigned_greater_than_equal_to_cust_op_helper(vec![0], vec![1])?,
0
);
assert_eq!(
test_unsigned_greater_than_equal_to_cust_op_helper(vec![1], vec![0])?,
1
);
assert_eq!(
test_unsigned_greater_than_equal_to_cust_op_helper(vec![1], vec![1])?,
1
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_greater_than_or_eq_to_cust_op() {
|| -> Result<()> {
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 0])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 0])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 0])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 1])?,
0
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 1])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![0, 0])?,
0
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 1], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![0, 0], vec![1, 1])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![0, 0])?,
0
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![0, 1])?,
1
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 1], vec![1, 0])?,
0
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(vec![1, 0], vec![1, 1])?,
1
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_not_equal_cust_op() {
|| -> Result<()> {
assert_eq!(test_not_equal_cust_op_helper(vec![0], vec![0])?, 0);
assert_eq!(test_not_equal_cust_op_helper(vec![0], vec![1])?, 1);
assert_eq!(test_not_equal_cust_op_helper(vec![1], vec![0])?, 1);
assert_eq!(test_not_equal_cust_op_helper(vec![1], vec![1])?, 0);
Ok(())
}()
.unwrap();
}
#[test]
fn test_equal_to_cust_op() {
|| -> Result<()> {
assert_eq!(test_equal_to_cust_op_helper(vec![0], vec![0])?, 1);
assert_eq!(test_equal_to_cust_op_helper(vec![0], vec![1])?, 0);
assert_eq!(test_equal_to_cust_op_helper(vec![1], vec![0])?, 0);
assert_eq!(test_equal_to_cust_op_helper(vec![1], vec![1])?, 1);
Ok(())
}()
.unwrap();
}
#[test]
fn test_unsigned_multiple_bit_comparisons_cust_op() {
|| -> Result<()> {
for i in 0..8 {
for j in 0..8 {
let a: Vec<u8> = vec![i & 1, (i & 2) >> 1, (i & 4) >> 2];
let b: Vec<u8> = vec![j & 1, (j & 2) >> 1, (j & 4) >> 2];
assert_eq!(
test_unsigned_greater_than_cust_op_helper(a.clone(), b.clone())?,
if i > j { 1 } else { 0 }
);
assert_eq!(
test_unsigned_less_than_cust_op_helper(a.clone(), b.clone())?,
if i < j { 1 } else { 0 }
);
assert_eq!(
test_unsigned_greater_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
if i >= j { 1 } else { 0 }
);
assert_eq!(
test_unsigned_less_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
if i <= j { 1 } else { 0 }
);
assert_eq!(
test_not_equal_cust_op_helper(a.clone(), b.clone())?,
if i != j { 1 } else { 0 }
);
assert_eq!(
test_equal_to_cust_op_helper(a.clone(), b.clone())?,
if i == j { 1 } else { 0 }
);
}
}
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_multiple_bit_comparisons_cust_op() {
|| -> Result<()> {
for i in 0..8 {
for j in 0..8 {
let a: Vec<u8> = vec![i & 1, (i & 2) >> 1, (i & 4) >> 2];
let b: Vec<u8> = vec![j & 1, (j & 2) >> 1, (j & 4) >> 2];
let s_i = if i > 3 { i as i8 - 8 } else { i as i8 };
let s_j = if j > 3 { j as i8 - 8 } else { j as i8 };
assert_eq!(
test_signed_greater_than_cust_op_helper(a.clone(), b.clone())?,
if s_i > s_j { 1 } else { 0 }
);
assert_eq!(
test_signed_less_than_cust_op_helper(a.clone(), b.clone())?,
if s_i < s_j { 1 } else { 0 }
);
assert_eq!(
test_signed_greater_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
if s_i >= s_j { 1 } else { 0 }
);
assert_eq!(
test_signed_less_than_equal_to_cust_op_helper(a.clone(), b.clone())?,
if s_i <= s_j { 1 } else { 0 }
);
}
}
Ok(())
}()
.unwrap();
}
#[test]
fn test_unsigned_malformed_basic_cust_ops() {
|| -> Result<()> {
let cust_ops = vec![
CustomOperation::new(GreaterThan {
signed_comparison: false,
}),
CustomOperation::new(NotEqual {}),
];
for cust_op in cust_ops.into_iter() {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1], BIT))?;
let i_c = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b, i_c]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(scalar_type(BIT))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(tuple_type(vec![array_type(vec![1], BIT)]))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], INT16))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], UINT16))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1], INT32))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1], UINT32))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![9], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1, 2], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let v_a = vec![170, 120, 61, 85];
let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
assert!(test_unsigned_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![2, 2, 16],
vec![3, 3, 32]
)
.is_err());
let v_a = vec![170];
let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
assert!(test_unsigned_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![2, 2, 16],
vec![3, 3, 16]
)
.is_err());
let v_a = vec![];
let v_b = vec![76, 20, 70, 249, 217, 190, 43, 83, 33710];
assert!(test_unsigned_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![0, 64],
vec![3, 3, 64]
)
.is_err());
let v_a = vec![170, 200];
let v_b = vec![];
assert!(test_unsigned_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![2, 1, 64],
vec![2, 2, 1, 64]
)
.is_err());
}
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_malformed_basic_cust_ops() {
|| -> Result<()> {
let cust_ops = vec![CustomOperation::new(GreaterThan {
signed_comparison: true,
})];
for cust_op in cust_ops.into_iter() {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1, 1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1, 1], BIT))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1, 64], BIT))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(scalar_type(BIT))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], UINT16))?;
let i_b = g.input(array_type(vec![1], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1], INT32))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1], UINT32))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![9], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1, 2, 3], BIT))?;
let i_b = g.input(array_type(vec![9], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![1], BIT))?;
let i_b = g.input(array_type(vec![1, 2], BIT))?;
assert!(g.custom_op(cust_op.clone(), vec![i_a, i_b]).is_err());
let v_a = vec![170, 120, 61, 85];
let v_b = vec![
-1176658021,
-301476304,
788180273,
-1085188538,
-1358798926,
-120286105,
-1300942710,
-389618936,
258721418,
];
assert!(test_signed_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![2, 2, 16],
vec![3, 3, 32]
)
.is_err());
let v_a = vec![-14735];
let v_b = vec![
16490, -10345, -31409, 2787, -15039, 26085, 7881, 32423, -23915,
];
assert!(test_signed_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![2, 2, 16],
vec![3, 3, 16]
)
.is_err());
let v_a = vec![];
let v_b = vec![
-2600362169875399934,
6278463339984150730,
-2962726308672949899,
3404980137287029349,
];
assert!(test_signed_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![0, 64],
vec![2, 2, 64]
)
.is_err());
let v_a = vec![-2600362169875399934, 6278463339984150730];
let v_b = vec![];
assert!(test_signed_comparison_cust_op_for_vec_helper(
cust_op.clone(),
v_a.clone(),
v_b.clone(),
vec![2, 1, 64],
vec![2, 2, 1, 64]
)
.is_err());
}
Ok(())
}()
.unwrap();
}
#[test]
fn test_unsigned_vector_comparisons() {
|| -> Result<()> {
let mut v_a = vec![170, 120, 61, 85];
let mut v_b = vec![
76, 20, 70, 249, 217, 190, 43, 83, 33710, 27637, 43918, 38683,
];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 64],
vec![3, 2, 2, 64],
)?,
vec![1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0]
);
v_a = vec![170, 120, 61, 85, 75, 149, 50, 54, 8811, 29720, 1009, 33126];
v_b = vec![76, 20, 70, 249, 217, 190];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 3, 2, 32],
vec![3, 2, 32],
)?,
vec![1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
);
v_a = vec![170, 120, 61, 85, 75, 149, 50, 54];
v_b = vec![76, 20, 70, 249];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 2, 16],
vec![2, 2, 16],
)?,
vec![1, 1, 0, 0, 0, 1, 0, 0]
);
v_a = vec![170, 120, 61, 85, 75, 149, 50, 54];
v_b = vec![76, 20, 70, 249, 217, 190, 43, 83];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 2, 64],
vec![2, 2, 2, 64],
)?,
vec![1, 1, 0, 0, 0, 0, 1, 0]
);
v_a = vec![170, 120, 61];
v_b = vec![76, 20, 70];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![3, 64],
vec![3, 64],
)?,
vec![1, 1, 0]
);
v_a = vec![170, 120, 61, 85, 75, 149];
v_b = vec![76, 20, 70];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(LessThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 3, 64],
vec![3, 64],
)?,
vec![0, 0, 1, 0, 0, 0]
);
v_a = vec![170, 120, 61, 85, 75, 70, 50, 54, 8811, 29720, 1009, 33126];
v_b = vec![76, 1009, 70];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(LessThanEqualTo {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 3, 64],
vec![3, 64],
)?,
vec![0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0]
);
v_a = vec![170, 120, 61, 85, 75, 76, 50, 54];
v_b = vec![76];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 2, 64],
vec![1, 64],
)?,
vec![1, 1, 0, 1, 0, 1, 0, 0]
);
v_a = vec![170];
v_b = vec![76];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![1, 64],
vec![1, 64],
)?,
vec![1]
);
v_a = vec![76];
v_b = vec![76, 170];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![1, 64],
vec![1, 2, 64],
)?,
vec![1, 0]
);
let v_a = vec![83, 172, 214, 2, 68];
let v_b = vec![83];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: false
}),
v_a,
v_b,
vec![5, 8],
vec![8]
)?,
vec![1, 1, 1, 0, 0]
);
let v_a = vec![2];
let v_b = vec![83, 1, 2, 100];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(LessThan {
signed_comparison: false
}),
v_a,
v_b,
vec![1, 32],
vec![2, 2, 32]
)?,
vec![1, 0, 0, 1]
);
let v_a = vec![83, 2];
let v_b = vec![83, 172, 214, 2, 68, 34, 87, 45, 83, 23];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(LessThanEqualTo {
signed_comparison: false
}),
v_a,
v_b,
vec![2, 1, 64],
vec![2, 5, 64]
)?,
vec![1, 1, 1, 0, 0, 1, 1, 1, 1, 1]
);
let v_a = vec![83, 2];
let v_b = vec![83, 172, 214, 2, 68, 2, 87, 45];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(NotEqual {}),
v_a,
v_b,
vec![2, 1, 64],
vec![2, 4, 64]
)?,
vec![0, 1, 1, 1, 1, 0, 1, 1]
);
let v_a = vec![4, 2];
let v_b = vec![83, 21];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(NotEqual {}),
v_a,
v_b,
vec![1, 2, 64],
vec![2, 1, 64]
)?,
vec![1, 1, 1, 1]
);
let v_a = vec![247, 170, 249, 162, 102, 243, 61, 203, 125];
let v_b = vec![247, 170, 249, 162, 102, 243, 61, 203, 125];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(NotEqual {}),
v_a,
v_b,
vec![3, 3, 16],
vec![3, 3, 16]
)?,
vec![0, 0, 0, 0, 0, 0, 0, 0, 0]
);
let v_a = vec![83, 2];
let v_b = vec![83, 172, 214, 2, 68, 2, 87, 45];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(Equal {}),
v_a,
v_b,
vec![2, 1, 64],
vec![2, 4, 64]
)?,
vec![1, 0, 0, 0, 0, 1, 0, 0]
);
let v_a = vec![4, 2];
let v_b = vec![83, 21];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(Equal {}),
v_a,
v_b,
vec![1, 2, 64],
vec![2, 1, 64]
)?,
vec![0, 0, 0, 0]
);
let v_a = vec![180, 16, 62, 141, 122, 217];
let v_b = vec![141, 122, 217, 100, 11, 29];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(Equal {}),
v_a,
v_b,
vec![3, 2, 1, 16],
vec![1, 2, 3, 16]
)?,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 ]
);
let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![4, 1, 64],
vec![1, 4, 64],
)?,
vec![0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0]
);
let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: false
}),
v_a.clone(),
v_b.clone(),
vec![4, 1, 64],
vec![1, 4, 64],
)?,
vec![1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1]
);
let v_a = vec![0, 1, 18446744073709551614, 18446744073709551615];
let v_b = vec![0, 1, 18446744073709551614, 18446744073709551615];
assert_eq!(
test_unsigned_comparison_cust_op_for_vec_helper(
CustomOperation::new(NotEqual {}),
v_a.clone(),
v_b.clone(),
vec![4, 1, 64],
vec![1, 4, 64],
)?,
vec![0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_signed_vector_comparisons() {
|| -> Result<()> {
let v_a = vec![
-9223372036854775808,
-9223372036854775807,
-1,
0,
1,
9223372036854775806,
9223372036854775807,
];
let v_b = vec![
-9223372036854775808,
-9223372036854775807,
-1,
0,
1,
9223372036854775806,
9223372036854775807,
];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![7, 1, 64],
vec![1, 7, 64],
)?,
vec![
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0,
0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0
]
);
let v_a = vec![
-9223372036854775808,
-9223372036854775807,
-1,
0,
1,
9223372036854775806,
9223372036854775807,
];
let v_b = vec![
-9223372036854775808,
-9223372036854775807,
-1,
0,
1,
9223372036854775806,
9223372036854775807,
];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![7, 1, 64],
vec![1, 7, 64],
)?,
vec![
1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0,
0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1
]
);
let mut v_a = vec![-6749, -1885, 7550, 9659];
let mut v_b = vec![
9918, 3462, -5690, 3436, 3214, -1733, 6171, 3148, -3534, 8282, -4904, -5976,
];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 64],
vec![3, 2, 2, 64],
)?,
vec![0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
);
v_a = vec![
-48, -9935, -745, 2360, -4597, -5271, 5130, -2632, 3112, 8089, 8293, 6058,
];
v_b = vec![2913, 7260, 1388, 6205, 1855, 3246];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 3, 2, 32],
vec![3, 2, 32],
)?,
vec![0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1]
);
v_a = vec![9838, -574, -4181, -8107, -2880, -2866, 2272, 3743];
v_b = vec![626, 4664, 1490, -5118, 7485, 6160, 4221, 2092];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 2, 64],
vec![2, 2, 2, 64],
)?,
vec![1, 0, 0, 0, 0, 0, 0, 1]
);
v_a = vec![-75, 95, -84, 67, -81, 14];
v_b = vec![-78, 21, -66];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(LessThan {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 3, 8],
vec![3, 8],
)?,
vec![0, 0, 1, 0, 1, 0]
);
v_a = vec![-52, -119, 30, -24, -74, -45, 66, 110, 21, 1, 95, -66];
v_b = vec![33, -78, 39];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(LessThanEqualTo {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 3, 8],
vec![3, 8],
)?,
vec![1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1]
);
v_a = vec![-128, 127, 0, 1, 0, -128, 1, 127];
v_b = vec![-128];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 2, 8],
vec![1, 8],
)?,
vec![1, 1, 1, 1, 1, 1, 1, 1]
);
v_a = vec![-128, 127, 0, 1, 0, -128, 1, 127];
v_b = vec![-128];
assert_eq!(
test_signed_comparison_cust_op_for_vec_helper(
CustomOperation::new(GreaterThan {
signed_comparison: true
}),
v_a.clone(),
v_b.clone(),
vec![2, 2, 2, 8],
vec![1, 8],
)?,
vec![0, 1, 1, 1, 1, 0, 1, 1]
);
Ok(())
}()
.unwrap();
}
#[test]
fn test_comparison_graph_size() -> Result<()> {
let mut custom_ops = vec![];
custom_ops.push(CustomOperation::new(Equal {}));
custom_ops.push(CustomOperation::new(NotEqual {}));
for &signed_comparison in [false, true].iter() {
custom_ops.push(CustomOperation::new(GreaterThan { signed_comparison }));
custom_ops.push(CustomOperation::new(LessThan { signed_comparison }));
custom_ops.push(CustomOperation::new(GreaterThanEqualTo {
signed_comparison,
}));
custom_ops.push(CustomOperation::new(LessThanEqualTo { signed_comparison }));
}
for custom_op in custom_ops.into_iter() {
let c = create_context()?;
let g = c.create_graph()?;
let i_a = g.input(array_type(vec![64], BIT))?;
let i_b = g.input(array_type(vec![64], BIT))?;
let o = g.custom_op(custom_op, vec![i_a, i_b])?;
g.set_output_node(o)?;
g.finalize()?;
c.set_main_graph(g.clone())?;
c.finalize()?;
let inline_config = InlineConfig {
default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
..Default::default()
};
let instantiated_context = run_instantiation_pass(c)?.get_context();
let inlined_context = inline_operations(instantiated_context, inline_config.clone())?;
let num_nodes = inlined_context.get_main_graph()?.get_num_nodes();
assert!(num_nodes <= 200);
}
Ok(())
}
}