ciphercore_base/ops/pwl/
approx_exponent.rs1use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{Type, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph};
6
7use serde::{Deserialize, Serialize};
8
9use super::approx_pointwise::{create_approximation, PWLConfig};
10
11#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
38pub struct ApproxExponent {
39 pub precision: u64,
41}
42
43#[typetag::serde]
44impl CustomOperationBody for ApproxExponent {
45 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
46 if arguments_types.len() != 1 {
47 return Err(runtime_error!(
48 "Invalid number of arguments for ApproxExponent"
49 ));
50 }
51 let t = arguments_types[0].clone();
52 if !t.is_scalar() && !t.is_array() {
53 return Err(runtime_error!(
54 "Argument in ApproxExponent must be a scalar or an array"
55 ));
56 }
57 let sc = t.get_scalar_type();
58 if sc != INT64 {
59 return Err(runtime_error!(
60 "Argument in ApproxExponent must consist of INT64's"
61 ));
62 }
63 if self.precision > 30 || self.precision == 0 {
64 return Err(runtime_error!("`precision` should be in range [1, 30]."));
65 }
66
67 let g = context.create_graph()?;
68 let arg = g.input(t)?;
69 let result = create_approximation(
79 arg,
80 |x| x.exp(),
81 -10.0,
82 10.0,
83 self.precision,
84 PWLConfig {
85 log_buckets: 5,
86 flatten_left: true,
87 flatten_right: false,
88 },
89 )?;
90 result.set_as_output()?;
91 g.finalize()?;
92 Ok(g)
93 }
94
95 fn get_name(&self) -> String {
96 format!("ApproxExponent(scaling_factor=2**{})", self.precision)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 use crate::custom_ops::run_instantiation_pass;
105 use crate::custom_ops::CustomOperation;
106 use crate::data_types::array_type;
107 use crate::data_types::scalar_type;
108 use crate::data_values::Value;
109 use crate::evaluators::random_evaluate;
110 use crate::graphs::util::simple_context;
111
112 fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
113 let c = simple_context(|g| {
114 let i = g.input(scalar_type(INT64))?;
115 g.custom_op(CustomOperation::new(ApproxExponent { precision }), vec![i])
116 })?;
117 let mapped_c = run_instantiation_pass(c)?;
118 let result = random_evaluate(
119 mapped_c.get_context().get_main_graph()?,
120 vec![Value::from_scalar(arg, INT64)?],
121 )?;
122 let res = result.to_i64(INT64)?;
123 Ok(res)
124 }
125
126 fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
127 let array_t = array_type(vec![arg.len() as u64], INT64);
128 let c = simple_context(|g| {
129 let i = g.input(array_t.clone())?;
130 g.custom_op(
131 CustomOperation::new(ApproxExponent { precision: 10 }),
132 vec![i],
133 )
134 })?;
135 let mapped_c = run_instantiation_pass(c)?;
136 let result = random_evaluate(
137 mapped_c.get_context().get_main_graph()?,
138 vec![Value::from_flattened_array(&arg, INT64)?],
139 )?;
140 result.to_flattened_array_i64(array_t)
141 }
142
143 #[test]
144 fn test_approx_exp_scalar() {
145 for i in vec![-10000, -1000, -100, -1, 0, 1, 100, 1000, 10000] {
146 let expected = (((i as f64) / 1024.0).exp() * 1024.0) as i64;
147 let actual = scalar_helper(i, 10).unwrap();
148 let relative_error = ((expected - actual).abs() as f64)
149 / (1.0 + f64::max(expected as f64, actual as f64));
150 assert!(relative_error <= 0.05);
151 }
152 }
153
154 #[test]
155 fn test_approx_exp_array() {
156 let arr = vec![23, 32, 57, 1271, 183, 555, -23, -32, -57, -1271, -183, -555];
157 let res = array_helper(arr.clone()).unwrap();
158 for i in 0..arr.len() {
159 let expected = (((arr[i] as f64) / 1024.0).exp() * 1024.0) as i64;
160 let actual = res[i];
161 let relative_error = ((expected - actual).abs() as f64)
162 / (1.0 + f64::max(expected as f64, actual as f64));
163 assert!(relative_error <= 0.05);
164 }
165 }
166}