use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use futures::future::{BoxFuture, FutureExt};
use kube::runtime::events::Recorder;
use serde::{Deserialize, Serialize};
use sqlx::postgres::{PgPool, PgPoolOptions};
use tokio::sync::{Mutex, RwLock};
use crate::crd::{ConnectionAuth, ConnectionSpec, SecretKeySelector};
use crate::observability::OperatorObservability;
const POOL_MAX_CONNECTIONS: u32 = 5;
const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
const GCP_METADATA_TOKEN_ENDPOINT: &str =
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
const GCP_IAM_CREDENTIALS_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
const GCP_TOKEN_CACHE_SKEW_SECS: u64 = 300;
const GCP_IMPERSONATED_TOKEN_LIFETIME_SECS: u64 = 3600;
const GCP_AUTH_HTTP_TIMEOUT_SECS: u64 = 10;
#[derive(Clone)]
struct CachedPool {
resource_version: Option<String>,
secret_fingerprint: Option<String>,
token_expires_at: Option<SystemTime>,
pool: PgPool,
}
struct ResolvedConnectionUrl {
database_url: String,
token_expires_at: Option<SystemTime>,
}
#[derive(Clone)]
struct GcpAccessToken {
token: String,
expires_at: SystemTime,
}
trait GcpAccessTokenProvider: Send + Sync {
fn fetch_token<'a>(
&'a self,
auth: &'a ConnectionAuth,
) -> BoxFuture<'a, Result<GcpAccessToken, ContextError>>;
}
#[derive(Clone)]
struct MetadataGcpAccessTokenProvider {
client: reqwest::Client,
}
impl Default for MetadataGcpAccessTokenProvider {
fn default() -> Self {
Self {
client: reqwest::Client::builder()
.no_proxy()
.timeout(Duration::from_secs(GCP_AUTH_HTTP_TIMEOUT_SECS))
.build()
.expect("GCP auth HTTP client should build"),
}
}
}
impl GcpAccessTokenProvider for MetadataGcpAccessTokenProvider {
fn fetch_token<'a>(
&'a self,
auth: &'a ConnectionAuth,
) -> BoxFuture<'a, Result<GcpAccessToken, ContextError>> {
async move {
let scope = auth.gcp_scope();
if let Some(target) = auth.gcp_impersonate_service_account() {
self.fetch_impersonated_access_token(target, scope).await
} else {
self.fetch_metadata_access_token(scope).await
}
}
.boxed()
}
}
impl MetadataGcpAccessTokenProvider {
async fn fetch_metadata_access_token(
&self,
scope: &str,
) -> Result<GcpAccessToken, ContextError> {
let response = self
.client
.get(GCP_METADATA_TOKEN_ENDPOINT)
.header("Metadata-Flavor", "Google")
.query(&[("scopes", scope)])
.send()
.await
.map_err(|source| ContextError::GcpAuthHttp {
endpoint: "metadata",
source,
})?;
let status = response.status();
if !status.is_success() {
let body = response_body_for_error(response).await;
return Err(ContextError::GcpAuthRejected {
endpoint: "metadata".to_string(),
status: status.as_u16(),
body,
});
}
let body: MetadataTokenResponse =
response
.json()
.await
.map_err(|source| ContextError::GcpAuthHttp {
endpoint: "metadata",
source,
})?;
if body.access_token.trim().is_empty() {
return Err(ContextError::GcpAuthInvalidResponse {
detail: "metadata token response omitted access_token".to_string(),
});
}
if body.expires_in == 0 {
return Err(ContextError::GcpAuthInvalidResponse {
detail: "metadata token response had zero expires_in".to_string(),
});
}
Ok(GcpAccessToken {
token: body.access_token,
expires_at: SystemTime::now() + Duration::from_secs(body.expires_in),
})
}
async fn fetch_impersonated_access_token(
&self,
target_service_account: &str,
scope: &str,
) -> Result<GcpAccessToken, ContextError> {
let source = self
.fetch_metadata_access_token(GCP_IAM_CREDENTIALS_SCOPE)
.await?;
let encoded_target = percent_encoding::utf8_percent_encode(
target_service_account,
percent_encoding::NON_ALPHANUMERIC,
)
.to_string();
let endpoint = format!(
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{encoded_target}:generateAccessToken"
);
let request = GenerateAccessTokenRequest {
scope: vec![scope.to_string()],
lifetime: format!("{GCP_IMPERSONATED_TOKEN_LIFETIME_SECS}s"),
};
let response = self
.client
.post(&endpoint)
.bearer_auth(&source.token)
.json(&request)
.send()
.await
.map_err(|source| ContextError::GcpAuthHttp {
endpoint: "iamcredentials",
source,
})?;
let status = response.status();
if !status.is_success() {
let body = response_body_for_error(response).await;
return Err(ContextError::GcpAuthRejected {
endpoint: "iamcredentials".to_string(),
status: status.as_u16(),
body,
});
}
let body: GenerateAccessTokenResponse =
response
.json()
.await
.map_err(|source| ContextError::GcpAuthHttp {
endpoint: "iamcredentials",
source,
})?;
if body.access_token.trim().is_empty() {
return Err(ContextError::GcpAuthInvalidResponse {
detail: "IAMCredentials response omitted accessToken".to_string(),
});
}
let expires_at = parse_google_expire_time(&body.expire_time).ok_or_else(|| {
ContextError::GcpAuthInvalidResponse {
detail: format!(
"IAMCredentials response had invalid expireTime {:?}",
body.expire_time
),
}
})?;
Ok(GcpAccessToken {
token: body.access_token,
expires_at,
})
}
}
#[derive(Deserialize)]
struct MetadataTokenResponse {
access_token: String,
expires_in: u64,
}
#[derive(Serialize)]
struct GenerateAccessTokenRequest {
scope: Vec<String>,
lifetime: String,
}
#[derive(Deserialize)]
struct GenerateAccessTokenResponse {
#[serde(rename = "accessToken")]
access_token: String,
#[serde(rename = "expireTime")]
expire_time: String,
}
async fn response_body_for_error(response: reqwest::Response) -> String {
match response.text().await {
Ok(body) => truncate_for_error(body),
Err(error) => format!("failed to read error body: {error}"),
}
}
fn truncate_for_error(mut body: String) -> String {
const MAX_ERROR_BODY_BYTES: usize = 512;
if body.len() <= MAX_ERROR_BODY_BYTES {
return body;
}
let mut end = MAX_ERROR_BODY_BYTES;
while !body.is_char_boundary(end) {
end -= 1;
}
body.truncate(end);
body.push_str("...");
body
}
fn parse_google_expire_time(expire_time: &str) -> Option<SystemTime> {
expire_time
.parse::<jiff::Timestamp>()
.ok()
.map(SystemTime::from)
}
fn token_expires_after_skew(expires_at: Option<SystemTime>, now: SystemTime) -> bool {
let Some(expires_at) = expires_at else {
return true;
};
let Some(refresh_at) = now.checked_add(Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS)) else {
return false;
};
expires_at > refresh_at
}
pub struct DatabaseLockGuard {
key: String,
locks: Arc<Mutex<HashMap<String, ()>>>,
}
impl Drop for DatabaseLockGuard {
fn drop(&mut self) {
if let Ok(mut map) = self.locks.try_lock() {
map.remove(&self.key);
tracing::debug!(database = %self.key, "released in-memory database lock");
} else {
let key = self.key.clone();
let locks = Arc::clone(&self.locks);
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
locks.lock().await.remove(&key);
tracing::debug!(database = %key, "released in-memory database lock (deferred)");
});
tracing::debug!(
database = %self.key,
"deferred in-memory database lock release to background task"
);
} else {
let mut map = self.locks.blocking_lock();
map.remove(&key);
tracing::debug!(
database = %key,
"released in-memory database lock (fallback sync)"
);
}
}
}
}
#[derive(Clone)]
pub struct OperatorContext {
pub kube_client: kube::Client,
pub event_recorder: Recorder,
pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
database_locks: Arc<Mutex<HashMap<String, ()>>>,
pub observability: OperatorObservability,
gcp_token_provider: Arc<dyn GcpAccessTokenProvider>,
}
impl OperatorContext {
pub fn new(
kube_client: kube::Client,
observability: OperatorObservability,
event_recorder: Recorder,
) -> Self {
Self {
kube_client,
event_recorder,
pool_cache: Arc::new(RwLock::new(HashMap::new())),
observability,
database_locks: Arc::new(Mutex::new(HashMap::new())),
gcp_token_provider: Arc::new(MetadataGcpAccessTokenProvider::default()),
}
}
pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
let mut locks = self.database_locks.lock().await;
if locks.contains_key(database_identity) {
tracing::info!(
database = %database_identity,
"in-memory database lock contention — another reconcile is in progress"
);
return None;
}
locks.insert(database_identity.to_string(), ());
tracing::debug!(database = %database_identity, "acquired in-memory database lock");
Some(DatabaseLockGuard {
key: database_identity.to_string(),
locks: Arc::clone(&self.database_locks),
})
}
async fn resolve_param(
&self,
namespace: &str,
literal: &Option<String>,
secret: &Option<SecretKeySelector>,
) -> Result<Option<String>, ContextError> {
if let Some(val) = literal {
return Ok(Some(val.clone()));
}
if let Some(sel) = secret {
return Ok(Some(
self.fetch_secret_value(namespace, &sel.name, &sel.key)
.await?,
));
}
Ok(None)
}
pub async fn resolve_connection_url(
&self,
namespace: &str,
connection: &ConnectionSpec,
) -> Result<String, ContextError> {
Ok(self
.resolve_connection_url_with_metadata(namespace, connection)
.await?
.database_url)
}
async fn resolve_connection_url_with_metadata(
&self,
namespace: &str,
connection: &ConnectionSpec,
) -> Result<ResolvedConnectionUrl, ContextError> {
if let Some(ref secret_ref) = connection.secret_ref {
let database_url = self
.fetch_secret_value(
namespace,
&secret_ref.name,
connection.effective_secret_key(),
)
.await?;
Ok(ResolvedConnectionUrl {
database_url,
token_expires_at: None,
})
} else if let Some(ref params) = connection.params {
let host = self
.resolve_param(namespace, ¶ms.host, ¶ms.host_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "host".to_string(),
})?;
if host.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "host".to_string(),
});
}
let port_str = params.port.map(|p| p.to_string());
let port = self
.resolve_param(namespace, &port_str, ¶ms.port_secret)
.await?
.unwrap_or_else(|| "5432".to_string());
if port.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "port".to_string(),
});
}
let dbname = self
.resolve_param(namespace, ¶ms.dbname, ¶ms.dbname_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "dbname".to_string(),
})?;
if dbname.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "dbname".to_string(),
});
}
let username = self
.resolve_param(namespace, ¶ms.username, ¶ms.username_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "username".to_string(),
})?;
if username.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "username".to_string(),
});
}
let (password, token_expires_at) = if let Some(auth) = ¶ms.auth {
let token = self.gcp_token_provider.fetch_token(auth).await?;
(token.token, Some(token.expires_at))
} else {
let password = self
.resolve_param(namespace, ¶ms.password, ¶ms.password_secret)
.await?
.ok_or_else(|| ContextError::EmptyResolvedValue {
field: "password".to_string(),
})?;
(password, None)
};
if password.trim().is_empty() {
return Err(ContextError::EmptyResolvedValue {
field: "password".to_string(),
});
}
use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
let encoded_username = utf8_percent_encode(&username, NON_ALPHANUMERIC).to_string();
let encoded_password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string();
let mut url = format!(
"postgresql://{encoded_username}:{encoded_password}@{host}:{port}/{dbname}"
);
let ssl_mode = self
.resolve_param(namespace, ¶ms.ssl_mode, ¶ms.ssl_mode_secret)
.await?
.or_else(|| params.auth.as_ref().map(|_| "require".to_string()));
if let Some(ssl_mode) = ssl_mode {
if !crate::crd::VALID_SSL_MODES.contains(&ssl_mode.as_str()) {
return Err(ContextError::InvalidResolvedSslMode { value: ssl_mode });
}
url.push_str("?sslmode=");
url.push_str(&ssl_mode);
}
Ok(ResolvedConnectionUrl {
database_url: url,
token_expires_at,
})
} else {
Err(ContextError::SecretMissing {
name: "connection".to_string(),
key: "neither secretRef nor params is set".to_string(),
})
}
}
pub async fn get_or_create_pool(
&self,
namespace: &str,
connection: &ConnectionSpec,
) -> Result<PgPool, ContextError> {
let cache_key = connection.cache_key(namespace);
let (resource_version, secret_fingerprint) =
if let Some(ref secret_ref) = connection.secret_ref {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let secret = secrets_api.get(&secret_ref.name).await.map_err(|err| {
ContextError::SecretFetch {
name: secret_ref.name.clone(),
namespace: namespace.to_string(),
source: err,
}
})?;
(secret.metadata.resource_version, None)
} else if connection.params.is_some() {
let mut secret_names = std::collections::BTreeSet::new();
connection.collect_secret_names(&mut secret_names);
if secret_names.is_empty() {
(None, Some(String::new()))
} else {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let mut fingerprint_parts = Vec::new();
for name in &secret_names {
let secret = secrets_api.get(name).await.map_err(|err| {
ContextError::SecretFetch {
name: name.clone(),
namespace: namespace.to_string(),
source: err,
}
})?;
let rv = secret
.metadata
.resource_version
.unwrap_or_else(|| "unknown".to_string());
fingerprint_parts.push(format!("{name}={rv}"));
}
(None, Some(fingerprint_parts.join(",")))
}
} else {
(None, None)
};
{
let cache = self.pool_cache.read().await;
if let Some(cached) = cache.get(&cache_key) {
let version_matches = match (&resource_version, &cached.resource_version) {
(Some(current), Some(cached_rv)) => current == cached_rv,
_ => true,
};
let fingerprint_matches = match (&secret_fingerprint, &cached.secret_fingerprint) {
(Some(current), Some(cached_fp)) => current == cached_fp,
(None, None) => true,
_ => false,
};
let token_fresh =
token_expires_after_skew(cached.token_expires_at, SystemTime::now());
if version_matches && fingerprint_matches && token_fresh {
return Ok(cached.pool.clone());
}
}
}
let resolved = self
.resolve_connection_url_with_metadata(namespace, connection)
.await?;
let pool = PgPoolOptions::new()
.max_connections(POOL_MAX_CONNECTIONS)
.acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
.connect(&resolved.database_url)
.await
.map_err(|err| ContextError::DatabaseConnect { source: err })?;
{
let mut cache = self.pool_cache.write().await;
cache.insert(
cache_key,
CachedPool {
resource_version,
secret_fingerprint,
token_expires_at: resolved.token_expires_at,
pool: pool.clone(),
},
);
}
Ok(pool)
}
pub async fn fetch_secret_value(
&self,
namespace: &str,
secret_name: &str,
secret_key: &str,
) -> Result<String, ContextError> {
let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
kube::Api::namespaced(self.kube_client.clone(), namespace);
let secret =
secrets_api
.get(secret_name)
.await
.map_err(|err| ContextError::SecretFetch {
name: secret_name.to_string(),
namespace: namespace.to_string(),
source: err,
})?;
let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
let value_bytes = data
.get(secret_key)
.ok_or_else(|| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})?;
String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
name: secret_name.to_string(),
key: secret_key.to_string(),
})
}
pub async fn evict_pool(&self, namespace: &str, connection: &ConnectionSpec) {
let cache_key = connection.cache_key(namespace);
let mut cache = self.pool_cache.write().await;
cache.remove(&cache_key);
}
}
#[derive(Debug, thiserror::Error)]
pub enum ContextError {
#[error("failed to fetch Secret {namespace}/{name}: {source}")]
SecretFetch {
name: String,
namespace: String,
source: kube::Error,
},
#[error("Secret \"{name}\" does not contain key \"{key}\"")]
SecretMissing { name: String, key: String },
#[error("failed to connect to database: {source}")]
DatabaseConnect { source: sqlx::Error },
#[error("connection param \"{field}\" resolved to an empty or whitespace-only value")]
EmptyResolvedValue { field: String },
#[error(
"connection param sslMode resolved to invalid value \"{value}\" (expected one of: disable, allow, prefer, require, verify-ca, verify-full)"
)]
InvalidResolvedSslMode { value: String },
#[error("failed to fetch GCP auth token from {endpoint}: {source}")]
GcpAuthHttp {
endpoint: &'static str,
source: reqwest::Error,
},
#[error("GCP auth token endpoint {endpoint} returned HTTP {status}: {body}")]
GcpAuthRejected {
endpoint: String,
status: u16,
body: String,
},
#[error("GCP auth token response was invalid: {detail}")]
GcpAuthInvalidResponse { detail: String },
}
impl ContextError {
pub fn is_secret_fetch_non_transient(&self) -> bool {
matches!(
self,
ContextError::SecretFetch {
source: kube::Error::Api(response),
..
} if (400..500).contains(&response.code) && response.code != 429
)
}
pub fn is_gcp_auth_non_transient(&self) -> bool {
matches!(
self,
ContextError::GcpAuthRejected { status, .. }
if (400..500).contains(status) && *status != 429
) || matches!(self, ContextError::GcpAuthInvalidResponse { .. })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pool_cache_key_format() {
let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
}
#[test]
fn secret_fetch_not_found_is_non_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
.with_code(404)
.boxed(),
),
};
assert!(error.is_secret_fetch_non_transient());
}
#[test]
fn secret_fetch_forbidden_is_non_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("forbidden", "Forbidden")
.with_code(403)
.boxed(),
),
};
assert!(error.is_secret_fetch_non_transient());
}
#[test]
fn secret_fetch_server_error_remains_transient() {
let error = ContextError::SecretFetch {
name: "db-credentials".into(),
namespace: "default".into(),
source: kube::Error::Api(
kube::core::Status::failure("internal error", "InternalError")
.with_code(500)
.boxed(),
),
};
assert!(!error.is_secret_fetch_non_transient());
}
#[test]
fn gcp_auth_client_error_is_non_transient() {
let error = ContextError::GcpAuthRejected {
endpoint: "metadata".into(),
status: 403,
body: "forbidden".into(),
};
assert!(error.is_gcp_auth_non_transient());
}
#[test]
fn gcp_auth_rate_limit_remains_transient() {
let error = ContextError::GcpAuthRejected {
endpoint: "metadata".into(),
status: 429,
body: "rate limited".into(),
};
assert!(!error.is_gcp_auth_non_transient());
}
#[test]
fn token_expiry_uses_five_minute_refresh_skew() {
let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000);
assert!(token_expires_after_skew(
Some(now + Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS + 1)),
now
));
assert!(!token_expires_after_skew(
Some(now + Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS)),
now
));
}
#[test]
fn parse_google_expire_time_accepts_rfc3339() {
let parsed =
parse_google_expire_time("2026-05-14T02:30:00Z").expect("expireTime should parse");
assert_eq!(
parsed
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs(),
1_778_725_800
);
}
#[test]
fn truncate_for_error_keeps_utf8_boundary() {
let body = "é".repeat(300);
let truncated = truncate_for_error(body);
assert!(truncated.ends_with("..."));
assert!(truncated.is_char_boundary(truncated.len() - 3));
}
#[tokio::test]
async fn try_lock_database_acquires_when_free() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let guard = ctx.try_lock("db-a").await;
assert!(guard.is_some(), "should acquire lock on free database");
}
#[tokio::test]
async fn try_lock_database_contention_returns_none() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let _guard1 = ctx
.try_lock("db-a")
.await
.expect("first lock should succeed");
let guard2 = ctx.try_lock("db-a").await;
assert!(guard2.is_none(), "second lock on same database should fail");
}
#[tokio::test]
async fn try_lock_database_different_databases_independent() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: locks,
};
let guard_a = ctx.try_lock("db-a").await;
let guard_b = ctx.try_lock("db-b").await;
assert!(guard_a.is_some(), "lock on db-a should succeed");
assert!(
guard_b.is_some(),
"lock on db-b should succeed (different database)"
);
}
#[tokio::test]
async fn try_lock_database_released_after_drop() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let ctx = OperatorContextLockHelper {
database_locks: Arc::clone(&locks),
};
{
let _guard = ctx.try_lock("db-a").await.expect("should acquire");
}
let guard2 = ctx.try_lock("db-a").await;
assert!(
guard2.is_some(),
"should re-acquire after previous guard dropped"
);
}
#[tokio::test]
async fn try_lock_database_concurrent_contention() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let locks1 = Arc::clone(&locks);
let locks2 = Arc::clone(&locks);
let handle1 = tokio::spawn(async move {
let ctx = OperatorContextLockHelper {
database_locks: locks1,
};
let guard = ctx.try_lock("shared-db").await;
if guard.is_some() {
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
guard.is_some()
});
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let handle2 = tokio::spawn(async move {
let ctx = OperatorContextLockHelper {
database_locks: locks2,
};
let guard = ctx.try_lock("shared-db").await;
guard.is_some()
});
let (r1, r2) = tokio::join!(handle1, handle2);
let acquired1 = r1.unwrap();
let acquired2 = r2.unwrap();
assert!(
acquired1 ^ acquired2,
"exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
);
}
struct OperatorContextLockHelper {
database_locks: Arc<Mutex<HashMap<String, ()>>>,
}
impl OperatorContextLockHelper {
async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
let mut locks = self.database_locks.lock().await;
if locks.contains_key(database_identity) {
return None;
}
locks.insert(database_identity.to_string(), ());
Some(DatabaseLockGuard {
key: database_identity.to_string(),
locks: Arc::clone(&self.database_locks),
})
}
}
#[tokio::test]
async fn try_lock_database_high_concurrency_same_db() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 50;
let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for _ in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&acquired_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
let ctx = OperatorContextLockHelper {
database_locks: locks_clone,
};
let guard = ctx.try_lock("contested-db").await;
if guard.is_some() {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, 1,
"exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
);
}
#[tokio::test]
async fn try_lock_database_high_concurrency_different_dbs() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 50;
let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for i in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&acquired_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
let ctx = OperatorContextLockHelper {
database_locks: locks_clone,
};
let db_name = format!("db-{i}");
let guard = ctx.try_lock(&db_name).await;
if guard.is_some() {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, concurrency,
"all {concurrency} tasks locking different dbs should succeed, got {total}"
);
}
#[tokio::test]
async fn try_lock_database_acquire_release_cycle_under_contention() {
let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
let concurrency = 20;
let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
let mut handles = Vec::with_capacity(concurrency);
for _ in 0..concurrency {
let locks_clone = Arc::clone(&locks);
let count = Arc::clone(&success_count);
let bar = Arc::clone(&barrier);
handles.push(tokio::spawn(async move {
bar.wait().await;
for _ in 0..100 {
let ctx = OperatorContextLockHelper {
database_locks: Arc::clone(&locks_clone),
};
if let Some(_guard) = ctx.try_lock("shared-db").await {
count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
return;
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
panic!("task failed to acquire lock after 100 retries");
}));
}
for h in handles {
h.await.unwrap();
}
let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
total, concurrency,
"all {concurrency} tasks should eventually acquire the lock"
);
}
}