use crate::custom_ops::{CustomOperation, CustomOperationBody, Or};
use crate::data_types::{array_type, scalar_type, vector_type, Type, BIT};
use crate::errors::Result;
use crate::graphs::{Context, Graph, GraphAnnotation, SliceElement};
use crate::ops::multiplexer::Mux;
use crate::ops::utils::{pull_out_bits, put_in_bits};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Clip2K {
pub k: u64,
}
#[typetag::serde]
impl CustomOperationBody for Clip2K {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
if arguments_types.len() != 1 {
return Err(runtime_error!("Invalid number of arguments for Clip"));
}
let input_type = arguments_types[0].clone();
if !input_type.is_array() || input_type.get_scalar_type() != BIT {
return Err(runtime_error!("Clip can only be applied to bit arrays"));
}
let input_shape = input_type.get_shape();
let num_bits = input_shape[input_shape.len() - 1];
if self.k >= num_bits - 1 {
return Err(runtime_error!(
"Clip(k) can be applied only whenever k <= num_bits - 2"
));
}
let bit_type = if input_shape.len() == 1 {
scalar_type(BIT)
} else {
array_type(input_shape[0..input_shape.len() - 1].to_vec(), BIT)
};
let aux_or_graph = context.create_graph()?;
let state = aux_or_graph.input(bit_type.clone())?;
let input = aux_or_graph.input(bit_type.clone())?;
let output_state =
aux_or_graph.custom_op(CustomOperation::new(Or {}), vec![state, input])?;
let empty = aux_or_graph.create_tuple(vec![])?;
let output = aux_or_graph.create_tuple(vec![output_state, empty])?;
aux_or_graph.set_output_node(output)?;
aux_or_graph.add_annotation(GraphAnnotation::AssociativeOperation)?;
aux_or_graph.finalize()?;
let g = context.create_graph()?;
let input = g.input(input_type)?;
let input_bits = pull_out_bits(input)?;
let is_negative = input_bits.get(vec![num_bits - 1])?;
let zero_bit = g.zeros(bit_type.clone())?;
let one_bit = g.ones(bit_type.clone())?;
let top_bits = input_bits
.get_slice(vec![SliceElement::SubArray(
Some(self.k as i64),
None,
None,
)])?
.array_to_vector()?;
let is_large_or_negative = g
.iterate(aux_or_graph, zero_bit.clone(), top_bits)?
.tuple_get(0)?;
let clipped_value = g
.create_tuple(vec![
zero_bit.repeat(self.k)?,
g.custom_op(
CustomOperation::new(Mux {}),
vec![is_negative, zero_bit.clone(), one_bit],
)?,
zero_bit.repeat(num_bits - self.k - 1)?,
])?
.reshape(vector_type(num_bits, bit_type))?
.vector_to_array()?;
g.set_output_node(put_in_bits(g.custom_op(
CustomOperation::new(Mux {}),
vec![is_large_or_negative, clipped_value, input_bits],
)?)?)?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!("Clip({})", self.k)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_ops::{run_instantiation_pass, CustomOperation};
use crate::data_types::{array_type, tuple_type, INT32, INT64};
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::create_context;
use crate::graphs::util::simple_context;
#[test]
fn test_well_behaved() {
|| -> Result<()> {
let c = simple_context(|g| {
let i = g.input(array_type(vec![19, 64], BIT))?;
g.custom_op(CustomOperation::new(Clip2K { k: 10 }), vec![i])
})?;
let mapped_c = run_instantiation_pass(c)?;
let inputs = Value::from_flattened_array(
&vec![
0,
1,
-1,
2,
-2,
1023,
-1023,
1024,
-1024,
1025,
-1025,
2048,
-2048,
2047,
-2047,
2049,
-2049,
i64::MIN,
i64::MAX,
],
INT64,
)?;
let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
.to_flattened_array_u64(array_type(vec![19], INT64))?;
assert_eq!(
result_v,
vec![0, 1, 0, 2, 0, 1023, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 0, 1024]
);
Ok(())
}()
.unwrap();
|| -> Result<()> {
let c = simple_context(|g| {
let i = g.input(array_type(vec![64], BIT))?;
g.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i])
})?;
let mapped_c = run_instantiation_pass(c)?;
let inputs = Value::from_scalar(123456789, INT64)?;
let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
.to_u64(INT64)?;
assert_eq!(result_v, 1 << 20);
Ok(())
}()
.unwrap();
}
#[test]
fn test_malformed() {
|| -> Result<()> {
let c = create_context()?;
let g = c.create_graph()?;
let i = g.input(array_type(vec![64], BIT))?;
let i1 = g.input(array_type(vec![64], INT32))?;
let i2 = g.input(tuple_type(vec![]))?;
assert!(g
.custom_op(CustomOperation::new(Clip2K { k: 64 }), vec![i])
.is_err());
assert!(g
.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![])
.is_err());
assert!(g
.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i1])
.is_err());
assert!(g
.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i2])
.is_err());
Ok(())
}()
.unwrap();
}
}