1use crate::broadcast::broadcast_shapes;
3use crate::custom_ops::{CustomOperation, CustomOperationBody, Not};
4use crate::data_types::{array_type, scalar_type, tuple_type, ArrayShape, Type, BIT};
5use crate::errors::Result;
6use crate::graphs::{Context, Graph, Node, SliceElement};
7use crate::ops::multiplexer::Mux;
8use crate::ops::utils::unsqueeze;
9
10use serde::{Deserialize, Serialize};
11
12use super::adder::{BinaryAdd, BinaryAddTransposed};
13use super::comparisons::Equal;
14use super::utils::{prepend_dims, pull_out_bits_pair, put_in_bits};
15
16#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
56pub struct LongDivision {
57 pub signed: bool,
58}
59
60#[typetag::serde]
61impl CustomOperationBody for LongDivision {
62 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
63 if arguments_types.len() != 2 {
64 return Err(runtime_error!(
65 "Invalid number of arguments for LongDivision, given {}, expected 2",
66 arguments_types.len()
67 ));
68 }
69
70 let dividend_type = arguments_types[0].clone();
71 let divisor_type = arguments_types[1].clone();
72 if dividend_type.get_scalar_type() != BIT {
73 return Err(runtime_error!(
74 "Invalid scalar types for LongDivision: dividend scalar type {}, expected BIT",
75 dividend_type.get_scalar_type()
76 ));
77 }
78 if divisor_type.get_scalar_type() != BIT {
79 return Err(runtime_error!(
80 "Invalid scalar types for LongDivision: divisor scalar type {}, expected BIT",
81 dividend_type.get_scalar_type()
82 ));
83 }
84 if !divisor_type.is_array() {
85 return Err(runtime_error!("Divisor in LongDivision must be an array"));
86 }
87 if !dividend_type.is_array() {
88 return Err(runtime_error!("Dividend in LongDivision must be an array"));
89 }
90 let types = Types::new(dividend_type, divisor_type)?;
91 let g_iterate = single_iteration_graph(&context, types.clone())?;
92 let g = context.create_graph()?;
93 let dividend = g.input(types.divident_type.clone())?;
94 let divisor = g.input(types.divisor_type.clone())?;
95
96 let (dividend_is_negative, abs_dividend) = abs(dividend, self.signed)?;
98 let (divisor_is_negative, abs_divisor) = abs(divisor, self.signed)?;
99 let negative_abs_divisor = negative(abs_divisor.clone())?;
100 let (dividend_pulled_bits, negative_abs_divisor_pulled_bits) =
104 pull_out_bits_pair(abs_dividend, negative_abs_divisor)?;
105
106 let dividend_pulled_bits =
107 dividend_pulled_bits.get_slice(vec![SliceElement::SubArray(None, None, Some(-1))])?;
108
109 let state = g.create_tuple(vec![
111 g.zeros(types.remainder_pulled_bits_type.clone())?,
112 broadcast(
113 negative_abs_divisor_pulled_bits,
114 types.remainder_pulled_bits_type,
115 )?,
116 ])?;
117 let result = g.iterate(g_iterate, state, dividend_pulled_bits.array_to_vector()?)?;
118 let remainder = put_in_bits(result.tuple_get(0)?.tuple_get(0)?)?;
119 let quotient_pulled_bits = result.tuple_get(1)?.vector_to_array()?;
120
121 let quotient_pulled_bits =
123 quotient_pulled_bits.get_slice(vec![SliceElement::SubArray(None, None, Some(-1))])?;
124 let quotient = put_in_bits(quotient_pulled_bits)?;
125
126 let (quotient, remainder) = if self.signed {
127 adjust_negative(
128 quotient,
129 remainder,
130 abs_divisor,
131 dividend_is_negative,
132 divisor_is_negative,
133 )?
134 } else {
135 (quotient, remainder)
136 };
137 let output = g.create_tuple(vec![quotient, remainder])?;
138 output.set_as_output()?;
139 g.finalize()?;
140 Ok(g)
141 }
142
143 fn get_name(&self) -> String {
144 format!("LongDivision(signed={})", self.signed)
145 }
146}
147
148#[derive(Debug, Clone)]
149struct Types {
150 divident_type: Type,
151 divisor_type: Type,
152 remainder_pulled_bits_type: Type,
153 quotient_pulled_bit_type: Type,
154 dividend_no_bits_type: Type,
155 quotient_no_bits_type: Type,
156}
157
158impl Types {
159 fn new(divident_type: Type, divisor_type: Type) -> Result<Self> {
160 let (dividend_no_bits_shape, _dividend_bits) = pop_last_dim(divident_type.get_dimensions());
161 let (divisor_no_bits_shape, divisor_bits) = pop_last_dim(divisor_type.get_dimensions());
162 let output_no_bits_shape =
163 broadcast_shapes(dividend_no_bits_shape.clone(), divisor_no_bits_shape)?;
164 let dividend_no_bits_shape =
165 prepend_dims(dividend_no_bits_shape, output_no_bits_shape.len())?;
166 let remainder_pulled_bits_shape =
167 [vec![divisor_bits], output_no_bits_shape.clone()].concat();
168 let quotient_pulled_bit_shape = [vec![1], output_no_bits_shape.clone()].concat();
169 let quotient_no_bits_shape = output_no_bits_shape;
170 Ok(Self {
171 divident_type,
172 divisor_type,
173 remainder_pulled_bits_type: array_type(remainder_pulled_bits_shape, BIT),
174 quotient_pulled_bit_type: array_type(quotient_pulled_bit_shape, BIT),
175 dividend_no_bits_type: array_type(dividend_no_bits_shape, BIT),
176 quotient_no_bits_type: array_type(quotient_no_bits_shape, BIT),
177 })
178 }
179}
180
181fn broadcast(node: Node, want_type: Type) -> Result<Node> {
182 let g = node.get_graph();
183 if node.get_type()? == want_type {
184 Ok(node)
185 } else {
186 g.zeros(want_type)?.add(node)
187 }
188}
189
190fn single_iteration_graph(context: &Context, types: Types) -> Result<Graph> {
191 let state_type = tuple_type(vec![
196 types.remainder_pulled_bits_type.clone(),
197 types.remainder_pulled_bits_type.clone(),
198 ]);
199
200 let g = context.create_graph()?;
201 let old_state = g.input(state_type)?;
203 let next_dividend_bit = g.input(types.dividend_no_bits_type.clone())?;
204 let remainder = old_state.tuple_get(0)?;
205 let minus_divisor = old_state.tuple_get(1)?;
206
207 let remainder = remainder.get_slice(vec![SliceElement::SubArray(None, Some(-1), None)])?;
209 let next_dividend_bit = broadcast(next_dividend_bit, types.quotient_pulled_bit_type.clone())?;
211 let remainder = g.concatenate(vec![next_dividend_bit, remainder], 0)?;
212
213 let remainder_minus_divisor_with_carry = g.custom_op(
215 CustomOperation::new(BinaryAddTransposed { overflow_bit: true }),
216 vec![remainder.clone(), minus_divisor.clone()],
217 )?;
218 let next_quotient_bit = remainder_minus_divisor_with_carry.tuple_get(1)?;
219 let remainder_minus_divisor = remainder_minus_divisor_with_carry.tuple_get(0)?;
220 let new_remainder = g.custom_op(
221 CustomOperation::new(Mux {}),
222 vec![
223 next_quotient_bit.clone(),
224 remainder_minus_divisor,
225 remainder,
226 ],
227 )?;
228
229 let new_state = g.create_tuple(vec![new_remainder, minus_divisor])?;
230 let output = g.create_tuple(vec![
231 new_state,
232 next_quotient_bit.reshape(types.quotient_no_bits_type)?,
233 ])?;
234 output.set_as_output()?;
235 g.finalize()?;
236 Ok(g)
237}
238
239fn adjust_negative(
240 quotient: Node,
241 remainder: Node,
242 abs_divisor: Node,
243 dividend_is_negative: Node,
244 divisor_is_negative: Node,
245) -> Result<(Node, Node)> {
246 let g = quotient.get_graph();
248 let result_is_negative = dividend_is_negative.add(divisor_is_negative.clone())?;
249 let remainder_bits = pop_last_dim(remainder.get_type()?.get_dimensions()).1;
250 let remainder_is_zero = unsqueeze(
251 g.custom_op(
252 CustomOperation::new(Equal {}),
253 vec![
254 remainder.clone(),
255 g.zeros(array_type(vec![remainder_bits], BIT))?,
256 ],
257 )?,
258 -1,
259 )?;
260 let inverted_quotient = invert_bits(quotient.clone())?; let negative_quotient = add_one(inverted_quotient.clone())?;
264 let quotient = g.custom_op(
265 CustomOperation::new(Mux {}),
266 vec![
267 result_is_negative.clone(),
268 g.custom_op(
269 CustomOperation::new(Mux {}),
270 vec![
271 remainder_is_zero.clone(),
272 negative_quotient,
273 inverted_quotient,
274 ],
275 )?,
276 quotient,
277 ],
278 )?;
279 let positive_remainder = g.custom_op(
282 CustomOperation::new(Mux {}),
283 vec![
284 remainder_is_zero,
285 remainder.clone(),
286 g.custom_op(
287 CustomOperation::new(Mux {}),
288 vec![
289 result_is_negative,
290 g.custom_op(
291 CustomOperation::new(BinaryAdd {
292 overflow_bit: false,
293 }),
294 vec![abs_divisor, negative(remainder.clone())?],
295 )?,
296 remainder,
297 ],
298 )?,
299 ],
300 )?;
301 let remainder = g.custom_op(
304 CustomOperation::new(Mux {}),
305 vec![
306 divisor_is_negative,
307 negative(positive_remainder.clone())?,
308 positive_remainder,
309 ],
310 )?;
311 Ok((quotient, remainder))
312}
313
314fn pop_last_dim(shape: ArrayShape) -> (ArrayShape, u64) {
316 let last = shape[shape.len() - 1];
317 (shape[..shape.len() - 1].to_vec(), last)
318}
319
320fn add_one(binary_num: Node) -> Result<Node> {
321 let dims = binary_num.get_type()?.get_dimensions();
322 let bits = dims[dims.len() - 1];
323 let g = binary_num.get_graph();
324 let binary_one = g.concatenate(
325 vec![
326 g.ones(array_type(vec![1], BIT))?,
327 g.zeros(array_type(vec![bits - 1], BIT))?,
328 ],
329 0,
330 )?;
331 g.custom_op(
332 CustomOperation::new(BinaryAdd {
333 overflow_bit: false,
334 }),
335 vec![binary_num, binary_one],
336 )
337}
338
339fn invert_bits(binary_num: Node) -> Result<Node> {
340 let g = binary_num.get_graph();
341 g.custom_op(CustomOperation::new(Not {}), vec![binary_num])
342}
343
344fn negative(binary_num: Node) -> Result<Node> {
346 add_one(invert_bits(binary_num)?)
347}
348
349fn is_negative(binary_num: Node) -> Result<Node> {
351 binary_num.get_slice(vec![
352 SliceElement::Ellipsis,
353 SliceElement::SubArray(Some(-1), None, None),
354 ])
355}
356
357fn abs(binary_num: Node, is_signed: bool) -> Result<(Node, Node)> {
359 let g = binary_num.get_graph();
360 if is_signed {
361 let num_is_negative = is_negative(binary_num.clone())?;
362 let abs = g.custom_op(
363 CustomOperation::new(Mux {}),
364 vec![
365 num_is_negative.clone(),
366 negative(binary_num.clone())?,
367 binary_num,
368 ],
369 )?;
370 Ok((num_is_negative, abs))
371 } else {
372 Ok((g.zeros(scalar_type(BIT))?, binary_num))
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use ndarray::array;
379
380 use super::*;
381 use crate::custom_ops::{run_instantiation_pass, CustomOperation};
382 use crate::data_types::{array_type, ScalarType, INT32, INT64, INT8, UINT8};
383 use crate::data_values::Value;
384 use crate::evaluators::random_evaluate;
385 use crate::graphs::util::simple_context;
386 use crate::typed_value::TypedValue;
387 use crate::typed_value_operations::TypedValueArrayOperations;
388
389 #[test]
390 fn test_long_division_i32_i8() -> Result<()> {
391 let (dividends, divisors, want_q, want_r) = unzip::<i32, i8>(vec![
392 (55557, 5, 11111, 2),
393 (-55557, 5, -11112, 3),
394 (55557, -5, -11112, -3),
395 (-55557, -5, 11111, -2),
396 (2147483647, 64, 33554431, 63),
397 (-2147483648, 64, -33554432, 0),
398 (2147483647, 1, 2147483647, 0),
399 (-2147483648, 1, -2147483648, 0),
400 (-2147483648, -1, -2147483648, 0), (1, 5, 0, 1),
402 (-1, 5, -1, 4),
403 (0, 1, 0, 0),
404 (0, -1, 0, 0),
405 (0, 0, 0, 0), ]);
407 let (q, r) = long_division_helper(dividends.clone(), divisors.clone(), INT32, INT8)?;
408 assert_eq!(q.value.to_flattened_array_i32(q.t)?, want_q);
409 assert_eq!(r.value.to_flattened_array_i8(r.t)?, want_r);
410 Ok(())
411 }
412
413 #[test]
414 fn test_long_division_u8_u8() -> Result<()> {
415 let (dividends, divisors, want_q, want_r) = unzip::<u8, u8>(vec![
416 (255, 1, 255, 0),
417 (51, 2, 25, 1),
418 (85, 6, 14, 1),
419 (75, 4, 18, 3),
420 (161, 5, 32, 1),
421 (173, 6, 28, 5),
422 (78, 2, 39, 0),
423 (235, 43, 5, 20),
424 (244, 228, 1, 16),
425 (98, 65, 1, 33),
426 (35, 6, 5, 5),
427 (187, 249, 0, 187),
428 (209, 94, 2, 21),
429 (196, 179, 1, 17),
430 (112, 213, 0, 112),
431 (129, 70, 1, 59),
432 (223, 125, 1, 98),
433 (0, 1, 0, 0),
434 (0, 0, 0, 0), ]);
436 let (q, r) = long_division_helper(dividends.clone(), divisors.clone(), UINT8, UINT8)?;
437 assert_eq!(q.value.to_flattened_array_u8(q.t)?, want_q);
438 assert_eq!(r.value.to_flattened_array_u8(r.t)?, want_r);
439 Ok(())
440 }
441
442 #[test]
443 fn test_long_division_i64_i64() -> Result<()> {
444 let (dividends, divisors, want_q, want_r) = unzip::<i64, i64>(vec![
445 (9223372036854775807, 1, 9223372036854775807, 0),
446 (-9223372036854775808, 1, -9223372036854775808, 0),
447 (-9223372036854775808, -1, -9223372036854775808, 0), (9223372036854775807, 9223372036854775807, 1, 0),
449 (-9223372036854775808, -9223372036854775808, 1, 0),
450 (-9223372036854775808, -9223372036854775808, 1, 0),
451 (3391070024636615284, 243545908, 13923740507, 102919928),
452 (3982195138714201679, -589530672, -6754856580, -156820081),
453 (-8836348637758589809, 111540404, -79221056415, 77301851),
454 (-2780817202823147876, -882478846, 3151143186, -461104520),
455 ]);
456 let (q, r) = long_division_helper(dividends.clone(), divisors.clone(), INT64, INT64)?;
457 assert_eq!(q.value.to_flattened_array_i64(q.t)?, want_q);
458 assert_eq!(r.value.to_flattened_array_i64(r.t)?, want_r);
459 Ok(())
460 }
461
462 #[test]
463 fn test_broadcast_divisor() -> Result<()> {
464 let x = TypedValue::from_ndarray(array![[7, 8, 9], [-7, -8, -9]].into_dyn(), INT8)?;
465 let y = TypedValue::from_ndarray(array![3].into_dyn(), INT8)?;
466 let c = simple_context(|g| {
467 let x = g.input(x.t.clone())?.a2b()?;
468 let y = g.input(y.t.clone())?.a2b()?;
469 let z = g.custom_op(
470 CustomOperation::new(LongDivision { signed: true }),
471 vec![x, y],
472 )?;
473 let q = z.tuple_get(0)?.b2a(INT8)?;
474 let r = z.tuple_get(1)?.b2a(INT8)?;
475 g.create_tuple(vec![q, r])
476 })?;
477 let c = run_instantiation_pass(c)?.context;
478 let g = c.get_main_graph()?;
479 let z = random_evaluate(g, vec![x.value, y.value])?.to_vector()?;
480 let r_t = array_type(vec![2, 3], INT8);
481 let q_t = array_type(vec![2, 3], INT8);
482 assert_eq!(z[0].to_flattened_array_i8(r_t)?, [2, 2, 3, -3, -3, -3]);
483 assert_eq!(z[1].to_flattened_array_i8(q_t)?, [1, 2, 0, 2, 1, 0]);
484 Ok(())
485 }
486
487 #[test]
488 fn test_broadcast_dividend() -> Result<()> {
489 let x = TypedValue::from_ndarray(array![10].into_dyn(), INT8)?;
490 let y = TypedValue::from_ndarray(array![[1, 2, 3], [-1, -2, -3]].into_dyn(), INT8)?;
491 let c = simple_context(|g| {
492 let x = g.input(x.t.clone())?.a2b()?;
493 let y = g.input(y.t.clone())?.a2b()?;
494 let z = g.custom_op(
495 CustomOperation::new(LongDivision { signed: true }),
496 vec![x, y],
497 )?;
498 let q = z.tuple_get(0)?.b2a(INT8)?;
499 let r = z.tuple_get(1)?.b2a(INT8)?;
500 g.create_tuple(vec![q, r])
501 })?;
502 let c = run_instantiation_pass(c)?.context;
503 let g = c.get_main_graph()?;
504 let z = random_evaluate(g, vec![x.value, y.value])?.to_vector()?;
505 let r_t = array_type(vec![2, 3], INT8);
506 let q_t = array_type(vec![2, 3], INT8);
507 assert_eq!(z[0].to_flattened_array_i8(r_t)?, [10, 5, 3, -10, -5, -4]);
508 assert_eq!(z[1].to_flattened_array_i8(q_t)?, [0, 0, 1, 0, 0, -2]);
509 Ok(())
510 }
511
512 fn unzip<A, B>(rows: Vec<(i64, i64, A, B)>) -> (Vec<i64>, Vec<i64>, Vec<A>, Vec<B>) {
513 let mut dividends = vec![];
514 let mut divisors = vec![];
515 let mut quotients = vec![];
516 let mut remainders = vec![];
517 for (dividend, divisor, quotient, remainder) in rows {
518 dividends.push(dividend);
519 divisors.push(divisor);
520 quotients.push(quotient);
521 remainders.push(remainder);
522 }
523 (dividends, divisors, quotients, remainders)
524 }
525
526 fn long_division_helper(
527 dividends: Vec<i64>,
528 divisors: Vec<i64>,
529 dividend_st: ScalarType,
530 divisor_st: ScalarType,
531 ) -> Result<(TypedValue, TypedValue)> {
532 let n = dividends.len();
533 if n != divisors.len() {
534 return Err(runtime_error!("dividends and divisors length mismatch"));
535 }
536 if dividend_st.is_signed() != divisor_st.is_signed() {
537 return Err(runtime_error!("dividends and divisors signed mismatch"));
538 }
539 let dividends_t = array_type(vec![n as u64], dividend_st);
540 let divisors_t = array_type(vec![n as u64], divisor_st);
541 let c = simple_context(|g| {
542 let input_dividends = g.input(dividends_t.clone())?;
543 let input_divisors = g.input(divisors_t.clone())?;
544 let binary_dividends = input_dividends.a2b()?;
545 let binary_divisors = input_divisors.a2b()?;
546 let result = g.custom_op(
547 CustomOperation::new(LongDivision {
548 signed: dividend_st.is_signed(),
549 }),
550 vec![binary_dividends, binary_divisors],
551 )?;
552 let quotient = result.tuple_get(0)?.b2a(dividend_st)?;
553 let remainder = result.tuple_get(1)?.b2a(divisor_st)?;
554 g.create_tuple(vec![quotient, remainder])
555 })?;
556 let c = run_instantiation_pass(c)?.context;
557 let g = c.get_main_graph()?;
558 let result = random_evaluate(
559 g,
560 vec![
561 Value::from_flattened_array(÷nds, dividend_st)?,
562 Value::from_flattened_array(&divisors, divisor_st)?,
563 ],
564 )?
565 .to_vector()?;
566 Ok((
567 TypedValue {
568 value: result[0].clone(),
569 t: dividends_t,
570 name: None,
571 },
572 TypedValue {
573 value: result[1].clone(),
574 t: divisors_t,
575 name: None,
576 },
577 ))
578 }
579}