ciphercore_base/ops/
taylor_exponent.rs

1//! Exp(x) approximation relying on Taylor series expansion.
2use crate::custom_ops::{CustomOperation, CustomOperationBody};
3use crate::data_types::{array_type, scalar_type, vector_type, Type, BIT, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, SliceElement};
6use crate::ops::utils::{pull_out_bits, put_in_bits};
7
8use serde::{Deserialize, Serialize};
9
10use super::comparisons::GreaterThanEqualTo;
11use super::utils::{constant_scalar, multiply_fixed_point, zeros_like};
12
13/// A structure that defines the custom operation TaylorExponent that computes an approximate exp(x / (2 ** fixed_precision)) * (2 ** fixed_precision) using Taylor expansion.
14///
15/// Note that Taylor expansion correcly approximates exp(x) only for positive x, so we have to do A2B to get the MSB.
16/// Since we're doing A2B anyway, we can compute exp(integer_part(x)) and exp(fractional_part(x)) separately, computing the former directly from bits, and using Taylor expansion for the latter, getting better precision.
17/// See [the Keller-Sun paper, Algorithm 2](https://eprint.iacr.org/2022/933.pdf) for more details.
18///
19/// So far this operation supports only INT64 scalar type.
20///
21/// # Custom operation arguments
22///
23/// - Node containing a signed 64-bit array or scalar to compute the exponent
24///
25/// # Custom operation returns
26///
27/// New TaylorExponent node
28///
29/// # Example
30///
31/// ```
32/// # use ciphercore_base::graphs::create_context;
33/// # use ciphercore_base::data_types::{scalar_type, array_type, INT64};
34/// # use ciphercore_base::custom_ops::{CustomOperation};
35/// # use ciphercore_base::ops::taylor_exponent::TaylorExponent;
36/// let c = create_context().unwrap();
37/// let g = c.create_graph().unwrap();
38/// let t = array_type(vec![2, 3], INT64);
39/// let x = g.input(t.clone()).unwrap();
40/// let n2 = g.custom_op(CustomOperation::new(TaylorExponent {taylor_terms: 5, fixed_precision_points: 4}), vec![x]).unwrap();
41///
42// TODO: generalize to other types.
43#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
44pub struct TaylorExponent {
45    /// Number of terms from the Taylor expansion to consider (5 is typically enough).
46    pub taylor_terms: u64,
47    /// Assume that we're operating in fixed precision arithmetic with denominator 2 ** fixed_precision_points.
48    pub fixed_precision_points: u64,
49}
50
51#[typetag::serde]
52impl CustomOperationBody for TaylorExponent {
53    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
54        if arguments_types.len() != 1 {
55            return Err(runtime_error!(
56                "Invalid number of arguments for TaylorExponent"
57            ));
58        }
59        let t = arguments_types[0].clone();
60        if !t.is_scalar() && !t.is_array() {
61            return Err(runtime_error!(
62                "Argument in TaylorExponent must be a scalar or an array"
63            ));
64        }
65        let sc = t.get_scalar_type();
66        if sc != INT64 {
67            return Err(runtime_error!(
68                "Argument in TaylorExponent must consist of INT64's"
69            ));
70        }
71        if self.fixed_precision_points > 15 {
72            return Err(runtime_error!("fixed_precision_points is too large."));
73        }
74
75        let bit_type = if t.is_scalar() {
76            scalar_type(BIT)
77        } else {
78            array_type(t.get_shape(), BIT)
79        };
80
81        let g = context.create_graph()?;
82        let arg = g.input(t.clone())?;
83        // Below, we compute 2 ** (arg / ln(2)) rather than exp(arg).
84        // `x` is arg * ln(2).
85        let one_over_ln2_int = (((1 << self.fixed_precision_points) as f64) / 2.0_f64.ln()) as u64;
86        let one_over_ln2 = constant_scalar(&g, one_over_ln2_int, sc)?;
87        let x = multiply_fixed_point(arg, one_over_ln2, self.fixed_precision_points)?;
88
89        let binary_x = x.a2b()?;
90        let x_bits = pull_out_bits(binary_x.clone())?;
91        let msb = x_bits.get(vec![63])?;
92
93        // STAGE 1: compute exp(integer part of the argument).
94        // Note that if we're looking at the int part, we're computing the product of 2 ** (2 ** k) if k'th bit is 1.
95        // Since we work with 31-bit fixed-point arithmetic, the exponent is limited from above by 31 - fixed_precision_points.
96        let max_exp_bits = (31f64 - self.fixed_precision_points as f64).log2().ceil() as u64;
97        let one = g.ones(t)?;
98        let mut exp_integer = one.clone();
99        for i in self.fixed_precision_points..self.fixed_precision_points + max_exp_bits {
100            let bit = x_bits.get(vec![i])?;
101            let j = i - self.fixed_precision_points;
102            let p2 = constant_scalar(&g, 1_u64 << (1_u64 << j), sc)?;
103            // `term` is 1 if bit is not set, and 2 ** (2 ** j) otherwise.
104            let term = p2
105                .subtract(one.clone())?
106                .mixed_multiply(bit.clone())?
107                .add(one.clone())?;
108            // TODO: this can be optimized to be depth-3 rather than depth-5.
109            exp_integer = exp_integer.multiply(term)?;
110        }
111
112        // STAGE 2: compute exp(fractional part of the argument).
113        // Extract fractional part.
114        let exp_fractional = if self.fixed_precision_points == 0 {
115            one
116        } else {
117            let bits_after_point = x_bits.get_slice(vec![
118                SliceElement::SubArray(Some(0), Some(self.fixed_precision_points as i64), None),
119                SliceElement::Ellipsis,
120            ])?;
121            let mut bits_before_point_shape = x_bits.get_type()?.get_shape();
122            bits_before_point_shape[0] = 64 - self.fixed_precision_points;
123            let zero_bits_before_point = g.zeros(array_type(bits_before_point_shape, BIT))?;
124            let stacked_frac_bits = g.create_tuple(vec![
125                bits_after_point.array_to_vector()?,
126                zero_bits_before_point.array_to_vector()?,
127            ])?;
128            let stacked_type = vector_type(64, bit_type);
129            let x_frac = put_in_bits(stacked_frac_bits.reshape(stacked_type)?.vector_to_array()?)?
130                .b2a(sc)?;
131
132            // Now, we want 2 ** x = exp(x * ln(2)) = \sum_i (ln(2) * x) ** i / i!
133            let mut exp_fractional = zeros_like(x_frac.clone())?;
134            let mut coef = constant_scalar(&g, 1 << self.fixed_precision_points, sc)?;
135            let ln2_int = (2_f64.ln() * ((1 << self.fixed_precision_points) as f64)) as u64;
136            let ln2 = constant_scalar(&g, ln2_int, sc)?;
137            let y = multiply_fixed_point(x_frac, ln2, self.fixed_precision_points)?;
138            for i in 0..self.taylor_terms {
139                exp_fractional = exp_fractional.add(coef.clone())?;
140                if i < self.taylor_terms - 1 {
141                    coef = coef.multiply(y.clone())?;
142                    // We need to divide it by i + 1, and by 2 ** fixed_precision_points, so we combine the two.
143                    coef = coef.truncate((i as u128 + 1) << self.fixed_precision_points)?;
144                }
145            }
146            exp_fractional
147        };
148
149        // STAGE 3: combine the answers, and do exp(-x) if x was negative.
150        // No truncation here, since exp_integer is a normal int number, not a fixed-precision one.
151        let exp = exp_fractional.multiply(exp_integer)?;
152        // If x < 0, then it can be represented as x = -2^max_exp_bits + integer_bits + fractional_bits
153        // exp is equal to 2^(integer_bits + fractional_bits).
154        // Thus, truncation by 2^(2^max_exp_bits) changes the sign of the exponent.
155        let one_over_exp = exp.truncate(1u128 << (1u64 << max_exp_bits))?;
156        // Our maximal precision is 15, leading to minimum value around 3e-5. Exp(-10) is 4e-5
157        // If x is smaller than -10, return 0.
158        let upper_bound_for_inversion =
159            constant_scalar(&g, (-10) * (1 << self.fixed_precision_points), sc)?.a2b()?;
160        let inversion_overflow_bit = g.custom_op(
161            CustomOperation::new(GreaterThanEqualTo {
162                signed_comparison: true,
163            }),
164            vec![binary_x, upper_bound_for_inversion],
165        )?;
166        let mut result = exp.add(one_over_exp.subtract(exp.clone())?.mixed_multiply(msb)?)?;
167        result = result.mixed_multiply(inversion_overflow_bit)?;
168        result.set_as_output()?;
169        g.finalize()?;
170        Ok(g)
171    }
172
173    fn get_name(&self) -> String {
174        format!(
175            "TaylorExponent(taylor_terms={}, fixed_precision_denom=2**{})",
176            self.taylor_terms, self.fixed_precision_points
177        )
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    use crate::custom_ops::run_instantiation_pass;
186    use crate::custom_ops::CustomOperation;
187    use crate::data_values::Value;
188    use crate::evaluators::random_evaluate;
189    use crate::graphs::util::simple_context;
190
191    fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
192        let c = simple_context(|g| {
193            let i = g.input(scalar_type(INT64))?;
194            g.custom_op(
195                CustomOperation::new(TaylorExponent {
196                    taylor_terms: 5,
197                    fixed_precision_points: precision,
198                }),
199                vec![i],
200            )
201        })?;
202        let mapped_c = run_instantiation_pass(c)?;
203        let result = random_evaluate(
204            mapped_c.get_context().get_main_graph()?,
205            vec![Value::from_scalar(arg, INT64)?],
206        )?;
207        let res = result.to_i64(INT64)?;
208        Ok(res)
209    }
210
211    fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
212        let array_t = array_type(vec![arg.len() as u64], INT64);
213        let c = simple_context(|g| {
214            let i = g.input(array_t.clone())?;
215            g.custom_op(
216                CustomOperation::new(TaylorExponent {
217                    taylor_terms: 5,
218                    fixed_precision_points: 10,
219                }),
220                vec![i],
221            )
222        })?;
223        let mapped_c = run_instantiation_pass(c)?;
224        let result = random_evaluate(
225            mapped_c.get_context().get_main_graph()?,
226            vec![Value::from_flattened_array(&arg, INT64)?],
227        )?;
228        result.to_flattened_array_i64(array_t)
229    }
230
231    #[test]
232    fn test_exp_scalar() {
233        for i in vec![-10000, -1000, -100, -1, 0, 1, 100, 1000, 10000] {
234            let expected = (((i as f64) / 1024.0).exp() * 1024.0) as i64;
235            let actual = scalar_helper(i, 10).unwrap();
236            let relative_error = ((expected - actual).abs() as f64)
237                / (1.0 + f64::max(expected as f64, actual as f64));
238            assert!(relative_error <= 0.01);
239        }
240    }
241
242    #[test]
243    fn test_exp_array() {
244        let arr = vec![23, 32, 57, 1271, 183, 555, -23, -32, -57, -1271, -183, -555];
245        let res = array_helper(arr.clone()).unwrap();
246        for i in 0..arr.len() {
247            let expected = (((arr[i] as f64) / 1024.0).exp() * 1024.0) as i64;
248            let actual = res[i];
249            let relative_error = ((expected - actual).abs() as f64)
250                / (1.0 + f64::max(expected as f64, actual as f64));
251            assert!(relative_error <= 0.01);
252        }
253    }
254
255    #[test]
256    fn test_exp_integer() {
257        for i in vec![0, 1, 2, 3, 5] {
258            // With zero precision, ln(2) = 1, so it'll compute 2**i instead of exp(i).
259            let expected = 1 << i;
260            let actual = scalar_helper(i, 0).unwrap();
261            let absolute_error = (expected - actual).abs();
262            assert!(absolute_error == 0);
263        }
264    }
265}