1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
//! Clip function that returns a given value if it is inside of the interval [0,2<sup>k</sup>] and clips values outside this interval to its edges.
use crate::custom_ops::{CustomOperation, CustomOperationBody, Or};
use crate::data_types::{array_type, get_size_in_bits, scalar_type, vector_type, Type, BIT};
use crate::data_values::Value;
use crate::errors::Result;
use crate::graphs::{Context, Graph, GraphAnnotation, SliceElement};
use crate::ops::multiplexer::Mux;
use crate::ops::utils::{pull_out_bits, put_in_bits};

use serde::{Deserialize, Serialize};

/// A structure that defines the custom operation Clip2K that computes elementwise the following clipping function:
/// - 0 if input <= 0,
/// - input if 0 < input < 2<sup>k</sup>,
/// - 2<sup>k</sup> if >= 2<sup>k</sup>.
///
/// This function is an approximation of [the sigmoid function](https://en.wikipedia.org/wiki/Sigmoid_function).
///
/// An array of length-n bitstrings is accepted as input. These bitstrings are interpreted as signed integers.
///
/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
///
/// # Custom operation arguments
///
/// - Node containing a binary array
///
/// # Custom operation returns
///
/// New Clip2K node
///
/// # Example
///
/// ```
/// # use ciphercore_base::graphs::create_context;
/// # use ciphercore_base::data_types::{array_type, BIT};
/// # use ciphercore_base::custom_ops::{CustomOperation};
/// # use ciphercore_base::ops::clip::Clip2K;
/// let c = create_context().unwrap();
/// let g = c.create_graph().unwrap();
/// let t = array_type(vec![2, 16], BIT);
/// let n1 = g.input(t.clone()).unwrap();
/// let n2 = g.custom_op(CustomOperation::new(Clip2K {k: 4}), vec![n1]).unwrap();
/// ```
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct Clip2K {
    /// 2<sup>k</sup> is the upper threshold of clipping
    pub k: u64,
}

