use crate::{
core::{
circuits::boolean::{
boolean_array::BooleanArray,
boolean_value::BooleanValue,
utils::{addition_circuit, CircuitType},
},
global_value::value::FieldValue,
},
traits::{FromLeBits, GreaterEqual, RandomBit, Reveal, Select},
types::{ArcisArray, DOUBLE_PRECISION_MANTISSA},
utils::{
crypto::rescue_prime_hash::RescuePrimeHash,
elliptic_curve::{AffineEdwardsPoint, ProjectiveEdwardsPoint},
field::ScalarField,
used_field::UsedField,
},
ArcisField,
ArcisValue,
};
use rayon::prelude::*;
use std::{iter::successors, time::Instant};
const A: usize = 5;
const B: usize = 5;
#[allow(dead_code)]
const N: usize = A + B;
#[derive(Debug)]
pub enum DiscreteLogarithmFailure {
#[allow(dead_code)]
MatchError(String),
}
#[derive(Clone, Debug)]
pub struct DiscreteLogarithm;
impl DiscreteLogarithm {
#[allow(dead_code, non_snake_case, clippy::type_complexity)]
pub fn pre_processing() -> (
ArcisArray<{ (1usize << A) + 1 }>,
ArcisValue,
ArcisValue,
ArcisValue,
ProjectiveEdwardsPoint<ArcisValue>,
) {
// 1. generate i_star and j_star
// To generate a uniformly random element mod 2^a + 1 we generate a random bit
// b ~ Bernoulli(1/(2^a + 1)) and rand, a uniformly random element mod 2^a.
// If b == 1 we return 2^a else we return rand.
let p = (1u64 << DOUBLE_PRECISION_MANTISSA) / ((1 << A) + 1);
let random_val = ArcisValue::from_le_bits(
(0..DOUBLE_PRECISION_MANTISSA)
.map(|_| BooleanValue::random())
.collect::<Vec<BooleanValue>>(),
false,
);
let b = random_val.lt(ArcisValue::from(p));
let mut two_pow_a = vec![BooleanValue::from(false); A];
two_pow_a.push(BooleanValue::from(true));
let mut rand = (0..A)
.map(|_| BooleanValue::random())
.collect::<Vec<BooleanValue>>();
rand.push(BooleanValue::from(false));
let i_star = b.select(two_pow_a, rand);
let j_star = (0..B)
.map(|_| BooleanValue::random())
.collect::<Vec<BooleanValue>>();
// 2. compute the vector i_star + iota mod 2^a + 1
// we compute the 1+ from the paper in step 3. below
fn cyclic_shift(
mut i_star: Vec<BooleanArray<{ (1usize << A) + 1 }>>,
mut iota: Vec<BooleanArray<{ (1usize << A) + 1 }>>,
) -> Vec<BooleanArray<{ (1usize << A) + 1 }>> {
let a = i_star.len() - 1;
// for the below addition_circuit, i_star and iota must be a + 2 bits long
i_star.push(BooleanArray::from(false));
iota.push(BooleanArray::from(false));
let mut sum = addition_circuit(
i_star,
iota,
BooleanArray::from(false),
CircuitType::default(),
);
// subtract 2^a + 1 from the sum
// on signed integers we have -(2^a + 1) = \bar{2^a},
// where \bar means element-wise negation
// hence, on a (a+2)-bit window, -(2^a + 1) writes as 1..101 (lsb-to-msb)
let mut neg = vec![BooleanArray::from(true); a];
neg.append(&mut vec![
BooleanArray::from(false),
BooleanArray::from(true),
]);
let mut sum_corrected = addition_circuit(
sum.clone(),
neg,
BooleanArray::from(false),
CircuitType::default(),
);
let _ = sum.pop();
let sign = sum_corrected.pop().unwrap();
sign.select(sum, sum_corrected)
}
let iota = (0..A + 1)
.map(|i| {
BooleanArray::from(
TryInto::<[BooleanValue; (1usize << A) + 1]>::try_into(
(0..((1usize << A) + 1))
.map(|iota| {
if (iota >> i) & 1usize == 1usize {
BooleanValue::from(true)
} else {
BooleanValue::from(false)
}
})
.collect::<Vec<BooleanValue>>(),
)
.unwrap_or_else(|v: Vec<BooleanValue>| {
panic!(
"Expected a Vec of length {} (found {})",
(1usize << A) + 1,
v.len()
)
}),
)
})
.collect::<Vec<BooleanArray<{ (1usize << A) + 1 }>>>();
// shifted is a vector of length a + 1 of BooleanArrays of length 2^a + 1
let shifted = cyclic_shift(
i_star
.iter()
.cloned()
.map(BooleanArray::<{ (1usize << A) + 1 }>::from)
.collect::<Vec<BooleanArray<{ (1usize << A) + 1 }>>>(),
iota,
);
let G = ProjectiveEdwardsPoint::<ArcisValue>::generator();
let two_pow_b = "0".repeat(B).to_string() + "1";
let two_pow_b_G = G.mul_str(&two_pow_b);
let two_pow_b_G_vec = ProjectiveEdwardsPoint::new(
(
ArcisArray::<{ (1usize << A) + 1 }>::from(two_pow_b_G.X),
ArcisArray::<{ (1usize << A) + 1 }>::from(two_pow_b_G.Y),
ArcisArray::<{ (1usize << A) + 1 }>::from(two_pow_b_G.Z),
),
two_pow_b_G.is_on_curve,
two_pow_b_G.is_ell_torsion,
);
// we first perfrom a vector-wise multiplication
// between shifted and 2^b * G, and then add 2^b * G (for the 1+ from step 2. of the
// paper)
let GS_i_star = two_pow_b_G_vec.mul_bits(shifted) + two_pow_b_G_vec;
// 4. generate R
let R = FieldValue::<ScalarField>::random() * G;
let R_vec = ProjectiveEdwardsPoint::new(
(
ArcisArray::from(R.X),
ArcisArray::from(R.Y),
ArcisArray::from(R.Z),
),
R.is_on_curve,
R.is_ell_torsion,
);
// 5. compute GS_i_star_R
// given that we work with the Edwards model we hash the y-coordinate
let GS_i_star_R = (GS_i_star + R_vec).to_affine().y;
// 6. GS_i_star_R_prime
let salt = ArcisValue::random();
let hasher = RescuePrimeHash::new();
let GS_i_star_R_prime =
hasher.digest(vec![ArcisArray::from(salt), GS_i_star_R])[0].reveal();
// 7. compute j_star * G + R
let j_star_G_plus_R = ProjectiveEdwardsPoint::mul_bits_generator(j_star.clone()) + R;
(
GS_i_star_R_prime,
ArcisValue::from_le_bits(i_star, false),
ArcisValue::from_le_bits(j_star, false),
salt,
j_star_G_plus_R,
)
}
/// The first online phase takes:
/// - the secret P
/// - the secret salt
/// - the secret j_star * G + R
/// and returns:
/// - the public BS_j_star_R_prime
#[allow(dead_code, non_snake_case)]
pub fn online_phase_1(
P: AffineEdwardsPoint<ArcisValue>,
salt: ArcisValue,
j_star_G_plus_R: ProjectiveEdwardsPoint<ArcisValue>,
) -> ArcisArray<{ 1usize << B }> {
assert!(P.is_on_curve && P.is_ell_torsion);
let online_phase_1_time = Instant::now();
let G = ProjectiveEdwardsPoint::<ArcisValue>::generator();
// 1. compute P_star and open
println!(" computing P_star");
let P_star = (P.to_projective() + j_star_G_plus_R).reveal();
let time_computing_P_star = online_phase_1_time.elapsed();
println!(" time: {:?}", time_computing_P_star);
// 2. compute BS_j_star_R
println!(" adding muls of G");
let mut iter = 1..1 << B;
let BS_j_star_R_proj = successors(Some(P_star), |point| iter.next().map(|_| *point + G))
.collect::<Vec<ProjectiveEdwardsPoint<ArcisValue>>>();
fn points_to_coord_vec<const N: usize, F>(
points: &[ProjectiveEdwardsPoint<ArcisValue>],
coord_func: F,
) -> ArcisArray<N>
where
F: FnMut(&ProjectiveEdwardsPoint<ArcisValue>) -> ArcisValue,
{
ArcisArray::from(
TryInto::<[ArcisValue; N]>::try_into(
points.iter().map(coord_func).collect::<Vec<ArcisValue>>(),
)
.unwrap_or_else(|v: Vec<ArcisValue>| {
panic!("Expected a Vec of length {} (found {})", N, v.len())
}),
)
}
let BS_j_star_R_proj = ProjectiveEdwardsPoint::new(
(
points_to_coord_vec(&BS_j_star_R_proj, |point| point.X),
points_to_coord_vec(&BS_j_star_R_proj, |point| point.Y),
points_to_coord_vec(&BS_j_star_R_proj, |point| point.Z),
),
P.is_on_curve,
P.is_ell_torsion,
);
let time_adding_muls_of_G = online_phase_1_time.elapsed();
println!(
" time: {:?}",
time_adding_muls_of_G - time_computing_P_star
);
println!(" converting to affine coordinates");
let BS_j_star_R = BS_j_star_R_proj.to_affine().y;
let time_converting_affine = online_phase_1_time.elapsed();
println!(
" time: {:?}",
time_converting_affine - time_adding_muls_of_G
);
// 3. compute BS_j_star_R_prime
println!(" hashing");
let hasher = RescuePrimeHash::new();
let BS_j_star_R_prime =
hasher.digest(vec![ArcisArray::from(salt), BS_j_star_R])[0].reveal();
let time_hashing = online_phase_1_time.elapsed();
println!(
" time: {:?}",
time_hashing - time_converting_affine
);
BS_j_star_R_prime
}
/// The second online phase takes:
/// - the public GS_i_star_R_prime
/// - the public BS_j_star_R_prime
/// and returns:
/// - the public matching indices i_prime and j_prime such that GS_i_star_R_prime[i_prime] =
/// BS_j_star_R_prime[j_prime]
#[allow(dead_code, non_snake_case)]
pub fn online_phase_2(
GS_i_star_R_prime: Vec<ArcisField>,
BS_j_star_R_prime: Vec<ArcisField>,
) -> Result<(usize, usize), DiscreteLogarithmFailure> {
// 4. find match
println!(" finding matches");
let mut all_steps = GS_i_star_R_prime
.into_iter()
.enumerate()
.collect::<Vec<(usize, ArcisField)>>();
all_steps.append(
&mut BS_j_star_R_prime
.into_iter()
.enumerate()
.collect::<Vec<(usize, ArcisField)>>(),
);
// // HashMap based approach
// let mut map: HashMap<ArcisField, Vec<usize>> = HashMap::new();
// for (index, y_coordinate) in all_steps {
// let match_indices = map.entry(y_coordinate).or_default();
// match_indices.push(index);
// }
// let matches = map
// .into_iter()
// .map(|(_, match_indices)| match_indices)
// .filter(|match_indices| (*match_indices).len() > 1)
// .collect::<Vec<Vec<usize>>>();
// let time_finding_matches = online_phase_2_time.elapsed();
// if matches.len() != 1 || matches[0].len() != 2 {
// Err(DiscreteLogarithmFailure::MatchError(format!(
// "Baby-steps and giant-steps are expected have exactly one match (found {:?})",
// matches
// )))
// } else {
// // (i, j) are the indices of the match
// Ok((matches[0][0], matches[0][1]))
// }
// sorting based approach
all_steps.par_sort_by_key(|val| val.1);
// we know there is exactly one occurrence of consecutive elements with equal y-coordinate
let matches = all_steps
.windows(2)
.filter(|w| w[0].1 == w[1].1)
.collect::<Vec<&[(usize, ArcisField)]>>();
if matches.len() != 1 {
Err(DiscreteLogarithmFailure::MatchError(format!(
"Baby-steps and giant-steps are expected have exactly one match (found {:?})",
matches
)))
} else {
// (i, j) are the indices of the match
Ok((matches[0][0].0, matches[0][1].0))
}
}
/// The third online phase takes:
/// - the secret i_star, j_star
/// - the public i_prime, j_prime
/// and returns:
/// - the secret x = (i + 1) 2^b - j
#[allow(dead_code, non_snake_case)]
pub fn online_phase_3(
i_star: ArcisValue,
j_star: ArcisValue,
i_prime: usize,
j_prime: usize,
) -> ArcisValue {
// 5.
let sum = ArcisValue::from(i_prime as u64) + i_star;
let sum_corrected =
sum - ArcisValue::from(ArcisField::power_of_two(A) + ArcisField::from(1));
let sign = sum_corrected.signed_lt(ArcisValue::from(0));
let i = sign.select(sum, sum_corrected);
let j = ArcisValue::from(j_prime as u64) + j_star;
// 6.
(i + ArcisValue::from(1)) * ArcisValue::from(ArcisField::power_of_two(B)) - j
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::boolean::byte::Byte,
expressions::{expr::EvalValue, field_expr::FieldExpr, InputKind},
global_value::global_expr_store::with_local_expr_store_as_global,
ir_builder::{ExprStore, IRBuilder},
};
use rand::Rng;
#[test]
#[allow(non_snake_case)]
fn test_discrete_log() {
let rng = &mut crate::utils::test_rng::get();
// pre-processing circuit
println!("\nstarting pre-processing (MPC computation, operations carried out on vectors of length {})", 1 << A);
let pre_processing_time = Instant::now();
let mut expr_store = IRBuilder::new(true);
let pre_processing_data_ids = with_local_expr_store_as_global(
|| {
let data = DiscreteLogarithm::pre_processing();
let mut data_ids = data
.0
.into_iter()
.map(|val| val.get_id())
.collect::<Vec<usize>>();
data_ids.push(data.1.get_id());
data_ids.push(data.2.get_id());
data_ids.push(data.3.get_id());
data_ids.push(data.4.X.get_id());
data_ids.push(data.4.Y.get_id());
data_ids.push(data.4.Z.get_id());
data_ids
},
&mut expr_store,
);
let ir = expr_store.into_ir(pre_processing_data_ids);
let pre_processing_data = ir
.eval(rng, &mut vec![].into_iter().enumerate().collect())
.map(|x| {
x.into_iter()
.map(ArcisField::eval_value_to_field)
.collect::<Vec<ArcisField>>()
})
.unwrap();
let GS_i_star_R_prime = pre_processing_data
.iter()
.copied()
.take((1 << A) + 1)
.collect::<Vec<ArcisField>>();
let pre_processing_data = pre_processing_data
.into_iter()
.skip((1 << A) + 1)
.collect::<Vec<ArcisField>>();
let i_star = pre_processing_data[0];
let j_star = pre_processing_data[1];
let salt = pre_processing_data[2];
let j_star_G_plus_R = (
pre_processing_data[3],
pre_processing_data[4],
pre_processing_data[5],
);
println!(" time: {:?}", pre_processing_time.elapsed());
// generate x (N bits)
let mut x_bits_expected = [false; N];
for bit in x_bits_expected.iter_mut() {
*bit = rng.gen();
}
// compute P = x * G
let P = ProjectiveEdwardsPoint::<ArcisField>::mul_bits_generator(x_bits_expected.to_vec())
.to_affine();
// online phase 1
println!("\nstarting online phase 1 (MPC computation, operations carried out on vectors of length {})", 1 << B);
let mut expr_store = IRBuilder::new(true);
let online_phase_1_input_ids = (0..6)
.map(|i| {
expr_store.push_field(FieldExpr::Input(
i,
FieldBounds::<ArcisField>::All.as_input_info(InputKind::Secret),
))
})
.collect::<Vec<usize>>();
let mut online_phase_1_inputs_map = [
P.x,
P.y,
salt,
j_star_G_plus_R.0,
j_star_G_plus_R.1,
j_star_G_plus_R.2,
]
.into_iter()
.map(EvalValue::Base)
.enumerate()
.collect();
let online_phase_1_output_ids = with_local_expr_store_as_global(
|| {
let P = AffineEdwardsPoint::new(
(
ArcisValue::from_id(online_phase_1_input_ids[0]),
ArcisValue::from_id(online_phase_1_input_ids[1]),
),
true,
true,
);
let salt = ArcisValue::from_id(online_phase_1_input_ids[2]);
let j_star_G_plus_R = ProjectiveEdwardsPoint::new(
(
ArcisValue::from_id(online_phase_1_input_ids[3]),
ArcisValue::from_id(online_phase_1_input_ids[4]),
ArcisValue::from_id(online_phase_1_input_ids[5]),
),
true,
true,
);
let BS_j_star_R_prime = DiscreteLogarithm::online_phase_1(P, salt, j_star_G_plus_R);
BS_j_star_R_prime
.into_iter()
.map(|val| val.get_id())
.collect::<Vec<usize>>()
},
&mut expr_store,
);
let ir = expr_store.into_ir(online_phase_1_output_ids);
let online_phase_1_output = ir
.eval(rng, &mut online_phase_1_inputs_map)
.map(|x| {
x.into_iter()
.map(ArcisField::eval_value_to_field)
.collect::<Vec<ArcisField>>()
})
.unwrap();
let BS_j_star_R_prime = online_phase_1_output;
// online phase 2
println!(
"\nstarting online phase 2 (plaintext computation, operations carried out on vectors of length {})",
(1 << A) + 1 + (1 << B)
);
let online_phase_2_time = Instant::now();
let (i_prime, j_prime) =
DiscreteLogarithm::online_phase_2(GS_i_star_R_prime, BS_j_star_R_prime).unwrap();
println!(" time: {:?}", online_phase_2_time.elapsed());
// online phase 3
println!("\nstarting online phase 3 (MPC computation)");
let online_phase_3_time = Instant::now();
let mut expr_store = IRBuilder::new(true);
let online_phase_3_input_ids = (0..2)
.map(|i| {
expr_store.push_field(FieldExpr::Input(
i,
FieldBounds::<ArcisField>::All.as_input_info(InputKind::Secret),
))
})
.collect::<Vec<usize>>();
let mut online_phase_3_inputs_map = [i_star, j_star]
.into_iter()
.map(EvalValue::Base)
.enumerate()
.collect();
let online_phase_3_output_ids = with_local_expr_store_as_global(
|| {
let i_star = ArcisValue::from_id(online_phase_3_input_ids[0]);
let j_star = ArcisValue::from_id(online_phase_3_input_ids[1]);
let x = DiscreteLogarithm::online_phase_3(i_star, j_star, i_prime, j_prime);
vec![x.get_id()]
},
&mut expr_store,
);
let ir = expr_store.into_ir(online_phase_3_output_ids);
let online_phase_3_output = ir
.eval(rng, &mut online_phase_3_inputs_map)
.map(|x| {
x.into_iter()
.map(ArcisField::eval_value_to_field)
.collect::<Vec<ArcisField>>()
})
.unwrap();
println!(" time: {:?}\n", online_phase_3_time.elapsed());
let bytes = online_phase_3_output[0].to_le_bytes();
let x_bits = bytes
.into_iter()
.flat_map(|byte| Byte::from(byte).to_vec())
.take(N)
.collect::<Vec<bool>>();
assert_eq!(x_bits, x_bits_expected);
}
}