use crate::data_types::{array_type, scalar_size_in_bits, scalar_type, vector_type, BIT};
use crate::data_values::Value;
use crate::errors::Result;
use crate::graphs::{Node, SliceElement};
use crate::ops::utils::pull_out_bits;
pub struct PWLConfig {
pub log_buckets: u64,
pub flatten_left: bool,
pub flatten_right: bool,
}
pub fn create_approximation<F>(
x: Node,
f: F,
left: f32,
right: f32,
precision: u64,
config: PWLConfig,
) -> Result<Node>
where
F: Fn(f32) -> f32,
{
let st = x.get_type()?.get_scalar_type();
if !st.is_signed() {
return Err(runtime_error!("Only signed types are supported"));
}
if right <= left {
return Err(runtime_error!(
"Interval boundaries [left, right] should satisfy right > left"
));
}
let log_buckets = config.log_buckets;
if log_buckets == 0 {
return Err(runtime_error!("log_buckets should be positive"));
}
let bit_len = scalar_size_in_bits(st);
if log_buckets >= bit_len - 2 {
return Err(runtime_error!("Too many approximation buckets"));
}
let g = x.get_graph();
let scale = 1 << log_buckets;
let mut xs = vec![];
let mut ys = vec![];
for i in -1..(scale + 2) {
let x = left + (right - left) * (i as f32) / (scale as f32);
let y = f(x);
xs.push((x * ((1 << precision) as f32)) as i64);
ys.push((y * ((1 << precision) as f32)) as i64);
}
let mut alphas = vec![];
let mut betas = vec![];
for i in 1..xs.len() {
let x0 = xs[i - 1];
let x1 = xs[i];
let y0 = ys[i - 1];
let y1 = ys[i];
let c = ((y1 - y0) << precision) / (x1 - x0);
alphas.push(c);
betas.push(y0 - ((c * x0) >> precision));
}
if config.flatten_left {
alphas[0] = 0;
betas[0] = ys[0];
}
if config.flatten_right {
let n = alphas.len() - 1;
alphas[n] = 0;
betas[n] = ys[ys.len() - 2];
}
let alphas_arr = g.constant(
array_type(vec![alphas.len() as u64], st),
Value::from_flattened_array(&alphas, st)?,
)?;
let betas_arr = g.constant(
array_type(vec![betas.len() as u64], st),
Value::from_flattened_array(&betas, st)?,
)?;
let mut x_shape = x.get_type()?.get_dimensions();
x_shape.push(1);
let expanded_x = x.reshape(array_type(x_shape, st))?;
let all_vals = expanded_x
.multiply(alphas_arr)?
.truncate(1 << precision)?
.add(betas_arr)?;
let mut perm: Vec<u64> = (0..all_vals.get_type()?.get_shape().len())
.map(|x| x as u64)
.collect();
perm.rotate_right(1);
let potential_vals = all_vals.permute_axes(perm)?;
let left_fp = (left * ((1 << precision) as f32)) as i64;
let left_node = g.constant(scalar_type(st), Value::from_scalar(left_fp, st)?)?;
let shifted_x = x.subtract(left_node)?;
let divisor = if log_buckets <= precision {
((right - left) * ((1 << (precision - log_buckets)) as f32)) as u128
} else {
((right - left) as u128) >> (log_buckets - precision)
};
let scaled_x = shifted_x.truncate(divisor)?;
let bits = pull_out_bits(scaled_x.a2b()?)?;
let msb = bits.get(vec![bit_len - 1])?;
let high_bits = bits.get_slice(vec![
SliceElement::SubArray(Some(log_buckets as i64), Some((bit_len - 1) as i64), None),
SliceElement::Ellipsis,
])?;
let low_bits = bits.get_slice(vec![
SliceElement::SubArray(Some(0), Some(log_buckets as i64), None),
SliceElement::Ellipsis,
])?;
let main_result = tree_retrieve(
low_bits,
potential_vals.get_slice(vec![
SliceElement::SubArray(Some(1), Some(alphas.len() as i64 - 1), None),
SliceElement::Ellipsis,
])?,
)?;
let left_result = potential_vals.get(vec![0])?;
let right_result = potential_vals.get(vec![alphas.len() as u64 - 1])?;
let is_left = msb.clone();
let one = g.ones(scalar_type(BIT))?;
let is_right = any_bit_set(high_bits)?.multiply(msb.add(one.clone())?)?;
let is_main = is_left.add(one.clone())?.multiply(is_right.add(one)?)?;
let main_result_masked = main_result.mixed_multiply(is_main)?;
let left_result_masked = left_result.mixed_multiply(is_left)?;
let right_result_masked = right_result.mixed_multiply(is_right)?;
let result = main_result_masked
.add(left_result_masked)?
.add(right_result_masked)?;
Ok(result)
}
fn tree_retrieve(bits: Node, vals: Node) -> Result<Node> {
let n = bits.get_type()?.get_shape()[0];
let num_vals = vals.get_type()?.get_shape()[0];
if num_vals != (1 << n) {
return Err(runtime_error!(
"Logic error: number of tree leaves should be equal to 2 ** depth"
));
}
let mut data = vals;
for bit_index in 0..n {
let bit = bits.get(vec![bit_index])?;
let len = 1 << (n - bit_index);
let even = data.get_slice(vec![
SliceElement::SubArray(Some(0), Some(len - 1), Some(2)),
SliceElement::Ellipsis,
])?;
let odd = data.get_slice(vec![
SliceElement::SubArray(Some(1), Some(len), Some(2)),
SliceElement::Ellipsis,
])?;
data = odd.subtract(even.clone())?.mixed_multiply(bit)?.add(even)?;
}
data.get(vec![0])
}
fn any_bit_set(bits: Node) -> Result<Node> {
let g = bits.get_graph();
let one = g.ones(scalar_type(BIT))?;
let mut unset = bits.add(one.clone())?;
while unset.get_type()?.get_shape()[0] > 1 {
let n = unset.get_type()?.get_shape()[0];
let k = n / 2;
let half1 = unset.get_slice(vec![
SliceElement::SubArray(Some(0), Some(k as i64), None),
SliceElement::Ellipsis,
])?;
let half2 = unset.get_slice(vec![
SliceElement::SubArray(Some(k as i64), Some(2 * k as i64), None),
SliceElement::Ellipsis,
])?;
let reduced = half1.multiply(half2)?;
unset = if n % 2 == 0 {
reduced
} else {
let last_element = unset.get(vec![n - 1])?;
let elements = reduced.array_to_vector()?;
let joined_elements = g.create_tuple(vec![elements, last_element.clone()])?;
joined_elements
.reshape(vector_type(k + 1, last_element.get_type()?))?
.vector_to_array()?
};
}
unset.get(vec![0])?.add(one)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_ops::run_instantiation_pass;
use crate::data_types::INT64;
use crate::evaluators::random_evaluate;
use crate::graphs::util::simple_context;
fn scalar_helper(arg: f32, conf: PWLConfig) -> Result<f32> {
let c = simple_context(|g| {
let i = g.input(scalar_type(INT64))?;
create_approximation(i, |x| x * x, -2.0, 2.0, 10, conf)
})?;
let mapped_c = run_instantiation_pass(c)?;
let result = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![Value::from_scalar(
(arg * ((1 << 10) as f32)) as i64,
INT64,
)?],
)?;
let res = result.to_i64(INT64)?;
Ok((res as f32) / ((1 << 10) as f32))
}
fn array_helper(arg: Vec<f32>, shape: Vec<u64>, conf: PWLConfig) -> Result<Vec<f32>> {
let array_t = array_type(shape, INT64);
let c = simple_context(|g| {
let i = g.input(array_t.clone())?;
create_approximation(i, |x| x * x, -2.0, 2.0, 10, conf)
})?;
let mapped_c = run_instantiation_pass(c)?;
let mut arr = vec![];
for x in arg {
arr.push((x * ((1 << 10) as f32)) as i64);
}
let result = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![Value::from_flattened_array(&arr, INT64)?],
)?;
let output = result.to_flattened_array_i64(array_t)?;
let mut res = vec![];
for x in output {
res.push((x as f32) / ((1 << 10) as f32));
}
Ok(res)
}
#[test]
fn pwl_simple_test() -> Result<()> {
for i in 0..50 {
let x = ((i - 25) as f32) * 0.1;
let y = scalar_helper(
x,
PWLConfig {
log_buckets: 4,
flatten_left: false,
flatten_right: false,
},
)?;
let expected = x * x;
if i < 5 || i > 45 {
assert!((expected - y).abs() < 0.2);
} else {
assert!((expected - y).abs() < 0.02);
}
}
Ok(())
}
#[test]
fn pwl_flat_sides_test() -> Result<()> {
let mut prev: Option<f32> = None;
for i in 0..50 {
let x = ((i - 25) as f32) * 0.1;
let y = scalar_helper(
x,
PWLConfig {
log_buckets: 4,
flatten_left: true,
flatten_right: true,
},
)?;
if i >= 3 && i <= 45 {
assert!((y - x * x).abs() < 0.1);
prev = None;
} else {
if let Some(expected) = prev {
assert!((y - expected).abs() < 0.02);
}
prev = Some(y);
}
}
Ok(())
}
#[test]
fn pwl_array_test() -> Result<()> {
let mut xs = vec![];
for i in 0..50 {
let x = ((i - 25) as f32) * 0.1;
xs.push(x);
}
let ys = array_helper(
xs.clone(),
vec![xs.len() as u64],
PWLConfig {
log_buckets: 4,
flatten_left: false,
flatten_right: false,
},
)?;
for (x, y) in xs.iter().zip(ys.iter()) {
let expected = *x * *x;
if *x < -2.0 || *x > 2.0 {
assert!((expected - *y).abs() < 0.2);
} else {
assert!((expected - *y).abs() < 0.02);
}
}
Ok(())
}
#[test]
fn pwl_array2d_test() -> Result<()> {
let mut xs = vec![];
for i in 0..50 {
let x = ((i - 25) as f32) * 0.1;
xs.push(x);
}
let ys = array_helper(
xs.clone(),
vec![(xs.len() / 10) as u64, 10],
PWLConfig {
log_buckets: 4,
flatten_left: false,
flatten_right: false,
},
)?;
for (x, y) in xs.iter().zip(ys.iter()) {
let expected = *x * *x;
if *x < -2.0 || *x > 2.0 {
assert!((expected - *y).abs() < 0.2);
} else {
assert!((expected - *y).abs() < 0.02);
}
}
Ok(())
}
}