chie_crypto/
schnorr.rs

1//! Schnorr signatures for simplicity and provable security.
2//!
3//! Schnorr signatures provide:
4//! - Simpler construction than EdDSA with cleaner security proofs
5//! - Provable security under the discrete logarithm assumption
6//! - Native support for threshold signatures
7//! - Batch verification support
8//! - Linear signature aggregation
9//!
10//! # Example
11//! ```
12//! use chie_crypto::schnorr::{SchnorrKeypair, batch_verify};
13//!
14//! // Generate a keypair
15//! let keypair = SchnorrKeypair::generate();
16//! let message = b"Hello, Schnorr!";
17//!
18//! // Sign a message
19//! let signature = keypair.sign(message);
20//!
21//! // Verify the signature
22//! assert!(keypair.verify(message, &signature).is_ok());
23//!
24//! // Batch verification
25//! let items = vec![
26//!     (keypair.public_key(), message.as_slice(), signature),
27//! ];
28//! assert!(batch_verify(&items).is_ok());
29//! ```
30
31use curve25519_dalek::{
32    constants::RISTRETTO_BASEPOINT_TABLE,
33    ristretto::{CompressedRistretto, RistrettoPoint},
34    scalar::Scalar,
35};
36use rand::Rng;
37use serde::{Deserialize, Serialize};
38use thiserror::Error;
39use zeroize::Zeroize;
40
41/// Schnorr signature error types
42#[derive(Error, Debug)]
43pub enum SchnorrError {
44    #[error("Invalid signature")]
45    InvalidSignature,
46    #[error("Invalid public key")]
47    InvalidPublicKey,
48    #[error("Invalid secret key")]
49    InvalidSecretKey,
50    #[error("Batch verification failed")]
51    BatchVerificationFailed,
52    #[error("Empty batch")]
53    EmptyBatch,
54    #[error("Serialization error: {0}")]
55    SerializationError(String),
56}
57
58pub type SchnorrResult<T> = Result<T, SchnorrError>;
59
60/// Schnorr secret key (scalar in the Ristretto group)
61#[derive(Clone, Zeroize)]
62#[zeroize(drop)]
63pub struct SchnorrSecretKey {
64    scalar: Scalar,
65}
66
67impl SchnorrSecretKey {
68    /// Generate a random Schnorr secret key
69    pub fn generate() -> Self {
70        let mut rng = rand::thread_rng();
71        let mut bytes = [0u8; 32];
72        rng.fill(&mut bytes);
73        let scalar = Scalar::from_bytes_mod_order(bytes);
74        Self { scalar }
75    }
76
77    /// Create a Schnorr secret key from bytes
78    pub fn from_bytes(bytes: &[u8; 32]) -> SchnorrResult<Self> {
79        let scalar = Scalar::from_bytes_mod_order(*bytes);
80        Ok(Self { scalar })
81    }
82
83    /// Export secret key to bytes
84    pub fn to_bytes(&self) -> [u8; 32] {
85        self.scalar.to_bytes()
86    }
87
88    /// Derive public key from secret key
89    pub fn public_key(&self) -> SchnorrPublicKey {
90        let point = RISTRETTO_BASEPOINT_TABLE * &self.scalar;
91        SchnorrPublicKey { point }
92    }
93}
94
95/// Schnorr public key (point in the Ristretto group)
96#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
97pub struct SchnorrPublicKey {
98    point: RistrettoPoint,
99}
100
101impl SchnorrPublicKey {
102    /// Create a Schnorr public key from compressed bytes
103    pub fn from_bytes(bytes: &[u8; 32]) -> SchnorrResult<Self> {
104        let compressed =
105            CompressedRistretto::from_slice(bytes).map_err(|_| SchnorrError::InvalidPublicKey)?;
106        let point = compressed
107            .decompress()
108            .ok_or(SchnorrError::InvalidPublicKey)?;
109        Ok(Self { point })
110    }
111
112    /// Export public key to compressed bytes
113    pub fn to_bytes(&self) -> [u8; 32] {
114        self.point.compress().to_bytes()
115    }
116}
117
118/// Schnorr signature (challenge + response)
119///
120/// σ = (c, s) where:
121/// - c = H(R || P || m) is the challenge
122/// - s = k - c*x is the response
123/// - R = k*G is the commitment
124/// - P = x*G is the public key
125#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)]
126pub struct SchnorrSignature {
127    challenge: Scalar,
128    response: Scalar,
129}
130
131impl SchnorrSignature {
132    /// Create a Schnorr signature from bytes (64 bytes: 32 for challenge + 32 for response)
133    pub fn from_bytes(bytes: &[u8; 64]) -> SchnorrResult<Self> {
134        let mut challenge_bytes = [0u8; 32];
135        let mut response_bytes = [0u8; 32];
136        challenge_bytes.copy_from_slice(&bytes[..32]);
137        response_bytes.copy_from_slice(&bytes[32..]);
138
139        let challenge: Option<Scalar> = Scalar::from_canonical_bytes(challenge_bytes).into();
140        let response: Option<Scalar> = Scalar::from_canonical_bytes(response_bytes).into();
141
142        let challenge = challenge.ok_or(SchnorrError::InvalidSignature)?;
143        let response = response.ok_or(SchnorrError::InvalidSignature)?;
144
145        Ok(Self {
146            challenge,
147            response,
148        })
149    }
150
151    /// Export signature to bytes
152    pub fn to_bytes(&self) -> [u8; 64] {
153        let mut bytes = [0u8; 64];
154        bytes[..32].copy_from_slice(&self.challenge.to_bytes());
155        bytes[32..].copy_from_slice(&self.response.to_bytes());
156        bytes
157    }
158}
159
160/// Schnorr keypair (secret key + public key)
161pub struct SchnorrKeypair {
162    secret_key: SchnorrSecretKey,
163    public_key: SchnorrPublicKey,
164}
165
166impl SchnorrKeypair {
167    /// Generate a random Schnorr keypair
168    pub fn generate() -> Self {
169        let secret_key = SchnorrSecretKey::generate();
170        let public_key = secret_key.public_key();
171        Self {
172            secret_key,
173            public_key,
174        }
175    }
176
177    /// Create a keypair from a secret key
178    pub fn from_secret_key(secret_key: SchnorrSecretKey) -> Self {
179        let public_key = secret_key.public_key();
180        Self {
181            secret_key,
182            public_key,
183        }
184    }
185
186    /// Get the public key
187    pub fn public_key(&self) -> SchnorrPublicKey {
188        self.public_key
189    }
190
191    /// Get a reference to the secret key
192    pub fn secret_key(&self) -> &SchnorrSecretKey {
193        &self.secret_key
194    }
195
196    /// Sign a message using Schnorr signature scheme
197    ///
198    /// 1. Generate random nonce k
199    /// 2. Compute commitment R = k*G
200    /// 3. Compute challenge c = H(R || P || m)
201    /// 4. Compute response s = k - c*x
202    /// 5. Return signature σ = (c, s)
203    pub fn sign(&self, message: &[u8]) -> SchnorrSignature {
204        let mut rng = rand::thread_rng();
205        let mut nonce_bytes = [0u8; 32];
206        rng.fill(&mut nonce_bytes);
207        let nonce = Scalar::from_bytes_mod_order(nonce_bytes);
208
209        // Commitment: R = k*G
210        let commitment = RISTRETTO_BASEPOINT_TABLE * &nonce;
211
212        // Challenge: c = H(R || P || m)
213        let challenge = compute_challenge(&commitment, &self.public_key.point, message);
214
215        // Response: s = k - c*x
216        let response = nonce - (challenge * self.secret_key.scalar);
217
218        SchnorrSignature {
219            challenge,
220            response,
221        }
222    }
223
224    /// Verify a Schnorr signature
225    ///
226    /// Verification checks: R' = s*G + c*P
227    /// Then verifies: c == H(R' || P || m)
228    pub fn verify(&self, message: &[u8], signature: &SchnorrSignature) -> SchnorrResult<()> {
229        verify(&self.public_key, message, signature)
230    }
231}
232
233/// Compute the Schnorr challenge: c = H(R || P || m)
234fn compute_challenge(
235    commitment: &RistrettoPoint,
236    public_key: &RistrettoPoint,
237    message: &[u8],
238) -> Scalar {
239    let mut data = Vec::new();
240    data.extend_from_slice(&commitment.compress().to_bytes());
241    data.extend_from_slice(&public_key.compress().to_bytes());
242    data.extend_from_slice(message);
243
244    let hash = crate::hash::hash(&data);
245    Scalar::from_bytes_mod_order(hash)
246}
247
248/// Verify a Schnorr signature against a public key and message
249pub fn verify(
250    public_key: &SchnorrPublicKey,
251    message: &[u8],
252    signature: &SchnorrSignature,
253) -> SchnorrResult<()> {
254    // Recompute commitment: R' = s*G + c*P
255    let commitment_reconstructed =
256        RISTRETTO_BASEPOINT_TABLE * &signature.response + public_key.point * signature.challenge;
257
258    // Recompute challenge: c' = H(R' || P || m)
259    let challenge_reconstructed =
260        compute_challenge(&commitment_reconstructed, &public_key.point, message);
261
262    // Verify: c == c'
263    if challenge_reconstructed == signature.challenge {
264        Ok(())
265    } else {
266        Err(SchnorrError::InvalidSignature)
267    }
268}
269
270/// Batch verify multiple Schnorr signatures
271///
272/// More efficient than verifying each signature individually using random linear combination.
273///
274/// For each signature (c_i, s_i) with public key P_i and message m_i:
275/// 1. Reconstruct R_i = s_i*G + c_i*P_i
276/// 2. Verify c_i == H(R_i || P_i || m_i)
277/// 3. Use random linear combination to batch verify: Sum(a_i * R_i) == Sum(a_i * s_i)*G + Sum(a_i * c_i * P_i)
278///
279/// This reduces the number of expensive point operations from 2n to approximately n+2.
280pub fn batch_verify(items: &[(SchnorrPublicKey, &[u8], SchnorrSignature)]) -> SchnorrResult<()> {
281    if items.is_empty() {
282        return Err(SchnorrError::EmptyBatch);
283    }
284
285    // For single signature, use regular verification (no overhead)
286    if items.len() == 1 {
287        return verify(&items[0].0, items[0].1, &items[0].2);
288    }
289
290    let mut rng = rand::thread_rng();
291
292    // Step 1: Reconstruct commitments and verify challenges
293    let mut reconstructed_commitments = Vec::with_capacity(items.len());
294
295    for (public_key, message, signature) in items {
296        // Reconstruct commitment: R_i = s_i*G + c_i*P_i
297        let commitment = RISTRETTO_BASEPOINT_TABLE * &signature.response
298            + public_key.point * signature.challenge;
299
300        // Verify challenge: c_i == H(R_i || P_i || m_i)
301        let expected_challenge = compute_challenge(&commitment, &public_key.point, message);
302
303        if expected_challenge != signature.challenge {
304            return Err(SchnorrError::InvalidSignature);
305        }
306
307        reconstructed_commitments.push(commitment);
308    }
309
310    // Step 2: Batch verify using random linear combination
311    // Generate random weights a_i for each signature
312    let weights: Vec<Scalar> = (0..items.len())
313        .map(|_| {
314            let mut bytes = [0u8; 32];
315            rng.fill(&mut bytes);
316            Scalar::from_bytes_mod_order(bytes)
317        })
318        .collect();
319
320    // Compute left side: Sum(a_i * R_i)
321    let mut lhs = RistrettoPoint::default();
322    for (weight, commitment) in weights.iter().zip(reconstructed_commitments.iter()) {
323        lhs += weight * commitment;
324    }
325
326    // Compute right side: Sum(a_i * s_i)*G + Sum(a_i * c_i * P_i)
327    let mut response_sum = Scalar::ZERO;
328    let mut weighted_pubkey_sum = RistrettoPoint::default();
329
330    for (i, (public_key, _, signature)) in items.iter().enumerate() {
331        response_sum += weights[i] * signature.response;
332        weighted_pubkey_sum += (weights[i] * signature.challenge) * public_key.point;
333    }
334
335    let rhs = RISTRETTO_BASEPOINT_TABLE * &response_sum + weighted_pubkey_sum;
336
337    // Verify the batch equation
338    if lhs == rhs {
339        Ok(())
340    } else {
341        Err(SchnorrError::BatchVerificationFailed)
342    }
343}
344
345/// Aggregate multiple Schnorr signatures for the same message
346///
347/// Note: This is different from BLS aggregation - Schnorr aggregation
348/// requires interactive protocols or more complex schemes
349#[allow(dead_code)]
350pub fn aggregate_signatures(_signatures: &[SchnorrSignature]) -> SchnorrResult<SchnorrSignature> {
351    // Schnorr signature aggregation requires MuSig or similar protocols
352    // This is a placeholder for future implementation
353    unimplemented!("Schnorr aggregation requires MuSig protocol")
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_keypair_generation() {
362        let keypair = SchnorrKeypair::generate();
363        let pk = keypair.public_key();
364
365        // Verify public key can be serialized and deserialized
366        let pk_bytes = pk.to_bytes();
367        let pk2 = SchnorrPublicKey::from_bytes(&pk_bytes).unwrap();
368        assert_eq!(pk, pk2);
369    }
370
371    #[test]
372    fn test_sign_and_verify() {
373        let keypair = SchnorrKeypair::generate();
374        let message = b"Test message for Schnorr signature";
375
376        let signature = keypair.sign(message);
377        assert!(keypair.verify(message, &signature).is_ok());
378    }
379
380    #[test]
381    fn test_verify_wrong_message() {
382        let keypair = SchnorrKeypair::generate();
383        let message = b"Original message";
384        let wrong_message = b"Wrong message";
385
386        let signature = keypair.sign(message);
387        assert!(keypair.verify(wrong_message, &signature).is_err());
388    }
389
390    #[test]
391    fn test_verify_wrong_public_key() {
392        let keypair1 = SchnorrKeypair::generate();
393        let keypair2 = SchnorrKeypair::generate();
394        let message = b"Test message";
395
396        let signature = keypair1.sign(message);
397        assert!(verify(&keypair2.public_key(), message, &signature).is_err());
398    }
399
400    #[test]
401    fn test_signature_serialization() {
402        let keypair = SchnorrKeypair::generate();
403        let message = b"Test message";
404
405        let signature = keypair.sign(message);
406        let sig_bytes = signature.to_bytes();
407        let signature2 = SchnorrSignature::from_bytes(&sig_bytes).unwrap();
408
409        assert_eq!(signature, signature2);
410        assert!(keypair.verify(message, &signature2).is_ok());
411    }
412
413    #[test]
414    fn test_deterministic_public_key() {
415        let sk_bytes = [42u8; 32];
416        let sk1 = SchnorrSecretKey::from_bytes(&sk_bytes).unwrap();
417        let sk2 = SchnorrSecretKey::from_bytes(&sk_bytes).unwrap();
418
419        assert_eq!(sk1.public_key().to_bytes(), sk2.public_key().to_bytes());
420    }
421
422    #[test]
423    fn test_batch_verify() {
424        let keypair1 = SchnorrKeypair::generate();
425        let keypair2 = SchnorrKeypair::generate();
426        let keypair3 = SchnorrKeypair::generate();
427
428        let message = b"Batch verification test";
429
430        let sig1 = keypair1.sign(message);
431        let sig2 = keypair2.sign(message);
432        let sig3 = keypair3.sign(message);
433
434        let items = vec![
435            (keypair1.public_key(), message.as_slice(), sig1),
436            (keypair2.public_key(), message.as_slice(), sig2),
437            (keypair3.public_key(), message.as_slice(), sig3),
438        ];
439
440        assert!(batch_verify(&items).is_ok());
441    }
442
443    #[test]
444    fn test_batch_verify_one_invalid() {
445        let keypair1 = SchnorrKeypair::generate();
446        let keypair2 = SchnorrKeypair::generate();
447        let keypair3 = SchnorrKeypair::generate();
448
449        let message = b"Batch verification test";
450        let wrong_message = b"Wrong message";
451
452        let sig1 = keypair1.sign(message);
453        let sig2 = keypair2.sign(wrong_message); // Invalid!
454        let sig3 = keypair3.sign(message);
455
456        let items = vec![
457            (keypair1.public_key(), message.as_slice(), sig1),
458            (keypair2.public_key(), message.as_slice(), sig2),
459            (keypair3.public_key(), message.as_slice(), sig3),
460        ];
461
462        assert!(batch_verify(&items).is_err());
463    }
464
465    #[test]
466    fn test_batch_verify_empty() {
467        let items: Vec<(SchnorrPublicKey, &[u8], SchnorrSignature)> = vec![];
468        assert!(batch_verify(&items).is_err());
469    }
470
471    #[test]
472    fn test_secret_key_serialization() {
473        let sk = SchnorrSecretKey::generate();
474        let sk_bytes = sk.to_bytes();
475        let sk2 = SchnorrSecretKey::from_bytes(&sk_bytes).unwrap();
476
477        assert_eq!(sk.to_bytes(), sk2.to_bytes());
478        assert_eq!(sk.public_key().to_bytes(), sk2.public_key().to_bytes());
479    }
480
481    #[test]
482    fn test_signature_randomness() {
483        let keypair = SchnorrKeypair::generate();
484        let message = b"Test message";
485
486        // Schnorr signatures should be different each time due to random nonce
487        let sig1 = keypair.sign(message);
488        let sig2 = keypair.sign(message);
489
490        assert_ne!(sig1, sig2);
491        assert!(keypair.verify(message, &sig1).is_ok());
492        assert!(keypair.verify(message, &sig2).is_ok());
493    }
494}