Skip to main content

pgroles_operator/
context.rs

1//! Shared operator context — database pool cache, metrics, and configuration.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use kube::runtime::events::Recorder;
7use std::time::Duration;
8
9use sqlx::postgres::{PgPool, PgPoolOptions};
10use tokio::sync::{Mutex, RwLock};
11
12use crate::observability::OperatorObservability;
13
14/// Minimum pool size required for reconciliation.
15///
16/// One connection is held for the session-scoped advisory lock while the
17/// reconcile loop performs inspection and apply work on the pool.
18const POOL_MAX_CONNECTIONS: u32 = 5;
19
20/// Bound how long a reconcile waits for a pooled connection before surfacing
21/// a transient database connectivity failure.
22const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
23
24const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
25
26#[derive(Clone)]
27struct CachedPool {
28    resource_version: Option<String>,
29    pool: PgPool,
30}
31
32/// Guard returned by [`OperatorContext::try_lock_database`].
33///
34/// Holding this guard prevents other reconcile loops (within the same process)
35/// from starting work on the same database target. The lock is released when
36/// the guard is dropped.
37pub struct DatabaseLockGuard {
38    key: String,
39    locks: Arc<Mutex<HashMap<String, ()>>>,
40}
41
42impl Drop for DatabaseLockGuard {
43    fn drop(&mut self) {
44        // Best-effort removal — `try_lock` avoids blocking the drop.
45        if let Ok(mut map) = self.locks.try_lock() {
46            map.remove(&self.key);
47            tracing::debug!(database = %self.key, "released in-memory database lock");
48        } else {
49            // Spawn a task to clean up if the mutex is currently held.
50            // Use Handle::try_current() so we don't panic when dropped
51            // outside an active Tokio runtime (e.g. during shutdown).
52            let key = self.key.clone();
53            let locks = Arc::clone(&self.locks);
54            if let Ok(handle) = tokio::runtime::Handle::try_current() {
55                handle.spawn(async move {
56                    locks.lock().await.remove(&key);
57                    tracing::debug!(database = %key, "released in-memory database lock (deferred)");
58                });
59                tracing::debug!(
60                    database = %self.key,
61                    "deferred in-memory database lock release to background task"
62                );
63            } else {
64                // No runtime available — fall back to synchronous cleanup
65                // via blocking_lock so the entry is still removed.
66                let mut map = self.locks.blocking_lock();
67                map.remove(&key);
68                tracing::debug!(
69                    database = %key,
70                    "released in-memory database lock (fallback sync)"
71                );
72            }
73        }
74    }
75}
76
77/// Shared state for the operator, passed to every reconciliation.
78#[derive(Clone)]
79pub struct OperatorContext {
80    /// Kubernetes client for API calls.
81    pub kube_client: kube::Client,
82
83    /// Kubernetes Event recorder for transition-based policy Events.
84    pub event_recorder: Recorder,
85
86    /// Cached database connection pools keyed by `"namespace/secret-name/secret-key"`.
87    pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
88    /// In-process per-database reconciliation locks.
89    ///
90    /// Prevents concurrent reconcile loops from operating on the same database
91    /// within a single operator replica. Cross-replica safety is provided by
92    /// PostgreSQL advisory locks (see [`crate::advisory`]).
93    database_locks: Arc<Mutex<HashMap<String, ()>>>,
94
95    /// Shared health/metrics state.
96    pub observability: OperatorObservability,
97}
98
99impl OperatorContext {
100    /// Create a new operator context with an empty pool cache.
101    pub fn new(
102        kube_client: kube::Client,
103        observability: OperatorObservability,
104        event_recorder: Recorder,
105    ) -> Self {
106        Self {
107            kube_client,
108            event_recorder,
109            pool_cache: Arc::new(RwLock::new(HashMap::new())),
110            observability,
111            database_locks: Arc::new(Mutex::new(HashMap::new())),
112        }
113    }
114
115    /// Try to acquire the in-process lock for the given database identity.
116    ///
117    /// Returns `Some(guard)` if no other reconcile is in progress for this
118    /// database, `None` if one is already running. The lock is released when
119    /// the guard is dropped.
120    pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
121        let mut locks = self.database_locks.lock().await;
122        if locks.contains_key(database_identity) {
123            tracing::info!(
124                database = %database_identity,
125                "in-memory database lock contention — another reconcile is in progress"
126            );
127            return None;
128        }
129        locks.insert(database_identity.to_string(), ());
130        tracing::debug!(database = %database_identity, "acquired in-memory database lock");
131        Some(DatabaseLockGuard {
132            key: database_identity.to_string(),
133            locks: Arc::clone(&self.database_locks),
134        })
135    }
136
137    /// Get or create a PgPool for the given secret reference.
138    ///
139    /// Reads the `DATABASE_URL` (or custom key) from the referenced Secret,
140    /// and caches the resulting pool for reuse.
141    pub async fn get_or_create_pool(
142        &self,
143        namespace: &str,
144        secret_name: &str,
145        secret_key: &str,
146    ) -> Result<PgPool, ContextError> {
147        let cache_key = format!("{namespace}/{secret_name}/{secret_key}");
148
149        // Fetch secret from k8s API.
150        let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
151            kube::Api::namespaced(self.kube_client.clone(), namespace);
152
153        let secret =
154            secrets_api
155                .get(secret_name)
156                .await
157                .map_err(|err| ContextError::SecretFetch {
158                    name: secret_name.to_string(),
159                    namespace: namespace.to_string(),
160                    source: err,
161                })?;
162
163        let resource_version = secret.metadata.resource_version.clone();
164
165        // Check cache after reading the current Secret version.
166        {
167            let cache = self.pool_cache.read().await;
168            if let Some(cached) = cache.get(&cache_key)
169                && cached.resource_version == resource_version
170            {
171                return Ok(cached.pool.clone());
172            }
173        }
174
175        let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
176            name: secret_name.to_string(),
177            key: secret_key.to_string(),
178        })?;
179
180        let url_bytes = data
181            .get(secret_key)
182            .ok_or_else(|| ContextError::SecretMissing {
183                name: secret_name.to_string(),
184                key: secret_key.to_string(),
185            })?;
186
187        let database_url =
188            String::from_utf8(url_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
189                name: secret_name.to_string(),
190                key: secret_key.to_string(),
191            })?;
192
193        // Create pool with explicit sizing. Reconciliation holds one dedicated
194        // connection for PostgreSQL advisory locking and needs additional pool
195        // capacity for inspection/apply queries.
196        let pool = PgPoolOptions::new()
197            .max_connections(POOL_MAX_CONNECTIONS)
198            .acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
199            .connect(&database_url)
200            .await
201            .map_err(|err| ContextError::DatabaseConnect { source: err })?;
202
203        // Cache it (write lock).
204        {
205            let mut cache = self.pool_cache.write().await;
206            cache.insert(
207                cache_key,
208                CachedPool {
209                    resource_version,
210                    pool: pool.clone(),
211                },
212            );
213        }
214
215        Ok(pool)
216    }
217
218    /// Fetch a single string value from a Kubernetes Secret.
219    ///
220    /// Used to resolve role passwords from Secret references at reconcile time.
221    pub async fn fetch_secret_value(
222        &self,
223        namespace: &str,
224        secret_name: &str,
225        secret_key: &str,
226    ) -> Result<String, ContextError> {
227        let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
228            kube::Api::namespaced(self.kube_client.clone(), namespace);
229
230        let secret =
231            secrets_api
232                .get(secret_name)
233                .await
234                .map_err(|err| ContextError::SecretFetch {
235                    name: secret_name.to_string(),
236                    namespace: namespace.to_string(),
237                    source: err,
238                })?;
239
240        let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
241            name: secret_name.to_string(),
242            key: secret_key.to_string(),
243        })?;
244
245        let value_bytes = data
246            .get(secret_key)
247            .ok_or_else(|| ContextError::SecretMissing {
248                name: secret_name.to_string(),
249                key: secret_key.to_string(),
250            })?;
251
252        String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
253            name: secret_name.to_string(),
254            key: secret_key.to_string(),
255        })
256    }
257
258    /// Remove a cached pool (e.g. when secret changes or CR is deleted).
259    pub async fn evict_pool(&self, namespace: &str, secret_name: &str, secret_key: &str) {
260        let cache_key = format!("{namespace}/{secret_name}/{secret_key}");
261        let mut cache = self.pool_cache.write().await;
262        cache.remove(&cache_key);
263    }
264}
265
266/// Errors from operator context operations.
267#[derive(Debug, thiserror::Error)]
268pub enum ContextError {
269    #[error("failed to fetch Secret {namespace}/{name}: {source}")]
270    SecretFetch {
271        name: String,
272        namespace: String,
273        source: kube::Error,
274    },
275
276    #[error("Secret \"{name}\" does not contain key \"{key}\"")]
277    SecretMissing { name: String, key: String },
278
279    #[error("failed to connect to database: {source}")]
280    DatabaseConnect { source: sqlx::Error },
281}
282
283impl ContextError {
284    /// Returns true when a Secret fetch failed due to a non-transient client-side API error.
285    pub fn is_secret_fetch_non_transient(&self) -> bool {
286        matches!(
287            self,
288            ContextError::SecretFetch {
289                source: kube::Error::Api(response),
290                ..
291            } if (400..500).contains(&response.code) && response.code != 429
292        )
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    #[test]
301    fn pool_cache_key_format() {
302        // Verify the cache key format is "namespace/secret-name/secret-key"
303        let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
304        assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
305    }
306
307    #[test]
308    fn secret_fetch_not_found_is_non_transient() {
309        let error = ContextError::SecretFetch {
310            name: "db-credentials".into(),
311            namespace: "default".into(),
312            source: kube::Error::Api(
313                kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
314                    .with_code(404)
315                    .boxed(),
316            ),
317        };
318
319        assert!(error.is_secret_fetch_non_transient());
320    }
321
322    #[test]
323    fn secret_fetch_forbidden_is_non_transient() {
324        let error = ContextError::SecretFetch {
325            name: "db-credentials".into(),
326            namespace: "default".into(),
327            source: kube::Error::Api(
328                kube::core::Status::failure("forbidden", "Forbidden")
329                    .with_code(403)
330                    .boxed(),
331            ),
332        };
333
334        assert!(error.is_secret_fetch_non_transient());
335    }
336
337    #[test]
338    fn secret_fetch_server_error_remains_transient() {
339        let error = ContextError::SecretFetch {
340            name: "db-credentials".into(),
341            namespace: "default".into(),
342            source: kube::Error::Api(
343                kube::core::Status::failure("internal error", "InternalError")
344                    .with_code(500)
345                    .boxed(),
346            ),
347        };
348
349        assert!(!error.is_secret_fetch_non_transient());
350    }
351
352    #[tokio::test]
353    async fn try_lock_database_acquires_when_free() {
354        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
355        let ctx = OperatorContextLockHelper {
356            database_locks: locks,
357        };
358        let guard = ctx.try_lock("db-a").await;
359        assert!(guard.is_some(), "should acquire lock on free database");
360    }
361
362    #[tokio::test]
363    async fn try_lock_database_contention_returns_none() {
364        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
365        let ctx = OperatorContextLockHelper {
366            database_locks: locks,
367        };
368
369        let _guard1 = ctx
370            .try_lock("db-a")
371            .await
372            .expect("first lock should succeed");
373        let guard2 = ctx.try_lock("db-a").await;
374        assert!(guard2.is_none(), "second lock on same database should fail");
375    }
376
377    #[tokio::test]
378    async fn try_lock_database_different_databases_independent() {
379        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
380        let ctx = OperatorContextLockHelper {
381            database_locks: locks,
382        };
383
384        let guard_a = ctx.try_lock("db-a").await;
385        let guard_b = ctx.try_lock("db-b").await;
386        assert!(guard_a.is_some(), "lock on db-a should succeed");
387        assert!(
388            guard_b.is_some(),
389            "lock on db-b should succeed (different database)"
390        );
391    }
392
393    #[tokio::test]
394    async fn try_lock_database_released_after_drop() {
395        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
396        let ctx = OperatorContextLockHelper {
397            database_locks: Arc::clone(&locks),
398        };
399
400        {
401            let _guard = ctx.try_lock("db-a").await.expect("should acquire");
402            // guard is dropped here
403        }
404
405        // After drop, should be able to acquire again.
406        let guard2 = ctx.try_lock("db-a").await;
407        assert!(
408            guard2.is_some(),
409            "should re-acquire after previous guard dropped"
410        );
411    }
412
413    #[tokio::test]
414    async fn try_lock_database_concurrent_contention() {
415        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
416
417        // Simulate two concurrent reconciles for the same database.
418        let locks1 = Arc::clone(&locks);
419        let locks2 = Arc::clone(&locks);
420
421        let handle1 = tokio::spawn(async move {
422            let ctx = OperatorContextLockHelper {
423                database_locks: locks1,
424            };
425            let guard = ctx.try_lock("shared-db").await;
426            if guard.is_some() {
427                // Hold the lock briefly.
428                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
429            }
430            guard.is_some()
431        });
432
433        // Small delay so handle1 is likely first.
434        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
435
436        let handle2 = tokio::spawn(async move {
437            let ctx = OperatorContextLockHelper {
438                database_locks: locks2,
439            };
440            let guard = ctx.try_lock("shared-db").await;
441            guard.is_some()
442        });
443
444        let (r1, r2) = tokio::join!(handle1, handle2);
445        let acquired1 = r1.unwrap();
446        let acquired2 = r2.unwrap();
447
448        // Exactly one should succeed.
449        assert!(
450            acquired1 ^ acquired2,
451            "exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
452        );
453    }
454
455    /// Helper to test locking without a real kube client.
456    struct OperatorContextLockHelper {
457        database_locks: Arc<Mutex<HashMap<String, ()>>>,
458    }
459
460    impl OperatorContextLockHelper {
461        async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
462            let mut locks = self.database_locks.lock().await;
463            if locks.contains_key(database_identity) {
464                return None;
465            }
466            locks.insert(database_identity.to_string(), ());
467            Some(DatabaseLockGuard {
468                key: database_identity.to_string(),
469                locks: Arc::clone(&self.database_locks),
470            })
471        }
472    }
473
474    #[tokio::test]
475    async fn try_lock_database_high_concurrency_same_db() {
476        // Spawn many tasks all racing to lock the same database.
477        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
478        let concurrency = 50;
479        let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
480        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
481
482        let mut handles = Vec::with_capacity(concurrency);
483        for _ in 0..concurrency {
484            let locks_clone = Arc::clone(&locks);
485            let count = Arc::clone(&acquired_count);
486            let bar = Arc::clone(&barrier);
487            handles.push(tokio::spawn(async move {
488                // Synchronize start so all tasks race at the same instant.
489                bar.wait().await;
490                let ctx = OperatorContextLockHelper {
491                    database_locks: locks_clone,
492                };
493                let guard = ctx.try_lock("contested-db").await;
494                if guard.is_some() {
495                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
496                    // Hold lock briefly to let other tasks observe contention.
497                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
498                }
499            }));
500        }
501
502        for h in handles {
503            h.await.unwrap();
504        }
505
506        // Exactly one task should have acquired the lock.
507        let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
508        assert_eq!(
509            total, 1,
510            "exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
511        );
512    }
513
514    #[tokio::test]
515    async fn try_lock_database_high_concurrency_different_dbs() {
516        // Many tasks each locking a different database — all should succeed.
517        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
518        let concurrency = 50;
519        let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
520        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
521
522        let mut handles = Vec::with_capacity(concurrency);
523        for i in 0..concurrency {
524            let locks_clone = Arc::clone(&locks);
525            let count = Arc::clone(&acquired_count);
526            let bar = Arc::clone(&barrier);
527            handles.push(tokio::spawn(async move {
528                bar.wait().await;
529                let ctx = OperatorContextLockHelper {
530                    database_locks: locks_clone,
531                };
532                let db_name = format!("db-{i}");
533                let guard = ctx.try_lock(&db_name).await;
534                if guard.is_some() {
535                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
536                    tokio::time::sleep(std::time::Duration::from_millis(5)).await;
537                }
538            }));
539        }
540
541        for h in handles {
542            h.await.unwrap();
543        }
544
545        let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
546        assert_eq!(
547            total, concurrency,
548            "all {concurrency} tasks locking different dbs should succeed, got {total}"
549        );
550    }
551
552    #[tokio::test]
553    async fn try_lock_database_acquire_release_cycle_under_contention() {
554        // Repeatedly acquire and release the same database lock from many tasks.
555        // Each task attempts the lock in a loop until it succeeds, simulating
556        // the requeue-after-contention pattern used in the reconciler.
557        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
558        let concurrency = 20;
559        let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
560        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
561
562        let mut handles = Vec::with_capacity(concurrency);
563        for _ in 0..concurrency {
564            let locks_clone = Arc::clone(&locks);
565            let count = Arc::clone(&success_count);
566            let bar = Arc::clone(&barrier);
567            handles.push(tokio::spawn(async move {
568                bar.wait().await;
569                // Retry up to 100 times with a small sleep between attempts,
570                // simulating the jittered requeue pattern.
571                for _ in 0..100 {
572                    let ctx = OperatorContextLockHelper {
573                        database_locks: Arc::clone(&locks_clone),
574                    };
575                    if let Some(_guard) = ctx.try_lock("shared-db").await {
576                        count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
577                        // Brief simulated work, then guard drops (releasing lock).
578                        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
579                        return;
580                    }
581                    tokio::time::sleep(std::time::Duration::from_millis(1)).await;
582                }
583                // Should not reach here in practice — fail the test if we do.
584                panic!("task failed to acquire lock after 100 retries");
585            }));
586        }
587
588        for h in handles {
589            h.await.unwrap();
590        }
591
592        let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
593        assert_eq!(
594            total, concurrency,
595            "all {concurrency} tasks should eventually acquire the lock"
596        );
597    }
598}