ciphercore_base/ops/
clip.rs

1//! Clip function that returns a given value if it is inside of the interval [0,2<sup>k</sup>] and clips values outside this interval to its edges.
2use crate::custom_ops::{CustomOperation, CustomOperationBody, Or};
3use crate::data_types::{array_type, scalar_type, vector_type, Type, BIT};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, GraphAnnotation, SliceElement};
6use crate::ops::multiplexer::Mux;
7use crate::ops::utils::{pull_out_bits, put_in_bits};
8
9use serde::{Deserialize, Serialize};
10
11/// A structure that defines the custom operation Clip2K that computes elementwise the following clipping function:
12/// - 0 if input <= 0,
13/// - input if 0 < input < 2<sup>k</sup>,
14/// - 2<sup>k</sup> if >= 2<sup>k</sup>.
15///
16/// This function is an approximation of [the sigmoid function](https://en.wikipedia.org/wiki/Sigmoid_function).
17///
18/// An array of length-n bitstrings is accepted as input. These bitstrings are interpreted as signed integers.
19///
20/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
21///
22/// # Custom operation arguments
23///
24/// - Node containing a binary array
25///
26/// # Custom operation returns
27///
28/// New Clip2K node
29///
30/// # Example
31///
32/// ```
33/// # use ciphercore_base::graphs::create_context;
34/// # use ciphercore_base::data_types::{array_type, BIT};
35/// # use ciphercore_base::custom_ops::{CustomOperation};
36/// # use ciphercore_base::ops::clip::Clip2K;
37/// let c = create_context().unwrap();
38/// let g = c.create_graph().unwrap();
39/// let t = array_type(vec![2, 16], BIT);
40/// let n1 = g.input(t.clone()).unwrap();
41/// let n2 = g.custom_op(CustomOperation::new(Clip2K {k: 4}), vec![n1]).unwrap();
42/// ```
43#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
44pub struct Clip2K {
45    /// 2<sup>k</sup> is the upper threshold of clipping
46    pub k: u64,
47}
48
49#[typetag::serde]
50impl CustomOperationBody for Clip2K {
51    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
52        if arguments_types.len() != 1 {
53            return Err(runtime_error!("Invalid number of arguments for Clip"));
54        }
55        let input_type = arguments_types[0].clone();
56        if !input_type.is_array() || input_type.get_scalar_type() != BIT {
57            return Err(runtime_error!("Clip can only be applied to bit arrays"));
58        }
59        let input_shape = input_type.get_shape();
60        let num_bits = input_shape[input_shape.len() - 1];
61        if self.k >= num_bits - 1 {
62            return Err(runtime_error!(
63                "Clip(k) can be applied only whenever k <= num_bits - 2"
64            ));
65        }
66        let bit_type = if input_shape.len() == 1 {
67            scalar_type(BIT)
68        } else {
69            array_type(input_shape[0..input_shape.len() - 1].to_vec(), BIT)
70        };
71        let aux_or_graph = context.create_graph()?;
72        let state = aux_or_graph.input(bit_type.clone())?;
73        let input = aux_or_graph.input(bit_type.clone())?;
74        let output_state =
75            aux_or_graph.custom_op(CustomOperation::new(Or {}), vec![state, input])?;
76        let empty = aux_or_graph.create_tuple(vec![])?;
77        let output = aux_or_graph.create_tuple(vec![output_state, empty])?;
78        aux_or_graph.set_output_node(output)?;
79        aux_or_graph.add_annotation(GraphAnnotation::AssociativeOperation)?;
80        aux_or_graph.finalize()?;
81        let g = context.create_graph()?;
82        let input = g.input(input_type)?;
83        let input_bits = pull_out_bits(input)?;
84        let is_negative = input_bits.get(vec![num_bits - 1])?;
85        let zero_bit = g.zeros(bit_type.clone())?;
86        let one_bit = g.ones(bit_type.clone())?;
87        let top_bits = input_bits
88            .get_slice(vec![SliceElement::SubArray(
89                Some(self.k as i64),
90                None,
91                None,
92            )])?
93            .array_to_vector()?;
94        let is_large_or_negative = g
95            .iterate(aux_or_graph, zero_bit.clone(), top_bits)?
96            .tuple_get(0)?;
97        // clipped_value = if is_negative then 0, else 2^k
98        // obtained by concatenating a bunch of zeros,
99        // zero or one, then bunch of zeros again
100        let clipped_value = g
101            .create_tuple(vec![
102                zero_bit.repeat(self.k)?,
103                g.custom_op(
104                    CustomOperation::new(Mux {}),
105                    vec![is_negative, zero_bit.clone(), one_bit],
106                )?,
107                zero_bit.repeat(num_bits - self.k - 1)?,
108            ])?
109            .reshape(vector_type(num_bits, bit_type))?
110            .vector_to_array()?;
111        g.set_output_node(put_in_bits(g.custom_op(
112            CustomOperation::new(Mux {}),
113            vec![is_large_or_negative, clipped_value, input_bits],
114        )?)?)?;
115        g.finalize()?;
116        Ok(g)
117    }
118
119    fn get_name(&self) -> String {
120        format!("Clip({})", self.k)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    use crate::custom_ops::{run_instantiation_pass, CustomOperation};
129    use crate::data_types::{array_type, tuple_type, INT32, INT64};
130    use crate::data_values::Value;
131    use crate::evaluators::random_evaluate;
132    use crate::graphs::create_context;
133    use crate::graphs::util::simple_context;
134
135    #[test]
136    fn test_well_behaved() {
137        || -> Result<()> {
138            let c = simple_context(|g| {
139                let i = g.input(array_type(vec![19, 64], BIT))?;
140                g.custom_op(CustomOperation::new(Clip2K { k: 10 }), vec![i])
141            })?;
142            let mapped_c = run_instantiation_pass(c)?;
143            let inputs = Value::from_flattened_array(
144                &vec![
145                    0,
146                    1,
147                    -1,
148                    2,
149                    -2,
150                    1023,
151                    -1023,
152                    1024,
153                    -1024,
154                    1025,
155                    -1025,
156                    2048,
157                    -2048,
158                    2047,
159                    -2047,
160                    2049,
161                    -2049,
162                    i64::MIN,
163                    i64::MAX,
164                ],
165                INT64,
166            )?;
167            let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
168                .to_flattened_array_u64(array_type(vec![19], INT64))?;
169            assert_eq!(
170                result_v,
171                vec![0, 1, 0, 2, 0, 1023, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 0, 1024]
172            );
173            Ok(())
174        }()
175        .unwrap();
176        || -> Result<()> {
177            let c = simple_context(|g| {
178                let i = g.input(array_type(vec![64], BIT))?;
179                g.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i])
180            })?;
181            let mapped_c = run_instantiation_pass(c)?;
182            let inputs = Value::from_scalar(123456789, INT64)?;
183            let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
184                .to_u64(INT64)?;
185            assert_eq!(result_v, 1 << 20);
186            Ok(())
187        }()
188        .unwrap();
189    }
190
191    #[test]
192    fn test_malformed() {
193        || -> Result<()> {
194            let c = create_context()?;
195            let g = c.create_graph()?;
196            let i = g.input(array_type(vec![64], BIT))?;
197            let i1 = g.input(array_type(vec![64], INT32))?;
198            let i2 = g.input(tuple_type(vec![]))?;
199            assert!(g
200                .custom_op(CustomOperation::new(Clip2K { k: 64 }), vec![i])
201                .is_err());
202            assert!(g
203                .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![])
204                .is_err());
205            assert!(g
206                .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i1])
207                .is_err());
208            assert!(g
209                .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i2])
210                .is_err());
211            Ok(())
212        }()
213        .unwrap();
214    }
215}