ciphercore_base/ops/pwl/
approx_exponent.rs

1//! Exp(x) piecewise-linear approximation.
2use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{Type, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9use super::approx_pointwise::{create_approximation, PWLConfig};
10
11/// A structure that defines the custom operation ApproxExponent that computes an approximate exp(x / (2 ** precision)) * (2 ** precision) using piecewise-linear approximation.
12///
13/// So far this operation supports only INT64 scalar type.
14///
15/// # Custom operation arguments
16///
17/// - Node containing a signed 64-bit array or scalar to compute the exponent
18///
19/// # Custom operation returns
20///
21/// New ApproxExponent node
22///
23/// # Example
24///
25/// ```
26/// # use ciphercore_base::graphs::create_context;
27/// # use ciphercore_base::data_types::{scalar_type, array_type, INT64};
28/// # use ciphercore_base::custom_ops::{CustomOperation};
29/// # use ciphercore_base::ops::pwl::approx_exponent::ApproxExponent;
30/// let c = create_context().unwrap();
31/// let g = c.create_graph().unwrap();
32/// let t = array_type(vec![3], INT64);
33/// let x = g.input(t.clone()).unwrap();
34/// let n = g.custom_op(CustomOperation::new(ApproxExponent {precision: 4}), vec![x]).unwrap();
35///
36// TODO: generalize to other types.
37#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
38pub struct ApproxExponent {
39    /// Assume that we're operating in fixed precision arithmetic with denominator 2 ** precision.
40    pub precision: u64,
41}
42
43#[typetag::serde]
44impl CustomOperationBody for ApproxExponent {
45    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
46        if arguments_types.len() != 1 {
47            return Err(runtime_error!(
48                "Invalid number of arguments for ApproxExponent"
49            ));
50        }
51        let t = arguments_types[0].clone();
52        if !t.is_scalar() && !t.is_array() {
53            return Err(runtime_error!(
54                "Argument in ApproxExponent must be a scalar or an array"
55            ));
56        }
57        let sc = t.get_scalar_type();
58        if sc != INT64 {
59            return Err(runtime_error!(
60                "Argument in ApproxExponent must consist of INT64's"
61            ));
62        }
63        if self.precision > 30 || self.precision == 0 {
64            return Err(runtime_error!("`precision` should be in range [1, 30]."));
65        }
66
67        let g = context.create_graph()?;
68        let arg = g.input(t)?;
69        // Choice of parameters:
70        // -- left/right: our typical use-case is precision=15, leading to minimum value around 3e-5. Exp(-10) is 4e-5, so right=-left=10 is a reasonable choice with our precision;
71        // -- log_buckets: we look at max relative difference to the real exponent. It looks as follows (note that this usually happens around -10, for higher values, it is more accurate):
72        //    log_buckets=4 => 36%,
73        //    log_buckets=5 => 21%,
74        //    log_buckets=6 => 22%,
75        //    log_buckets=7 => 22%.
76        // Note that it is not monotonic due to numerical issues for very low values. From this table, the best value is 5.
77        // -- flatten_left/flatten_right: exponent is flat on the left, so we replicate this in our approximation (we don't want to go to 0 and below).
78        let result = create_approximation(
79            arg,
80            |x| x.exp(),
81            -10.0,
82            10.0,
83            self.precision,
84            PWLConfig {
85                log_buckets: 5,
86                flatten_left: true,
87                flatten_right: false,
88            },
89        )?;
90        result.set_as_output()?;
91        g.finalize()?;
92        Ok(g)
93    }
94
95    fn get_name(&self) -> String {
96        format!("ApproxExponent(scaling_factor=2**{})", self.precision)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    use crate::custom_ops::run_instantiation_pass;
105    use crate::custom_ops::CustomOperation;
106    use crate::data_types::array_type;
107    use crate::data_types::scalar_type;
108    use crate::data_values::Value;
109    use crate::evaluators::random_evaluate;
110    use crate::graphs::util::simple_context;
111
112    fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
113        let c = simple_context(|g| {
114            let i = g.input(scalar_type(INT64))?;
115            g.custom_op(CustomOperation::new(ApproxExponent { precision }), vec![i])
116        })?;
117        let mapped_c = run_instantiation_pass(c)?;
118        let result = random_evaluate(
119            mapped_c.get_context().get_main_graph()?,
120            vec![Value::from_scalar(arg, INT64)?],
121        )?;
122        let res = result.to_i64(INT64)?;
123        Ok(res)
124    }
125
126    fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
127        let array_t = array_type(vec![arg.len() as u64], INT64);
128        let c = simple_context(|g| {
129            let i = g.input(array_t.clone())?;
130            g.custom_op(
131                CustomOperation::new(ApproxExponent { precision: 10 }),
132                vec![i],
133            )
134        })?;
135        let mapped_c = run_instantiation_pass(c)?;
136        let result = random_evaluate(
137            mapped_c.get_context().get_main_graph()?,
138            vec![Value::from_flattened_array(&arg, INT64)?],
139        )?;
140        result.to_flattened_array_i64(array_t)
141    }
142
143    #[test]
144    fn test_approx_exp_scalar() {
145        for i in vec![-10000, -1000, -100, -1, 0, 1, 100, 1000, 10000] {
146            let expected = (((i as f64) / 1024.0).exp() * 1024.0) as i64;
147            let actual = scalar_helper(i, 10).unwrap();
148            let relative_error = ((expected - actual).abs() as f64)
149                / (1.0 + f64::max(expected as f64, actual as f64));
150            assert!(relative_error <= 0.05);
151        }
152    }
153
154    #[test]
155    fn test_approx_exp_array() {
156        let arr = vec![23, 32, 57, 1271, 183, 555, -23, -32, -57, -1271, -183, -555];
157        let res = array_helper(arr.clone()).unwrap();
158        for i in 0..arr.len() {
159            let expected = (((arr[i] as f64) / 1024.0).exp() * 1024.0) as i64;
160            let actual = res[i];
161            let relative_error = ((expected - actual).abs() as f64)
162                / (1.0 + f64::max(expected as f64, actual as f64));
163            assert!(relative_error <= 0.05);
164        }
165    }
166}