1use crate::custom_ops::{CustomOperation, CustomOperationBody};
3use crate::data_types::{array_type, Type, BIT};
4use crate::errors::Result;
5use crate::graphs::{Context, Graph, Node, SliceElement};
6use crate::ops::utils::{expand_dims, put_in_bits};
7
8use serde::{Deserialize, Serialize};
9
10use super::utils::{pull_out_bits_pair, validate_arguments_in_broadcast_bit_ops};
11
12#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
52pub struct BinaryAdd {
53 pub overflow_bit: bool,
54}
55
56#[typetag::serde]
57impl CustomOperationBody for BinaryAdd {
58 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
59 validate_arguments_in_broadcast_bit_ops(arguments_types.clone(), &self.get_name())?;
60 let input_type0 = arguments_types[0].clone();
61 let input_type1 = arguments_types[1].clone();
62
63 let g = context.create_graph()?;
65 let (input0, input1) = pull_out_bits_pair(g.input(input_type0)?, g.input(input_type1)?)?;
66 let added = g.custom_op(
67 CustomOperation::new(BinaryAddTransposed {
68 overflow_bit: self.overflow_bit,
69 }),
70 vec![input0, input1],
71 )?;
72 let output = if self.overflow_bit {
73 g.create_tuple(vec![
74 put_in_bits(added.tuple_get(0)?)?,
75 put_in_bits(added.tuple_get(1)?)?,
76 ])?
77 } else {
78 put_in_bits(added)?
79 };
80 output.set_as_output()?;
81 g.finalize()?;
82 Ok(g)
83 }
84
85 fn get_name(&self) -> String {
86 format!("BinaryAdd(overflow_bit={})", self.overflow_bit)
87 }
88}
89
90#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
93pub(crate) struct BinaryAddTransposed {
94 pub overflow_bit: bool,
95}
96
97#[typetag::serde]
98impl CustomOperationBody for BinaryAddTransposed {
99 fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
100 if arguments_types.len() != 2 {
101 return Err(runtime_error!("Invalid number of arguments"));
102 }
103 match (&arguments_types[0], &arguments_types[1]) {
104 (Type::Array(shape0, scalar_type0), Type::Array(shape1, scalar_type1)) => {
105 if shape0[0] != shape1[0] {
106 return Err(runtime_error!(
107 "Input arrays' first dimensions are not the same"
108 ));
109 }
110 if *scalar_type0 != BIT {
111 return Err(runtime_error!("Input array [0]'s ScalarType is not BIT"));
112 }
113 if *scalar_type1 != BIT {
114 return Err(runtime_error!("Input array [1]'s ScalarType is not BIT"));
115 }
116 }
117 _ => {
118 return Err(runtime_error!(
119 "Invalid input argument type, expected Array type"
120 ));
121 }
122 }
123
124 let input_type0 = arguments_types[0].clone();
125 let input_type1 = arguments_types[1].clone();
126
127 let g = context.create_graph()?;
129 let input0 = g.input(input_type0)?;
130 let input1 = g.input(input_type1)?;
131 let xor_bits = g.add(input0.clone(), input1.clone())?;
133 let and_bits = g.multiply(input0, input1)?;
135
136 let (carries, overflow_bit) =
137 calculate_carry_bits(xor_bits.clone(), and_bits, self.overflow_bit)?;
138 let added = carries.add(xor_bits)?;
140 let output = match overflow_bit {
141 Some(overflow_bit) => g.create_tuple(vec![added, overflow_bit])?,
142 None => added,
143 };
144 output.set_as_output()?;
145 g.finalize()?;
146 Ok(g)
147 }
148
149 fn get_name(&self) -> String {
150 format!("BinaryAddTransposed(overflow_bit={})", self.overflow_bit)
151 }
152}
153
154#[derive(Clone)]
158struct CarryNode {
159 propagate: Node,
160 generate: Node,
161}
162
163impl CarryNode {
164 fn bit_len(&self) -> Result<u64> {
165 Ok(self.propagate.get_type()?.get_shape()[0])
166 }
167
168 fn shrink(&self, overflow_bit: bool) -> Result<CarryNode> {
169 let bit_len = self.bit_len()? as i64;
170
171 let next_lvl_bits = if overflow_bit {
172 bit_len / 2
173 } else {
174 (bit_len - 1) / 2
175 };
176 let use_bits = next_lvl_bits * 2;
177 let lower = self.sub_slice(0, use_bits)?;
178 let higher = self.sub_slice(1, use_bits)?;
179
180 lower.join(&higher)
181 }
182
183 fn join(&self, rhs: &Self) -> Result<Self> {
185 let propagate = self.propagate.multiply(rhs.propagate.clone())?;
186 let generate = rhs
187 .generate
188 .add(rhs.propagate.multiply(self.generate.clone())?)?;
189 Ok(Self {
190 propagate,
191 generate,
192 })
193 }
194
195 fn sub_slice(&self, start_offset: i64, bit_len: i64) -> Result<Self> {
197 let get_slice = |node: &Node| {
198 node.get_slice(vec![SliceElement::SubArray(
199 Some(start_offset),
200 Some(bit_len),
201 Some(2),
202 )])
203 };
204 Ok(Self {
205 propagate: get_slice(&self.propagate)?,
206 generate: get_slice(&self.generate)?,
207 })
208 }
209
210 fn apply(&self, prev_carry: Node) -> Result<Node> {
211 self.generate.add(self.propagate.multiply(prev_carry)?)
212 }
213}
214
215fn interleave(first: Node, second: Node) -> Result<Node> {
219 let first = expand_dims(first, &[0])?;
220 let second = expand_dims(second, &[0])?;
221 let graph = first.get_graph();
222 let joined = graph.concatenate(vec![first, second], 0)?;
223 let mut axes: Vec<_> = (0..joined.get_type()?.get_shape().len() as u64).collect();
224 axes.swap(0, 1);
225 let joined = joined.permute_axes(axes)?;
226 let mut shape = joined.get_type()?.get_shape();
227 shape[0] *= 2;
228 shape.remove(1);
229 let scalar = joined.get_type()?.get_scalar_type();
230 joined.reshape(array_type(shape, scalar))
231}
232
233fn calculate_carry_bits(
285 propagate_bits: Node,
286 generate_bits: Node,
287 overflow_bit: bool,
288) -> Result<(Node, Option<Node>)> {
289 let graph = propagate_bits.get_graph();
290
291 let mut nodes = vec![CarryNode {
292 propagate: propagate_bits,
293 generate: generate_bits,
294 }];
295 let bit_len = nodes[0].bit_len()?;
296 if !bit_len.is_power_of_two() {
297 return Err(runtime_error!("BinaryAdd only supports numbers with number of bits, which is a power of 2. {} bits provided.", bit_len));
298 }
299 let mut shape = nodes[0].propagate.get_type()?.get_shape();
300 shape[0] = 1;
301 let mut carries = graph.zeros(array_type(shape, BIT))?;
302 if !overflow_bit && bit_len == 1 {
305 return Ok((carries, None));
306 }
307 if overflow_bit || bit_len > 2 {
309 while nodes.last().unwrap().bit_len()? > 1 {
310 let last = nodes.last().unwrap();
311 nodes.push(last.shrink(overflow_bit)?);
312 }
313 }
314
315 let mut node_rev_iter = nodes.iter().rev();
316 let overflow_bit = if overflow_bit {
317 let root_node = node_rev_iter.next().unwrap();
318 Some(root_node.apply(carries.clone())?)
319 } else {
320 None
321 };
322 for node in node_rev_iter {
323 let lower = node.sub_slice(0, node.bit_len()? as i64)?;
324 let new_carries = lower.apply(carries.clone())?;
325 carries = interleave(carries, new_carries)?;
326 }
327
328 Ok((carries, overflow_bit))
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 use crate::custom_ops::{run_instantiation_pass, CustomOperation};
336 use crate::data_types::{
337 array_type, tuple_type, ScalarType, INT16, INT64, UINT16, UINT32, UINT64, UINT8,
338 };
339 use crate::data_values::Value;
340 use crate::evaluators::random_evaluate;
341 use crate::graphs::create_context;
342 use crate::graphs::util::simple_context;
343
344 fn test_helper(first: u64, second: u64, st: ScalarType) -> Result<()> {
345 let bits = st.size_in_bits();
346 let mask = (1u128 << bits) - 1;
347 let first = (first as u128) & mask;
348 let second = (second as u128) & mask;
349
350 let c = simple_context(|g| {
351 let i1 = g.input(array_type(vec![bits], BIT))?;
352 let i2 = g.input(array_type(vec![bits], BIT))?;
353 let o = g.custom_op(
354 CustomOperation::new(BinaryAdd {
355 overflow_bit: false,
356 }),
357 vec![i1, i2],
358 )?;
359 assert_eq!(
360 o.get_type()?.get_dimensions(),
361 vec![bits],
362 "{first} + {second} with {bits} bits"
363 );
364 Ok(o)
365 })?;
366 let mapped_c = run_instantiation_pass(c)?;
367 let input0 = Value::from_scalar(first, st)?;
368 let input1 = Value::from_scalar(second, st)?;
369 let result_v = random_evaluate(
370 mapped_c.get_context().get_main_graph()?,
371 vec![input0, input1],
372 )?
373 .to_u128(st)?;
374
375 let expected_result = first.wrapping_add(second) & mask;
376 assert_eq!(
377 result_v, expected_result,
378 "{first} + {second} with {bits} bits"
379 );
380 Ok(())
381 }
382
383 #[test]
384 fn test_random_inputs() -> Result<()> {
385 let random_numbers = [0, 1, 3, 4, 10, 100500, 123456, 787788];
386 for st in [BIT, UINT8, UINT16, UINT32, UINT64] {
387 for &x in random_numbers.iter() {
388 for &y in random_numbers.iter() {
389 test_helper(x, y, st)?;
390 }
391 }
392 }
393 Ok(())
394 }
395
396 fn add_with_overflow_helper(first: u64, second: u64, st: ScalarType) -> Result<(u64, u64)> {
397 let bits = st.size_in_bits();
398 let c = simple_context(|g| {
399 let i1 = g.input(array_type(vec![bits], BIT))?;
400 let i2 = g.input(array_type(vec![bits], BIT))?;
401 g.custom_op(
402 CustomOperation::new(BinaryAdd { overflow_bit: true }),
403 vec![i1, i2],
404 )
405 })?;
406 let mapped_c = run_instantiation_pass(c)?;
407 let input0 = Value::from_scalar(first, st)?;
408 let input1 = Value::from_scalar(second, st)?;
409 let results = random_evaluate(
410 mapped_c.get_context().get_main_graph()?,
411 vec![input0, input1],
412 )?
413 .to_vector()?;
414 Ok((results[0].to_u64(st)?, results[1].to_u64(BIT)?))
415 }
416
417 #[test]
418 fn test_add_with_overflow_bit() -> Result<()> {
419 for (first, second, st, want_sum, want_overflow) in [
420 (0, 0, BIT, 0, 0),
421 (0, 1, BIT, 1, 0),
422 (1, 0, BIT, 1, 0),
423 (1, 1, BIT, 0, 1),
424 (127, 128, UINT8, 255, 0),
425 (127, 129, UINT8, 0, 1),
426 (128, 128, UINT8, 0, 1),
427 (255, 255, UINT8, 254, 1),
428 (1234, 4321, UINT16, 5555, 0),
429 (12345, 54321, UINT16, 1130, 1),
430 (12345, 54321, UINT32, 66666, 0),
431 (2000000000, 2000000000, UINT32, 4000000000, 0),
432 (2000000000, 3000000000, UINT32, 705032704, 1),
433 (u64::MAX, u64::MAX, UINT64, u64::MAX - 1, 1),
434 ] {
435 let (got_sum, got_overflow) = add_with_overflow_helper(first, second, st)?;
436 assert_eq!(got_sum, want_sum, "{first} + {second}");
437 assert_eq!(got_overflow, want_overflow, "{first} + {second}");
438 }
439 Ok(())
440 }
441
442 #[test]
443 fn test_well_behaved() -> Result<()> {
444 {
445 let c = simple_context(|g| {
446 let i1 = g.input(array_type(vec![5, 16], BIT))?;
447 let i2 = g.input(array_type(vec![1, 16], BIT))?;
448 g.custom_op(
449 CustomOperation::new(BinaryAdd {
450 overflow_bit: false,
451 }),
452 vec![i1, i2],
453 )
454 })?;
455 let mapped_c = run_instantiation_pass(c)?;
456 let inputs1 =
457 Value::from_flattened_array(&vec![0, 1023, -1023, i16::MIN, i16::MAX], INT16)?;
458 let inputs2 = Value::from_flattened_array(&vec![1024], INT16)?;
459 let result_v = random_evaluate(
460 mapped_c.get_context().get_main_graph()?,
461 vec![inputs1, inputs2],
462 )?
463 .to_flattened_array_u64(array_type(vec![5], INT16))?;
464 assert_eq!(
465 result_v,
466 vec![
467 1024,
468 2047,
469 1,
470 (i16::MIN + 1024) as u64,
471 (i16::MAX.wrapping_add(1024)) as u64,
472 ]
473 );
474 }
475 {
476 let c = simple_context(|g| {
477 let i1 = g.input(array_type(vec![64], BIT))?;
478 let i2 = g.input(array_type(vec![64], BIT))?;
479 g.custom_op(
480 CustomOperation::new(BinaryAdd {
481 overflow_bit: false,
482 }),
483 vec![i1, i2],
484 )
485 })?;
486 let mapped_c = run_instantiation_pass(c)?;
487 let input0 = Value::from_scalar(123456790, INT64)?;
488 let input1 = Value::from_scalar(-123456789, INT64)?;
489 let result_v = random_evaluate(
490 mapped_c.get_context().get_main_graph()?,
491 vec![input0, input1],
492 )?
493 .to_u64(INT64)?;
494 assert_eq!(result_v, 1);
495 }
496 Ok(())
497 }
498
499 #[test]
500 fn test_malformed() -> Result<()> {
501 let c = create_context()?;
502 let g = c.create_graph()?;
503 let i = g.input(array_type(vec![64], BIT))?;
504 let i1 = g.input(array_type(vec![64], INT16))?;
505 let i2 = g.input(tuple_type(vec![]))?;
506 let i3 = g.input(array_type(vec![32], BIT))?;
507 let i4 = g.input(array_type(vec![31], BIT))?;
508 assert!(g
509 .custom_op(
510 CustomOperation::new(BinaryAdd {
511 overflow_bit: false
512 }),
513 vec![i.clone()]
514 )
515 .is_err());
516 assert!(g
517 .custom_op(
518 CustomOperation::new(BinaryAdd {
519 overflow_bit: false
520 }),
521 vec![i.clone(), i1.clone()]
522 )
523 .is_err());
524 assert!(g
525 .custom_op(
526 CustomOperation::new(BinaryAdd {
527 overflow_bit: false
528 }),
529 vec![i1.clone(), i.clone()]
530 )
531 .is_err());
532 assert!(g
533 .custom_op(
534 CustomOperation::new(BinaryAdd {
535 overflow_bit: false
536 }),
537 vec![i2]
538 )
539 .is_err());
540 assert!(g
541 .custom_op(
542 CustomOperation::new(BinaryAdd {
543 overflow_bit: false
544 }),
545 vec![i.clone(), i3]
546 )
547 .is_err());
548 assert!(g
549 .custom_op(
550 CustomOperation::new(BinaryAdd {
551 overflow_bit: false
552 }),
553 vec![i4.clone(), i4]
554 )
555 .is_err());
556 Ok(())
557 }
558}