use cess_hashers::Domain;
use num_bigint::BigUint;
use num_traits::cast::ToPrimitive;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerChallenges {
layers: usize,
max_count: usize,
}
impl LayerChallenges {
pub const fn new(layers: usize, max_count: usize) -> Self {
LayerChallenges { layers, max_count }
}
pub fn layers(&self) -> usize {
self.layers
}
pub fn challenges_count_all(&self) -> usize {
self.max_count
}
pub fn derive<D: Domain>(
&self,
leaves: usize,
replica_id: &D,
seed: &[u8; 32],
k: u8,
) -> Vec<usize> {
self.derive_internal(self.challenges_count_all(), leaves, replica_id, seed, k)
}
pub fn derive_internal<D: Domain>(
&self,
challenges_count: usize,
leaves: usize,
replica_id: &D,
seed: &[u8; 32],
k: u8,
) -> Vec<usize> {
assert!(leaves > 2, "Too few leaves: {}", leaves);
(0..challenges_count)
.map(|i| {
let j: u32 = ((challenges_count * k as usize) + i) as u32;
let hash = Sha256::new()
.chain(replica_id.into_bytes())
.chain(seed)
.chain(&j.to_le_bytes())
.finalize();
let big_challenge = BigUint::from_bytes_le(hash.as_ref());
let big_mod_challenge = big_challenge % (leaves - 1);
let big_mod_challenge = big_mod_challenge
.to_usize()
.expect("`big_mod_challenge` exceeds size of `usize`");
big_mod_challenge + 1
})
.collect()
}
}
#[derive(Debug, Default)]
pub struct ChallengeRequirements {
pub minimum_challenges: usize,
}
#[cfg(test)]
mod test {
use super::*;
use std::collections::HashMap;
use cess_hashers::sha256::Sha256Domain;
use rand::{thread_rng, Rng};
#[test]
fn test_calculate_fixed_challenges() {
let layer_challenges = LayerChallenges::new(10, 333);
let expected = 333;
let calculated_count = layer_challenges.challenges_count_all();
assert_eq!(expected as usize, calculated_count);
}
#[test]
fn challenge_derivation() {
let n = 200;
let layers = 100;
let challenges = LayerChallenges::new(layers, n);
let leaves = 1 << 30;
let rng = &mut thread_rng();
let replica_id: Sha256Domain = Sha256Domain::random(rng);
let seed: [u8; 32] = rng.gen();
let partitions = 5;
let total_challenges = partitions * n;
let mut layers_with_duplicates = 0;
for _layer in 1..=layers {
let mut histogram = HashMap::new();
for k in 0..partitions {
let challenges = challenges.derive(leaves, &replica_id, &seed, k as u8);
for challenge in challenges {
let counter = histogram.entry(challenge).or_insert(0);
*counter += 1;
}
}
let unique_challenges = histogram.len();
if unique_challenges < total_challenges {
layers_with_duplicates += 1;
}
}
assert!(layers_with_duplicates < 3);
}
#[test]
fn challenge_partition_equivalence() {
let n = 40;
let leaves = 1 << 30;
let rng = &mut thread_rng();
let replica_id: Sha256Domain = Sha256Domain::random(rng);
let seed: [u8; 32] = rng.gen();
let partitions = 5;
let layers = 100;
let total_challenges = n * partitions;
for _layer in 1..=layers {
let one_partition_challenges = LayerChallenges::new(layers, total_challenges).derive(
leaves,
&replica_id,
&seed,
0,
);
let many_partition_challenges = (0..partitions)
.flat_map(|k| {
LayerChallenges::new(layers, n).derive(leaves, &replica_id, &seed, k as u8)
})
.collect::<Vec<_>>();
assert_eq!(one_partition_challenges, many_partition_challenges);
}
}
}