alterion_ecdh/
keystore.rs1use std::sync::Arc;
3use tokio::sync::RwLock;
4use x25519_dalek::{StaticSecret, PublicKey};
5use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
6use chrono::{DateTime, Duration, Utc};
7use uuid::Uuid;
8use zeroize::Zeroizing;
9use dashmap::DashMap;
10
11pub struct KeyEntry {
12 pub key_id: String,
13 pub public_key_b64: String,
14 pub public_key_raw: [u8; 32],
15 pub secret: StaticSecret,
16 pub created_at: DateTime<Utc>,
17 pub expires_at: DateTime<Utc>,
18}
19
20pub struct KeyStore {
21 pub current: KeyEntry,
22 pub previous: Option<KeyEntry>,
23 pub next: Option<KeyEntry>,
26}
27
28struct HandshakeEntry {
31 secret: StaticSecret,
32 public_raw: [u8; 32],
33 expires_at: DateTime<Utc>,
34}
35
36#[derive(Clone)]
41pub struct HandshakeStore(Arc<DashMap<String, HandshakeEntry>>);
42
43#[derive(Debug, thiserror::Error)]
44pub enum EcdhError {
45 #[error("key_expired")]
46 KeyExpired,
47 #[error("invalid client public key")]
48 InvalidClientKey,
49 #[error("key generation failed: {0}")]
50 KeyGenerationFailed(String),
51}
52
53const KEY_GRACE_SECS: u64 = 300;
54const KEY_WARM_LEAD_SECS: u64 = 600;
55const HANDSHAKE_TTL_SECS: i64 = 60;
56
57fn generate_entry(interval_secs: u64) -> KeyEntry {
58 let secret = StaticSecret::random_from_rng(rand::thread_rng());
59 let public_key = PublicKey::from(&secret);
60 let raw = *public_key.as_bytes();
61 let now = Utc::now();
62 let secs = i64::try_from(interval_secs + KEY_GRACE_SECS)
63 .expect("interval overflow");
64 KeyEntry {
65 key_id: Uuid::new_v4().to_string(),
66 public_key_b64: B64.encode(raw),
67 public_key_raw: raw,
68 secret,
69 created_at: now,
70 expires_at: now + Duration::seconds(secs),
71 }
72}
73
74pub fn init_key_store(interval_secs: u64) -> Arc<RwLock<KeyStore>> {
76 Arc::new(RwLock::new(KeyStore {
77 current: generate_entry(interval_secs),
78 previous: None,
79 next: None,
80 }))
81}
82
83pub fn init_handshake_store() -> HandshakeStore {
85 HandshakeStore(Arc::new(DashMap::new()))
86}
87
88pub fn init_handshake(hs: &HandshakeStore) -> (String, String) {
94 let secret = StaticSecret::random_from_rng(rand::thread_rng());
95 let public_key = PublicKey::from(&secret);
96 let raw = *public_key.as_bytes();
97 let id = format!("hs_{}", Uuid::new_v4());
98 hs.0.insert(id.clone(), HandshakeEntry {
99 secret,
100 public_raw: raw,
101 expires_at: Utc::now() + Duration::seconds(HANDSHAKE_TTL_SECS),
102 });
103 (id, B64.encode(raw))
104}
105
106pub async fn ecdh_ephemeral(
109 hs: &HandshakeStore,
110 handshake_id: &str,
111 client_pk_bytes: &[u8; 32],
112) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
113 let entry = hs.0.remove(handshake_id)
114 .ok_or(EcdhError::KeyExpired)?;
115 let (_, entry) = entry;
116 if Utc::now() > entry.expires_at {
117 return Err(EcdhError::KeyExpired);
118 }
119 let client_public = PublicKey::from(*client_pk_bytes);
120 let shared = entry.secret.diffie_hellman(&client_public);
121 Ok((Zeroizing::new(*shared.as_bytes()), entry.public_raw))
122}
123
124pub fn prune_handshakes(hs: &HandshakeStore) {
126 let now = Utc::now();
127 hs.0.retain(|_, v| v.expires_at > now);
128}
129
130pub fn start_rotation(store: Arc<RwLock<KeyStore>>, interval_secs: u64, hs: HandshakeStore) {
138 let warm_lead = KEY_WARM_LEAD_SECS.min(interval_secs.saturating_sub(1));
139 let warm_offset = interval_secs.saturating_sub(warm_lead);
140
141 let store_warm = store.clone();
142 let store_rotate = store.clone();
143
144 tokio::spawn(async move {
145 let mut warm_interval = tokio::time::interval_at(
146 tokio::time::Instant::now() + tokio::time::Duration::from_secs(warm_offset),
147 tokio::time::Duration::from_secs(interval_secs),
148 );
149 loop {
150 warm_interval.tick().await;
151 let next = tokio::task::spawn_blocking(move || generate_entry(interval_secs))
152 .await
153 .expect("key generation panicked");
154 store_warm.write().await.next = Some(next);
155 tracing::debug!("next X25519 key pre-warmed");
156 }
157 });
158
159 tokio::spawn(async move {
160 let mut rotation_interval = tokio::time::interval_at(
161 tokio::time::Instant::now() + tokio::time::Duration::from_secs(interval_secs),
162 tokio::time::Duration::from_secs(interval_secs),
163 );
164 let mut cleanup_interval = tokio::time::interval(
165 tokio::time::Duration::from_secs(30),
166 );
167 loop {
168 tokio::select! {
169 _ = rotation_interval.tick() => {
170 let mut w = store_rotate.write().await;
171 let new_entry = w.next.take().unwrap_or_else(|| generate_entry(interval_secs));
172 let old = std::mem::replace(&mut w.current, new_entry);
173 w.previous = Some(old);
174 tracing::info!("X25519 key rotated → {}", w.current.key_id);
175 }
176 _ = cleanup_interval.tick() => {
177 let needs_cleanup = {
178 let r = store_rotate.read().await;
179 r.previous.as_ref().map_or(false, |p| Utc::now() > p.expires_at)
180 };
181 if needs_cleanup {
182 store_rotate.write().await.previous = None;
183 tracing::debug!("previous X25519 key pruned");
184 }
185 prune_handshakes(&hs);
186 }
187 }
188 }
189 });
190}
191
192pub async fn get_current_public_key(store: &Arc<RwLock<KeyStore>>) -> (String, String) {
194 let guard = store.read().await;
195 (guard.current.key_id.clone(), guard.current.public_key_b64.clone())
196}
197
198pub async fn ecdh(
203 store: &Arc<RwLock<KeyStore>>,
204 key_id: &str,
205 client_pk_bytes: &[u8; 32],
206) -> Result<(Zeroizing<[u8; 32]>, [u8; 32]), EcdhError> {
207 let guard = store.read().await;
208
209 let entry = if guard.current.key_id == key_id {
210 &guard.current
211 } else if let Some(prev) = &guard.previous {
212 if prev.key_id == key_id {
213 if Utc::now() > prev.expires_at {
214 return Err(EcdhError::KeyExpired);
215 }
216 prev
217 } else {
218 return Err(EcdhError::KeyExpired);
219 }
220 } else {
221 return Err(EcdhError::KeyExpired);
222 };
223
224 let client_public = PublicKey::from(*client_pk_bytes);
225 let shared = entry.secret.diffie_hellman(&client_public);
226 let server_pub_raw = entry.public_key_raw;
227
228 Ok((Zeroizing::new(*shared.as_bytes()), server_pub_raw))
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[tokio::test]
236 async fn init_produces_valid_keypair() {
237 let store = init_key_store(3600);
238 let (key_id, b64) = get_current_public_key(&store).await;
239 assert!(!key_id.is_empty());
240 let bytes = B64.decode(&b64).unwrap();
241 assert_eq!(bytes.len(), 32);
242 }
243
244 #[tokio::test]
245 async fn ecdh_roundtrip() {
246 let store = init_key_store(3600);
247 let (key_id, b64) = get_current_public_key(&store).await;
248 let server_pub_bytes: [u8; 32] = B64.decode(&b64).unwrap().try_into().unwrap();
249
250 let client_secret = StaticSecret::random_from_rng(rand::thread_rng());
252 let client_public = PublicKey::from(&client_secret);
253 let client_shared = client_secret.diffie_hellman(&PublicKey::from(server_pub_bytes));
254
255 let (server_shared, _) = ecdh(&store, &key_id, client_public.as_bytes()).await.unwrap();
257
258 assert_eq!(client_shared.as_bytes(), server_shared.as_slice());
259 }
260
261 #[tokio::test]
262 async fn unknown_key_id_returns_expired() {
263 let store = init_key_store(3600);
264 let fake_pk = [0u8; 32];
265 let result = ecdh(&store, "nonexistent", &fake_pk).await;
266 assert!(matches!(result, Err(EcdhError::KeyExpired)));
267 }
268}