use std::collections::HashMap;
use std::sync::Arc;
use kube::runtime::events::Recorder;
use std::time::Duration;
use sqlx::postgres::{PgPool, PgPoolOptions};
use tokio::sync::{Mutex, RwLock};
use crate::crd::{ConnectionSpec, SecretKeySelector};
use crate::observability::OperatorObservability;
const POOL_MAX_CONNECTIONS: u32 = 5;
const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
#[derive(Clone)]
struct CachedPool {
resource_version: Option<String>,
secret_fingerprint: Option<String>,
pool: PgPool,
}
pub struct DatabaseLockGuard {
key: String,
locks: Arc<Mutex<HashMap<String, ()>>>,
}
impl Drop for DatabaseLockGuard {
fn drop(&mut self) {
if let Ok(mut map) = self.locks.try_lock() {
map.remove(&self.key);
tracing::debug!(database = %self.key, "released in-memory database lock");
} else {
let key = self.key.clone();
let locks = Arc::clone(&self.locks);
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
locks.lock().await.remove(&key);
tracing::debug!(database = %key, "released in-memory database lock (deferred)");
});
tracing::debug!(
database = %self.key,
"deferred in-memory database lock release to background task"
);
} else {
let mut map = self.locks.blocking_lock();
map.remove(&key);
tracing::debug!(
database = %key,
"released in-memory database lock (fallback sync)"
);
}
}
}
}
#[derive(Clone)]
pub struct OperatorContext {
pub kube_client: kube::Client,
pub event_recorder: Recorder,
pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
database_locks: Arc<Mutex<HashMap<String, ()>>>,
pub observability: OperatorObservability,
}
impl OperatorContext {
pub fn new(
kube_client: kube::Client,
observability: OperatorObservability,
event_recorder: Recorder,
) -> Self {
Self {
kube_client,
event_recorder,
pool_cache: Arc::new(RwLock::new(HashMap::new())),
observability,
database_locks: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
let mut locks = self.database_locks.lock().await;
if locks.contains_key(database_identity) {
tracing::info!(
database = %database_identity,
"in-memory database lock contention — another reconcile is in progress"
);
return None;
}
locks.insert(database_identity.to_string(), ());
tracing::debug!(database = %database_identity, "acquired in-memory database lock");
Some(DatabaseLockGuard {
key: database_identity.to_string(),
locks: Arc::clone(&self.database_locks),
})
}
async fn resolve_param(
&self,
namespace: &str,
literal: &Option<String>,
secret: &Option<SecretKeySelector>,
) -> Result<Option<String>, ContextError> {
if let Some(val) = literal {
return Ok(Some(val.clone()));
}
if let Some(sel) = secret {
return Ok(Some(
self.fetch_secret_value(namespace, &sel.name, &sel.key)
.await?,
));
}
Ok(None)
}
pub async fn resolve_connection_url(
&self,
namespace: &str,
connection: &ConnectionSpec,
) -> Result<String, ContextError> {
if let Some(ref secret_ref) = connection.secret_ref {
self.fetch_secret_value(
namespace,
&secret_ref.name,
connection.effective_secret_key(),
)
.await
} else if let Some(ref params) = connection.params {
let host = self
.resolve_param(namespace, ¶ms.host, ¶ms.host_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "host".to_string(),
})?;
if host.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "host".to_string(),
});
}
let port_str = params.port.map(|p| p.to_string());
let port = self
.resolve_param(namespace, &port_str, ¶ms.port_secret)
.await?
.unwrap_or_else(|| "5432".to_string());
if port.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "port".to_string(),
});
}
let dbname = self
.resolve_param(namespace, ¶ms.dbname, ¶ms.dbname_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "dbname".to_string(),
})?;
if dbname.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "dbname".to_string(),
});
}
let username = self
.resolve_param(namespace, ¶ms.username, ¶ms.username_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "username".to_string(),
})?;
if username.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "username".to_string(),
});
}
let password = self
.resolve_param(namespace, ¶ms.password, ¶ms.password_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "password".to_string(),
})?;
if password.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "password".to_string(),
});
}
use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
let encoded_username = utf8_percent_encode(&username, NON_ALPHANUMERIC).to_string();
let encoded_password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string();
let mut url = format!(
"postgresql://{encoded_username}:{encoded_password}@{host}:{port}/{dbname}"
);
if let Some(ssl_mode) = self
.resolve_param(namespace, ¶ms.ssl_mode, ¶ms.ssl_mode_secret)
.await?
{
if !crate::crd::VALID_SSL_MODES.contains(&ssl_mode.as_str()) {
return Err(ContextError::InvalidResolvedSslMode { value: ssl_mode });
}
url.push_str("?sslmode=");
url.push_str(&ssl_mode);
}
Ok(url)
} else {
Err(ContextError::SecretMissing {
name: "connection".to_string(),
key: "neither secretRef nor params is set".to_string(),
})
}
}
pub async fn get_or_create_pool(
&self,
namespace: &str,
connection: &ConnectionSpec,
) -> Result<PgPool, ContextError> {
let cache_key = connection.cache_key(namespace);
let (resource_version, secret_fingerprint) =
if let Some(ref secret_ref) = connection.secret_ref {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let secret = secrets_api.get(&secret_ref.name).await.map_err(|err| {
ContextError::SecretFetch {
name: secret_ref.name.clone(),
namespace: namespace.to_string(),
source: err,
}
})?;
(secret.metadata.resource_version, None)
} else if connection.params.is_some() {
let mut secret_names = std::collections::BTreeSet::new();
connection.collect_secret_names(&mut secret_names);
if secret_names.is_empty() {
(None, Some(String::new()))
} else {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let mut fingerprint_parts = Vec::new();
for name in &secret_names {
let secret = secrets_api.get(name).await.map_err(|err| {
ContextError::SecretFetch {
name: name.clone(),
namespace: namespace.to_string(),
source: err,
}
})?;
let rv = secret
.metadata
.resource_version
.unwrap_or_else(|| "unknown".to_string());
fingerprint_parts.push(format!("{name}={rv}"));
}
(None, Some(fingerprint_parts.join(",")))
}
} else {
(None, None)
};
{
let cache = self.pool_cache.read().await;
if let Some(cached) = cache.get(&cache_key) {
let version_matches = match (&resource_version, &cached.resource_version) {
(Some(current), Some(cached_rv)) => current == cached_rv,
_ => true,
};
let fingerprint_matches = match (&secret_fingerprint, &cached.secret_fingerprint) {
(Some(current), Some(cached_fp)) => current == cached_fp,
(None, None) => true,
_ => false,
};
if version_matches && fingerprint_matches {
return Ok(cached.pool.clone());
}
}
}
let database_url = self.resolve_connection_url(namespace, connection).await?;
let pool = PgPoolOptions::new()
.max_connections(POOL_MAX_CONNECTIONS)
.acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
.connect(&database_url)
.await
.map_err(|err| ContextError::DatabaseConnect { source: err })?;
{
let mut cache = self.pool_cache.write().await;
cache.insert(
cache_key,
CachedPool {
resource_version,
secret_fingerprint,
pool: pool.clone(),
},
);
}
Ok(pool)
}
pub async fn fetch_secret_value(
&self,
namespace: &str,
secret_name: &str,
secret_key: &str,
) -> Result<String, ContextError> {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let secret =
secrets_api
.get(secret_name)
.await
.map_err(|err| ContextError::SecretFetch {
name: secret_name.to_string(),
namespace: namespace.to_string(),
source: err,
})?;
let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
let value_bytes = data
.get(secret_key)
.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})
}
pub async fn evict_pool(&self, namespace: &str, connection: &ConnectionSpec) {
let cache_key = connection.cache_key(namespace);
let mut cache = self.pool_cache.write().await;
cache.remove(&cache_key);
}
}
#[derive(Debug, thiserror::Error)]
pub enum ContextError {
#[error("failed to fetch Secret {namespace}/{name}: {source}")]
SecretFetch {
name: String,
namespace: String,
source: kube::Error,
},
#[error("Secret \"{name}\" does not contain key \"{key}\"")]
SecretMissing { name: String, key: String },
#[error("failed to connect to database: {source}")]
DatabaseConnect { source: sqlx::Error },
#[error("connection param \"{field}\" resolved to an empty or whitespace-only value")]
EmptyResolvedValue { field: String },
#[error(
"connection param sslMode resolved to invalid value \"{value}\" (expected one of: disable, allow, prefer, require, verify-ca, verify-full)"
)]
InvalidResolvedSslMode { value: String },
}
impl ContextError {
pub fn is_secret_fetch_non_transient(&self) -> bool {
matches!(
self,
ContextError::SecretFetch {
source: kube::Error::Api(response),
..
} if (400..500).contains(&response.code) && response.code != 429
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_cache_key_format() {
let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
}
#[test]
fn secret_fetch_not_found_is_non_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
.with_code(404)
.boxed(),
),
};
assert!(error.is_secret_fetch_non_transient());
}
#[test]
fn secret_fetch_forbidden_is_non_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("forbidden", "Forbidden")
.with_code(403)
.boxed(),
),
};
assert!(error.is_secret_fetch_non_transient());
}
#[test]
fn secret_fetch_server_error_remains_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("internal error", "InternalError")
.with_code(500)
.boxed(),
),
};
assert!(!error.is_secret_fetch_non_transient());
}
#[tokio::test]
async fn try_lock_database_acquires_when_free() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let guard = ctx.try_lock("db-a").await;
assert!(guard.is_some(), "should acquire lock on free database");
}
#[tokio::test]
async fn try_lock_database_contention_returns_none() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let _guard1 = ctx
.try_lock("db-a")
.await
.expect("first lock should succeed");
let guard2 = ctx.try_lock("db-a").await;
assert!(guard2.is_none(), "second lock on same database should fail");
}
#[tokio::test]
async fn try_lock_database_different_databases_independent() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let guard_a = ctx.try_lock("db-a").await;
let guard_b = ctx.try_lock("db-b").await;
assert!(guard_a.is_some(), "lock on db-a should succeed");
assert!(
guard_b.is_some(),
"lock on db-b should succeed (different database)"
);
}
#[tokio::test]
async fn try_lock_database_released_after_drop() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: Arc::clone(&locks),
};
{
let _guard = ctx.try_lock("db-a").await.expect("should acquire");
}
let guard2 = ctx.try_lock("db-a").await;
assert!(
guard2.is_some(),
"should re-acquire after previous guard dropped"
);
}
#[tokio::test]
async fn try_lock_database_concurrent_contention() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let locks1 = Arc::clone(&locks);
let locks2 = Arc::clone(&locks);
let handle1 = tokio::spawn(async move {
let ctx = OperatorContextLockHelper {
database_locks: locks1,
};
let guard = ctx.try_lock("shared-db").await;
if guard.is_some() {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
guard.is_some()
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let handle2 = tokio::spawn(async move {
let ctx = OperatorContextLockHelper {
database_locks: locks2,
};
let guard = ctx.try_lock("shared-db").await;
guard.is_some()
});
let (r1, r2) = tokio::join!(handle1, handle2);
let acquired1 = r1.unwrap();
let acquired2 = r2.unwrap();
assert!(
acquired1 ^ acquired2,
"exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
);
}
struct OperatorContextLockHelper {
database_locks: Arc<Mutex<HashMap<String, ()>>>,
}
impl OperatorContextLockHelper {
async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
let mut locks = self.database_locks.lock().await;
if locks.contains_key(database_identity) {
return None;
}
locks.insert(database_identity.to_string(), ());
Some(DatabaseLockGuard {
key: database_identity.to_string(),
locks: Arc::clone(&self.database_locks),
})
}
}
#[tokio::test]
async fn try_lock_database_high_concurrency_same_db() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 50;
let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for _ in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&acquired_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
let ctx = OperatorContextLockHelper {
database_locks: locks_clone,
};
let guard = ctx.try_lock("contested-db").await;
if guard.is_some() {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, 1,
"exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
);
}
#[tokio::test]
async fn try_lock_database_high_concurrency_different_dbs() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 50;
let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for i in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&acquired_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
let ctx = OperatorContextLockHelper {
database_locks: locks_clone,
};
let db_name = format!("db-{i}");
let guard = ctx.try_lock(&db_name).await;
if guard.is_some() {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, concurrency,
"all {concurrency} tasks locking different dbs should succeed, got {total}"
);
}
#[tokio::test]
async fn try_lock_database_acquire_release_cycle_under_contention() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 20;
let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for _ in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&success_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
for _ in 0..100 {
let ctx = OperatorContextLockHelper {
database_locks: Arc::clone(&locks_clone),
};
if let Some(_guard) = ctx.try_lock("shared-db").await {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
return;
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
panic!("task failed to acquire lock after 100 retries");
}));
}
for h in handles {
h.await.unwrap();
}
let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, concurrency,
"all {concurrency} tasks should eventually acquire the lock"
);
}
}