use crate::{
core::{
bounds::FieldBounds,
circuits::{
boolean::{
boolean_value::BooleanValue,
byte::Byte,
utils::{addition_circuit, CircuitType},
},
f64::utils::F64,
traits::f64_circuit::F64Circuit,
},
expressions::expr::EvalFailure,
global_value::value::FieldValue,
},
traits::{FromLeBits, GetBit},
utils::{field::BaseField, used_field::UsedField},
};
use core::panic;
use ff::Field;
#[allow(dead_code)]
const EXPONENT_OFFSET: [bool; 12] = [
true, false, false, false, false, false, false, false, false, false, true, true,
];
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct F64Mul;
impl F64Mul {
#[allow(dead_code)]
pub fn mul(lhs: F64, rhs: F64) -> F64 {
let sign_res = lhs.sign ^ rhs.sign;
let mantissa_lhs = FieldValue::from(BaseField::power_of_two(52)) + lhs.mantissa;
let mantissa_rhs = FieldValue::from(BaseField::power_of_two(52)) + rhs.mantissa;
let mantissa_prod = mantissa_lhs * mantissa_rhs;
let mantissa_prod_bits = (0..106)
.map(|i| mantissa_prod.get_bit(i, false))
.collect::<Vec<BooleanValue>>();
let is_full_length = *mantissa_prod_bits.last().unwrap();
let is_not_full_length = !is_full_length;
let lsb_bit = mantissa_prod_bits[52];
let lsb_offset = is_not_full_length & lsb_bit;
let mantissa_before_correction = FieldValue::<BaseField>::from_le_bits(
mantissa_prod_bits
.into_iter()
.skip(53)
.collect::<Vec<BooleanValue>>(),
false,
);
let correction_term =
FieldValue::<BaseField>::from_le_bits(vec![is_full_length, is_not_full_length], false);
let mantissa = mantissa_before_correction * correction_term
+ FieldValue::<BaseField>::from(lsb_offset);
let mantissa_res = (mantissa + FieldValue::from(BaseField::negative_power_of_two(52)))
.with_bounds(FieldBounds::new(
BaseField::ZERO,
BaseField::power_of_two(52) - BaseField::ONE,
));
let mut exponent_lhs = lhs.exponent.to_vec();
exponent_lhs.push(BooleanValue::from(false));
let mut exponent_rhs = rhs.exponent.to_vec();
exponent_rhs.push(BooleanValue::from(false));
let tmp = addition_circuit(
exponent_lhs,
EXPONENT_OFFSET
.into_iter()
.map(BooleanValue::from)
.collect::<Vec<BooleanValue>>(),
is_full_length,
CircuitType::default(),
);
let mut exponent_res = addition_circuit(
tmp,
exponent_rhs,
BooleanValue::from(false),
CircuitType::default(),
);
let _is_nan = exponent_res.pop();
let exponent_res = exponent_res
.try_into()
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!("Expected a Vec of length 11 (found {})", v.len())
});
F64::new(sign_res, exponent_res, mantissa_res)
}
}
impl F64Circuit for F64Mul {
fn eval(&self, x: Vec<f64>) -> Result<Vec<f64>, EvalFailure> {
if x.len() != 2 {
panic!("F64Mul expects input Vec of length 2");
}
let lhs = x[0];
let rhs = x[1];
let lhs_bits = lhs
.to_le_bytes()
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.collect::<Vec<bool>>();
let rhs_bits = rhs
.to_le_bytes()
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.collect::<Vec<bool>>();
let exponent_lhs = lhs_bits[52..63]
.iter()
.enumerate()
.fold(0i16, |acc, (i, b)| if *b { acc | (1 << i) } else { acc });
let exponent_rhs = rhs_bits[52..63]
.iter()
.enumerate()
.fold(0i16, |acc, (i, b)| if *b { acc | (1 << i) } else { acc });
let exponent_res = exponent_lhs + exponent_rhs - 1023;
if !(1..2046).contains(&exponent_lhs)
|| !(1..2046).contains(&exponent_rhs)
|| !(1..2046).contains(&exponent_res)
{
return EvalFailure::err_ub("inputs or product out of range");
}
Ok(vec![lhs * rhs])
}
fn rtol(&self) -> f64 {
2f64.powi(-52)
}
fn run(&self, vals: Vec<F64>) -> Vec<F64> {
if vals.len() != 2 {
panic!("F64Mul expects input Vec of length 2");
}
let lhs = vals[0].clone();
let rhs = vals[1].clone();
vec![F64Mul::mul(lhs, rhs)]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::circuits::traits::f64_circuit::tests::TestedF64Circuit;
use rand::Rng;
impl TestedF64Circuit for F64Mul {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
2
}
}
#[test]
fn tested_f64_mul() {
F64Mul::test(4, 16)
}
}