1use crate::custom_ops::{CustomOperation, CustomOperationBody};
3use crate::data_types::{array_type, scalar_type, vector_type, Type, BIT, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, SliceElement};
6use crate::ops::utils::{pull_out_bits, put_in_bits};
7
8use serde::{Deserialize, Serialize};
9
10use super::comparisons::GreaterThanEqualTo;
11use super::utils::{constant_scalar, multiply_fixed_point, zeros_like};
12
13#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
44pub struct TaylorExponent {
45 pub taylor_terms: u64,
47 pub fixed_precision_points: u64,
49}
50
51#[typetag::serde]
52impl CustomOperationBody for TaylorExponent {
53 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
54 if arguments_types.len() != 1 {
55 return Err(runtime_error!(
56 "Invalid number of arguments for TaylorExponent"
57 ));
58 }
59 let t = arguments_types[0].clone();
60 if !t.is_scalar() && !t.is_array() {
61 return Err(runtime_error!(
62 "Argument in TaylorExponent must be a scalar or an array"
63 ));
64 }
65 let sc = t.get_scalar_type();
66 if sc != INT64 {
67 return Err(runtime_error!(
68 "Argument in TaylorExponent must consist of INT64's"
69 ));
70 }
71 if self.fixed_precision_points > 15 {
72 return Err(runtime_error!("fixed_precision_points is too large."));
73 }
74
75 let bit_type = if t.is_scalar() {
76 scalar_type(BIT)
77 } else {
78 array_type(t.get_shape(), BIT)
79 };
80
81 let g = context.create_graph()?;
82 let arg = g.input(t.clone())?;
83 let one_over_ln2_int = (((1 << self.fixed_precision_points) as f64) / 2.0_f64.ln()) as u64;
86 let one_over_ln2 = constant_scalar(&g, one_over_ln2_int, sc)?;
87 let x = multiply_fixed_point(arg, one_over_ln2, self.fixed_precision_points)?;
88
89 let binary_x = x.a2b()?;
90 let x_bits = pull_out_bits(binary_x.clone())?;
91 let msb = x_bits.get(vec![63])?;
92
93 let max_exp_bits = (31f64 - self.fixed_precision_points as f64).log2().ceil() as u64;
97 let one = g.ones(t)?;
98 let mut exp_integer = one.clone();
99 for i in self.fixed_precision_points..self.fixed_precision_points + max_exp_bits {
100 let bit = x_bits.get(vec![i])?;
101 let j = i - self.fixed_precision_points;
102 let p2 = constant_scalar(&g, 1_u64 << (1_u64 << j), sc)?;
103 let term = p2
105 .subtract(one.clone())?
106 .mixed_multiply(bit.clone())?
107 .add(one.clone())?;
108 exp_integer = exp_integer.multiply(term)?;
110 }
111
112 let exp_fractional = if self.fixed_precision_points == 0 {
115 one
116 } else {
117 let bits_after_point = x_bits.get_slice(vec![
118 SliceElement::SubArray(Some(0), Some(self.fixed_precision_points as i64), None),
119 SliceElement::Ellipsis,
120 ])?;
121 let mut bits_before_point_shape = x_bits.get_type()?.get_shape();
122 bits_before_point_shape[0] = 64 - self.fixed_precision_points;
123 let zero_bits_before_point = g.zeros(array_type(bits_before_point_shape, BIT))?;
124 let stacked_frac_bits = g.create_tuple(vec![
125 bits_after_point.array_to_vector()?,
126 zero_bits_before_point.array_to_vector()?,
127 ])?;
128 let stacked_type = vector_type(64, bit_type);
129 let x_frac = put_in_bits(stacked_frac_bits.reshape(stacked_type)?.vector_to_array()?)?
130 .b2a(sc)?;
131
132 let mut exp_fractional = zeros_like(x_frac.clone())?;
134 let mut coef = constant_scalar(&g, 1 << self.fixed_precision_points, sc)?;
135 let ln2_int = (2_f64.ln() * ((1 << self.fixed_precision_points) as f64)) as u64;
136 let ln2 = constant_scalar(&g, ln2_int, sc)?;
137 let y = multiply_fixed_point(x_frac, ln2, self.fixed_precision_points)?;
138 for i in 0..self.taylor_terms {
139 exp_fractional = exp_fractional.add(coef.clone())?;
140 if i < self.taylor_terms - 1 {
141 coef = coef.multiply(y.clone())?;
142 coef = coef.truncate((i as u128 + 1) << self.fixed_precision_points)?;
144 }
145 }
146 exp_fractional
147 };
148
149 let exp = exp_fractional.multiply(exp_integer)?;
152 let one_over_exp = exp.truncate(1u128 << (1u64 << max_exp_bits))?;
156 let upper_bound_for_inversion =
159 constant_scalar(&g, (-10) * (1 << self.fixed_precision_points), sc)?.a2b()?;
160 let inversion_overflow_bit = g.custom_op(
161 CustomOperation::new(GreaterThanEqualTo {
162 signed_comparison: true,
163 }),
164 vec![binary_x, upper_bound_for_inversion],
165 )?;
166 let mut result = exp.add(one_over_exp.subtract(exp.clone())?.mixed_multiply(msb)?)?;
167 result = result.mixed_multiply(inversion_overflow_bit)?;
168 result.set_as_output()?;
169 g.finalize()?;
170 Ok(g)
171 }
172
173 fn get_name(&self) -> String {
174 format!(
175 "TaylorExponent(taylor_terms={}, fixed_precision_denom=2**{})",
176 self.taylor_terms, self.fixed_precision_points
177 )
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 use crate::custom_ops::run_instantiation_pass;
186 use crate::custom_ops::CustomOperation;
187 use crate::data_values::Value;
188 use crate::evaluators::random_evaluate;
189 use crate::graphs::util::simple_context;
190
191 fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
192 let c = simple_context(|g| {
193 let i = g.input(scalar_type(INT64))?;
194 g.custom_op(
195 CustomOperation::new(TaylorExponent {
196 taylor_terms: 5,
197 fixed_precision_points: precision,
198 }),
199 vec![i],
200 )
201 })?;
202 let mapped_c = run_instantiation_pass(c)?;
203 let result = random_evaluate(
204 mapped_c.get_context().get_main_graph()?,
205 vec![Value::from_scalar(arg, INT64)?],
206 )?;
207 let res = result.to_i64(INT64)?;
208 Ok(res)
209 }
210
211 fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
212 let array_t = array_type(vec![arg.len() as u64], INT64);
213 let c = simple_context(|g| {
214 let i = g.input(array_t.clone())?;
215 g.custom_op(
216 CustomOperation::new(TaylorExponent {
217 taylor_terms: 5,
218 fixed_precision_points: 10,
219 }),
220 vec![i],
221 )
222 })?;
223 let mapped_c = run_instantiation_pass(c)?;
224 let result = random_evaluate(
225 mapped_c.get_context().get_main_graph()?,
226 vec![Value::from_flattened_array(&arg, INT64)?],
227 )?;
228 result.to_flattened_array_i64(array_t)
229 }
230
231 #[test]
232 fn test_exp_scalar() {
233 for i in vec![-10000, -1000, -100, -1, 0, 1, 100, 1000, 10000] {
234 let expected = (((i as f64) / 1024.0).exp() * 1024.0) as i64;
235 let actual = scalar_helper(i, 10).unwrap();
236 let relative_error = ((expected - actual).abs() as f64)
237 / (1.0 + f64::max(expected as f64, actual as f64));
238 assert!(relative_error <= 0.01);
239 }
240 }
241
242 #[test]
243 fn test_exp_array() {
244 let arr = vec![23, 32, 57, 1271, 183, 555, -23, -32, -57, -1271, -183, -555];
245 let res = array_helper(arr.clone()).unwrap();
246 for i in 0..arr.len() {
247 let expected = (((arr[i] as f64) / 1024.0).exp() * 1024.0) as i64;
248 let actual = res[i];
249 let relative_error = ((expected - actual).abs() as f64)
250 / (1.0 + f64::max(expected as f64, actual as f64));
251 assert!(relative_error <= 0.01);
252 }
253 }
254
255 #[test]
256 fn test_exp_integer() {
257 for i in vec![0, 1, 2, 3, 5] {
258 let expected = 1 << i;
260 let actual = scalar_helper(i, 0).unwrap();
261 let absolute_error = (expected - actual).abs();
262 assert!(absolute_error == 0);
263 }
264 }
265}