ciphercore_base/ops/pwl/
approx_pointwise.rs

1use crate::data_types::{array_type, scalar_size_in_bits, scalar_type, vector_type, BIT};
2use crate::data_values::Value;
3use crate::errors::Result;
4use crate::graphs::{Node, SliceElement};
5use crate::ops::utils::pull_out_bits;
6
7pub struct PWLConfig {
8    pub log_buckets: u64,
9    pub flatten_left: bool,
10    pub flatten_right: bool,
11}
12
13/// This helper approximates any given function with a piecewise-linear approximation.
14/// It is assumed that we're operating in fixed-precision arithmetic with `precision` bits after point.
15/// We're approximating the function on the segment [left, right], using 2 ** `config.log_buckets` equally-distanced control points.
16/// Note that the function has to be defined everywhere, even outside the segment.
17/// Behavior outside of the segment is determined by `config.flatten_left` and `config.flatten_right` (they control whether it will be approximated with a constant or linearly).
18///
19/// Some notes on implementation. It is optimized for performance, vectorizing operations whenever possible. It uses `Rounds(A2B) + max((config.log_buckets + 1) * Rounds(MixedMultiply), 6) + 2` network rounds.
20/// It is possible to remove the `config.log_buckets` at expense of more compute, but it is probably not worth it. It is also possible to slightly improve compute at expense of 7 more rounds.
21pub fn create_approximation<F>(
22    x: Node,
23    f: F,
24    left: f32,
25    right: f32,
26    precision: u64,
27    config: PWLConfig,
28) -> Result<Node>
29where
30    F: Fn(f32) -> f32,
31{
32    let st = x.get_type()?.get_scalar_type();
33    if !st.is_signed() {
34        return Err(runtime_error!("Only signed types are supported"));
35    }
36    if right <= left {
37        return Err(runtime_error!(
38            "Interval boundaries [left, right] should satisfy right > left"
39        ));
40    }
41    let log_buckets = config.log_buckets;
42    if log_buckets == 0 {
43        return Err(runtime_error!("log_buckets should be positive"));
44    }
45    let bit_len = scalar_size_in_bits(st);
46    if log_buckets >= bit_len - 2 {
47        return Err(runtime_error!("Too many approximation buckets"));
48    }
49
50    let g = x.get_graph();
51
52    let scale = 1 << log_buckets;
53    let mut xs = vec![];
54    let mut ys = vec![];
55    for i in -1..(scale + 2) {
56        let x = left + (right - left) * (i as f32) / (scale as f32);
57        let y = f(x);
58        xs.push((x * ((1 << precision) as f32)) as i64);
59        ys.push((y * ((1 << precision) as f32)) as i64);
60    }
61    // For each segment, the approximation is f_i(x) = alpha[i] * x + beta[i].
62    let mut alphas = vec![];
63    let mut betas = vec![];
64    for i in 1..xs.len() {
65        let x0 = xs[i - 1];
66        let x1 = xs[i];
67        let y0 = ys[i - 1];
68        let y1 = ys[i];
69        let c = ((y1 - y0) << precision) / (x1 - x0);
70        alphas.push(c);
71        betas.push(y0 - ((c * x0) >> precision));
72    }
73    if config.flatten_left {
74        alphas[0] = 0;
75        betas[0] = ys[0];
76    }
77    if config.flatten_right {
78        let n = alphas.len() - 1;
79        alphas[n] = 0;
80        betas[n] = ys[ys.len() - 2];
81    }
82
83    let alphas_arr = g.constant(
84        array_type(vec![alphas.len() as u64], st),
85        Value::from_flattened_array(&alphas, st)?,
86    )?;
87    let betas_arr = g.constant(
88        array_type(vec![betas.len() as u64], st),
89        Value::from_flattened_array(&betas, st)?,
90    )?;
91    // We compute potential values for all segments with broadcasting.
92    let mut x_shape = x.get_type()?.get_dimensions();
93    x_shape.push(1);
94    let expanded_x = x.reshape(array_type(x_shape, st))?;
95    let all_vals = expanded_x
96        .multiply(alphas_arr)?
97        .truncate(1 << precision)?
98        .add(betas_arr)?;
99    // Bring the dimension with segments to the front.
100    let mut perm: Vec<u64> = (0..all_vals.get_type()?.get_shape().len())
101        .map(|x| x as u64)
102        .collect();
103    perm.rotate_right(1);
104    let potential_vals = all_vals.permute_axes(perm)?;
105
106    let left_fp = (left * ((1 << precision) as f32)) as i64;
107    let left_node = g.constant(scalar_type(st), Value::from_scalar(left_fp, st)?)?;
108    // We want to linearly transform `x` so that `left` becomes 0, and `right` becomes `scale` (as integer).
109    let shifted_x = x.subtract(left_node)?;
110    // We need to find `divisor` so that `(right - left) / divisor` = scale.
111    // So divisor is `(right - left) / scale` (with the fixed-precision multiplier).
112    let divisor = if log_buckets <= precision {
113        ((right - left) * ((1 << (precision - log_buckets)) as f32)) as u128
114    } else {
115        ((right - left) as u128) >> (log_buckets - precision)
116    };
117    let scaled_x = shifted_x.truncate(divisor)?;
118
119    let bits = pull_out_bits(scaled_x.a2b()?)?;
120    // Now we're interested in the last `log_buckets` of clipped x, so we have to check msb and higher bits to see if we're outside of the interval.
121    let msb = bits.get(vec![bit_len - 1])?;
122    let high_bits = bits.get_slice(vec![
123        SliceElement::SubArray(Some(log_buckets as i64), Some((bit_len - 1) as i64), None),
124        SliceElement::Ellipsis,
125    ])?;
126    let low_bits = bits.get_slice(vec![
127        SliceElement::SubArray(Some(0), Some(log_buckets as i64), None),
128        SliceElement::Ellipsis,
129    ])?;
130
131    // NOTE: we could re-arrange potential_vals in the form potential_vals'[i] = potential_vals[reversed_bits(i)], which would make vectorization more efficient (but the code will be more arcane).
132    let main_result = tree_retrieve(
133        low_bits,
134        potential_vals.get_slice(vec![
135            SliceElement::SubArray(Some(1), Some(alphas.len() as i64 - 1), None),
136            SliceElement::Ellipsis,
137        ])?,
138    )?;
139    let left_result = potential_vals.get(vec![0])?;
140    let right_result = potential_vals.get(vec![alphas.len() as u64 - 1])?;
141
142    // Let's handle left/right cases. Left is easy - `x - r` should be negative, so we look at MSB.
143    // Right is a bit more tricky. We have to check that any bit higher than `log_buckets` is set and MSB is not set.
144    let is_left = msb.clone();
145    let one = g.ones(scalar_type(BIT))?;
146    let is_right = any_bit_set(high_bits)?.multiply(msb.add(one.clone())?)?;
147    let is_main = is_left.add(one.clone())?.multiply(is_right.add(one)?)?;
148
149    let main_result_masked = main_result.mixed_multiply(is_main)?;
150    let left_result_masked = left_result.mixed_multiply(is_left)?;
151    let right_result_masked = right_result.mixed_multiply(is_right)?;
152
153    // TODO: can left/right cases be handled more efficiently?
154    let result = main_result_masked
155        .add(left_result_masked)?
156        .add(right_result_masked)?;
157
158    Ok(result)
159}
160
161/// Note that the following function essentially accesses an item in array `val` by an index determined by the bits in `bits`.
162/// Rather than doing top-down dfs like we do in plaintext, this is doing down-up updates. We'll walk the whole tree
163/// either way, but down-up is more efficient in terms of graph size and vectorization.
164fn tree_retrieve(bits: Node, vals: Node) -> Result<Node> {
165    let n = bits.get_type()?.get_shape()[0];
166    let num_vals = vals.get_type()?.get_shape()[0];
167    if num_vals != (1 << n) {
168        return Err(runtime_error!(
169            "Logic error: number of tree leaves should be equal to 2 ** depth"
170        ));
171    }
172    let mut data = vals;
173    for bit_index in 0..n {
174        let bit = bits.get(vec![bit_index])?;
175        let len = 1 << (n - bit_index);
176        let even = data.get_slice(vec![
177            SliceElement::SubArray(Some(0), Some(len - 1), Some(2)),
178            SliceElement::Ellipsis,
179        ])?;
180        let odd = data.get_slice(vec![
181            SliceElement::SubArray(Some(1), Some(len), Some(2)),
182            SliceElement::Ellipsis,
183        ])?;
184        data = odd.subtract(even.clone())?.mixed_multiply(bit)?.add(even)?;
185    }
186    data.get(vec![0])
187}
188
189/// Efficient check whether any of the bits are set.
190/// Utilizes the fact that AND is not only associative but also commutative, making the graph more vectorized.
191fn any_bit_set(bits: Node) -> Result<Node> {
192    let g = bits.get_graph();
193    let one = g.ones(scalar_type(BIT))?;
194    let mut unset = bits.add(one.clone())?;
195    while unset.get_type()?.get_shape()[0] > 1 {
196        let n = unset.get_type()?.get_shape()[0];
197        let k = n / 2;
198        let half1 = unset.get_slice(vec![
199            SliceElement::SubArray(Some(0), Some(k as i64), None),
200            SliceElement::Ellipsis,
201        ])?;
202        let half2 = unset.get_slice(vec![
203            SliceElement::SubArray(Some(k as i64), Some(2 * k as i64), None),
204            SliceElement::Ellipsis,
205        ])?;
206        let reduced = half1.multiply(half2)?;
207        unset = if n % 2 == 0 {
208            reduced
209        } else {
210            let last_element = unset.get(vec![n - 1])?;
211            let elements = reduced.array_to_vector()?;
212            let joined_elements = g.create_tuple(vec![elements, last_element.clone()])?;
213            joined_elements
214                .reshape(vector_type(k + 1, last_element.get_type()?))?
215                .vector_to_array()?
216        };
217    }
218    unset.get(vec![0])?.add(one)
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    use crate::custom_ops::run_instantiation_pass;
226    use crate::data_types::INT64;
227    use crate::evaluators::random_evaluate;
228    use crate::graphs::util::simple_context;
229
230    fn scalar_helper(arg: f32, conf: PWLConfig) -> Result<f32> {
231        let c = simple_context(|g| {
232            let i = g.input(scalar_type(INT64))?;
233            create_approximation(i, |x| x * x, -2.0, 2.0, 10, conf)
234        })?;
235        let mapped_c = run_instantiation_pass(c)?;
236        let result = random_evaluate(
237            mapped_c.get_context().get_main_graph()?,
238            vec![Value::from_scalar(
239                (arg * ((1 << 10) as f32)) as i64,
240                INT64,
241            )?],
242        )?;
243        let res = result.to_i64(INT64)?;
244        Ok((res as f32) / ((1 << 10) as f32))
245    }
246
247    fn array_helper(arg: Vec<f32>, shape: Vec<u64>, conf: PWLConfig) -> Result<Vec<f32>> {
248        let array_t = array_type(shape, INT64);
249        let c = simple_context(|g| {
250            let i = g.input(array_t.clone())?;
251            create_approximation(i, |x| x * x, -2.0, 2.0, 10, conf)
252        })?;
253        let mapped_c = run_instantiation_pass(c)?;
254        let mut arr = vec![];
255        for x in arg {
256            arr.push((x * ((1 << 10) as f32)) as i64);
257        }
258        let result = random_evaluate(
259            mapped_c.get_context().get_main_graph()?,
260            vec![Value::from_flattened_array(&arr, INT64)?],
261        )?;
262        let output = result.to_flattened_array_i64(array_t)?;
263        let mut res = vec![];
264        for x in output {
265            res.push((x as f32) / ((1 << 10) as f32));
266        }
267        Ok(res)
268    }
269
270    #[test]
271    fn pwl_simple_test() -> Result<()> {
272        for i in 0..50 {
273            let x = ((i - 25) as f32) * 0.1;
274            let y = scalar_helper(
275                x,
276                PWLConfig {
277                    log_buckets: 4,
278                    flatten_left: false,
279                    flatten_right: false,
280                },
281            )?;
282            let expected = x * x;
283            if i < 5 || i > 45 {
284                // Edges.
285                assert!((expected - y).abs() < 0.2);
286            } else {
287                // Inner stuff.
288                assert!((expected - y).abs() < 0.02);
289            }
290        }
291        Ok(())
292    }
293
294    #[test]
295    fn pwl_flat_sides_test() -> Result<()> {
296        let mut prev: Option<f32> = None;
297        for i in 0..50 {
298            let x = ((i - 25) as f32) * 0.1;
299            let y = scalar_helper(
300                x,
301                PWLConfig {
302                    log_buckets: 4,
303                    flatten_left: true,
304                    flatten_right: true,
305                },
306            )?;
307            if i >= 3 && i <= 45 {
308                assert!((y - x * x).abs() < 0.1);
309                prev = None;
310            } else {
311                if let Some(expected) = prev {
312                    assert!((y - expected).abs() < 0.02);
313                }
314                prev = Some(y);
315            }
316        }
317        Ok(())
318    }
319
320    #[test]
321    fn pwl_array_test() -> Result<()> {
322        let mut xs = vec![];
323        for i in 0..50 {
324            let x = ((i - 25) as f32) * 0.1;
325            xs.push(x);
326        }
327        let ys = array_helper(
328            xs.clone(),
329            vec![xs.len() as u64],
330            PWLConfig {
331                log_buckets: 4,
332                flatten_left: false,
333                flatten_right: false,
334            },
335        )?;
336        for (x, y) in xs.iter().zip(ys.iter()) {
337            let expected = *x * *x;
338            if *x < -2.0 || *x > 2.0 {
339                // Edges.
340                assert!((expected - *y).abs() < 0.2);
341            } else {
342                // Inner stuff.
343                assert!((expected - *y).abs() < 0.02);
344            }
345        }
346        Ok(())
347    }
348
349    #[test]
350    fn pwl_array2d_test() -> Result<()> {
351        let mut xs = vec![];
352        for i in 0..50 {
353            let x = ((i - 25) as f32) * 0.1;
354            xs.push(x);
355        }
356        let ys = array_helper(
357            xs.clone(),
358            vec![(xs.len() / 10) as u64, 10],
359            PWLConfig {
360                log_buckets: 4,
361                flatten_left: false,
362                flatten_right: false,
363            },
364        )?;
365        for (x, y) in xs.iter().zip(ys.iter()) {
366            let expected = *x * *x;
367            if *x < -2.0 || *x > 2.0 {
368                // Edges.
369                assert!((expected - *y).abs() < 0.2);
370            } else {
371                // Inner stuff.
372                assert!((expected - *y).abs() < 0.02);
373            }
374        }
375        Ok(())
376    }
377}