ciphercore_base/ops/fixed_precision/
fixed_multiply.rs

1//! Multiplication for the fixed-precision arithmetic.
2use crate::custom_ops::CustomOperationBody;
3use crate::data_types::{Type, BIT, INT64};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node, SliceElement};
6use crate::ops::utils::{ones_like, reduce_mul, unsqueeze};
7use crate::typed_value::TypedValue;
8use crate::typed_value_operations::TypedValueArrayOperations;
9
10use serde::{Deserialize, Serialize};
11
12use super::fixed_precision_config::FixedPrecisionConfig;
13
14/// Multiplication of numbers in fixed precision.
15///
16/// In particular, given numbers represented as `x / 2^fractional_bits` and `y / 2^fractional_bits`, this operation returns `x * y / 2^fractional_bits`.
17/// This operation supports debug mode, which checks for overflow.
18///
19/// # Custom operation arguments
20///
21/// - Fixed precision config
22///
23/// # Custom operation returns
24///
25/// Node representing the product of the numbers.
26///
27/// # Example
28///
29/// ```
30/// # use ciphercore_base::graphs::create_context;
31/// # use ciphercore_base::data_types::{array_type, INT64};
32/// # use ciphercore_base::custom_ops::{CustomOperation};
33/// # use ciphercore_base::ops::fixed_precision::fixed_multiply::FixedMultiply;
34/// # use ciphercore_base::ops::fixed_precision::fixed_precision_config::FixedPrecisionConfig;
35/// let c = create_context().unwrap();
36/// let g = c.create_graph().unwrap();
37/// let t = array_type(vec![2, 3], INT64);
38/// let a = g.input(t.clone()).unwrap();
39/// let b = g.input(t.clone()).unwrap();
40/// let config = FixedPrecisionConfig {fractional_bits: 10, debug: false};
41/// let res = g.custom_op(CustomOperation::new(FixedMultiply {config}), vec![a, b]).unwrap();
42/// ```
43#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
44pub struct FixedMultiply {
45    pub config: FixedPrecisionConfig,
46}
47
48#[typetag::serde]
49impl CustomOperationBody for FixedMultiply {
50    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
51        if arguments_types.len() != 2 {
52            return Err(runtime_error!("FixedMultiply takes two arguments"));
53        }
54        for arg in arguments_types.iter() {
55            if !arg.is_array() && !arg.is_scalar() {
56                return Err(runtime_error!(
57                    "FixedMultiply expects scalar or array, got {:?}",
58                    arg
59                ));
60            }
61            if arg.get_scalar_type() != INT64 {
62                return Err(runtime_error!("FixedMultiply expects INT64, got {:?}", arg));
63            }
64        }
65
66        let g = context.create_graph()?;
67        let a = g.input(arguments_types[0].clone())?;
68        let b = g.input(arguments_types[1].clone())?;
69        let mut a_times_b = a.multiply(b.clone())?;
70        if self.config.debug {
71            a_times_b = g.assert(
72                "Integer overflow".into(),
73                is_multiplication_safe_from_overflow(a, b)?,
74                a_times_b,
75            )?
76        }
77        let a_times_b_shifted = a_times_b.truncate(self.config.denominator())?;
78        a_times_b_shifted.set_as_output()?;
79        g.finalize()?;
80        Ok(g)
81    }
82
83    fn get_name(&self) -> String {
84        format!("FixedMultiply({})", self.config.fractional_bits)
85    }
86}
87
88/// This function checks whether it is safe to multiply INT64 numbers without overflowing or getting close to it.
89/// The primary use-case is inside `FixedMultiply` for the debug mode, but it can also be used in isolation if needed.
90pub fn is_multiplication_safe_from_overflow(x: Node, y: Node) -> Result<Node> {
91    let x_bits = x.a2b()?;
92    let y_bits = y.a2b()?;
93    // If `x` and `y` were broadcastable before A2B, they will be broadcastable after A2B.
94    // First, let's make sure we don't need to deal with negative numbers.
95    // The way we do it is the following: if the MSB bit is set, we flip _all_ bits. This is not exactly negation in the two-complement form, but is close enough.
96    let msb_x = x_bits.get_slice(vec![
97        SliceElement::Ellipsis,
98        SliceElement::SubArray(Some(-1), None, None),
99    ])?;
100    let x_bits = x_bits.add(msb_x)?;
101    let msb_y = y_bits.get_slice(vec![
102        SliceElement::Ellipsis,
103        SliceElement::SubArray(Some(-1), None, None),
104    ])?;
105    let y_bits = y_bits.add(msb_y)?;
106    // Now, we're checking the following:
107    //   for every pair of indices (i, j) such that i + j >= 56, we check that x[..., i] * y[.., j] == 0.
108    // Why 56? Because this guarantees that the resulting product is below 2**63:
109    // -- let's consider s in 0..55, and i + j = s;
110    // -- max possible contribution for such s is (2 ** s) * (s + 1);
111    // -- summing over s <= 55, we get \sum (2 ** s) * (s + 1) = 3963167672086036481, which happens to be lower than 2 ** 63 - 1.
112    // -- fwiw, for 56, it would be very close to 2 ** 63 - 1, but still lower. For 57, it is already higher.
113    //
114    // First, compute xy[..., i, j] = x[..., i] * y[..., j].
115    let xy_bits = unsqueeze(x_bits, -1)?.multiply(unsqueeze(y_bits, -2)?)?;
116    // Now, mask out pairs of bits such that i + j < 56.
117    let mut mask_arr = ndarray::Array2::zeros((64, 64));
118    for i in 0..64 {
119        for j in 0..64 {
120            if i + j >= 56 {
121                mask_arr[[i, j]] = 1;
122            }
123        }
124    }
125    let mask_tv = TypedValue::from_ndarray(mask_arr.into_dyn(), BIT)?;
126    let g = x.get_graph();
127    let mask = g.constant(mask_tv.t, mask_tv.value)?;
128    let xy_bits = xy_bits.multiply(mask)?;
129    // Now, all that remains is to check that `xy_bits` is zero.
130    // This is surprisingly non-trivial, we need to reduce it across all dimensions.
131    let one = ones_like(xy_bits.clone())?;
132    let not_xy_bits = xy_bits.add(one)?;
133    let mut reduction_result = not_xy_bits;
134    while reduction_result.get_type()?.is_array() {
135        reduction_result = reduce_mul(reduction_result)?;
136    }
137    // `reduction_result` is 1 iff all of `xy_bits` are 0.
138    Ok(reduction_result)
139}
140
141#[cfg(test)]
142mod tests {
143    use ndarray::array;
144
145    use super::*;
146    use crate::custom_ops::run_instantiation_pass;
147    use crate::custom_ops::CustomOperation;
148    use crate::evaluators::random_evaluate;
149    use crate::graphs::create_context;
150    use crate::typed_value_operations::ToNdarray;
151    use crate::typed_value_operations::TypedValueArrayOperations;
152
153    fn multiply_helper(
154        a: TypedValue,
155        b: TypedValue,
156        config: FixedPrecisionConfig,
157    ) -> Result<TypedValue> {
158        let c = create_context()?;
159        let g = c.create_graph()?;
160        let node_a = g.input(a.t.clone())?;
161        let node_b = g.input(b.t.clone())?;
162        let o = g.custom_op(
163            CustomOperation::new(FixedMultiply { config }),
164            vec![node_a, node_b],
165        )?;
166        let t = o.get_type()?;
167        o.set_as_output()?;
168        g.finalize()?;
169        g.set_as_main()?;
170        c.finalize()?;
171        let mapped_c = run_instantiation_pass(c)?;
172        let result = random_evaluate(
173            mapped_c.get_context().get_main_graph()?,
174            vec![a.value, b.value],
175        )?;
176        TypedValue::new(t, result)
177    }
178
179    #[test]
180    fn test_multiply_scalars() -> Result<()> {
181        let int_config = FixedPrecisionConfig {
182            fractional_bits: 0,
183            debug: false,
184        };
185        let two_times_two = multiply_helper(
186            TypedValue::from_scalar(2, INT64)?,
187            TypedValue::from_scalar(2, INT64)?,
188            int_config,
189        )?
190        .to_u64()?;
191        assert_eq!(two_times_two, 4);
192        let five_times_six = multiply_helper(
193            TypedValue::from_scalar(5, INT64)?,
194            TypedValue::from_scalar(6, INT64)?,
195            int_config,
196        )?
197        .to_u64()?;
198        assert_eq!(five_times_six, 30);
199
200        let fixed_config = FixedPrecisionConfig {
201            fractional_bits: 15,
202            debug: false,
203        };
204        let two_times_two = multiply_helper(
205            TypedValue::from_scalar(2 << 15, INT64)?,
206            TypedValue::from_scalar(2 << 15, INT64)?,
207            fixed_config,
208        )?
209        .to_u64()?;
210        assert_eq!(two_times_two, 4 << 15);
211        let five_times_six = multiply_helper(
212            TypedValue::from_scalar(5 << 15, INT64)?,
213            TypedValue::from_scalar(6 << 15, INT64)?,
214            fixed_config,
215        )?
216        .to_u64()?;
217        assert_eq!(five_times_six, 30 << 15);
218        Ok(())
219    }
220
221    #[test]
222    fn test_multiply_negative() -> Result<()> {
223        let fixed_config = FixedPrecisionConfig {
224            fractional_bits: 15,
225            debug: false,
226        };
227        let two_times_minus_three = multiply_helper(
228            TypedValue::from_scalar(2 << 15, INT64)?,
229            TypedValue::from_scalar(-3 << 15, INT64)?,
230            fixed_config,
231        )?
232        .to_u64()?;
233        assert_eq!(two_times_minus_three as i64, -6 << 15);
234        Ok(())
235    }
236
237    #[test]
238    fn test_multiply_arrays() -> Result<()> {
239        let fixed_config = FixedPrecisionConfig {
240            fractional_bits: 15,
241            debug: false,
242        };
243        let two_times_x = ToNdarray::<i64>::to_ndarray(&multiply_helper(
244            TypedValue::from_scalar(2 << 15, INT64)?,
245            TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
246            fixed_config,
247        )?)?;
248        assert_eq!(two_times_x.into_raw_vec(), vec![2 << 15, 4 << 15, 6 << 15]);
249        let x_times_two = ToNdarray::<i64>::to_ndarray(&multiply_helper(
250            TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
251            TypedValue::from_scalar(2 << 15, INT64)?,
252            fixed_config,
253        )?)?;
254        assert_eq!(x_times_two.into_raw_vec(), vec![2 << 15, 4 << 15, 6 << 15]);
255        let x_times_y = ToNdarray::<i64>::to_ndarray(&multiply_helper(
256            TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
257            TypedValue::from_ndarray(array![4 << 15, 5 << 15, 6 << 15].into_dyn(), INT64)?,
258            fixed_config,
259        )?)?;
260        assert_eq!(x_times_y.into_raw_vec(), vec![4 << 15, 10 << 15, 18 << 15]);
261        Ok(())
262    }
263
264    #[test]
265    fn test_multiply_broadcast() -> Result<()> {
266        let fixed_config = FixedPrecisionConfig {
267            fractional_bits: 15,
268            debug: false,
269        };
270        let x_times_y = ToNdarray::<i64>::to_ndarray(&multiply_helper(
271            TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?,
272            TypedValue::from_ndarray(array![2 << 15].into_dyn(), INT64)?,
273            fixed_config,
274        )?)?;
275        assert_eq!(x_times_y.into_raw_vec(), vec![2 << 15, 4 << 15, 6 << 15]);
276        Ok(())
277    }
278
279    #[test]
280    fn test_multiply_debug_mode_success() -> Result<()> {
281        let fixed_config = FixedPrecisionConfig {
282            fractional_bits: 15,
283            debug: true,
284        };
285        let two_times_two = multiply_helper(
286            TypedValue::from_scalar(2 << 15, INT64)?,
287            TypedValue::from_scalar(2 << 15, INT64)?,
288            fixed_config,
289        )?
290        .to_u64()?;
291        assert_eq!(two_times_two, 4 << 15);
292        Ok(())
293    }
294
295    #[test]
296    fn test_multiply_debug_mode_fail() -> Result<()> {
297        let fixed_config = FixedPrecisionConfig {
298            fractional_bits: 15,
299            debug: true,
300        };
301        let err = multiply_helper(
302            TypedValue::from_scalar(1 << 30, INT64)?,
303            TypedValue::from_scalar(1 << 30, INT64)?,
304            fixed_config,
305        );
306        assert!(err.is_err());
307        Ok(())
308    }
309
310    fn overflow_helper(a: TypedValue, b: TypedValue) -> Result<bool> {
311        let c = create_context()?;
312        let g = c.create_graph()?;
313        let node_a = g.input(a.t.clone())?;
314        let node_b = g.input(b.t.clone())?;
315        let o = is_multiplication_safe_from_overflow(node_a, node_b)?;
316        let t = o.get_type()?;
317        o.set_as_output()?;
318        g.finalize()?;
319        g.set_as_main()?;
320        c.finalize()?;
321        let mapped_c = run_instantiation_pass(c)?;
322        let result = random_evaluate(
323            mapped_c.get_context().get_main_graph()?,
324            vec![a.value, b.value],
325        )?;
326        Ok(TypedValue::new(t, result)?.to_u64()? > 0)
327    }
328
329    #[test]
330    fn test_overflow_check_success() -> Result<()> {
331        let one = TypedValue::from_scalar(1, INT64)?;
332        let two = TypedValue::from_scalar(2, INT64)?;
333        let small_number = TypedValue::from_scalar(4243, INT64)?;
334        let two_to_twenty_five = TypedValue::from_scalar(1 << 25, INT64)?;
335        let medium_number = TypedValue::from_scalar(71479832, INT64)?;
336        let two_to_thirty = TypedValue::from_scalar(1 << 30, INT64)?;
337        let two_to_fifty = TypedValue::from_scalar(1_i64 << 50, INT64)?;
338        let minus_two = TypedValue::from_scalar(-2, INT64)?;
339        let minus_one = TypedValue::from_scalar(-1, INT64)?;
340        assert!(overflow_helper(one.clone(), two.clone())?);
341        assert!(overflow_helper(one.clone(), minus_two.clone())?);
342        assert!(overflow_helper(minus_one.clone(), minus_one)?);
343        assert!(overflow_helper(small_number.clone(), small_number.clone())?);
344        assert!(overflow_helper(small_number, medium_number.clone())?);
345        assert!(overflow_helper(
346            two_to_twenty_five.clone(),
347            two_to_twenty_five
348        )?);
349        assert!(overflow_helper(medium_number.clone(), medium_number)?);
350        assert!(overflow_helper(two, two_to_thirty.clone())?);
351        assert!(overflow_helper(minus_two, two_to_thirty)?);
352        assert!(overflow_helper(one, two_to_fifty)?);
353        Ok(())
354    }
355
356    #[test]
357    fn test_overflow_check_fail() -> Result<()> {
358        let two_to_twenty_five = TypedValue::from_scalar(1 << 25, INT64)?;
359        let two_to_thirty = TypedValue::from_scalar(1 << 30, INT64)?;
360        let two_to_fifty = TypedValue::from_scalar(1_i64 << 50, INT64)?;
361        let large_number = TypedValue::from_scalar(2363897937439121_i64, INT64)?;
362        let minus_two_to_thirty = TypedValue::from_scalar(-1 << 30, INT64)?;
363        assert!(!overflow_helper(
364            two_to_twenty_five.clone(),
365            two_to_fifty.clone()
366        )?);
367        assert!(!overflow_helper(
368            two_to_thirty.clone(),
369            two_to_thirty.clone()
370        )?);
371        assert!(!overflow_helper(
372            two_to_thirty.clone(),
373            large_number.clone()
374        )?);
375        assert!(!overflow_helper(large_number.clone(), large_number)?);
376        assert!(!overflow_helper(
377            minus_two_to_thirty,
378            two_to_thirty.clone()
379        )?);
380        assert!(!overflow_helper(two_to_thirty, two_to_fifty)?);
381        Ok(())
382    }
383
384    #[test]
385    fn test_overflow_check_success_arrays() -> Result<()> {
386        let x = TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?;
387        let y = TypedValue::from_scalar(2, INT64)?;
388        assert!(overflow_helper(x.clone(), y.clone())?);
389        let x = TypedValue::from_ndarray(array![1 << 15, 2 << 15, 3 << 15].into_dyn(), INT64)?;
390        let y = TypedValue::from_ndarray(array![10 << 15, 20 << 15, 30 << 15].into_dyn(), INT64)?;
391        assert!(overflow_helper(x.clone(), y.clone())?);
392        Ok(())
393    }
394
395    #[test]
396    fn test_overflow_check_fail_arrays() -> Result<()> {
397        let x = TypedValue::from_ndarray(array![1 << 25, 1 << 26, 1 << 27].into_dyn(), INT64)?;
398        let y = TypedValue::from_scalar(1 << 30, INT64)?;
399        assert!(!overflow_helper(x.clone(), y.clone())?);
400        let x = TypedValue::from_ndarray(array![1 << 25, 1 << 26, 1 << 27].into_dyn(), INT64)?;
401        let y = TypedValue::from_ndarray(array![1 << 28, 1 << 29, 1 << 30].into_dyn(), INT64)?;
402        assert!(!overflow_helper(x.clone(), y.clone())?);
403        Ok(())
404    }
405}