Skip to main content

volt_client_grpc/
crypto.rs

1//! Cryptographic utilities for the Volt client.
2//!
3//! This module provides cryptographic operations including:
4//! - Key generation and management (Ed25519)
5//! - Signing and verification
6//! - Key exchange (X25519)
7//! - AES encryption/decryption (GCM and CBC modes)
8//! - HKDF key derivation
9//! - JWT token generation
10
11use crate::error::{Result, VoltError};
12use aes::cipher::{block_padding::Pkcs7, BlockDecryptMut, BlockEncryptMut, KeyIvInit};
13use aes_gcm::{
14    aead::{Aead, KeyInit},
15    Aes256Gcm, Nonce,
16};
17use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
18use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
19use hkdf::Hkdf;
20use rand::rngs::OsRng;
21use sha2::Sha256;
22use x25519_dalek::{EphemeralSecret, PublicKey as X25519PublicKey, StaticSecret};
23
24/// Length of AES-256 key in bytes
25pub const AES_KEY_LENGTH: usize = 32;
26/// Length of AES-GCM nonce in bytes  
27pub const AES_NONCE_LENGTH: usize = 12;
28/// Length of AES-CBC IV in bytes
29pub const AES_CBC_IV_LENGTH: usize = 16;
30
31/// HKDF salt for relay encryption key derivation
32pub const RELAY_HKDF_SALT: &str = "a06e10d13fa4445a";
33/// HKDF info for relay encryption key derivation
34pub const RELAY_HKDF_INFO: &str = "tdx-volt-encryption-key-derivation";
35
36// Type aliases for AES-CBC
37type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
38type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
39
40/// A signing key pair (Ed25519)
41#[derive(Clone)]
42pub struct SigningKeyPair {
43    signing_key: SigningKey,
44}
45
46impl SigningKeyPair {
47    /// Generate a new random signing key pair
48    pub fn generate() -> Self {
49        let signing_key = SigningKey::generate(&mut OsRng);
50        Self { signing_key }
51    }
52
53    /// Create from a private key in PEM format
54    pub fn from_pem(pem: &str) -> Result<Self> {
55        let pem_contents = pem
56            .lines()
57            .filter(|line| !line.starts_with("-----"))
58            .collect::<String>();
59
60        let key_bytes = BASE64
61            .decode(&pem_contents)
62            .map_err(|e| VoltError::key(format!("Invalid PEM encoding: {}", e)))?;
63
64        // Handle both raw 32-byte keys and PKCS#8 encoded keys
65        let secret_bytes = if key_bytes.len() == 32 {
66            key_bytes
67        } else if key_bytes.len() > 32 {
68            // Extract the last 32 bytes (typical for PKCS#8 where 16-byte header + 32-byte key)
69            key_bytes[key_bytes.len() - 32..].to_vec()
70        } else {
71            return Err(VoltError::key("Invalid key length"));
72        };
73
74        let secret_array: [u8; 32] = secret_bytes
75            .try_into()
76            .map_err(|_| VoltError::key("Invalid key length"))?;
77
78        let signing_key = SigningKey::from_bytes(&secret_array);
79        Ok(Self { signing_key })
80    }
81
82    /// Export private key to simple PEM format (just the raw key bytes)
83    pub fn private_key_pem(&self) -> String {
84        let encoded = BASE64.encode(self.signing_key.to_bytes());
85        format!(
86            "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----",
87            encoded
88        )
89    }
90
91    /// Export private key to PKCS#8 PEM format (for JWT signing with jsonwebtoken)
92    pub fn private_key_pkcs8_pem(&self) -> String {
93        // Ed25519 PKCS#8 structure:
94        // SEQUENCE {
95        //   INTEGER 0 (version)
96        //   SEQUENCE {
97        //     OBJECT IDENTIFIER 1.3.101.112 (Ed25519)
98        //   }
99        //   OCTET STRING containing OCTET STRING of private key
100        // }
101
102        // Ed25519 OID: 1.3.101.112
103        // DER prefix for PKCS#8 Ed25519 private key
104        let pkcs8_prefix: [u8; 16] = [
105            0x30, 0x2e, // SEQUENCE, length 46
106            0x02, 0x01, 0x00, // INTEGER 0 (version)
107            0x30, 0x05, // SEQUENCE, length 5
108            0x06, 0x03, 0x2b, 0x65, 0x70, // OID 1.3.101.112 (Ed25519)
109            0x04, 0x22, // OCTET STRING, length 34
110            0x04, 0x20, // OCTET STRING, length 32 (the actual key)
111        ];
112
113        let mut pkcs8_der = Vec::with_capacity(48);
114        pkcs8_der.extend_from_slice(&pkcs8_prefix);
115        pkcs8_der.extend_from_slice(self.signing_key.as_bytes());
116
117        let encoded = BASE64.encode(&pkcs8_der);
118        format!(
119            "-----BEGIN PRIVATE KEY-----\n{}\n-----END PRIVATE KEY-----",
120            encoded
121        )
122    }
123
124    /// Export public key to PEM format
125    pub fn public_key_pem(&self) -> String {
126        let encoded = BASE64.encode(self.signing_key.verifying_key().to_bytes());
127        format!(
128            "-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----",
129            encoded
130        )
131    }
132
133    /// Get the verifying (public) key
134    pub fn verifying_key(&self) -> VerifyingKey {
135        self.signing_key.verifying_key()
136    }
137
138    /// Sign data and return signature as bytes
139    pub fn sign(&self, data: &[u8]) -> Vec<u8> {
140        let signature = self.signing_key.sign(data);
141        signature.to_bytes().to_vec()
142    }
143
144    /// Sign data and return signature as base64
145    pub fn sign_base64(&self, data: &[u8]) -> String {
146        BASE64.encode(self.sign(data))
147    }
148
149    /// Get raw private key bytes
150    pub fn secret_bytes(&self) -> &[u8; 32] {
151        self.signing_key.as_bytes()
152    }
153}
154
155/// Verify a signature
156pub fn verify_signature(
157    public_key: &VerifyingKey,
158    message: &[u8],
159    signature: &[u8],
160) -> Result<bool> {
161    let sig_array: [u8; 64] = signature
162        .try_into()
163        .map_err(|_| VoltError::crypto("Invalid signature length"))?;
164    let signature = Signature::from_bytes(&sig_array);
165
166    Ok(public_key.verify(message, &signature).is_ok())
167}
168
169/// Parse a public key from PEM format
170pub fn public_key_from_pem(pem: &str) -> Result<VerifyingKey> {
171    let pem_contents = pem
172        .lines()
173        .filter(|line| !line.starts_with("-----"))
174        .collect::<String>();
175
176    let key_bytes = BASE64
177        .decode(&pem_contents)
178        .map_err(|e| VoltError::key(format!("Invalid PEM encoding: {}", e)))?;
179
180    // Handle both raw 32-byte keys and encoded keys
181    let public_bytes = if key_bytes.len() == 32 {
182        key_bytes
183    } else if key_bytes.len() > 32 {
184        // Extract the last 32 bytes
185        key_bytes[key_bytes.len() - 32..].to_vec()
186    } else {
187        return Err(VoltError::key("Invalid public key length"));
188    };
189
190    let key_array: [u8; 32] = public_bytes
191        .try_into()
192        .map_err(|_| VoltError::key("Invalid key length"))?;
193
194    VerifyingKey::from_bytes(&key_array)
195        .map_err(|e| VoltError::key(format!("Invalid public key: {}", e)))
196}
197
198/// Calculate fingerprint from a public key
199pub fn fingerprint_from_key(public_key: &VerifyingKey) -> String {
200    use ring::digest::{digest, SHA256};
201    let hash = digest(&SHA256, public_key.as_bytes());
202    BASE64.encode(hash.as_ref())
203}
204
205/// X25519 key exchange
206pub struct KeyExchange {
207    secret: Option<EphemeralSecret>,
208    public_key: X25519PublicKey,
209}
210
211impl std::fmt::Debug for KeyExchange {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.debug_struct("KeyExchange")
214            .field("public_key", &self.public_key_bytes())
215            .finish()
216    }
217}
218
219impl Clone for KeyExchange {
220    fn clone(&self) -> Self {
221        // Generate a new key exchange since we can't clone the secret
222        Self::new()
223    }
224}
225
226impl KeyExchange {
227    /// Generate a new key exchange pair
228    pub fn new() -> Self {
229        let secret = EphemeralSecret::random_from_rng(OsRng);
230        let public_key = X25519PublicKey::from(&secret);
231        Self {
232            secret: Some(secret),
233            public_key,
234        }
235    }
236
237    /// Get our public key bytes
238    pub fn public_key_bytes(&self) -> [u8; 32] {
239        *self.public_key.as_bytes()
240    }
241
242    /// Get our public key as PEM
243    pub fn public_key_pem(&self) -> String {
244        let encoded = BASE64.encode(self.public_key.as_bytes());
245        format!(
246            "-----BEGIN PUBLIC KEY-----\n{}\n-----END PUBLIC KEY-----",
247            encoded
248        )
249    }
250
251    /// Derive shared secret from peer's public key
252    /// Note: This consumes the secret, so it can only be called once
253    pub fn derive_shared_key(&mut self, peer_public_key: &[u8; 32]) -> Result<[u8; 32]> {
254        let secret = self
255            .secret
256            .take()
257            .ok_or_else(|| VoltError::crypto("Key exchange already consumed"))?;
258        let peer_key = X25519PublicKey::from(*peer_public_key);
259        let shared_secret = secret.diffie_hellman(&peer_key);
260        Ok(*shared_secret.as_bytes())
261    }
262}
263
264impl Default for KeyExchange {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270/// AES-256-GCM encryption key
271pub struct AesKey {
272    key: [u8; AES_KEY_LENGTH],
273    iv: [u8; AES_NONCE_LENGTH],
274}
275
276impl AesKey {
277    /// Create a new random AES key
278    pub fn generate() -> Self {
279        let mut key = [0u8; AES_KEY_LENGTH];
280        let mut iv = [0u8; AES_NONCE_LENGTH];
281
282        use rand::RngCore;
283        OsRng.fill_bytes(&mut key);
284        OsRng.fill_bytes(&mut iv);
285
286        Self { key, iv }
287    }
288
289    /// Create from existing key and IV
290    pub fn from_bytes(key: [u8; AES_KEY_LENGTH], iv: [u8; AES_NONCE_LENGTH]) -> Self {
291        Self { key, iv }
292    }
293
294    /// Get the IV/nonce
295    pub fn iv(&self) -> &[u8; AES_NONCE_LENGTH] {
296        &self.iv
297    }
298
299    /// Get the key bytes
300    pub fn key(&self) -> &[u8; AES_KEY_LENGTH] {
301        &self.key
302    }
303}
304
305/// Encrypt data using AES-256-GCM
306pub fn aes_encrypt(
307    key: &[u8; AES_KEY_LENGTH],
308    iv: &[u8; AES_NONCE_LENGTH],
309    plaintext: &[u8],
310) -> Result<Vec<u8>> {
311    let cipher =
312        Aes256Gcm::new_from_slice(key).map_err(|e| VoltError::EncryptionError(e.to_string()))?;
313
314    let nonce = Nonce::from_slice(iv);
315
316    cipher
317        .encrypt(nonce, plaintext)
318        .map_err(|e| VoltError::EncryptionError(e.to_string()))
319}
320
321/// Decrypt data using AES-256-GCM
322pub fn aes_decrypt(
323    key: &[u8; AES_KEY_LENGTH],
324    iv: &[u8; AES_NONCE_LENGTH],
325    ciphertext: &[u8],
326) -> Result<Vec<u8>> {
327    let cipher =
328        Aes256Gcm::new_from_slice(key).map_err(|e| VoltError::DecryptionError(e.to_string()))?;
329
330    let nonce = Nonce::from_slice(iv);
331
332    cipher
333        .decrypt(nonce, ciphertext)
334        .map_err(|e| VoltError::DecryptionError(e.to_string()))
335}
336
337/// Generate random bytes
338pub fn random_bytes(len: usize) -> Vec<u8> {
339    let mut bytes = vec![0u8; len];
340    use rand::RngCore;
341    OsRng.fill_bytes(&mut bytes);
342    bytes
343}
344
345/// Convert bytes to base64
346pub fn to_base64(data: &[u8]) -> String {
347    BASE64.encode(data)
348}
349
350/// Convert base64 to bytes
351pub fn from_base64(data: &str) -> Result<Vec<u8>> {
352    BASE64.decode(data).map_err(VoltError::from)
353}
354
355/// Strip PEM headers from a string
356pub fn strip_pem_headers(pem: &str) -> String {
357    pem.lines()
358        .filter(|line| !line.starts_with("-----"))
359        .collect()
360}
361
362/// Format data as PEM with given label
363pub fn format_pem(data: &[u8], label: &str) -> String {
364    let encoded = BASE64.encode(data);
365    format!(
366        "-----BEGIN {}-----\n{}\n-----END {}-----",
367        label, encoded, label
368    )
369}
370
371// ==============================================================================
372// AES-CBC encryption/decryption (for relay protocol)
373// ==============================================================================
374
375/// Encrypt data using AES-256-CBC with PKCS7 padding
376///
377/// This is used by the relay protocol for encrypted communication.
378pub fn aes_cbc_encrypt(
379    key: &[u8; AES_KEY_LENGTH],
380    iv: &[u8; AES_CBC_IV_LENGTH],
381    plaintext: &[u8],
382) -> Result<Vec<u8>> {
383    // Calculate output size: plaintext + padding (up to 16 bytes)
384    let block_size = 16;
385    let padded_len = ((plaintext.len() / block_size) + 1) * block_size;
386    let mut buffer = vec![0u8; padded_len];
387    buffer[..plaintext.len()].copy_from_slice(plaintext);
388
389    let cipher = Aes256CbcEnc::new_from_slices(key, iv)
390        .map_err(|e| VoltError::EncryptionError(format!("Invalid key/IV: {}", e)))?;
391
392    let ciphertext = cipher
393        .encrypt_padded_mut::<Pkcs7>(&mut buffer, plaintext.len())
394        .map_err(|e| VoltError::EncryptionError(format!("Encryption failed: {:?}", e)))?;
395
396    Ok(ciphertext.to_vec())
397}
398
399/// Decrypt data using AES-256-CBC with PKCS7 padding
400///
401/// This is used by the relay protocol for encrypted communication.
402pub fn aes_cbc_decrypt(
403    key: &[u8; AES_KEY_LENGTH],
404    iv: &[u8; AES_CBC_IV_LENGTH],
405    ciphertext: &[u8],
406) -> Result<Vec<u8>> {
407    let mut buffer = ciphertext.to_vec();
408
409    let cipher = Aes256CbcDec::new_from_slices(key, iv)
410        .map_err(|e| VoltError::DecryptionError(format!("Invalid key/IV: {}", e)))?;
411
412    let plaintext = cipher
413        .decrypt_padded_mut::<Pkcs7>(&mut buffer)
414        .map_err(|e| VoltError::DecryptionError(format!("Decryption failed: {:?}", e)))?;
415
416    Ok(plaintext.to_vec())
417}
418
419/// Generate a random 16-byte IV for AES-CBC
420pub fn random_iv() -> [u8; AES_CBC_IV_LENGTH] {
421    let mut iv = [0u8; AES_CBC_IV_LENGTH];
422    use rand::RngCore;
423    OsRng.fill_bytes(&mut iv);
424    iv
425}
426
427// ==============================================================================
428// HKDF key derivation (for relay protocol)
429// ==============================================================================
430
431/// Derive a shared encryption key using HKDF-SHA256
432///
433/// This combines the X25519 shared secret with HKDF to derive the final AES key.
434/// Uses the relay-specific salt and info strings.
435pub fn derive_relay_key(shared_secret: &[u8; 32]) -> Result<[u8; AES_KEY_LENGTH]> {
436    let hk = Hkdf::<Sha256>::new(Some(RELAY_HKDF_SALT.as_bytes()), shared_secret);
437
438    let mut derived_key = [0u8; AES_KEY_LENGTH];
439    hk.expand(RELAY_HKDF_INFO.as_bytes(), &mut derived_key)
440        .map_err(|e| VoltError::crypto(format!("HKDF expansion failed: {}", e)))?;
441
442    Ok(derived_key)
443}
444
445// ==============================================================================
446// X25519 key exchange with static secret support (for relay protocol)
447// ==============================================================================
448
449/// X25519 key exchange with a static (reusable) secret
450///
451/// Unlike `KeyExchange` which uses ephemeral secrets that can only be used once,
452/// this struct uses a static secret that can be reused for multiple key derivations.
453pub struct StaticKeyExchange {
454    secret: StaticSecret,
455    public_key: X25519PublicKey,
456}
457
458impl std::fmt::Debug for StaticKeyExchange {
459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        f.debug_struct("StaticKeyExchange")
461            .field("public_key", &self.public_key_bytes())
462            .finish()
463    }
464}
465
466impl StaticKeyExchange {
467    /// Generate a new key exchange pair
468    pub fn new() -> Self {
469        let secret = StaticSecret::random_from_rng(OsRng);
470        let public_key = X25519PublicKey::from(&secret);
471        Self { secret, public_key }
472    }
473
474    /// Create from existing secret bytes
475    pub fn from_secret(secret_bytes: [u8; 32]) -> Self {
476        let secret = StaticSecret::from(secret_bytes);
477        let public_key = X25519PublicKey::from(&secret);
478        Self { secret, public_key }
479    }
480
481    /// Get our public key bytes
482    pub fn public_key_bytes(&self) -> [u8; 32] {
483        *self.public_key.as_bytes()
484    }
485
486    /// Get our public key as base64
487    pub fn public_key_base64(&self) -> String {
488        BASE64.encode(self.public_key.as_bytes())
489    }
490
491    /// Derive shared secret from peer's public key
492    /// This can be called multiple times (unlike ephemeral key exchange)
493    pub fn derive_shared_secret(&self, peer_public_key: &[u8; 32]) -> [u8; 32] {
494        let peer_key = X25519PublicKey::from(*peer_public_key);
495        let shared_secret = self.secret.diffie_hellman(&peer_key);
496        *shared_secret.as_bytes()
497    }
498
499    /// Derive a relay-compatible encryption key from peer's public key
500    /// This performs X25519 ECDH followed by HKDF key derivation
501    pub fn derive_relay_encryption_key(
502        &self,
503        peer_public_key: &[u8; 32],
504    ) -> Result<[u8; AES_KEY_LENGTH]> {
505        let shared_secret = self.derive_shared_secret(peer_public_key);
506        derive_relay_key(&shared_secret)
507    }
508}
509
510impl Default for StaticKeyExchange {
511    fn default() -> Self {
512        Self::new()
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn test_signing_key_generation() {
522        let key = SigningKeyPair::generate();
523        let private_pem = key.private_key_pem();
524        let public_pem = key.public_key_pem();
525
526        assert!(private_pem.contains("BEGIN PRIVATE KEY"));
527        assert!(public_pem.contains("BEGIN PUBLIC KEY"));
528    }
529
530    #[test]
531    fn test_sign_and_verify() {
532        let key = SigningKeyPair::generate();
533        let message = b"Hello, World!";
534
535        let signature = key.sign(message);
536        let verified = verify_signature(&key.verifying_key(), message, &signature).unwrap();
537
538        assert!(verified);
539    }
540
541    #[test]
542    fn test_aes_encrypt_decrypt() {
543        let aes_key = AesKey::generate();
544        let plaintext = b"Secret message!";
545
546        let ciphertext = aes_encrypt(aes_key.key(), aes_key.iv(), plaintext).unwrap();
547        let decrypted = aes_decrypt(aes_key.key(), aes_key.iv(), &ciphertext).unwrap();
548
549        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
550    }
551
552    #[test]
553    fn test_key_exchange() {
554        let mut alice = KeyExchange::new();
555        let mut bob = KeyExchange::new();
556
557        let alice_public = alice.public_key_bytes();
558        let bob_public = bob.public_key_bytes();
559
560        let alice_shared = alice.derive_shared_key(&bob_public).unwrap();
561        let bob_shared = bob.derive_shared_key(&alice_public).unwrap();
562
563        assert_eq!(alice_shared, bob_shared);
564    }
565
566    #[test]
567    fn test_aes_cbc_encrypt_decrypt() {
568        let key = [0u8; AES_KEY_LENGTH];
569        let iv = [0u8; AES_CBC_IV_LENGTH];
570        let plaintext = b"Secret message for relay!";
571
572        let ciphertext = aes_cbc_encrypt(&key, &iv, plaintext).unwrap();
573        let decrypted = aes_cbc_decrypt(&key, &iv, &ciphertext).unwrap();
574
575        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
576    }
577
578    #[test]
579    fn test_aes_cbc_with_random_iv() {
580        let key = random_bytes(AES_KEY_LENGTH).try_into().unwrap();
581        let iv = random_iv();
582        let plaintext = b"Another secret message with random IV!";
583
584        let ciphertext = aes_cbc_encrypt(&key, &iv, plaintext).unwrap();
585        let decrypted = aes_cbc_decrypt(&key, &iv, &ciphertext).unwrap();
586
587        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
588    }
589
590    #[test]
591    fn test_static_key_exchange() {
592        let alice = StaticKeyExchange::new();
593        let bob = StaticKeyExchange::new();
594
595        let alice_public = alice.public_key_bytes();
596        let bob_public = bob.public_key_bytes();
597
598        // Static key exchange can derive multiple times
599        let alice_shared1 = alice.derive_shared_secret(&bob_public);
600        let alice_shared2 = alice.derive_shared_secret(&bob_public);
601        let bob_shared = bob.derive_shared_secret(&alice_public);
602
603        assert_eq!(alice_shared1, alice_shared2);
604        assert_eq!(alice_shared1, bob_shared);
605    }
606
607    #[test]
608    fn test_derive_relay_key() {
609        let alice = StaticKeyExchange::new();
610        let bob = StaticKeyExchange::new();
611
612        let alice_public = alice.public_key_bytes();
613        let bob_public = bob.public_key_bytes();
614
615        let alice_key = alice.derive_relay_encryption_key(&bob_public).unwrap();
616        let bob_key = bob.derive_relay_encryption_key(&alice_public).unwrap();
617
618        assert_eq!(alice_key, bob_key);
619        assert_eq!(alice_key.len(), AES_KEY_LENGTH);
620    }
621
622    #[test]
623    fn test_full_relay_encryption_flow() {
624        // Simulate the full relay encryption flow
625        let client = StaticKeyExchange::new();
626        let server = StaticKeyExchange::new();
627
628        // Exchange public keys
629        let client_public = client.public_key_bytes();
630        let server_public = server.public_key_bytes();
631
632        // Derive encryption keys
633        let client_key = client.derive_relay_encryption_key(&server_public).unwrap();
634        let server_key = server.derive_relay_encryption_key(&client_public).unwrap();
635
636        // Keys should match
637        assert_eq!(client_key, server_key);
638
639        // Encrypt a message from client to server
640        let iv = random_iv();
641        let message = b"Hello from client!";
642        let encrypted = aes_cbc_encrypt(&client_key, &iv, message).unwrap();
643
644        // Server decrypts
645        let decrypted = aes_cbc_decrypt(&server_key, &iv, &encrypted).unwrap();
646        assert_eq!(message.as_slice(), decrypted.as_slice());
647    }
648}