#[typetag::serde]
impl CustomOperationBody for Clip2K {
    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
        if arguments_types.len() != 1 {
            return Err(runtime_error!("Invalid number of arguments for Clip"));
        }
        let input_type = arguments_types[0].clone();
        if !input_type.is_array() || input_type.get_scalar_type() != BIT {
            return Err(runtime_error!("Clip can only be applied to bit arrays"));
        }
        let input_shape = input_type.get_shape();
        let num_bits = input_shape[input_shape.len() - 1];
        if self.k >= num_bits - 1 {
            return Err(runtime_error!(
                "Clip(k) can be applied only whenever k <= num_bits - 2"
            ));
        }
        let bit_type = if input_shape.len() == 1 {
            scalar_type(BIT)
        } else {
            array_type(input_shape[0..input_shape.len() - 1].to_vec(), BIT)
        };
        let aux_or_graph = context.create_graph()?;
        let state = aux_or_graph.input(bit_type.clone())?;
        let input = aux_or_graph.input(bit_type.clone())?;
        let output_state =
            aux_or_graph.custom_op(CustomOperation::new(Or {}), vec![state, input])?;
        let empty = aux_or_graph.create_tuple(vec![])?;
        let output = aux_or_graph.create_tuple(vec![output_state, empty])?;
        aux_or_graph.set_output_node(output)?;
        aux_or_graph.add_annotation(GraphAnnotation::AssociativeOperation)?;
        aux_or_graph.finalize()?;
        let g = context.create_graph()?;
        let input = g.input(input_type)?;
        let input_bits = pull_out_bits(input)?;
        let is_negative = input_bits.get(vec![num_bits - 1])?;
        let zero_bit = g.constant(bit_type.clone(), Value::zero_of_type(bit_type.clone()))?;
        let all_ones = Value::from_bytes(vec![
            255u8;
            ((get_size_in_bits(bit_type.clone())? + 7) / 8)
                as usize
        ]);
        let one_bit = g.constant(bit_type.clone(), all_ones)?;
        let top_bits = input_bits
            .get_slice(vec![SliceElement::SubArray(
                Some(self.k as i64),
                None,
                None,
            )])?
            .array_to_vector()?;
        let is_large_or_negative = g
            .iterate(aux_or_graph, zero_bit.clone(), top_bits)?
            .tuple_get(0)?;
        // clipped_value = if is_negative then 0, else 2^k
        // obtained by concatenating a bunch of zeros,
        // zero or one, then bunch of zeros again
        let clipped_value = g
            .create_tuple(vec![
                zero_bit.repeat(self.k)?,
                g.custom_op(
                    CustomOperation::new(Mux {}),
                    vec![is_negative, zero_bit.clone(), one_bit],
                )?,
                zero_bit.repeat(num_bits - self.k - 1)?,
            ])?
            .reshape(vector_type(num_bits, bit_type))?
            .vector_to_array()?;
        g.set_output_node(put_in_bits(g.custom_op(
            CustomOperation::new(Mux {}),
            vec![is_large_or_negative, clipped_value, input_bits],
        )?)?)?;
        g.finalize()?;
        Ok(g)
    }

    fn get_name(&self) -> String {
        format!("Clip({})", self.k)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use crate::custom_ops::{run_instantiation_pass, CustomOperation};
    use crate::data_types::{array_type, tuple_type, INT32, INT64};
    use crate::evaluators::random_evaluate;
    use crate::graphs::create_context;

    #[test]
    fn test_well_behaved() {
        || -> Result<()> {
            let c = create_context()?;
            let g = c.create_graph()?;
            let i = g.input(array_type(vec![19, 64], BIT))?;
            let o = g.custom_op(CustomOperation::new(Clip2K { k: 10 }), vec![i])?;
            g.set_output_node(o)?;
            g.finalize()?;
            c.set_main_graph(g)?;
            c.finalize()?;
            let mapped_c = run_instantiation_pass(c)?;
            let inputs = Value::from_flattened_array(
                &vec![
                    0,
                    1,
                    -1,
                    2,
                    -2,
                    1023,
                    -1023,
                    1024,
                    -1024,
                    1025,
                    -1025,
                    2048,
                    -2048,
                    2047,
                    -2047,
                    2049,
                    -2049,
                    i64::MIN,
                    i64::MAX,
                ],
                INT64,
            )?;
            let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
                .to_flattened_array_u64(array_type(vec![19], INT64))?;
            assert_eq!(
                result_v,
                vec![0, 1, 0, 2, 0, 1023, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 1024, 0, 0, 1024]
            );
            Ok(())
        }()
        .unwrap();
        || -> Result<()> {
            let c = create_context()?;
            let g = c.create_graph()?;
            let i = g.input(array_type(vec![64], BIT))?;
            let o = g.custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i])?;
            g.set_output_node(o)?;
            g.finalize()?;
            c.set_main_graph(g)?;
            c.finalize()?;
            let mapped_c = run_instantiation_pass(c)?;
            let inputs = Value::from_scalar(123456789, INT64)?;
            let result_v = random_evaluate(mapped_c.get_context().get_main_graph()?, vec![inputs])?
                .to_u64(INT64)?;
            assert_eq!(result_v, 1 << 20);
            Ok(())
        }()
        .unwrap();
    }

    #[test]
    fn test_malformed() {
        || -> Result<()> {
            let c = create_context()?;
            let g = c.create_graph()?;
            let i = g.input(array_type(vec![64], BIT))?;
            let i1 = g.input(array_type(vec![64], INT32))?;
            let i2 = g.input(tuple_type(vec![]))?;
            assert!(g
                .custom_op(CustomOperation::new(Clip2K { k: 64 }), vec![i])
                .is_err());
            assert!(g
                .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![])
                .is_err());
            assert!(g
                .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i1])
                .is_err());
            assert!(g
                .custom_op(CustomOperation::new(Clip2K { k: 20 }), vec![i2])
                .is_err());
            Ok(())
        }()
        .unwrap();
    }
}