use crate::custom_ops::{CustomOperation, CustomOperationBody};
use crate::data_types::{array_type, scalar_type, vector_type, Type, BIT, INT64};
use crate::errors::Result;
use crate::graphs::{Context, Graph, SliceElement};
use crate::ops::utils::{pull_out_bits, put_in_bits};
use serde::{Deserialize, Serialize};
use super::comparisons::GreaterThanEqualTo;
use super::utils::{constant_scalar, multiply_fixed_point, zeros_like};
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct TaylorExponent {
pub taylor_terms: u64,
pub fixed_precision_points: u64,
}
#[typetag::serde]
impl CustomOperationBody for TaylorExponent {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
if arguments_types.len() != 1 {
return Err(runtime_error!(
"Invalid number of arguments for TaylorExponent"
));
}
let t = arguments_types[0].clone();
if !t.is_scalar() && !t.is_array() {
return Err(runtime_error!(
"Argument in TaylorExponent must be a scalar or an array"
));
}
let sc = t.get_scalar_type();
if sc != INT64 {
return Err(runtime_error!(
"Argument in TaylorExponent must consist of INT64's"
));
}
if self.fixed_precision_points > 15 {
return Err(runtime_error!("fixed_precision_points is too large."));
}
let bit_type = if t.is_scalar() {
scalar_type(BIT)
} else {
array_type(t.get_shape(), BIT)
};
let g = context.create_graph()?;
let arg = g.input(t.clone())?;
let one_over_ln2_int = (((1 << self.fixed_precision_points) as f64) / 2.0_f64.ln()) as u64;
let one_over_ln2 = constant_scalar(&g, one_over_ln2_int, sc)?;
let x = multiply_fixed_point(arg, one_over_ln2, self.fixed_precision_points)?;
let binary_x = x.a2b()?;
let x_bits = pull_out_bits(binary_x.clone())?;
let msb = x_bits.get(vec![63])?;
let max_exp_bits = (31f64 - self.fixed_precision_points as f64).log2().ceil() as u64;
let one = g.ones(t)?;
let mut exp_integer = one.clone();
for i in self.fixed_precision_points..self.fixed_precision_points + max_exp_bits {
let bit = x_bits.get(vec![i])?;
let j = i - self.fixed_precision_points;
let p2 = constant_scalar(&g, 1_u64 << (1_u64 << j), sc)?;
let term = p2
.subtract(one.clone())?
.mixed_multiply(bit.clone())?
.add(one.clone())?;
exp_integer = exp_integer.multiply(term)?;
}
let exp_fractional = if self.fixed_precision_points == 0 {
one
} else {
let bits_after_point = x_bits.get_slice(vec![
SliceElement::SubArray(Some(0), Some(self.fixed_precision_points as i64), None),
SliceElement::Ellipsis,
])?;
let mut bits_before_point_shape = x_bits.get_type()?.get_shape();
bits_before_point_shape[0] = 64 - self.fixed_precision_points;
let zero_bits_before_point = g.zeros(array_type(bits_before_point_shape, BIT))?;
let stacked_frac_bits = g.create_tuple(vec![
bits_after_point.array_to_vector()?,
zero_bits_before_point.array_to_vector()?,
])?;
let stacked_type = vector_type(64, bit_type);
let x_frac = put_in_bits(stacked_frac_bits.reshape(stacked_type)?.vector_to_array()?)?
.b2a(sc)?;
let mut exp_fractional = zeros_like(x_frac.clone())?;
let mut coef = constant_scalar(&g, 1 << self.fixed_precision_points, sc)?;
let ln2_int = (2_f64.ln() * ((1 << self.fixed_precision_points) as f64)) as u64;
let ln2 = constant_scalar(&g, ln2_int, sc)?;
let y = multiply_fixed_point(x_frac, ln2, self.fixed_precision_points)?;
for i in 0..self.taylor_terms {
exp_fractional = exp_fractional.add(coef.clone())?;
if i < self.taylor_terms - 1 {
coef = coef.multiply(y.clone())?;
coef = coef.truncate((i as u128 + 1) << self.fixed_precision_points)?;
}
}
exp_fractional
};
let exp = exp_fractional.multiply(exp_integer)?;
let one_over_exp = exp.truncate(1u128 << (1u64 << max_exp_bits))?;
let upper_bound_for_inversion =
constant_scalar(&g, (-10) * (1 << self.fixed_precision_points), sc)?.a2b()?;
let inversion_overflow_bit = g.custom_op(
CustomOperation::new(GreaterThanEqualTo {
signed_comparison: true,
}),
vec![binary_x, upper_bound_for_inversion],
)?;
let mut result = exp.add(one_over_exp.subtract(exp.clone())?.mixed_multiply(msb)?)?;
result = result.mixed_multiply(inversion_overflow_bit)?;
result.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!(
"TaylorExponent(taylor_terms={}, fixed_precision_denom=2**{})",
self.taylor_terms, self.fixed_precision_points
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_ops::run_instantiation_pass;
use crate::custom_ops::CustomOperation;
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::util::simple_context;
fn scalar_helper(arg: i64, precision: u64) -> Result<i64> {
let c = simple_context(|g| {
let i = g.input(scalar_type(INT64))?;
g.custom_op(
CustomOperation::new(TaylorExponent {
taylor_terms: 5,
fixed_precision_points: precision,
}),
vec![i],
)
})?;
let mapped_c = run_instantiation_pass(c)?;
let result = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![Value::from_scalar(arg, INT64)?],
)?;
let res = result.to_i64(INT64)?;
Ok(res)
}
fn array_helper(arg: Vec<i64>) -> Result<Vec<i64>> {
let array_t = array_type(vec![arg.len() as u64], INT64);
let c = simple_context(|g| {
let i = g.input(array_t.clone())?;
g.custom_op(
CustomOperation::new(TaylorExponent {
taylor_terms: 5,
fixed_precision_points: 10,
}),
vec![i],
)
})?;
let mapped_c = run_instantiation_pass(c)?;
let result = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![Value::from_flattened_array(&arg, INT64)?],
)?;
result.to_flattened_array_i64(array_t)
}
#[test]
fn test_exp_scalar() {
for i in vec![-10000, -1000, -100, -1, 0, 1, 100, 1000, 10000] {
let expected = (((i as f64) / 1024.0).exp() * 1024.0) as i64;
let actual = scalar_helper(i, 10).unwrap();
let relative_error = ((expected - actual).abs() as f64)
/ (1.0 + f64::max(expected as f64, actual as f64));
assert!(relative_error <= 0.01);
}
}
#[test]
fn test_exp_array() {
let arr = vec![23, 32, 57, 1271, 183, 555, -23, -32, -57, -1271, -183, -555];
let res = array_helper(arr.clone()).unwrap();
for i in 0..arr.len() {
let expected = (((arr[i] as f64) / 1024.0).exp() * 1024.0) as i64;
let actual = res[i];
let relative_error = ((expected - actual).abs() as f64)
/ (1.0 + f64::max(expected as f64, actual as f64));
assert!(relative_error <= 0.01);
}
}
#[test]
fn test_exp_integer() {
for i in vec![0, 1, 2, 3, 5] {
let expected = 1 << i;
let actual = scalar_helper(i, 0).unwrap();
let absolute_error = (expected - actual).abs();
assert!(absolute_error == 0);
}
}
}