1use serde::{Deserialize, Serialize};
4
5use crate::error::CryptoError;
6use std::fmt;
7
8use super::{DSA, DSAlgorithm, Ed25519Signer, PublicKey, SignatureIdentifier};
9
10#[derive(
12 Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default,
13)]
14pub enum KeyPairAlgorithm {
15 #[default]
17 Ed25519,
18}
19
20impl From<DSAlgorithm> for KeyPairAlgorithm {
21 fn from(algo: DSAlgorithm) -> Self {
22 match algo {
23 DSAlgorithm::Ed25519 => KeyPairAlgorithm::Ed25519,
24 }
25 }
26}
27
28impl From<KeyPairAlgorithm> for DSAlgorithm {
29 fn from(kp_type: KeyPairAlgorithm) -> Self {
30 match kp_type {
31 KeyPairAlgorithm::Ed25519 => DSAlgorithm::Ed25519,
32 }
33 }
34}
35
36impl KeyPairAlgorithm {
37 pub fn generate_keypair(&self) -> Result<KeyPair, CryptoError> {
50 KeyPair::generate(*self)
51 }
52}
53
54impl fmt::Display for KeyPairAlgorithm {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 KeyPairAlgorithm::Ed25519 => write!(f, "Ed25519"),
58 }
59 }
60}
61
62#[derive(Clone)]
89pub enum KeyPair {
90 Ed25519(Ed25519Signer),
91}
92
93impl KeyPair {
94 pub fn generate(key_type: KeyPairAlgorithm) -> Result<Self, CryptoError> {
96 match key_type {
97 KeyPairAlgorithm::Ed25519 => {
98 Ed25519Signer::generate().map(KeyPair::Ed25519)
99 }
100 }
101 }
102
103 pub fn from_secret_der(der: &[u8]) -> Result<Self, CryptoError> {
122 use pkcs8::{ObjectIdentifier, PrivateKeyInfo};
123
124 let private_key_info = PrivateKeyInfo::try_from(der)
126 .map_err(|e| CryptoError::InvalidDerFormat(e.to_string()))?;
127
128 let oid = private_key_info.algorithm.oid;
130
131 const ED25519_OID: ObjectIdentifier =
133 ObjectIdentifier::new_unwrap("1.3.101.112");
134
135 if oid == ED25519_OID {
137 let secret_key = private_key_info.private_key;
139
140 if secret_key.len() < 2 || secret_key[0] != 0x04 {
143 return Err(CryptoError::InvalidSecretKey(
144 "Invalid Ed25519 key encoding in DER".to_string(),
145 ));
146 }
147
148 let key_length = secret_key[1] as usize;
149 if secret_key.len() < 2 + key_length {
150 return Err(CryptoError::InvalidSecretKey(
151 "Truncated Ed25519 key in DER".to_string(),
152 ));
153 }
154
155 let actual_key = &secret_key[2..2 + key_length];
156 Ed25519Signer::from_secret_key(actual_key).map(KeyPair::Ed25519)
157 } else {
158 Err(CryptoError::UnsupportedAlgorithm(format!(
159 "Algorithm with OID {} is not supported",
160 oid
161 )))
162 }
163 }
164
165 pub fn from_seed(
167 key_type: KeyPairAlgorithm,
168 seed: &[u8; 32],
169 ) -> Result<Self, CryptoError> {
170 match key_type {
171 KeyPairAlgorithm::Ed25519 => {
172 Ed25519Signer::from_seed(seed).map(KeyPair::Ed25519)
173 }
174 }
175 }
176
177 pub fn derive_from_data(
179 key_type: KeyPairAlgorithm,
180 data: &[u8],
181 ) -> Result<Self, CryptoError> {
182 match key_type {
183 KeyPairAlgorithm::Ed25519 => {
184 Ed25519Signer::derive_from_data(data).map(KeyPair::Ed25519)
185 }
186 }
187 }
188
189 pub fn from_secret_key(secret_key: &[u8]) -> Result<Self, CryptoError> {
194 match secret_key.len() {
196 32 | 64 => {
197 Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
198 }
199 _ => Err(CryptoError::InvalidSecretKey(format!(
200 "Unsupported key length: {} bytes",
201 secret_key.len()
202 ))),
203 }
204 }
205
206 pub fn from_secret_key_with_type(
208 key_type: KeyPairAlgorithm,
209 secret_key: &[u8],
210 ) -> Result<Self, CryptoError> {
211 match key_type {
212 KeyPairAlgorithm::Ed25519 => {
213 Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
214 }
215 }
216 }
217
218 #[inline]
220 pub fn key_type(&self) -> KeyPairAlgorithm {
221 match self {
222 KeyPair::Ed25519(_) => KeyPairAlgorithm::Ed25519,
223 }
224 }
225
226 #[inline]
228 pub fn sign(
229 &self,
230 message: &[u8],
231 ) -> Result<SignatureIdentifier, CryptoError> {
232 match self {
233 KeyPair::Ed25519(signer) => signer.sign(message),
234 }
235 }
236
237 #[inline]
239 pub fn algorithm(&self) -> DSAlgorithm {
240 match self {
241 KeyPair::Ed25519(signer) => signer.algorithm(),
242 }
243 }
244
245 #[inline]
247 pub fn algorithm_id(&self) -> u8 {
248 match self {
249 KeyPair::Ed25519(signer) => signer.algorithm_id(),
250 }
251 }
252
253 #[inline]
255 pub fn public_key_bytes(&self) -> Vec<u8> {
256 match self {
257 KeyPair::Ed25519(signer) => signer.public_key_bytes(),
258 }
259 }
260
261 #[inline]
263 pub fn public_key(&self) -> PublicKey {
264 PublicKey::new(self.algorithm(), self.public_key_bytes())
265 .expect("KeyPair should always have valid public key")
266 }
267
268 #[inline]
270 pub fn secret_key_bytes(&self) -> Result<Vec<u8>, CryptoError> {
271 match self {
272 KeyPair::Ed25519(signer) => signer.secret_key_bytes(),
273 }
274 }
275
276 pub fn to_bytes(&self) -> Result<Vec<u8>, CryptoError> {
281 let secret = self.secret_key_bytes()?;
282 let mut result = Vec::with_capacity(1 + secret.len());
283 result.push(self.algorithm_id());
284 result.extend_from_slice(&secret);
285 Ok(result)
286 }
287
288 pub fn from_bytes(bytes: &[u8]) -> Result<Self, CryptoError> {
290 if bytes.is_empty() {
291 return Err(CryptoError::InvalidSecretKey(
292 "Data too short to contain algorithm identifier".to_string(),
293 ));
294 }
295
296 let id = bytes[0];
297 let algorithm = DSAlgorithm::from_identifier(id)?;
298 let key_type = KeyPairAlgorithm::from(algorithm);
299 let secret_key = &bytes[1..];
300
301 Self::from_secret_key_with_type(key_type, secret_key)
302 }
303
304 pub fn to_secret_der(&self) -> Result<Vec<u8>, CryptoError> {
321 use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
322
323 const ED25519_OID: ObjectIdentifier =
324 ObjectIdentifier::new_unwrap("1.3.101.112");
325
326 let secret_key_bytes = self.secret_key_bytes()?;
327
328 let mut wrapped_key = Vec::with_capacity(2 + secret_key_bytes.len());
330 wrapped_key.push(0x04); wrapped_key.push(secret_key_bytes.len() as u8); wrapped_key.extend_from_slice(&secret_key_bytes);
333
334 let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
335 oid: ED25519_OID,
336 parameters: None,
337 };
338
339 let private_key_info = PrivateKeyInfo {
340 algorithm: algorithm_identifier,
341 private_key: &wrapped_key,
342 public_key: None,
343 };
344
345 private_key_info.to_der().map_err(|e| {
346 CryptoError::InvalidSecretKey(format!("DER encoding failed: {}", e))
347 })
348 }
349}
350
351impl Default for KeyPair {
352 fn default() -> Self {
353 KeyPair::Ed25519(Ed25519Signer::default())
354 }
355}
356
357impl fmt::Debug for KeyPair {
358 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
359 use crate::common::base64_encoding;
360 f.debug_struct("KeyPair")
361 .field("type", &self.key_type())
362 .field("algorithm", &self.algorithm())
363 .field(
364 "public_key",
365 &base64_encoding::encode(&self.public_key_bytes()),
366 )
367 .finish_non_exhaustive()
368 }
369}
370
371impl fmt::Display for KeyPair {
372 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373 write!(f, "{:?} KeyPair", self.key_type())
374 }
375}
376
377impl DSA for KeyPair {
379 #[inline]
380 fn algorithm_id(&self) -> u8 {
381 KeyPair::algorithm_id(self)
382 }
383
384 #[inline]
385 fn signature_length(&self) -> usize {
386 match self {
387 KeyPair::Ed25519(signer) => signer.signature_length(),
388 }
389 }
390
391 #[inline]
392 fn sign(&self, message: &[u8]) -> Result<SignatureIdentifier, CryptoError> {
393 KeyPair::sign(self, message)
394 }
395
396 #[inline]
397 fn algorithm(&self) -> DSAlgorithm {
398 KeyPair::algorithm(self)
399 }
400
401 #[inline]
402 fn public_key_bytes(&self) -> Vec<u8> {
403 KeyPair::public_key_bytes(self)
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_keypair_generate() {
413 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
414 assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
415 assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
416 assert_eq!(keypair.public_key_bytes().len(), 32);
417 }
418
419 #[test]
420 fn test_keypair_sign_verify() {
421 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
422 let message = b"Test message";
423
424 let signature = keypair.sign(message).unwrap();
425 let public_key = keypair.public_key();
426
427 assert!(public_key.verify(message, &signature).is_ok());
428 assert!(public_key.verify(b"Wrong message", &signature).is_err());
429 }
430
431 #[test]
432 fn test_keypair_from_seed() {
433 let seed = [42u8; 32];
434 let keypair1 =
435 KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
436 let keypair2 =
437 KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
438
439 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
441 }
442
443 #[test]
444 fn test_keypair_derive_from_data() {
445 let data = b"my passphrase";
446 let keypair1 =
447 KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
448 let keypair2 =
449 KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
450
451 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
453
454 let keypair3 =
456 KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, b"different")
457 .unwrap();
458 assert_ne!(keypair1.public_key_bytes(), keypair3.public_key_bytes());
459 }
460
461 #[test]
462 fn test_keypair_serialization() {
463 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
464 let message = b"Test message";
465
466 let bytes = keypair.to_bytes().unwrap();
468 assert_eq!(bytes[0], b'E'); let keypair2 = KeyPair::from_bytes(&bytes).unwrap();
472
473 let sig1 = keypair.sign(message).unwrap();
475 let sig2 = keypair2.sign(message).unwrap();
476
477 let public_key = keypair.public_key();
479 assert!(public_key.verify(message, &sig1).is_ok());
480 assert!(public_key.verify(message, &sig2).is_ok());
481 }
482
483 #[test]
484 fn test_keypair_dsa_trait() {
485 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
486 let message = b"Test message";
487
488 let signature = DSA::sign(&keypair, message).unwrap();
490 assert_eq!(DSA::algorithm(&keypair), DSAlgorithm::Ed25519);
491 assert_eq!(DSA::algorithm_id(&keypair), b'E');
492
493 let public_key = keypair.public_key();
495 assert!(public_key.verify(message, &signature).is_ok());
496 }
497
498 #[test]
499 fn test_keypair_public_key_wrapper() {
500 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
501 let public_key = keypair.public_key();
502
503 assert_eq!(public_key.algorithm(), keypair.algorithm());
504 assert_eq!(public_key.as_bytes(), &keypair.public_key_bytes()[..]);
505 }
506
507 #[test]
508 fn test_keypair_from_secret_key_autodetect() {
509 let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
510 let secret_bytes = keypair1.secret_key_bytes().unwrap();
511
512 let keypair2 = KeyPair::from_secret_key(&secret_bytes).unwrap();
514
515 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
516 }
517
518 #[test]
519 fn test_keypair_type_conversion() {
520 let kp_type = KeyPairAlgorithm::Ed25519;
521 let algo: DSAlgorithm = kp_type.into();
522 assert_eq!(algo, DSAlgorithm::Ed25519);
523
524 let kp_type2: KeyPairAlgorithm = algo.into();
525 assert_eq!(kp_type, kp_type2);
526 }
527
528 #[test]
529 fn test_keypair_algorithm_generate() {
530 let algorithm = KeyPairAlgorithm::Ed25519;
531 let keypair = algorithm.generate_keypair().unwrap();
532
533 assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
534 assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
535
536 let message = b"test";
538 let signature = keypair.sign(message).unwrap();
539 let public_key = keypair.public_key();
540 assert!(public_key.verify(message, &signature).is_ok());
541 }
542
543 #[test]
544 fn test_keypair_algorithm_display() {
545 let algorithm = KeyPairAlgorithm::Ed25519;
546 assert_eq!(algorithm.to_string(), "Ed25519");
547 }
548
549 #[test]
550 fn test_default_keypair() {
551 let keypair = KeyPair::default();
552 assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
553 }
554
555 #[test]
556 fn test_keypair_clone() {
557 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
559 let keypair_clone = keypair.clone();
560
561 assert_eq!(
563 keypair.public_key_bytes(),
564 keypair_clone.public_key_bytes()
565 );
566
567 let message = b"test message";
569 let sig1 = keypair.sign(message).unwrap();
570 let sig2 = keypair_clone.sign(message).unwrap();
571
572 assert_eq!(sig1, sig2);
574
575 let public_key = keypair.public_key();
577 assert!(public_key.verify(message, &sig1).is_ok());
578 assert!(public_key.verify(message, &sig2).is_ok());
579 }
580
581 #[test]
582 fn test_keypair_der_roundtrip() {
583 let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
585 let message = b"Test message for DER roundtrip";
586
587 let der_bytes = keypair1.to_secret_der().unwrap();
589
590 assert_eq!(der_bytes[0], 0x30); let keypair2 = KeyPair::from_secret_der(&der_bytes).unwrap();
595
596 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
598
599 let sig1 = keypair1.sign(message).unwrap();
601 let sig2 = keypair2.sign(message).unwrap();
602
603 let public_key = keypair1.public_key();
604 assert!(public_key.verify(message, &sig1).is_ok());
605 assert!(public_key.verify(message, &sig2).is_ok());
606 }
607
608 #[test]
609 fn test_keypair_from_der_invalid() {
610 let invalid_der = vec![0x00, 0x01, 0x02];
612 let result = KeyPair::from_secret_der(&invalid_der);
613 assert!(result.is_err());
614 assert!(matches!(
615 result.unwrap_err(),
616 CryptoError::InvalidDerFormat(_)
617 ));
618 }
619
620 #[test]
621 fn test_keypair_from_der_unsupported_algorithm() {
622 use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
624
625 let unsupported_oid = ObjectIdentifier::new_unwrap("1.3.132.0.10");
627
628 let fake_key = vec![0x04, 0x20]; let fake_key = [&fake_key[..], &[0u8; 32]].concat();
630
631 let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
632 oid: unsupported_oid,
633 parameters: None,
634 };
635
636 let private_key_info = PrivateKeyInfo {
637 algorithm: algorithm_identifier,
638 private_key: &fake_key,
639 public_key: None,
640 };
641
642 let der_bytes = private_key_info.to_der().unwrap();
643
644 let result = KeyPair::from_secret_der(&der_bytes);
645 assert!(result.is_err());
646 assert!(matches!(
647 result.unwrap_err(),
648 CryptoError::UnsupportedAlgorithm(_)
649 ));
650 }
651}