Skip to main content

pgroles_operator/
context.rs

1//! Shared operator context — database pool cache, metrics, and configuration.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use kube::runtime::events::Recorder;
7use std::time::Duration;
8
9use sqlx::postgres::{PgPool, PgPoolOptions};
10use tokio::sync::{Mutex, RwLock};
11
12use crate::crd::{ConnectionSpec, SecretKeySelector};
13use crate::observability::OperatorObservability;
14
15/// Minimum pool size required for reconciliation.
16///
17/// One connection is held for the session-scoped advisory lock while the
18/// reconcile loop performs inspection and apply work on the pool.
19const POOL_MAX_CONNECTIONS: u32 = 5;
20
21/// Bound how long a reconcile waits for a pooled connection before surfacing
22/// a transient database connectivity failure.
23const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
24
25const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
26
27#[derive(Clone)]
28struct CachedPool {
29    resource_version: Option<String>,
30    /// Fingerprint of all referenced secrets' resourceVersions (params mode).
31    secret_fingerprint: Option<String>,
32    pool: PgPool,
33}
34
35/// Guard returned by [`OperatorContext::try_lock_database`].
36///
37/// Holding this guard prevents other reconcile loops (within the same process)
38/// from starting work on the same database target. The lock is released when
39/// the guard is dropped.
40pub struct DatabaseLockGuard {
41    key: String,
42    locks: Arc<Mutex<HashMap<String, ()>>>,
43}
44
45impl Drop for DatabaseLockGuard {
46    fn drop(&mut self) {
47        // Best-effort removal — `try_lock` avoids blocking the drop.
48        if let Ok(mut map) = self.locks.try_lock() {
49            map.remove(&self.key);
50            tracing::debug!(database = %self.key, "released in-memory database lock");
51        } else {
52            // Spawn a task to clean up if the mutex is currently held.
53            // Use Handle::try_current() so we don't panic when dropped
54            // outside an active Tokio runtime (e.g. during shutdown).
55            let key = self.key.clone();
56            let locks = Arc::clone(&self.locks);
57            if let Ok(handle) = tokio::runtime::Handle::try_current() {
58                handle.spawn(async move {
59                    locks.lock().await.remove(&key);
60                    tracing::debug!(database = %key, "released in-memory database lock (deferred)");
61                });
62                tracing::debug!(
63                    database = %self.key,
64                    "deferred in-memory database lock release to background task"
65                );
66            } else {
67                // No runtime available — fall back to synchronous cleanup
68                // via blocking_lock so the entry is still removed.
69                let mut map = self.locks.blocking_lock();
70                map.remove(&key);
71                tracing::debug!(
72                    database = %key,
73                    "released in-memory database lock (fallback sync)"
74                );
75            }
76        }
77    }
78}
79
80/// Shared state for the operator, passed to every reconciliation.
81#[derive(Clone)]
82pub struct OperatorContext {
83    /// Kubernetes client for API calls.
84    pub kube_client: kube::Client,
85
86    /// Kubernetes Event recorder for transition-based policy Events.
87    pub event_recorder: Recorder,
88
89    /// Cached database connection pools keyed by `"namespace/secret-name/secret-key"`.
90    pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
91    /// In-process per-database reconciliation locks.
92    ///
93    /// Prevents concurrent reconcile loops from operating on the same database
94    /// within a single operator replica. Cross-replica safety is provided by
95    /// PostgreSQL advisory locks (see [`crate::advisory`]).
96    database_locks: Arc<Mutex<HashMap<String, ()>>>,
97
98    /// Shared health/metrics state.
99    pub observability: OperatorObservability,
100}
101
102impl OperatorContext {
103    /// Create a new operator context with an empty pool cache.
104    pub fn new(
105        kube_client: kube::Client,
106        observability: OperatorObservability,
107        event_recorder: Recorder,
108    ) -> Self {
109        Self {
110            kube_client,
111            event_recorder,
112            pool_cache: Arc::new(RwLock::new(HashMap::new())),
113            observability,
114            database_locks: Arc::new(Mutex::new(HashMap::new())),
115        }
116    }
117
118    /// Try to acquire the in-process lock for the given database identity.
119    ///
120    /// Returns `Some(guard)` if no other reconcile is in progress for this
121    /// database, `None` if one is already running. The lock is released when
122    /// the guard is dropped.
123    pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
124        let mut locks = self.database_locks.lock().await;
125        if locks.contains_key(database_identity) {
126            tracing::info!(
127                database = %database_identity,
128                "in-memory database lock contention — another reconcile is in progress"
129            );
130            return None;
131        }
132        locks.insert(database_identity.to_string(), ());
133        tracing::debug!(database = %database_identity, "acquired in-memory database lock");
134        Some(DatabaseLockGuard {
135            key: database_identity.to_string(),
136            locks: Arc::clone(&self.database_locks),
137        })
138    }
139
140    /// Resolve a param from either its literal value or a Secret reference.
141    ///
142    /// Returns `Ok(Some(value))` if one is set, `Ok(None)` if neither is set.
143    async fn resolve_param(
144        &self,
145        namespace: &str,
146        literal: &Option<String>,
147        secret: &Option<SecretKeySelector>,
148    ) -> Result<Option<String>, ContextError> {
149        if let Some(val) = literal {
150            return Ok(Some(val.clone()));
151        }
152        if let Some(sel) = secret {
153            return Ok(Some(
154                self.fetch_secret_value(namespace, &sel.name, &sel.key)
155                    .await?,
156            ));
157        }
158        Ok(None)
159    }
160
161    /// Resolve a [`ConnectionSpec`] into a PostgreSQL connection URL string.
162    ///
163    /// - **URL mode** (`secret_ref` is Some): reads the Secret key as a connection URL.
164    /// - **Params mode** (`params` is Some): resolves each field and constructs a URL.
165    pub async fn resolve_connection_url(
166        &self,
167        namespace: &str,
168        connection: &ConnectionSpec,
169    ) -> Result<String, ContextError> {
170        if let Some(ref secret_ref) = connection.secret_ref {
171            // URL mode — read the full connection URL from the Secret.
172            self.fetch_secret_value(
173                namespace,
174                &secret_ref.name,
175                connection.effective_secret_key(),
176            )
177            .await
178        } else if let Some(ref params) = connection.params {
179            // Params mode — resolve each field and build the URL.
180            let host = self
181                .resolve_param(namespace, &params.host, &params.host_secret)
182                .await?
183                .ok_or_else(|| ContextError::EmptyResolvedValue {
184                    field: "host".to_string(),
185                })?;
186            if host.trim().is_empty() {
187                return Err(ContextError::EmptyResolvedValue {
188                    field: "host".to_string(),
189                });
190            }
191
192            let port_str = params.port.map(|p| p.to_string());
193            let port = self
194                .resolve_param(namespace, &port_str, &params.port_secret)
195                .await?
196                .unwrap_or_else(|| "5432".to_string());
197            if port.trim().is_empty() {
198                return Err(ContextError::EmptyResolvedValue {
199                    field: "port".to_string(),
200                });
201            }
202
203            let dbname = self
204                .resolve_param(namespace, &params.dbname, &params.dbname_secret)
205                .await?
206                .ok_or_else(|| ContextError::EmptyResolvedValue {
207                    field: "dbname".to_string(),
208                })?;
209            if dbname.trim().is_empty() {
210                return Err(ContextError::EmptyResolvedValue {
211                    field: "dbname".to_string(),
212                });
213            }
214
215            let username = self
216                .resolve_param(namespace, &params.username, &params.username_secret)
217                .await?
218                .ok_or_else(|| ContextError::EmptyResolvedValue {
219                    field: "username".to_string(),
220                })?;
221            if username.trim().is_empty() {
222                return Err(ContextError::EmptyResolvedValue {
223                    field: "username".to_string(),
224                });
225            }
226
227            let password = self
228                .resolve_param(namespace, &params.password, &params.password_secret)
229                .await?
230                .ok_or_else(|| ContextError::EmptyResolvedValue {
231                    field: "password".to_string(),
232                })?;
233            if password.trim().is_empty() {
234                return Err(ContextError::EmptyResolvedValue {
235                    field: "password".to_string(),
236                });
237            }
238
239            use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
240            let encoded_username = utf8_percent_encode(&username, NON_ALPHANUMERIC).to_string();
241            let encoded_password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string();
242
243            let mut url = format!(
244                "postgresql://{encoded_username}:{encoded_password}@{host}:{port}/{dbname}"
245            );
246
247            if let Some(ssl_mode) = self
248                .resolve_param(namespace, &params.ssl_mode, &params.ssl_mode_secret)
249                .await?
250            {
251                // Validate sslMode at runtime — CRD validation only catches
252                // literal values; a secret ref could resolve to anything.
253                if !crate::crd::VALID_SSL_MODES.contains(&ssl_mode.as_str()) {
254                    return Err(ContextError::InvalidResolvedSslMode { value: ssl_mode });
255                }
256                url.push_str("?sslmode=");
257                url.push_str(&ssl_mode);
258            }
259
260            Ok(url)
261        } else {
262            Err(ContextError::SecretMissing {
263                name: "connection".to_string(),
264                key: "neither secretRef nor params is set".to_string(),
265            })
266        }
267    }
268
269    /// Get or create a PgPool for the given connection spec.
270    ///
271    /// Resolves the connection URL from the referenced Secret(s),
272    /// and caches the resulting pool for reuse.
273    pub async fn get_or_create_pool(
274        &self,
275        namespace: &str,
276        connection: &ConnectionSpec,
277    ) -> Result<PgPool, ContextError> {
278        let cache_key = connection.cache_key(namespace);
279
280        // For URL mode, we can do resource-version-based cache invalidation.
281        // For params mode, compute a fingerprint from all referenced secrets'
282        // resourceVersions so that secret rotations invalidate the cache.
283        let (resource_version, secret_fingerprint) =
284            if let Some(ref secret_ref) = connection.secret_ref {
285                let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
286                    kube::Api::namespaced(self.kube_client.clone(), namespace);
287                let secret = secrets_api.get(&secret_ref.name).await.map_err(|err| {
288                    ContextError::SecretFetch {
289                        name: secret_ref.name.clone(),
290                        namespace: namespace.to_string(),
291                        source: err,
292                    }
293                })?;
294                (secret.metadata.resource_version, None)
295            } else if connection.params.is_some() {
296                // Params mode — collect all referenced secret names and fetch their
297                // resourceVersions to build a fingerprint.
298                let mut secret_names = std::collections::BTreeSet::new();
299                connection.collect_secret_names(&mut secret_names);
300
301                if secret_names.is_empty() {
302                    // All values are literals — no secrets to watch.
303                    (None, Some(String::new()))
304                } else {
305                    let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
306                        kube::Api::namespaced(self.kube_client.clone(), namespace);
307                    let mut fingerprint_parts = Vec::new();
308                    for name in &secret_names {
309                        let secret = secrets_api.get(name).await.map_err(|err| {
310                            ContextError::SecretFetch {
311                                name: name.clone(),
312                                namespace: namespace.to_string(),
313                                source: err,
314                            }
315                        })?;
316                        let rv = secret
317                            .metadata
318                            .resource_version
319                            .unwrap_or_else(|| "unknown".to_string());
320                        fingerprint_parts.push(format!("{name}={rv}"));
321                    }
322                    (None, Some(fingerprint_parts.join(",")))
323                }
324            } else {
325                (None, None)
326            };
327
328        // Check cache.
329        {
330            let cache = self.pool_cache.read().await;
331            if let Some(cached) = cache.get(&cache_key) {
332                // URL mode: reuse if the Secret's resource_version matches.
333                // Params mode: reuse if the secret fingerprint matches.
334                let version_matches = match (&resource_version, &cached.resource_version) {
335                    (Some(current), Some(cached_rv)) => current == cached_rv,
336                    _ => true,
337                };
338                let fingerprint_matches = match (&secret_fingerprint, &cached.secret_fingerprint) {
339                    (Some(current), Some(cached_fp)) => current == cached_fp,
340                    (None, None) => true,
341                    _ => false,
342                };
343                if version_matches && fingerprint_matches {
344                    return Ok(cached.pool.clone());
345                }
346            }
347        }
348
349        let database_url = self.resolve_connection_url(namespace, connection).await?;
350
351        // Create pool with explicit sizing. Reconciliation holds one dedicated
352        // connection for PostgreSQL advisory locking and needs additional pool
353        // capacity for inspection/apply queries.
354        let pool = PgPoolOptions::new()
355            .max_connections(POOL_MAX_CONNECTIONS)
356            .acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
357            .connect(&database_url)
358            .await
359            .map_err(|err| ContextError::DatabaseConnect { source: err })?;
360
361        // Cache it (write lock).
362        {
363            let mut cache = self.pool_cache.write().await;
364            cache.insert(
365                cache_key,
366                CachedPool {
367                    resource_version,
368                    secret_fingerprint,
369                    pool: pool.clone(),
370                },
371            );
372        }
373
374        Ok(pool)
375    }
376
377    /// Fetch a single string value from a Kubernetes Secret.
378    ///
379    /// Used to resolve role passwords from Secret references at reconcile time.
380    pub async fn fetch_secret_value(
381        &self,
382        namespace: &str,
383        secret_name: &str,
384        secret_key: &str,
385    ) -> Result<String, ContextError> {
386        let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
387            kube::Api::namespaced(self.kube_client.clone(), namespace);
388
389        let secret =
390            secrets_api
391                .get(secret_name)
392                .await
393                .map_err(|err| ContextError::SecretFetch {
394                    name: secret_name.to_string(),
395                    namespace: namespace.to_string(),
396                    source: err,
397                })?;
398
399        let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
400            name: secret_name.to_string(),
401            key: secret_key.to_string(),
402        })?;
403
404        let value_bytes = data
405            .get(secret_key)
406            .ok_or_else(|| ContextError::SecretMissing {
407                name: secret_name.to_string(),
408                key: secret_key.to_string(),
409            })?;
410
411        String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
412            name: secret_name.to_string(),
413            key: secret_key.to_string(),
414        })
415    }
416
417    /// Remove a cached pool (e.g. when secret changes or CR is deleted).
418    pub async fn evict_pool(&self, namespace: &str, connection: &ConnectionSpec) {
419        let cache_key = connection.cache_key(namespace);
420        let mut cache = self.pool_cache.write().await;
421        cache.remove(&cache_key);
422    }
423}
424
425/// Errors from operator context operations.
426#[derive(Debug, thiserror::Error)]
427pub enum ContextError {
428    #[error("failed to fetch Secret {namespace}/{name}: {source}")]
429    SecretFetch {
430        name: String,
431        namespace: String,
432        source: kube::Error,
433    },
434
435    #[error("Secret \"{name}\" does not contain key \"{key}\"")]
436    SecretMissing { name: String, key: String },
437
438    #[error("failed to connect to database: {source}")]
439    DatabaseConnect { source: sqlx::Error },
440
441    #[error("connection param \"{field}\" resolved to an empty or whitespace-only value")]
442    EmptyResolvedValue { field: String },
443
444    #[error(
445        "connection param sslMode resolved to invalid value \"{value}\" (expected one of: disable, allow, prefer, require, verify-ca, verify-full)"
446    )]
447    InvalidResolvedSslMode { value: String },
448}
449
450impl ContextError {
451    /// Returns true when a Secret fetch failed due to a non-transient client-side API error.
452    pub fn is_secret_fetch_non_transient(&self) -> bool {
453        matches!(
454            self,
455            ContextError::SecretFetch {
456                source: kube::Error::Api(response),
457                ..
458            } if (400..500).contains(&response.code) && response.code != 429
459        )
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn pool_cache_key_format() {
469        // Verify the cache key format is "namespace/secret-name/secret-key"
470        let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
471        assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
472    }
473
474    #[test]
475    fn secret_fetch_not_found_is_non_transient() {
476        let error = ContextError::SecretFetch {
477            name: "db-credentials".into(),
478            namespace: "default".into(),
479            source: kube::Error::Api(
480                kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
481                    .with_code(404)
482                    .boxed(),
483            ),
484        };
485
486        assert!(error.is_secret_fetch_non_transient());
487    }
488
489    #[test]
490    fn secret_fetch_forbidden_is_non_transient() {
491        let error = ContextError::SecretFetch {
492            name: "db-credentials".into(),
493            namespace: "default".into(),
494            source: kube::Error::Api(
495                kube::core::Status::failure("forbidden", "Forbidden")
496                    .with_code(403)
497                    .boxed(),
498            ),
499        };
500
501        assert!(error.is_secret_fetch_non_transient());
502    }
503
504    #[test]
505    fn secret_fetch_server_error_remains_transient() {
506        let error = ContextError::SecretFetch {
507            name: "db-credentials".into(),
508            namespace: "default".into(),
509            source: kube::Error::Api(
510                kube::core::Status::failure("internal error", "InternalError")
511                    .with_code(500)
512                    .boxed(),
513            ),
514        };
515
516        assert!(!error.is_secret_fetch_non_transient());
517    }
518
519    #[tokio::test]
520    async fn try_lock_database_acquires_when_free() {
521        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
522        let ctx = OperatorContextLockHelper {
523            database_locks: locks,
524        };
525        let guard = ctx.try_lock("db-a").await;
526        assert!(guard.is_some(), "should acquire lock on free database");
527    }
528
529    #[tokio::test]
530    async fn try_lock_database_contention_returns_none() {
531        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
532        let ctx = OperatorContextLockHelper {
533            database_locks: locks,
534        };
535
536        let _guard1 = ctx
537            .try_lock("db-a")
538            .await
539            .expect("first lock should succeed");
540        let guard2 = ctx.try_lock("db-a").await;
541        assert!(guard2.is_none(), "second lock on same database should fail");
542    }
543
544    #[tokio::test]
545    async fn try_lock_database_different_databases_independent() {
546        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
547        let ctx = OperatorContextLockHelper {
548            database_locks: locks,
549        };
550
551        let guard_a = ctx.try_lock("db-a").await;
552        let guard_b = ctx.try_lock("db-b").await;
553        assert!(guard_a.is_some(), "lock on db-a should succeed");
554        assert!(
555            guard_b.is_some(),
556            "lock on db-b should succeed (different database)"
557        );
558    }
559
560    #[tokio::test]
561    async fn try_lock_database_released_after_drop() {
562        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
563        let ctx = OperatorContextLockHelper {
564            database_locks: Arc::clone(&locks),
565        };
566
567        {
568            let _guard = ctx.try_lock("db-a").await.expect("should acquire");
569            // guard is dropped here
570        }
571
572        // After drop, should be able to acquire again.
573        let guard2 = ctx.try_lock("db-a").await;
574        assert!(
575            guard2.is_some(),
576            "should re-acquire after previous guard dropped"
577        );
578    }
579
580    #[tokio::test]
581    async fn try_lock_database_concurrent_contention() {
582        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
583
584        // Simulate two concurrent reconciles for the same database.
585        let locks1 = Arc::clone(&locks);
586        let locks2 = Arc::clone(&locks);
587
588        let handle1 = tokio::spawn(async move {
589            let ctx = OperatorContextLockHelper {
590                database_locks: locks1,
591            };
592            let guard = ctx.try_lock("shared-db").await;
593            if guard.is_some() {
594                // Hold the lock briefly.
595                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
596            }
597            guard.is_some()
598        });
599
600        // Small delay so handle1 is likely first.
601        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
602
603        let handle2 = tokio::spawn(async move {
604            let ctx = OperatorContextLockHelper {
605                database_locks: locks2,
606            };
607            let guard = ctx.try_lock("shared-db").await;
608            guard.is_some()
609        });
610
611        let (r1, r2) = tokio::join!(handle1, handle2);
612        let acquired1 = r1.unwrap();
613        let acquired2 = r2.unwrap();
614
615        // Exactly one should succeed.
616        assert!(
617            acquired1 ^ acquired2,
618            "exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
619        );
620    }
621
622    /// Helper to test locking without a real kube client.
623    struct OperatorContextLockHelper {
624        database_locks: Arc<Mutex<HashMap<String, ()>>>,
625    }
626
627    impl OperatorContextLockHelper {
628        async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
629            let mut locks = self.database_locks.lock().await;
630            if locks.contains_key(database_identity) {
631                return None;
632            }
633            locks.insert(database_identity.to_string(), ());
634            Some(DatabaseLockGuard {
635                key: database_identity.to_string(),
636                locks: Arc::clone(&self.database_locks),
637            })
638        }
639    }
640
641    #[tokio::test]
642    async fn try_lock_database_high_concurrency_same_db() {
643        // Spawn many tasks all racing to lock the same database.
644        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
645        let concurrency = 50;
646        let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
647        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
648
649        let mut handles = Vec::with_capacity(concurrency);
650        for _ in 0..concurrency {
651            let locks_clone = Arc::clone(&locks);
652            let count = Arc::clone(&acquired_count);
653            let bar = Arc::clone(&barrier);
654            handles.push(tokio::spawn(async move {
655                // Synchronize start so all tasks race at the same instant.
656                bar.wait().await;
657                let ctx = OperatorContextLockHelper {
658                    database_locks: locks_clone,
659                };
660                let guard = ctx.try_lock("contested-db").await;
661                if guard.is_some() {
662                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
663                    // Hold lock briefly to let other tasks observe contention.
664                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
665                }
666            }));
667        }
668
669        for h in handles {
670            h.await.unwrap();
671        }
672
673        // Exactly one task should have acquired the lock.
674        let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
675        assert_eq!(
676            total, 1,
677            "exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
678        );
679    }
680
681    #[tokio::test]
682    async fn try_lock_database_high_concurrency_different_dbs() {
683        // Many tasks each locking a different database — all should succeed.
684        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
685        let concurrency = 50;
686        let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
687        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
688
689        let mut handles = Vec::with_capacity(concurrency);
690        for i in 0..concurrency {
691            let locks_clone = Arc::clone(&locks);
692            let count = Arc::clone(&acquired_count);
693            let bar = Arc::clone(&barrier);
694            handles.push(tokio::spawn(async move {
695                bar.wait().await;
696                let ctx = OperatorContextLockHelper {
697                    database_locks: locks_clone,
698                };
699                let db_name = format!("db-{i}");
700                let guard = ctx.try_lock(&db_name).await;
701                if guard.is_some() {
702                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
703                    tokio::time::sleep(std::time::Duration::from_millis(5)).await;
704                }
705            }));
706        }
707
708        for h in handles {
709            h.await.unwrap();
710        }
711
712        let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
713        assert_eq!(
714            total, concurrency,
715            "all {concurrency} tasks locking different dbs should succeed, got {total}"
716        );
717    }
718
719    #[tokio::test]
720    async fn try_lock_database_acquire_release_cycle_under_contention() {
721        // Repeatedly acquire and release the same database lock from many tasks.
722        // Each task attempts the lock in a loop until it succeeds, simulating
723        // the requeue-after-contention pattern used in the reconciler.
724        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
725        let concurrency = 20;
726        let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
727        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
728
729        let mut handles = Vec::with_capacity(concurrency);
730        for _ in 0..concurrency {
731            let locks_clone = Arc::clone(&locks);
732            let count = Arc::clone(&success_count);
733            let bar = Arc::clone(&barrier);
734            handles.push(tokio::spawn(async move {
735                bar.wait().await;
736                // Retry up to 100 times with a small sleep between attempts,
737                // simulating the jittered requeue pattern.
738                for _ in 0..100 {
739                    let ctx = OperatorContextLockHelper {
740                        database_locks: Arc::clone(&locks_clone),
741                    };
742                    if let Some(_guard) = ctx.try_lock("shared-db").await {
743                        count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
744                        // Brief simulated work, then guard drops (releasing lock).
745                        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
746                        return;
747                    }
748                    tokio::time::sleep(std::time::Duration::from_millis(1)).await;
749                }
750                // Should not reach here in practice — fail the test if we do.
751                panic!("task failed to acquire lock after 100 retries");
752            }));
753        }
754
755        for h in handles {
756            h.await.unwrap();
757        }
758
759        let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
760        assert_eq!(
761            total, concurrency,
762            "all {concurrency} tasks should eventually acquire the lock"
763        );
764    }
765}