use super::{
super::{group::SmallScalar, variant::Variant, Error},
hash_with_namespace,
};
#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
use commonware_math::algebra::Space;
use commonware_parallel::Strategy;
use rand_core::CryptoRngCore;
struct SegmentTree<V: Variant> {
len: usize,
tree: Vec<Option<(V::Public, V::Signature)>>,
}
impl<V: Variant> SegmentTree<V> {
fn build(leaves: &[(V::Public, V::Signature)]) -> Self {
let len = leaves.len();
if len == 0 {
return Self {
len,
tree: Vec::new(),
};
}
let mut tree = vec![None; 4 * len];
let mut stack = vec![(1usize, 0usize, len, false)];
while let Some((node, start, end, children_built)) = stack.pop() {
if end - start == 1 {
tree[node] = Some(leaves[start]);
} else if !children_built {
let mid = start + (end - start) / 2;
stack.push((node, start, end, true));
stack.push((2 * node + 1, mid, end, false));
stack.push((2 * node, start, mid, false));
} else {
let left = tree[2 * node].expect("left child built");
let right = tree[2 * node + 1].expect("right child built");
tree[node] = Some((left.0 + &right.0, left.1 + &right.1));
}
}
Self { len, tree }
}
fn verify(&self, hm: &V::Signature, root_invalid: bool) -> Vec<usize> {
if self.len == 0 {
return Vec::new();
}
let mut invalid = Vec::new();
let mut stack = if root_invalid && self.len > 1 {
let mid = self.len / 2;
vec![(2usize, 0, mid), (3usize, mid, self.len)]
} else if root_invalid {
invalid.push(0);
return invalid;
} else {
vec![(1usize, 0usize, self.len)]
};
while let Some((node, start, end)) = stack.pop() {
let (pk, sig) = self.tree[node].expect("node exists");
if V::verify(&pk, hm, &sig).is_ok() {
continue;
}
if end - start == 1 {
invalid.push(start);
continue;
}
let mid = start + (end - start) / 2;
stack.push((2 * node, start, mid));
stack.push((2 * node + 1, mid, end));
}
invalid
}
}
fn bisect<V: Variant>(
entries: &[(V::Public, V::Signature)],
hm: &V::Signature,
aggregate_invalid: bool,
strategy: &impl Strategy,
) -> Vec<usize> {
if entries.is_empty() {
return Vec::new();
}
let par_hint = strategy.parallelism_hint();
let chunk_size = entries.len().div_ceil(par_hint);
if entries.len() <= chunk_size {
let mut out = SegmentTree::<V>::build(entries).verify(hm, aggregate_invalid);
out.sort_unstable();
return out;
}
let mut out = strategy.fold(
entries.chunks(chunk_size).enumerate(),
|| Vec::with_capacity(entries.len()),
|mut acc, (i, chunk)| {
let offset = i * chunk_size;
acc.extend(
SegmentTree::<V>::build(chunk)
.verify(hm, false)
.into_iter()
.map(|j| offset + j),
);
acc
},
|mut acc_l, mut acc_r| {
acc_l.append(&mut acc_r);
acc_l
},
);
out.sort_unstable();
out
}
pub fn verify_same_message<R, V>(
rng: &mut R,
namespace: &[u8],
message: &[u8],
entries: &[(V::Public, V::Signature)],
par: &impl Strategy,
) -> Vec<usize>
where
R: CryptoRngCore,
V: Variant,
{
if entries.is_empty() {
return Vec::new();
}
let hm = hash_with_namespace::<V>(V::MESSAGE, namespace, message);
let scalars: Vec<SmallScalar> = (0..entries.len())
.map(|_| SmallScalar::random(&mut *rng))
.collect();
let (pks, sigs) = entries.iter().cloned().collect::<(Vec<_>, Vec<_>)>();
let (sum_pk, sum_sig) = par.join(
|| V::Public::msm(&pks, &scalars, par),
|| V::Signature::msm(&sigs, &scalars, par),
);
if V::verify(&sum_pk, &hm, &sum_sig).is_ok() {
return Vec::new();
}
let weighted_entries = par.map_collect_vec(
scalars.iter().zip(pks.iter().zip(sigs.iter())),
|(s, (&pk, &sig))| (pk * s, sig * s),
);
bisect::<V>(&weighted_entries, &hm, true, par)
}
pub fn verify_same_signer<'a, R, V, I>(
rng: &mut R,
public: &V::Public,
entries: I,
strategy: &impl Strategy,
) -> Result<(), Error>
where
R: CryptoRngCore,
V: Variant,
I: IntoIterator<Item = &'a (&'a [u8], &'a [u8], V::Signature)>,
{
let entries: Vec<_> = entries.into_iter().collect();
if entries.is_empty() {
return Ok(());
}
let scalars: Vec<SmallScalar> = (0..entries.len())
.map(|_| SmallScalar::random(&mut *rng))
.collect();
let hms: Vec<V::Signature> = strategy.map_collect_vec(entries.iter(), |(namespace, msg, _)| {
hash_with_namespace::<V>(V::MESSAGE, namespace, msg)
});
let sigs: Vec<V::Signature> = entries.iter().map(|(_, _, sig)| *sig).collect();
let (weighted_hm, weighted_sig) = strategy.join(
|| V::Signature::msm(&hms, &scalars, strategy),
|| V::Signature::msm(&sigs, &scalars, strategy),
);
V::verify(public, &weighted_hm, &weighted_sig)
}
#[cfg(test)]
mod tests {
use super::{
super::{
super::group::Scalar, aggregate, hash_with_namespace, keypair, sign_message,
verify_message,
},
*,
};
use crate::bls12381::primitives::variant::{MinPk, MinSig};
use commonware_math::algebra::{CryptoGroup, Random};
use commonware_parallel::{Rayon, Sequential};
use commonware_utils::{test_rng, NZUsize};
fn verify_same_signer_correct<V: Variant>() {
let mut rng = test_rng();
let (private, public) = keypair::<_, V>(&mut rng);
let namespace = b"test";
let messages: &[(&[u8], &[u8])] = &[
(namespace, b"Message 1"),
(namespace, b"Message 2"),
(namespace, b"Message 3"),
];
let entries: Vec<_> = messages
.iter()
.map(|(ns, msg)| (*ns, *msg, sign_message::<V>(&private, ns, msg)))
.collect();
verify_same_signer::<_, V, _>(&mut rng, &public, &entries, &Sequential)
.expect("valid signatures should be accepted");
let strategy = Rayon::new(NZUsize!(4)).unwrap();
verify_same_signer::<_, V, _>(&mut rng, &public, &entries, &strategy)
.expect("valid signatures should be accepted with parallel strategy");
}
#[test]
fn test_verify_same_signer_correct() {
verify_same_signer_correct::<MinPk>();
verify_same_signer_correct::<MinSig>();
}
fn verify_same_signer_wrong_signature<V: Variant>() {
let mut rng = test_rng();
let (private, public) = keypair::<_, V>(&mut rng);
let namespace = b"test";
let messages: &[(&[u8], &[u8])] = &[
(namespace, b"Message 1"),
(namespace, b"Message 2"),
(namespace, b"Message 3"),
];
let mut entries: Vec<_> = messages
.iter()
.map(|(ns, msg)| (*ns, *msg, sign_message::<V>(&private, ns, msg)))
.collect();
let random_scalar = Scalar::random(&mut rng);
entries[1].2 += &(V::Signature::generator() * &random_scalar);
let result = verify_same_signer::<_, V, _>(&mut rng, &public, &entries, &Sequential);
assert!(result.is_err(), "corrupted signature should be rejected");
}
#[test]
fn test_verify_same_signer_wrong_signature() {
verify_same_signer_wrong_signature::<MinPk>();
verify_same_signer_wrong_signature::<MinSig>();
}
fn rejects_malleability<V: Variant>() {
let mut rng = test_rng();
let (private, public) = keypair::<_, V>(&mut rng);
let namespace = b"test";
let msg1: &[u8] = b"message 1";
let msg2: &[u8] = b"message 2";
let sig1 = sign_message::<V>(&private, namespace, msg1);
let sig2 = sign_message::<V>(&private, namespace, msg2);
verify_message::<V>(&public, namespace, msg1, &sig1).expect("sig1 should be valid");
verify_message::<V>(&public, namespace, msg2, &sig2).expect("sig2 should be valid");
let random_scalar = Scalar::random(&mut rng);
let delta = V::Signature::generator() * &random_scalar;
let forged_sig1 = sig1 - δ
let forged_sig2 = sig2 + δ
assert!(
verify_message::<V>(&public, namespace, msg1, &forged_sig1).is_err(),
"forged sig1 should be invalid individually"
);
assert!(
verify_message::<V>(&public, namespace, msg2, &forged_sig2).is_err(),
"forged sig2 should be invalid individually"
);
let forged_agg = aggregate::combine_signatures::<V, _>(&[forged_sig1, forged_sig2]);
let valid_agg = aggregate::combine_signatures::<V, _>(&[sig1, sig2]);
assert_eq!(forged_agg, valid_agg, "aggregates should be equal");
let hm1 = hash_with_namespace::<V>(V::MESSAGE, namespace, msg1);
let hm2 = hash_with_namespace::<V>(V::MESSAGE, namespace, msg2);
let hm_sum = hm1 + &hm2;
V::verify(&public, &hm_sum, forged_agg.inner())
.expect("naive aggregate verification accepts forged aggregate");
let forged_entries: Vec<(&[u8], &[u8], _)> = vec![
(namespace, msg1, forged_sig1),
(namespace, msg2, forged_sig2),
];
let result = verify_same_signer::<_, V, _>(&mut rng, &public, &forged_entries, &Sequential);
assert!(
result.is_err(),
"batch verification should reject forged signatures"
);
let valid_entries: Vec<(&[u8], &[u8], _)> =
vec![(namespace, msg1, sig1), (namespace, msg2, sig2)];
verify_same_signer::<_, V, _>(&mut rng, &public, &valid_entries, &Sequential)
.expect("batch verification should accept valid signatures");
}
#[test]
fn test_rejects_malleability() {
rejects_malleability::<MinPk>();
rejects_malleability::<MinSig>();
}
}