alterion_ecdh/
keystore.rs1use std::sync::Arc;
3use rand_core::OsRng;
4use tokio::sync::RwLock;
5use x25519_dalek::{StaticSecret, PublicKey};
6use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
7use chrono::{DateTime, Duration, Utc};
8use uuid::Uuid;
9use zeroize::Zeroizing;
10use dashmap::DashMap;
11
12#[derive(zeroize::ZeroizeOnDrop)]
18pub struct KeyEntry {
19 pub secret: StaticSecret,
20 #[zeroize(skip)]
21 pub key_id: String,
22 #[zeroize(skip)]
23 pub public_key_b64: String,
24 #[zeroize(skip)]
25 pub public_key_raw: [u8; 32],
26 #[zeroize(skip)]
27 pub created_at: DateTime<Utc>,
28 #[zeroize(skip)]
29 pub expires_at: DateTime<Utc>,
30}
31
32pub struct KeyStore {
33 pub current: KeyEntry,
34 pub previous: Option<KeyEntry>,
35 pub next: Option<KeyEntry>,
38}
39
40struct HandshakeEntry {
43 secret: StaticSecret,
44 public_raw: [u8; 32],
45 expires_at: DateTime<Utc>,
46}
47
48#[derive(Clone)]
53pub struct HandshakeStore(Arc<DashMap<String, HandshakeEntry>>);
54
55#[derive(Debug, thiserror::Error)]
56pub enum EcdhError {
57 #[error("key_expired")]
58 KeyExpired,
59 #[error("invalid client public key")]
60 InvalidClientKey,
61 #[error("key generation failed: {0}")]
62 KeyGenerationFailed(String),
63}
64
65const KEY_GRACE_SECS: u64 = 300;
66const KEY_WARM_LEAD_SECS: u64 = 600;
67const HANDSHAKE_TTL_SECS: i64 = 60;
68pub const MAX_HANDSHAKES: usize = 10_000;
72
73fn generate_entry(interval_secs: u64) -> KeyEntry {
74 let secret = StaticSecret::random_from_rng(OsRng);
75 let public_key = PublicKey::from(&secret);
76 let raw = *public_key.as_bytes();
77 let now = Utc::now();
78 let secs = i64::try_from(interval_secs + KEY_GRACE_SECS)
79 .expect("interval overflow");
80 KeyEntry {
81 key_id: Uuid::new_v4().to_string(),
82 public_key_b64: B64.encode(raw),
83 public_key_raw: raw,
84 secret,
85 created_at: now,
86 expires_at: now + Duration::seconds(secs),
87 }
88}
89
90pub fn init_key_store(interval_secs: u64) -> Arc<RwLock<KeyStore>> {
92 Arc::new(RwLock::new(KeyStore {
93 current: generate_entry(interval_secs),
94 previous: None,
95 next: None,
96 }))
97}
98
99pub fn init_handshake_store() -> HandshakeStore {
101 HandshakeStore(Arc::new(DashMap::new()))
102}
103
104pub fn init_handshake(hs: &HandshakeStore) -> Option<(String, String)> {
111 if hs.0.len() >= MAX_HANDSHAKES {
112 tracing::warn!("alterion-ecdh: handshake store at capacity — rejecting new handshake");
113 return None;
114 }
115 let secret = StaticSecret::random_from_rng(OsRng);
116 let public_key = PublicKey::from(&secret);
117 let raw = *public_key.as_bytes();
118 let id = format!("hs_{}", Uuid::new_v4());
119 hs.0.insert(id.clone(), HandshakeEntry {
120 secret,
121 public_raw: raw,
122 expires_at: Utc::now() + Duration::seconds(HANDSHAKE_TTL_SECS),
123 });
124 Some((id, B64.encode(raw)))
125}
126
127pub async fn ecdh_ephemeral(
130 hs: &HandshakeStore,
131 handshake_id: &str,
132 client_pk_bytes: &[u8; 32],
133) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
134 let entry = hs.0.remove(handshake_id)
135 .ok_or(EcdhError::KeyExpired)?;
136 let (_, entry) = entry;
137 if Utc::now() > entry.expires_at {
138 return Err(EcdhError::KeyExpired);
139 }
140 let client_public = PublicKey::from(*client_pk_bytes);
141 let shared = entry.secret.diffie_hellman(&client_public);
142 Ok((Zeroizing::new(*shared.as_bytes()), entry.public_raw))
143}
144
145pub fn prune_handshakes(hs: &HandshakeStore) {
147 let now = Utc::now();
148 hs.0.retain(|_, v| v.expires_at > now);
149}
150
151pub fn start_rotation(store: Arc<RwLock<KeyStore>>, interval_secs: u64, hs: HandshakeStore) {
159 let warm_lead = KEY_WARM_LEAD_SECS.min(interval_secs.saturating_sub(1));
160 let warm_offset = interval_secs.saturating_sub(warm_lead);
161
162 let store_warm = store.clone();
163 let store_rotate = store.clone();
164
165 tokio::spawn(async move {
166 let mut warm_interval = tokio::time::interval_at(
167 tokio::time::Instant::now() + tokio::time::Duration::from_secs(warm_offset),
168 tokio::time::Duration::from_secs(interval_secs),
169 );
170 loop {
171 warm_interval.tick().await;
172 let next = tokio::task::spawn_blocking(move || generate_entry(interval_secs))
173 .await
174 .expect("key generation panicked");
175 store_warm.write().await.next = Some(next);
176 tracing::debug!("next X25519 key pre-warmed");
177 }
178 });
179
180 tokio::spawn(async move {
181 let mut rotation_interval = tokio::time::interval_at(
182 tokio::time::Instant::now() + tokio::time::Duration::from_secs(interval_secs),
183 tokio::time::Duration::from_secs(interval_secs),
184 );
185 let mut cleanup_interval = tokio::time::interval(
186 tokio::time::Duration::from_secs(30),
187 );
188 loop {
189 tokio::select! {
190 _ = rotation_interval.tick() => {
191 let mut w = store_rotate.write().await;
192 let new_entry = w.next.take().unwrap_or_else(|| generate_entry(interval_secs));
193 let old = std::mem::replace(&mut w.current, new_entry);
194 w.previous = Some(old);
195 tracing::info!("X25519 key rotated → {}", w.current.key_id);
196 }
197 _ = cleanup_interval.tick() => {
198 let needs_cleanup = {
199 let r = store_rotate.read().await;
200 r.previous.as_ref().map_or(false, |p| Utc::now() > p.expires_at)
201 };
202 if needs_cleanup {
203 store_rotate.write().await.previous = None;
204 tracing::debug!("previous X25519 key pruned");
205 }
206 prune_handshakes(&hs);
207 }
208 }
209 }
210 });
211}
212
213pub async fn get_current_public_key(store: &Arc<RwLock<KeyStore>>) -> (String, String) {
215 let guard = store.read().await;
216 (guard.current.key_id.clone(), guard.current.public_key_b64.clone())
217}
218
219pub async fn ecdh(
224 store: &Arc<RwLock<KeyStore>>,
225 key_id: &str,
226 client_pk_bytes: &[u8; 32],
227) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
228 let guard = store.read().await;
229
230 let entry = if guard.current.key_id == key_id {
231 &guard.current
232 } else if let Some(prev) = &guard.previous {
233 if prev.key_id == key_id {
234 if Utc::now() > prev.expires_at {
235 return Err(EcdhError::KeyExpired);
236 }
237 prev
238 } else {
239 return Err(EcdhError::KeyExpired);
240 }
241 } else {
242 return Err(EcdhError::KeyExpired);
243 };
244
245 let client_public = PublicKey::from(*client_pk_bytes);
246 let shared = entry.secret.diffie_hellman(&client_public);
247 let server_pub_raw = entry.public_key_raw;
248
249 Ok((Zeroizing::new(*shared.as_bytes()), server_pub_raw))
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[tokio::test]
257 async fn init_produces_valid_keypair() {
258 let store = init_key_store(3600);
259 let (key_id, b64) = get_current_public_key(&store).await;
260 assert!(!key_id.is_empty());
261 let bytes = B64.decode(&b64).unwrap();
262 assert_eq!(bytes.len(), 32);
263 }
264
265 #[tokio::test]
266 async fn ecdh_roundtrip() {
267 let store = init_key_store(3600);
268 let (key_id, b64) = get_current_public_key(&store).await;
269 let server_pub_bytes: [u8; 32] = B64.decode(&b64).unwrap().try_into().unwrap();
270
271 let client_secret = StaticSecret::random_from_rng(OsRng);
273 let client_public = PublicKey::from(&client_secret);
274 let client_shared = client_secret.diffie_hellman(&PublicKey::from(server_pub_bytes));
275
276 let (server_shared, _) = ecdh(&store, &key_id, client_public.as_bytes()).await.unwrap();
278
279 assert_eq!(client_shared.as_bytes(), server_shared.as_slice());
280 }
281
282 #[tokio::test]
283 async fn unknown_key_id_returns_expired() {
284 let store = init_key_store(3600);
285 let fake_pk = [0u8; 32];
286 let result = ecdh(&store, "nonexistent", &fake_pk).await;
287 assert!(matches!(result, Err(EcdhError::KeyExpired)));
288 }
289}