alterion_ecdh/
keystore.rs1use std::sync::Arc;
3use rand_core::OsRng;
4use tokio::sync::RwLock;
5use x25519_dalek::{StaticSecret, PublicKey};
6use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
7use chrono::{DateTime, Duration, Utc};
8use uuid::Uuid;
9use zeroize::Zeroizing;
10use dashmap::DashMap;
11
12pub struct KeyEntry {
13 pub key_id: String,
14 pub public_key_b64: String,
15 pub public_key_raw: [u8; 32],
16 pub secret: StaticSecret,
17 pub created_at: DateTime<Utc>,
18 pub expires_at: DateTime<Utc>,
19}
20
21pub struct KeyStore {
22 pub current: KeyEntry,
23 pub previous: Option<KeyEntry>,
24 pub next: Option<KeyEntry>,
27}
28
29struct HandshakeEntry {
32 secret: StaticSecret,
33 public_raw: [u8; 32],
34 expires_at: DateTime<Utc>,
35}
36
37#[derive(Clone)]
42pub struct HandshakeStore(Arc<DashMap<String, HandshakeEntry>>);
43
44#[derive(Debug, thiserror::Error)]
45pub enum EcdhError {
46 #[error("key_expired")]
47 KeyExpired,
48 #[error("invalid client public key")]
49 InvalidClientKey,
50 #[error("key generation failed: {0}")]
51 KeyGenerationFailed(String),
52}
53
54const KEY_GRACE_SECS: u64 = 300;
55const KEY_WARM_LEAD_SECS: u64 = 600;
56const HANDSHAKE_TTL_SECS: i64 = 60;
57
58fn generate_entry(interval_secs: u64) -> KeyEntry {
59 let secret = StaticSecret::random_from_rng(OsRng);
60 let public_key = PublicKey::from(&secret);
61 let raw = *public_key.as_bytes();
62 let now = Utc::now();
63 let secs = i64::try_from(interval_secs + KEY_GRACE_SECS)
64 .expect("interval overflow");
65 KeyEntry {
66 key_id: Uuid::new_v4().to_string(),
67 public_key_b64: B64.encode(raw),
68 public_key_raw: raw,
69 secret,
70 created_at: now,
71 expires_at: now + Duration::seconds(secs),
72 }
73}
74
75pub fn init_key_store(interval_secs: u64) -> Arc<RwLock<KeyStore>> {
77 Arc::new(RwLock::new(KeyStore {
78 current: generate_entry(interval_secs),
79 previous: None,
80 next: None,
81 }))
82}
83
84pub fn init_handshake_store() -> HandshakeStore {
86 HandshakeStore(Arc::new(DashMap::new()))
87}
88
89pub fn init_handshake(hs: &HandshakeStore) -> (String, String) {
95 let secret = StaticSecret::random_from_rng(OsRng);
96 let public_key = PublicKey::from(&secret);
97 let raw = *public_key.as_bytes();
98 let id = format!("hs_{}", Uuid::new_v4());
99 hs.0.insert(id.clone(), HandshakeEntry {
100 secret,
101 public_raw: raw,
102 expires_at: Utc::now() + Duration::seconds(HANDSHAKE_TTL_SECS),
103 });
104 (id, B64.encode(raw))
105}
106
107pub async fn ecdh_ephemeral(
110 hs: &HandshakeStore,
111 handshake_id: &str,
112 client_pk_bytes: &[u8; 32],
113) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
114 let entry = hs.0.remove(handshake_id)
115 .ok_or(EcdhError::KeyExpired)?;
116 let (_, entry) = entry;
117 if Utc::now() > entry.expires_at {
118 return Err(EcdhError::KeyExpired);
119 }
120 let client_public = PublicKey::from(*client_pk_bytes);
121 let shared = entry.secret.diffie_hellman(&client_public);
122 Ok((Zeroizing::new(*shared.as_bytes()), entry.public_raw))
123}
124
125pub fn prune_handshakes(hs: &HandshakeStore) {
127 let now = Utc::now();
128 hs.0.retain(|_, v| v.expires_at > now);
129}
130
131pub fn start_rotation(store: Arc<RwLock<KeyStore>>, interval_secs: u64, hs: HandshakeStore) {
139 let warm_lead = KEY_WARM_LEAD_SECS.min(interval_secs.saturating_sub(1));
140 let warm_offset = interval_secs.saturating_sub(warm_lead);
141
142 let store_warm = store.clone();
143 let store_rotate = store.clone();
144
145 tokio::spawn(async move {
146 let mut warm_interval = tokio::time::interval_at(
147 tokio::time::Instant::now() + tokio::time::Duration::from_secs(warm_offset),
148 tokio::time::Duration::from_secs(interval_secs),
149 );
150 loop {
151 warm_interval.tick().await;
152 let next = tokio::task::spawn_blocking(move || generate_entry(interval_secs))
153 .await
154 .expect("key generation panicked");
155 store_warm.write().await.next = Some(next);
156 tracing::debug!("next X25519 key pre-warmed");
157 }
158 });
159
160 tokio::spawn(async move {
161 let mut rotation_interval = tokio::time::interval_at(
162 tokio::time::Instant::now() + tokio::time::Duration::from_secs(interval_secs),
163 tokio::time::Duration::from_secs(interval_secs),
164 );
165 let mut cleanup_interval = tokio::time::interval(
166 tokio::time::Duration::from_secs(30),
167 );
168 loop {
169 tokio::select! {
170 _ = rotation_interval.tick() => {
171 let mut w = store_rotate.write().await;
172 let new_entry = w.next.take().unwrap_or_else(|| generate_entry(interval_secs));
173 let old = std::mem::replace(&mut w.current, new_entry);
174 w.previous = Some(old);
175 tracing::info!("X25519 key rotated → {}", w.current.key_id);
176 }
177 _ = cleanup_interval.tick() => {
178 let needs_cleanup = {
179 let r = store_rotate.read().await;
180 r.previous.as_ref().map_or(false, |p| Utc::now() > p.expires_at)
181 };
182 if needs_cleanup {
183 store_rotate.write().await.previous = None;
184 tracing::debug!("previous X25519 key pruned");
185 }
186 prune_handshakes(&hs);
187 }
188 }
189 }
190 });
191}
192
193pub async fn get_current_public_key(store: &Arc<RwLock<KeyStore>>) -> (String, String) {
195 let guard = store.read().await;
196 (guard.current.key_id.clone(), guard.current.public_key_b64.clone())
197}
198
199pub async fn ecdh(
204 store: &Arc<RwLock<KeyStore>>,
205 key_id: &str,
206 client_pk_bytes: &[u8; 32],
207) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
208 let guard = store.read().await;
209
210 let entry = if guard.current.key_id == key_id {
211 &guard.current
212 } else if let Some(prev) = &guard.previous {
213 if prev.key_id == key_id {
214 if Utc::now() > prev.expires_at {
215 return Err(EcdhError::KeyExpired);
216 }
217 prev
218 } else {
219 return Err(EcdhError::KeyExpired);
220 }
221 } else {
222 return Err(EcdhError::KeyExpired);
223 };
224
225 let client_public = PublicKey::from(*client_pk_bytes);
226 let shared = entry.secret.diffie_hellman(&client_public);
227 let server_pub_raw = entry.public_key_raw;
228
229 Ok((Zeroizing::new(*shared.as_bytes()), server_pub_raw))
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[tokio::test]
237 async fn init_produces_valid_keypair() {
238 let store = init_key_store(3600);
239 let (key_id, b64) = get_current_public_key(&store).await;
240 assert!(!key_id.is_empty());
241 let bytes = B64.decode(&b64).unwrap();
242 assert_eq!(bytes.len(), 32);
243 }
244
245 #[tokio::test]
246 async fn ecdh_roundtrip() {
247 let store = init_key_store(3600);
248 let (key_id, b64) = get_current_public_key(&store).await;
249 let server_pub_bytes: [u8; 32] = B64.decode(&b64).unwrap().try_into().unwrap();
250
251 let client_secret = StaticSecret::random_from_rng(OsRng);
253 let client_public = PublicKey::from(&client_secret);
254 let client_shared = client_secret.diffie_hellman(&PublicKey::from(server_pub_bytes));
255
256 let (server_shared, _) = ecdh(&store, &key_id, client_public.as_bytes()).await.unwrap();
258
259 assert_eq!(client_shared.as_bytes(), server_shared.as_slice());
260 }
261
262 #[tokio::test]
263 async fn unknown_key_id_returns_expired() {
264 let store = init_key_store(3600);
265 let fake_pk = [0u8; 32];
266 let result = ecdh(&store, "nonexistent", &fake_pk).await;
267 assert!(matches!(result, Err(EcdhError::KeyExpired)));
268 }
269}