use crate::utils::{
curve_point::CurvePoint,
field::{BaseField, ScalarField},
};
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DomainElement<T, U, V, W> {
Bit(T),
Scalar(U),
Base(V),
Curve(W),
}
impl<T, U, V, W> DomainElement<T, U, V, W> {
pub fn same_domain<T_, U_, V_, W_>(&self, other: &DomainElement<T_, U_, V_, W_>) -> bool {
match (self, other) {
(DomainElement::Bit(_), DomainElement::Bit(_)) => true,
(DomainElement::Scalar(_), DomainElement::Scalar(_)) => true,
(DomainElement::Base(_), DomainElement::Base(_)) => true,
(DomainElement::Curve(_), DomainElement::Curve(_)) => true,
(_, _) => false,
}
}
pub fn to_domain(&self) -> DomainElement<(), (), (), ()> {
match self {
DomainElement::Bit(_) => DomainElement::Bit(()),
DomainElement::Scalar(_) => DomainElement::Scalar(()),
DomainElement::Base(_) => DomainElement::Base(()),
DomainElement::Curve(_) => DomainElement::Curve(()),
}
}
}
pub trait Domain<T, U, V, W> {
type DomainType;
fn wrap(a: Self::DomainType) -> DomainElement<T, U, V, W>;
fn try_unwrap(a: DomainElement<T, U, V, W>) -> Option<Self::DomainType>;
#[inline(always)]
fn unwrap(a: DomainElement<T, U, V, W>) -> Self::DomainType {
Self::try_unwrap(a).unwrap_or_else(|| panic!("Can't unwrap DomainElement, wrong type"))
}
}
impl<T, U, V, W> Domain<T, U, V, W> for bool {
type DomainType = T;
#[inline(always)]
fn wrap(a: Self::DomainType) -> DomainElement<T, U, V, W> {
DomainElement::Bit(a)
}
#[inline(always)]
fn try_unwrap(a: DomainElement<T, U, V, W>) -> Option<Self::DomainType> {
match a {
DomainElement::Bit(b) => Some(b),
DomainElement::Scalar(_) => None,
DomainElement::Base(_) => None,
DomainElement::Curve(_) => None,
}
}
}
impl<T, U, V, W> Domain<T, U, V, W> for ScalarField {
type DomainType = U;
#[inline(always)]
fn wrap(a: Self::DomainType) -> DomainElement<T, U, V, W> {
DomainElement::Scalar(a)
}
#[inline(always)]
fn try_unwrap(a: DomainElement<T, U, V, W>) -> Option<Self::DomainType> {
match a {
DomainElement::Bit(_) => None,
DomainElement::Scalar(x) => Some(x),
DomainElement::Base(_) => None,
DomainElement::Curve(_) => None,
}
}
}
impl<T, U, V, W> Domain<T, U, V, W> for CurvePoint {
type DomainType = W;
fn wrap(a: Self::DomainType) -> DomainElement<T, U, V, W> {
DomainElement::Curve(a)
}
fn try_unwrap(a: DomainElement<T, U, V, W>) -> Option<Self::DomainType> {
match a {
DomainElement::Bit(_) => None,
DomainElement::Scalar(_) => None,
DomainElement::Base(_) => None,
DomainElement::Curve(x) => Some(x),
}
}
}
impl<T, U, V, W> Domain<T, U, V, W> for BaseField {
type DomainType = V;
fn wrap(a: Self::DomainType) -> DomainElement<T, U, V, W> {
DomainElement::Base(a)
}
fn try_unwrap(a: DomainElement<T, U, V, W>) -> Option<Self::DomainType> {
match a {
DomainElement::Bit(_) => None,
DomainElement::Scalar(_) => None,
DomainElement::Base(x) => Some(x),
DomainElement::Curve(_) => None,
}
}
}