ciphercore_base/ops/
long_division.rs

1//! Long division for bitstrings of arbitrary length.
2use crate::broadcast::broadcast_shapes;
3use crate::custom_ops::{CustomOperation, CustomOperationBody, Not};
4use crate::data_types::{array_type, scalar_type, tuple_type, ArrayShape, Type, BIT};
5use crate::errors::Result;
6use crate::graphs::{Context, Graph, Node, SliceElement};
7use crate::ops::multiplexer::Mux;
8use crate::ops::utils::unsqueeze;
9
10use serde::{Deserialize, Serialize};
11
12use super::adder::{BinaryAdd, BinaryAddTransposed};
13use super::comparisons::Equal;
14use super::utils::{prepend_dims, pull_out_bits_pair, put_in_bits};
15
16/// A structure that defines the custom operation LongDivision that computes the quotient and
17/// remainder of `dividend` / `divisor`, such that:
18///   quotient * divisor + remainder == dividend
19///
20/// # Custom operation arguments
21///
22/// - Node containing the dividend as a two's complement length-n bitstring.
23/// - Node containing the divisor as a two's complement length-n bitstring.
24///
25/// Only `n` which are powers of two are supported.
26///
27/// # Custom operation returns
28///
29/// Node containing the (quotient, remainder) tuple. Both quotient and remainder are bitstrings with
30/// lengths equal to dividend and divisor, respectively.
31///
32/// # Example
33/// ```
34/// # use ciphercore_base::graphs::util::simple_context;
35/// # use ciphercore_base::data_types::{array_type, INT32};
36/// # use ciphercore_base::custom_ops::{CustomOperation};
37/// # use ciphercore_base::ops::long_division::LongDivision;
38/// let t = array_type(vec![10, 25], INT32);
39/// let c = simple_context(|g| {
40///     let input_dividends = g.input(t.clone())?;
41///     let input_divisors = g.input(t.clone())?;
42///     let binary_dividends = input_dividends.a2b()?;
43///     let binary_divisors = input_divisors.a2b()?;
44///     let result = g.custom_op(
45///         CustomOperation::new(LongDivision {
46///             signed: true,
47///         }),
48///         vec![binary_dividends, binary_divisors],
49///     )?;
50///     let quotient = result.tuple_get(0)?.b2a(INT32)?;
51///     let remainder = result.tuple_get(1)?.b2a(INT32)?;
52///     g.create_tuple(vec![quotient, remainder])
53/// }).unwrap();
54/// ```
55#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
56pub struct LongDivision {
57    pub signed: bool,
58}
59
60#[typetag::serde]
61impl CustomOperationBody for LongDivision {
62    fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
63        if arguments_types.len() != 2 {
64            return Err(runtime_error!(
65                "Invalid number of arguments for LongDivision, given {}, expected 2",
66                arguments_types.len()
67            ));
68        }
69
70        let dividend_type = arguments_types[0].clone();
71        let divisor_type = arguments_types[1].clone();
72        if dividend_type.get_scalar_type() != BIT {
73            return Err(runtime_error!(
74                "Invalid scalar types for LongDivision: dividend scalar type {}, expected BIT",
75                dividend_type.get_scalar_type()
76            ));
77        }
78        if divisor_type.get_scalar_type() != BIT {
79            return Err(runtime_error!(
80                "Invalid scalar types for LongDivision: divisor scalar type {}, expected BIT",
81                dividend_type.get_scalar_type()
82            ));
83        }
84        if !divisor_type.is_array() {
85            return Err(runtime_error!("Divisor in LongDivision must be an array"));
86        }
87        if !dividend_type.is_array() {
88            return Err(runtime_error!("Dividend in LongDivision must be an array"));
89        }
90        let types = Types::new(dividend_type, divisor_type)?;
91        let g_iterate = single_iteration_graph(&context, types.clone())?;
92        let g = context.create_graph()?;
93        let dividend = g.input(types.divident_type.clone())?;
94        let divisor = g.input(types.divisor_type.clone())?;
95
96        // We compute abs(dividend) / abs(divisor) first, and adjust the results at the end.
97        let (dividend_is_negative, abs_dividend) = abs(dividend, self.signed)?;
98        let (divisor_is_negative, abs_divisor) = abs(divisor, self.signed)?;
99        let negative_abs_divisor = negative(abs_divisor.clone())?;
100        // Pull out dividend bits as first dimesion, and reverse them as we want to process bits
101        // starting with the most significant bit. We also pull out divisor bits as it's more
102        // efficient to work with them in this form.
103        let (dividend_pulled_bits, negative_abs_divisor_pulled_bits) =
104            pull_out_bits_pair(abs_dividend, negative_abs_divisor)?;
105
106        let dividend_pulled_bits =
107            dividend_pulled_bits.get_slice(vec![SliceElement::SubArray(None, None, Some(-1))])?;
108
109        // Iterate single bit computation over all dividend bits.
110        let state = g.create_tuple(vec![
111            g.zeros(types.remainder_pulled_bits_type.clone())?,
112            broadcast(
113                negative_abs_divisor_pulled_bits,
114                types.remainder_pulled_bits_type,
115            )?,
116        ])?;
117        let result = g.iterate(g_iterate, state, dividend_pulled_bits.array_to_vector()?)?;
118        let remainder = put_in_bits(result.tuple_get(0)?.tuple_get(0)?)?;
119        let quotient_pulled_bits = result.tuple_get(1)?.vector_to_array()?;
120
121        // Reverse quotient bits, and put bits back into the last dimension.
122        let quotient_pulled_bits =
123            quotient_pulled_bits.get_slice(vec![SliceElement::SubArray(None, None, Some(-1))])?;
124        let quotient = put_in_bits(quotient_pulled_bits)?;
125
126        let (quotient, remainder) = if self.signed {
127            adjust_negative(
128                quotient,
129                remainder,
130                abs_divisor,
131                dividend_is_negative,
132                divisor_is_negative,
133            )?
134        } else {
135            (quotient, remainder)
136        };
137        let output = g.create_tuple(vec![quotient, remainder])?;
138        output.set_as_output()?;
139        g.finalize()?;
140        Ok(g)
141    }
142
143    fn get_name(&self) -> String {
144        format!("LongDivision(signed={})", self.signed)
145    }
146}
147
148#[derive(Debug, Clone)]
149struct Types {
150    divident_type: Type,
151    divisor_type: Type,
152    remainder_pulled_bits_type: Type,
153    quotient_pulled_bit_type: Type,
154    dividend_no_bits_type: Type,
155    quotient_no_bits_type: Type,
156}
157
158impl Types {
159    fn new(divident_type: Type, divisor_type: Type) -> Result<Self> {
160        let (dividend_no_bits_shape, _dividend_bits) = pop_last_dim(divident_type.get_dimensions());
161        let (divisor_no_bits_shape, divisor_bits) = pop_last_dim(divisor_type.get_dimensions());
162        let output_no_bits_shape =
163            broadcast_shapes(dividend_no_bits_shape.clone(), divisor_no_bits_shape)?;
164        let dividend_no_bits_shape =
165            prepend_dims(dividend_no_bits_shape, output_no_bits_shape.len())?;
166        let remainder_pulled_bits_shape =
167            [vec![divisor_bits], output_no_bits_shape.clone()].concat();
168        let quotient_pulled_bit_shape = [vec![1], output_no_bits_shape.clone()].concat();
169        let quotient_no_bits_shape = output_no_bits_shape;
170        Ok(Self {
171            divident_type,
172            divisor_type,
173            remainder_pulled_bits_type: array_type(remainder_pulled_bits_shape, BIT),
174            quotient_pulled_bit_type: array_type(quotient_pulled_bit_shape, BIT),
175            dividend_no_bits_type: array_type(dividend_no_bits_shape, BIT),
176            quotient_no_bits_type: array_type(quotient_no_bits_shape, BIT),
177        })
178    }
179}
180
181fn broadcast(node: Node, want_type: Type) -> Result<Node> {
182    let g = node.get_graph();
183    if node.get_type()? == want_type {
184        Ok(node)
185    } else {
186        g.zeros(want_type)?.add(node)
187    }
188}
189
190fn single_iteration_graph(context: &Context, types: Types) -> Result<Graph> {
191    // In the state we store (remainder, abs(divisor), -abs(divisor)), with bits dimension pulled
192    // out to the outermost level.
193    // The remainder is updated in each iteration, while the other two are just passed through all
194    // iterations.
195    let state_type = tuple_type(vec![
196        types.remainder_pulled_bits_type.clone(),
197        types.remainder_pulled_bits_type.clone(),
198    ]);
199
200    let g = context.create_graph()?;
201    // Prepare inputs.
202    let old_state = g.input(state_type)?;
203    let next_dividend_bit = g.input(types.dividend_no_bits_type.clone())?;
204    let remainder = old_state.tuple_get(0)?;
205    let minus_divisor = old_state.tuple_get(1)?;
206
207    // Get rid of the most-significant bit of the remainder, and append the next_dividend_bit.
208    let remainder = remainder.get_slice(vec![SliceElement::SubArray(None, Some(-1), None)])?;
209    // Broadcast the next_dividend_bit before concatenation, if needed.
210    let next_dividend_bit = broadcast(next_dividend_bit, types.quotient_pulled_bit_type.clone())?;
211    let remainder = g.concatenate(vec![next_dividend_bit, remainder], 0)?;
212
213    // Compute the new remainder.
214    let remainder_minus_divisor_with_carry = g.custom_op(
215        CustomOperation::new(BinaryAddTransposed { overflow_bit: true }),
216        vec![remainder.clone(), minus_divisor.clone()],
217    )?;
218    let next_quotient_bit = remainder_minus_divisor_with_carry.tuple_get(1)?;
219    let remainder_minus_divisor = remainder_minus_divisor_with_carry.tuple_get(0)?;
220    let new_remainder = g.custom_op(
221        CustomOperation::new(Mux {}),
222        vec![
223            next_quotient_bit.clone(),
224            remainder_minus_divisor,
225            remainder,
226        ],
227    )?;
228
229    let new_state = g.create_tuple(vec![new_remainder, minus_divisor])?;
230    let output = g.create_tuple(vec![
231        new_state,
232        next_quotient_bit.reshape(types.quotient_no_bits_type)?,
233    ])?;
234    output.set_as_output()?;
235    g.finalize()?;
236    Ok(g)
237}
238
239fn adjust_negative(
240    quotient: Node,
241    remainder: Node,
242    abs_divisor: Node,
243    dividend_is_negative: Node,
244    divisor_is_negative: Node,
245) -> Result<(Node, Node)> {
246    // We compute the quotient and remainder using the same logic as numpy's // and %.
247    let g = quotient.get_graph();
248    let result_is_negative = dividend_is_negative.add(divisor_is_negative.clone())?;
249    let remainder_bits = pop_last_dim(remainder.get_type()?.get_dimensions()).1;
250    let remainder_is_zero = unsqueeze(
251        g.custom_op(
252            CustomOperation::new(Equal {}),
253            vec![
254                remainder.clone(),
255                g.zeros(array_type(vec![remainder_bits], BIT))?,
256            ],
257        )?,
258        -1,
259    )?;
260    // quotient = (-quotient if remainder is 0 else -quotient-1)
261    //            if result_is_negative else quotient
262    let inverted_quotient = invert_bits(quotient.clone())?; // a.k.a (-quotient-1)
263    let negative_quotient = add_one(inverted_quotient.clone())?;
264    let quotient = g.custom_op(
265        CustomOperation::new(Mux {}),
266        vec![
267            result_is_negative.clone(),
268            g.custom_op(
269                CustomOperation::new(Mux {}),
270                vec![
271                    remainder_is_zero.clone(),
272                    negative_quotient,
273                    inverted_quotient,
274                ],
275            )?,
276            quotient,
277        ],
278    )?;
279    // positive_remainder = 0 if remainder is 0
280    //                      else abs(divisor) - remainder if result_is_negative else remainder
281    let positive_remainder = g.custom_op(
282        CustomOperation::new(Mux {}),
283        vec![
284            remainder_is_zero,
285            remainder.clone(),
286            g.custom_op(
287                CustomOperation::new(Mux {}),
288                vec![
289                    result_is_negative,
290                    g.custom_op(
291                        CustomOperation::new(BinaryAdd {
292                            overflow_bit: false,
293                        }),
294                        vec![abs_divisor, negative(remainder.clone())?],
295                    )?,
296                    remainder,
297                ],
298            )?,
299        ],
300    )?;
301    // If the divisor is negative, we need to return negative divisor, to satisfy:
302    //   dividend = divisor * quotient + remainder
303    let remainder = g.custom_op(
304        CustomOperation::new(Mux {}),
305        vec![
306            divisor_is_negative,
307            negative(positive_remainder.clone())?,
308            positive_remainder,
309        ],
310    )?;
311    Ok((quotient, remainder))
312}
313
314// Returns the array type with last dimension removed.
315fn pop_last_dim(shape: ArrayShape) -> (ArrayShape, u64) {
316    let last = shape[shape.len() - 1];
317    (shape[..shape.len() - 1].to_vec(), last)
318}
319
320fn add_one(binary_num: Node) -> Result<Node> {
321    let dims = binary_num.get_type()?.get_dimensions();
322    let bits = dims[dims.len() - 1];
323    let g = binary_num.get_graph();
324    let binary_one = g.concatenate(
325        vec![
326            g.ones(array_type(vec![1], BIT))?,
327            g.zeros(array_type(vec![bits - 1], BIT))?,
328        ],
329        0,
330    )?;
331    g.custom_op(
332        CustomOperation::new(BinaryAdd {
333            overflow_bit: false,
334        }),
335        vec![binary_num, binary_one],
336    )
337}
338
339fn invert_bits(binary_num: Node) -> Result<Node> {
340    let g = binary_num.get_graph();
341    g.custom_op(CustomOperation::new(Not {}), vec![binary_num])
342}
343
344// Returns -binary_num, i.e. two's complement of a given number.
345fn negative(binary_num: Node) -> Result<Node> {
346    add_one(invert_bits(binary_num)?)
347}
348
349// Returns 1 where signed number is negative, and 0 elsewhere.
350fn is_negative(binary_num: Node) -> Result<Node> {
351    binary_num.get_slice(vec![
352        SliceElement::Ellipsis,
353        SliceElement::SubArray(Some(-1), None, None),
354    ])
355}
356
357// Returns (is_negative(binary_num), abs(binary_num)).
358fn abs(binary_num: Node, is_signed: bool) -> Result<(Node, Node)> {
359    let g = binary_num.get_graph();
360    if is_signed {
361        let num_is_negative = is_negative(binary_num.clone())?;
362        let abs = g.custom_op(
363            CustomOperation::new(Mux {}),
364            vec![
365                num_is_negative.clone(),
366                negative(binary_num.clone())?,
367                binary_num,
368            ],
369        )?;
370        Ok((num_is_negative, abs))
371    } else {
372        Ok((g.zeros(scalar_type(BIT))?, binary_num))
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use ndarray::array;
379
380    use super::*;
381    use crate::custom_ops::{run_instantiation_pass, CustomOperation};
382    use crate::data_types::{array_type, ScalarType, INT32, INT64, INT8, UINT8};
383    use crate::data_values::Value;
384    use crate::evaluators::random_evaluate;
385    use crate::graphs::util::simple_context;
386    use crate::typed_value::TypedValue;
387    use crate::typed_value_operations::TypedValueArrayOperations;
388
389    #[test]
390    fn test_long_division_i32_i8() -> Result<()> {
391        let (dividends, divisors, want_q, want_r) = unzip::<i32, i8>(vec![
392            (55557, 5, 11111, 2),
393            (-55557, 5, -11112, 3),
394            (55557, -5, -11112, -3),
395            (-55557, -5, 11111, -2),
396            (2147483647, 64, 33554431, 63),
397            (-2147483648, 64, -33554432, 0),
398            (2147483647, 1, 2147483647, 0),
399            (-2147483648, 1, -2147483648, 0),
400            (-2147483648, -1, -2147483648, 0), // quotient should be positive, but overflows.
401            (1, 5, 0, 1),
402            (-1, 5, -1, 4),
403            (0, 1, 0, 0),
404            (0, -1, 0, 0),
405            (0, 0, 0, 0), // Division by zero happens to evaluate to this value.
406        ]);
407        let (q, r) = long_division_helper(dividends.clone(), divisors.clone(), INT32, INT8)?;
408        assert_eq!(q.value.to_flattened_array_i32(q.t)?, want_q);
409        assert_eq!(r.value.to_flattened_array_i8(r.t)?, want_r);
410        Ok(())
411    }
412
413    #[test]
414    fn test_long_division_u8_u8() -> Result<()> {
415        let (dividends, divisors, want_q, want_r) = unzip::<u8, u8>(vec![
416            (255, 1, 255, 0),
417            (51, 2, 25, 1),
418            (85, 6, 14, 1),
419            (75, 4, 18, 3),
420            (161, 5, 32, 1),
421            (173, 6, 28, 5),
422            (78, 2, 39, 0),
423            (235, 43, 5, 20),
424            (244, 228, 1, 16),
425            (98, 65, 1, 33),
426            (35, 6, 5, 5),
427            (187, 249, 0, 187),
428            (209, 94, 2, 21),
429            (196, 179, 1, 17),
430            (112, 213, 0, 112),
431            (129, 70, 1, 59),
432            (223, 125, 1, 98),
433            (0, 1, 0, 0),
434            (0, 0, 0, 0), // Division by zero happens to evaluate to this value.
435        ]);
436        let (q, r) = long_division_helper(dividends.clone(), divisors.clone(), UINT8, UINT8)?;
437        assert_eq!(q.value.to_flattened_array_u8(q.t)?, want_q);
438        assert_eq!(r.value.to_flattened_array_u8(r.t)?, want_r);
439        Ok(())
440    }
441
442    #[test]
443    fn test_long_division_i64_i64() -> Result<()> {
444        let (dividends, divisors, want_q, want_r) = unzip::<i64, i64>(vec![
445            (9223372036854775807, 1, 9223372036854775807, 0),
446            (-9223372036854775808, 1, -9223372036854775808, 0),
447            (-9223372036854775808, -1, -9223372036854775808, 0), // quotient should be positive, but overflows.
448            (9223372036854775807, 9223372036854775807, 1, 0),
449            (-9223372036854775808, -9223372036854775808, 1, 0),
450            (-9223372036854775808, -9223372036854775808, 1, 0),
451            (3391070024636615284, 243545908, 13923740507, 102919928),
452            (3982195138714201679, -589530672, -6754856580, -156820081),
453            (-8836348637758589809, 111540404, -79221056415, 77301851),
454            (-2780817202823147876, -882478846, 3151143186, -461104520),
455        ]);
456        let (q, r) = long_division_helper(dividends.clone(), divisors.clone(), INT64, INT64)?;
457        assert_eq!(q.value.to_flattened_array_i64(q.t)?, want_q);
458        assert_eq!(r.value.to_flattened_array_i64(r.t)?, want_r);
459        Ok(())
460    }
461
462    #[test]
463    fn test_broadcast_divisor() -> Result<()> {
464        let x = TypedValue::from_ndarray(array![[7, 8, 9], [-7, -8, -9]].into_dyn(), INT8)?;
465        let y = TypedValue::from_ndarray(array![3].into_dyn(), INT8)?;
466        let c = simple_context(|g| {
467            let x = g.input(x.t.clone())?.a2b()?;
468            let y = g.input(y.t.clone())?.a2b()?;
469            let z = g.custom_op(
470                CustomOperation::new(LongDivision { signed: true }),
471                vec![x, y],
472            )?;
473            let q = z.tuple_get(0)?.b2a(INT8)?;
474            let r = z.tuple_get(1)?.b2a(INT8)?;
475            g.create_tuple(vec![q, r])
476        })?;
477        let c = run_instantiation_pass(c)?.context;
478        let g = c.get_main_graph()?;
479        let z = random_evaluate(g, vec![x.value, y.value])?.to_vector()?;
480        let r_t = array_type(vec![2, 3], INT8);
481        let q_t = array_type(vec![2, 3], INT8);
482        assert_eq!(z[0].to_flattened_array_i8(r_t)?, [2, 2, 3, -3, -3, -3]);
483        assert_eq!(z[1].to_flattened_array_i8(q_t)?, [1, 2, 0, 2, 1, 0]);
484        Ok(())
485    }
486
487    #[test]
488    fn test_broadcast_dividend() -> Result<()> {
489        let x = TypedValue::from_ndarray(array![10].into_dyn(), INT8)?;
490        let y = TypedValue::from_ndarray(array![[1, 2, 3], [-1, -2, -3]].into_dyn(), INT8)?;
491        let c = simple_context(|g| {
492            let x = g.input(x.t.clone())?.a2b()?;
493            let y = g.input(y.t.clone())?.a2b()?;
494            let z = g.custom_op(
495                CustomOperation::new(LongDivision { signed: true }),
496                vec![x, y],
497            )?;
498            let q = z.tuple_get(0)?.b2a(INT8)?;
499            let r = z.tuple_get(1)?.b2a(INT8)?;
500            g.create_tuple(vec![q, r])
501        })?;
502        let c = run_instantiation_pass(c)?.context;
503        let g = c.get_main_graph()?;
504        let z = random_evaluate(g, vec![x.value, y.value])?.to_vector()?;
505        let r_t = array_type(vec![2, 3], INT8);
506        let q_t = array_type(vec![2, 3], INT8);
507        assert_eq!(z[0].to_flattened_array_i8(r_t)?, [10, 5, 3, -10, -5, -4]);
508        assert_eq!(z[1].to_flattened_array_i8(q_t)?, [0, 0, 1, 0, 0, -2]);
509        Ok(())
510    }
511
512    fn unzip<A, B>(rows: Vec<(i64, i64, A, B)>) -> (Vec<i64>, Vec<i64>, Vec<A>, Vec<B>) {
513        let mut dividends = vec![];
514        let mut divisors = vec![];
515        let mut quotients = vec![];
516        let mut remainders = vec![];
517        for (dividend, divisor, quotient, remainder) in rows {
518            dividends.push(dividend);
519            divisors.push(divisor);
520            quotients.push(quotient);
521            remainders.push(remainder);
522        }
523        (dividends, divisors, quotients, remainders)
524    }
525
526    fn long_division_helper(
527        dividends: Vec<i64>,
528        divisors: Vec<i64>,
529        dividend_st: ScalarType,
530        divisor_st: ScalarType,
531    ) -> Result<(TypedValue, TypedValue)> {
532        let n = dividends.len();
533        if n != divisors.len() {
534            return Err(runtime_error!("dividends and divisors length mismatch"));
535        }
536        if dividend_st.is_signed() != divisor_st.is_signed() {
537            return Err(runtime_error!("dividends and divisors signed mismatch"));
538        }
539        let dividends_t = array_type(vec![n as u64], dividend_st);
540        let divisors_t = array_type(vec![n as u64], divisor_st);
541        let c = simple_context(|g| {
542            let input_dividends = g.input(dividends_t.clone())?;
543            let input_divisors = g.input(divisors_t.clone())?;
544            let binary_dividends = input_dividends.a2b()?;
545            let binary_divisors = input_divisors.a2b()?;
546            let result = g.custom_op(
547                CustomOperation::new(LongDivision {
548                    signed: dividend_st.is_signed(),
549                }),
550                vec![binary_dividends, binary_divisors],
551            )?;
552            let quotient = result.tuple_get(0)?.b2a(dividend_st)?;
553            let remainder = result.tuple_get(1)?.b2a(divisor_st)?;
554            g.create_tuple(vec![quotient, remainder])
555        })?;
556        let c = run_instantiation_pass(c)?.context;
557        let g = c.get_main_graph()?;
558        let result = random_evaluate(
559            g,
560            vec![
561                Value::from_flattened_array(&dividends, dividend_st)?,
562                Value::from_flattened_array(&divisors, divisor_st)?,
563            ],
564        )?
565        .to_vector()?;
566        Ok((
567            TypedValue {
568                value: result[0].clone(),
569                t: dividends_t,
570                name: None,
571            },
572            TypedValue {
573                value: result[1].clone(),
574                t: divisors_t,
575                name: None,
576            },
577        ))
578    }
579}