use std::iter;
use ark_ec::CurveGroup;
use itertools::Itertools;
use crate::{algebra::AuthenticatedScalarResult, MpcFabric};
pub fn prefix_product<C: CurveGroup>(
values: &[AuthenticatedScalarResult<C>],
fabric: &MpcFabric<C>,
) -> Vec<AuthenticatedScalarResult<C>> {
let n = values.len();
let (b_values, b_inv_values) = fabric.random_inverse_pairs(n + 1);
let partial_blind = AuthenticatedScalarResult::batch_mul(&b_inv_values[..n], values);
let blinded = AuthenticatedScalarResult::batch_mul(&partial_blind, &b_values[1..]);
let blinded_open = AuthenticatedScalarResult::open_authenticated_batch(&blinded)
.into_iter()
.map(|v| v.value)
.collect_vec();
let b0_repeat = iter::repeat(b_values[0].clone()).take(n).collect_vec();
let mut prefix = blinded_open[0].clone();
let mut prefixes = vec![prefix.clone()];
for blinded_term in blinded_open[1..].iter() {
prefix = prefix * blinded_term;
prefixes.push(prefix.clone());
}
let right_hand_terms = &b_inv_values[1..];
let partial_unblind = AuthenticatedScalarResult::batch_mul_public(&b0_repeat, &prefixes);
AuthenticatedScalarResult::batch_mul(&partial_unblind, right_hand_terms)
}
#[cfg(test)]
mod test {
use futures::future;
use itertools::Itertools;
use rand::thread_rng;
use crate::{
algebra::{AuthenticatedScalarResult, Scalar},
gadgets::prefix_product,
test_helpers::execute_mock_mpc,
PARTY0,
};
#[tokio::test]
async fn test_prefix_prod() {
const N: usize = 10;
let mut rng = thread_rng();
let values = (0..N).map(|_| Scalar::random(&mut rng)).collect_vec();
let mut expected_res = vec![values[0]];
let mut product = values[0];
for val in values[1..].iter() {
product *= *val;
expected_res.push(product);
}
let (res, _) = execute_mock_mpc(|fabric| {
let values = values.clone();
async move {
let allocated_values = fabric.batch_share_scalar(values, PARTY0 );
let res = prefix_product(&allocated_values, &fabric);
let res_open = AuthenticatedScalarResult::open_authenticated_batch(&res);
future::join_all(res_open).await
}
})
.await;
let res = res.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(res, expected_res)
}
}