use crate::{
core::{
bounds::{below_power_of_two, BoolBounds, FieldBounds, IsBounds},
expressions::{expr::EvalFailure, field_expr::InputId},
},
traits::FromLeBits,
utils::{unique_id::UniqueId, used_field::UsedField},
};
use arcis_internal_expr_macro::Expr;
use serde::{Deserialize, Serialize};
use std::{cell::Cell, marker::PhantomData, ops::Add};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EdaBitId(UniqueId);
impl EdaBitId {
pub fn new() -> Self {
EdaBitId(UniqueId::new())
}
}
impl Default for EdaBitId {
fn default() -> Self {
EdaBitId::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
pub enum ConversionExpr<F: UsedField, T: Clone, B: Clone = T, E: Clone = T> {
BitNumToBit(T, usize, bool),
BitToBitNum(Vec<B>, bool),
EdaBit(EdaBitId, usize, PhantomData<F>),
BitFromEdaBit(E, usize),
ScalarFromPlaintextBit(B),
ScalarFromEdaBit(E),
}
pub enum ConversionValue<F: UsedField> {
Bit(bool),
Scalar(F),
}
impl<F: UsedField + FromLeBits<bool>> ConversionExpr<F, F, bool> {
pub fn eval(self) -> Result<ConversionValue<F>, EvalFailure> {
use ConversionExpr::*;
use ConversionValue::*;
let val = match self {
BitNumToBit(e, c, b) => {
if b {
Bit(e.signed_bit(c))
} else {
Bit(e.unsigned_bit(c))
}
}
BitToBitNum(v, signed) => Scalar(<F>::from_le_bits(v, signed)),
EdaBit(_, _, _) => EvalFailure::err_imp("edaBit not evaluable here")?,
BitFromEdaBit(e, k) => Bit(e.unsigned_bit(k)),
ScalarFromEdaBit(e) => Scalar(e),
ScalarFromPlaintextBit(b) => Scalar(F::from(b)),
};
Ok(val)
}
}
impl<F: UsedField> ConversionExpr<F, bool, bool> {
pub fn is_plaintext(&self) -> bool {
!matches!(self, ConversionExpr::EdaBit(..)) && self.get_deps().iter().all(|x| *x)
}
}
impl<F: UsedField, T: Clone, B: Clone> ConversionExpr<F, T, B> {
pub fn is_boolean(&self) -> bool {
use ConversionExpr::*;
match self {
BitNumToBit(_, _, _) => true,
BitToBitNum(_, _) => false,
EdaBit(_, _, _) => false,
BitFromEdaBit(_, _) => true,
ScalarFromPlaintextBit(_) => false,
ScalarFromEdaBit(_) => false,
}
}
pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
!matches!(self, ConversionExpr::EdaBit(..))
}
pub fn get_input(&self) -> Option<InputId> {
None
}
pub fn get_input_name(&self) -> &str {
""
}
pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
None
}
}
pub enum ConversionBounds<F: UsedField> {
Bit(BoolBounds),
Scalar(FieldBounds<F>),
}
fn all_bit_bools<F: UsedField>() -> ConversionBounds<F> {
ConversionBounds::Bit(BoolBounds::new(true, true))
}
fn all_scalar_bools<F: UsedField>() -> ConversionBounds<F> {
ConversionBounds::Scalar(FieldBounds::new(F::ZERO, F::ONE))
}
impl<F: UsedField> From<bool> for ConversionBounds<F> {
fn from(b: bool) -> Self {
ConversionBounds::Bit(BoolBounds::from(b))
}
}
impl<F: UsedField> From<F> for ConversionBounds<F> {
fn from(f: F) -> Self {
ConversionBounds::Scalar(FieldBounds::from(f))
}
}
impl<F: UsedField> ConversionExpr<F, FieldBounds<F>, BoolBounds> {
pub fn bounds(self) -> ConversionBounds<F> {
use ConversionBounds::*;
use ConversionExpr::*;
match self {
BitNumToBit(b, idx, signed) => {
let (min, max) = b.min_and_max(signed);
if max - min >= F::power_of_two(idx) {
all_bit_bools()
} else {
let (b1, b2) = b.bits_in_pos(idx, signed);
if b1 != b2 {
all_bit_bools()
} else {
b1.into()
}
}
}
BitToBitNum(v, signed) => {
let n = v.len();
Scalar(
v.into_iter()
.enumerate()
.map(|(i, bounds)| {
let is_negative = signed && i + 1 == n;
bounds.multiply_by_power_of_two(i, is_negative)
})
.fold(FieldBounds::from(F::ZERO), Add::add),
)
}
EdaBit(_, u, _) => Scalar(below_power_of_two(u)),
BitFromEdaBit(b, k) => {
if b.unsigned_max() >= F::power_of_two(k) {
all_bit_bools()
} else {
Bit(false.into())
}
}
ScalarFromPlaintextBit(b) => {
if let Some(v) = b.as_constant() {
ConversionBounds::from(F::from(v))
} else {
all_scalar_bools()
}
}
ScalarFromEdaBit(b) => Scalar(b),
}
}
}