use crate::prelude::*;
use crate::traits::{KeyImageGen, Link, Sign, Verify};
use curve25519_dalek::constants;
use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::MultiscalarMul;
use digest::generic_array::typenum::U64;
use digest::Digest;
use rand_core::{CryptoRng, RngCore};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone)]
pub struct MLSAG {
pub challenge: Scalar,
pub responses: Vec<Vec<Scalar>>,
pub ring: Vec<Vec<RistrettoPoint>>,
pub key_images: Vec<RistrettoPoint>,
}
impl KeyImageGen<Vec<Scalar>, Vec<RistrettoPoint>> for MLSAG {
fn generate_key_image<Hash: Digest<OutputSize = U64> + Clone + Default>(
ks: Vec<Scalar>,
) -> Vec<RistrettoPoint> {
let nc = ks.len();
let k_points: Vec<RistrettoPoint> = ks
.iter()
.map(|k| k * constants::RISTRETTO_BASEPOINT_POINT)
.collect();
let key_images: Vec<RistrettoPoint> = (0..nc)
.map(|j| {
ks[j]
* RistrettoPoint::from_hash(
Hash::default().chain_update(k_points[j].compress().as_bytes()),
)
})
.collect();
key_images
}
}
impl Sign<Vec<Scalar>, Vec<Vec<RistrettoPoint>>> for MLSAG {
fn sign<
Hash: Digest<OutputSize = U64> + Clone + Default,
CSPRNG: CryptoRng + RngCore + Default,
>(
ks: Vec<Scalar>,
mut ring: Vec<Vec<RistrettoPoint>>,
secret_index: usize,
message: &[u8],
) -> MLSAG {
let mut csprng = CSPRNG::default();
let nr = ring.len() + 1;
let nc = ring[0].len();
let k_points: Vec<RistrettoPoint> = ks
.iter()
.map(|k| k * constants::RISTRETTO_BASEPOINT_POINT)
.collect();
let key_images: Vec<RistrettoPoint> = MLSAG::generate_key_image::<Hash>(ks.clone());
ring.insert(secret_index, k_points.clone());
let a: Vec<Scalar> = (0..nc).map(|_| Scalar::random(&mut csprng)).collect();
let mut rs: Vec<Vec<Scalar>> = (0..nr)
.map(|_| (0..nc).map(|_| Scalar::random(&mut csprng)).collect())
.collect();
let mut cs: Vec<Scalar> = (0..nr).map(|_| Scalar::ZERO).collect();
let mut message_hash = Hash::default();
message_hash.update(message);
let mut hashes: Vec<Hash> = (0..nr).map(|_| message_hash.clone()).collect();
for j in 0..nc {
hashes[(secret_index + 1) % nr].update(
(a[j] * constants::RISTRETTO_BASEPOINT_POINT)
.compress()
.as_bytes(),
);
hashes[(secret_index + 1) % nr].update(
(a[j]
* RistrettoPoint::from_hash(
Hash::default().chain_update(k_points[j].compress().as_bytes()),
))
.compress()
.as_bytes(),
);
}
cs[(secret_index + 1) % nr] = Scalar::from_hash(hashes[(secret_index + 1) % nr].clone());
let mut i = (secret_index + 1) % nr;
loop {
for (j, key_image) in key_images.iter().enumerate().take(nc) {
hashes[(i + 1) % nr].update(
RistrettoPoint::multiscalar_mul(
&[rs[i % nr][j], cs[i % nr]],
&[constants::RISTRETTO_BASEPOINT_POINT, ring[i % nr][j]],
)
.compress()
.as_bytes(),
);
hashes[(i + 1) % nr].update(
RistrettoPoint::multiscalar_mul(
&[rs[i % nr][j], cs[i % nr]],
&[
RistrettoPoint::from_hash(
Hash::default().chain_update(ring[i % nr][j].compress().as_bytes()),
),
*key_image,
],
)
.compress()
.as_bytes(),
);
}
cs[(i + 1) % nr] = Scalar::from_hash(hashes[(i + 1) % nr].clone());
if (secret_index >= 1 && i % nr == (secret_index - 1) % nr)
|| (secret_index == 0 && i % nr == nr - 1)
{
break;
} else {
i = (i + 1) % nr;
}
}
for j in 0..nc {
rs[secret_index][j] = a[j] - (cs[secret_index] * ks[j]);
}
MLSAG {
challenge: cs[0],
responses: rs,
ring,
key_images,
}
}
}
impl Verify for MLSAG {
fn verify<Hash: Digest<OutputSize = U64> + Clone + Default>(
signature: MLSAG,
message: &[u8],
) -> bool {
let mut reconstructed_c: Scalar = signature.challenge;
let nr = signature.ring.len();
let nc = signature.ring[0].len();
for _i in 0..nr {
let mut h: Hash = Hash::default();
h.update(message);
for (j, key_image) in signature.key_images.iter().enumerate().take(nc) {
h.update(
RistrettoPoint::multiscalar_mul(
&[signature.responses[_i][j], reconstructed_c],
&[constants::RISTRETTO_BASEPOINT_POINT, signature.ring[_i][j]],
)
.compress()
.as_bytes(),
);
h.update(
RistrettoPoint::multiscalar_mul(
&[signature.responses[_i][j], reconstructed_c],
&[
RistrettoPoint::from_hash(
Hash::default()
.chain_update(signature.ring[_i][j].compress().as_bytes()),
),
*key_image,
],
)
.compress()
.as_bytes(),
);
}
reconstructed_c = Scalar::from_hash(h);
}
signature.challenge == reconstructed_c
}
}
impl Link for MLSAG {
fn link(signature_1: MLSAG, signature_2: MLSAG) -> bool {
let mut vec: Vec<[u8; 32]> = Vec::new();
vec.append(
&mut signature_1
.key_images
.iter()
.map(|x| x.compress().to_bytes())
.collect(),
);
vec.append(
&mut signature_2
.key_images
.iter()
.map(|x| x.compress().to_bytes())
.collect(),
);
vec.sort_unstable();
vec.iter().zip(vec.iter().skip(1)).any(|(a, b)| a == b)
}
}
#[cfg(test)]
#[cfg(feature = "std")]
mod test {
extern crate blake2;
extern crate rand;
extern crate sha2;
extern crate sha3;
use super::*;
use blake2::Blake2b512;
use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::scalar::Scalar;
use rand::rngs::OsRng;
use sha2::Sha512;
use sha3::Keccak512;
#[test]
fn mlsag() {
let mut csprng = OsRng::default();
let secret_index = 1;
let nr = 2;
let nc = 2;
let ks: Vec<Scalar> = (0..nc).map(|_| Scalar::random(&mut csprng)).collect();
let ring: Vec<Vec<RistrettoPoint>> = (0..(nr - 1)) .map(|_| {
(0..nc)
.map(|_| RistrettoPoint::random(&mut csprng))
.collect()
})
.collect();
let message: Vec<u8> = b"This is the message".iter().cloned().collect();
{
let signature =
MLSAG::sign::<Sha512, OsRng>(ks.clone(), ring.clone(), secret_index, &message);
let result = MLSAG::verify::<Sha512>(signature, &message);
assert!(result);
}
{
let signature =
MLSAG::sign::<Keccak512, OsRng>(ks.clone(), ring.clone(), secret_index, &message);
let result = MLSAG::verify::<Keccak512>(signature, &message);
assert!(result);
}
{
let signature =
MLSAG::sign::<Blake2b512, OsRng>(ks.clone(), ring.clone(), secret_index, &message);
let result = MLSAG::verify::<Blake2b512>(signature, &message);
assert!(result);
}
let another_ring: Vec<Vec<RistrettoPoint>> = (0..(nr - 1)) .map(|_| {
(0..nc)
.map(|_| RistrettoPoint::random(&mut csprng))
.collect()
})
.collect();
let another_message: Vec<u8> = b"This is another message".iter().cloned().collect();
let signature_1 = MLSAG::sign::<Blake2b512, OsRng>(
ks.clone(),
another_ring.clone(),
secret_index,
&another_message,
);
let signature_2 =
MLSAG::sign::<Blake2b512, OsRng>(ks.clone(), ring.clone(), secret_index, &message);
let result = MLSAG::link(signature_1, signature_2);
assert!(result);
}
}