1use std::collections::HashMap;
4use std::sync::Arc;
5
6use kube::runtime::events::Recorder;
7use std::time::Duration;
8
9use sqlx::postgres::{PgPool, PgPoolOptions};
10use tokio::sync::{Mutex, RwLock};
11
12use crate::observability::OperatorObservability;
13
14const POOL_MAX_CONNECTIONS: u32 = 5;
19
20const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
23
24const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
25
26#[derive(Clone)]
27struct CachedPool {
28 resource_version: Option<String>,
29 pool: PgPool,
30}
31
32pub struct DatabaseLockGuard {
38 key: String,
39 locks: Arc<Mutex<HashMap<String, ()>>>,
40}
41
42impl Drop for DatabaseLockGuard {
43 fn drop(&mut self) {
44 if let Ok(mut map) = self.locks.try_lock() {
46 map.remove(&self.key);
47 tracing::debug!(database = %self.key, "released in-memory database lock");
48 } else {
49 let key = self.key.clone();
53 let locks = Arc::clone(&self.locks);
54 if let Ok(handle) = tokio::runtime::Handle::try_current() {
55 handle.spawn(async move {
56 locks.lock().await.remove(&key);
57 tracing::debug!(database = %key, "released in-memory database lock (deferred)");
58 });
59 tracing::debug!(
60 database = %self.key,
61 "deferred in-memory database lock release to background task"
62 );
63 } else {
64 let mut map = self.locks.blocking_lock();
67 map.remove(&key);
68 tracing::debug!(
69 database = %key,
70 "released in-memory database lock (fallback sync)"
71 );
72 }
73 }
74 }
75}
76
77#[derive(Clone)]
79pub struct OperatorContext {
80 pub kube_client: kube::Client,
82
83 pub event_recorder: Recorder,
85
86 pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
88 database_locks: Arc<Mutex<HashMap<String, ()>>>,
94
95 pub observability: OperatorObservability,
97}
98
99impl OperatorContext {
100 pub fn new(
102 kube_client: kube::Client,
103 observability: OperatorObservability,
104 event_recorder: Recorder,
105 ) -> Self {
106 Self {
107 kube_client,
108 event_recorder,
109 pool_cache: Arc::new(RwLock::new(HashMap::new())),
110 observability,
111 database_locks: Arc::new(Mutex::new(HashMap::new())),
112 }
113 }
114
115 pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
121 let mut locks = self.database_locks.lock().await;
122 if locks.contains_key(database_identity) {
123 tracing::info!(
124 database = %database_identity,
125 "in-memory database lock contention — another reconcile is in progress"
126 );
127 return None;
128 }
129 locks.insert(database_identity.to_string(), ());
130 tracing::debug!(database = %database_identity, "acquired in-memory database lock");
131 Some(DatabaseLockGuard {
132 key: database_identity.to_string(),
133 locks: Arc::clone(&self.database_locks),
134 })
135 }
136
137 pub async fn get_or_create_pool(
142 &self,
143 namespace: &str,
144 secret_name: &str,
145 secret_key: &str,
146 ) -> Result<PgPool, ContextError> {
147 let cache_key = format!("{namespace}/{secret_name}/{secret_key}");
148
149 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
151 kube::Api::namespaced(self.kube_client.clone(), namespace);
152
153 let secret =
154 secrets_api
155 .get(secret_name)
156 .await
157 .map_err(|err| ContextError::SecretFetch {
158 name: secret_name.to_string(),
159 namespace: namespace.to_string(),
160 source: err,
161 })?;
162
163 let resource_version = secret.metadata.resource_version.clone();
164
165 {
167 let cache = self.pool_cache.read().await;
168 if let Some(cached) = cache.get(&cache_key)
169 && cached.resource_version == resource_version
170 {
171 return Ok(cached.pool.clone());
172 }
173 }
174
175 let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
176 name: secret_name.to_string(),
177 key: secret_key.to_string(),
178 })?;
179
180 let url_bytes = data
181 .get(secret_key)
182 .ok_or_else(|| ContextError::SecretMissing {
183 name: secret_name.to_string(),
184 key: secret_key.to_string(),
185 })?;
186
187 let database_url =
188 String::from_utf8(url_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
189 name: secret_name.to_string(),
190 key: secret_key.to_string(),
191 })?;
192
193 let pool = PgPoolOptions::new()
197 .max_connections(POOL_MAX_CONNECTIONS)
198 .acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
199 .connect(&database_url)
200 .await
201 .map_err(|err| ContextError::DatabaseConnect { source: err })?;
202
203 {
205 let mut cache = self.pool_cache.write().await;
206 cache.insert(
207 cache_key,
208 CachedPool {
209 resource_version,
210 pool: pool.clone(),
211 },
212 );
213 }
214
215 Ok(pool)
216 }
217
218 pub async fn fetch_secret_value(
222 &self,
223 namespace: &str,
224 secret_name: &str,
225 secret_key: &str,
226 ) -> Result<String, ContextError> {
227 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
228 kube::Api::namespaced(self.kube_client.clone(), namespace);
229
230 let secret =
231 secrets_api
232 .get(secret_name)
233 .await
234 .map_err(|err| ContextError::SecretFetch {
235 name: secret_name.to_string(),
236 namespace: namespace.to_string(),
237 source: err,
238 })?;
239
240 let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
241 name: secret_name.to_string(),
242 key: secret_key.to_string(),
243 })?;
244
245 let value_bytes = data
246 .get(secret_key)
247 .ok_or_else(|| ContextError::SecretMissing {
248 name: secret_name.to_string(),
249 key: secret_key.to_string(),
250 })?;
251
252 String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
253 name: secret_name.to_string(),
254 key: secret_key.to_string(),
255 })
256 }
257
258 pub async fn evict_pool(&self, namespace: &str, secret_name: &str, secret_key: &str) {
260 let cache_key = format!("{namespace}/{secret_name}/{secret_key}");
261 let mut cache = self.pool_cache.write().await;
262 cache.remove(&cache_key);
263 }
264}
265
266#[derive(Debug, thiserror::Error)]
268pub enum ContextError {
269 #[error("failed to fetch Secret {namespace}/{name}: {source}")]
270 SecretFetch {
271 name: String,
272 namespace: String,
273 source: kube::Error,
274 },
275
276 #[error("Secret \"{name}\" does not contain key \"{key}\"")]
277 SecretMissing { name: String, key: String },
278
279 #[error("failed to connect to database: {source}")]
280 DatabaseConnect { source: sqlx::Error },
281}
282
283impl ContextError {
284 pub fn is_secret_fetch_non_transient(&self) -> bool {
286 matches!(
287 self,
288 ContextError::SecretFetch {
289 source: kube::Error::Api(response),
290 ..
291 } if (400..500).contains(&response.code) && response.code != 429
292 )
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn pool_cache_key_format() {
302 let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
304 assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
305 }
306
307 #[test]
308 fn secret_fetch_not_found_is_non_transient() {
309 let error = ContextError::SecretFetch {
310 name: "db-credentials".into(),
311 namespace: "default".into(),
312 source: kube::Error::Api(
313 kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
314 .with_code(404)
315 .boxed(),
316 ),
317 };
318
319 assert!(error.is_secret_fetch_non_transient());
320 }
321
322 #[test]
323 fn secret_fetch_forbidden_is_non_transient() {
324 let error = ContextError::SecretFetch {
325 name: "db-credentials".into(),
326 namespace: "default".into(),
327 source: kube::Error::Api(
328 kube::core::Status::failure("forbidden", "Forbidden")
329 .with_code(403)
330 .boxed(),
331 ),
332 };
333
334 assert!(error.is_secret_fetch_non_transient());
335 }
336
337 #[test]
338 fn secret_fetch_server_error_remains_transient() {
339 let error = ContextError::SecretFetch {
340 name: "db-credentials".into(),
341 namespace: "default".into(),
342 source: kube::Error::Api(
343 kube::core::Status::failure("internal error", "InternalError")
344 .with_code(500)
345 .boxed(),
346 ),
347 };
348
349 assert!(!error.is_secret_fetch_non_transient());
350 }
351
352 #[tokio::test]
353 async fn try_lock_database_acquires_when_free() {
354 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
355 let ctx = OperatorContextLockHelper {
356 database_locks: locks,
357 };
358 let guard = ctx.try_lock("db-a").await;
359 assert!(guard.is_some(), "should acquire lock on free database");
360 }
361
362 #[tokio::test]
363 async fn try_lock_database_contention_returns_none() {
364 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
365 let ctx = OperatorContextLockHelper {
366 database_locks: locks,
367 };
368
369 let _guard1 = ctx
370 .try_lock("db-a")
371 .await
372 .expect("first lock should succeed");
373 let guard2 = ctx.try_lock("db-a").await;
374 assert!(guard2.is_none(), "second lock on same database should fail");
375 }
376
377 #[tokio::test]
378 async fn try_lock_database_different_databases_independent() {
379 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
380 let ctx = OperatorContextLockHelper {
381 database_locks: locks,
382 };
383
384 let guard_a = ctx.try_lock("db-a").await;
385 let guard_b = ctx.try_lock("db-b").await;
386 assert!(guard_a.is_some(), "lock on db-a should succeed");
387 assert!(
388 guard_b.is_some(),
389 "lock on db-b should succeed (different database)"
390 );
391 }
392
393 #[tokio::test]
394 async fn try_lock_database_released_after_drop() {
395 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
396 let ctx = OperatorContextLockHelper {
397 database_locks: Arc::clone(&locks),
398 };
399
400 {
401 let _guard = ctx.try_lock("db-a").await.expect("should acquire");
402 }
404
405 let guard2 = ctx.try_lock("db-a").await;
407 assert!(
408 guard2.is_some(),
409 "should re-acquire after previous guard dropped"
410 );
411 }
412
413 #[tokio::test]
414 async fn try_lock_database_concurrent_contention() {
415 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
416
417 let locks1 = Arc::clone(&locks);
419 let locks2 = Arc::clone(&locks);
420
421 let handle1 = tokio::spawn(async move {
422 let ctx = OperatorContextLockHelper {
423 database_locks: locks1,
424 };
425 let guard = ctx.try_lock("shared-db").await;
426 if guard.is_some() {
427 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
429 }
430 guard.is_some()
431 });
432
433 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
435
436 let handle2 = tokio::spawn(async move {
437 let ctx = OperatorContextLockHelper {
438 database_locks: locks2,
439 };
440 let guard = ctx.try_lock("shared-db").await;
441 guard.is_some()
442 });
443
444 let (r1, r2) = tokio::join!(handle1, handle2);
445 let acquired1 = r1.unwrap();
446 let acquired2 = r2.unwrap();
447
448 assert!(
450 acquired1 ^ acquired2,
451 "exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
452 );
453 }
454
455 struct OperatorContextLockHelper {
457 database_locks: Arc<Mutex<HashMap<String, ()>>>,
458 }
459
460 impl OperatorContextLockHelper {
461 async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
462 let mut locks = self.database_locks.lock().await;
463 if locks.contains_key(database_identity) {
464 return None;
465 }
466 locks.insert(database_identity.to_string(), ());
467 Some(DatabaseLockGuard {
468 key: database_identity.to_string(),
469 locks: Arc::clone(&self.database_locks),
470 })
471 }
472 }
473
474 #[tokio::test]
475 async fn try_lock_database_high_concurrency_same_db() {
476 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
478 let concurrency = 50;
479 let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
480 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
481
482 let mut handles = Vec::with_capacity(concurrency);
483 for _ in 0..concurrency {
484 let locks_clone = Arc::clone(&locks);
485 let count = Arc::clone(&acquired_count);
486 let bar = Arc::clone(&barrier);
487 handles.push(tokio::spawn(async move {
488 bar.wait().await;
490 let ctx = OperatorContextLockHelper {
491 database_locks: locks_clone,
492 };
493 let guard = ctx.try_lock("contested-db").await;
494 if guard.is_some() {
495 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
496 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
498 }
499 }));
500 }
501
502 for h in handles {
503 h.await.unwrap();
504 }
505
506 let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
508 assert_eq!(
509 total, 1,
510 "exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
511 );
512 }
513
514 #[tokio::test]
515 async fn try_lock_database_high_concurrency_different_dbs() {
516 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
518 let concurrency = 50;
519 let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
520 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
521
522 let mut handles = Vec::with_capacity(concurrency);
523 for i in 0..concurrency {
524 let locks_clone = Arc::clone(&locks);
525 let count = Arc::clone(&acquired_count);
526 let bar = Arc::clone(&barrier);
527 handles.push(tokio::spawn(async move {
528 bar.wait().await;
529 let ctx = OperatorContextLockHelper {
530 database_locks: locks_clone,
531 };
532 let db_name = format!("db-{i}");
533 let guard = ctx.try_lock(&db_name).await;
534 if guard.is_some() {
535 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
536 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
537 }
538 }));
539 }
540
541 for h in handles {
542 h.await.unwrap();
543 }
544
545 let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
546 assert_eq!(
547 total, concurrency,
548 "all {concurrency} tasks locking different dbs should succeed, got {total}"
549 );
550 }
551
552 #[tokio::test]
553 async fn try_lock_database_acquire_release_cycle_under_contention() {
554 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
558 let concurrency = 20;
559 let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
560 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
561
562 let mut handles = Vec::with_capacity(concurrency);
563 for _ in 0..concurrency {
564 let locks_clone = Arc::clone(&locks);
565 let count = Arc::clone(&success_count);
566 let bar = Arc::clone(&barrier);
567 handles.push(tokio::spawn(async move {
568 bar.wait().await;
569 for _ in 0..100 {
572 let ctx = OperatorContextLockHelper {
573 database_locks: Arc::clone(&locks_clone),
574 };
575 if let Some(_guard) = ctx.try_lock("shared-db").await {
576 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
577 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
579 return;
580 }
581 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
582 }
583 panic!("task failed to acquire lock after 100 retries");
585 }));
586 }
587
588 for h in handles {
589 h.await.unwrap();
590 }
591
592 let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
593 assert_eq!(
594 total, concurrency,
595 "all {concurrency} tasks should eventually acquire the lock"
596 );
597 }
598}