Skip to main content

clsag/
signature.rs

1use crate::clsag::calc_aggregation_coefficients;
2use crate::constants::BASEPOINT;
3use crate::member::compute_challenge_ring;
4use crate::transcript::TranscriptProtocol;
5use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
6use curve25519_dalek::scalar::Scalar;
7use curve25519_dalek::traits::VartimeMultiscalarMul;
8use merlin::Transcript;
9use sha2::Sha512;
10
11#[derive(Debug)]
12pub struct Signature {
13    pub challenge: Scalar,
14    pub responses: Vec<Scalar>,
15    pub key_images: Vec<CompressedRistretto>,
16}
17
18pub enum Error {
19    // This error occurs if the signature contains an amount of public keys
20    // that does not match the number of public keys
21    IncorrectNumOfPubKeys,
22    // This error occurs when either one of the key images supplied cannot be decompressed
23    BadKeyImages,
24    // This error occurs when the calculated challenge is different from the challenge in the signature
25    ChallengeMismatch,
26    // This error occurs when the point cannot be correctly decompressed
27    BadPoint,
28    // This error occurs when an underlying error from the member package occurs
29    MemberError(String),
30}
31
32impl From<crate::member::Error> for Error {
33    fn from(e: crate::member::Error) -> Error {
34        let err_string = format!(" underlying member error {:?}", e);
35        Error::MemberError(err_string)
36    }
37}
38
39impl Signature {
40    pub fn verify(
41        &self,
42        public_keys: &mut Vec<Vec<CompressedRistretto>>,
43        msg: &[u8],
44    ) -> Result<(), Error> {
45        // Skip subgroup check as ristretto points have co-factor 1.
46
47        let num_responses = self.responses.len();
48        let num_pubkey_sets = public_keys.len();
49
50        // -- Check that we have the correct amount of public keys
51        if num_pubkey_sets != num_responses {
52            return Err(Error::IncorrectNumOfPubKeys);
53        }
54
55        let pubkey_matrix_bytes: Vec<u8> = self.pubkeys_to_bytes(public_keys);
56
57        // Calculate aggregation co-efficients
58        let agg_coeffs = calc_aggregation_coefficients(&pubkey_matrix_bytes, &self.key_images, msg);
59
60        let mut challenge = self.challenge.clone();
61        for (pub_keys, response) in public_keys.iter().zip(self.responses.iter()) {
62            let first_pubkey = pub_keys[0];
63            let hashed_pubkey = RistrettoPoint::hash_from_bytes::<Sha512>(first_pubkey.as_bytes());
64            challenge = compute_challenge_ring(
65                pub_keys,
66                &challenge,
67                &self.key_images,
68                response,
69                &agg_coeffs,
70                &hashed_pubkey,
71                &pubkey_matrix_bytes,
72            );
73        }
74
75        if self.challenge != challenge {
76            return Err(Error::ChallengeMismatch);
77        }
78
79        Ok(())
80    }
81
82    pub fn optimised_verify(
83        &self,
84        public_keys: &mut Vec<Vec<CompressedRistretto>>,
85        msg: &[u8],
86    ) -> Result<(), Error> {
87        // Skip subgroup check as ristretto points have co-factor 1.
88
89        let num_responses = self.responses.len();
90        let num_pubkey_sets = public_keys.len();
91
92        // -- Check that we have the correct amount of public keys
93        if num_pubkey_sets != num_responses {
94            return Err(Error::IncorrectNumOfPubKeys);
95        }
96
97        // Calculate all response * BASEPOINT
98        let response_points: Vec<RistrettoPoint> = self
99            .responses
100            .iter()
101            .map(|response| response * BASEPOINT)
102            .collect();
103
104        // calculate all response * H(signingKeys)
105        let response_hashed_points: Vec<RistrettoPoint> = self
106            .responses
107            .iter()
108            .zip(public_keys.iter())
109            .map(|(response, pub_keys)| {
110                let first_pubkey = pub_keys[0];
111                let hashed_pubkey =
112                    RistrettoPoint::hash_from_bytes::<Sha512>(first_pubkey.as_bytes());
113
114                response * hashed_pubkey
115            })
116            .collect();
117
118        // compute the public key bytes
119        let pubkey_matrix_bytes = self.pubkeys_to_bytes(public_keys);
120
121        // Calculate aggregation co-efficients
122        let agg_coeffs = calc_aggregation_coefficients(&pubkey_matrix_bytes, &self.key_images, msg);
123
124        let mut challenge = self.challenge.clone();
125
126        for ((resp_point, resp_hashed_point), pub_keys) in response_points
127            .iter()
128            .zip(response_hashed_points.iter())
129            .zip(public_keys.iter())
130        {
131            let challenge_agg_coeffs: Vec<Scalar> =
132                agg_coeffs.iter().map(|ac| ac * &challenge).collect();
133
134            let mut l_i = RistrettoPoint::optional_multiscalar_mul(
135                &challenge_agg_coeffs,
136                pub_keys.iter().map(|pt| pt.decompress()),
137            )
138            .ok_or(Error::BadPoint)?;
139            l_i = l_i + resp_point;
140
141            let mut r_i = RistrettoPoint::optional_multiscalar_mul(
142                &challenge_agg_coeffs,
143                self.key_images.iter().map(|pt| pt.decompress()),
144            )
145            .ok_or(Error::BadPoint)?;
146            r_i = r_i + resp_hashed_point;
147
148            let mut transcript = Transcript::new(b"clsag");
149            transcript.append_message(b"", &pubkey_matrix_bytes);
150            transcript.append_point(b"", &l_i);
151            transcript.append_point(b"", &r_i);
152
153            challenge = transcript.challenge_scalar(b"");
154        }
155
156        if challenge != self.challenge {
157            return Err(Error::ChallengeMismatch);
158        }
159
160        Ok(())
161    }
162
163    fn pubkeys_to_bytes(&self, pubkey_matrix: &Vec<Vec<CompressedRistretto>>) -> Vec<u8> {
164        let mut bytes: Vec<u8> =
165            Vec::with_capacity(self.key_images.len() * self.responses.len() * 64);
166        for i in 0..pubkey_matrix.len() {
167            let pubkey_bytes: Vec<u8> = pubkey_matrix[i]
168                .iter()
169                .map(|pubkey| pubkey.to_bytes().to_vec())
170                .flatten()
171                .collect();
172            bytes.extend(pubkey_bytes);
173        }
174        bytes
175    }
176}
177
178#[cfg(test)]
179mod test {
180    extern crate test;
181    use test::Bencher;
182
183    use crate::tests_helper::*;
184    use rand::seq::SliceRandom;
185    use rand::thread_rng;
186
187    #[test]
188    fn test_verify() {
189        let num_keys = 1;
190        let num_decoys = 1;
191        let msg = b"hello world";
192
193        let mut clsag = generate_clsag_with(num_decoys, num_keys);
194        clsag.add_member(generate_signer(num_keys));
195        let sig = clsag.sign(msg).unwrap();
196        let mut pub_keys = clsag.public_keys();
197
198        let expected_pubkey_bytes = clsag.public_keys_bytes();
199        let have_pubkey_bytes = sig.pubkeys_to_bytes(&pub_keys);
200
201        assert_eq!(expected_pubkey_bytes, have_pubkey_bytes);
202        assert!(sig.optimised_verify(&mut pub_keys, msg).is_ok());
203    }
204
205    #[test]
206    fn test_verify_fail_shuffle_keys() {
207        let num_keys = 2;
208        let num_decoys = 11;
209        let msg = b"hello world";
210
211        let mut clsag = generate_clsag_with(num_decoys, num_keys);
212        clsag.add_member(generate_signer(num_keys));
213        let sig = clsag.sign(msg).unwrap();
214        let mut pub_keys = clsag.public_keys();
215
216        // shuffle public key ordering
217        pub_keys.shuffle(&mut thread_rng());
218        assert!(sig.optimised_verify(&mut pub_keys, msg).is_err());
219    }
220    #[test]
221    fn test_verify_fail_incorrect_num_keys() {
222        let num_keys = 2;
223        let num_decoys = 11;
224        let msg = b"hello world";
225
226        let mut clsag = generate_clsag_with(num_decoys, num_keys);
227        clsag.add_member(generate_signer(num_keys));
228        let sig = clsag.sign(msg).unwrap();
229        let mut pub_keys = clsag.public_keys();
230
231        // Add extra key
232        let extra_key = generate_rand_compressed_points(num_keys);
233        pub_keys.push(extra_key);
234        assert!(sig.optimised_verify(&mut pub_keys, msg).is_err());
235
236        // remove the extra key and test should pass
237        pub_keys.remove(pub_keys.len() - 1);
238        assert!(sig.optimised_verify(&mut pub_keys, msg).is_ok());
239
240        // remove another key and tests should fail
241        pub_keys.remove(pub_keys.len() - 1);
242        assert!(sig.optimised_verify(&mut pub_keys, msg).is_err());
243    }
244
245    macro_rules! param_bench_verify {
246        ($func_name: ident,$num_keys:expr, $num_decoys :expr) => {
247            #[bench]
248            fn $func_name(b: &mut Bencher) {
249                let num_keys = $num_keys;
250                let num_decoys = $num_decoys;
251                let msg = b"hello world";
252
253                let mut clsag = generate_clsag_with(num_decoys, num_keys);
254                clsag.add_member(generate_signer(num_keys));
255                let sig = clsag.sign(msg).unwrap();
256                let mut pub_keys = clsag.public_keys();
257
258                b.iter(|| sig.optimised_verify(&mut pub_keys, msg));
259            }
260        };
261    }
262
263    param_bench_verify!(bench_verify_2, 2, 2);
264    param_bench_verify!(bench_verify_4, 2, 3);
265    param_bench_verify!(bench_verify_6, 2, 5);
266    param_bench_verify!(bench_verify_8, 2, 7);
267    param_bench_verify!(bench_verify_11, 2, 10);
268    param_bench_verify!(bench_verify_16, 2, 15);
269    param_bench_verify!(bench_verify_32, 2, 31);
270    param_bench_verify!(bench_verify_64, 2, 63);
271    param_bench_verify!(bench_verify_128, 2, 127);
272    param_bench_verify!(bench_verify_256, 2, 255);
273    param_bench_verify!(bench_verify_512, 2, 511);
274}