1use curve25519_dalek::{
32 constants::RISTRETTO_BASEPOINT_TABLE,
33 ristretto::{CompressedRistretto, RistrettoPoint},
34 scalar::Scalar,
35};
36use rand::Rng;
37use serde::{Deserialize, Serialize};
38use thiserror::Error;
39use zeroize::Zeroize;
40
41#[derive(Error, Debug)]
43pub enum SchnorrError {
44 #[error("Invalid signature")]
45 InvalidSignature,
46 #[error("Invalid public key")]
47 InvalidPublicKey,
48 #[error("Invalid secret key")]
49 InvalidSecretKey,
50 #[error("Batch verification failed")]
51 BatchVerificationFailed,
52 #[error("Empty batch")]
53 EmptyBatch,
54 #[error("Serialization error: {0}")]
55 SerializationError(String),
56}
57
58pub type SchnorrResult<T> = Result<T, SchnorrError>;
59
60#[derive(Clone, Zeroize)]
62#[zeroize(drop)]
63pub struct SchnorrSecretKey {
64 scalar: Scalar,
65}
66
67impl SchnorrSecretKey {
68 pub fn generate() -> Self {
70 let mut rng = rand::thread_rng();
71 let mut bytes = [0u8; 32];
72 rng.fill(&mut bytes);
73 let scalar = Scalar::from_bytes_mod_order(bytes);
74 Self { scalar }
75 }
76
77 pub fn from_bytes(bytes: &[u8; 32]) -> SchnorrResult<Self> {
79 let scalar = Scalar::from_bytes_mod_order(*bytes);
80 Ok(Self { scalar })
81 }
82
83 pub fn to_bytes(&self) -> [u8; 32] {
85 self.scalar.to_bytes()
86 }
87
88 pub fn public_key(&self) -> SchnorrPublicKey {
90 let point = RISTRETTO_BASEPOINT_TABLE * &self.scalar;
91 SchnorrPublicKey { point }
92 }
93}
94
95#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
97pub struct SchnorrPublicKey {
98 point: RistrettoPoint,
99}
100
101impl SchnorrPublicKey {
102 pub fn from_bytes(bytes: &[u8; 32]) -> SchnorrResult<Self> {
104 let compressed =
105 CompressedRistretto::from_slice(bytes).map_err(|_| SchnorrError::InvalidPublicKey)?;
106 let point = compressed
107 .decompress()
108 .ok_or(SchnorrError::InvalidPublicKey)?;
109 Ok(Self { point })
110 }
111
112 pub fn to_bytes(&self) -> [u8; 32] {
114 self.point.compress().to_bytes()
115 }
116}
117
118#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
126pub struct SchnorrSignature {
127 challenge: Scalar,
128 response: Scalar,
129}
130
131impl SchnorrSignature {
132 pub fn from_bytes(bytes: &[u8; 64]) -> SchnorrResult<Self> {
134 let mut challenge_bytes = [0u8; 32];
135 let mut response_bytes = [0u8; 32];
136 challenge_bytes.copy_from_slice(&bytes[..32]);
137 response_bytes.copy_from_slice(&bytes[32..]);
138
139 let challenge: Option<Scalar> = Scalar::from_canonical_bytes(challenge_bytes).into();
140 let response: Option<Scalar> = Scalar::from_canonical_bytes(response_bytes).into();
141
142 let challenge = challenge.ok_or(SchnorrError::InvalidSignature)?;
143 let response = response.ok_or(SchnorrError::InvalidSignature)?;
144
145 Ok(Self {
146 challenge,
147 response,
148 })
149 }
150
151 pub fn to_bytes(&self) -> [u8; 64] {
153 let mut bytes = [0u8; 64];
154 bytes[..32].copy_from_slice(&self.challenge.to_bytes());
155 bytes[32..].copy_from_slice(&self.response.to_bytes());
156 bytes
157 }
158}
159
160pub struct SchnorrKeypair {
162 secret_key: SchnorrSecretKey,
163 public_key: SchnorrPublicKey,
164}
165
166impl SchnorrKeypair {
167 pub fn generate() -> Self {
169 let secret_key = SchnorrSecretKey::generate();
170 let public_key = secret_key.public_key();
171 Self {
172 secret_key,
173 public_key,
174 }
175 }
176
177 pub fn from_secret_key(secret_key: SchnorrSecretKey) -> Self {
179 let public_key = secret_key.public_key();
180 Self {
181 secret_key,
182 public_key,
183 }
184 }
185
186 pub fn public_key(&self) -> SchnorrPublicKey {
188 self.public_key
189 }
190
191 pub fn secret_key(&self) -> &SchnorrSecretKey {
193 &self.secret_key
194 }
195
196 pub fn sign(&self, message: &[u8]) -> SchnorrSignature {
204 let mut rng = rand::thread_rng();
205 let mut nonce_bytes = [0u8; 32];
206 rng.fill(&mut nonce_bytes);
207 let nonce = Scalar::from_bytes_mod_order(nonce_bytes);
208
209 let commitment = RISTRETTO_BASEPOINT_TABLE * &nonce;
211
212 let challenge = compute_challenge(&commitment, &self.public_key.point, message);
214
215 let response = nonce - (challenge * self.secret_key.scalar);
217
218 SchnorrSignature {
219 challenge,
220 response,
221 }
222 }
223
224 pub fn verify(&self, message: &[u8], signature: &SchnorrSignature) -> SchnorrResult<()> {
229 verify(&self.public_key, message, signature)
230 }
231}
232
233fn compute_challenge(
235 commitment: &RistrettoPoint,
236 public_key: &RistrettoPoint,
237 message: &[u8],
238) -> Scalar {
239 let mut data = Vec::new();
240 data.extend_from_slice(&commitment.compress().to_bytes());
241 data.extend_from_slice(&public_key.compress().to_bytes());
242 data.extend_from_slice(message);
243
244 let hash = crate::hash::hash(&data);
245 Scalar::from_bytes_mod_order(hash)
246}
247
248pub fn verify(
250 public_key: &SchnorrPublicKey,
251 message: &[u8],
252 signature: &SchnorrSignature,
253) -> SchnorrResult<()> {
254 let commitment_reconstructed =
256 RISTRETTO_BASEPOINT_TABLE * &signature.response + public_key.point * signature.challenge;
257
258 let challenge_reconstructed =
260 compute_challenge(&commitment_reconstructed, &public_key.point, message);
261
262 if challenge_reconstructed == signature.challenge {
264 Ok(())
265 } else {
266 Err(SchnorrError::InvalidSignature)
267 }
268}
269
270pub fn batch_verify(items: &[(SchnorrPublicKey, &[u8], SchnorrSignature)]) -> SchnorrResult<()> {
281 if items.is_empty() {
282 return Err(SchnorrError::EmptyBatch);
283 }
284
285 if items.len() == 1 {
287 return verify(&items[0].0, items[0].1, &items[0].2);
288 }
289
290 let mut rng = rand::thread_rng();
291
292 let mut reconstructed_commitments = Vec::with_capacity(items.len());
294
295 for (public_key, message, signature) in items {
296 let commitment = RISTRETTO_BASEPOINT_TABLE * &signature.response
298 + public_key.point * signature.challenge;
299
300 let expected_challenge = compute_challenge(&commitment, &public_key.point, message);
302
303 if expected_challenge != signature.challenge {
304 return Err(SchnorrError::InvalidSignature);
305 }
306
307 reconstructed_commitments.push(commitment);
308 }
309
310 let weights: Vec<Scalar> = (0..items.len())
313 .map(|_| {
314 let mut bytes = [0u8; 32];
315 rng.fill(&mut bytes);
316 Scalar::from_bytes_mod_order(bytes)
317 })
318 .collect();
319
320 let mut lhs = RistrettoPoint::default();
322 for (weight, commitment) in weights.iter().zip(reconstructed_commitments.iter()) {
323 lhs += weight * commitment;
324 }
325
326 let mut response_sum = Scalar::ZERO;
328 let mut weighted_pubkey_sum = RistrettoPoint::default();
329
330 for (i, (public_key, _, signature)) in items.iter().enumerate() {
331 response_sum += weights[i] * signature.response;
332 weighted_pubkey_sum += (weights[i] * signature.challenge) * public_key.point;
333 }
334
335 let rhs = RISTRETTO_BASEPOINT_TABLE * &response_sum + weighted_pubkey_sum;
336
337 if lhs == rhs {
339 Ok(())
340 } else {
341 Err(SchnorrError::BatchVerificationFailed)
342 }
343}
344
345#[allow(dead_code)]
350pub fn aggregate_signatures(_signatures: &[SchnorrSignature]) -> SchnorrResult<SchnorrSignature> {
351 unimplemented!("Schnorr aggregation requires MuSig protocol")
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_keypair_generation() {
362 let keypair = SchnorrKeypair::generate();
363 let pk = keypair.public_key();
364
365 let pk_bytes = pk.to_bytes();
367 let pk2 = SchnorrPublicKey::from_bytes(&pk_bytes).unwrap();
368 assert_eq!(pk, pk2);
369 }
370
371 #[test]
372 fn test_sign_and_verify() {
373 let keypair = SchnorrKeypair::generate();
374 let message = b"Test message for Schnorr signature";
375
376 let signature = keypair.sign(message);
377 assert!(keypair.verify(message, &signature).is_ok());
378 }
379
380 #[test]
381 fn test_verify_wrong_message() {
382 let keypair = SchnorrKeypair::generate();
383 let message = b"Original message";
384 let wrong_message = b"Wrong message";
385
386 let signature = keypair.sign(message);
387 assert!(keypair.verify(wrong_message, &signature).is_err());
388 }
389
390 #[test]
391 fn test_verify_wrong_public_key() {
392 let keypair1 = SchnorrKeypair::generate();
393 let keypair2 = SchnorrKeypair::generate();
394 let message = b"Test message";
395
396 let signature = keypair1.sign(message);
397 assert!(verify(&keypair2.public_key(), message, &signature).is_err());
398 }
399
400 #[test]
401 fn test_signature_serialization() {
402 let keypair = SchnorrKeypair::generate();
403 let message = b"Test message";
404
405 let signature = keypair.sign(message);
406 let sig_bytes = signature.to_bytes();
407 let signature2 = SchnorrSignature::from_bytes(&sig_bytes).unwrap();
408
409 assert_eq!(signature, signature2);
410 assert!(keypair.verify(message, &signature2).is_ok());
411 }
412
413 #[test]
414 fn test_deterministic_public_key() {
415 let sk_bytes = [42u8; 32];
416 let sk1 = SchnorrSecretKey::from_bytes(&sk_bytes).unwrap();
417 let sk2 = SchnorrSecretKey::from_bytes(&sk_bytes).unwrap();
418
419 assert_eq!(sk1.public_key().to_bytes(), sk2.public_key().to_bytes());
420 }
421
422 #[test]
423 fn test_batch_verify() {
424 let keypair1 = SchnorrKeypair::generate();
425 let keypair2 = SchnorrKeypair::generate();
426 let keypair3 = SchnorrKeypair::generate();
427
428 let message = b"Batch verification test";
429
430 let sig1 = keypair1.sign(message);
431 let sig2 = keypair2.sign(message);
432 let sig3 = keypair3.sign(message);
433
434 let items = vec![
435 (keypair1.public_key(), message.as_slice(), sig1),
436 (keypair2.public_key(), message.as_slice(), sig2),
437 (keypair3.public_key(), message.as_slice(), sig3),
438 ];
439
440 assert!(batch_verify(&items).is_ok());
441 }
442
443 #[test]
444 fn test_batch_verify_one_invalid() {
445 let keypair1 = SchnorrKeypair::generate();
446 let keypair2 = SchnorrKeypair::generate();
447 let keypair3 = SchnorrKeypair::generate();
448
449 let message = b"Batch verification test";
450 let wrong_message = b"Wrong message";
451
452 let sig1 = keypair1.sign(message);
453 let sig2 = keypair2.sign(wrong_message); let sig3 = keypair3.sign(message);
455
456 let items = vec![
457 (keypair1.public_key(), message.as_slice(), sig1),
458 (keypair2.public_key(), message.as_slice(), sig2),
459 (keypair3.public_key(), message.as_slice(), sig3),
460 ];
461
462 assert!(batch_verify(&items).is_err());
463 }
464
465 #[test]
466 fn test_batch_verify_empty() {
467 let items: Vec<(SchnorrPublicKey, &[u8], SchnorrSignature)> = vec![];
468 assert!(batch_verify(&items).is_err());
469 }
470
471 #[test]
472 fn test_secret_key_serialization() {
473 let sk = SchnorrSecretKey::generate();
474 let sk_bytes = sk.to_bytes();
475 let sk2 = SchnorrSecretKey::from_bytes(&sk_bytes).unwrap();
476
477 assert_eq!(sk.to_bytes(), sk2.to_bytes());
478 assert_eq!(sk.public_key().to_bytes(), sk2.public_key().to_bytes());
479 }
480
481 #[test]
482 fn test_signature_randomness() {
483 let keypair = SchnorrKeypair::generate();
484 let message = b"Test message";
485
486 let sig1 = keypair.sign(message);
488 let sig2 = keypair.sign(message);
489
490 assert_ne!(sig1, sig2);
491 assert!(keypair.verify(message, &sig1).is_ok());
492 assert!(keypair.verify(message, &sig2).is_ok());
493 }
494}