Skip to main content

alterion_ecdh/
keystore.rs

1// SPDX-License-Identifier: GPL-3.0
2use std::sync::Arc;
3use rand_core::OsRng;
4use tokio::sync::RwLock;
5use x25519_dalek::{StaticSecret, PublicKey};
6use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
7use chrono::{DateTime, Duration, Utc};
8use uuid::Uuid;
9use zeroize::Zeroizing;
10use dashmap::DashMap;
11
12/// A single X25519 key pair entry managed by [`KeyStore`].
13///
14/// Derives [`zeroize::ZeroizeOnDrop`] so that `secret` is guaranteed to be zeroed whenever
15/// a `KeyEntry` is dropped — including when it is moved or replaced during key rotation —
16/// regardless of whether the compiler relocated the struct in memory beforehand.
17#[derive(zeroize::ZeroizeOnDrop)]
18pub struct KeyEntry {
19    pub secret:         StaticSecret,
20    #[zeroize(skip)]
21    pub key_id:         String,
22    #[zeroize(skip)]
23    pub public_key_b64: String,
24    #[zeroize(skip)]
25    pub public_key_raw: [u8; 32],
26    #[zeroize(skip)]
27    pub created_at:     DateTime<Utc>,
28    #[zeroize(skip)]
29    pub expires_at:     DateTime<Utc>,
30}
31
32pub struct KeyStore {
33    pub current:  KeyEntry,
34    pub previous: Option<KeyEntry>,
35    /// Pre-warmed entry generated `KEY_WARM_LEAD_SECS` before the next rotation.
36    /// The rotation tick consumes this instead of generating a new key on the hot path.
37    pub next:     Option<KeyEntry>,
38}
39
40/// A single-use ephemeral server key pair created by [`init_handshake`].
41/// Consumed and removed by the first [`ecdh`] call that references its ID.
42struct HandshakeEntry {
43    secret:     StaticSecret,
44    public_raw: [u8; 32],
45    expires_at: DateTime<Utc>,
46}
47
48/// Thread-safe store of pending ephemeral handshake entries, keyed by handshake ID.
49///
50/// Separate from [`KeyStore`] so that handshake writes (use-once removals) never contend
51/// with the read-heavy static key lock.
52#[derive(Clone)]
53pub struct HandshakeStore(Arc<DashMap<String, HandshakeEntry>>);
54
55#[derive(Debug, thiserror::Error)]
56pub enum EcdhError {
57    #[error("key_expired")]
58    KeyExpired,
59    #[error("invalid client public key")]
60    InvalidClientKey,
61    #[error("key generation failed: {0}")]
62    KeyGenerationFailed(String),
63}
64
65const KEY_GRACE_SECS:       u64   = 300;
66const KEY_WARM_LEAD_SECS:   u64   = 600;
67const HANDSHAKE_TTL_SECS:   i64   = 60;
68/// Maximum number of pending handshake entries. Prevents memory exhaustion from a handshake flood.
69/// Returns `None` from [`init_handshake`] when the cap is reached; the ECDH endpoint should
70/// respond with 503 Service Unavailable in that case.
71pub const MAX_HANDSHAKES:   usize = 10_000;
72
73fn generate_entry(interval_secs: u64) -> KeyEntry {
74    let secret     = StaticSecret::random_from_rng(OsRng);
75    let public_key = PublicKey::from(&secret);
76    let raw        = *public_key.as_bytes();
77    let now        = Utc::now();
78    let secs       = i64::try_from(interval_secs + KEY_GRACE_SECS)
79        .expect("interval overflow");
80    KeyEntry {
81        key_id:         Uuid::new_v4().to_string(),
82        public_key_b64: B64.encode(raw),
83        public_key_raw: raw,
84        secret,
85        created_at:     now,
86        expires_at:     now + Duration::seconds(secs),
87    }
88}
89
90/// Generates an initial X25519 key pair and wraps it in a shared, RwLock-guarded `KeyStore`.
91pub fn init_key_store(interval_secs: u64) -> Arc<RwLock<KeyStore>> {
92    Arc::new(RwLock::new(KeyStore {
93        current:  generate_entry(interval_secs),
94        previous: None,
95        next:     None,
96    }))
97}
98
99/// Creates an empty `HandshakeStore`. Call once at startup and share the handle across all workers.
100pub fn init_handshake_store() -> HandshakeStore {
101    HandshakeStore(Arc::new(DashMap::new()))
102}
103
104/// Generates a fresh ephemeral X25519 key pair, stores the private key in `hs` with a 60-second
105/// TTL, and returns `Some((handshake_id, base64_public_key))`.
106///
107/// Returns `None` when the store is at capacity ([`MAX_HANDSHAKES`]); callers should respond with
108/// 503 Service Unavailable. The private key is consumed on the first matching [`ecdh_ephemeral`]
109/// call. Any entry not consumed within `HANDSHAKE_TTL_SECS` is pruned by [`prune_handshakes`].
110pub fn init_handshake(hs: &HandshakeStore) -> Option<(String, String)> {
111    if hs.0.len() >= MAX_HANDSHAKES {
112        tracing::warn!("alterion-ecdh: handshake store at capacity — rejecting new handshake");
113        return None;
114    }
115    let secret     = StaticSecret::random_from_rng(OsRng);
116    let public_key = PublicKey::from(&secret);
117    let raw        = *public_key.as_bytes();
118    let id         = format!("hs_{}", Uuid::new_v4());
119    hs.0.insert(id.clone(), HandshakeEntry {
120        secret,
121        public_raw: raw,
122        expires_at: Utc::now() + Duration::seconds(HANDSHAKE_TTL_SECS),
123    });
124    Some((id, B64.encode(raw)))
125}
126
127/// Performs a use-once X25519 ECDH using a handshake entry created by [`init_handshake`].
128/// Removes the entry on success — replaying the same handshake ID returns `EcdhError::KeyExpired`.
129pub async fn ecdh_ephemeral(
130    hs:              &HandshakeStore,
131    handshake_id:    &str,
132    client_pk_bytes: &[u8; 32],
133) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
134    let entry = hs.0.remove(handshake_id)
135        .ok_or(EcdhError::KeyExpired)?;
136    let (_, entry) = entry;
137    if Utc::now() > entry.expires_at {
138        return Err(EcdhError::KeyExpired);
139    }
140    let client_public = PublicKey::from(*client_pk_bytes);
141    let shared        = entry.secret.diffie_hellman(&client_public);
142    Ok((Zeroizing::new(*shared.as_bytes()), entry.public_raw))
143}
144
145/// Removes all expired handshake entries from `hs`. Call periodically (e.g. every 30 seconds).
146pub fn prune_handshakes(hs: &HandshakeStore) {
147    let now = Utc::now();
148    hs.0.retain(|_, v| v.expires_at > now);
149}
150
151/// Spawns two background tasks:
152/// - **Warm-up**: generates the next key `KEY_WARM_LEAD_SECS` before each rotation and stores it
153///   in `KeyStore::next` so key generation never blocks the hot path.
154/// - **Rotation**: swaps `next` (or falls back to a fresh key) into `current` and retires the old
155///   key to `previous` for the grace-window period.
156/// - **Cleanup**: prunes `previous` once its grace window expires, and prunes expired handshake
157///   entries from `hs` every 30 seconds.
158pub fn start_rotation(store: Arc<RwLock<KeyStore>>, interval_secs: u64, hs: HandshakeStore) {
159    let warm_lead = KEY_WARM_LEAD_SECS.min(interval_secs.saturating_sub(1));
160    let warm_offset = interval_secs.saturating_sub(warm_lead);
161
162    let store_warm    = store.clone();
163    let store_rotate  = store.clone();
164
165    tokio::spawn(async move {
166        let mut warm_interval = tokio::time::interval_at(
167            tokio::time::Instant::now() + tokio::time::Duration::from_secs(warm_offset),
168            tokio::time::Duration::from_secs(interval_secs),
169        );
170        loop {
171            warm_interval.tick().await;
172            let next = tokio::task::spawn_blocking(move || generate_entry(interval_secs))
173                .await
174                .expect("key generation panicked");
175            store_warm.write().await.next = Some(next);
176            tracing::debug!("next X25519 key pre-warmed");
177        }
178    });
179
180    tokio::spawn(async move {
181        let mut rotation_interval = tokio::time::interval_at(
182            tokio::time::Instant::now() + tokio::time::Duration::from_secs(interval_secs),
183            tokio::time::Duration::from_secs(interval_secs),
184        );
185        let mut cleanup_interval = tokio::time::interval(
186            tokio::time::Duration::from_secs(30),
187        );
188        loop {
189            tokio::select! {
190                _ = rotation_interval.tick() => {
191                    let mut w = store_rotate.write().await;
192                    let new_entry = w.next.take().unwrap_or_else(|| generate_entry(interval_secs));
193                    let old = std::mem::replace(&mut w.current, new_entry);
194                    w.previous = Some(old);
195                    tracing::info!("X25519 key rotated → {}", w.current.key_id);
196                }
197                _ = cleanup_interval.tick() => {
198                    let needs_cleanup = {
199                        let r = store_rotate.read().await;
200                        r.previous.as_ref().map_or(false, |p| Utc::now() > p.expires_at)
201                    };
202                    if needs_cleanup {
203                        store_rotate.write().await.previous = None;
204                        tracing::debug!("previous X25519 key pruned");
205                    }
206                    prune_handshakes(&hs);
207                }
208            }
209        }
210    });
211}
212
213/// Returns `(key_id, base64_public_key)` for the currently active key.
214pub async fn get_current_public_key(store: &Arc<RwLock<KeyStore>>) -> (String, String) {
215    let guard = store.read().await;
216    (guard.current.key_id.clone(), guard.current.public_key_b64.clone())
217}
218
219/// Performs X25519 ECDH using the server key identified by `key_id` and the client's
220/// ephemeral public key bytes. Returns `(shared_secret, server_public_key_bytes)`.
221///
222/// Falls back to the previous key within its grace window; returns `EcdhError::KeyExpired` otherwise.
223pub async fn ecdh(
224    store:           &Arc<RwLock<KeyStore>>,
225    key_id:          &str,
226    client_pk_bytes: &[u8; 32],
227) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
228    let guard = store.read().await;
229
230    let entry = if guard.current.key_id == key_id {
231        &guard.current
232    } else if let Some(prev) = &guard.previous {
233        if prev.key_id == key_id {
234            if Utc::now() > prev.expires_at {
235                return Err(EcdhError::KeyExpired);
236            }
237            prev
238        } else {
239            return Err(EcdhError::KeyExpired);
240        }
241    } else {
242        return Err(EcdhError::KeyExpired);
243    };
244
245    let client_public  = PublicKey::from(*client_pk_bytes);
246    let shared         = entry.secret.diffie_hellman(&client_public);
247    let server_pub_raw = entry.public_key_raw;
248
249    Ok((Zeroizing::new(*shared.as_bytes()), server_pub_raw))
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[tokio::test]
257    async fn init_produces_valid_keypair() {
258        let store         = init_key_store(3600);
259        let (key_id, b64) = get_current_public_key(&store).await;
260        assert!(!key_id.is_empty());
261        let bytes = B64.decode(&b64).unwrap();
262        assert_eq!(bytes.len(), 32);
263    }
264
265    #[tokio::test]
266    async fn ecdh_roundtrip() {
267        let store = init_key_store(3600);
268        let (key_id, b64) = get_current_public_key(&store).await;
269        let server_pub_bytes: [u8; 32] = B64.decode(&b64).unwrap().try_into().unwrap();
270
271        // Simulate client side
272        let client_secret = StaticSecret::random_from_rng(OsRng);
273        let client_public = PublicKey::from(&client_secret);
274        let client_shared = client_secret.diffie_hellman(&PublicKey::from(server_pub_bytes));
275
276        // Server side
277        let (server_shared, _) = ecdh(&store, &key_id, client_public.as_bytes()).await.unwrap();
278
279        assert_eq!(client_shared.as_bytes(), server_shared.as_slice());
280    }
281
282    #[tokio::test]
283    async fn unknown_key_id_returns_expired() {
284        let store = init_key_store(3600);
285        let fake_pk = [0u8; 32];
286        let result = ecdh(&store, "nonexistent", &fake_pk).await;
287        assert!(matches!(result, Err(EcdhError::KeyExpired)));
288    }
289}