Skip to main content

astrid_crypto/
keypair.rs

1//! Ed25519 key pairs with secure memory handling.
2//!
3//! Provides key generation, signing, and verification for:
4//! - Runtime identity (signs audit entries, capability tokens)
5//! - User identity verification (optional user signing keys)
6
7use ed25519_dalek::{Signer, SigningKey, VerifyingKey};
8use rand::rngs::OsRng;
9use serde::{Deserialize, Serialize};
10use zeroize::{Zeroize, ZeroizeOnDrop};
11
12use crate::error::{CryptoError, CryptoResult};
13use crate::signature::Signature;
14
15/// An Ed25519 key pair with secure memory handling.
16///
17/// The secret key is zeroized on drop to prevent leaking sensitive material.
18#[derive(ZeroizeOnDrop)]
19pub struct KeyPair {
20    #[zeroize(skip)] // VerifyingKey doesn't implement Zeroize
21    verifying_key: VerifyingKey,
22    signing_key: SigningKey,
23}
24
25impl KeyPair {
26    /// Generate a new random key pair.
27    #[must_use]
28    pub fn generate() -> Self {
29        let signing_key = SigningKey::generate(&mut OsRng);
30        let verifying_key = signing_key.verifying_key();
31        Self {
32            verifying_key,
33            signing_key,
34        }
35    }
36
37    /// Create from a secret key (32 bytes).
38    ///
39    /// # Errors
40    ///
41    /// Returns [`CryptoError::InvalidKeyLength`] if the slice is not exactly 32 bytes.
42    pub fn from_secret_key(bytes: &[u8]) -> CryptoResult<Self> {
43        if bytes.len() != 32 {
44            return Err(CryptoError::InvalidKeyLength {
45                expected: 32,
46                actual: bytes.len(),
47            });
48        }
49
50        let mut secret = [0u8; 32];
51        secret.copy_from_slice(bytes);
52
53        let signing_key = SigningKey::from_bytes(&secret);
54        let verifying_key = signing_key.verifying_key();
55
56        // Zeroize the temporary buffer
57        secret.zeroize();
58
59        Ok(Self {
60            verifying_key,
61            signing_key,
62        })
63    }
64
65    /// Get the public key bytes (32 bytes).
66    #[must_use]
67    pub fn public_key_bytes(&self) -> &[u8; 32] {
68        self.verifying_key.as_bytes()
69    }
70
71    /// Get a short key ID (first 8 bytes of public key).
72    ///
73    /// Useful for identifying keys in logs without exposing the full key.
74    #[must_use]
75    pub fn key_id(&self) -> [u8; 8] {
76        let mut id = [0u8; 8];
77        id.copy_from_slice(&self.public_key_bytes()[..8]);
78        id
79    }
80
81    /// Get the key ID as a hex string.
82    #[must_use]
83    pub fn key_id_hex(&self) -> String {
84        hex::encode(self.key_id())
85    }
86
87    /// Sign a message.
88    #[must_use]
89    pub fn sign(&self, message: &[u8]) -> Signature {
90        let sig = self.signing_key.sign(message);
91        Signature::from(sig)
92    }
93
94    /// Verify a signature (convenience method using our public key).
95    ///
96    /// # Errors
97    ///
98    /// Returns [`CryptoError::SignatureVerificationFailed`] if verification fails.
99    pub fn verify(&self, message: &[u8], signature: &Signature) -> CryptoResult<()> {
100        signature.verify(message, self.public_key_bytes())
101    }
102
103    /// Export the public key for serialization.
104    #[must_use]
105    pub fn export_public_key(&self) -> PublicKey {
106        PublicKey::from_bytes(*self.public_key_bytes())
107    }
108
109    /// Export the secret key bytes (careful - sensitive!).
110    ///
111    /// This should only be used for secure storage.
112    #[must_use]
113    pub fn secret_key_bytes(&self) -> [u8; 32] {
114        self.signing_key.to_bytes()
115    }
116}
117
118impl std::fmt::Debug for KeyPair {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("KeyPair")
121            .field("key_id", &self.key_id_hex())
122            .finish_non_exhaustive()
123    }
124}
125
126/// A public key (safe to share, serialize, etc.).
127#[derive(Clone, Copy, PartialEq, Eq, Hash)]
128pub struct PublicKey([u8; 32]);
129
130impl PublicKey {
131    /// Create from raw bytes.
132    #[must_use]
133    pub const fn from_bytes(bytes: [u8; 32]) -> Self {
134        Self(bytes)
135    }
136
137    /// Try to create from a slice.
138    ///
139    /// # Errors
140    ///
141    /// Returns [`CryptoError::InvalidKeyLength`] if the slice is not exactly 32 bytes.
142    pub fn try_from_slice(slice: &[u8]) -> CryptoResult<Self> {
143        if slice.len() != 32 {
144            return Err(CryptoError::InvalidKeyLength {
145                expected: 32,
146                actual: slice.len(),
147            });
148        }
149        let mut bytes = [0u8; 32];
150        bytes.copy_from_slice(slice);
151        Ok(Self(bytes))
152    }
153
154    /// Get the raw bytes.
155    #[must_use]
156    pub const fn as_bytes(&self) -> &[u8; 32] {
157        &self.0
158    }
159
160    /// Get a short key ID (first 8 bytes).
161    #[must_use]
162    pub fn key_id(&self) -> [u8; 8] {
163        let mut id = [0u8; 8];
164        id.copy_from_slice(&self.0[..8]);
165        id
166    }
167
168    /// Get the key ID as a hex string.
169    #[must_use]
170    pub fn key_id_hex(&self) -> String {
171        hex::encode(self.key_id())
172    }
173
174    /// Encode as hex string.
175    #[must_use]
176    pub fn to_hex(&self) -> String {
177        hex::encode(self.0)
178    }
179
180    /// Decode from hex string.
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if the string is not valid hex or not 32 bytes.
185    pub fn from_hex(s: &str) -> CryptoResult<Self> {
186        let bytes = hex::decode(s).map_err(|_| CryptoError::InvalidHexEncoding)?;
187        Self::try_from_slice(&bytes)
188    }
189
190    /// Encode as base64 string.
191    #[must_use]
192    pub fn to_base64(&self) -> String {
193        use base64::Engine;
194        base64::engine::general_purpose::STANDARD.encode(self.0)
195    }
196
197    /// Decode from base64 string.
198    ///
199    /// # Errors
200    ///
201    /// Returns an error if the string is not valid base64 or not 32 bytes.
202    pub fn from_base64(s: &str) -> CryptoResult<Self> {
203        use base64::Engine;
204        let bytes = base64::engine::general_purpose::STANDARD
205            .decode(s)
206            .map_err(|_| CryptoError::InvalidBase64Encoding)?;
207        Self::try_from_slice(&bytes)
208    }
209
210    /// Verify a signature against this public key.
211    ///
212    /// # Errors
213    ///
214    /// Returns [`CryptoError::SignatureVerificationFailed`] if verification fails.
215    pub fn verify(&self, message: &[u8], signature: &Signature) -> CryptoResult<()> {
216        signature.verify(message, &self.0)
217    }
218}
219
220impl std::fmt::Debug for PublicKey {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        write!(f, "PublicKey({})", self.key_id_hex())
223    }
224}
225
226impl std::fmt::Display for PublicKey {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(f, "{}", self.to_hex())
229    }
230}
231
232impl Serialize for PublicKey {
233    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
234    where
235        S: serde::Serializer,
236    {
237        serializer.serialize_str(&self.to_base64())
238    }
239}
240
241impl<'de> Deserialize<'de> for PublicKey {
242    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
243    where
244        D: serde::Deserializer<'de>,
245    {
246        let s = String::deserialize(deserializer)?;
247        Self::from_base64(&s).map_err(serde::de::Error::custom)
248    }
249}
250
251impl From<[u8; 32]> for PublicKey {
252    fn from(bytes: [u8; 32]) -> Self {
253        Self(bytes)
254    }
255}
256
257impl From<PublicKey> for [u8; 32] {
258    fn from(pk: PublicKey) -> Self {
259        pk.0
260    }
261}
262
263impl AsRef<[u8]> for PublicKey {
264    fn as_ref(&self) -> &[u8] {
265        &self.0
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_keypair_generation() {
275        let kp1 = KeyPair::generate();
276        let kp2 = KeyPair::generate();
277
278        // Different keypairs have different public keys
279        assert_ne!(kp1.public_key_bytes(), kp2.public_key_bytes());
280    }
281
282    #[test]
283    fn test_keypair_from_secret() {
284        let original = KeyPair::generate();
285        let secret = original.secret_key_bytes();
286
287        let restored = KeyPair::from_secret_key(&secret).unwrap();
288
289        assert_eq!(original.public_key_bytes(), restored.public_key_bytes());
290    }
291
292    #[test]
293    fn test_sign_verify() {
294        let keypair = KeyPair::generate();
295        let message = b"hello world";
296
297        let signature = keypair.sign(message);
298        assert!(keypair.verify(message, &signature).is_ok());
299
300        // Wrong message fails
301        assert!(keypair.verify(b"wrong", &signature).is_err());
302    }
303
304    #[test]
305    fn test_key_id() {
306        let keypair = KeyPair::generate();
307        let key_id = keypair.key_id();
308
309        // Key ID is first 8 bytes of public key
310        assert_eq!(&key_id[..], &keypair.public_key_bytes()[..8]);
311
312        // Hex encoding works
313        let hex_id = keypair.key_id_hex();
314        assert_eq!(hex_id.len(), 16); // 8 bytes = 16 hex chars
315    }
316
317    #[test]
318    fn test_public_key_encoding() {
319        let keypair = KeyPair::generate();
320        let pk = keypair.export_public_key();
321
322        // Hex roundtrip
323        let hex = pk.to_hex();
324        let decoded = PublicKey::from_hex(&hex).unwrap();
325        assert_eq!(pk, decoded);
326
327        // Base64 roundtrip
328        let b64 = pk.to_base64();
329        let decoded = PublicKey::from_base64(&b64).unwrap();
330        assert_eq!(pk, decoded);
331    }
332
333    #[test]
334    fn test_public_key_verify() {
335        let keypair = KeyPair::generate();
336        let pk = keypair.export_public_key();
337        let message = b"test";
338
339        let sig = keypair.sign(message);
340        assert!(pk.verify(message, &sig).is_ok());
341    }
342
343    #[test]
344    fn test_invalid_key_length() {
345        let result = KeyPair::from_secret_key(&[0u8; 31]);
346        assert!(matches!(result, Err(CryptoError::InvalidKeyLength { .. })));
347    }
348}