use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::{
boolean::{boolean_value::BooleanValue, utils::decoder_circuit},
traits::arithmetic_circuit::ArithmeticCircuit,
},
expressions::{expr::EvalFailure, field_expr::FieldExpr},
global_value::value::FieldValue,
},
traits::{GetBit, Select},
utils::used_field::UsedField,
};
use num_traits::ToPrimitive;
pub fn one_hot_encode<F: ActuallyUsedField>(x: FieldValue<F>, len: usize) -> Vec<BooleanValue> {
let bounds = x.bounds();
let (min, max) = bounds.min_and_max(true);
if max.is_lt_zero() || (min.is_ge_zero() && F::from(len as u64) <= min) {
return vec![BooleanValue::from(false); len];
}
let mut ohe = if min.eq(&max) {
vec![BooleanValue::from(true)]
} else if min.is_ge_zero() {
let x_sub = x - FieldValue::from(min);
let gap = (len - 1 - min.to_unsigned_number().to_usize().unwrap()).min(
(max - min)
.to_unsigned_number()
.to_usize()
.unwrap_or(usize::MAX),
);
let bin_size = gap.max(1).ilog2() as usize + 1;
let bits = (0..bin_size)
.map(|i| x_sub.get_bit(i, false))
.collect::<Vec<BooleanValue>>();
decoder_circuit(bits)
.into_iter()
.take(gap + 1)
.collect::<Vec<BooleanValue>>()
} else {
let bin_size = bounds.bin_size(true);
let bits = (0..bin_size)
.map(|i| x.get_bit(i, true))
.collect::<Vec<BooleanValue>>();
let n_bits_to_keep = F::from(len as u64 - 1)
.min(max, false)
.to_unsigned_number()
.to_usize()
.unwrap()
.max(1)
.ilog2() as usize
+ 1;
decoder_circuit(
bits.into_iter()
.take(n_bits_to_keep)
.collect::<Vec<BooleanValue>>(),
)
.into_iter()
.take(
F::from(len as u64)
.min(max + F::ONE, false)
.to_unsigned_number()
.to_usize()
.unwrap(),
)
.collect::<Vec<BooleanValue>>()
};
let ohe_len = ohe.len();
if min.is_ge_zero() {
let mut res = vec![BooleanValue::from(false); min.to_unsigned_number().to_usize().unwrap()];
res.append(&mut ohe);
res.append(&mut vec![
BooleanValue::from(false);
len - min.to_unsigned_number().to_usize().unwrap()
- ohe_len
]);
res
} else {
let mut res = ohe;
res.append(&mut vec![BooleanValue::from(false); (len - ohe_len).max(0)]);
res
}
}
#[derive(Clone, Debug, Default)]
pub struct Index;
impl Index {
pub fn index<F: ActuallyUsedField>(
container: Vec<FieldValue<F>>,
index: FieldValue<F>,
) -> FieldValue<F> {
let bounds = index.bounds();
let signed_min = bounds.signed_min();
if bounds.signed_max().is_lt_zero()
|| (signed_min.is_ge_zero() && signed_min >= F::from(container.len() as u64))
{
FieldValue::from(F::ZERO)
} else if bounds.signed_max().eq(&F::ZERO) {
container[0]
} else if bounds.signed_min().eq(&F::from(container.len() as u64 - 1)) {
*container.last().unwrap()
} else {
let index_ohe = one_hot_encode(index, container.len());
let zero = FieldValue::<F>::from(0);
container
.into_iter()
.zip(index_ohe)
.fold(zero, |acc, (x, b)| acc + b.select(x, zero))
}
}
}
impl<F: UsedField> ArithmeticCircuit<F> for Index {
fn eval(&self, mut x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
let index = x.pop().unwrap();
if index.is_lt_zero() || (index - F::from((x.len() - 1) as u64)).is_gt_zero() {
return EvalFailure::err_ub("index out of range");
}
Ok(vec![x[index.to_unsigned_number().to_usize().unwrap()]])
}
fn bounds(&self, mut bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
let bounds_index = bounds.pop().unwrap();
let signed_min = bounds_index.signed_min();
let signed_max = bounds_index.signed_max();
let res_bounds = if signed_max.is_lt_zero()
|| (signed_min.is_ge_zero() && signed_min >= F::from(bounds.len() as u64))
{
FieldBounds::new(F::ZERO, F::ZERO)
} else if signed_max == F::ZERO {
bounds[0]
} else if signed_min == F::from((bounds.len() - 1) as u64) {
*bounds.last().unwrap()
} else {
let init_bounds =
if signed_min.is_lt_zero() || signed_max > F::from((bounds.len() - 1) as u64) {
FieldBounds::new(F::ZERO, F::ZERO)
} else {
FieldBounds::Empty
};
bounds
.into_iter()
.fold(init_bounds, |acc, bds| acc.union(bds))
};
vec![res_bounds]
}
fn run(&self, mut vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
let index = vals.pop().unwrap();
vec![Self::index(vals, index)]
}
}
#[derive(Clone, Debug)]
pub struct IndexOpAssign<F: ActuallyUsedField> {
op: FieldExpr<F, bool>,
}
impl<F: ActuallyUsedField> IndexOpAssign<F> {
#[allow(unused)]
pub fn new(op: FieldExpr<F, bool>) -> Self {
IndexOpAssign { op }
}
}
impl<F: ActuallyUsedField> IndexOpAssign<F> {
pub fn index_op_assign(
&self,
container: Vec<FieldValue<F>>,
index: FieldValue<F>,
value: FieldValue<F>,
) -> Vec<FieldValue<F>> {
let bounds = index.bounds();
let signed_min = bounds.signed_min();
let signed_max = bounds.signed_max();
if signed_max.is_lt_zero()
|| signed_min.is_ge_zero() && signed_min >= F::from(container.len() as u64)
{
container
} else if signed_max == F::ZERO {
let mut res = container;
let arr = [res[0], value];
res[0] = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
res
} else if signed_min == F::from(container.len() as u64 - 1) {
let len = container.len();
let mut res = container;
let arr = [res[len - 1], value];
res[len - 1] = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
res
} else {
let index_ohe = one_hot_encode(index, container.len());
match self.op {
FieldExpr::Add(_, _) | FieldExpr::Sub(_, _) => container
.into_iter()
.zip(index_ohe)
.map(|(x, b)| {
let arr = [x, value];
let new_value = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
b.select(new_value, x)
})
.collect::<Vec<FieldValue<F>>>(),
_ => {
let arr = [Index::index(container.clone(), index), value];
let new_value = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
container
.into_iter()
.zip(index_ohe)
.map(|(x, b)| b.select(new_value, x))
.collect::<Vec<FieldValue<F>>>()
}
}
}
}
}
impl<F: ActuallyUsedField> ArithmeticCircuit<F> for IndexOpAssign<F> {
fn eval(&self, mut x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
let index = x.pop().unwrap();
let value = x.pop().unwrap();
if index >= F::from(x.len() as u64) {
return EvalFailure::err_ub("index out of range");
}
let index = index.to_unsigned_number().to_usize().unwrap();
let arr = [x[index], value];
x[index] = self.op.clone().apply(|i| arr[i as usize]).eval()?;
Ok(x)
}
fn bounds(&self, mut bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
let bounds_index = bounds.pop().unwrap();
let bounds_value = bounds.pop().unwrap();
let signed_min = bounds_index.signed_min();
let signed_max = bounds_index.signed_max();
if signed_max.is_lt_zero()
|| signed_min.is_ge_zero() && signed_min >= F::from(bounds.len() as u64)
{
bounds
} else if signed_max == F::ZERO {
let mut res = bounds;
let arr = [res[0], bounds_value];
res[0] = self.op.clone().apply(|i| arr[i as usize]).bounds();
res
} else if signed_min == F::from(bounds.len() as u64 - 1) {
let len = bounds.len();
let mut res = bounds;
let arr = [res[len - 1], bounds_value];
res[len - 1] = self.op.clone().apply(|i| arr[i as usize]).bounds();
res
} else {
let len: usize = bounds.len();
let container_bounds = bounds;
let bit_bounds = FieldBounds::from((F::ZERO, F::ONE));
let mut res_bounds = container_bounds
.into_iter()
.map(|x_bounds| {
let arr = [x_bounds, bounds_value];
let new_bounds = self.op.clone().apply(|i| arr[i as usize]).bounds();
FieldExpr::bounds(FieldExpr::Where(bit_bounds, new_bounds, x_bounds))
})
.collect::<Vec<FieldBounds<F>>>();
if (bounds_index.signed_min() - F::ZERO).is_lt_zero()
|| (bounds_index.signed_max() - F::from((len - 1) as u64)).is_gt_zero()
{
res_bounds = res_bounds
.into_iter()
.map(|bounds| bounds.union(FieldBounds::new(F::ZERO, F::ZERO)))
.collect::<Vec<FieldBounds<F>>>()
}
res_bounds
}
}
fn run(&self, mut vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
let index = vals.pop().unwrap();
let value = vals.pop().unwrap();
self.index_op_assign(vals, index, value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{
circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
expressions::field_expr::expr_lincomb,
},
utils::field::ScalarField,
};
use rand::Rng;
impl TestedArithmeticCircuit<ScalarField> for Index {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let mut result = if rng.gen_bool(0.125) { 300 } else { 9 };
while rng.gen_bool(0.875) {
result += 3;
}
result
}
}
impl TestedArithmeticCircuit<ScalarField> for IndexOpAssign<ScalarField> {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
let variant = rng.gen_range(0..6);
match variant {
0 => Self::new(expr_lincomb!((true, 1);0)), 1 => Self::new(FieldExpr::Add(false, true)), 2 => Self::new(FieldExpr::Sub(false, true)), 3 => Self::new(FieldExpr::Mul(false, true)), 4 => Self::new(FieldExpr::Div(false, true)), 5 => Self::new(FieldExpr::Rem(false, true)), _ => unreachable!("variant more than 5 should not be possible"),
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let mut result = if rng.gen_bool(0.125) { 300 } else { 9 };
while rng.gen_bool(0.875) {
result += 3;
}
result
}
}
#[test]
fn tested_index() {
Index::test(64, 4)
}
#[test]
fn tested_index_op_assign() {
IndexOpAssign::test(64, 4)
}
}