use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::{FieldBounds, IsBounds},
circuits::{
arithmetic::sqrt,
key_recovery::utils::reed_solomon::KeyRecoveryReedSolomonFinal,
},
expressions::{circuit::ArithmeticCircuitId, expr::EvalFailure, InputKind},
},
traits::{Invert, Pow},
utils::{
ignore_for_equality::IgnoreForEquality,
number::Number,
unique_id::UniqueId,
used_field::UsedField,
},
};
use arcis_internal_expr_macro::Expr;
use core_utils::key_recovery::{MXE_KEY_RECOVERY_D, MXE_KEY_RECOVERY_N};
use serde::{Deserialize, Serialize};
use std::{cell::Cell, marker::PhantomData, rc::Rc};
pub type InputId = usize;
pub type PlayerId = u16;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct InputInfo<F: UsedField> {
pub kind: InputKind,
pub min: F,
pub max: F,
pub name: String,
pub has_already_been_found_unused: IgnoreForEquality<Cell<bool>>,
}
impl<F: UsedField> Default for InputInfo<F> {
fn default() -> Self {
let kind = InputKind::Plaintext;
let min = F::ZERO;
let max = F::ONE;
let name = "_".to_owned();
let has_already_been_found_unused = IgnoreForEquality(Cell::new(false));
Self {
kind,
min,
max,
name,
has_already_been_found_unused,
}
}
}
impl<F: UsedField> InputInfo<F> {
pub fn is_plaintext(&self) -> bool {
self.kind.is_plaintext()
}
}
impl<F: UsedField> From<InputKind> for InputInfo<F> {
fn from(value: InputKind) -> Self {
InputInfo {
kind: value,
min: F::ZERO,
max: -F::ONE,
..InputInfo::default()
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RandomValId(UniqueId);
impl RandomValId {
pub fn new() -> RandomValId {
RandomValId(UniqueId::new())
}
}
impl Default for RandomValId {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
pub enum FieldExpr<F: UsedField, T: Clone, C: Clone = T, P: Clone = T> {
Input(InputId, Rc<InputInfo<F>>),
Add(T, T),
Sub(T, T),
Mul(T, T),
LinComb(Vec<(T, F)>, F),
Gt(T, T, bool),
Ge(T, T, bool),
Rem(T, P),
Reveal(T),
Val(F),
Where(C, T, T),
Equal(T, T),
Neg(T),
Abs(T),
LogicalRightShift(T, usize),
KeepLsBits(T, usize, bool),
Div(T, P),
Bounds(T, FieldBounds<F>),
FieldInverse(P),
SubCircuit(Vec<T>, ArithmeticCircuitId, usize),
RandomVal(RandomValId),
Sqrt(T),
Pow(T, Number, bool),
Cap(T, usize),
KeyRecoveryComputeErrors(T, Vec<T>, usize),
}
impl<F: UsedField> FieldExpr<F, bool> {
pub fn is_plaintext(&self) -> bool {
match self {
FieldExpr::Input(_, info) => info.is_plaintext(),
FieldExpr::Reveal(_) => true,
FieldExpr::Val(_) => true,
FieldExpr::RandomVal(_) => false,
_ => self.get_deps().iter().all(|x: &bool| *x),
}
}
}
impl<F: UsedField, T: Clone, C: Clone, P: Clone> FieldExpr<F, T, C, P> {
pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
!matches!(
self,
FieldExpr::Input(_, _)
| FieldExpr::RandomVal(_)
| FieldExpr::Sqrt(_)
| FieldExpr::Cap(_, _)
)
}
pub fn get_input(&self) -> Option<InputId> {
match self {
FieldExpr::Input(id, _) => Some(*id),
_ => None,
}
}
pub fn get_input_name(&self) -> &str {
match self {
FieldExpr::Input(_, info) => info.name.as_str(),
_ => "",
}
}
pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
match self {
FieldExpr::Input(_, info) => Some(&info.has_already_been_found_unused.0),
_ => None,
}
}
}
impl<F: ActuallyUsedField> FieldExpr<F, F> {
pub fn eval(self) -> Result<F, EvalFailure> {
use FieldExpr::*;
let val: F = match self {
Input(_, _) => EvalFailure::err_imp("Input not evaluable here")?,
Add(e1, e2) => e1 + e2,
LinComb(vec, c) => vec.into_iter().map(|(e, factor)| e * factor).sum::<F>() + c,
Val(v) => v,
Mul(e1, e2) => e1 * e2,
Gt(e1, e2, signed) => {
let offset = if signed { F::TWO_INV } else { F::ZERO };
(e1 - offset > e2 - offset).into()
}
Ge(e1, e2, signed) => {
let offset = if signed { F::TWO_INV } else { F::ZERO };
(e1 - offset >= e2 - offset).into()
}
Where(e1, e2, e3) => {
if e1 == F::ONE {
e2
} else if e1 == F::ZERO {
e3
} else {
EvalFailure::err_ub("Where condition input should be 1 or 0")?
}
}
Equal(e1, e2) => (e1 == e2).into(),
Rem(e1, e2) => {
if e2 == F::ZERO {
EvalFailure::err_ub("Modulo by 0")?
} else {
let div = e1.unsigned_euclidean_division(e2);
e1 - div * e2
}
}
Reveal(e) => e,
Abs(e) => e.abs(),
LogicalRightShift(e, s) => (e.to_unsigned_number() >> s).into(),
KeepLsBits(e, s, signed_output) => {
let mut temp = e.to_signed_number() & (Number::power_of_two(s) - 1);
if signed_output && s >= 1 && temp >= Number::power_of_two(s - 1) {
temp = temp - Number::power_of_two(s);
}
temp.into()
}
Div(e1, e2) => {
if e2 == F::ZERO {
EvalFailure::err_ub("Division by 0")?
} else {
e1.unsigned_euclidean_division(e2)
}
}
Bounds(e, b) => {
if b.contains(e) {
e
} else {
EvalFailure::err_bounds(format!("Bounds input {e:?} should be in {b:?}"))?
}
}
Sub(e1, e2) => e1 - e2,
Neg(e) => -e,
FieldInverse(e) => {
e.invert(true)
}
SubCircuit(v, c, i) => c.to_circuit().eval(v)?.get(i).cloned().unwrap_or(F::ZERO),
RandomVal(_) => EvalFailure::err_imp("RandomVal not evaluable here")?,
Sqrt(v) => {
let (is_real, res) = sqrt::<F, bool, F>(v, false);
if is_real {
res
} else {
EvalFailure::err_ub("Non-quadratic residue.")?
}
}
Pow(v, e, _) => {
v.pow(&e, true)
}
Cap(x, n) => {
if x < F::power_of_two(n) {
x
} else {
EvalFailure::err_ub("Input of capping failed")?
}
}
KeyRecoveryComputeErrors(d_minus_one, syndromes, i) => {
if d_minus_one.ge(&F::from(MXE_KEY_RECOVERY_D as u64)) {
return EvalFailure::err_ub("d_minus_one too large");
}
if i >= MXE_KEY_RECOVERY_N {
return EvalFailure::err_ub("i too large");
}
KeyRecoveryReedSolomonFinal::compute_errors_field::<MXE_KEY_RECOVERY_N, F>(
d_minus_one,
syndromes,
)[i]
}
};
Ok(val)
}
}
fn equal_bounds<F: UsedField>(b1: FieldBounds<F>, b2: FieldBounds<F>) -> FieldBounds<F> {
if let (Some(c1), Some(c2)) = (b1.as_constant(), b2.as_constant()) {
if c1 == c2 {
FieldBounds::from(F::ONE)
} else {
FieldBounds::from(F::ZERO)
}
} else if b1.inter(b2).is_empty() {
FieldBounds::from(F::ZERO)
} else {
FieldBounds::new(F::ZERO, F::ONE)
}
}
pub fn div_bounds<F: UsedField>(b1: FieldBounds<F>, b2: FieldBounds<F>) -> FieldBounds<F> {
let (min1, max1) = b1.to_unsigned_number_pair();
let (min2, max2) = b2.to_unsigned_number_pair();
FieldBounds::new(
(min1 / max2.max(1.into())).into(),
(max1 / min2.max(1.into())).into(),
)
}
fn rem_bounds<F: UsedField>(b1: FieldBounds<F>, b2: FieldBounds<F>) -> FieldBounds<F> {
if let (Some(c1), Some(c2)) = (b1.as_constant(), b2.as_constant()) {
if c2 == F::ZERO {
return FieldBounds::from(F::ZERO);
}
let res: F = (c1.to_unsigned_number() % c2.to_unsigned_number()).into();
FieldBounds::from(res)
} else if b2.unsigned_max() == F::ZERO {
FieldBounds::from(F::ZERO)
} else {
FieldBounds::new(F::ZERO, b2.unsigned_max() - F::ONE)
}
}
pub fn shr_bounds<F: UsedField>(b: FieldBounds<F>, c: usize, signed: bool) -> FieldBounds<F> {
let (min, max) = if signed {
b.to_signed_number_pair()
} else {
b.to_unsigned_number_pair()
};
FieldBounds::new((min >> c).into(), (max >> c).into())
}
pub fn keep_ls_bounds<F: ActuallyUsedField>(
b: FieldBounds<F>,
c: usize,
signed_output: bool,
) -> FieldBounds<F> {
if c == 0 {
FieldBounds::new(F::ZERO, F::ZERO)
} else {
let (min_b, max_b) = b.min_and_max(true);
if max_b - min_b < F::power_of_two(c) {
let min_res = FieldExpr::KeepLsBits(min_b, c, signed_output)
.eval()
.expect("KeepLowEndianSignedBits always succeeds.");
let max_res = FieldExpr::KeepLsBits(max_b, c, signed_output)
.eval()
.expect("KeepLowEndianSignedBits always succeeds.");
if (max_res - min_res).is_ge_zero() {
return FieldBounds::new(min_res, max_res);
}
}
if signed_output {
FieldBounds::new(
F::negative_power_of_two(c - 1),
F::power_of_two(c - 1) - F::ONE,
)
} else {
FieldBounds::new(F::ZERO, F::power_of_two(c) - F::ONE)
}
}
}
impl<F: ActuallyUsedField> FieldExpr<F, FieldBounds<F>> {
pub fn bounds(self) -> FieldBounds<F> {
use FieldExpr::*;
match self {
Input(_, info) => FieldBounds::new(info.min, info.max),
Add(b1, b2) => b1 + b2,
Sub(b1, b2) => b1 + (-b2),
Mul(b1, b2) => b1 * b2,
LinComb(v, c) => v
.into_iter()
.map(|(b, factor)| b * factor)
.fold(FieldBounds::from(c), std::ops::Add::add),
Gt(b1, b2, signed) => {
let (b1_min, b1_max) = b1.min_and_max(signed);
let (b2_min, b2_max) = b2.min_and_max(signed);
FieldBounds::new(
Gt(b1_min, b2_max, signed)
.eval()
.expect("Comparisons cannot fail."),
Gt(b1_max, b2_min, signed)
.eval()
.expect("Comparisons cannot fail."),
)
}
Ge(b1, b2, signed) => {
let (b1_min, b1_max) = b1.min_and_max(signed);
let (b2_min, b2_max) = b2.min_and_max(signed);
FieldBounds::new(
Ge(b1_min, b2_max, signed)
.eval()
.expect("Comparisons cannot fail."),
Ge(b1_max, b2_min, signed)
.eval()
.expect("Comparisons cannot fail."),
)
}
Rem(b1, b2) => rem_bounds(b1, b2),
Reveal(b) => b,
Val(b) => FieldBounds::from(b),
Where(b1, b2, b3) => {
let can_be_true = b1.contains(F::ONE);
let can_be_false = b1.contains(F::ZERO);
if can_be_true && can_be_false {
b2.union(b3)
} else if can_be_true {
b2
} else if can_be_false {
b3
} else {
FieldBounds::Empty
}
}
Equal(b1, b2) => equal_bounds(b1, b2),
Neg(b) => -b,
Abs(b) => FieldBounds::new(b.min_abs(), b.max_abs()),
LogicalRightShift(b, c) => shr_bounds(b, c, false),
KeepLsBits(b, c, signed_output) => keep_ls_bounds(b, c, signed_output),
Div(b1, b2) => div_bounds(b1, b2),
Bounds(b1, b2) => b1.inter(b2),
FieldInverse(b) => {
if let Some(x) = b.as_constant() {
FieldBounds::from(x.invert(true))
} else {
FieldBounds::All
}
}
SubCircuit(v, c, i) => c
.to_circuit()
.bounds(v)
.get(i)
.cloned()
.unwrap_or(FieldBounds::new(F::ZERO, F::ZERO)),
RandomVal(_) => FieldBounds::All,
Sqrt(b) => {
if let Some(x) = b.as_constant() {
FieldBounds::from(sqrt::<F, bool, F>(x, false).1)
} else {
FieldBounds::All
}
}
Pow(b, e, _) => {
if let Some(x) = b.as_constant() {
FieldBounds::from(x.pow(&e, true))
} else {
FieldBounds::All
}
}
Cap(b, n) => {
let max = F::power_of_two(n) - F::ONE;
if b.unsigned_max() <= max {
b
} else {
FieldBounds::new(F::ZERO, max)
}
}
KeyRecoveryComputeErrors(d_minus_one, syndromes, i) => {
if let Some(all_bounds) = std::iter::once(d_minus_one)
.chain(syndromes)
.map(|b| b.as_constant())
.collect::<Option<Vec<F>>>()
{
let d_minus_one = all_bounds[0];
let syndromes = all_bounds.into_iter().skip(1).collect::<Vec<F>>();
assert!(d_minus_one.lt(&F::from(MXE_KEY_RECOVERY_D as u64)));
assert!(i < MXE_KEY_RECOVERY_N);
FieldBounds::from(
KeyRecoveryReedSolomonFinal::compute_errors_field::<MXE_KEY_RECOVERY_N, F>(
d_minus_one,
syndromes,
)[i],
)
} else {
FieldBounds::All
}
}
}
}
}
macro_rules! expr_lincomb {
($(($e:expr, $factor:expr)),*) => (vec![$(($e, $factor.into())),*]);
($(($e:expr, $factor:expr)),*; $c: expr) => (FieldExpr::LinComb(vec![$(($e, $factor.into())),*], $c.into()));
}
pub(crate) use expr_lincomb;
#[cfg(test)]
pub mod tests {
use super::*;
use crate::utils::number::Number;
use rand::Rng;
impl<F: UsedField> InputInfo<F> {
pub fn generate<R: Rng + ?Sized>(
rng: &mut R,
lower: &Number,
upper: &Number,
) -> Rc<InputInfo<F>> {
let l: F = lower.clone().into();
let u: F = (upper - 1).into();
let min = F::gen_inclusive_range(rng, l, u);
let max = F::gen_inclusive_range(rng, l, u);
let (min, max) = if max - l < min - l {
(max, min)
} else {
(min, max)
};
Rc::new(InputInfo {
kind: InputKind::Secret,
min,
max,
..InputInfo::default()
})
}
}
}