use alloc::format;
use alloc::vec::Vec;
use core::iter;
use core::marker::PhantomData;
use crate::errors::{Error, InvalidInstance};
use crate::group::msm::MultiScalarMul;
use ff::Field;
use group::prime::PrimeGroup;
mod convert;
mod ops;
mod canonical;
pub use canonical::CanonicalLinearRelation;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct ScalarVar<G>(usize, PhantomData<G>);
impl<G> ScalarVar<G> {
pub fn index(&self) -> usize {
self.0
}
}
impl<G> core::hash::Hash for ScalarVar<G> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct GroupVar<G>(usize, PhantomData<G>);
impl<G> GroupVar<G> {
pub fn index(&self) -> usize {
self.0
}
}
impl<G> core::hash::Hash for GroupVar<G> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state)
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum ScalarTerm<G> {
Var(ScalarVar<G>),
Unit,
}
impl<G: PrimeGroup> ScalarTerm<G> {
fn value(self, scalars: &[G::Scalar]) -> G::Scalar {
match self {
Self::Var(var) => scalars[var.0],
Self::Unit => G::Scalar::ONE,
}
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct Term<G> {
scalar: ScalarTerm<G>,
elem: GroupVar<G>,
}
#[derive(Copy, Clone, Debug)]
pub struct Weighted<T, F> {
pub term: T,
pub weight: F,
}
#[derive(Clone, Debug)]
pub struct Sum<T>(Vec<T>);
impl<T> Sum<T> {
pub fn terms(&self) -> &[T] {
&self.0
}
}
impl<T> core::iter::Sum<T> for Sum<T> {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = T>,
{
Self(iter.collect())
}
}
pub type LinearCombination<G> = Sum<Weighted<Term<G>, <G as group::Group>::Scalar>>;
impl<G: PrimeGroup + MultiScalarMul> LinearMap<G> {
fn map(&self, scalars: &[G::Scalar]) -> Result<Vec<G>, InvalidInstance> {
self.linear_combinations
.iter()
.map(|lc| {
let weighted_coefficients =
lc.0.iter()
.map(|weighted| weighted.term.scalar.value(scalars) * weighted.weight)
.collect::<Vec<_>>();
let elements =
lc.0.iter()
.map(|weighted| self.group_elements.get(weighted.term.elem))
.collect::<Result<Vec<_>, InvalidInstance>>();
match elements {
Ok(elements) => Ok(G::msm(&weighted_coefficients, &elements)),
Err(error) => Err(error),
}
})
.collect::<Result<Vec<_>, InvalidInstance>>()
}
}
#[derive(Clone, Debug)]
pub struct GroupMap<G>(Vec<Option<G>>);
impl<G: PrimeGroup> GroupMap<G> {
pub fn assign_element(&mut self, var: GroupVar<G>, element: G) {
if self.0.len() <= var.0 {
self.0.resize(var.0 + 1, None);
} else if let Some(assignment) = self.0[var.0] {
assert_eq!(
assignment, element,
"conflicting assignments for var {var:?}"
)
}
self.0[var.0] = Some(element);
}
pub fn assign_elements(&mut self, assignments: impl IntoIterator<Item = (GroupVar<G>, G)>) {
for (var, elem) in assignments.into_iter() {
self.assign_element(var, elem);
}
}
pub fn get(&self, var: GroupVar<G>) -> Result<G, InvalidInstance> {
match self.0.get(var.0) {
Some(Some(elem)) => Ok(*elem),
Some(None) | None => Err(InvalidInstance::new(format!(
"unassigned group variable {}",
var.0
))),
}
}
#[allow(clippy::should_implement_trait)]
pub fn into_iter(self) -> impl Iterator<Item = (GroupVar<G>, Option<G>)> {
self.0
.into_iter()
.enumerate()
.map(|(i, x)| (GroupVar(i, PhantomData), x))
}
pub fn iter(&self) -> impl Iterator<Item = (GroupVar<G>, Option<&G>)> {
self.0
.iter()
.enumerate()
.map(|(i, opt)| (GroupVar(i, PhantomData), opt.as_ref()))
}
pub fn push(&mut self, element: G) -> GroupVar<G> {
let index = self.0.len();
self.0.push(Some(element));
GroupVar(index, PhantomData)
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
impl<G> Default for GroupMap<G> {
fn default() -> Self {
Self(Vec::default())
}
}
impl<G: PrimeGroup> FromIterator<(GroupVar<G>, G)> for GroupMap<G> {
fn from_iter<T: IntoIterator<Item = (GroupVar<G>, G)>>(iter: T) -> Self {
iter.into_iter()
.fold(Self::default(), |mut instance, (var, val)| {
instance.assign_element(var, val);
instance
})
}
}
#[derive(Clone, Default, Debug)]
pub struct LinearMap<G: PrimeGroup> {
pub linear_combinations: Vec<LinearCombination<G>>,
pub group_elements: GroupMap<G>,
pub num_scalars: usize,
pub num_elements: usize,
}
impl<G: PrimeGroup> LinearMap<G> {
pub fn new() -> Self {
Self {
linear_combinations: Vec::new(),
group_elements: GroupMap::default(),
num_scalars: 0,
num_elements: 0,
}
}
pub fn num_constraints(&self) -> usize {
self.linear_combinations.len()
}
pub fn append(&mut self, lc: LinearCombination<G>) {
self.linear_combinations.push(lc);
}
pub fn evaluate(&self, scalars: &[G::Scalar]) -> Result<Vec<G>, Error>
where
G: MultiScalarMul,
{
self.linear_combinations
.iter()
.map(|lc| {
let weighted_coefficients =
lc.0.iter()
.map(|weighted| weighted.term.scalar.value(scalars) * weighted.weight)
.collect::<Vec<_>>();
let elements =
lc.0.iter()
.map(|weighted| self.group_elements.get(weighted.term.elem))
.collect::<Result<Vec<_>, _>>()?;
Ok(G::msm(&weighted_coefficients, &elements))
})
.collect()
}
}
#[derive(Clone, Default, Debug)]
pub struct LinearRelation<G: PrimeGroup> {
pub linear_map: LinearMap<G>,
pub image: Vec<GroupVar<G>>,
}
impl<G: PrimeGroup> LinearRelation<G> {
pub fn new() -> Self {
Self {
linear_map: LinearMap::new(),
image: Vec::new(),
}
}
pub fn append_equation(&mut self, lhs: GroupVar<G>, rhs: impl Into<LinearCombination<G>>) {
self.linear_map.append(rhs.into());
self.image.push(lhs);
}
pub fn allocate_eq(&mut self, rhs: impl Into<LinearCombination<G>>) -> GroupVar<G> {
let var = self.allocate_element();
self.append_equation(var, rhs);
var
}
pub fn allocate_scalar(&mut self) -> ScalarVar<G> {
self.linear_map.num_scalars += 1;
ScalarVar(self.linear_map.num_scalars - 1, PhantomData)
}
pub fn allocate_scalars<const N: usize>(&mut self) -> [ScalarVar<G>; N] {
let mut vars = [ScalarVar(usize::MAX, PhantomData); N];
for var in vars.iter_mut() {
*var = self.allocate_scalar();
}
vars
}
pub fn allocate_scalars_vec(&mut self, n: usize) -> Vec<ScalarVar<G>> {
(0..n).map(|_| self.allocate_scalar()).collect()
}
pub fn allocate_element(&mut self) -> GroupVar<G> {
self.linear_map.num_elements += 1;
GroupVar(self.linear_map.num_elements - 1, PhantomData)
}
pub fn allocate_element_with(&mut self, element: G) -> GroupVar<G> {
let var = self.allocate_element();
self.set_element(var, element);
var
}
pub fn allocate_elements<const N: usize>(&mut self) -> [GroupVar<G>; N] {
let mut vars = [GroupVar(usize::MAX, PhantomData); N];
for var in vars.iter_mut() {
*var = self.allocate_element();
}
vars
}
pub fn allocate_elements_vec(&mut self, n: usize) -> Vec<GroupVar<G>> {
(0..n).map(|_| self.allocate_element()).collect()
}
pub fn allocate_elements_with(&mut self, elements: &[G]) -> Vec<GroupVar<G>> {
elements
.iter()
.map(|element| self.allocate_element_with(*element))
.collect()
}
pub fn set_element(&mut self, var: GroupVar<G>, element: G) {
self.linear_map.group_elements.assign_element(var, element)
}
pub fn set_elements(&mut self, assignments: impl IntoIterator<Item = (GroupVar<G>, G)>) {
self.linear_map.group_elements.assign_elements(assignments)
}
pub fn compute_image(&mut self, scalars: &[G::Scalar]) -> Result<(), Error>
where
G: MultiScalarMul,
{
if self.linear_map.num_constraints() != self.image.len() {
panic!("invalid LinearRelation: different number of constraints and image variables");
}
let mapped_scalars = self.linear_map.map(scalars)?;
for (mapped_scalar, lhs) in iter::zip(mapped_scalars, &self.image) {
self.linear_map
.group_elements
.assign_element(*lhs, mapped_scalar)
}
Ok(())
}
pub fn image(&self) -> Result<Vec<G>, InvalidInstance> {
self.image
.iter()
.map(|&var| self.linear_map.group_elements.get(var))
.collect()
}
pub fn canonical(&self) -> Result<CanonicalLinearRelation<G>, InvalidInstance>
where
G: MultiScalarMul,
{
self.try_into()
}
}