Skip to main content

alterion_ecdh/
keystore.rs

1// SPDX-License-Identifier: GPL-3.0
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use x25519_dalek::{StaticSecret, PublicKey};
5use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
6use chrono::{DateTime, Duration, Utc};
7use uuid::Uuid;
8use zeroize::Zeroizing;
9use dashmap::DashMap;
10
11pub struct KeyEntry {
12    pub key_id:         String,
13    pub public_key_b64: String,
14    pub public_key_raw: [u8; 32],
15    pub secret:         StaticSecret,
16    pub created_at:     DateTime<Utc>,
17    pub expires_at:     DateTime<Utc>,
18}
19
20pub struct KeyStore {
21    pub current:  KeyEntry,
22    pub previous: Option<KeyEntry>,
23    /// Pre-warmed entry generated `KEY_WARM_LEAD_SECS` before the next rotation.
24    /// The rotation tick consumes this instead of generating a new key on the hot path.
25    pub next:     Option<KeyEntry>,
26}
27
28/// A single-use ephemeral server key pair created by [`init_handshake`].
29/// Consumed and removed by the first [`ecdh`] call that references its ID.
30struct HandshakeEntry {
31    secret:     StaticSecret,
32    public_raw: [u8; 32],
33    expires_at: DateTime<Utc>,
34}
35
36/// Thread-safe store of pending ephemeral handshake entries, keyed by handshake ID.
37///
38/// Separate from [`KeyStore`] so that handshake writes (use-once removals) never contend
39/// with the read-heavy static key lock.
40#[derive(Clone)]
41pub struct HandshakeStore(Arc<DashMap<String, HandshakeEntry>>);
42
43#[derive(Debug, thiserror::Error)]
44pub enum EcdhError {
45    #[error("key_expired")]
46    KeyExpired,
47    #[error("invalid client public key")]
48    InvalidClientKey,
49    #[error("key generation failed: {0}")]
50    KeyGenerationFailed(String),
51}
52
53const KEY_GRACE_SECS:       u64 = 300;
54const KEY_WARM_LEAD_SECS:   u64 = 600;
55const HANDSHAKE_TTL_SECS:   i64 = 60;
56
57fn generate_entry(interval_secs: u64) -> KeyEntry {
58    let secret     = StaticSecret::random_from_rng(rand::thread_rng());
59    let public_key = PublicKey::from(&secret);
60    let raw        = *public_key.as_bytes();
61    let now        = Utc::now();
62    let secs       = i64::try_from(interval_secs + KEY_GRACE_SECS)
63        .expect("interval overflow");
64    KeyEntry {
65        key_id:         Uuid::new_v4().to_string(),
66        public_key_b64: B64.encode(raw),
67        public_key_raw: raw,
68        secret,
69        created_at:     now,
70        expires_at:     now + Duration::seconds(secs),
71    }
72}
73
74/// Generates an initial X25519 key pair and wraps it in a shared, RwLock-guarded `KeyStore`.
75pub fn init_key_store(interval_secs: u64) -> Arc<RwLock<KeyStore>> {
76    Arc::new(RwLock::new(KeyStore {
77        current:  generate_entry(interval_secs),
78        previous: None,
79        next:     None,
80    }))
81}
82
83/// Creates an empty `HandshakeStore`. Call once at startup and share the handle across all workers.
84pub fn init_handshake_store() -> HandshakeStore {
85    HandshakeStore(Arc::new(DashMap::new()))
86}
87
88/// Generates a fresh ephemeral X25519 key pair, stores the private key in `hs` with a 60-second
89/// TTL, and returns `(handshake_id, base64_public_key)`.
90///
91/// The private key is consumed and deleted on the first matching [`ecdh_ephemeral`] call.
92/// Any entry not consumed within `HANDSHAKE_TTL_SECS` is pruned by [`prune_handshakes`].
93pub fn init_handshake(hs: &HandshakeStore) -> (String, String) {
94    let secret     = StaticSecret::random_from_rng(rand::thread_rng());
95    let public_key = PublicKey::from(&secret);
96    let raw        = *public_key.as_bytes();
97    let id         = format!("hs_{}", Uuid::new_v4());
98    hs.0.insert(id.clone(), HandshakeEntry {
99        secret,
100        public_raw: raw,
101        expires_at: Utc::now() + Duration::seconds(HANDSHAKE_TTL_SECS),
102    });
103    (id, B64.encode(raw))
104}
105
106/// Performs a use-once X25519 ECDH using a handshake entry created by [`init_handshake`].
107/// Removes the entry on success — replaying the same handshake ID returns `EcdhError::KeyExpired`.
108pub async fn ecdh_ephemeral(
109    hs:              &HandshakeStore,
110    handshake_id:    &str,
111    client_pk_bytes: &[u8; 32],
112) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
113    let entry = hs.0.remove(handshake_id)
114        .ok_or(EcdhError::KeyExpired)?;
115    let (_, entry) = entry;
116    if Utc::now() > entry.expires_at {
117        return Err(EcdhError::KeyExpired);
118    }
119    let client_public = PublicKey::from(*client_pk_bytes);
120    let shared        = entry.secret.diffie_hellman(&client_public);
121    Ok((Zeroizing::new(*shared.as_bytes()), entry.public_raw))
122}
123
124/// Removes all expired handshake entries from `hs`. Call periodically (e.g. every 30 seconds).
125pub fn prune_handshakes(hs: &HandshakeStore) {
126    let now = Utc::now();
127    hs.0.retain(|_, v| v.expires_at > now);
128}
129
130/// Spawns two background tasks:
131/// - **Warm-up**: generates the next key `KEY_WARM_LEAD_SECS` before each rotation and stores it
132///   in `KeyStore::next` so key generation never blocks the hot path.
133/// - **Rotation**: swaps `next` (or falls back to a fresh key) into `current` and retires the old
134///   key to `previous` for the grace-window period.
135/// - **Cleanup**: prunes `previous` once its grace window expires, and prunes expired handshake
136///   entries from `hs` every 30 seconds.
137pub fn start_rotation(store: Arc<RwLock<KeyStore>>, interval_secs: u64, hs: HandshakeStore) {
138    let warm_lead = KEY_WARM_LEAD_SECS.min(interval_secs.saturating_sub(1));
139    let warm_offset = interval_secs.saturating_sub(warm_lead);
140
141    let store_warm    = store.clone();
142    let store_rotate  = store.clone();
143
144    tokio::spawn(async move {
145        let mut warm_interval = tokio::time::interval_at(
146            tokio::time::Instant::now() + tokio::time::Duration::from_secs(warm_offset),
147            tokio::time::Duration::from_secs(interval_secs),
148        );
149        loop {
150            warm_interval.tick().await;
151            let next = tokio::task::spawn_blocking(move || generate_entry(interval_secs))
152                .await
153                .expect("key generation panicked");
154            store_warm.write().await.next = Some(next);
155            tracing::debug!("next X25519 key pre-warmed");
156        }
157    });
158
159    tokio::spawn(async move {
160        let mut rotation_interval = tokio::time::interval_at(
161            tokio::time::Instant::now() + tokio::time::Duration::from_secs(interval_secs),
162            tokio::time::Duration::from_secs(interval_secs),
163        );
164        let mut cleanup_interval = tokio::time::interval(
165            tokio::time::Duration::from_secs(30),
166        );
167        loop {
168            tokio::select! {
169                _ = rotation_interval.tick() => {
170                    let mut w = store_rotate.write().await;
171                    let new_entry = w.next.take().unwrap_or_else(|| generate_entry(interval_secs));
172                    let old = std::mem::replace(&mut w.current, new_entry);
173                    w.previous = Some(old);
174                    tracing::info!("X25519 key rotated → {}", w.current.key_id);
175                }
176                _ = cleanup_interval.tick() => {
177                    let needs_cleanup = {
178                        let r = store_rotate.read().await;
179                        r.previous.as_ref().map_or(false, |p| Utc::now() > p.expires_at)
180                    };
181                    if needs_cleanup {
182                        store_rotate.write().await.previous = None;
183                        tracing::debug!("previous X25519 key pruned");
184                    }
185                    prune_handshakes(&hs);
186                }
187            }
188        }
189    });
190}
191
192/// Returns `(key_id, base64_public_key)` for the currently active key.
193pub async fn get_current_public_key(store: &Arc<RwLock<KeyStore>>) -> (String, String) {
194    let guard = store.read().await;
195    (guard.current.key_id.clone(), guard.current.public_key_b64.clone())
196}
197
198/// Performs X25519 ECDH using the server key identified by `key_id` and the client's
199/// ephemeral public key bytes. Returns `(shared_secret, server_public_key_bytes)`.
200///
201/// Falls back to the previous key within its grace window; returns `EcdhError::KeyExpired` otherwise.
202pub async fn ecdh(
203    store:           &Arc<RwLock<KeyStore>>,
204    key_id:          &str,
205    client_pk_bytes: &[u8; 32],
206) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
207    let guard = store.read().await;
208
209    let entry = if guard.current.key_id == key_id {
210        &guard.current
211    } else if let Some(prev) = &guard.previous {
212        if prev.key_id == key_id {
213            if Utc::now() > prev.expires_at {
214                return Err(EcdhError::KeyExpired);
215            }
216            prev
217        } else {
218            return Err(EcdhError::KeyExpired);
219        }
220    } else {
221        return Err(EcdhError::KeyExpired);
222    };
223
224    let client_public  = PublicKey::from(*client_pk_bytes);
225    let shared         = entry.secret.diffie_hellman(&client_public);
226    let server_pub_raw = entry.public_key_raw;
227
228    Ok((Zeroizing::new(*shared.as_bytes()), server_pub_raw))
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[tokio::test]
236    async fn init_produces_valid_keypair() {
237        let store         = init_key_store(3600);
238        let (key_id, b64) = get_current_public_key(&store).await;
239        assert!(!key_id.is_empty());
240        let bytes = B64.decode(&b64).unwrap();
241        assert_eq!(bytes.len(), 32);
242    }
243
244    #[tokio::test]
245    async fn ecdh_roundtrip() {
246        let store = init_key_store(3600);
247        let (key_id, b64) = get_current_public_key(&store).await;
248        let server_pub_bytes: [u8; 32] = B64.decode(&b64).unwrap().try_into().unwrap();
249
250        // Simulate client side
251        let client_secret = StaticSecret::random_from_rng(rand::thread_rng());
252        let client_public = PublicKey::from(&client_secret);
253        let client_shared = client_secret.diffie_hellman(&PublicKey::from(server_pub_bytes));
254
255        // Server side
256        let (server_shared, _) = ecdh(&store, &key_id, client_public.as_bytes()).await.unwrap();
257
258        assert_eq!(client_shared.as_bytes(), server_shared.as_slice());
259    }
260
261    #[tokio::test]
262    async fn unknown_key_id_returns_expired() {
263        let store = init_key_store(3600);
264        let fake_pk = [0u8; 32];
265        let result = ecdh(&store, "nonexistent", &fake_pk).await;
266        assert!(matches!(result, Err(EcdhError::KeyExpired)));
267    }
268}