#![allow(unsafe_code)]
use ark_ff::Field;
use ark_std::vec::*;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::gr1cs::{
field_interner::{FieldInterner, InternedField},
Variable,
};
#[derive(Debug, Clone, Default)]
pub struct LcMap<F: Field> {
vars: Vec<Variable>,
coeffs: Vec<InternedField>,
offsets: Vec<usize>,
_f: core::marker::PhantomData<F>,
}
pub(crate) fn to_non_interned_lc<'a, F: Field>(
lc: impl Iterator<Item = (&'a InternedField, &'a Variable)> + 'a,
f_interner: &'a FieldInterner<F>,
) -> impl Iterator<Item = (F, Variable)> + 'a {
lc.map(|(&c, &v)| (f_interner.value(c).unwrap(), v))
}
type LcMapIterItem<'a> =
core::iter::Zip<core::slice::Iter<'a, InternedField>, core::slice::Iter<'a, Variable>>;
type LcVarsIterMutItem<'a> = core::slice::IterMut<'a, Variable>;
impl<F: Field> LcMap<F> {
#[inline(always)]
pub fn new() -> Self {
Self {
vars: Vec::new(),
coeffs: Vec::new(),
offsets: vec![0],
_f: core::marker::PhantomData,
}
}
#[inline(always)]
pub fn with_capacity(expected_num_lcs: usize, expected_total_lc_size: usize) -> Self {
let mut result = Self::new();
result.vars.reserve(expected_total_lc_size);
result.coeffs.reserve(expected_total_lc_size);
result.offsets.reserve(expected_num_lcs + 1);
result
}
#[inline(always)]
pub fn push(
&mut self,
v: impl IntoIterator<Item = (F, Variable)>,
f_interner: &mut FieldInterner<F>,
) {
for (coeff, var) in v {
self.coeffs.push(f_interner.get_or_intern(coeff));
self.vars.push(var);
}
self.offsets.push(self.coeffs.len());
}
#[inline(always)]
pub fn push_by_ref<'a>(
&mut self,
v: impl IntoIterator<Item = &'a (F, Variable)>,
f_interner: &mut FieldInterner<F>,
) {
for (coeff, var) in v {
self.coeffs.push(f_interner.get_or_intern(*coeff));
self.vars.push(*var);
}
self.offsets.push(self.coeffs.len());
}
#[inline(always)]
pub fn iter(&self) -> impl Iterator<Item = LcMapIterItem<'_>> {
self.offsets.windows(2).map(|w|
unsafe { windowed_access(w, &self.coeffs, &self.vars) })
}
#[cfg(feature = "parallel")]
#[inline(always)]
pub fn par_iter(&self) -> impl ParallelIterator<Item = LcMapIterItem<'_>> {
self.offsets.par_windows(2).map(|w|
unsafe { windowed_access(w, &self.coeffs, &self.vars) })
}
#[inline(always)]
pub fn lc_vars_iter_mut(&mut self) -> impl Iterator<Item = LcVarsIterMutItem<'_>> {
LcVarsIterMut::new(self)
}
#[cfg(feature = "parallel")]
#[inline(always)]
pub fn lc_vars_par_iter_mut(&mut self) -> LcVarsParIterMut<'_> {
LcVarsParIterMut::new(self)
}
#[inline(always)]
pub fn num_lcs(&self) -> usize {
self.offsets.len() - 1
}
#[inline(always)]
pub fn total_lc_size(&self) -> usize {
self.vars.len()
}
#[allow(unsafe_code)]
#[inline(always)]
pub fn get(&self, idx: usize) -> Option<LcMapIterItem<'_>> {
if idx >= self.num_lcs() || self.offsets.len() < 2 {
cold()
} else {
unsafe {
Some(windowed_access(
self.offsets.get_unchecked(idx..=(idx + 1)),
&self.coeffs,
&self.vars,
))
}
}
}
}
#[inline(always)]
unsafe fn windowed_access<'a>(
w: &'a [usize],
coeffs: &'a [InternedField],
vars: &'a [Variable],
) -> LcMapIterItem<'a> {
debug_assert!(w.len() == 2, "Expected a slice of length 2");
debug_assert!(w[0] <= w[1], "Expected w[0] <= w[1]");
debug_assert!(w[1] <= coeffs.len(), "Expected w[1] <= coeffs.len()");
unsafe {
let start = *w.get_unchecked(0);
let end = *w.get_unchecked(1);
coeffs
.get_unchecked(start..end)
.iter()
.zip(vars.get_unchecked(start..end))
}
}
#[inline(always)]
unsafe fn windowed_access_mut<'b>(
w: &[usize],
vars: &mut &'b mut [Variable],
) -> LcVarsIterMutItem<'b> {
debug_assert!(w.len() == 2, "Expected a slice of length 2");
debug_assert!(w[0] <= w[1], "Expected w[0] <= w[1]");
debug_assert!(w[1] - w[0] <= vars.len(), "`w[1] - w[0] > vars.len()`");
#[allow(unsafe_code)]
unsafe {
let start = *w.get_unchecked(0);
let end = *w.get_unchecked(1);
let len = end - start;
let (v_head, v_tail) = core::mem::take(vars).split_at_mut_unchecked(len);
*vars = v_tail;
v_head.iter_mut()
}
}
#[cold]
fn cold<'a>() -> Option<LcMapIterItem<'a>> {
None
}
pub struct LcVarsIterMut<'a> {
vars: &'a mut [Variable],
offsets: core::slice::Windows<'a, usize>,
}
impl<'a> LcVarsIterMut<'a> {
#[inline(always)]
pub fn new<F: Field>(lc_map: &'a mut LcMap<F>) -> Self {
Self {
vars: &mut lc_map.vars,
offsets: lc_map.offsets.windows(2),
}
}
}
impl<'a> Iterator for LcVarsIterMut<'a> {
type Item = LcVarsIterMutItem<'a>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.offsets.next().map(move |w| {
unsafe { windowed_access_mut(w, &mut self.vars) }
})
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.offsets.len();
(len, Some(len))
}
}
#[cfg(feature = "parallel")]
pub struct LcVarsParIterMut<'a> {
vars: &'a mut [Variable],
offsets: &'a [usize],
}
#[cfg(feature = "parallel")]
impl<'a> LcVarsParIterMut<'a> {
#[inline(always)]
pub fn new<F: Field>(lc_map: &'a mut LcMap<F>) -> Self {
Self {
vars: &mut lc_map.vars,
offsets: &lc_map.offsets,
}
}
}
#[cfg(feature = "parallel")]
impl<'a> ParallelIterator for LcVarsParIterMut<'a> {
type Item = LcVarsIterMutItem<'a>;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: rayon::iter::plumbing::UnindexedConsumer<Self::Item>,
{
rayon::iter::plumbing::bridge(self, consumer)
}
}
#[cfg(feature = "parallel")]
impl<'a> IndexedParallelIterator for LcVarsParIterMut<'a> {
fn len(&self) -> usize {
self.offsets.len().saturating_sub(1)
}
fn drive<C>(self, consumer: C) -> C::Result
where
C: rayon::iter::plumbing::Consumer<Self::Item>,
{
rayon::iter::plumbing::bridge(self, consumer)
}
fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: rayon::iter::plumbing::ProducerCallback<Self::Item>,
{
struct Producer<'a> {
vars: &'a mut [Variable],
offsets: &'a [usize],
}
impl<'a> rayon::iter::plumbing::Producer for Producer<'a> {
type Item = LcVarsIterMutItem<'a>;
type IntoIter = std::vec::IntoIter<LcVarsIterMutItem<'a>>;
fn into_iter(mut self) -> Self::IntoIter {
self.offsets
.windows(2)
.map(|w| unsafe { windowed_access_mut(w, &mut self.vars) })
.collect::<Vec<_>>()
.into_iter()
}
fn split_at(mut self, index: usize) -> (Self, Self) {
let left_offsets = &self.offsets[..=index];
let right_offsets = &self.offsets[index..];
let split_point = self.offsets[index];
let (left_vars, right_vars) =
std::mem::take(&mut self.vars).split_at_mut(split_point);
let left = Producer {
vars: left_vars,
offsets: left_offsets,
};
let right = Producer {
vars: right_vars,
offsets: right_offsets,
};
(left, right)
}
}
callback.callback(Producer {
vars: self.vars,
offsets: self.offsets,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ark_test_curves::bls12_381::Fr;
#[test]
#[cfg(feature = "parallel")]
fn test_lc_map_par_iter_mut() {
let mut interner = FieldInterner::<Fr>::new();
let mut lcmap = LcMap::<Fr>::new();
lcmap.push(
[
(1u8.into(), Variable::One),
(2u8.into(), Variable::instance(2)),
],
&mut interner,
);
lcmap.push(
[
(3u8.into(), Variable::witness(4)),
(4u8.into(), Variable::instance(4)),
],
&mut interner,
);
lcmap.lc_vars_par_iter_mut().for_each(|chunk| {
for v in chunk {
if v.is_instance() {
*v = Variable::instance(v.index().unwrap() + 1);
}
}
});
let flattened: Vec<_> = lcmap
.iter()
.flat_map(|chunk| to_non_interned_lc(chunk, &interner))
.collect();
let expected = vec![
(1u8.into(), Variable::One),
(2u8.into(), Variable::instance(3)),
(3u8.into(), Variable::witness(4)),
(4u8.into(), Variable::instance(5)),
];
assert_eq!(flattened, expected);
}
#[test]
fn test_lc_map_iter_mut() {
let mut interner = FieldInterner::<Fr>::new();
let mut lcmap = LcMap::<Fr>::new();
lcmap.push(
[
(1u8.into(), Variable::One),
(2u8.into(), Variable::instance(2)),
],
&mut interner,
);
lcmap.push(
[
(3u8.into(), Variable::witness(4)),
(4u8.into(), Variable::instance(4)),
],
&mut interner,
);
lcmap.lc_vars_iter_mut().for_each(|chunk| {
for v in chunk {
if v.is_instance() {
*v = Variable::instance(v.index().unwrap() + 1);
}
}
});
let flattened: Vec<_> = lcmap
.iter()
.flat_map(|chunk| to_non_interned_lc(chunk, &interner))
.collect();
let expected = vec![
(1u8.into(), Variable::One),
(2u8.into(), Variable::instance(3)),
(3u8.into(), Variable::witness(4)),
(4u8.into(), Variable::instance(5)),
];
assert_eq!(flattened, expected);
}
#[test]
fn test_lc_vars_iter_mut_size_hint_empty() {
let mut lcmap = LcMap::<Fr>::new();
let mut it = lcmap.lc_vars_iter_mut();
let (lower, upper) = it.size_hint();
assert_eq!((lower, upper), (0, Some(0)));
assert!(it.next().is_none());
}
}