use super::*;
impl<E: Environment, I: IntegerType, M: Magnitude> Pow<Integer<E, M>> for Integer<E, I> {
type Output = Integer<E, I>;
#[inline]
fn pow(self, other: Integer<E, M>) -> Self::Output {
self.pow_checked(&other)
}
}
impl<E: Environment, I: IntegerType, M: Magnitude> Pow<&Integer<E, M>> for Integer<E, I> {
type Output = Integer<E, I>;
#[inline]
fn pow(self, other: &Integer<E, M>) -> Self::Output {
self.pow_checked(other)
}
}
impl<E: Environment, I: IntegerType, M: Magnitude> PowChecked<Integer<E, M>> for Integer<E, I> {
type Output = Self;
#[inline]
fn pow_checked(&self, other: &Integer<E, M>) -> Self::Output {
if self.is_constant() && other.is_constant() {
match self.eject_value().checked_pow(&other.eject_value().to_u32().unwrap()) {
Some(value) => Integer::new(Mode::Constant, console::Integer::new(value)),
None => E::halt("Integer overflow on exponentiation of two constants"),
}
} else {
let mut result = Self::one();
for bit in other.bits_le.iter().rev() {
result = result.mul_checked(&result);
let result_times_self = if I::is_signed() {
let (product, carry) = Self::mul_with_carry(&(&result).abs_wrapped(), &self.abs_wrapped());
let carry_bits_nonzero = carry.iter().fold(Boolean::constant(false), |a, b| a | b);
let operands_same_sign = &result.msb().is_equal(self.msb());
let positive_product_overflows = operands_same_sign & product.msb();
let negative_product_underflows = {
let lower_product_bits_nonzero = product.bits_le[..(I::BITS as usize - 1)]
.iter()
.fold(Boolean::constant(false), |a, b| a | b);
let negative_product_lt_or_eq_signed_min =
!product.msb() | (product.msb() & !lower_product_bits_nonzero);
!operands_same_sign & !negative_product_lt_or_eq_signed_min
};
let overflow = carry_bits_nonzero | positive_product_overflows | negative_product_underflows;
E::assert_eq(overflow & bit, E::zero());
Self::ternary(operands_same_sign, &product, &(!&product).add_wrapped(&Self::one()))
} else {
let (product, carry) = Self::mul_with_carry(&result, self);
let overflow = carry.iter().fold(Boolean::constant(false), |a, b| a | b);
E::assert_eq(overflow & bit, E::zero());
product
};
result = Self::ternary(bit, &result_times_self, &result);
}
result
}
}
}
impl<E: Environment, I: IntegerType, M: Magnitude> Metrics<dyn PowChecked<Integer<E, M>, Output = Integer<E, I>>>
for Integer<E, I>
{
type Case = (Mode, Mode);
fn count(case: &Self::Case) -> Count {
match (case.0, case.1) {
(Mode::Constant, Mode::Constant) => Count::is(I::BITS, 0, 0, 0),
(Mode::Constant, _) | (_, Mode::Constant) => {
let mul_count = count!(Integer<E, I>, MulWrapped<Integer<E, I>, Output=Integer<E, I>>, case);
(2 * M::BITS * mul_count) + Count::is(2 * I::BITS, 0, I::BITS, I::BITS)
}
(_, _) => {
let mul_count = count!(Integer<E, I>, MulWrapped<Integer<E, I>, Output=Integer<E, I>>, case);
(2 * M::BITS * mul_count) + Count::is(2 * I::BITS, 0, I::BITS, I::BITS)
}
}
}
}
impl<E: Environment, I: IntegerType, M: Magnitude> OutputMode<dyn PowChecked<Integer<E, M>, Output = Integer<E, I>>>
for Integer<E, I>
{
type Case = (Mode, CircuitType<Integer<E, M>>);
fn output_mode(case: &Self::Case) -> Mode {
match (case.0, (case.1.mode(), &case.1)) {
(Mode::Constant, (Mode::Constant, _)) => Mode::Constant,
(Mode::Constant, (mode, _)) => match mode {
Mode::Constant => Mode::Constant,
_ => Mode::Private,
},
(_, (Mode::Constant, case)) => match case {
CircuitType::Constant(constant) => match constant.eject_value().is_zero() {
true => Mode::Constant,
false => Mode::Private,
},
_ => E::halt("The constant is required for the output mode of `pow_wrapped` with a constant."),
},
(_, _) => Mode::Private,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utilities::*;
use snarkvm_circuit_environment::Circuit;
use std::{ops::RangeInclusive, panic::RefUnwindSafe};
const ITERATIONS: u64 = 4;
fn check_pow<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(
name: &str,
first: console::Integer<<Circuit as Environment>::Network, I>,
second: console::Integer<<Circuit as Environment>::Network, M>,
mode_a: Mode,
mode_b: Mode,
) {
let a = Integer::<Circuit, I>::new(mode_a, first);
let b = Integer::<Circuit, M>::new(mode_b, second);
match first.checked_pow(&second.to_u32().unwrap()) {
Some(expected) => Circuit::scope(name, || {
let candidate = a.pow_checked(&b);
assert_eq!(expected, *candidate.eject_value());
assert_eq!(console::Integer::new(expected), candidate.eject_value());
assert!(Circuit::is_satisfied_in_scope(), "(is_satisfied_in_scope)");
}),
None => {
match (mode_a, mode_b) {
(Mode::Constant, Mode::Constant) => check_operation_halts(&a, &b, Integer::pow_checked),
_ => Circuit::scope(name, || {
let _candidate = a.pow_checked(&b);
assert!(!Circuit::is_satisfied_in_scope(), "(!is_satisfied_in_scope)");
}),
}
}
}
Circuit::reset();
}
fn run_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode) {
let mut rng = TestRng::default();
for i in 0..ITERATIONS {
let first = Uniform::rand(&mut rng);
let second = Uniform::rand(&mut rng);
let name = format!("Pow: {} ** {} {}", mode_a, mode_b, i);
check_pow::<I, M>(&name, first, second, mode_a, mode_b);
let name = format!("Pow Zero: {} ** {} {}", mode_a, mode_b, i);
check_pow::<I, M>(&name, first, console::Integer::zero(), mode_a, mode_b);
let name = format!("Pow One: {} ** {} {}", mode_a, mode_b, i);
check_pow::<I, M>(&name, first, console::Integer::one(), mode_a, mode_b);
let name = format!("Square: {} ** {} {}", mode_a, mode_b, i);
check_pow::<I, M>(&name, first, console::Integer::one() + console::Integer::one(), mode_a, mode_b);
let name = format!("Cube: {} ** {} {}", mode_a, mode_b, i);
check_pow::<I, M>(
&name,
first,
console::Integer::one() + console::Integer::one() + console::Integer::one(),
mode_a,
mode_b,
);
}
check_pow::<I, M>("MAX ** MAX", console::Integer::MAX, console::Integer::MAX, mode_a, mode_b);
check_pow::<I, M>("MIN ** 0", console::Integer::MIN, console::Integer::zero(), mode_a, mode_b);
check_pow::<I, M>("MAX ** 0", console::Integer::MAX, console::Integer::zero(), mode_a, mode_b);
check_pow::<I, M>("MIN ** 1", console::Integer::MIN, console::Integer::one(), mode_a, mode_b);
check_pow::<I, M>("MAX ** 1", console::Integer::MAX, console::Integer::one(), mode_a, mode_b);
}
fn run_exhaustive_test<I: IntegerType + RefUnwindSafe, M: Magnitude + RefUnwindSafe>(mode_a: Mode, mode_b: Mode)
where
RangeInclusive<I>: Iterator<Item = I>,
RangeInclusive<M>: Iterator<Item = M>,
{
for first in I::MIN..=I::MAX {
for second in M::MIN..=M::MAX {
let first = console::Integer::<_, I>::new(first);
let second = console::Integer::<_, M>::new(second);
let name = format!("Pow: ({} ** {})", first, second);
check_pow::<I, M>(&name, first, second, mode_a, mode_b);
}
}
}
test_integer_binary!(run_test, i8, u8, pow);
test_integer_binary!(run_test, i8, u16, pow);
test_integer_binary!(run_test, i8, u32, pow);
test_integer_binary!(run_test, i16, u8, pow);
test_integer_binary!(run_test, i16, u16, pow);
test_integer_binary!(run_test, i16, u32, pow);
test_integer_binary!(run_test, i32, u8, pow);
test_integer_binary!(run_test, i32, u16, pow);
test_integer_binary!(run_test, i32, u32, pow);
test_integer_binary!(run_test, i64, u8, pow);
test_integer_binary!(run_test, i64, u16, pow);
test_integer_binary!(run_test, i64, u32, pow);
test_integer_binary!(run_test, i128, u8, pow);
test_integer_binary!(run_test, i128, u16, pow);
test_integer_binary!(run_test, i128, u32, pow);
test_integer_binary!(run_test, u8, u8, pow);
test_integer_binary!(run_test, u8, u16, pow);
test_integer_binary!(run_test, u8, u32, pow);
test_integer_binary!(run_test, u16, u8, pow);
test_integer_binary!(run_test, u16, u16, pow);
test_integer_binary!(run_test, u16, u32, pow);
test_integer_binary!(run_test, u32, u8, pow);
test_integer_binary!(run_test, u32, u16, pow);
test_integer_binary!(run_test, u32, u32, pow);
test_integer_binary!(run_test, u64, u8, pow);
test_integer_binary!(run_test, u64, u16, pow);
test_integer_binary!(run_test, u64, u32, pow);
test_integer_binary!(run_test, u128, u8, pow);
test_integer_binary!(run_test, u128, u16, pow);
test_integer_binary!(run_test, u128, u32, pow);
test_integer_binary!(#[ignore], run_exhaustive_test, u8, u8, pow, exhaustive);
test_integer_binary!(#[ignore], run_exhaustive_test, i8, u8, pow, exhaustive);
}