use crate::{
core::{
bounds::{BoolBounds, Bounds, CurveBounds, FieldBounds, IsBounds},
expressions::{
circuit::{BaseCircuitId, GeneralCircuitId},
expr::{EvalFailure, EvalValue},
field_expr::InputId,
},
mxe_input::{ArxInput, MxeInput},
},
traits::FromLeBytes,
utils::{
curve_point::CurvePoint,
elliptic_curve::{EDWARDS25519_D, FOUR_INV_MOD_ELL},
field::{BaseField, ScalarField},
},
};
use arcis_internal_expr_macro::Expr;
use core::marker::PhantomData;
use curve25519_dalek_arcium_fork::{field::FieldElement, EdwardsPoint, RistrettoPoint, Scalar};
use ff::{Field, PrimeField};
use serde::{Deserialize, Serialize};
use std::cell::Cell;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
#[allow(clippy::enum_variant_names)]
pub enum OtherExpr<T: Clone, B: Clone = T, C: Clone = T> {
ScalarGeneralCircuit(Vec<T>, Vec<B>, GeneralCircuitId, usize),
BaseGeneralCircuit(Vec<T>, Vec<B>, GeneralCircuitId, usize),
BaseArithmeticCircuit(Vec<B>, BaseCircuitId, usize),
PlaintextCurveToExtendedEdwards(C, usize),
CurveFromPlaintextExtendedEdwards(B, B, B, B),
CompressPlaintextPoint(C, usize),
ToProjective(C, usize),
MxeKey(MxeInput),
}
impl OtherExpr<ScalarField, BaseField, CurvePoint> {
#[allow(non_snake_case)]
pub fn eval(self) -> Result<EvalValue, EvalFailure> {
match self {
OtherExpr::ScalarGeneralCircuit(s, b, c, i) => {
let (res, _) = c.to_circuit().eval(s, b)?;
Ok(if i < res.len() {
EvalValue::Scalar(res[i])
} else {
EvalValue::Scalar(ScalarField::ZERO)
})
}
OtherExpr::BaseGeneralCircuit(s, b, c, i) => {
let (_, res) = c.to_circuit().eval(s, b)?;
Ok(if i < res.len() {
EvalValue::Base(res[i])
} else {
EvalValue::Base(BaseField::ZERO)
})
}
OtherExpr::BaseArithmeticCircuit(v, c, i) => Ok(EvalValue::Base(
c.to_circuit()
.eval(v)?
.get(i)
.cloned()
.unwrap_or(BaseField::ZERO),
)),
OtherExpr::PlaintextCurveToExtendedEdwards(c, i) => {
let edwards_point = c.inner().inner();
let prime_order_repr = Scalar::from_canonical_bytes(FOUR_INV_MOD_ELL).unwrap()
* (Scalar::from(4u8) * edwards_point);
if prime_order_repr.Z() == FieldElement::ZERO {
Err(EvalFailure::ub("Point at infinity"))
} else {
let Z_inv = prime_order_repr.Z().invert();
Ok(match i {
0 => EvalValue::Base(BaseField::from_le_bytes(
(&prime_order_repr.X() * &Z_inv).as_bytes(),
)),
1 => EvalValue::Base(BaseField::from_le_bytes(
(&prime_order_repr.Y() * &Z_inv).as_bytes(),
)),
2 => EvalValue::Base(BaseField::from(1)),
3 => EvalValue::Base(BaseField::from_le_bytes(
(&prime_order_repr.T() * &Z_inv).as_bytes(),
)),
_ => EvalValue::Base(BaseField::ZERO),
})
}
}
OtherExpr::CurveFromPlaintextExtendedEdwards(X, Y, Z, T) => {
if -X * X + Y * Y == Z * Z + BaseField::from_le_bytes(EDWARDS25519_D) * T * T
&& X * Y == Z * T
&& Z != BaseField::ZERO
{
let edwards_point = EdwardsPoint::new_unchecked(
FieldElement::from_bytes(&X.to_le_bytes()),
FieldElement::from_bytes(&Y.to_le_bytes()),
FieldElement::from_bytes(&Z.to_le_bytes()),
FieldElement::from_bytes(&T.to_le_bytes()),
);
let ristretto_point = RistrettoPoint::new_unchecked(
(edwards_point * Scalar::from(2u8)) * Scalar::TWO_INV,
);
Ok(EvalValue::Curve(CurvePoint::new(ristretto_point)))
} else {
Ok(EvalValue::Curve(CurvePoint::identity()))
}
}
OtherExpr::CompressPlaintextPoint(c, i) => {
assert!(i < 256);
let compressed = c.to_bytes();
Ok(EvalValue::Bit(compressed[i / 8] >> (i % 8) & 1u8 == 1u8))
}
OtherExpr::ToProjective(c, idx) => {
if idx < 3 {
OtherExpr::PlaintextCurveToExtendedEdwards(c, idx).eval()
} else {
Ok(EvalValue::Base(BaseField::ZERO))
}
}
OtherExpr::MxeKey(k) => Ok(k.mock_eval()),
}
}
}
impl OtherExpr<bool> {
pub fn is_plaintext(&self) -> bool {
if let OtherExpr::MxeKey(i) = self {
i.is_plaintext()
} else {
self.get_deps().iter().all(|x| *x)
}
}
}
impl<T: Clone, B: Clone, C: Clone> OtherExpr<T, B, C> {
pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
!matches!(self, OtherExpr::MxeKey(_))
}
pub fn get_input(&self) -> Option<InputId> {
None
}
pub fn get_input_name(&self) -> &str {
""
}
pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
None
}
}
impl OtherExpr<FieldBounds<ScalarField>, FieldBounds<BaseField>, CurveBounds> {
#[allow(non_snake_case)]
pub fn bounds(self) -> Bounds {
match self {
OtherExpr::ScalarGeneralCircuit(s, b, c, i) => {
let (res, _) = c.to_circuit().bounds(s, b);
Bounds::Scalar(if i < res.len() {
res[i]
} else {
FieldBounds::new(ScalarField::ZERO, ScalarField::ZERO)
})
}
OtherExpr::BaseGeneralCircuit(s, b, c, i) => {
let (_, res) = c.to_circuit().bounds(s, b);
Bounds::Base(if i < res.len() {
res[i]
} else {
FieldBounds::new(BaseField::ZERO, BaseField::ZERO)
})
}
OtherExpr::BaseArithmeticCircuit(v, c, i) => Bounds::Base(
c.to_circuit()
.bounds(v)
.get(i)
.cloned()
.unwrap_or(FieldBounds::new(BaseField::ZERO, BaseField::ZERO)),
),
OtherExpr::PlaintextCurveToExtendedEdwards(c, i) => {
let c_prime_order_repr = c.as_constant().map(|point| {
CurvePoint::new(
Scalar::from_canonical_bytes(FOUR_INV_MOD_ELL).unwrap()
* (Scalar::from(4u8) * point.inner()),
)
});
match (c_prime_order_repr, i) {
(Some(point), _) if point.inner().inner().Z() == FieldElement::ZERO => {
Bounds::Base(FieldBounds::from(BaseField::ZERO))
}
(Some(point), 0) => Bounds::Base(FieldBounds::from(BaseField::from_le_bytes(
(&point.inner().inner().X() * &point.inner().inner().Z().invert())
.as_bytes(),
))),
(Some(point), 1) => Bounds::Base(FieldBounds::from(BaseField::from_le_bytes(
(&point.inner().inner().Y() * &point.inner().inner().Z().invert())
.as_bytes(),
))),
(Some(_), 2) => Bounds::Base(FieldBounds::from(BaseField::from(1))),
(Some(point), 3) => Bounds::Base(FieldBounds::from(BaseField::from_le_bytes(
(&point.inner().inner().T() * &point.inner().inner().Z().invert())
.as_bytes(),
))),
(_, t) if t >= 4 => Bounds::Base(FieldBounds::from(BaseField::ZERO)),
_ => Bounds::Base(FieldBounds::All),
}
}
OtherExpr::CurveFromPlaintextExtendedEdwards(X, Y, Z, T) => {
match (
X.as_constant(),
Y.as_constant(),
Z.as_constant(),
T.as_constant(),
) {
(Some(X), Some(Y), Some(Z), Some(T)) => {
if -X * X + Y * Y
== Z * Z + BaseField::from_le_bytes(EDWARDS25519_D) * T * T
&& X * Y == Z * T
&& Z != BaseField::ZERO
{
let edwards_point = EdwardsPoint::new_unchecked(
FieldElement::from_bytes(&X.to_le_bytes()),
FieldElement::from_bytes(&Y.to_le_bytes()),
FieldElement::from_bytes(&Z.to_le_bytes()),
FieldElement::from_bytes(&T.to_le_bytes()),
);
let ristretto_point = RistrettoPoint::new_unchecked(
(edwards_point * Scalar::from(2u8)) * Scalar::TWO_INV,
);
Bounds::Curve(CurveBounds::Constant(CurvePoint::new(ristretto_point)))
} else {
Bounds::Curve(CurveBounds::Constant(CurvePoint::identity()))
}
}
_ => Bounds::Curve(CurveBounds::All),
}
}
OtherExpr::CompressPlaintextPoint(c, i) => match c.as_constant() {
Some(point) => Bounds::Bit(BoolBounds::from(
point.to_bytes()[i / 8] >> (i % 8) & 1u8 == 1u8,
)),
None => Bounds::Bit(BoolBounds::new(true, true)),
},
OtherExpr::ToProjective(c, idx) => match c.as_constant() {
Some(point) => {
let v = OtherExpr::ToProjective(point, idx).eval();
match v {
Ok(EvalValue::Base(v)) => Bounds::Base(FieldBounds::new(v, v)),
_ => Bounds::Base(FieldBounds::All),
}
}
None => Bounds::Base(FieldBounds::All),
},
OtherExpr::MxeKey(k) => k.bounds(),
}
}
}