Skip to main content

ave_identity/keys/
keypair.rs

1//! Generic key pair wrapper for any DSA implementation
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::CryptoError;
6use std::fmt;
7
8use super::{DSA, DSAlgorithm, Ed25519Signer, PublicKey, SignatureIdentifier};
9
10/// Key pair types supported by the system
11#[derive(
12    Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default,
13)]
14pub enum KeyPairAlgorithm {
15    /// Ed25519 elliptic curve signature scheme
16    #[default]
17    Ed25519,
18}
19
20impl From<DSAlgorithm> for KeyPairAlgorithm {
21    fn from(algo: DSAlgorithm) -> Self {
22        match algo {
23            DSAlgorithm::Ed25519 => KeyPairAlgorithm::Ed25519,
24        }
25    }
26}
27
28impl From<KeyPairAlgorithm> for DSAlgorithm {
29    fn from(kp_type: KeyPairAlgorithm) -> Self {
30        match kp_type {
31            KeyPairAlgorithm::Ed25519 => DSAlgorithm::Ed25519,
32        }
33    }
34}
35
36impl KeyPairAlgorithm {
37    /// Generate a new key pair for this algorithm
38    ///
39    /// This is a convenience method that creates a new random key pair
40    /// of the specified algorithm type.
41    ///
42    /// # Example
43    /// ```rust
44    /// use ave_identity::keys::KeyPairAlgorithm;
45    ///
46    /// let algorithm = KeyPairAlgorithm::Ed25519;
47    /// let keypair = algorithm.generate_keypair().unwrap();
48    /// ```
49    pub fn generate_keypair(&self) -> Result<KeyPair, CryptoError> {
50        KeyPair::generate(*self)
51    }
52}
53
54impl fmt::Display for KeyPairAlgorithm {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            KeyPairAlgorithm::Ed25519 => write!(f, "Ed25519"),
58        }
59    }
60}
61
62/// Generic key pair wrapper that can hold any DSA implementation
63///
64/// This provides algorithm-agnostic operations for signing and verification.
65///
66/// Cloning a KeyPair is cheap because the underlying secret keys are stored
67/// in Arc<EncryptedMem>, so only the reference is cloned, not the encrypted data.
68///
69/// # Example
70///
71/// ```rust
72/// use ave_identity::keys::{KeyPair, KeyPairAlgorithm, DSA};
73///
74/// // Generate a key pair
75/// let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).expect("Failed to generate key pair");
76///
77/// let message = b"Hello, World!";
78///
79/// // Sign message using generic interface
80/// let signature = keypair.sign(message).unwrap();
81///
82/// // Get public key
83/// let public_key = keypair.public_key();
84///
85/// // Verify
86/// assert!(public_key.verify(message, &signature).is_ok());
87/// ```
88#[derive(Clone)]
89pub enum KeyPair {
90    Ed25519(Ed25519Signer),
91}
92
93impl KeyPair {
94    /// Generate a new random key pair of the specified type
95    pub fn generate(key_type: KeyPairAlgorithm) -> Result<Self, CryptoError> {
96        match key_type {
97            KeyPairAlgorithm::Ed25519 => {
98                Ed25519Signer::generate().map(KeyPair::Ed25519)
99            }
100        }
101    }
102
103    /// Create key pair from PKCS#8 DER-encoded secret key
104    ///
105    /// This method automatically detects the algorithm from the OID in the DER structure.
106    /// Supported OIDs:
107    /// - Ed25519: 1.3.101.112
108    ///
109    /// # Errors
110    /// - Returns `InvalidDerFormat` if the DER structure is malformed
111    /// - Returns `UnsupportedAlgorithm` if the algorithm OID is not supported
112    /// - Returns `InvalidSecretKey` if the key data is invalid
113    ///
114    /// # Example
115    /// ```no_run
116    /// use ave_identity::keys::KeyPair;
117    ///
118    /// let der_bytes = std::fs::read("private_key.der").unwrap();
119    /// let keypair = KeyPair::from_secret_der(&der_bytes).unwrap();
120    /// ```
121    pub fn from_secret_der(der: &[u8]) -> Result<Self, CryptoError> {
122        use pkcs8::{ObjectIdentifier, PrivateKeyInfo};
123
124        // Parse the DER structure
125        let private_key_info = PrivateKeyInfo::try_from(der)
126            .map_err(|e| CryptoError::InvalidDerFormat(e.to_string()))?;
127
128        // Get the algorithm OID
129        let oid = private_key_info.algorithm.oid;
130
131        // Ed25519 OID: 1.3.101.112
132        const ED25519_OID: ObjectIdentifier =
133            ObjectIdentifier::new_unwrap("1.3.101.112");
134
135        // Match OID to algorithm
136        if oid == ED25519_OID {
137            // Extract the secret key bytes from the OCTET STRING
138            let secret_key = private_key_info.private_key;
139
140            // Ed25519 keys in PKCS#8 are wrapped in an OCTET STRING
141            // The first byte should be 0x04 (OCTET STRING tag), followed by length
142            if secret_key.len() < 2 || secret_key[0] != 0x04 {
143                return Err(CryptoError::InvalidSecretKey(
144                    "Invalid Ed25519 key encoding in DER".to_string(),
145                ));
146            }
147
148            let key_length = secret_key[1] as usize;
149            if secret_key.len() < 2 + key_length {
150                return Err(CryptoError::InvalidSecretKey(
151                    "Truncated Ed25519 key in DER".to_string(),
152                ));
153            }
154
155            let actual_key = &secret_key[2..2 + key_length];
156            Ed25519Signer::from_secret_key(actual_key).map(KeyPair::Ed25519)
157        } else {
158            Err(CryptoError::UnsupportedAlgorithm(format!(
159                "Algorithm with OID {} is not supported",
160                oid
161            )))
162        }
163    }
164
165    /// Create key pair from seed
166    pub fn from_seed(
167        key_type: KeyPairAlgorithm,
168        seed: &[u8; 32],
169    ) -> Result<Self, CryptoError> {
170        match key_type {
171            KeyPairAlgorithm::Ed25519 => {
172                Ed25519Signer::from_seed(seed).map(KeyPair::Ed25519)
173            }
174        }
175    }
176
177    /// Derive key pair from arbitrary data (will be hashed)
178    pub fn derive_from_data(
179        key_type: KeyPairAlgorithm,
180        data: &[u8],
181    ) -> Result<Self, CryptoError> {
182        match key_type {
183            KeyPairAlgorithm::Ed25519 => {
184                Ed25519Signer::derive_from_data(data).map(KeyPair::Ed25519)
185            }
186        }
187    }
188
189    /// Create key pair from secret key bytes
190    ///
191    /// Attempts to auto-detect the algorithm from key length.
192    /// For explicit algorithm selection, use `from_secret_key_with_type`.
193    pub fn from_secret_key(secret_key: &[u8]) -> Result<Self, CryptoError> {
194        // Try to detect algorithm from key length
195        match secret_key.len() {
196            32 | 64 => {
197                Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
198            }
199            _ => Err(CryptoError::InvalidSecretKey(format!(
200                "Unsupported key length: {} bytes",
201                secret_key.len()
202            ))),
203        }
204    }
205
206    /// Create key pair from secret key bytes with explicit type
207    pub fn from_secret_key_with_type(
208        key_type: KeyPairAlgorithm,
209        secret_key: &[u8],
210    ) -> Result<Self, CryptoError> {
211        match key_type {
212            KeyPairAlgorithm::Ed25519 => {
213                Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
214            }
215        }
216    }
217
218    /// Get the key pair type
219    #[inline]
220    pub fn key_type(&self) -> KeyPairAlgorithm {
221        match self {
222            KeyPair::Ed25519(_) => KeyPairAlgorithm::Ed25519,
223        }
224    }
225
226    /// Sign a message using the appropriate algorithm
227    #[inline]
228    pub fn sign(
229        &self,
230        message: &[u8],
231    ) -> Result<SignatureIdentifier, CryptoError> {
232        match self {
233            KeyPair::Ed25519(signer) => signer.sign(message),
234        }
235    }
236
237    /// Get the algorithm used by this key pair
238    #[inline]
239    pub fn algorithm(&self) -> DSAlgorithm {
240        match self {
241            KeyPair::Ed25519(signer) => signer.algorithm(),
242        }
243    }
244
245    /// Get the algorithm identifier
246    #[inline]
247    pub fn algorithm_id(&self) -> u8 {
248        match self {
249            KeyPair::Ed25519(signer) => signer.algorithm_id(),
250        }
251    }
252
253    /// Get the public key bytes
254    #[inline]
255    pub fn public_key_bytes(&self) -> Vec<u8> {
256        match self {
257            KeyPair::Ed25519(signer) => signer.public_key_bytes(),
258        }
259    }
260
261    /// Get the public key as a PublicKey wrapper
262    #[inline]
263    pub fn public_key(&self) -> PublicKey {
264        PublicKey::new(self.algorithm(), self.public_key_bytes())
265            .expect("KeyPair should always have valid public key")
266    }
267
268    /// Get the secret key bytes (if available)
269    #[inline]
270    pub fn secret_key_bytes(&self) -> Result<Vec<u8>, CryptoError> {
271        match self {
272            KeyPair::Ed25519(signer) => signer.secret_key_bytes(),
273        }
274    }
275
276    /// Serialize to bytes (includes algorithm identifier and secret key)
277    ///
278    /// # Warning
279    /// This exposes the secret key. Use with extreme caution.
280    pub fn to_bytes(&self) -> Result<Vec<u8>, CryptoError> {
281        let secret = self.secret_key_bytes()?;
282        let mut result = Vec::with_capacity(1 + secret.len());
283        result.push(self.algorithm_id());
284        result.extend_from_slice(&secret);
285        Ok(result)
286    }
287
288    /// Deserialize from bytes (includes algorithm identifier)
289    pub fn from_bytes(bytes: &[u8]) -> Result<Self, CryptoError> {
290        if bytes.is_empty() {
291            return Err(CryptoError::InvalidSecretKey(
292                "Data too short to contain algorithm identifier".to_string(),
293            ));
294        }
295
296        let id = bytes[0];
297        let algorithm = DSAlgorithm::from_identifier(id)?;
298        let key_type = KeyPairAlgorithm::from(algorithm);
299        let secret_key = &bytes[1..];
300
301        Self::from_secret_key_with_type(key_type, secret_key)
302    }
303
304    /// Serialize to PKCS#8 DER format
305    ///
306    /// This creates a DER-encoded PKCS#8 PrivateKeyInfo structure containing
307    /// the secret key and algorithm identifier.
308    ///
309    /// # Errors
310    /// - Returns `InvalidSecretKey` if the secret key cannot be retrieved
311    ///
312    /// # Example
313    /// ```no_run
314    /// use ave_identity::keys::{KeyPair, KeyPairAlgorithm};
315    ///
316    /// let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
317    /// let der_bytes = keypair.to_secret_der().unwrap();
318    /// std::fs::write("private_key.der", der_bytes).unwrap();
319    /// ```
320    pub fn to_secret_der(&self) -> Result<Vec<u8>, CryptoError> {
321        use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
322
323        const ED25519_OID: ObjectIdentifier =
324            ObjectIdentifier::new_unwrap("1.3.101.112");
325
326        let secret_key_bytes = self.secret_key_bytes()?;
327
328        // Wrap the key in an OCTET STRING (0x04 tag)
329        let mut wrapped_key = Vec::with_capacity(2 + secret_key_bytes.len());
330        wrapped_key.push(0x04); // OCTET STRING tag
331        wrapped_key.push(secret_key_bytes.len() as u8); // length
332        wrapped_key.extend_from_slice(&secret_key_bytes);
333
334        let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
335            oid: ED25519_OID,
336            parameters: None,
337        };
338
339        let private_key_info = PrivateKeyInfo {
340            algorithm: algorithm_identifier,
341            private_key: &wrapped_key,
342            public_key: None,
343        };
344
345        private_key_info.to_der().map_err(|e| {
346            CryptoError::InvalidSecretKey(format!("DER encoding failed: {}", e))
347        })
348    }
349}
350
351impl Default for KeyPair {
352    fn default() -> Self {
353        KeyPair::Ed25519(Ed25519Signer::default())
354    }
355}
356
357impl fmt::Debug for KeyPair {
358    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359        use crate::common::base64_encoding;
360        f.debug_struct("KeyPair")
361            .field("type", &self.key_type())
362            .field("algorithm", &self.algorithm())
363            .field(
364                "public_key",
365                &base64_encoding::encode(&self.public_key_bytes()),
366            )
367            .finish_non_exhaustive()
368    }
369}
370
371impl fmt::Display for KeyPair {
372    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373        write!(f, "{:?} KeyPair", self.key_type())
374    }
375}
376
377// Implement DSA trait for KeyPair to make it fully interchangeable
378impl DSA for KeyPair {
379    #[inline]
380    fn algorithm_id(&self) -> u8 {
381        KeyPair::algorithm_id(self)
382    }
383
384    #[inline]
385    fn signature_length(&self) -> usize {
386        match self {
387            KeyPair::Ed25519(signer) => signer.signature_length(),
388        }
389    }
390
391    #[inline]
392    fn sign(&self, message: &[u8]) -> Result<SignatureIdentifier, CryptoError> {
393        KeyPair::sign(self, message)
394    }
395
396    #[inline]
397    fn algorithm(&self) -> DSAlgorithm {
398        KeyPair::algorithm(self)
399    }
400
401    #[inline]
402    fn public_key_bytes(&self) -> Vec<u8> {
403        KeyPair::public_key_bytes(self)
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[test]
412    fn test_keypair_generate() {
413        let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
414        assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
415        assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
416        assert_eq!(keypair.public_key_bytes().len(), 32);
417    }
418
419    #[test]
420    fn test_keypair_sign_verify() {
421        let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
422        let message = b"Test message";
423
424        let signature = keypair.sign(message).unwrap();
425        let public_key = keypair.public_key();
426
427        assert!(public_key.verify(message, &signature).is_ok());
428        assert!(public_key.verify(b"Wrong message", &signature).is_err());
429    }
430
431    #[test]
432    fn test_keypair_from_seed() {
433        let seed = [42u8; 32];
434        let keypair1 =
435            KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
436        let keypair2 =
437            KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
438
439        // Same seed should produce same keys
440        assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
441    }
442
443    #[test]
444    fn test_keypair_derive_from_data() {
445        let data = b"my passphrase";
446        let keypair1 =
447            KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
448        let keypair2 =
449            KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
450
451        // Same data should produce same keys
452        assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
453
454        // Different data should produce different keys
455        let keypair3 =
456            KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, b"different")
457                .unwrap();
458        assert_ne!(keypair1.public_key_bytes(), keypair3.public_key_bytes());
459    }
460
461    #[test]
462    fn test_keypair_serialization() {
463        let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
464        let message = b"Test message";
465
466        // Serialize
467        let bytes = keypair.to_bytes().unwrap();
468        assert_eq!(bytes[0], b'E'); // Ed25519 identifier
469
470        // Deserialize
471        let keypair2 = KeyPair::from_bytes(&bytes).unwrap();
472
473        // Should produce same signatures
474        let sig1 = keypair.sign(message).unwrap();
475        let sig2 = keypair2.sign(message).unwrap();
476
477        // Both should verify correctly
478        let public_key = keypair.public_key();
479        assert!(public_key.verify(message, &sig1).is_ok());
480        assert!(public_key.verify(message, &sig2).is_ok());
481    }
482
483    #[test]
484    fn test_keypair_dsa_trait() {
485        let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
486        let message = b"Test message";
487
488        // Use DSA trait methods
489        let signature = DSA::sign(&keypair, message).unwrap();
490        assert_eq!(DSA::algorithm(&keypair), DSAlgorithm::Ed25519);
491        assert_eq!(DSA::algorithm_id(&keypair), b'E');
492
493        // Verify
494        let public_key = keypair.public_key();
495        assert!(public_key.verify(message, &signature).is_ok());
496    }
497
498    #[test]
499    fn test_keypair_public_key_wrapper() {
500        let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
501        let public_key = keypair.public_key();
502
503        assert_eq!(public_key.algorithm(), keypair.algorithm());
504        assert_eq!(public_key.as_bytes(), &keypair.public_key_bytes()[..]);
505    }
506
507    #[test]
508    fn test_keypair_from_secret_key_autodetect() {
509        let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
510        let secret_bytes = keypair1.secret_key_bytes().unwrap();
511
512        // Auto-detect should work
513        let keypair2 = KeyPair::from_secret_key(&secret_bytes).unwrap();
514
515        assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
516    }
517
518    #[test]
519    fn test_keypair_type_conversion() {
520        let kp_type = KeyPairAlgorithm::Ed25519;
521        let algo: DSAlgorithm = kp_type.into();
522        assert_eq!(algo, DSAlgorithm::Ed25519);
523
524        let kp_type2: KeyPairAlgorithm = algo.into();
525        assert_eq!(kp_type, kp_type2);
526    }
527
528    #[test]
529    fn test_keypair_algorithm_generate() {
530        let algorithm = KeyPairAlgorithm::Ed25519;
531        let keypair = algorithm.generate_keypair().unwrap();
532
533        assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
534        assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
535
536        // Should be able to sign
537        let message = b"test";
538        let signature = keypair.sign(message).unwrap();
539        let public_key = keypair.public_key();
540        assert!(public_key.verify(message, &signature).is_ok());
541    }
542
543    #[test]
544    fn test_keypair_algorithm_display() {
545        let algorithm = KeyPairAlgorithm::Ed25519;
546        assert_eq!(algorithm.to_string(), "Ed25519");
547    }
548
549    #[test]
550    fn test_default_keypair() {
551        let keypair = KeyPair::default();
552        assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
553    }
554
555    #[test]
556    fn test_keypair_clone() {
557        // Test that cloning works correctly
558        let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
559        let keypair_clone = keypair.clone();
560
561        // Both should have the same public key
562        assert_eq!(
563            keypair.public_key_bytes(),
564            keypair_clone.public_key_bytes()
565        );
566
567        // Both should sign the same way
568        let message = b"test message";
569        let sig1 = keypair.sign(message).unwrap();
570        let sig2 = keypair_clone.sign(message).unwrap();
571
572        // Signatures should be identical (deterministic)
573        assert_eq!(sig1, sig2);
574
575        // Both signatures should verify
576        let public_key = keypair.public_key();
577        assert!(public_key.verify(message, &sig1).is_ok());
578        assert!(public_key.verify(message, &sig2).is_ok());
579    }
580
581    #[test]
582    fn test_keypair_der_roundtrip() {
583        // Test DER serialization and deserialization
584        let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
585        let message = b"Test message for DER roundtrip";
586
587        // Serialize to DER
588        let der_bytes = keypair1.to_secret_der().unwrap();
589
590        // Verify it starts with DER SEQUENCE tag
591        assert_eq!(der_bytes[0], 0x30); // SEQUENCE tag
592
593        // Deserialize from DER
594        let keypair2 = KeyPair::from_secret_der(&der_bytes).unwrap();
595
596        // Should have the same public key
597        assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
598
599        // Should produce verifiable signatures
600        let sig1 = keypair1.sign(message).unwrap();
601        let sig2 = keypair2.sign(message).unwrap();
602
603        let public_key = keypair1.public_key();
604        assert!(public_key.verify(message, &sig1).is_ok());
605        assert!(public_key.verify(message, &sig2).is_ok());
606    }
607
608    #[test]
609    fn test_keypair_from_der_invalid() {
610        // Test error handling for invalid DER data
611        let invalid_der = vec![0x00, 0x01, 0x02];
612        let result = KeyPair::from_secret_der(&invalid_der);
613        assert!(result.is_err());
614        assert!(matches!(
615            result.unwrap_err(),
616            CryptoError::InvalidDerFormat(_)
617        ));
618    }
619
620    #[test]
621    fn test_keypair_from_der_unsupported_algorithm() {
622        // Create a valid DER structure but with an unsupported OID
623        use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
624
625        // Use a different OID (e.g., secp256k1: 1.3.132.0.10)
626        let unsupported_oid = ObjectIdentifier::new_unwrap("1.3.132.0.10");
627
628        let fake_key = vec![0x04, 0x20]; // OCTET STRING tag + length
629        let fake_key = [&fake_key[..], &[0u8; 32]].concat();
630
631        let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
632            oid: unsupported_oid,
633            parameters: None,
634        };
635
636        let private_key_info = PrivateKeyInfo {
637            algorithm: algorithm_identifier,
638            private_key: &fake_key,
639            public_key: None,
640        };
641
642        let der_bytes = private_key_info.to_der().unwrap();
643
644        let result = KeyPair::from_secret_der(&der_bytes);
645        assert!(result.is_err());
646        assert!(matches!(
647            result.unwrap_err(),
648            CryptoError::UnsupportedAlgorithm(_)
649        ));
650    }
651}