use crate::{
core::{
bounds::{CurveBounds, FieldBounds, IsBounds},
expressions::{
expr::EvalFailure,
macro_uses::{BoundWrap, ConstantBoundSampler},
InputKind,
},
},
utils::{curve_point::CurvePoint, field::ScalarField, ignore_for_equality::IgnoreForEquality},
};
use arcis_internal_expr_macro::Expr;
use primitives::algebra::{
elliptic_curve::{Curve as AsyncMPCCurve, Curve25519Ristretto},
field::subfield_element::SubfieldElement,
};
use serde::{Deserialize, Serialize};
use std::{cell::Cell, marker::PhantomData, rc::Rc};
pub type InputId = usize;
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct InputInfo {
pub kind: InputKind,
pub name: String,
pub has_already_been_found_unused: IgnoreForEquality<Cell<bool>>,
}
impl InputInfo {
pub fn to_async_mpc_input(&self) -> core_utils::circuit::Input<Curve25519Ristretto> {
match self.kind {
InputKind::Secret => core_utils::circuit::Input::Share {
algebraic_type: core_utils::circuit::AlgebraicType::Point,
batched: core_utils::circuit::Batched::No,
},
InputKind::SecretFromPlayer(i) => core_utils::circuit::Input::SecretPlaintext {
inputer: i,
algebraic_type: core_utils::circuit::AlgebraicType::Point,
batched: core_utils::circuit::Batched::No,
},
InputKind::Plaintext => {
core_utils::circuit::Input::Point(core_utils::circuit::PointPlaintext::<
Curve25519Ristretto,
>::Input(1))
}
}
}
}
impl Default for InputInfo {
fn default() -> Self {
let kind = InputKind::Plaintext;
let name = "_".to_owned();
let has_already_been_found_unused = IgnoreForEquality(Cell::new(false));
Self {
kind,
name,
has_already_been_found_unused,
}
}
}
impl InputInfo {
pub fn is_plaintext(&self) -> bool {
self.kind.is_plaintext()
}
}
impl From<InputKind> for InputInfo {
fn from(value: InputKind) -> Self {
InputInfo {
kind: value,
..InputInfo::default()
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, Expr)]
pub enum CurveExpr<C: Clone, S: Clone = C> {
Input(InputId, Rc<InputInfo>),
Add(C, C),
Neg(C),
Mul(S, C),
Reveal(C),
Val(CurvePoint),
}
impl CurveExpr<bool> {
pub fn is_plaintext(&self) -> bool {
match self {
CurveExpr::Input(_, info) => info.is_plaintext(),
CurveExpr::Reveal(_) => true,
CurveExpr::Val(_) => true,
_ => self.get_deps().iter().all(|x| *x),
}
}
}
impl<T: Clone, S: Clone> CurveExpr<T, S> {
pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
!matches!(self, CurveExpr::Input(_, _))
}
pub fn get_input(&self) -> Option<InputId> {
match self {
CurveExpr::Input(id, _) => Some(*id),
_ => None,
}
}
pub fn get_input_name(&self) -> &str {
match self {
CurveExpr::Input(_, info) => info.name.as_str(),
_ => "",
}
}
pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
match self {
CurveExpr::Input(_, info) => Some(&info.has_already_been_found_unused.0),
_ => None,
}
}
}
impl CurveExpr<CurvePoint, ScalarField> {
pub fn eval(self) -> Result<CurvePoint, EvalFailure> {
let val: CurvePoint = match self {
CurveExpr::Input(_, _) => EvalFailure::err_imp("Input not evaluable here")?,
CurveExpr::Add(e1, e2) => e1 + e2,
CurveExpr::Mul(e1, e2) => {
let e1 = SubfieldElement::new(
<Curve25519Ristretto as AsyncMPCCurve>::Scalar::from_canonical_bytes(
e1.to_le_bytes(),
)
.unwrap(),
);
e2 * e1
}
CurveExpr::Reveal(e) => e,
CurveExpr::Val(e) => e,
CurveExpr::Neg(e) => -e,
};
Ok(val)
}
}
impl CurveExpr<CurveBounds, FieldBounds<ScalarField>> {
pub fn bounds(self) -> CurveBounds {
match self {
CurveExpr::Input(_, _) => CurveBounds::All,
_ => {
let deps = ApplyCurveExpr::apply(&mut BoundWrap, self.clone()).get_deps();
if deps.iter().any(|x| x.is_empty()) {
CurveBounds::Empty
} else if deps.iter().all(|x1| x1.as_constant().is_some()) {
let sampled = ApplyCurveExpr::apply(&mut ConstantBoundSampler, self);
let val = sampled.eval().expect("Can't evaluate constant CurveExpr");
CurveBounds::Constant(val)
} else {
CurveBounds::All
}
}
}
}
}