use crate::custom_ops::{CustomOperation, CustomOperationBody};
use crate::data_types::{array_type, Type, BIT};
use crate::errors::Result;
use crate::graphs::{Context, Graph, Node, SliceElement};
use crate::ops::utils::{expand_dims, put_in_bits};
use serde::{Deserialize, Serialize};
use super::utils::{pull_out_bits_pair, validate_arguments_in_broadcast_bit_ops};
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct BinaryAdd {
pub overflow_bit: bool,
}
#[typetag::serde]
impl CustomOperationBody for BinaryAdd {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
validate_arguments_in_broadcast_bit_ops(arguments_types.clone(), &self.get_name())?;
let input_type0 = arguments_types[0].clone();
let input_type1 = arguments_types[1].clone();
let g = context.create_graph()?;
let (input0, input1) = pull_out_bits_pair(g.input(input_type0)?, g.input(input_type1)?)?;
let added = g.custom_op(
CustomOperation::new(BinaryAddTransposed {
overflow_bit: self.overflow_bit,
}),
vec![input0, input1],
)?;
let output = if self.overflow_bit {
g.create_tuple(vec![
put_in_bits(added.tuple_get(0)?)?,
put_in_bits(added.tuple_get(1)?)?,
])?
} else {
put_in_bits(added)?
};
output.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!("BinaryAdd(overflow_bit={})", self.overflow_bit)
}
}
#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub(crate) struct BinaryAddTransposed {
pub overflow_bit: bool,
}
#[typetag::serde]
impl CustomOperationBody for BinaryAddTransposed {
fn instantiate(&self, context: Context, arguments_types: Vec<Type>) -> Result<Graph> {
if arguments_types.len() != 2 {
return Err(runtime_error!("Invalid number of arguments"));
}
match (&arguments_types[0], &arguments_types[1]) {
(Type::Array(shape0, scalar_type0), Type::Array(shape1, scalar_type1)) => {
if shape0[0] != shape1[0] {
return Err(runtime_error!(
"Input arrays' first dimensions are not the same"
));
}
if *scalar_type0 != BIT {
return Err(runtime_error!("Input array [0]'s ScalarType is not BIT"));
}
if *scalar_type1 != BIT {
return Err(runtime_error!("Input array [1]'s ScalarType is not BIT"));
}
}
_ => {
return Err(runtime_error!(
"Invalid input argument type, expected Array type"
));
}
}
let input_type0 = arguments_types[0].clone();
let input_type1 = arguments_types[1].clone();
let g = context.create_graph()?;
let input0 = g.input(input_type0)?;
let input1 = g.input(input_type1)?;
let xor_bits = g.add(input0.clone(), input1.clone())?;
let and_bits = g.multiply(input0, input1)?;
let (carries, overflow_bit) =
calculate_carry_bits(xor_bits.clone(), and_bits, self.overflow_bit)?;
let added = carries.add(xor_bits)?;
let output = match overflow_bit {
Some(overflow_bit) => g.create_tuple(vec![added, overflow_bit])?,
None => added,
};
output.set_as_output()?;
g.finalize()?;
Ok(g)
}
fn get_name(&self) -> String {
format!("BinaryAddTransposed(overflow_bit={})", self.overflow_bit)
}
}
#[derive(Clone)]
struct CarryNode {
propagate: Node,
generate: Node,
}
impl CarryNode {
fn bit_len(&self) -> Result<u64> {
Ok(self.propagate.get_type()?.get_shape()[0])
}
fn shrink(&self, overflow_bit: bool) -> Result<CarryNode> {
let bit_len = self.bit_len()? as i64;
let next_lvl_bits = if overflow_bit {
bit_len / 2
} else {
(bit_len - 1) / 2
};
let use_bits = next_lvl_bits * 2;
let lower = self.sub_slice(0, use_bits)?;
let higher = self.sub_slice(1, use_bits)?;
lower.join(&higher)
}
fn join(&self, rhs: &Self) -> Result<Self> {
let propagate = self.propagate.multiply(rhs.propagate.clone())?;
let generate = rhs
.generate
.add(rhs.propagate.multiply(self.generate.clone())?)?;
Ok(Self {
propagate,
generate,
})
}
fn sub_slice(&self, start_offset: i64, bit_len: i64) -> Result<Self> {
let get_slice = |node: &Node| {
node.get_slice(vec![SliceElement::SubArray(
Some(start_offset),
Some(bit_len),
Some(2),
)])
};
Ok(Self {
propagate: get_slice(&self.propagate)?,
generate: get_slice(&self.generate)?,
})
}
fn apply(&self, prev_carry: Node) -> Result<Node> {
self.generate.add(self.propagate.multiply(prev_carry)?)
}
}
fn interleave(first: Node, second: Node) -> Result<Node> {
let first = expand_dims(first, &[0])?;
let second = expand_dims(second, &[0])?;
let graph = first.get_graph();
let joined = graph.concatenate(vec![first, second], 0)?;
let mut axes: Vec<_> = (0..joined.get_type()?.get_shape().len() as u64).collect();
axes.swap(0, 1);
let joined = joined.permute_axes(axes)?;
let mut shape = joined.get_type()?.get_shape();
shape[0] *= 2;
shape.remove(1);
let scalar = joined.get_type()?.get_scalar_type();
joined.reshape(array_type(shape, scalar))
}
fn calculate_carry_bits(
propagate_bits: Node,
generate_bits: Node,
overflow_bit: bool,
) -> Result<(Node, Option<Node>)> {
let graph = propagate_bits.get_graph();
let mut nodes = vec![CarryNode {
propagate: propagate_bits,
generate: generate_bits,
}];
let bit_len = nodes[0].bit_len()?;
if !bit_len.is_power_of_two() {
return Err(runtime_error!("BinaryAdd only supports numbers with number of bits, which is a power of 2. {} bits provided.", bit_len));
}
let mut shape = nodes[0].propagate.get_type()?.get_shape();
shape[0] = 1;
let mut carries = graph.zeros(array_type(shape, BIT))?;
if !overflow_bit && bit_len == 1 {
return Ok((carries, None));
}
if overflow_bit || bit_len > 2 {
while nodes.last().unwrap().bit_len()? > 1 {
let last = nodes.last().unwrap();
nodes.push(last.shrink(overflow_bit)?);
}
}
let mut node_rev_iter = nodes.iter().rev();
let overflow_bit = if overflow_bit {
let root_node = node_rev_iter.next().unwrap();
Some(root_node.apply(carries.clone())?)
} else {
None
};
for node in node_rev_iter {
let lower = node.sub_slice(0, node.bit_len()? as i64)?;
let new_carries = lower.apply(carries.clone())?;
carries = interleave(carries, new_carries)?;
}
Ok((carries, overflow_bit))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::custom_ops::{run_instantiation_pass, CustomOperation};
use crate::data_types::{
array_type, tuple_type, ScalarType, INT16, INT64, UINT16, UINT32, UINT64, UINT8,
};
use crate::data_values::Value;
use crate::evaluators::random_evaluate;
use crate::graphs::create_context;
use crate::graphs::util::simple_context;
fn test_helper(first: u64, second: u64, st: ScalarType) -> Result<()> {
let bits = st.size_in_bits();
let mask = (1u128 << bits) - 1;
let first = (first as u128) & mask;
let second = (second as u128) & mask;
let c = simple_context(|g| {
let i1 = g.input(array_type(vec![bits], BIT))?;
let i2 = g.input(array_type(vec![bits], BIT))?;
let o = g.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false,
}),
vec![i1, i2],
)?;
assert_eq!(
o.get_type()?.get_dimensions(),
vec![bits],
"{first} + {second} with {bits} bits"
);
Ok(o)
})?;
let mapped_c = run_instantiation_pass(c)?;
let input0 = Value::from_scalar(first, st)?;
let input1 = Value::from_scalar(second, st)?;
let result_v = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![input0, input1],
)?
.to_u128(st)?;
let expected_result = first.wrapping_add(second) & mask;
assert_eq!(
result_v, expected_result,
"{first} + {second} with {bits} bits"
);
Ok(())
}
#[test]
fn test_random_inputs() -> Result<()> {
let random_numbers = [0, 1, 3, 4, 10, 100500, 123456, 787788];
for st in [BIT, UINT8, UINT16, UINT32, UINT64] {
for &x in random_numbers.iter() {
for &y in random_numbers.iter() {
test_helper(x, y, st)?;
}
}
}
Ok(())
}
fn add_with_overflow_helper(first: u64, second: u64, st: ScalarType) -> Result<(u64, u64)> {
let bits = st.size_in_bits();
let c = simple_context(|g| {
let i1 = g.input(array_type(vec![bits], BIT))?;
let i2 = g.input(array_type(vec![bits], BIT))?;
g.custom_op(
CustomOperation::new(BinaryAdd { overflow_bit: true }),
vec![i1, i2],
)
})?;
let mapped_c = run_instantiation_pass(c)?;
let input0 = Value::from_scalar(first, st)?;
let input1 = Value::from_scalar(second, st)?;
let results = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![input0, input1],
)?
.to_vector()?;
Ok((results[0].to_u64(st)?, results[1].to_u64(BIT)?))
}
#[test]
fn test_add_with_overflow_bit() -> Result<()> {
for (first, second, st, want_sum, want_overflow) in [
(0, 0, BIT, 0, 0),
(0, 1, BIT, 1, 0),
(1, 0, BIT, 1, 0),
(1, 1, BIT, 0, 1),
(127, 128, UINT8, 255, 0),
(127, 129, UINT8, 0, 1),
(128, 128, UINT8, 0, 1),
(255, 255, UINT8, 254, 1),
(1234, 4321, UINT16, 5555, 0),
(12345, 54321, UINT16, 1130, 1),
(12345, 54321, UINT32, 66666, 0),
(2000000000, 2000000000, UINT32, 4000000000, 0),
(2000000000, 3000000000, UINT32, 705032704, 1),
(u64::MAX, u64::MAX, UINT64, u64::MAX - 1, 1),
] {
let (got_sum, got_overflow) = add_with_overflow_helper(first, second, st)?;
assert_eq!(got_sum, want_sum, "{first} + {second}");
assert_eq!(got_overflow, want_overflow, "{first} + {second}");
}
Ok(())
}
#[test]
fn test_well_behaved() -> Result<()> {
{
let c = simple_context(|g| {
let i1 = g.input(array_type(vec![5, 16], BIT))?;
let i2 = g.input(array_type(vec![1, 16], BIT))?;
g.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false,
}),
vec![i1, i2],
)
})?;
let mapped_c = run_instantiation_pass(c)?;
let inputs1 =
Value::from_flattened_array(&vec![0, 1023, -1023, i16::MIN, i16::MAX], INT16)?;
let inputs2 = Value::from_flattened_array(&vec![1024], INT16)?;
let result_v = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![inputs1, inputs2],
)?
.to_flattened_array_u64(array_type(vec![5], INT16))?;
assert_eq!(
result_v,
vec![
1024,
2047,
1,
(i16::MIN + 1024) as u64,
(i16::MAX.wrapping_add(1024)) as u64,
]
);
}
{
let c = simple_context(|g| {
let i1 = g.input(array_type(vec![64], BIT))?;
let i2 = g.input(array_type(vec![64], BIT))?;
g.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false,
}),
vec![i1, i2],
)
})?;
let mapped_c = run_instantiation_pass(c)?;
let input0 = Value::from_scalar(123456790, INT64)?;
let input1 = Value::from_scalar(-123456789, INT64)?;
let result_v = random_evaluate(
mapped_c.get_context().get_main_graph()?,
vec![input0, input1],
)?
.to_u64(INT64)?;
assert_eq!(result_v, 1);
}
Ok(())
}
#[test]
fn test_malformed() -> Result<()> {
let c = create_context()?;
let g = c.create_graph()?;
let i = g.input(array_type(vec![64], BIT))?;
let i1 = g.input(array_type(vec![64], INT16))?;
let i2 = g.input(tuple_type(vec![]))?;
let i3 = g.input(array_type(vec![32], BIT))?;
let i4 = g.input(array_type(vec![31], BIT))?;
assert!(g
.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false
}),
vec![i.clone()]
)
.is_err());
assert!(g
.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false
}),
vec![i.clone(), i1.clone()]
)
.is_err());
assert!(g
.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false
}),
vec![i1.clone(), i.clone()]
)
.is_err());
assert!(g
.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false
}),
vec![i2]
)
.is_err());
assert!(g
.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false
}),
vec![i.clone(), i3]
)
.is_err());
assert!(g
.custom_op(
CustomOperation::new(BinaryAdd {
overflow_bit: false
}),
vec![i4.clone(), i4]
)
.is_err());
Ok(())
}
}