1use crate::backend::{EPOCH_CURSOR, SecretBackend};
2use crate::encryptor::{Encrypted, KeyEncryptor};
3use crate::secret_rotation::InMemorySecretGroup;
4use std::collections::HashMap;
5use std::hash::{Hash, Hasher};
6use std::collections::hash_map::DefaultHasher;
7use std::sync::Arc;
8use std::time::{Duration, SystemTime};
9use tokio_util::sync::CancellationToken;
10use tracing::{error, info};
11
12const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(5);
17const ROTATION_POLL_BUFFER: Duration = Duration::from_secs(2);
18
19fn to_nonce(v: Option<Vec<u8>>) -> Option<[u8; 12]> {
27 v.and_then(|b| b.try_into().ok())
28}
29
30fn payload_hash(enc: &Encrypted) -> u64 {
31 let mut h = DefaultHasher::new();
32 enc.ciphertext.hash(&mut h);
33 enc.nonce.hash(&mut h);
34 enc.key_version.hash(&mut h);
35 h.finish()
36}
37
38pub struct SecretSyncer<B: SecretBackend, E: KeyEncryptor + Clone, const V: usize = 256, const S: usize = 32> {
100 group_id: String,
101 secret: Arc<InMemorySecretGroup<V, S>>,
102 backend: B,
103 encryptor: E,
104 rotation_interval: Duration,
105 poll_interval: Duration,
106 seen_hashes: HashMap<u8, u64>,
107}
108
109impl<B: SecretBackend, E: KeyEncryptor + Clone, const V: usize, const S: usize> SecretSyncer<B, E, V, S> {
110 pub fn new(
122 group_id: impl Into<String>,
123 secret: Arc<InMemorySecretGroup<V, S>>,
124 backend: B,
125 encryptor: E,
126 rotation_interval: Duration,
127 poll_interval: Option<Duration>,
128 ) -> Self {
129 Self {
130 group_id: group_id.into(),
131 secret,
132 backend,
133 encryptor,
134 rotation_interval,
135 poll_interval: poll_interval.unwrap_or(DEFAULT_POLL_INTERVAL),
136 seen_hashes: HashMap::new(),
137 }
138 }
139
140 pub async fn initial_load(
153 &mut self,
154 token: &CancellationToken,
155 ) -> Result<(SystemTime, i64), B::Error> {
156 let records = self.backend.load_all(&self.group_id).await?;
157 let count = records.len();
158 let mut max_time = EPOCH_CURSOR;
159 let mut max_id = 0i64;
160 let mut latest_active_version: Option<u8> = None;
161 let mut latest_active_at = EPOCH_CURSOR;
162
163 let now = SystemTime::now();
164
165 for record in records {
166 if (record.activated_at, record.id) > (max_time, max_id) {
167 max_time = record.activated_at;
168 max_id = record.id;
169 }
170
171 if (record.version as usize) >= V {
172 error!(
173 group_id = %self.group_id,
174 version = record.version,
175 ring_size = V,
176 "SecretSyncer: version exceeds ring buffer size, skipping"
177 );
178 continue;
179 }
180
181 let enc = Encrypted {
182 ciphertext: record.key_bytes,
183 nonce: to_nonce(record.nonce),
184 key_version: record.encryption_key_version,
185 };
186 let hash = payload_hash(&enc);
187
188 if self.seen_hashes.get(&record.version) == Some(&hash) {
189 if record.activated_at <= now {
191 if record.activated_at >= latest_active_at {
192 latest_active_at = record.activated_at;
193 latest_active_version = Some(record.version);
194 }
195 }
196 continue;
197 }
198
199 match self.encryptor.decrypt(&enc).await {
200 Ok(bytes) => {
201 if let Ok(key) = <[u8; S]>::try_from(bytes) {
202 self.secret.store_key(record.version, key);
203 self.seen_hashes.insert(record.version, hash);
204 if record.activated_at <= now {
205 if record.activated_at >= latest_active_at {
206 latest_active_at = record.activated_at;
207 latest_active_version = Some(record.version);
208 }
209 } else {
210 self.schedule_promotion(record.version, record.activated_at, token.clone());
211 }
212 }
213 }
214 Err(e) => {
215 error!(
216 group_id = %self.group_id,
217 version = record.version,
218 error = %e,
219 "SecretSyncer: decryption failed during initial load"
220 );
221 }
222 }
223 }
224
225 if let Some(v) = latest_active_version {
226 self.secret.promote(v);
227 }
228
229 info!(group_id = %self.group_id, count, "SecretSyncer initial load complete");
230 Ok((max_time, max_id))
231 }
232
233 pub async fn run(mut self, token: CancellationToken, mut cursor: (SystemTime, i64)) {
240 loop {
241 let now = SystemTime::now();
242 let next_expected = cursor.0.checked_add(self.rotation_interval).unwrap_or(now);
243
244 let sleep_dur = next_expected
245 .duration_since(now)
246 .ok()
247 .map(|d| d + ROTATION_POLL_BUFFER)
248 .filter(|&smart| smart < self.poll_interval)
249 .unwrap_or(self.poll_interval);
250
251 tokio::select! {
252 biased;
253 _ = token.cancelled() => {
254 info!(group_id = %self.group_id, "SecretSyncer shutting down");
255 break;
256 }
257 _ = tokio::time::sleep(sleep_dur) => {
258 match self.backend.poll_new(&self.group_id, cursor.0, cursor.1).await {
259 Ok(records) => {
260 for record in records {
261 if (record.activated_at, record.id) > cursor {
262 cursor = (record.activated_at, record.id);
263 }
264 if (record.version as usize) >= V {
265 error!(
266 group_id = %self.group_id,
267 version = record.version,
268 ring_size = V,
269 "SecretSyncer: version exceeds ring buffer size, skipping"
270 );
271 continue;
272 }
273 let enc = Encrypted {
274 ciphertext: record.key_bytes,
275 nonce: to_nonce(record.nonce),
276 key_version: record.encryption_key_version,
277 };
278 let hash = payload_hash(&enc);
279 if self.seen_hashes.get(&record.version) == Some(&hash) {
280 continue;
281 }
282 match self.encryptor.decrypt(&enc).await {
283 Ok(bytes) => {
284 if let Ok(key) = <[u8; S]>::try_from(bytes) {
285 self.secret.store_key(record.version, key);
286 self.seen_hashes.insert(record.version, hash);
287 let now = SystemTime::now();
288 if record.activated_at <= now {
289 self.secret.promote(record.version);
290 } else {
291 self.schedule_promotion(record.version, record.activated_at, token.clone());
292 }
293 }
294 }
295 Err(e) => {
296 error!(
297 group_id = %self.group_id,
298 version = record.version,
299 error = %e,
300 "SecretSyncer: decryption failed during poll"
301 );
302 }
303 }
304 }
305 }
306 Err(e) => {
307 error!(group_id = %self.group_id, error = %e, "SecretSyncer poll failed");
308 if self.sleep_or_cancel(Duration::from_secs(30), &token).await { break; }
309 }
310 }
311 }
312 }
313 }
314 }
315
316 fn schedule_promotion(&self, version: u8, activated_at: SystemTime, token: CancellationToken) {
317 let secret = Arc::clone(&self.secret);
318 tokio::spawn(async move {
319 if let Ok(sleep_dur) = activated_at.duration_since(SystemTime::now()) {
320 tokio::select! {
321 biased;
322 _ = token.cancelled() => return,
323 _ = tokio::time::sleep(sleep_dur) => {}
324 }
325 }
326 secret.promote(version);
327 });
328 }
329
330 async fn sleep_or_cancel(&self, duration: Duration, token: &CancellationToken) -> bool {
331 tokio::select! {
332 biased;
333 _ = token.cancelled() => true,
334 _ = tokio::time::sleep(duration) => false,
335 }
336 }
337}
338
339#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::backend::KeyRecord;
347 use crate::encryptor::Encrypted;
348 use crate::no_op_encryptor::NoOpEncryptor;
349 use crate::secret_rotation::SecretGroup;
350 use crate::encryptor::EncryptorError;
351 use async_trait::async_trait;
352 use std::collections::VecDeque;
353 use std::sync::Mutex;
354
355 #[derive(Debug)]
360 struct MockError;
361 impl std::fmt::Display for MockError {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 write!(f, "mock error")
364 }
365 }
366 impl std::error::Error for MockError {}
367
368 #[derive(Clone)]
370 struct MockBackend {
371 load_response: Vec<KeyRecord>,
372 poll_responses: Arc<Mutex<VecDeque<Result<Vec<KeyRecord>, MockError>>>>,
373 }
374
375 impl MockBackend {
376 fn with_load(records: Vec<KeyRecord>) -> Self {
377 Self {
378 load_response: records,
379 poll_responses: Arc::new(Mutex::new(VecDeque::new())),
380 }
381 }
382
383 fn push_poll(&self, records: Vec<KeyRecord>) {
384 self.poll_responses.lock().unwrap().push_back(Ok(records));
385 }
386
387 fn push_poll_err(&self) {
388 self.poll_responses.lock().unwrap().push_back(Err(MockError));
389 }
390 }
391
392 #[async_trait]
393 impl SecretBackend for MockBackend {
394 type Error = MockError;
395 async fn load_all(&self, _group_id: &str) -> Result<Vec<KeyRecord>, MockError> {
396 Ok(self.load_response.clone())
397 }
398 async fn poll_new(
399 &self,
400 _group_id: &str,
401 _since_time: SystemTime,
402 _since_id: i64,
403 ) -> Result<Vec<KeyRecord>, MockError> {
404 self.poll_responses
405 .lock()
406 .unwrap()
407 .pop_front()
408 .unwrap_or(Ok(vec![]))
409 }
410 }
411
412 #[derive(Clone)]
414 struct CountingEncryptor {
415 decrypt_calls: Arc<Mutex<usize>>,
416 }
417
418 impl CountingEncryptor {
419 fn new() -> Self {
420 Self { decrypt_calls: Arc::new(Mutex::new(0)) }
421 }
422 fn decrypt_calls(&self) -> usize {
423 *self.decrypt_calls.lock().unwrap()
424 }
425 }
426
427 #[async_trait]
428 impl KeyEncryptor for CountingEncryptor {
429 async fn encrypt(&self, plaintext: &[u8]) -> Result<Encrypted, EncryptorError> {
430 Ok(Encrypted { ciphertext: plaintext.to_vec(), nonce: None, key_version: 0 })
431 }
432 async fn decrypt(&self, enc: &Encrypted) -> Result<Vec<u8>, EncryptorError> {
433 *self.decrypt_calls.lock().unwrap() += 1;
434 Ok(enc.ciphertext.clone())
435 }
436 }
437
438 fn rec(id: i64, version: u8, fill: u8, activated_at: SystemTime) -> KeyRecord {
439 KeyRecord {
440 id,
441 version,
442 key_bytes: vec![fill; 32],
443 nonce: None,
444 encryption_key_version: 0,
445 activated_at,
446 }
447 }
448
449 fn make_syncer<E: KeyEncryptor + Clone>(
450 backend: MockBackend,
451 group: Arc<InMemorySecretGroup<256, 32>>,
452 enc: E,
453 ) -> SecretSyncer<MockBackend, E, 256, 32> {
454 SecretSyncer::new(
455 "test-syncer",
456 group,
457 backend,
458 enc,
459 Duration::from_secs(3600),
460 Some(Duration::from_millis(10)),
461 )
462 }
463
464 #[tokio::test]
469 async fn initial_load_applies_all_keys_and_promotes_latest_active() {
470 let now = SystemTime::now();
471 let backend = MockBackend::with_load(vec![
472 rec(1, 0, 0xAA, now - Duration::from_secs(600)),
473 rec(2, 1, 0xBB, now - Duration::from_secs(300)),
474 ]);
475 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
476 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
477 syncer.initial_load(&CancellationToken::new()).await.unwrap();
478 let (v, _) = group.current();
479 assert_eq!(v, 1);
480 }
481
482 #[tokio::test]
483 async fn initial_load_empty_returns_epoch_cursor() {
484 let backend = MockBackend::with_load(vec![]);
485 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
486 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
487 let (t, id) = syncer.initial_load(&CancellationToken::new()).await.unwrap();
488 assert_eq!(t, EPOCH_CURSOR);
489 assert_eq!(id, 0);
490 }
491
492 #[tokio::test]
493 async fn initial_load_returns_max_cursor() {
494 let t0 = SystemTime::now() - Duration::from_secs(60);
495 let t1 = SystemTime::now();
496 let backend = MockBackend::with_load(vec![
497 rec(10, 0, 0xAA, t0),
498 rec(20, 1, 0xBB, t1), ]);
500 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
501 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
502 let (t, id) = syncer.initial_load(&CancellationToken::new()).await.unwrap();
503 assert_eq!(id, 20);
504 assert!(t.duration_since(t1).unwrap_or_default().as_millis() < 5);
505 }
506
507 #[tokio::test]
508 async fn initial_load_stores_future_key_but_does_not_promote_it() {
509 tokio::time::pause();
510 let future_at = SystemTime::now() + Duration::from_secs(30);
511 let backend = MockBackend::with_load(vec![rec(1, 1, 0xCC, future_at)]);
512 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0xFFu8; 32]));
513 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
514 syncer.initial_load(&CancellationToken::new()).await.unwrap();
515
516 assert_eq!(group.resolve(1), Some([0xCC; 32]));
518 assert_eq!(group.current().0, 0, "current must still be the initial version");
519 }
520
521 #[tokio::test]
522 async fn initial_load_future_key_promoted_after_activation_time() {
523 tokio::time::pause();
524 let future_at = SystemTime::now() + Duration::from_secs(10);
525 let backend = MockBackend::with_load(vec![rec(1, 1, 0xCC, future_at)]);
526 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0xFFu8; 32]));
527 let token = CancellationToken::new();
528 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
529 syncer.initial_load(&token).await.unwrap();
530
531 tokio::task::yield_now().await;
534 tokio::time::advance(Duration::from_secs(11)).await;
535 tokio::task::yield_now().await;
536
537 assert_eq!(group.current().0, 1, "key must be promoted after activation time elapses");
538 }
539
540 #[tokio::test]
541 async fn initial_load_skips_version_out_of_ring_range() {
542 let now = SystemTime::now() - Duration::from_secs(1);
543 let backend = MockBackend::with_load(vec![
545 rec(1, 0, 0xAA, now),
546 rec(2, 4, 0xBB, now), ]);
548 let group = Arc::new(InMemorySecretGroup::<4, 32>::new(0, [0u8; 32]));
549 let mut syncer: SecretSyncer<MockBackend, NoOpEncryptor, 4, 32> = SecretSyncer::new(
550 "test-syncer",
551 Arc::clone(&group),
552 backend,
553 NoOpEncryptor,
554 Duration::from_secs(3600),
555 None,
556 );
557 syncer.initial_load(&CancellationToken::new()).await.unwrap();
558
559 assert_eq!(group.current().0, 0);
560 assert!(group.resolve(0).is_some());
561 }
563
564 #[tokio::test]
565 async fn initial_load_dedup_skips_decrypt_on_repeated_load() {
566 let now = SystemTime::now() - Duration::from_secs(60);
567 let backend = MockBackend::with_load(vec![rec(1, 0, 0xAA, now)]);
568 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
569 let enc = CountingEncryptor::new();
570 let mut syncer = make_syncer(backend, Arc::clone(&group), enc.clone());
571
572 syncer.initial_load(&CancellationToken::new()).await.unwrap();
573 assert_eq!(enc.decrypt_calls(), 1);
574
575 syncer.initial_load(&CancellationToken::new()).await.unwrap();
577 assert_eq!(enc.decrypt_calls(), 1, "dedup should skip decrypt for unchanged payload");
578 }
579
580 #[tokio::test]
585 async fn run_exits_on_cancellation() {
586 let backend = MockBackend::with_load(vec![]);
587 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
588 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
589 let cursor = syncer.initial_load(&CancellationToken::new()).await.unwrap();
590
591 let token = CancellationToken::new();
592 let handle = tokio::spawn(syncer.run(token.clone(), cursor));
593 token.cancel();
594 tokio::time::timeout(Duration::from_millis(200), handle)
595 .await
596 .expect("run must exit promptly after cancellation")
597 .unwrap();
598 }
599
600 #[tokio::test]
601 async fn run_applies_polled_keys_and_promotes() {
602 tokio::time::pause();
603 let backend = MockBackend::with_load(vec![]);
604 let poll_handle = backend.clone();
605 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
606 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
607 let cursor = syncer.initial_load(&CancellationToken::new()).await.unwrap();
608
609 let past = SystemTime::now() - Duration::from_secs(5);
610 poll_handle.push_poll(vec![rec(1, 1, 0xBB, past)]);
611
612 let token = CancellationToken::new();
613 let handle = tokio::spawn(syncer.run(token.clone(), cursor));
614
615 tokio::task::yield_now().await;
617 tokio::time::advance(Duration::from_millis(20)).await;
618 tokio::task::yield_now().await;
619
620 assert_eq!(group.current().0, 1);
621 assert_eq!(group.resolve(1), Some([0xBB; 32]));
622
623 token.cancel();
624 handle.await.unwrap();
625 }
626
627 #[tokio::test]
628 async fn run_poll_error_retries_and_eventually_recovers() {
629 tokio::time::pause();
630 let backend = MockBackend::with_load(vec![]);
631 let poll_handle = backend.clone();
632 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
633 let mut syncer = make_syncer(backend, Arc::clone(&group), NoOpEncryptor);
634 let cursor = syncer.initial_load(&CancellationToken::new()).await.unwrap();
635
636 let past = SystemTime::now() - Duration::from_secs(5);
638 poll_handle.push_poll_err();
639 poll_handle.push_poll(vec![rec(1, 1, 0xBB, past)]);
640
641 let token = CancellationToken::new();
642 let handle = tokio::spawn(syncer.run(token.clone(), cursor));
643
644 tokio::task::yield_now().await;
646 tokio::time::advance(Duration::from_millis(15)).await;
647 tokio::task::yield_now().await;
649 tokio::time::advance(Duration::from_secs(31)).await;
651 tokio::task::yield_now().await;
653 tokio::time::advance(Duration::from_millis(15)).await;
655 tokio::task::yield_now().await;
657 tokio::task::yield_now().await;
658
659 assert_eq!(group.current().0, 1, "must recover and apply key after error back-off");
660
661 token.cancel();
662 handle.await.unwrap();
663 }
664
665 #[tokio::test]
666 async fn run_dedup_skips_repeated_poll_records() {
667 tokio::time::pause();
668 let backend = MockBackend::with_load(vec![]);
669 let poll_handle = backend.clone();
670 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
671 let enc = CountingEncryptor::new();
672 let mut syncer = make_syncer(backend, Arc::clone(&group), enc.clone());
673 let cursor = syncer.initial_load(&CancellationToken::new()).await.unwrap();
674
675 let past = SystemTime::now() - Duration::from_secs(5);
676 poll_handle.push_poll(vec![rec(1, 1, 0xBB, past)]);
678 poll_handle.push_poll(vec![rec(1, 1, 0xBB, past)]);
679
680 let token = CancellationToken::new();
681 let handle = tokio::spawn(syncer.run(token.clone(), cursor));
682
683 tokio::task::yield_now().await;
684 tokio::time::advance(Duration::from_millis(25)).await;
686 tokio::task::yield_now().await;
687
688 assert_eq!(enc.decrypt_calls(), 1, "second identical poll record must be deduped");
689
690 token.cancel();
691 handle.await.unwrap();
692 }
693}