use std::collections::{btree_map::Entry, BTreeMap};
use ff::Field;
use midnight_curves::msm::msm_best;
use midnight_proofs::{
circuit::{Layouter, Value},
plonk::Error,
};
use crate::{
field::AssignedNative,
instructions::PublicInputInstructions,
types::{InnerValue, Instantiable},
verifier::{
types::SelfEmulation,
utils::{
add_bounded_scalars, assign_bounded_scalars, mul_bounded_scalars, AssignedBoundedScalar,
},
},
};
#[derive(Clone, Debug)]
pub struct Msm<S: SelfEmulation> {
bases: Vec<S::C>,
scalars: Vec<S::F>,
fixed_base_scalars: BTreeMap<String, S::F>,
}
#[derive(Clone, Debug)]
pub struct AssignedMsm<S: SelfEmulation> {
bases: Vec<S::AssignedPoint>,
pub(crate) scalars: Vec<AssignedBoundedScalar<S::F>>,
fixed_base_scalars: BTreeMap<String, AssignedBoundedScalar<S::F>>,
}
impl<S: SelfEmulation> PartialEq for AssignedMsm<S> {
fn eq(&self, other: &Self) -> bool {
self.bases == other.bases
&& self.scalars == other.scalars
&& self.fixed_base_scalars == other.fixed_base_scalars
}
}
impl<S: SelfEmulation> Eq for AssignedMsm<S> {}
impl<S: SelfEmulation> Msm<S> {
pub fn new(
bases: &[S::C],
scalars: &[S::F],
fixed_base_scalars: &BTreeMap<String, S::F>,
) -> Self {
assert_eq!(bases.len(), scalars.len());
Msm {
bases: bases.to_vec(),
scalars: scalars.to_vec(),
fixed_base_scalars: fixed_base_scalars.clone(),
}
}
pub fn bases(&self) -> Vec<S::C> {
self.bases.clone()
}
pub fn scalars(&self) -> Vec<S::F> {
self.scalars.clone()
}
pub fn fixed_base_scalars(&self) -> BTreeMap<String, S::F> {
self.fixed_base_scalars.clone()
}
pub fn from_terms(bases: &[S::C], scalars: &[S::F]) -> Self {
assert_eq!(bases.len(), scalars.len());
Msm {
bases: bases.to_vec(),
scalars: scalars.to_vec(),
fixed_base_scalars: BTreeMap::new(),
}
}
pub fn collapse(&mut self) {
let affine_bases: Vec<S::G1Affine> = self.bases.iter().map(|&b| b.into()).collect();
let collapsed_base = msm_best(&self.scalars, &affine_bases);
self.bases = vec![collapsed_base];
self.scalars = vec![S::F::ONE];
}
pub fn resolve_fixed_bases(&mut self, fixed_bases: &BTreeMap<String, S::C>) {
for (name, scalar) in &self.fixed_base_scalars {
let base = fixed_bases.get(name).unwrap_or_else(|| panic!("Base not provided: {name}"));
self.bases.push(*base);
self.scalars.push(*scalar);
}
self.fixed_base_scalars.clear();
}
pub fn eval(&self, fixed_bases: &BTreeMap<String, S::C>) -> S::C {
let mut bases = self.bases.clone();
let mut scalars = self.scalars.clone();
for (key, scalar) in self.fixed_base_scalars.iter() {
let base = fixed_bases.get(key).unwrap_or_else(|| panic!("Base not provided: {key}"));
bases.push(*base);
scalars.push(*scalar);
}
let affine_bases: Vec<S::G1Affine> = bases.iter().map(|&b| b.into()).collect();
msm_best(&scalars, &affine_bases)
}
pub fn accumulate_with_r(&self, other: &Self, r: S::F) -> Self {
let mut acc = self.clone();
acc.bases.extend(other.bases.clone());
acc.scalars.extend(other.scalars.iter().map(|s| *s * r));
for (key, value) in other.fixed_base_scalars.clone() {
let r_times_value = r * value;
acc.fixed_base_scalars
.entry(key)
.and_modify(|e| *e += r_times_value)
.or_insert(r_times_value);
}
acc
}
}
impl<S: SelfEmulation> InnerValue for AssignedMsm<S> {
type Element = Msm<S>;
fn value(&self) -> Value<Self::Element> {
let bases: Value<Vec<S::C>> = Value::from_iter(self.bases.iter().map(|base| base.value()));
let scalars: Value<Vec<S::F>> =
Value::from_iter(self.scalars.iter().map(|s| s.scalar.value().copied()));
let fixed_based_scalars: Value<BTreeMap<String, S::F>> = Value::from_iter(
self.fixed_base_scalars
.iter()
.map(|(name, s)| s.scalar.value().map(|s| (name.clone(), *s))),
);
scalars
.zip(bases)
.zip(fixed_based_scalars)
.map(|((scalars, bases), fixed_base_scalars)| Msm {
bases,
scalars,
fixed_base_scalars,
})
}
}
impl<S: SelfEmulation> Instantiable<S::F> for AssignedMsm<S> {
fn as_public_input(msm: &Msm<S>) -> Vec<S::F> {
[
msm.bases.iter().flat_map(S::AssignedPoint::as_public_input).collect::<Vec<_>>(),
msm.scalars.clone(),
msm.fixed_base_scalars.values().copied().collect::<Vec<_>>(),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
}
}
impl<S: SelfEmulation> AssignedMsm<S> {
pub fn as_public_input_with_committed_scalars(msm: &Msm<S>) -> (Vec<S::F>, Vec<S::F>) {
let normal_instance =
msm.bases.iter().flat_map(S::AssignedPoint::as_public_input).collect();
let committed_instance = [
msm.scalars.clone(),
msm.fixed_base_scalars.values().copied().collect(),
]
.concat();
(normal_instance, committed_instance)
}
}
impl<S: SelfEmulation> AssignedMsm<S> {
pub(crate) fn in_circuit_as_public_input(
&self,
layouter: &mut impl Layouter<S::F>,
curve_chip: &S::CurveChip,
) -> Result<Vec<AssignedNative<S::F>>, Error> {
Ok([
self.bases
.iter()
.map(|base| curve_chip.as_public_input(layouter, base))
.collect::<Result<Vec<_>, Error>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
self.scalars.iter().map(|s| s.clone().scalar).collect::<Vec<_>>(),
self.fixed_base_scalars.values().map(|s| s.clone().scalar).collect::<Vec<_>>(),
]
.into_iter()
.flatten()
.collect())
}
pub(crate) fn constrain_as_public_input(
&self,
layouter: &mut impl Layouter<S::F>,
curve_chip: &S::CurveChip,
scalar_chip: &S::ScalarChip,
) -> Result<(), Error> {
self.bases
.iter()
.try_for_each(|base| curve_chip.constrain_as_public_input(layouter, base))?;
self.scalars
.iter()
.try_for_each(|s| scalar_chip.constrain_as_public_input(layouter, &s.clone().scalar))?;
self.fixed_base_scalars
.values()
.try_for_each(|s| scalar_chip.constrain_as_public_input(layouter, &s.clone().scalar))
}
pub(crate) fn constrain_as_public_input_with_committed_scalars(
&self,
layouter: &mut impl Layouter<S::F>,
curve_chip: &S::CurveChip,
scalar_chip: &S::ScalarChip,
) -> Result<(), Error> {
self.bases
.iter()
.try_for_each(|base| curve_chip.constrain_as_public_input(layouter, base))?;
self.scalars.iter().try_for_each(|s| {
let mut a = S::F::ZERO;
s.scalar.clone().value().map(|v| a = *v);
S::constrain_scalar_as_committed_public_input(layouter, scalar_chip, &s.scalar)
})?;
self.fixed_base_scalars.values().try_for_each(|s| {
S::constrain_scalar_as_committed_public_input(layouter, scalar_chip, &s.scalar)
})
}
}
impl<S: SelfEmulation> AssignedMsm<S> {
pub fn assign(
layouter: &mut impl Layouter<S::F>,
curve_chip: &S::CurveChip,
scalar_chip: &S::ScalarChip,
len: usize,
fixed_base_names: &[String],
msm_value: Value<Msm<S>>,
) -> Result<Self, Error> {
let bases_val = msm_value.as_ref().map(|msm| msm.bases.clone()).transpose_vec(len);
let scalars_val = msm_value.as_ref().map(|msm| msm.scalars.clone()).transpose_vec(len);
let fixed_base_scalars_val = msm_value
.as_ref()
.map(|msm| {
msm.fixed_base_scalars.iter().map(|s| *s.1).collect::<Vec<_>>()
})
.transpose_vec(fixed_base_names.len());
let mut fixed_base_names = fixed_base_names.to_vec();
fixed_base_names.sort();
let bases = bases_val
.iter()
.map(|p| S::assign_without_subgroup_check(layouter, curve_chip, *p))
.collect::<Result<Vec<_>, Error>>()?;
let scalars = assign_bounded_scalars(layouter, scalar_chip, &scalars_val)?;
let fixed_base_scalars: BTreeMap<String, AssignedBoundedScalar<S::F>> = {
let scalars = assign_bounded_scalars(layouter, scalar_chip, &fixed_base_scalars_val)?;
fixed_base_names.iter().cloned().zip(scalars).collect()
};
Ok(AssignedMsm {
scalars,
bases,
fixed_base_scalars,
})
}
pub fn empty() -> Self {
Self {
scalars: vec![],
bases: vec![],
fixed_base_scalars: BTreeMap::new(),
}
}
pub fn from_term(scalar: &AssignedBoundedScalar<S::F>, base: &S::AssignedPoint) -> Self {
Self {
scalars: vec![scalar.clone()],
bases: vec![base.clone()],
fixed_base_scalars: BTreeMap::new(),
}
}
pub fn from_fixed_term(scalar: &AssignedBoundedScalar<S::F>, base_name: &str) -> Self {
Self {
scalars: vec![],
bases: vec![],
fixed_base_scalars: [(base_name.to_string(), scalar.clone())].into_iter().collect(),
}
}
pub fn add_term(&mut self, scalar: &AssignedBoundedScalar<S::F>, base: &S::AssignedPoint) {
self.scalars.push(scalar.clone());
self.bases.push(base.clone());
}
pub fn add_msm(
&mut self,
layouter: &mut impl Layouter<S::F>,
scalar_chip: &S::ScalarChip,
other: &Self,
) -> Result<(), Error> {
self.scalars.extend(other.scalars.clone());
self.bases.extend(other.bases.clone());
for (key, value) in other.fixed_base_scalars.clone() {
match self.fixed_base_scalars.entry(key) {
Entry::Occupied(mut occ) => {
*occ.get_mut() = add_bounded_scalars(layouter, scalar_chip, occ.get(), &value)?;
}
Entry::Vacant(vac) => {
vac.insert(value);
}
}
}
Ok(())
}
pub fn collapse(
&mut self,
layouter: &mut impl Layouter<S::F>,
curve_chip: &S::CurveChip,
scalar_chip: &S::ScalarChip,
) -> Result<(), Error> {
let scalars = self
.scalars
.iter()
.map(|s| (s.scalar.clone(), s.bound.bits() as usize))
.collect::<Vec<_>>();
let collapsed_base = S::msm(layouter, curve_chip, &scalars, &self.bases)?;
self.bases = vec![collapsed_base];
self.scalars = vec![AssignedBoundedScalar::one(layouter, scalar_chip)?];
Ok(())
}
pub fn resolve_fixed_bases(&mut self, fixed_bases: &BTreeMap<String, S::AssignedPoint>) {
for (name, scalar) in &self.fixed_base_scalars {
let base = fixed_bases
.get(name)
.unwrap_or_else(|| panic!("Fixed base not provided: {name}"));
self.bases.push(base.clone());
self.scalars.push(scalar.clone());
}
self.fixed_base_scalars.clear();
}
pub fn scale(
&mut self,
layouter: &mut impl Layouter<S::F>,
scalar_chip: &S::ScalarChip,
r: &AssignedBoundedScalar<S::F>,
) -> Result<(), Error> {
self.scalars = (self.scalars.iter())
.map(|s| mul_bounded_scalars(layouter, scalar_chip, s, r))
.collect::<Result<Vec<_>, Error>>()?;
for s in self.fixed_base_scalars.values_mut() {
*s = mul_bounded_scalars(layouter, scalar_chip, s, r)?;
}
Ok(())
}
pub fn accumulate_with_r(
&self,
layouter: &mut impl Layouter<S::F>,
scalar_chip: &S::ScalarChip,
other: &Self,
r: &AssignedBoundedScalar<S::F>,
) -> Result<Self, Error> {
let mut other = other.clone();
other.scale(layouter, scalar_chip, r)?;
let mut acc = self.clone();
acc.add_msm(layouter, scalar_chip, &other)?;
Ok(acc)
}
}