1use crate::backend::SecretBackend;
2use crate::encryptor::KeyEncryptor;
3use crate::rotator::{KeyRotator, SecretRotationBackend};
4use crate::secret_rotation::{InMemorySecretGroup, SecretGroup};
5use crate::syncer::SecretSyncer;
6
7use crate::util::generate_secret;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::task::JoinHandle;
11use tokio_util::sync::CancellationToken;
12
13pub struct SecretManagerHandle {
16 syncer: JoinHandle<()>,
17 rotator: JoinHandle<()>,
18}
19
20impl SecretManagerHandle {
21 pub async fn wait(self) {
23 let _ = tokio::join!(self.syncer, self.rotator);
24 }
25}
26
27pub struct SecretManager<B, E, const V: usize = 256, const S: usize = 32>
51where
52 B: SecretBackend + SecretRotationBackend + Clone,
53 E: KeyEncryptor + Clone,
54{
55 group_id: String,
56 group: Arc<InMemorySecretGroup<V, S>>,
57 backend: B,
58 encryptor: E,
59 rotation_interval: Duration,
60 propagation_delay: Duration,
61 poll_interval: Option<Duration>,
62 generate_key: Arc<dyn Fn() -> [u8; S] + Send + Sync + 'static>,
63}
64
65impl<B, E, const V: usize, const S: usize> SecretManager<B, E, V, S>
66where
67 B: SecretBackend + SecretRotationBackend + Clone,
68 E: KeyEncryptor + Clone,
69{
70 pub fn new(
85 group_id: impl Into<String>,
86 group: Arc<InMemorySecretGroup<V, S>>,
87 backend: B,
88 encryptor: E,
89 rotation_interval: Duration,
90 propagation_delay: Duration,
91 poll_interval: Option<Duration>,
92 generate_key: Option<fn() -> [u8; S]>,
93 ) -> Self {
94 let generate_key = generate_key.unwrap_or(generate_secret::<S>);
95 Self {
96 group_id: group_id.into(),
97 group,
98 backend,
99 encryptor,
100 rotation_interval,
101 propagation_delay,
102 poll_interval,
103 generate_key: Arc::new(generate_key),
104 }
105 }
106
107 pub async fn start(self, token: CancellationToken) -> Result<SecretManagerHandle, <B as SecretBackend>::Error> {
149 let generate_key = Arc::clone(&self.generate_key);
150
151 let mut syncer = SecretSyncer::new(
152 self.group_id.clone(),
153 Arc::clone(&self.group),
154 self.backend.clone(),
155 self.encryptor.clone(),
156 self.rotation_interval,
157 self.poll_interval,
158 );
159
160 let cursor = syncer.initial_load(&token).await?;
161
162 let rotator: KeyRotator<B, E, V, S> = KeyRotator::new(
163 self.group_id,
164 self.backend,
165 self.rotation_interval,
166 self.propagation_delay,
167 self.encryptor,
168 move || (generate_key)(),
169 );
170
171 Ok(SecretManagerHandle {
172 syncer: tokio::spawn(syncer.run(token.clone(), cursor)),
173 rotator: tokio::spawn(rotator.run(token)),
174 })
175 }
176}
177
178impl<B, E, const V: usize, const S: usize> SecretGroup<V, S> for SecretManager<B, E, V, S>
179where
180 B: SecretBackend + SecretRotationBackend + Clone,
181 E: KeyEncryptor + Clone,
182{
183 fn current(&self) -> (u8, [u8; S]) {
184 self.group.current()
185 }
186
187 fn resolve(&self, version: u8) -> Option<[u8; S]> {
188 self.group.resolve(version)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::backend::KeyRecord;
196 use crate::encryptor::Encrypted;
197 use crate::no_op_encryptor::NoOpEncryptor;
198 use crate::rotator::SecretRotationBackend;
199 use async_trait::async_trait;
200 use std::collections::VecDeque;
201 use std::sync::Mutex;
202 use std::time::SystemTime;
203
204 #[derive(Debug)]
205 struct MockError;
206 impl std::fmt::Display for MockError {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 write!(f, "mock error")
209 }
210 }
211 impl std::error::Error for MockError {}
212
213 #[derive(Clone)]
214 struct MockBackend {
215 load_response: Vec<KeyRecord>,
216 poll_responses: Arc<Mutex<VecDeque<Vec<KeyRecord>>>>,
217 latest_responses: Arc<Mutex<VecDeque<Option<(u8, SystemTime)>>>>,
218 }
219
220 #[async_trait]
221 impl SecretBackend for MockBackend {
222 type Error = MockError;
223 async fn load_all(&self, _group_id: &str) -> Result<Vec<KeyRecord>, MockError> {
224 Ok(self.load_response.clone())
225 }
226 async fn poll_new(
227 &self,
228 _group_id: &str,
229 _since_time: SystemTime,
230 _since_id: i64,
231 ) -> Result<Vec<KeyRecord>, MockError> {
232 Ok(self
233 .poll_responses
234 .lock()
235 .unwrap()
236 .pop_front()
237 .unwrap_or_default())
238 }
239 }
240
241 #[async_trait]
242 impl SecretRotationBackend for MockBackend {
243 type Error = MockError;
244 async fn latest_key_info(
245 &self,
246 _group_id: &str,
247 ) -> Result<Option<(u8, SystemTime)>, MockError> {
248 Ok(self
249 .latest_responses
250 .lock()
251 .unwrap()
252 .pop_front()
253 .unwrap_or(None))
254 }
255 async fn try_insert_key(
256 &self,
257 _group_id: &str,
258 _expected_version: Option<u8>,
259 _new_version: u8,
260 _encrypted: &Encrypted,
261 _activated_at: SystemTime,
262 ) -> Result<bool, MockError> {
263 Ok(false)
264 }
265 }
266
267 #[tokio::test]
268 async fn start_hydrates_group_and_returns_ok() {
269 let backend = MockBackend {
270 load_response: vec![KeyRecord {
271 id: 1,
272 version: 0,
273 key_bytes: vec![0xAA; 32],
274 nonce: None,
275 encryption_key_version: 0,
276 activated_at: SystemTime::now() - Duration::from_secs(300),
277 }],
278 poll_responses: Arc::new(Mutex::new(VecDeque::new())),
279 latest_responses: Arc::new(Mutex::new(VecDeque::new())),
280 };
281 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(0, [0u8; 32]));
282 let manager = SecretManager::new(
283 "test-manager",
284 Arc::clone(&group),
285 backend,
286 NoOpEncryptor,
287 Duration::from_secs(3600),
288 Duration::from_secs(10),
289 None,
290 Some(|| [0xFFu8; 32]),
291 );
292 let token = CancellationToken::new();
293 let handle = manager.start(token.clone()).await.expect("start should succeed");
294 let (v, _) = group.current();
295 assert_eq!(v, 0);
296 token.cancel();
297 handle.wait().await;
298 }
299
300 #[test]
301 fn manager_implements_secret_group() {
302 let backend = MockBackend {
303 load_response: vec![],
304 poll_responses: Arc::new(Mutex::new(VecDeque::new())),
305 latest_responses: Arc::new(Mutex::new(VecDeque::new())),
306 };
307 let group = Arc::new(InMemorySecretGroup::<256, 32>::new(42, [0xEEu8; 32]));
308 let manager = SecretManager::new(
309 "test-manager",
310 group,
311 backend,
312 NoOpEncryptor,
313 Duration::from_secs(3600),
314 Duration::from_secs(10),
315 None,
316 Some(|| [0u8; 32]),
317 );
318
319 let sg: &dyn SecretGroup<256, 32> = &manager;
320 let (v, k) = sg.current();
321 assert_eq!(v, 42);
322 assert_eq!(k, [0xEEu8; 32]);
323 }
324}