1use crate::data_types::{array_type, scalar_size_in_bits, scalar_type, vector_type, BIT};
2use crate::data_values::Value;
3use crate::errors::Result;
4use crate::graphs::{Node, SliceElement};
5use crate::ops::utils::pull_out_bits;
6
7pub struct PWLConfig {
8 pub log_buckets: u64,
9 pub flatten_left: bool,
10 pub flatten_right: bool,
11}
12
13pub fn create_approximation<F>(
22 x: Node,
23 f: F,
24 left: f32,
25 right: f32,
26 precision: u64,
27 config: PWLConfig,
28) -> Result<Node>
29where
30 F: Fn(f32) -> f32,
31{
32 let st = x.get_type()?.get_scalar_type();
33 if !st.is_signed() {
34 return Err(runtime_error!("Only signed types are supported"));
35 }
36 if right <= left {
37 return Err(runtime_error!(
38 "Interval boundaries [left, right] should satisfy right > left"
39 ));
40 }
41 let log_buckets = config.log_buckets;
42 if log_buckets == 0 {
43 return Err(runtime_error!("log_buckets should be positive"));
44 }
45 let bit_len = scalar_size_in_bits(st);
46 if log_buckets >= bit_len - 2 {
47 return Err(runtime_error!("Too many approximation buckets"));
48 }
49
50 let g = x.get_graph();
51
52 let scale = 1 << log_buckets;
53 let mut xs = vec![];
54 let mut ys = vec![];
55 for i in -1..(scale + 2) {
56 let x = left + (right - left) * (i as f32) / (scale as f32);
57 let y = f(x);
58 xs.push((x * ((1 << precision) as f32)) as i64);
59 ys.push((y * ((1 << precision) as f32)) as i64);
60 }
61 let mut alphas = vec![];
63 let mut betas = vec![];
64 for i in 1..xs.len() {
65 let x0 = xs[i - 1];
66 let x1 = xs[i];
67 let y0 = ys[i - 1];
68 let y1 = ys[i];
69 let c = ((y1 - y0) << precision) / (x1 - x0);
70 alphas.push(c);
71 betas.push(y0 - ((c * x0) >> precision));
72 }
73 if config.flatten_left {
74 alphas[0] = 0;
75 betas[0] = ys[0];
76 }
77 if config.flatten_right {
78 let n = alphas.len() - 1;
79 alphas[n] = 0;
80 betas[n] = ys[ys.len() - 2];
81 }
82
83 let alphas_arr = g.constant(
84 array_type(vec![alphas.len() as u64], st),
85 Value::from_flattened_array(&alphas, st)?,
86 )?;
87 let betas_arr = g.constant(
88 array_type(vec![betas.len() as u64], st),
89 Value::from_flattened_array(&betas, st)?,
90 )?;
91 let mut x_shape = x.get_type()?.get_dimensions();
93 x_shape.push(1);
94 let expanded_x = x.reshape(array_type(x_shape, st))?;
95 let all_vals = expanded_x
96 .multiply(alphas_arr)?
97 .truncate(1 << precision)?
98 .add(betas_arr)?;
99 let mut perm: Vec<u64> = (0..all_vals.get_type()?.get_shape().len())
101 .map(|x| x as u64)
102 .collect();
103 perm.rotate_right(1);
104 let potential_vals = all_vals.permute_axes(perm)?;
105
106 let left_fp = (left * ((1 << precision) as f32)) as i64;
107 let left_node = g.constant(scalar_type(st), Value::from_scalar(left_fp, st)?)?;
108 let shifted_x = x.subtract(left_node)?;
110 let divisor = if log_buckets <= precision {
113 ((right - left) * ((1 << (precision - log_buckets)) as f32)) as u128
114 } else {
115 ((right - left) as u128) >> (log_buckets - precision)
116 };
117 let scaled_x = shifted_x.truncate(divisor)?;
118
119 let bits = pull_out_bits(scaled_x.a2b()?)?;
120 let msb = bits.get(vec![bit_len - 1])?;
122 let high_bits = bits.get_slice(vec![
123 SliceElement::SubArray(Some(log_buckets as i64), Some((bit_len - 1) as i64), None),
124 SliceElement::Ellipsis,
125 ])?;
126 let low_bits = bits.get_slice(vec![
127 SliceElement::SubArray(Some(0), Some(log_buckets as i64), None),
128 SliceElement::Ellipsis,
129 ])?;
130
131 let main_result = tree_retrieve(
133 low_bits,
134 potential_vals.get_slice(vec![
135 SliceElement::SubArray(Some(1), Some(alphas.len() as i64 - 1), None),
136 SliceElement::Ellipsis,
137 ])?,
138 )?;
139 let left_result = potential_vals.get(vec![0])?;
140 let right_result = potential_vals.get(vec![alphas.len() as u64 - 1])?;
141
142 let is_left = msb.clone();
145 let one = g.ones(scalar_type(BIT))?;
146 let is_right = any_bit_set(high_bits)?.multiply(msb.add(one.clone())?)?;
147 let is_main = is_left.add(one.clone())?.multiply(is_right.add(one)?)?;
148
149 let main_result_masked = main_result.mixed_multiply(is_main)?;
150 let left_result_masked = left_result.mixed_multiply(is_left)?;
151 let right_result_masked = right_result.mixed_multiply(is_right)?;
152
153 let result = main_result_masked
155 .add(left_result_masked)?
156 .add(right_result_masked)?;
157
158 Ok(result)
159}
160
161fn tree_retrieve(bits: Node, vals: Node) -> Result<Node> {
165 let n = bits.get_type()?.get_shape()[0];
166 let num_vals = vals.get_type()?.get_shape()[0];
167 if num_vals != (1 << n) {
168 return Err(runtime_error!(
169 "Logic error: number of tree leaves should be equal to 2 ** depth"
170 ));
171 }
172 let mut data = vals;
173 for bit_index in 0..n {
174 let bit = bits.get(vec![bit_index])?;
175 let len = 1 << (n - bit_index);
176 let even = data.get_slice(vec![
177 SliceElement::SubArray(Some(0), Some(len - 1), Some(2)),
178 SliceElement::Ellipsis,
179 ])?;
180 let odd = data.get_slice(vec![
181 SliceElement::SubArray(Some(1), Some(len), Some(2)),
182 SliceElement::Ellipsis,
183 ])?;
184 data = odd.subtract(even.clone())?.mixed_multiply(bit)?.add(even)?;
185 }
186 data.get(vec![0])
187}
188
189fn any_bit_set(bits: Node) -> Result<Node> {
192 let g = bits.get_graph();
193 let one = g.ones(scalar_type(BIT))?;
194 let mut unset = bits.add(one.clone())?;
195 while unset.get_type()?.get_shape()[0] > 1 {
196 let n = unset.get_type()?.get_shape()[0];
197 let k = n / 2;
198 let half1 = unset.get_slice(vec![
199 SliceElement::SubArray(Some(0), Some(k as i64), None),
200 SliceElement::Ellipsis,
201 ])?;
202 let half2 = unset.get_slice(vec![
203 SliceElement::SubArray(Some(k as i64), Some(2 * k as i64), None),
204 SliceElement::Ellipsis,
205 ])?;
206 let reduced = half1.multiply(half2)?;
207 unset = if n % 2 == 0 {
208 reduced
209 } else {
210 let last_element = unset.get(vec![n - 1])?;
211 let elements = reduced.array_to_vector()?;
212 let joined_elements = g.create_tuple(vec![elements, last_element.clone()])?;
213 joined_elements
214 .reshape(vector_type(k + 1, last_element.get_type()?))?
215 .vector_to_array()?
216 };
217 }
218 unset.get(vec![0])?.add(one)
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 use crate::custom_ops::run_instantiation_pass;
226 use crate::data_types::INT64;
227 use crate::evaluators::random_evaluate;
228 use crate::graphs::util::simple_context;
229
230 fn scalar_helper(arg: f32, conf: PWLConfig) -> Result<f32> {
231 let c = simple_context(|g| {
232 let i = g.input(scalar_type(INT64))?;
233 create_approximation(i, |x| x * x, -2.0, 2.0, 10, conf)
234 })?;
235 let mapped_c = run_instantiation_pass(c)?;
236 let result = random_evaluate(
237 mapped_c.get_context().get_main_graph()?,
238 vec![Value::from_scalar(
239 (arg * ((1 << 10) as f32)) as i64,
240 INT64,
241 )?],
242 )?;
243 let res = result.to_i64(INT64)?;
244 Ok((res as f32) / ((1 << 10) as f32))
245 }
246
247 fn array_helper(arg: Vec<f32>, shape: Vec<u64>, conf: PWLConfig) -> Result<Vec<f32>> {
248 let array_t = array_type(shape, INT64);
249 let c = simple_context(|g| {
250 let i = g.input(array_t.clone())?;
251 create_approximation(i, |x| x * x, -2.0, 2.0, 10, conf)
252 })?;
253 let mapped_c = run_instantiation_pass(c)?;
254 let mut arr = vec![];
255 for x in arg {
256 arr.push((x * ((1 << 10) as f32)) as i64);
257 }
258 let result = random_evaluate(
259 mapped_c.get_context().get_main_graph()?,
260 vec![Value::from_flattened_array(&arr, INT64)?],
261 )?;
262 let output = result.to_flattened_array_i64(array_t)?;
263 let mut res = vec![];
264 for x in output {
265 res.push((x as f32) / ((1 << 10) as f32));
266 }
267 Ok(res)
268 }
269
270 #[test]
271 fn pwl_simple_test() -> Result<()> {
272 for i in 0..50 {
273 let x = ((i - 25) as f32) * 0.1;
274 let y = scalar_helper(
275 x,
276 PWLConfig {
277 log_buckets: 4,
278 flatten_left: false,
279 flatten_right: false,
280 },
281 )?;
282 let expected = x * x;
283 if i < 5 || i > 45 {
284 assert!((expected - y).abs() < 0.2);
286 } else {
287 assert!((expected - y).abs() < 0.02);
289 }
290 }
291 Ok(())
292 }
293
294 #[test]
295 fn pwl_flat_sides_test() -> Result<()> {
296 let mut prev: Option<f32> = None;
297 for i in 0..50 {
298 let x = ((i - 25) as f32) * 0.1;
299 let y = scalar_helper(
300 x,
301 PWLConfig {
302 log_buckets: 4,
303 flatten_left: true,
304 flatten_right: true,
305 },
306 )?;
307 if i >= 3 && i <= 45 {
308 assert!((y - x * x).abs() < 0.1);
309 prev = None;
310 } else {
311 if let Some(expected) = prev {
312 assert!((y - expected).abs() < 0.02);
313 }
314 prev = Some(y);
315 }
316 }
317 Ok(())
318 }
319
320 #[test]
321 fn pwl_array_test() -> Result<()> {
322 let mut xs = vec![];
323 for i in 0..50 {
324 let x = ((i - 25) as f32) * 0.1;
325 xs.push(x);
326 }
327 let ys = array_helper(
328 xs.clone(),
329 vec![xs.len() as u64],
330 PWLConfig {
331 log_buckets: 4,
332 flatten_left: false,
333 flatten_right: false,
334 },
335 )?;
336 for (x, y) in xs.iter().zip(ys.iter()) {
337 let expected = *x * *x;
338 if *x < -2.0 || *x > 2.0 {
339 assert!((expected - *y).abs() < 0.2);
341 } else {
342 assert!((expected - *y).abs() < 0.02);
344 }
345 }
346 Ok(())
347 }
348
349 #[test]
350 fn pwl_array2d_test() -> Result<()> {
351 let mut xs = vec![];
352 for i in 0..50 {
353 let x = ((i - 25) as f32) * 0.1;
354 xs.push(x);
355 }
356 let ys = array_helper(
357 xs.clone(),
358 vec![(xs.len() / 10) as u64, 10],
359 PWLConfig {
360 log_buckets: 4,
361 flatten_left: false,
362 flatten_right: false,
363 },
364 )?;
365 for (x, y) in xs.iter().zip(ys.iter()) {
366 let expected = *x * *x;
367 if *x < -2.0 || *x > 2.0 {
368 assert!((expected - *y).abs() < 0.2);
370 } else {
371 assert!((expected - *y).abs() < 0.02);
373 }
374 }
375 Ok(())
376 }
377}