1use std::collections::HashMap;
4use std::sync::Arc;
5
6use kube::runtime::events::Recorder;
7use std::time::Duration;
8
9use sqlx::postgres::{PgPool, PgPoolOptions};
10use tokio::sync::{Mutex, RwLock};
11
12use crate::crd::{ConnectionSpec, SecretKeySelector};
13use crate::observability::OperatorObservability;
14
15const POOL_MAX_CONNECTIONS: u32 = 5;
20
21const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
24
25const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
26
27#[derive(Clone)]
28struct CachedPool {
29 resource_version: Option<String>,
30 secret_fingerprint: Option<String>,
32 pool: PgPool,
33}
34
35pub struct DatabaseLockGuard {
41 key: String,
42 locks: Arc<Mutex<HashMap<String, ()>>>,
43}
44
45impl Drop for DatabaseLockGuard {
46 fn drop(&mut self) {
47 if let Ok(mut map) = self.locks.try_lock() {
49 map.remove(&self.key);
50 tracing::debug!(database = %self.key, "released in-memory database lock");
51 } else {
52 let key = self.key.clone();
56 let locks = Arc::clone(&self.locks);
57 if let Ok(handle) = tokio::runtime::Handle::try_current() {
58 handle.spawn(async move {
59 locks.lock().await.remove(&key);
60 tracing::debug!(database = %key, "released in-memory database lock (deferred)");
61 });
62 tracing::debug!(
63 database = %self.key,
64 "deferred in-memory database lock release to background task"
65 );
66 } else {
67 let mut map = self.locks.blocking_lock();
70 map.remove(&key);
71 tracing::debug!(
72 database = %key,
73 "released in-memory database lock (fallback sync)"
74 );
75 }
76 }
77 }
78}
79
80#[derive(Clone)]
82pub struct OperatorContext {
83 pub kube_client: kube::Client,
85
86 pub event_recorder: Recorder,
88
89 pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
91 database_locks: Arc<Mutex<HashMap<String, ()>>>,
97
98 pub observability: OperatorObservability,
100}
101
102impl OperatorContext {
103 pub fn new(
105 kube_client: kube::Client,
106 observability: OperatorObservability,
107 event_recorder: Recorder,
108 ) -> Self {
109 Self {
110 kube_client,
111 event_recorder,
112 pool_cache: Arc::new(RwLock::new(HashMap::new())),
113 observability,
114 database_locks: Arc::new(Mutex::new(HashMap::new())),
115 }
116 }
117
118 pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
124 let mut locks = self.database_locks.lock().await;
125 if locks.contains_key(database_identity) {
126 tracing::info!(
127 database = %database_identity,
128 "in-memory database lock contention — another reconcile is in progress"
129 );
130 return None;
131 }
132 locks.insert(database_identity.to_string(), ());
133 tracing::debug!(database = %database_identity, "acquired in-memory database lock");
134 Some(DatabaseLockGuard {
135 key: database_identity.to_string(),
136 locks: Arc::clone(&self.database_locks),
137 })
138 }
139
140 async fn resolve_param(
144 &self,
145 namespace: &str,
146 literal: &Option<String>,
147 secret: &Option<SecretKeySelector>,
148 ) -> Result<Option<String>, ContextError> {
149 if let Some(val) = literal {
150 return Ok(Some(val.clone()));
151 }
152 if let Some(sel) = secret {
153 return Ok(Some(
154 self.fetch_secret_value(namespace, &sel.name, &sel.key)
155 .await?,
156 ));
157 }
158 Ok(None)
159 }
160
161 pub async fn resolve_connection_url(
166 &self,
167 namespace: &str,
168 connection: &ConnectionSpec,
169 ) -> Result<String, ContextError> {
170 if let Some(ref secret_ref) = connection.secret_ref {
171 self.fetch_secret_value(
173 namespace,
174 &secret_ref.name,
175 connection.effective_secret_key(),
176 )
177 .await
178 } else if let Some(ref params) = connection.params {
179 let host = self
181 .resolve_param(namespace, ¶ms.host, ¶ms.host_secret)
182 .await?
183 .ok_or_else(|| ContextError::EmptyResolvedValue {
184 field: "host".to_string(),
185 })?;
186 if host.trim().is_empty() {
187 return Err(ContextError::EmptyResolvedValue {
188 field: "host".to_string(),
189 });
190 }
191
192 let port_str = params.port.map(|p| p.to_string());
193 let port = self
194 .resolve_param(namespace, &port_str, ¶ms.port_secret)
195 .await?
196 .unwrap_or_else(|| "5432".to_string());
197 if port.trim().is_empty() {
198 return Err(ContextError::EmptyResolvedValue {
199 field: "port".to_string(),
200 });
201 }
202
203 let dbname = self
204 .resolve_param(namespace, ¶ms.dbname, ¶ms.dbname_secret)
205 .await?
206 .ok_or_else(|| ContextError::EmptyResolvedValue {
207 field: "dbname".to_string(),
208 })?;
209 if dbname.trim().is_empty() {
210 return Err(ContextError::EmptyResolvedValue {
211 field: "dbname".to_string(),
212 });
213 }
214
215 let username = self
216 .resolve_param(namespace, ¶ms.username, ¶ms.username_secret)
217 .await?
218 .ok_or_else(|| ContextError::EmptyResolvedValue {
219 field: "username".to_string(),
220 })?;
221 if username.trim().is_empty() {
222 return Err(ContextError::EmptyResolvedValue {
223 field: "username".to_string(),
224 });
225 }
226
227 let password = self
228 .resolve_param(namespace, ¶ms.password, ¶ms.password_secret)
229 .await?
230 .ok_or_else(|| ContextError::EmptyResolvedValue {
231 field: "password".to_string(),
232 })?;
233 if password.trim().is_empty() {
234 return Err(ContextError::EmptyResolvedValue {
235 field: "password".to_string(),
236 });
237 }
238
239 use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
240 let encoded_username = utf8_percent_encode(&username, NON_ALPHANUMERIC).to_string();
241 let encoded_password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string();
242
243 let mut url = format!(
244 "postgresql://{encoded_username}:{encoded_password}@{host}:{port}/{dbname}"
245 );
246
247 if let Some(ssl_mode) = self
248 .resolve_param(namespace, ¶ms.ssl_mode, ¶ms.ssl_mode_secret)
249 .await?
250 {
251 if !crate::crd::VALID_SSL_MODES.contains(&ssl_mode.as_str()) {
254 return Err(ContextError::InvalidResolvedSslMode { value: ssl_mode });
255 }
256 url.push_str("?sslmode=");
257 url.push_str(&ssl_mode);
258 }
259
260 Ok(url)
261 } else {
262 Err(ContextError::SecretMissing {
263 name: "connection".to_string(),
264 key: "neither secretRef nor params is set".to_string(),
265 })
266 }
267 }
268
269 pub async fn get_or_create_pool(
274 &self,
275 namespace: &str,
276 connection: &ConnectionSpec,
277 ) -> Result<PgPool, ContextError> {
278 let cache_key = connection.cache_key(namespace);
279
280 let (resource_version, secret_fingerprint) =
284 if let Some(ref secret_ref) = connection.secret_ref {
285 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
286 kube::Api::namespaced(self.kube_client.clone(), namespace);
287 let secret = secrets_api.get(&secret_ref.name).await.map_err(|err| {
288 ContextError::SecretFetch {
289 name: secret_ref.name.clone(),
290 namespace: namespace.to_string(),
291 source: err,
292 }
293 })?;
294 (secret.metadata.resource_version, None)
295 } else if connection.params.is_some() {
296 let mut secret_names = std::collections::BTreeSet::new();
299 connection.collect_secret_names(&mut secret_names);
300
301 if secret_names.is_empty() {
302 (None, Some(String::new()))
304 } else {
305 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
306 kube::Api::namespaced(self.kube_client.clone(), namespace);
307 let mut fingerprint_parts = Vec::new();
308 for name in &secret_names {
309 let secret = secrets_api.get(name).await.map_err(|err| {
310 ContextError::SecretFetch {
311 name: name.clone(),
312 namespace: namespace.to_string(),
313 source: err,
314 }
315 })?;
316 let rv = secret
317 .metadata
318 .resource_version
319 .unwrap_or_else(|| "unknown".to_string());
320 fingerprint_parts.push(format!("{name}={rv}"));
321 }
322 (None, Some(fingerprint_parts.join(",")))
323 }
324 } else {
325 (None, None)
326 };
327
328 {
330 let cache = self.pool_cache.read().await;
331 if let Some(cached) = cache.get(&cache_key) {
332 let version_matches = match (&resource_version, &cached.resource_version) {
335 (Some(current), Some(cached_rv)) => current == cached_rv,
336 _ => true,
337 };
338 let fingerprint_matches = match (&secret_fingerprint, &cached.secret_fingerprint) {
339 (Some(current), Some(cached_fp)) => current == cached_fp,
340 (None, None) => true,
341 _ => false,
342 };
343 if version_matches && fingerprint_matches {
344 return Ok(cached.pool.clone());
345 }
346 }
347 }
348
349 let database_url = self.resolve_connection_url(namespace, connection).await?;
350
351 let pool = PgPoolOptions::new()
355 .max_connections(POOL_MAX_CONNECTIONS)
356 .acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
357 .connect(&database_url)
358 .await
359 .map_err(|err| ContextError::DatabaseConnect { source: err })?;
360
361 {
363 let mut cache = self.pool_cache.write().await;
364 cache.insert(
365 cache_key,
366 CachedPool {
367 resource_version,
368 secret_fingerprint,
369 pool: pool.clone(),
370 },
371 );
372 }
373
374 Ok(pool)
375 }
376
377 pub async fn fetch_secret_value(
381 &self,
382 namespace: &str,
383 secret_name: &str,
384 secret_key: &str,
385 ) -> Result<String, ContextError> {
386 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
387 kube::Api::namespaced(self.kube_client.clone(), namespace);
388
389 let secret =
390 secrets_api
391 .get(secret_name)
392 .await
393 .map_err(|err| ContextError::SecretFetch {
394 name: secret_name.to_string(),
395 namespace: namespace.to_string(),
396 source: err,
397 })?;
398
399 let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
400 name: secret_name.to_string(),
401 key: secret_key.to_string(),
402 })?;
403
404 let value_bytes = data
405 .get(secret_key)
406 .ok_or_else(|| ContextError::SecretMissing {
407 name: secret_name.to_string(),
408 key: secret_key.to_string(),
409 })?;
410
411 String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
412 name: secret_name.to_string(),
413 key: secret_key.to_string(),
414 })
415 }
416
417 pub async fn evict_pool(&self, namespace: &str, connection: &ConnectionSpec) {
419 let cache_key = connection.cache_key(namespace);
420 let mut cache = self.pool_cache.write().await;
421 cache.remove(&cache_key);
422 }
423}
424
425#[derive(Debug, thiserror::Error)]
427pub enum ContextError {
428 #[error("failed to fetch Secret {namespace}/{name}: {source}")]
429 SecretFetch {
430 name: String,
431 namespace: String,
432 source: kube::Error,
433 },
434
435 #[error("Secret \"{name}\" does not contain key \"{key}\"")]
436 SecretMissing { name: String, key: String },
437
438 #[error("failed to connect to database: {source}")]
439 DatabaseConnect { source: sqlx::Error },
440
441 #[error("connection param \"{field}\" resolved to an empty or whitespace-only value")]
442 EmptyResolvedValue { field: String },
443
444 #[error(
445 "connection param sslMode resolved to invalid value \"{value}\" (expected one of: disable, allow, prefer, require, verify-ca, verify-full)"
446 )]
447 InvalidResolvedSslMode { value: String },
448}
449
450impl ContextError {
451 pub fn is_secret_fetch_non_transient(&self) -> bool {
453 matches!(
454 self,
455 ContextError::SecretFetch {
456 source: kube::Error::Api(response),
457 ..
458 } if (400..500).contains(&response.code) && response.code != 429
459 )
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn pool_cache_key_format() {
469 let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
471 assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
472 }
473
474 #[test]
475 fn secret_fetch_not_found_is_non_transient() {
476 let error = ContextError::SecretFetch {
477 name: "db-credentials".into(),
478 namespace: "default".into(),
479 source: kube::Error::Api(
480 kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
481 .with_code(404)
482 .boxed(),
483 ),
484 };
485
486 assert!(error.is_secret_fetch_non_transient());
487 }
488
489 #[test]
490 fn secret_fetch_forbidden_is_non_transient() {
491 let error = ContextError::SecretFetch {
492 name: "db-credentials".into(),
493 namespace: "default".into(),
494 source: kube::Error::Api(
495 kube::core::Status::failure("forbidden", "Forbidden")
496 .with_code(403)
497 .boxed(),
498 ),
499 };
500
501 assert!(error.is_secret_fetch_non_transient());
502 }
503
504 #[test]
505 fn secret_fetch_server_error_remains_transient() {
506 let error = ContextError::SecretFetch {
507 name: "db-credentials".into(),
508 namespace: "default".into(),
509 source: kube::Error::Api(
510 kube::core::Status::failure("internal error", "InternalError")
511 .with_code(500)
512 .boxed(),
513 ),
514 };
515
516 assert!(!error.is_secret_fetch_non_transient());
517 }
518
519 #[tokio::test]
520 async fn try_lock_database_acquires_when_free() {
521 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
522 let ctx = OperatorContextLockHelper {
523 database_locks: locks,
524 };
525 let guard = ctx.try_lock("db-a").await;
526 assert!(guard.is_some(), "should acquire lock on free database");
527 }
528
529 #[tokio::test]
530 async fn try_lock_database_contention_returns_none() {
531 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
532 let ctx = OperatorContextLockHelper {
533 database_locks: locks,
534 };
535
536 let _guard1 = ctx
537 .try_lock("db-a")
538 .await
539 .expect("first lock should succeed");
540 let guard2 = ctx.try_lock("db-a").await;
541 assert!(guard2.is_none(), "second lock on same database should fail");
542 }
543
544 #[tokio::test]
545 async fn try_lock_database_different_databases_independent() {
546 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
547 let ctx = OperatorContextLockHelper {
548 database_locks: locks,
549 };
550
551 let guard_a = ctx.try_lock("db-a").await;
552 let guard_b = ctx.try_lock("db-b").await;
553 assert!(guard_a.is_some(), "lock on db-a should succeed");
554 assert!(
555 guard_b.is_some(),
556 "lock on db-b should succeed (different database)"
557 );
558 }
559
560 #[tokio::test]
561 async fn try_lock_database_released_after_drop() {
562 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
563 let ctx = OperatorContextLockHelper {
564 database_locks: Arc::clone(&locks),
565 };
566
567 {
568 let _guard = ctx.try_lock("db-a").await.expect("should acquire");
569 }
571
572 let guard2 = ctx.try_lock("db-a").await;
574 assert!(
575 guard2.is_some(),
576 "should re-acquire after previous guard dropped"
577 );
578 }
579
580 #[tokio::test]
581 async fn try_lock_database_concurrent_contention() {
582 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
583
584 let locks1 = Arc::clone(&locks);
586 let locks2 = Arc::clone(&locks);
587
588 let handle1 = tokio::spawn(async move {
589 let ctx = OperatorContextLockHelper {
590 database_locks: locks1,
591 };
592 let guard = ctx.try_lock("shared-db").await;
593 if guard.is_some() {
594 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
596 }
597 guard.is_some()
598 });
599
600 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
602
603 let handle2 = tokio::spawn(async move {
604 let ctx = OperatorContextLockHelper {
605 database_locks: locks2,
606 };
607 let guard = ctx.try_lock("shared-db").await;
608 guard.is_some()
609 });
610
611 let (r1, r2) = tokio::join!(handle1, handle2);
612 let acquired1 = r1.unwrap();
613 let acquired2 = r2.unwrap();
614
615 assert!(
617 acquired1 ^ acquired2,
618 "exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
619 );
620 }
621
622 struct OperatorContextLockHelper {
624 database_locks: Arc<Mutex<HashMap<String, ()>>>,
625 }
626
627 impl OperatorContextLockHelper {
628 async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
629 let mut locks = self.database_locks.lock().await;
630 if locks.contains_key(database_identity) {
631 return None;
632 }
633 locks.insert(database_identity.to_string(), ());
634 Some(DatabaseLockGuard {
635 key: database_identity.to_string(),
636 locks: Arc::clone(&self.database_locks),
637 })
638 }
639 }
640
641 #[tokio::test]
642 async fn try_lock_database_high_concurrency_same_db() {
643 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
645 let concurrency = 50;
646 let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
647 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
648
649 let mut handles = Vec::with_capacity(concurrency);
650 for _ in 0..concurrency {
651 let locks_clone = Arc::clone(&locks);
652 let count = Arc::clone(&acquired_count);
653 let bar = Arc::clone(&barrier);
654 handles.push(tokio::spawn(async move {
655 bar.wait().await;
657 let ctx = OperatorContextLockHelper {
658 database_locks: locks_clone,
659 };
660 let guard = ctx.try_lock("contested-db").await;
661 if guard.is_some() {
662 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
663 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
665 }
666 }));
667 }
668
669 for h in handles {
670 h.await.unwrap();
671 }
672
673 let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
675 assert_eq!(
676 total, 1,
677 "exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
678 );
679 }
680
681 #[tokio::test]
682 async fn try_lock_database_high_concurrency_different_dbs() {
683 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
685 let concurrency = 50;
686 let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
687 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
688
689 let mut handles = Vec::with_capacity(concurrency);
690 for i in 0..concurrency {
691 let locks_clone = Arc::clone(&locks);
692 let count = Arc::clone(&acquired_count);
693 let bar = Arc::clone(&barrier);
694 handles.push(tokio::spawn(async move {
695 bar.wait().await;
696 let ctx = OperatorContextLockHelper {
697 database_locks: locks_clone,
698 };
699 let db_name = format!("db-{i}");
700 let guard = ctx.try_lock(&db_name).await;
701 if guard.is_some() {
702 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
703 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
704 }
705 }));
706 }
707
708 for h in handles {
709 h.await.unwrap();
710 }
711
712 let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
713 assert_eq!(
714 total, concurrency,
715 "all {concurrency} tasks locking different dbs should succeed, got {total}"
716 );
717 }
718
719 #[tokio::test]
720 async fn try_lock_database_acquire_release_cycle_under_contention() {
721 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
725 let concurrency = 20;
726 let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
727 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
728
729 let mut handles = Vec::with_capacity(concurrency);
730 for _ in 0..concurrency {
731 let locks_clone = Arc::clone(&locks);
732 let count = Arc::clone(&success_count);
733 let bar = Arc::clone(&barrier);
734 handles.push(tokio::spawn(async move {
735 bar.wait().await;
736 for _ in 0..100 {
739 let ctx = OperatorContextLockHelper {
740 database_locks: Arc::clone(&locks_clone),
741 };
742 if let Some(_guard) = ctx.try_lock("shared-db").await {
743 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
744 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
746 return;
747 }
748 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
749 }
750 panic!("task failed to acquire lock after 100 retries");
752 }));
753 }
754
755 for h in handles {
756 h.await.unwrap();
757 }
758
759 let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
760 assert_eq!(
761 total, concurrency,
762 "all {concurrency} tasks should eventually acquire the lock"
763 );
764 }
765}