ciphercore_base/ops/
adder.rs

1//! Binary adder that adds two bitstrings.
2use crate::custom_ops::{CustomOperation, CustomOperationBody};
3use crate::data_types::{array_type, Type, BIT};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node, SliceElement};
6use crate::ops::utils::{expand_dims, put_in_bits};
7
8use serde::{Deserialize, Serialize};
9
10use super::utils::{pull_out_bits_pair, validate_arguments_in_broadcast_bit_ops};
11
12/// A structure that defines the custom operation BinaryAdd that implements the binary adder.
13///
14/// The binary adder takes two arrays of length-n bitstrings and returns the elementwise binary sum of these arrays.
15/// If overflow_bit is true, the output is a tuple (sum, overflow_bit) instead.
16///
17/// Only `n` which are powers of two are supported.
18///
19/// The last dimension of both inputs must be the same; it defines the length of input bitstrings.
20/// If input shapes are different, the broadcasting rules are applied (see [the NumPy broadcasting rules](https://numpy.org/doc/stable/user/basics.broadcasting.html)).
21/// For example, if input arrays are of shapes `[2,3]`, and `[1,3]`, the resulting array has shape `[2,3]`.
22///
23/// This operation is needed for conversion between arithmetic and boolean additive MPC shares
24/// (i.e. A2B and B2A operations in MPC).
25///
26/// To use this and other custom operations in computation graphs, see [Graph::custom_op].
27///
28/// # Custom operation arguments
29///
30/// - Node containing a binary array or scalar
31/// - Node containing a binary array or scalar
32///
33/// # Custom operation returns
34///
35/// New BinaryAdd node
36///
37/// # Example
38///
39/// ```
40/// # use ciphercore_base::graphs::create_context;
41/// # use ciphercore_base::data_types::{array_type, BIT};
42/// # use ciphercore_base::custom_ops::{CustomOperation};
43/// # use ciphercore_base::ops::adder::BinaryAdd;
44/// let c = create_context().unwrap();
45/// let g = c.create_graph().unwrap();
46/// let t = array_type(vec![2, 4], BIT);
47/// let n1 = g.input(t.clone()).unwrap();
48/// let n2 = g.input(t.clone()).unwrap();
49/// let n3 = g.custom_op(CustomOperation::new(BinaryAdd { overflow_bit: false }), vec![n1, n2]).unwrap();
50/// ```
51#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
52pub struct BinaryAdd {
53    pub overflow_bit: bool,
54}
55
56#[typetag::serde]
57impl CustomOperationBody for BinaryAdd {
58    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
59        validate_arguments_in_broadcast_bit_ops(arguments_types.clone(), &self.get_name())?;
60        let input_type0 = arguments_types[0].clone();
61        let input_type1 = arguments_types[1].clone();
62
63        // Adder input consists of two binary strings x and y
64        let g = context.create_graph()?;
65        let (input0, input1) = pull_out_bits_pair(g.input(input_type0)?, g.input(input_type1)?)?;
66        let added = g.custom_op(
67            CustomOperation::new(BinaryAddTransposed {
68                overflow_bit: self.overflow_bit,
69            }),
70            vec![input0, input1],
71        )?;
72        let output = if self.overflow_bit {
73            g.create_tuple(vec![
74                put_in_bits(added.tuple_get(0)?)?,
75                put_in_bits(added.tuple_get(1)?)?,
76            ])?
77        } else {
78            put_in_bits(added)?
79        };
80        output.set_as_output()?;
81        g.finalize()?;
82        Ok(g)
83    }
84
85    fn get_name(&self) -> String {
86        format!("BinaryAdd(overflow_bit={})", self.overflow_bit)
87    }
88}
89
90// Same as BinaryAdd, but expect that the first dimension is bits.
91// This is a performance optimization, it's easier to operate on the first dimension.
92#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
93pub(crate) struct BinaryAddTransposed {
94    pub overflow_bit: bool,
95}
96
97#[typetag::serde]
98impl CustomOperationBody for BinaryAddTransposed {
99    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
100        if arguments_types.len() != 2 {
101            return Err(runtime_error!("Invalid number of arguments"));
102        }
103        match (&arguments_types[0], &arguments_types[1]) {
104            (Type::Array(shape0, scalar_type0), Type::Array(shape1, scalar_type1)) => {
105                if shape0[0] != shape1[0] {
106                    return Err(runtime_error!(
107                        "Input arrays' first dimensions are not the same"
108                    ));
109                }
110                if *scalar_type0 != BIT {
111                    return Err(runtime_error!("Input array [0]'s ScalarType is not BIT"));
112                }
113                if *scalar_type1 != BIT {
114                    return Err(runtime_error!("Input array [1]'s ScalarType is not BIT"));
115                }
116            }
117            _ => {
118                return Err(runtime_error!(
119                    "Invalid input argument type, expected Array type"
120                ));
121            }
122        }
123
124        let input_type0 = arguments_types[0].clone();
125        let input_type1 = arguments_types[1].clone();
126
127        // Adder input consists of two binary strings x and y
128        let g = context.create_graph()?;
129        let input0 = g.input(input_type0)?;
130        let input1 = g.input(input_type1)?;
131        // Compute "propagate" bits x_i XOR y_i
132        let xor_bits = g.add(input0.clone(), input1.clone())?;
133        // Compute "generate" bits x_i AND y_i
134        let and_bits = g.multiply(input0, input1)?;
135
136        let (carries, overflow_bit) =
137            calculate_carry_bits(xor_bits.clone(), and_bits, self.overflow_bit)?;
138        // The last step is to XOR carries with "propagate" bits
139        let added = carries.add(xor_bits)?;
140        let output = match overflow_bit {
141            Some(overflow_bit) => g.create_tuple(vec![added, overflow_bit])?,
142            None => added,
143        };
144        output.set_as_output()?;
145        g.finalize()?;
146        Ok(g)
147    }
148
149    fn get_name(&self) -> String {
150        format!("BinaryAddTransposed(overflow_bit={})", self.overflow_bit)
151    }
152}
153
154/// Actual carry is calculated as `generate + propagate * prev_carry`
155///
156/// `(propagate, generate) = (1, 1)` is impossible state.
157#[derive(Clone)]
158struct CarryNode {
159    propagate: Node,
160    generate: Node,
161}
162
163impl CarryNode {
164    fn bit_len(&self) -> Result<u64> {
165        Ok(self.propagate.get_type()?.get_shape()[0])
166    }
167
168    fn shrink(&self, overflow_bit: bool) -> Result<CarryNode> {
169        let bit_len = self.bit_len()? as i64;
170
171        let next_lvl_bits = if overflow_bit {
172            bit_len / 2
173        } else {
174            (bit_len - 1) / 2
175        };
176        let use_bits = next_lvl_bits * 2;
177        let lower = self.sub_slice(0, use_bits)?;
178        let higher = self.sub_slice(1, use_bits)?;
179
180        lower.join(&higher)
181    }
182
183    /// assumes `bit_len` is the same for `self` and `rhs`
184    fn join(&self, rhs: &Self) -> Result<Self> {
185        let propagate = self.propagate.multiply(rhs.propagate.clone())?;
186        let generate = rhs
187            .generate
188            .add(rhs.propagate.multiply(self.generate.clone())?)?;
189        Ok(Self {
190            propagate,
191            generate,
192        })
193    }
194
195    /// Returns every second element starting from `start_offset`
196    fn sub_slice(&self, start_offset: i64, bit_len: i64) -> Result<Self> {
197        let get_slice = |node: &Node| {
198            node.get_slice(vec![SliceElement::SubArray(
199                Some(start_offset),
200                Some(bit_len),
201                Some(2),
202            )])
203        };
204        Ok(Self {
205            propagate: get_slice(&self.propagate)?,
206            generate: get_slice(&self.generate)?,
207        })
208    }
209
210    fn apply(&self, prev_carry: Node) -> Result<Node> {
211        self.generate.add(self.propagate.multiply(prev_carry)?)
212    }
213}
214
215/// Takes arrays `[a1, a2, ..., a_n]` and `[b1, b2, ..., b_n]`
216///
217/// Returns `[a1, b1, a2, b2, ..., a_n, b_n]`
218fn interleave(first: Node, second: Node) -> Result<Node> {
219    let first = expand_dims(first, &[0])?;
220    let second = expand_dims(second, &[0])?;
221    let graph = first.get_graph();
222    let joined = graph.concatenate(vec![first, second], 0)?;
223    let mut axes: Vec<_> = (0..joined.get_type()?.get_shape().len() as u64).collect();
224    axes.swap(0, 1);
225    let joined = joined.permute_axes(axes)?;
226    let mut shape = joined.get_type()?.get_shape();
227    shape[0] *= 2;
228    shape.remove(1);
229    let scalar = joined.get_type()?.get_scalar_type();
230    joined.reshape(array_type(shape, scalar))
231}
232
233/// This function generates a graph for the "segment tree" to calculate
234/// carry bits.
235///
236/// It assumes both `propagate_bits` and `generate_bits` are arrays
237/// with bits dimension pulled out to the outermost level.
238/// It also assumes the number of bits is a power of two.
239///
240/// Each node of the segment tree labeled `ij` stores two bits `(propagate, generate)` that
241/// are used to compute carry[j+1] given carry[i], as `generate + propagate * carry[i]`.
242///
243/// The overall multiplicative depth of generated segment tree is `2*log(bits)`.
244/// First, we generate nodes for each separate bit.
245/// Then, we join neighboring nodes, until we have a single node.
246///
247/// When the top node is calculated, we go top-down and push carry bits to the lower layers.
248/// In the implementation, bits from nodes on the same tree layer are stored together in a
249/// single CarryNode in the nodes[] array.
250///
251/// # Example
252/// Let's say we have 8 bits. First, we create a node for each bit:
253/// ```text
254/// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |    <- stored in `nodes[0]`
255///   \   /   \   /   \   /   \   /
256///    01      23      45      67        <- stored in `nodes[1]`
257///      \    /          \    /
258///        03              47            <- stored in `nodes[2]`
259///           \          /  
260///                07                    <- stored in `nodes[3]`
261/// ```
262///
263/// Then we calculate carry bits iterating over layers from bottom to top. We know
264/// that `carry[0] = 0`.
265///
266/// Based on `nodes[3]` we can calculate the overflow bit: carry[n].
267///
268/// Based on `nodes[2]` we calculate `carry[4]`, interleave with `carry[0]`,
269/// and get `{carry[0], carry[4]}`.
270///
271/// Then we use `nodes[1]` to generate `{carry[2], carry[6]}`, interleave with
272/// previous result, and get `{carry[0], carry[2], carry[4], carry[6]}`. Note
273/// that to do this we only need half of the values from `nodes[1]`
274/// (we don't need nodes `23` or `67`).
275///
276/// And finally using (half of) `nodes[0]` we can calculate all odd indexes carries.
277///
278/// As an optimization, if overflow_bit=false, we remove the last node from each layer
279/// (except layer 0) as they're not needed.
280///
281/// Returns (carry[0..n-1], Some(carry[n])) tuple if overflow_bit=true, or
282/// (carry[0..n-1], None) tuple if overflow_bit=false,
283/// representing carry bits and the overflow bit.
284fn calculate_carry_bits(
285    propagate_bits: Node,
286    generate_bits: Node,
287    overflow_bit: bool,
288) -> Result<(Node, Option<Node>)> {
289    let graph = propagate_bits.get_graph();
290
291    let mut nodes = vec![CarryNode {
292        propagate: propagate_bits,
293        generate: generate_bits,
294    }];
295    let bit_len = nodes[0].bit_len()?;
296    if !bit_len.is_power_of_two() {
297        return Err(runtime_error!("BinaryAdd only supports numbers with number of bits, which is a power of 2. {} bits provided.", bit_len));
298    }
299    let mut shape = nodes[0].propagate.get_type()?.get_shape();
300    shape[0] = 1;
301    let mut carries = graph.zeros(array_type(shape, BIT))?;
302    // Two special cases for the overflow_bit=false optimization:
303    // If the input is 1 bit long: ignore input and just return zero:
304    if !overflow_bit && bit_len == 1 {
305        return Ok((carries, None));
306    }
307    // If the input is 2 bits long: we must not join the two nodes on the initial layer.
308    if overflow_bit || bit_len > 2 {
309        while nodes.last().unwrap().bit_len()? > 1 {
310            let last = nodes.last().unwrap();
311            nodes.push(last.shrink(overflow_bit)?);
312        }
313    }
314
315    let mut node_rev_iter = nodes.iter().rev();
316    let overflow_bit = if overflow_bit {
317        let root_node = node_rev_iter.next().unwrap();
318        Some(root_node.apply(carries.clone())?)
319    } else {
320        None
321    };
322    for node in node_rev_iter {
323        let lower = node.sub_slice(0, node.bit_len()? as i64)?;
324        let new_carries = lower.apply(carries.clone())?;
325        carries = interleave(carries, new_carries)?;
326    }
327
328    Ok((carries, overflow_bit))
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    use crate::custom_ops::{run_instantiation_pass, CustomOperation};
336    use crate::data_types::{
337        array_type, tuple_type, ScalarType, INT16, INT64, UINT16, UINT32, UINT64, UINT8,
338    };
339    use crate::data_values::Value;
340    use crate::evaluators::random_evaluate;
341    use crate::graphs::create_context;
342    use crate::graphs::util::simple_context;
343
344    fn test_helper(first: u64, second: u64, st: ScalarType) -> Result<()> {
345        let bits = st.size_in_bits();
346        let mask = (1u128 << bits) - 1;
347        let first = (first as u128) & mask;
348        let second = (second as u128) & mask;
349
350        let c = simple_context(|g| {
351            let i1 = g.input(array_type(vec![bits], BIT))?;
352            let i2 = g.input(array_type(vec![bits], BIT))?;
353            let o = g.custom_op(
354                CustomOperation::new(BinaryAdd {
355                    overflow_bit: false,
356                }),
357                vec![i1, i2],
358            )?;
359            assert_eq!(
360                o.get_type()?.get_dimensions(),
361                vec![bits],
362                "{first} + {second} with {bits} bits"
363            );
364            Ok(o)
365        })?;
366        let mapped_c = run_instantiation_pass(c)?;
367        let input0 = Value::from_scalar(first, st)?;
368        let input1 = Value::from_scalar(second, st)?;
369        let result_v = random_evaluate(
370            mapped_c.get_context().get_main_graph()?,
371            vec![input0, input1],
372        )?
373        .to_u128(st)?;
374
375        let expected_result = first.wrapping_add(second) & mask;
376        assert_eq!(
377            result_v, expected_result,
378            "{first} + {second} with {bits} bits"
379        );
380        Ok(())
381    }
382
383    #[test]
384    fn test_random_inputs() -> Result<()> {
385        let random_numbers = [0, 1, 3, 4, 10, 100500, 123456, 787788];
386        for st in [BIT, UINT8, UINT16, UINT32, UINT64] {
387            for &x in random_numbers.iter() {
388                for &y in random_numbers.iter() {
389                    test_helper(x, y, st)?;
390                }
391            }
392        }
393        Ok(())
394    }
395
396    fn add_with_overflow_helper(first: u64, second: u64, st: ScalarType) -> Result<(u64, u64)> {
397        let bits = st.size_in_bits();
398        let c = simple_context(|g| {
399            let i1 = g.input(array_type(vec![bits], BIT))?;
400            let i2 = g.input(array_type(vec![bits], BIT))?;
401            g.custom_op(
402                CustomOperation::new(BinaryAdd { overflow_bit: true }),
403                vec![i1, i2],
404            )
405        })?;
406        let mapped_c = run_instantiation_pass(c)?;
407        let input0 = Value::from_scalar(first, st)?;
408        let input1 = Value::from_scalar(second, st)?;
409        let results = random_evaluate(
410            mapped_c.get_context().get_main_graph()?,
411            vec![input0, input1],
412        )?
413        .to_vector()?;
414        Ok((results[0].to_u64(st)?, results[1].to_u64(BIT)?))
415    }
416
417    #[test]
418    fn test_add_with_overflow_bit() -> Result<()> {
419        for (first, second, st, want_sum, want_overflow) in [
420            (0, 0, BIT, 0, 0),
421            (0, 1, BIT, 1, 0),
422            (1, 0, BIT, 1, 0),
423            (1, 1, BIT, 0, 1),
424            (127, 128, UINT8, 255, 0),
425            (127, 129, UINT8, 0, 1),
426            (128, 128, UINT8, 0, 1),
427            (255, 255, UINT8, 254, 1),
428            (1234, 4321, UINT16, 5555, 0),
429            (12345, 54321, UINT16, 1130, 1),
430            (12345, 54321, UINT32, 66666, 0),
431            (2000000000, 2000000000, UINT32, 4000000000, 0),
432            (2000000000, 3000000000, UINT32, 705032704, 1),
433            (u64::MAX, u64::MAX, UINT64, u64::MAX - 1, 1),
434        ] {
435            let (got_sum, got_overflow) = add_with_overflow_helper(first, second, st)?;
436            assert_eq!(got_sum, want_sum, "{first} + {second}");
437            assert_eq!(got_overflow, want_overflow, "{first} + {second}");
438        }
439        Ok(())
440    }
441
442    #[test]
443    fn test_well_behaved() -> Result<()> {
444        {
445            let c = simple_context(|g| {
446                let i1 = g.input(array_type(vec![5, 16], BIT))?;
447                let i2 = g.input(array_type(vec![1, 16], BIT))?;
448                g.custom_op(
449                    CustomOperation::new(BinaryAdd {
450                        overflow_bit: false,
451                    }),
452                    vec![i1, i2],
453                )
454            })?;
455            let mapped_c = run_instantiation_pass(c)?;
456            let inputs1 =
457                Value::from_flattened_array(&vec![0, 1023, -1023, i16::MIN, i16::MAX], INT16)?;
458            let inputs2 = Value::from_flattened_array(&vec![1024], INT16)?;
459            let result_v = random_evaluate(
460                mapped_c.get_context().get_main_graph()?,
461                vec![inputs1, inputs2],
462            )?
463            .to_flattened_array_u64(array_type(vec![5], INT16))?;
464            assert_eq!(
465                result_v,
466                vec![
467                    1024,
468                    2047,
469                    1,
470                    (i16::MIN + 1024) as u64,
471                    (i16::MAX.wrapping_add(1024)) as u64,
472                ]
473            );
474        }
475        {
476            let c = simple_context(|g| {
477                let i1 = g.input(array_type(vec![64], BIT))?;
478                let i2 = g.input(array_type(vec![64], BIT))?;
479                g.custom_op(
480                    CustomOperation::new(BinaryAdd {
481                        overflow_bit: false,
482                    }),
483                    vec![i1, i2],
484                )
485            })?;
486            let mapped_c = run_instantiation_pass(c)?;
487            let input0 = Value::from_scalar(123456790, INT64)?;
488            let input1 = Value::from_scalar(-123456789, INT64)?;
489            let result_v = random_evaluate(
490                mapped_c.get_context().get_main_graph()?,
491                vec![input0, input1],
492            )?
493            .to_u64(INT64)?;
494            assert_eq!(result_v, 1);
495        }
496        Ok(())
497    }
498
499    #[test]
500    fn test_malformed() -> Result<()> {
501        let c = create_context()?;
502        let g = c.create_graph()?;
503        let i = g.input(array_type(vec![64], BIT))?;
504        let i1 = g.input(array_type(vec![64], INT16))?;
505        let i2 = g.input(tuple_type(vec![]))?;
506        let i3 = g.input(array_type(vec![32], BIT))?;
507        let i4 = g.input(array_type(vec![31], BIT))?;
508        assert!(g
509            .custom_op(
510                CustomOperation::new(BinaryAdd {
511                    overflow_bit: false
512                }),
513                vec![i.clone()]
514            )
515            .is_err());
516        assert!(g
517            .custom_op(
518                CustomOperation::new(BinaryAdd {
519                    overflow_bit: false
520                }),
521                vec![i.clone(), i1.clone()]
522            )
523            .is_err());
524        assert!(g
525            .custom_op(
526                CustomOperation::new(BinaryAdd {
527                    overflow_bit: false
528                }),
529                vec![i1.clone(), i.clone()]
530            )
531            .is_err());
532        assert!(g
533            .custom_op(
534                CustomOperation::new(BinaryAdd {
535                    overflow_bit: false
536                }),
537                vec![i2]
538            )
539            .is_err());
540        assert!(g
541            .custom_op(
542                CustomOperation::new(BinaryAdd {
543                    overflow_bit: false
544                }),
545                vec![i.clone(), i3]
546            )
547            .is_err());
548        assert!(g
549            .custom_op(
550                CustomOperation::new(BinaryAdd {
551                    overflow_bit: false
552                }),
553                vec![i4.clone(), i4]
554            )
555            .is_err());
556        Ok(())
557    }
558}