1use crate::custom_ops::{CustomOperation, CustomOperationBody, Or};
3use crate::data_types::{array_type, scalar_type, vector_type, Type, BIT};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, GraphAnnotation, SliceElement};
6use crate::ops::multiplexer::Mux;
7use crate::ops::utils::{pull_out_bits, put_in_bits};
8
9use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
44pub struct Clip2K {
45 pub k: u64,
47}
48
49#[typetag::serde]
50impl CustomOperationBody for Clip2K {
51 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
52 if arguments_types.len() != 1 {
53 return Err(runtime_error!("Invalid number of arguments for Clip"));
54 }
55 let input_type = arguments_types[0].clone();
56 if !input_type.is_array() || input_type.get_scalar_type() != BIT {
57 return Err(runtime_error!("Clip can only be applied to bit arrays"));
58 }
59 let input_shape = input_type.get_shape();
60 let num_bits = input_shape[input_shape.len() - 1];
61 if self.k >= num_bits - 1 {
62 return Err(runtime_error!(
63 "Clip(k) can be applied only whenever k <= num_bits - 2"
64 ));
65 }
66 let bit_type = if input_shape.len() == 1 {
67 scalar_type(BIT)
68 } else {
69 array_type(input_shape[0..input_shape.len() - 1].to_vec(), BIT)
70 };
71 let aux_or_graph = context.create_graph()?;
72 let state = aux_or_graph.input(bit_type.clone())?;
73 let input = aux_or_graph.input(bit_type.clone())?;
74 let output_state =
75 aux_or_graph.custom_op(CustomOperation::new(Or {}), vec![state, input])?;
76 let empty = aux_or_graph.create_tuple(vec![])?;
77 let output = aux_or_graph.create_tuple(vec![output_state, empty])?;
78 aux_or_graph.set_output_node(output)?;
79 aux_or_graph.add_annotation(GraphAnnotation::AssociativeOperation)?;
80 aux_or_graph.finalize()?;
81 let g = context.create_graph()?;
82 let input = g.input(input_type)?;
83 let input_bits = pull_out_bits(input)?;
84 let is_negative = input_bits.get(vec![num_bits - 1])?;
85 let zero_bit = g.zeros(bit_type.clone())?;
86 let one_bit = g.ones(bit_type.clone())?;
87 let top_bits = input_bits
88 .get_slice(vec![SliceElement::SubArray(
89 Some(self.k as i64),
90 None,
91 None,
92 )])?
93 .array_to_vector()?;
94 let is_large_or_negative = g
95 .iterate(aux_or_graph, zero_bit.clone(), top_bits)?
96 .tuple_get(0)?;
97 let clipped_value = g
101 .create_tuple(vec![
102 zero_bit.repeat(self.k)?,
103 g.custom_op(
104 CustomOperation::new(Mux {}),
105 vec![is_negative, zero_bit.clone(), one_bit],
106 )?,
107 zero_bit.repeat(num_bits - self.k - 1)?,
108 ])?
109 .reshape(vector_type(num_bits, bit_type))?
110 .vector_to_array()?;
111 g.set_output_node(put_in_bits(g.custom_op(
112 CustomOperation::new(Mux {}),
113 vec![is_large_or_negative, clipped_value, input_bits],
114 )?)?)?;
115 g.finalize()?;
116 Ok(g)
117 }
118
119 fn get_name(&self) -> String {
120 format!("Clip({})", self.k)
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 use crate::custom_ops::{run_instantiation_pass, CustomOperation};
129 use crate::data_types::{array_type, tuple_type, INT32, INT64};
130 use crate::data_values::Value;
131 use crate::evaluators::random_evaluate;
132 use crate::graphs::create_context;
133 use crate::graphs::util::simple_context;
134
135 #[test]
136 fn test_well_behaved() {
137 || -> Result<()> {
138 let c = simple_context(|g| {
139 let i = g.input(array_type(vec![19, 64], BIT))?;
140 g.custom_op(CustomOperation::new(Clip2K { k: 10 }), vec![i])
141 })?;
142 let mapped_c = run_instantiation_pass(c)?;
143 let inputs = Value::from_flattened_array(
144 &vec![
145 0,
146 1,
147 -1,
148 2,
149 -2,
150 1023,
151 -1023,
152 1024,
153 -1024,
154 1025,
155 -1025,
156 2048,
157 -2048,
158 2047,
159 -2047,
160 2049,
161 -2049,
162 i64::MIN,
163 i64::MAX,
164 ],
165 INT64,
166 )?;
167 let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
168 .to_flattened_array_u64(array_type(vec![19], INT64))?;
169 assert_eq!(
170 result_v,
171 vec![0, 1, 0, 2, 0, 1023, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 0, 1024]
172 );
173 Ok(())
174 }()
175 .unwrap();
176 || -> Result<()> {
177 let c = simple_context(|g| {
178 let i = g.input(array_type(vec![64], BIT))?;
179 g.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i])
180 })?;
181 let mapped_c = run_instantiation_pass(c)?;
182 let inputs = Value::from_scalar(123456789, INT64)?;
183 let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
184 .to_u64(INT64)?;
185 assert_eq!(result_v, 1 << 20);
186 Ok(())
187 }()
188 .unwrap();
189 }
190
191 #[test]
192 fn test_malformed() {
193 || -> Result<()> {
194 let c = create_context()?;
195 let g = c.create_graph()?;
196 let i = g.input(array_type(vec![64], BIT))?;
197 let i1 = g.input(array_type(vec![64], INT32))?;
198 let i2 = g.input(tuple_type(vec![]))?;
199 assert!(g
200 .custom_op(CustomOperation::new(Clip2K { k: 64 }), vec![i])
201 .is_err());
202 assert!(g
203 .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![])
204 .is_err());
205 assert!(g
206 .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i1])
207 .is_err());
208 assert!(g
209 .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i2])
210 .is_err());
211 Ok(())
212 }()
213 .unwrap();
214 }
215}