Skip to main content

alterion_ecdh/
keystore.rs

1// SPDX-License-Identifier: GPL-3.0
2use std::sync::Arc;
3use rand_core::OsRng;
4use tokio::sync::RwLock;
5use x25519_dalek::{StaticSecret, PublicKey};
6use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
7use chrono::{DateTime, Duration, Utc};
8use uuid::Uuid;
9use zeroize::Zeroizing;
10use dashmap::DashMap;
11
12pub struct KeyEntry {
13    pub key_id:         String,
14    pub public_key_b64: String,
15    pub public_key_raw: [u8; 32],
16    pub secret:         StaticSecret,
17    pub created_at:     DateTime<Utc>,
18    pub expires_at:     DateTime<Utc>,
19}
20
21pub struct KeyStore {
22    pub current:  KeyEntry,
23    pub previous: Option<KeyEntry>,
24    /// Pre-warmed entry generated `KEY_WARM_LEAD_SECS` before the next rotation.
25    /// The rotation tick consumes this instead of generating a new key on the hot path.
26    pub next:     Option<KeyEntry>,
27}
28
29/// A single-use ephemeral server key pair created by [`init_handshake`].
30/// Consumed and removed by the first [`ecdh`] call that references its ID.
31struct HandshakeEntry {
32    secret:     StaticSecret,
33    public_raw: [u8; 32],
34    expires_at: DateTime<Utc>,
35}
36
37/// Thread-safe store of pending ephemeral handshake entries, keyed by handshake ID.
38///
39/// Separate from [`KeyStore`] so that handshake writes (use-once removals) never contend
40/// with the read-heavy static key lock.
41#[derive(Clone)]
42pub struct HandshakeStore(Arc<DashMap<String, HandshakeEntry>>);
43
44#[derive(Debug, thiserror::Error)]
45pub enum EcdhError {
46    #[error("key_expired")]
47    KeyExpired,
48    #[error("invalid client public key")]
49    InvalidClientKey,
50    #[error("key generation failed: {0}")]
51    KeyGenerationFailed(String),
52}
53
54const KEY_GRACE_SECS:       u64 = 300;
55const KEY_WARM_LEAD_SECS:   u64 = 600;
56const HANDSHAKE_TTL_SECS:   i64 = 60;
57
58fn generate_entry(interval_secs: u64) -> KeyEntry {
59    let secret     = StaticSecret::random_from_rng(OsRng);
60    let public_key = PublicKey::from(&secret);
61    let raw        = *public_key.as_bytes();
62    let now        = Utc::now();
63    let secs       = i64::try_from(interval_secs + KEY_GRACE_SECS)
64        .expect("interval overflow");
65    KeyEntry {
66        key_id:         Uuid::new_v4().to_string(),
67        public_key_b64: B64.encode(raw),
68        public_key_raw: raw,
69        secret,
70        created_at:     now,
71        expires_at:     now + Duration::seconds(secs),
72    }
73}
74
75/// Generates an initial X25519 key pair and wraps it in a shared, RwLock-guarded `KeyStore`.
76pub fn init_key_store(interval_secs: u64) -> Arc<RwLock<KeyStore>> {
77    Arc::new(RwLock::new(KeyStore {
78        current:  generate_entry(interval_secs),
79        previous: None,
80        next:     None,
81    }))
82}
83
84/// Creates an empty `HandshakeStore`. Call once at startup and share the handle across all workers.
85pub fn init_handshake_store() -> HandshakeStore {
86    HandshakeStore(Arc::new(DashMap::new()))
87}
88
89/// Generates a fresh ephemeral X25519 key pair, stores the private key in `hs` with a 60-second
90/// TTL, and returns `(handshake_id, base64_public_key)`.
91///
92/// The private key is consumed and deleted on the first matching [`ecdh_ephemeral`] call.
93/// Any entry not consumed within `HANDSHAKE_TTL_SECS` is pruned by [`prune_handshakes`].
94pub fn init_handshake(hs: &HandshakeStore) -> (String, String) {
95    let secret     = StaticSecret::random_from_rng(OsRng);
96    let public_key = PublicKey::from(&secret);
97    let raw        = *public_key.as_bytes();
98    let id         = format!("hs_{}", Uuid::new_v4());
99    hs.0.insert(id.clone(), HandshakeEntry {
100        secret,
101        public_raw: raw,
102        expires_at: Utc::now() + Duration::seconds(HANDSHAKE_TTL_SECS),
103    });
104    (id, B64.encode(raw))
105}
106
107/// Performs a use-once X25519 ECDH using a handshake entry created by [`init_handshake`].
108/// Removes the entry on success — replaying the same handshake ID returns `EcdhError::KeyExpired`.
109pub async fn ecdh_ephemeral(
110    hs:              &HandshakeStore,
111    handshake_id:    &str,
112    client_pk_bytes: &[u8; 32],
113) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
114    let entry = hs.0.remove(handshake_id)
115        .ok_or(EcdhError::KeyExpired)?;
116    let (_, entry) = entry;
117    if Utc::now() > entry.expires_at {
118        return Err(EcdhError::KeyExpired);
119    }
120    let client_public = PublicKey::from(*client_pk_bytes);
121    let shared        = entry.secret.diffie_hellman(&client_public);
122    Ok((Zeroizing::new(*shared.as_bytes()), entry.public_raw))
123}
124
125/// Removes all expired handshake entries from `hs`. Call periodically (e.g. every 30 seconds).
126pub fn prune_handshakes(hs: &HandshakeStore) {
127    let now = Utc::now();
128    hs.0.retain(|_, v| v.expires_at > now);
129}
130
131/// Spawns two background tasks:
132/// - **Warm-up**: generates the next key `KEY_WARM_LEAD_SECS` before each rotation and stores it
133///   in `KeyStore::next` so key generation never blocks the hot path.
134/// - **Rotation**: swaps `next` (or falls back to a fresh key) into `current` and retires the old
135///   key to `previous` for the grace-window period.
136/// - **Cleanup**: prunes `previous` once its grace window expires, and prunes expired handshake
137///   entries from `hs` every 30 seconds.
138pub fn start_rotation(store: Arc<RwLock<KeyStore>>, interval_secs: u64, hs: HandshakeStore) {
139    let warm_lead = KEY_WARM_LEAD_SECS.min(interval_secs.saturating_sub(1));
140    let warm_offset = interval_secs.saturating_sub(warm_lead);
141
142    let store_warm    = store.clone();
143    let store_rotate  = store.clone();
144
145    tokio::spawn(async move {
146        let mut warm_interval = tokio::time::interval_at(
147            tokio::time::Instant::now() + tokio::time::Duration::from_secs(warm_offset),
148            tokio::time::Duration::from_secs(interval_secs),
149        );
150        loop {
151            warm_interval.tick().await;
152            let next = tokio::task::spawn_blocking(move || generate_entry(interval_secs))
153                .await
154                .expect("key generation panicked");
155            store_warm.write().await.next = Some(next);
156            tracing::debug!("next X25519 key pre-warmed");
157        }
158    });
159
160    tokio::spawn(async move {
161        let mut rotation_interval = tokio::time::interval_at(
162            tokio::time::Instant::now() + tokio::time::Duration::from_secs(interval_secs),
163            tokio::time::Duration::from_secs(interval_secs),
164        );
165        let mut cleanup_interval = tokio::time::interval(
166            tokio::time::Duration::from_secs(30),
167        );
168        loop {
169            tokio::select! {
170                _ = rotation_interval.tick() => {
171                    let mut w = store_rotate.write().await;
172                    let new_entry = w.next.take().unwrap_or_else(|| generate_entry(interval_secs));
173                    let old = std::mem::replace(&mut w.current, new_entry);
174                    w.previous = Some(old);
175                    tracing::info!("X25519 key rotated → {}", w.current.key_id);
176                }
177                _ = cleanup_interval.tick() => {
178                    let needs_cleanup = {
179                        let r = store_rotate.read().await;
180                        r.previous.as_ref().map_or(false, |p| Utc::now() > p.expires_at)
181                    };
182                    if needs_cleanup {
183                        store_rotate.write().await.previous = None;
184                        tracing::debug!("previous X25519 key pruned");
185                    }
186                    prune_handshakes(&hs);
187                }
188            }
189        }
190    });
191}
192
193/// Returns `(key_id, base64_public_key)` for the currently active key.
194pub async fn get_current_public_key(store: &Arc<RwLock<KeyStore>>) -> (String, String) {
195    let guard = store.read().await;
196    (guard.current.key_id.clone(), guard.current.public_key_b64.clone())
197}
198
199/// Performs X25519 ECDH using the server key identified by `key_id` and the client's
200/// ephemeral public key bytes. Returns `(shared_secret, server_public_key_bytes)`.
201///
202/// Falls back to the previous key within its grace window; returns `EcdhError::KeyExpired` otherwise.
203pub async fn ecdh(
204    store:           &Arc<RwLock<KeyStore>>,
205    key_id:          &str,
206    client_pk_bytes: &[u8; 32],
207) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
208    let guard = store.read().await;
209
210    let entry = if guard.current.key_id == key_id {
211        &guard.current
212    } else if let Some(prev) = &guard.previous {
213        if prev.key_id == key_id {
214            if Utc::now() > prev.expires_at {
215                return Err(EcdhError::KeyExpired);
216            }
217            prev
218        } else {
219            return Err(EcdhError::KeyExpired);
220        }
221    } else {
222        return Err(EcdhError::KeyExpired);
223    };
224
225    let client_public  = PublicKey::from(*client_pk_bytes);
226    let shared         = entry.secret.diffie_hellman(&client_public);
227    let server_pub_raw = entry.public_key_raw;
228
229    Ok((Zeroizing::new(*shared.as_bytes()), server_pub_raw))
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[tokio::test]
237    async fn init_produces_valid_keypair() {
238        let store         = init_key_store(3600);
239        let (key_id, b64) = get_current_public_key(&store).await;
240        assert!(!key_id.is_empty());
241        let bytes = B64.decode(&b64).unwrap();
242        assert_eq!(bytes.len(), 32);
243    }
244
245    #[tokio::test]
246    async fn ecdh_roundtrip() {
247        let store = init_key_store(3600);
248        let (key_id, b64) = get_current_public_key(&store).await;
249        let server_pub_bytes: [u8; 32] = B64.decode(&b64).unwrap().try_into().unwrap();
250
251        // Simulate client side
252        let client_secret = StaticSecret::random_from_rng(OsRng);
253        let client_public = PublicKey::from(&client_secret);
254        let client_shared = client_secret.diffie_hellman(&PublicKey::from(server_pub_bytes));
255
256        // Server side
257        let (server_shared, _) = ecdh(&store, &key_id, client_public.as_bytes()).await.unwrap();
258
259        assert_eq!(client_shared.as_bytes(), server_shared.as_slice());
260    }
261
262    #[tokio::test]
263    async fn unknown_key_id_returns_expired() {
264        let store = init_key_store(3600);
265        let fake_pk = [0u8; 32];
266        let result = ecdh(&store, "nonexistent", &fake_pk).await;
267        assert!(matches!(result, Err(EcdhError::KeyExpired)));
268    }
269}