use std::collections::HashMap;
use std::sync::Arc;
use kube::runtime::events::Recorder;
use std::time::Duration;
use sqlx::postgres::{PgPool, PgPoolOptions};
use tokio::sync::{Mutex, RwLock};
use crate::observability::OperatorObservability;
const POOL_MAX_CONNECTIONS: u32 = 5;
const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
#[derive(Clone)]
struct CachedPool {
resource_version: Option<String>,
pool: PgPool,
}
pub struct DatabaseLockGuard {
key: String,
locks: Arc<Mutex<HashMap<String, ()>>>,
}
impl Drop for DatabaseLockGuard {
fn drop(&mut self) {
if let Ok(mut map) = self.locks.try_lock() {
map.remove(&self.key);
tracing::debug!(database = %self.key, "released in-memory database lock");
} else {
let key = self.key.clone();
let locks = Arc::clone(&self.locks);
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
locks.lock().await.remove(&key);
tracing::debug!(database = %key, "released in-memory database lock (deferred)");
});
tracing::debug!(
database = %self.key,
"deferred in-memory database lock release to background task"
);
} else {
let mut map = self.locks.blocking_lock();
map.remove(&key);
tracing::debug!(
database = %key,
"released in-memory database lock (fallback sync)"
);
}
}
}
}
#[derive(Clone)]
pub struct OperatorContext {
pub kube_client: kube::Client,
pub event_recorder: Recorder,
pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
database_locks: Arc<Mutex<HashMap<String, ()>>>,
pub observability: OperatorObservability,
}
impl OperatorContext {
pub fn new(
kube_client: kube::Client,
observability: OperatorObservability,
event_recorder: Recorder,
) -> Self {
Self {
kube_client,
event_recorder,
pool_cache: Arc::new(RwLock::new(HashMap::new())),
observability,
database_locks: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
let mut locks = self.database_locks.lock().await;
if locks.contains_key(database_identity) {
tracing::info!(
database = %database_identity,
"in-memory database lock contention — another reconcile is in progress"
);
return None;
}
locks.insert(database_identity.to_string(), ());
tracing::debug!(database = %database_identity, "acquired in-memory database lock");
Some(DatabaseLockGuard {
key: database_identity.to_string(),
locks: Arc::clone(&self.database_locks),
})
}
pub async fn get_or_create_pool(
&self,
namespace: &str,
secret_name: &str,
secret_key: &str,
) -> Result<PgPool, ContextError> {
let cache_key = format!("{namespace}/{secret_name}/{secret_key}");
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let secret =
secrets_api
.get(secret_name)
.await
.map_err(|err| ContextError::SecretFetch {
name: secret_name.to_string(),
namespace: namespace.to_string(),
source: err,
})?;
let resource_version = secret.metadata.resource_version.clone();
{
let cache = self.pool_cache.read().await;
if let Some(cached) = cache.get(&cache_key)
&& cached.resource_version == resource_version
{
return Ok(cached.pool.clone());
}
}
let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
let url_bytes = data
.get(secret_key)
.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
let database_url =
String::from_utf8(url_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
let pool = PgPoolOptions::new()
.max_connections(POOL_MAX_CONNECTIONS)
.acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
.connect(&database_url)
.await
.map_err(|err| ContextError::DatabaseConnect { source: err })?;
{
let mut cache = self.pool_cache.write().await;
cache.insert(
cache_key,
CachedPool {
resource_version,
pool: pool.clone(),
},
);
}
Ok(pool)
}
pub async fn fetch_secret_value(
&self,
namespace: &str,
secret_name: &str,
secret_key: &str,
) -> Result<String, ContextError> {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let secret =
secrets_api
.get(secret_name)
.await
.map_err(|err| ContextError::SecretFetch {
name: secret_name.to_string(),
namespace: namespace.to_string(),
source: err,
})?;
let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
let value_bytes = data
.get(secret_key)
.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})
}
pub async fn evict_pool(&self, namespace: &str, secret_name: &str, secret_key: &str) {
let cache_key = format!("{namespace}/{secret_name}/{secret_key}");
let mut cache = self.pool_cache.write().await;
cache.remove(&cache_key);
}
}
#[derive(Debug, thiserror::Error)]
pub enum ContextError {
#[error("failed to fetch Secret {namespace}/{name}: {source}")]
SecretFetch {
name: String,
namespace: String,
source: kube::Error,
},
#[error("Secret \"{name}\" does not contain key \"{key}\"")]
SecretMissing { name: String, key: String },
#[error("failed to connect to database: {source}")]
DatabaseConnect { source: sqlx::Error },
}
impl ContextError {
pub fn is_secret_fetch_non_transient(&self) -> bool {
matches!(
self,
ContextError::SecretFetch {
source: kube::Error::Api(response),
..
} if (400..500).contains(&response.code) && response.code != 429
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_cache_key_format() {
let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
}
#[test]
fn secret_fetch_not_found_is_non_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
.with_code(404)
.boxed(),
),
};
assert!(error.is_secret_fetch_non_transient());
}
#[test]
fn secret_fetch_forbidden_is_non_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("forbidden", "Forbidden")
.with_code(403)
.boxed(),
),
};
assert!(error.is_secret_fetch_non_transient());
}
#[test]
fn secret_fetch_server_error_remains_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("internal error", "InternalError")
.with_code(500)
.boxed(),
),
};
assert!(!error.is_secret_fetch_non_transient());
}
#[tokio::test]
async fn try_lock_database_acquires_when_free() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let guard = ctx.try_lock("db-a").await;
assert!(guard.is_some(), "should acquire lock on free database");
}
#[tokio::test]
async fn try_lock_database_contention_returns_none() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let _guard1 = ctx
.try_lock("db-a")
.await
.expect("first lock should succeed");
let guard2 = ctx.try_lock("db-a").await;
assert!(guard2.is_none(), "second lock on same database should fail");
}
#[tokio::test]
async fn try_lock_database_different_databases_independent() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let guard_a = ctx.try_lock("db-a").await;
let guard_b = ctx.try_lock("db-b").await;
assert!(guard_a.is_some(), "lock on db-a should succeed");
assert!(
guard_b.is_some(),
"lock on db-b should succeed (different database)"
);
}
#[tokio::test]
async fn try_lock_database_released_after_drop() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: Arc::clone(&locks),
};
{
let _guard = ctx.try_lock("db-a").await.expect("should acquire");
}
let guard2 = ctx.try_lock("db-a").await;
assert!(
guard2.is_some(),
"should re-acquire after previous guard dropped"
);
}
#[tokio::test]
async fn try_lock_database_concurrent_contention() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let locks1 = Arc::clone(&locks);
let locks2 = Arc::clone(&locks);
let handle1 = tokio::spawn(async move {
let ctx = OperatorContextLockHelper {
database_locks: locks1,
};
let guard = ctx.try_lock("shared-db").await;
if guard.is_some() {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
guard.is_some()
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let handle2 = tokio::spawn(async move {
let ctx = OperatorContextLockHelper {
database_locks: locks2,
};
let guard = ctx.try_lock("shared-db").await;
guard.is_some()
});
let (r1, r2) = tokio::join!(handle1, handle2);
let acquired1 = r1.unwrap();
let acquired2 = r2.unwrap();
assert!(
acquired1 ^ acquired2,
"exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
);
}
struct OperatorContextLockHelper {
database_locks: Arc<Mutex<HashMap<String, ()>>>,
}
impl OperatorContextLockHelper {
async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
let mut locks = self.database_locks.lock().await;
if locks.contains_key(database_identity) {
return None;
}
locks.insert(database_identity.to_string(), ());
Some(DatabaseLockGuard {
key: database_identity.to_string(),
locks: Arc::clone(&self.database_locks),
})
}
}
#[tokio::test]
async fn try_lock_database_high_concurrency_same_db() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 50;
let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for _ in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&acquired_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
let ctx = OperatorContextLockHelper {
database_locks: locks_clone,
};
let guard = ctx.try_lock("contested-db").await;
if guard.is_some() {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, 1,
"exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
);
}
#[tokio::test]
async fn try_lock_database_high_concurrency_different_dbs() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 50;
let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for i in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&acquired_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
let ctx = OperatorContextLockHelper {
database_locks: locks_clone,
};
let db_name = format!("db-{i}");
let guard = ctx.try_lock(&db_name).await;
if guard.is_some() {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, concurrency,
"all {concurrency} tasks locking different dbs should succeed, got {total}"
);
}
#[tokio::test]
async fn try_lock_database_acquire_release_cycle_under_contention() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 20;
let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for _ in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&success_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
for _ in 0..100 {
let ctx = OperatorContextLockHelper {
database_locks: Arc::clone(&locks_clone),
};
if let Some(_guard) = ctx.try_lock("shared-db").await {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
return;
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
panic!("task failed to acquire lock after 100 retries");
}));
}
for h in handles {
h.await.unwrap();
}
let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, concurrency,
"all {concurrency} tasks should eventually acquire the lock"
);
}
}