1use serde::{Deserialize, Serialize};
4
5use crate::error::CryptoError;
6use std::fmt;
7
8use super::{DSA, DSAlgorithm, Ed25519Signer, PublicKey, SignatureIdentifier};
9
10#[derive(
12 Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default,
13)]
14pub enum KeyPairAlgorithm {
15 #[default]
17 Ed25519,
18}
19
20impl From<DSAlgorithm> for KeyPairAlgorithm {
21 fn from(algo: DSAlgorithm) -> Self {
22 match algo {
23 DSAlgorithm::Ed25519 => Self::Ed25519,
24 }
25 }
26}
27
28impl From<KeyPairAlgorithm> for DSAlgorithm {
29 fn from(kp_type: KeyPairAlgorithm) -> Self {
30 match kp_type {
31 KeyPairAlgorithm::Ed25519 => Self::Ed25519,
32 }
33 }
34}
35
36impl KeyPairAlgorithm {
37 pub fn generate_keypair(&self) -> Result<KeyPair, CryptoError> {
39 KeyPair::generate(*self)
40 }
41}
42
43impl fmt::Display for KeyPairAlgorithm {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 match self {
46 Self::Ed25519 => write!(f, "Ed25519"),
47 }
48 }
49}
50
51#[derive(Clone)]
55pub enum KeyPair {
56 Ed25519(Ed25519Signer),
57}
58
59impl KeyPair {
60 pub fn generate(key_type: KeyPairAlgorithm) -> Result<Self, CryptoError> {
62 match key_type {
63 KeyPairAlgorithm::Ed25519 => {
64 Ed25519Signer::generate().map(KeyPair::Ed25519)
65 }
66 }
67 }
68
69 pub fn from_secret_der(der: &[u8]) -> Result<Self, CryptoError> {
73 use pkcs8::{ObjectIdentifier, PrivateKeyInfo};
74
75 let private_key_info = PrivateKeyInfo::try_from(der)
77 .map_err(|e| CryptoError::InvalidDerFormat(e.to_string()))?;
78
79 let oid = private_key_info.algorithm.oid;
81
82 const ED25519_OID: ObjectIdentifier =
84 ObjectIdentifier::new_unwrap("1.3.101.112");
85
86 if oid == ED25519_OID {
88 let secret_key = private_key_info.private_key;
90
91 if secret_key.len() < 2 || secret_key[0] != 0x04 {
94 return Err(CryptoError::InvalidSecretKey(
95 "Invalid Ed25519 key encoding in DER".to_string(),
96 ));
97 }
98
99 let key_length = secret_key[1] as usize;
100 if secret_key.len() < 2 + key_length {
101 return Err(CryptoError::InvalidSecretKey(
102 "Truncated Ed25519 key in DER".to_string(),
103 ));
104 }
105
106 let actual_key = &secret_key[2..2 + key_length];
107 Ed25519Signer::from_secret_key(actual_key).map(KeyPair::Ed25519)
108 } else {
109 Err(CryptoError::UnsupportedAlgorithm(format!(
110 "Algorithm with OID {} is not supported",
111 oid
112 )))
113 }
114 }
115
116 pub fn from_seed(
118 key_type: KeyPairAlgorithm,
119 seed: &[u8; 32],
120 ) -> Result<Self, CryptoError> {
121 match key_type {
122 KeyPairAlgorithm::Ed25519 => {
123 Ed25519Signer::from_seed(seed).map(KeyPair::Ed25519)
124 }
125 }
126 }
127
128 pub fn derive_from_data(
130 key_type: KeyPairAlgorithm,
131 data: &[u8],
132 ) -> Result<Self, CryptoError> {
133 match key_type {
134 KeyPairAlgorithm::Ed25519 => {
135 Ed25519Signer::derive_from_data(data).map(KeyPair::Ed25519)
136 }
137 }
138 }
139
140 pub fn from_secret_key(secret_key: &[u8]) -> Result<Self, CryptoError> {
144 match secret_key.len() {
146 32 | 64 => {
147 Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
148 }
149 _ => Err(CryptoError::InvalidSecretKey(format!(
150 "Unsupported key length: {} bytes",
151 secret_key.len()
152 ))),
153 }
154 }
155
156 pub fn from_secret_key_with_type(
158 key_type: KeyPairAlgorithm,
159 secret_key: &[u8],
160 ) -> Result<Self, CryptoError> {
161 match key_type {
162 KeyPairAlgorithm::Ed25519 => {
163 Ed25519Signer::from_secret_key(secret_key).map(KeyPair::Ed25519)
164 }
165 }
166 }
167
168 #[inline]
170 pub const fn key_type(&self) -> KeyPairAlgorithm {
171 match self {
172 Self::Ed25519(_) => KeyPairAlgorithm::Ed25519,
173 }
174 }
175
176 #[inline]
178 pub fn sign(
179 &self,
180 message: &[u8],
181 ) -> Result<SignatureIdentifier, CryptoError> {
182 match self {
183 Self::Ed25519(signer) => signer.sign(message),
184 }
185 }
186
187 #[inline]
189 pub fn algorithm(&self) -> DSAlgorithm {
190 match self {
191 Self::Ed25519(signer) => signer.algorithm(),
192 }
193 }
194
195 #[inline]
197 pub fn algorithm_id(&self) -> u8 {
198 match self {
199 Self::Ed25519(signer) => signer.algorithm_id(),
200 }
201 }
202
203 #[inline]
205 pub fn public_key_bytes(&self) -> Vec<u8> {
206 match self {
207 Self::Ed25519(signer) => signer.public_key_bytes(),
208 }
209 }
210
211 #[inline]
213 pub fn public_key(&self) -> PublicKey {
214 PublicKey::new(self.algorithm(), self.public_key_bytes())
215 .expect("KeyPair should always have valid public key")
216 }
217
218 #[inline]
220 pub fn secret_key_bytes(&self) -> Result<Vec<u8>, CryptoError> {
221 match self {
222 Self::Ed25519(signer) => signer.secret_key_bytes(),
223 }
224 }
225
226 pub fn to_bytes(&self) -> Result<Vec<u8>, CryptoError> {
230 let secret = self.secret_key_bytes()?;
231 let mut result = Vec::with_capacity(1 + secret.len());
232 result.push(self.algorithm_id());
233 result.extend_from_slice(&secret);
234 Ok(result)
235 }
236
237 pub fn from_bytes(bytes: &[u8]) -> Result<Self, CryptoError> {
239 if bytes.is_empty() {
240 return Err(CryptoError::InvalidSecretKey(
241 "Data too short to contain algorithm identifier".to_string(),
242 ));
243 }
244
245 let id = bytes[0];
246 let algorithm = DSAlgorithm::from_identifier(id)?;
247 let key_type = KeyPairAlgorithm::from(algorithm);
248 let secret_key = &bytes[1..];
249
250 Self::from_secret_key_with_type(key_type, secret_key)
251 }
252
253 pub fn to_secret_der(&self) -> Result<Vec<u8>, CryptoError> {
255 use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
256
257 const ED25519_OID: ObjectIdentifier =
258 ObjectIdentifier::new_unwrap("1.3.101.112");
259
260 let secret_key_bytes = self.secret_key_bytes()?;
261
262 let mut wrapped_key = Vec::with_capacity(2 + secret_key_bytes.len());
264 wrapped_key.push(0x04); wrapped_key.push(secret_key_bytes.len() as u8); wrapped_key.extend_from_slice(&secret_key_bytes);
267
268 let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
269 oid: ED25519_OID,
270 parameters: None,
271 };
272
273 let private_key_info = PrivateKeyInfo {
274 algorithm: algorithm_identifier,
275 private_key: &wrapped_key,
276 public_key: None,
277 };
278
279 private_key_info.to_der().map_err(|e| {
280 CryptoError::InvalidSecretKey(format!("DER encoding failed: {}", e))
281 })
282 }
283}
284
285impl Default for KeyPair {
286 fn default() -> Self {
287 Self::Ed25519(Ed25519Signer::default())
288 }
289}
290
291impl fmt::Debug for KeyPair {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 use crate::common::base64_encoding;
294 f.debug_struct("KeyPair")
295 .field("type", &self.key_type())
296 .field("algorithm", &self.algorithm())
297 .field(
298 "public_key",
299 &base64_encoding::encode(&self.public_key_bytes()),
300 )
301 .finish_non_exhaustive()
302 }
303}
304
305impl fmt::Display for KeyPair {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 write!(f, "{:?} KeyPair", self.key_type())
308 }
309}
310
311impl DSA for KeyPair {
313 #[inline]
314 fn algorithm_id(&self) -> u8 {
315 Self::algorithm_id(self)
316 }
317
318 #[inline]
319 fn signature_length(&self) -> usize {
320 match self {
321 Self::Ed25519(signer) => signer.signature_length(),
322 }
323 }
324
325 #[inline]
326 fn sign(&self, message: &[u8]) -> Result<SignatureIdentifier, CryptoError> {
327 Self::sign(self, message)
328 }
329
330 #[inline]
331 fn algorithm(&self) -> DSAlgorithm {
332 Self::algorithm(self)
333 }
334
335 #[inline]
336 fn public_key_bytes(&self) -> Vec<u8> {
337 Self::public_key_bytes(self)
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_keypair_generate() {
347 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
348 assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
349 assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
350 assert_eq!(keypair.public_key_bytes().len(), 32);
351 }
352
353 #[test]
354 fn test_keypair_sign_verify() {
355 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
356 let message = b"Test message";
357
358 let signature = keypair.sign(message).unwrap();
359 let public_key = keypair.public_key();
360
361 assert!(public_key.verify(message, &signature).is_ok());
362 assert!(public_key.verify(b"Wrong message", &signature).is_err());
363 }
364
365 #[test]
366 fn test_keypair_from_seed() {
367 let seed = [42u8; 32];
368 let keypair1 =
369 KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
370 let keypair2 =
371 KeyPair::from_seed(KeyPairAlgorithm::Ed25519, &seed).unwrap();
372
373 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
375 }
376
377 #[test]
378 fn test_keypair_derive_from_data() {
379 let data = b"my passphrase";
380 let keypair1 =
381 KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
382 let keypair2 =
383 KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, data).unwrap();
384
385 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
387
388 let keypair3 =
390 KeyPair::derive_from_data(KeyPairAlgorithm::Ed25519, b"different")
391 .unwrap();
392 assert_ne!(keypair1.public_key_bytes(), keypair3.public_key_bytes());
393 }
394
395 #[test]
396 fn test_keypair_serialization() {
397 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
398 let message = b"Test message";
399
400 let bytes = keypair.to_bytes().unwrap();
402 assert_eq!(bytes[0], b'E'); let keypair2 = KeyPair::from_bytes(&bytes).unwrap();
406
407 let sig1 = keypair.sign(message).unwrap();
409 let sig2 = keypair2.sign(message).unwrap();
410
411 let public_key = keypair.public_key();
413 assert!(public_key.verify(message, &sig1).is_ok());
414 assert!(public_key.verify(message, &sig2).is_ok());
415 }
416
417 #[test]
418 fn test_keypair_dsa_trait() {
419 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
420 let message = b"Test message";
421
422 let signature = DSA::sign(&keypair, message).unwrap();
424 assert_eq!(DSA::algorithm(&keypair), DSAlgorithm::Ed25519);
425 assert_eq!(DSA::algorithm_id(&keypair), b'E');
426
427 let public_key = keypair.public_key();
429 assert!(public_key.verify(message, &signature).is_ok());
430 }
431
432 #[test]
433 fn test_keypair_public_key_wrapper() {
434 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
435 let public_key = keypair.public_key();
436
437 assert_eq!(public_key.algorithm(), keypair.algorithm());
438 assert_eq!(public_key.as_bytes(), &keypair.public_key_bytes()[..]);
439 }
440
441 #[test]
442 fn test_keypair_from_secret_key_autodetect() {
443 let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
444 let secret_bytes = keypair1.secret_key_bytes().unwrap();
445
446 let keypair2 = KeyPair::from_secret_key(&secret_bytes).unwrap();
448
449 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
450 }
451
452 #[test]
453 fn test_keypair_type_conversion() {
454 let kp_type = KeyPairAlgorithm::Ed25519;
455 let algo: DSAlgorithm = kp_type.into();
456 assert_eq!(algo, DSAlgorithm::Ed25519);
457
458 let kp_type2: KeyPairAlgorithm = algo.into();
459 assert_eq!(kp_type, kp_type2);
460 }
461
462 #[test]
463 fn test_keypair_algorithm_generate() {
464 let algorithm = KeyPairAlgorithm::Ed25519;
465 let keypair = algorithm.generate_keypair().unwrap();
466
467 assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
468 assert_eq!(keypair.algorithm(), DSAlgorithm::Ed25519);
469
470 let message = b"test";
472 let signature = keypair.sign(message).unwrap();
473 let public_key = keypair.public_key();
474 assert!(public_key.verify(message, &signature).is_ok());
475 }
476
477 #[test]
478 fn test_keypair_algorithm_display() {
479 let algorithm = KeyPairAlgorithm::Ed25519;
480 assert_eq!(algorithm.to_string(), "Ed25519");
481 }
482
483 #[test]
484 fn test_default_keypair() {
485 let keypair = KeyPair::default();
486 assert_eq!(keypair.key_type(), KeyPairAlgorithm::Ed25519);
487 }
488
489 #[test]
490 fn test_keypair_clone() {
491 let keypair = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
493 let keypair_clone = keypair.clone();
494
495 assert_eq!(
497 keypair.public_key_bytes(),
498 keypair_clone.public_key_bytes()
499 );
500
501 let message = b"test message";
503 let sig1 = keypair.sign(message).unwrap();
504 let sig2 = keypair_clone.sign(message).unwrap();
505
506 assert_eq!(sig1, sig2);
508
509 let public_key = keypair.public_key();
511 assert!(public_key.verify(message, &sig1).is_ok());
512 assert!(public_key.verify(message, &sig2).is_ok());
513 }
514
515 #[test]
516 fn test_keypair_der_roundtrip() {
517 let keypair1 = KeyPair::generate(KeyPairAlgorithm::Ed25519).unwrap();
519 let message = b"Test message for DER roundtrip";
520
521 let der_bytes = keypair1.to_secret_der().unwrap();
523
524 assert_eq!(der_bytes[0], 0x30); let keypair2 = KeyPair::from_secret_der(&der_bytes).unwrap();
529
530 assert_eq!(keypair1.public_key_bytes(), keypair2.public_key_bytes());
532
533 let sig1 = keypair1.sign(message).unwrap();
535 let sig2 = keypair2.sign(message).unwrap();
536
537 let public_key = keypair1.public_key();
538 assert!(public_key.verify(message, &sig1).is_ok());
539 assert!(public_key.verify(message, &sig2).is_ok());
540 }
541
542 #[test]
543 fn test_keypair_from_der_invalid() {
544 let invalid_der = vec![0x00, 0x01, 0x02];
546 let result = KeyPair::from_secret_der(&invalid_der);
547 assert!(result.is_err());
548 assert!(matches!(
549 result.unwrap_err(),
550 CryptoError::InvalidDerFormat(_)
551 ));
552 }
553
554 #[test]
555 fn test_keypair_from_der_unsupported_algorithm() {
556 use pkcs8::{ObjectIdentifier, PrivateKeyInfo, der::Encode};
558
559 let unsupported_oid = ObjectIdentifier::new_unwrap("1.3.132.0.10");
561
562 let fake_key = vec![0x04, 0x20]; let fake_key = [&fake_key[..], &[0u8; 32]].concat();
564
565 let algorithm_identifier = pkcs8::AlgorithmIdentifierRef {
566 oid: unsupported_oid,
567 parameters: None,
568 };
569
570 let private_key_info = PrivateKeyInfo {
571 algorithm: algorithm_identifier,
572 private_key: &fake_key,
573 public_key: None,
574 };
575
576 let der_bytes = private_key_info.to_der().unwrap();
577
578 let result = KeyPair::from_secret_der(&der_bytes);
579 assert!(result.is_err());
580 assert!(matches!(
581 result.unwrap_err(),
582 CryptoError::UnsupportedAlgorithm(_)
583 ));
584 }
585}