ciphercore_base/ops/
inverse_sqrt.rs

1//! Inverse square root approximation via [the Newton-Raphson method](https://en.wikipedia.org/wiki/Newton%27s_method#Square_root).
2use crate::custom_ops::{CustomOperation, CustomOperationBody, Or};
3use crate::data_types::{array_type, scalar_type, Type, BIT, INT64, UINT64};
4use crate::data_values::Value;
5use crate::errors::Result;
6use crate::graphs::{Context, Graph, GraphAnnotation};
7use crate::ops::utils::{pull_out_bits, put_in_bits};
8
9use serde::{Deserialize, Serialize};
10
11use super::utils::{constant_scalar, multiply_fixed_point, single_bit_to_arithmetic};
12
13/// A structure that defines the custom operation InverseSqrt that computes an approximate inverse square root using Newton iterations.
14///
15/// In particular, this operation computes an approximation of 2<sup>denominator_cap_2k</sup> / sqrt(input).
16///
17/// Input must be of the scalar type UINT64/INT64 and be in (0, 2<sup>2 * denominator_cap_2k - 1</sup>) range.
18/// The input is also assumed to be small enough (less than 2<sup>21</sup>), otherwise integer overflows
19/// are possible, yielding incorrect results.
20/// In case of INT64 type, negative inputs yield undefined behavior.
21///
22/// Optionally, an initial approximation for the Newton iterations can be provided.
23/// In this case, the operation might be faster and of lower depth, however, it must be guaranteed that
24/// 2<sup>2 * denominator_cap_2k - 2</sup> <= input * initial_approximation <= 2<sup>2 * denominator_cap_2k</sup>.
25///
26/// The following formula for the Newton iterations is used:
27///   x_{i + 1} = x_i * (3 / 2 - d / 2 * x_i * x_i).
28///
29/// # Custom operation arguments
30///
31/// - Node containing an unsigned 64-bit array or scalar to compute the inverse square root
32/// - (optional) Node containing an array or scalar that serves as an initial approximation of the Newton iterations
33///
34/// # Custom operation returns
35///
36/// New InverseSqrt node
37///
38/// # Example
39///
40/// ```
41/// # use ciphercore_base::graphs::create_context;
42/// # use ciphercore_base::data_types::{scalar_type, array_type, UINT64};
43/// # use ciphercore_base::custom_ops::{CustomOperation};
44/// # use ciphercore_base::ops::inverse_sqrt::InverseSqrt;
45/// let c = create_context().unwrap();
46/// let g = c.create_graph().unwrap();
47/// let t = array_type(vec![2, 3], UINT64);
48/// let n1 = g.input(t.clone()).unwrap();
49/// let guess_n = g.input(t.clone()).unwrap();
50/// let n2 = g.custom_op(CustomOperation::new(InverseSqrt {iterations: 10, denominator_cap_2k: 4}), vec![n1, guess_n]).unwrap();
51///
52// TODO: generalize to other types.
53#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
54pub struct InverseSqrt {
55    /// Number of iterations of the Newton-Raphson algorithm
56    pub iterations: u64,
57    /// Number of output bits that are approximated
58    pub denominator_cap_2k: u64,
59}
60
61#[typetag::serde]
62impl CustomOperationBody for InverseSqrt {
63    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
64        if arguments_types.len() != 1 && arguments_types.len() != 2 {
65            return Err(runtime_error!(
66                "Invalid number of arguments for InverseSqrt"
67            ));
68        }
69        let t = arguments_types[0].clone();
70        if !t.is_scalar() && !t.is_array() {
71            return Err(runtime_error!(
72                "Divisor in InverseSqrt must be a scalar or an array"
73            ));
74        }
75        let sc = t.get_scalar_type();
76        if sc != UINT64 && sc != INT64 {
77            return Err(runtime_error!(
78                "Divisor in InverseSqrt must consist of UINT64's or INT64's"
79            ));
80        }
81        let has_initial_approximation = arguments_types.len() == 2;
82        if has_initial_approximation {
83            let divisor_t = arguments_types[1].clone();
84            if divisor_t != t {
85                return Err(runtime_error!(
86                    "Divisor and initial approximation must have the same type."
87                ));
88            }
89        }
90        if self.denominator_cap_2k > 31 {
91            return Err(runtime_error!("denominator_cap_2k is too large."));
92        }
93
94        if self.denominator_cap_2k <= 1 {
95            return Err(runtime_error!("denominator_cap_2k is too small."));
96        }
97
98        let bit_type = if t.is_scalar() {
99            scalar_type(BIT)
100        } else {
101            array_type(t.get_shape(), BIT)
102        };
103
104        // Graph for identifying highest one bit.
105        let g_highest_one_bit = context.create_graph()?;
106        {
107            let input_state = g_highest_one_bit.input(bit_type.clone())?;
108            let input_bit = g_highest_one_bit.input(bit_type.clone())?;
109
110            let one = g_highest_one_bit.ones(scalar_type(BIT))?;
111            let not_input_state = one.add(input_state.clone())?;
112            // If input state is 1, then the highest bit has been already encountered.
113            // All other bits can be set to zero.
114            let output = not_input_state.multiply(input_bit)?;
115            // new_state is equal to input_state OR input_bit
116            // Hence, input state becomes and stays 1 once the highest bit has been encountered.
117            let new_state = input_state.add(output.clone())?;
118            let output_tuple = g_highest_one_bit.create_tuple(vec![new_state, output])?;
119            output_tuple.set_as_output()?;
120        }
121        g_highest_one_bit.add_annotation(GraphAnnotation::AssociativeOperation)?;
122        g_highest_one_bit.finalize()?;
123
124        let g = context.create_graph()?;
125        let divisor = g.input(t.clone())?;
126        let mut approximation = if has_initial_approximation {
127            g.input(t)?
128        } else if self.denominator_cap_2k == 0 {
129            let two = constant_scalar(&g, 2, sc)?;
130            g.zeros(t)?.add(two)?
131        } else {
132            let divisor_bits = pull_out_bits(divisor.a2b()?)?.array_to_vector()?;
133            let mut divisor_bits_reversed = vec![];
134            for i in 0..self.denominator_cap_2k {
135                // We group pairs of consecutive bits together for the purpose of the initial approximation.
136                // Namely, consider divisor to have digits (d_0, ..., d_31) in base-4. Then, if d_k is the highest
137                // non-zero digit, our approximation will be 2 ** (cap - k).
138                // Indeed, 4 ** k <= divisor < 4 ** (k + 1), so 2 ** (-k - 1) < 1 / sqrt(divisor) < 2 ** -k.
139                let index1 = constant_scalar(&g, 2 * self.denominator_cap_2k - 2 * i - 1, UINT64)?;
140                let index2 = constant_scalar(&g, 2 * self.denominator_cap_2k - 2 * i - 2, UINT64)?;
141                let bit1 = divisor_bits.vector_get(index1)?;
142                let bit2 = divisor_bits.vector_get(index2)?;
143                let bit = g.custom_op(CustomOperation::new(Or {}), vec![bit1, bit2])?;
144                divisor_bits_reversed.push(bit);
145            }
146            let zero = g.zeros(bit_type.clone())?;
147            let highest_one_bit_binary = g
148                .iterate(
149                    g_highest_one_bit,
150                    zero,
151                    g.create_vector(bit_type, divisor_bits_reversed)?,
152                )?
153                .tuple_get(1)?
154                .vector_to_array()?;
155            let highest_one_bit = single_bit_to_arithmetic(highest_one_bit_binary, sc)?;
156            let first_approximation_bits = put_in_bits(highest_one_bit)?;
157            let mut powers_of_two = vec![];
158            for i in 0..self.denominator_cap_2k {
159                powers_of_two.push(1u64 << i);
160            }
161            let powers_of_two_node = g.constant(
162                array_type(vec![self.denominator_cap_2k], sc),
163                Value::from_flattened_array(&powers_of_two, sc)?,
164            )?;
165            first_approximation_bits.dot(powers_of_two_node)?
166        };
167        // Now, we do Newton approximation for computing 1 / sqrt(x), where x = divisor / (2 ** cap).
168        // We use F(t) = 1 / (t ** 2) - d;
169        // The formula for the Newton method is x_{i + 1} = x_i * (3 / 2 - d / 2 * x_i * x_i).
170        let three_halves = constant_scalar(&g, 3 << (self.denominator_cap_2k - 1), sc)?;
171        for _ in 0..self.iterations {
172            let x = approximation;
173            // We have two terms: 3/2 and divisor * x * x / 2. Since x is multiplied by
174            // 2 ** denominator_cap_2k, we should normalize the second term before subtracting from the first one.
175            let ax2 = divisor.clone().multiply(x.clone())?.multiply(x.clone())?;
176            let ax2_norm = g.truncate(ax2, 1 << (self.denominator_cap_2k + 1))?;
177
178            let mult = three_halves.subtract(ax2_norm)?;
179            approximation = multiply_fixed_point(mult, x, self.denominator_cap_2k)?;
180        }
181        approximation.set_as_output()?;
182        g.finalize()?;
183        Ok(g)
184    }
185
186    fn get_name(&self) -> String {
187        format!(
188            "InverseSqrt(iterations={}, cap=2**{})",
189            self.iterations, self.denominator_cap_2k
190        )
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    use crate::custom_ops::run_instantiation_pass;
199    use crate::custom_ops::CustomOperation;
200    use crate::data_types::ScalarType;
201    use crate::data_values::Value;
202    use crate::evaluators::random_evaluate;
203    use crate::graphs::util::simple_context;
204    use crate::inline::inline_common::DepthOptimizationLevel;
205    use crate::inline::inline_ops::inline_operations;
206    use crate::inline::inline_ops::InlineConfig;
207    use crate::inline::inline_ops::InlineMode;
208    use crate::mpc::mpc_compiler::prepare_for_mpc_evaluation;
209    use crate::mpc::mpc_compiler::IOStatus;
210
211    fn scalar_helper(
212        divisor: u64,
213        initial_approximation: Option<u64>,
214        st: ScalarType,
215    ) -> Result<u64> {
216        let c = simple_context(|g| {
217            let i = g.input(scalar_type(st))?;
218            if let Some(approx) = initial_approximation {
219                let approx_const = constant_scalar(&g, approx, st)?;
220                g.custom_op(
221                    CustomOperation::new(InverseSqrt {
222                        iterations: 5,
223                        denominator_cap_2k: 10,
224                    }),
225                    vec![i, approx_const],
226                )
227            } else {
228                g.custom_op(
229                    CustomOperation::new(InverseSqrt {
230                        iterations: 5,
231                        denominator_cap_2k: 10,
232                    }),
233                    vec![i],
234                )
235            }
236        })?;
237        let mapped_c = run_instantiation_pass(c)?;
238        let result = random_evaluate(
239            mapped_c.get_context().get_main_graph()?,
240            vec![Value::from_scalar(divisor, st)?],
241        )?;
242        if st == UINT64 {
243            result.to_u64(st)
244        } else {
245            let res = result.to_i64(st)?;
246            assert!(res >= 0);
247            Ok(res as u64)
248        }
249    }
250
251    fn array_helper(divisor: Vec<u64>, st: ScalarType) -> Result<Vec<u64>> {
252        let array_t = array_type(vec![divisor.len() as u64], st);
253        let c = simple_context(|g| {
254            let i = g.input(array_t.clone())?;
255            g.custom_op(
256                CustomOperation::new(InverseSqrt {
257                    iterations: 5,
258                    denominator_cap_2k: 10,
259                }),
260                vec![i],
261            )
262        })?;
263        let mapped_c = run_instantiation_pass(c)?;
264        let result = random_evaluate(
265            mapped_c.get_context().get_main_graph()?,
266            vec![Value::from_flattened_array(&divisor, st)?],
267        )?;
268        result.to_flattened_array_u64(array_t)
269    }
270
271    #[test]
272    fn test_inverse_sqrt_scalar() {
273        for i in vec![1, 2, 3, 123, 300, 500, 700] {
274            let expected = (1024.0 / (i as f64).powf(0.5)) as i64;
275            assert!((scalar_helper(i, None, UINT64).unwrap() as i64 - expected).abs() <= 1);
276            assert!((scalar_helper(i, None, INT64).unwrap() as i64 - expected).abs() <= 1);
277        }
278    }
279
280    #[test]
281    fn test_inverse_sqrt_array() {
282        let arr = vec![23, 32, 57, 71, 183, 555];
283        let div1 = array_helper(arr.clone(), UINT64).unwrap();
284        let div2 = array_helper(arr.clone(), INT64).unwrap();
285        for i in 0..arr.len() {
286            let expected = (1024.0 / (arr[i] as f64).powf(0.5)) as i64;
287            assert!((div1[i] as i64 - expected).abs() <= 1);
288            assert!((div2[i] as i64 - expected).abs() <= 1);
289        }
290    }
291
292    #[test]
293    fn test_inverse_sqrt_with_initial_guess() {
294        for i in vec![1, 2, 3, 123, 300, 500, 700] {
295            let mut initial_guess = 1;
296            while initial_guess * initial_guess * i * 4 < 1024 * 1024 {
297                initial_guess *= 2;
298            }
299            let expected = (1024.0 / (i as f64).powf(0.5)) as i64;
300            assert!(
301                (scalar_helper(i, Some(initial_guess), UINT64).unwrap() as i64 - expected).abs()
302                    <= 1
303            );
304            assert!(
305                (scalar_helper(i, Some(initial_guess), INT64).unwrap() as i64 - expected).abs()
306                    <= 1
307            );
308        }
309    }
310
311    #[test]
312    fn test_inverse_sqrt_negative_values_nothing_bad() {
313        for i in vec![-1, -100, -1000] {
314            scalar_helper(i as u64, None, INT64).unwrap();
315        }
316    }
317
318    #[test]
319    fn test_inverse_sqrt_compiles_end2end() -> Result<()> {
320        let c = simple_context(|g| {
321            let i = g.input(scalar_type(INT64))?;
322            g.custom_op(
323                CustomOperation::new(InverseSqrt {
324                    iterations: 5,
325                    denominator_cap_2k: 10,
326                }),
327                vec![i],
328            )
329        })?;
330        let inline_config = InlineConfig {
331            default_mode: InlineMode::DepthOptimized(DepthOptimizationLevel::Default),
332            ..Default::default()
333        };
334        let instantiated_context = run_instantiation_pass(c)?.get_context();
335        let inlined_context = inline_operations(instantiated_context, inline_config.clone())?;
336        let _unused = prepare_for_mpc_evaluation(
337            inlined_context,
338            vec![vec![IOStatus::Shared]],
339            vec![vec![]],
340            inline_config,
341        )?;
342        Ok(())
343    }
344}