Skip to main content

pgroles_operator/
context.rs

1//! Shared operator context — database pool cache, metrics, and configuration.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, SystemTime};
6
7use futures::future::{BoxFuture, FutureExt};
8use kube::runtime::events::Recorder;
9use serde::{Deserialize, Serialize};
10
11use sqlx::postgres::{PgPool, PgPoolOptions};
12use tokio::sync::{Mutex, RwLock};
13
14use crate::crd::{ConnectionAuth, ConnectionSpec, SecretKeySelector};
15use crate::observability::OperatorObservability;
16
17/// Minimum pool size required for reconciliation.
18///
19/// One connection is held for the session-scoped advisory lock while the
20/// reconcile loop performs inspection and apply work on the pool.
21const POOL_MAX_CONNECTIONS: u32 = 5;
22
23/// Bound how long a reconcile waits for a pooled connection before surfacing
24/// a transient database connectivity failure.
25const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
26
27const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
28
29const GCP_METADATA_TOKEN_ENDPOINT: &str =
30    "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
31const GCP_IAM_CREDENTIALS_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
32const GCP_TOKEN_CACHE_SKEW_SECS: u64 = 300;
33const GCP_IMPERSONATED_TOKEN_LIFETIME_SECS: u64 = 3600;
34const GCP_AUTH_HTTP_TIMEOUT_SECS: u64 = 10;
35
36#[derive(Clone)]
37struct CachedPool {
38    resource_version: Option<String>,
39    /// Fingerprint of all referenced secrets' resourceVersions (params mode).
40    secret_fingerprint: Option<String>,
41    /// Expiry for token-backed connection passwords.
42    token_expires_at: Option<SystemTime>,
43    pool: PgPool,
44}
45
46struct ResolvedConnectionUrl {
47    database_url: String,
48    token_expires_at: Option<SystemTime>,
49}
50
51#[derive(Clone)]
52struct GcpAccessToken {
53    token: String,
54    expires_at: SystemTime,
55}
56
57trait GcpAccessTokenProvider: Send + Sync {
58    fn fetch_token<'a>(
59        &'a self,
60        auth: &'a ConnectionAuth,
61    ) -> BoxFuture<'a, Result<GcpAccessToken, ContextError>>;
62}
63
64#[derive(Clone)]
65struct MetadataGcpAccessTokenProvider {
66    client: reqwest::Client,
67}
68
69impl Default for MetadataGcpAccessTokenProvider {
70    fn default() -> Self {
71        Self {
72            client: reqwest::Client::builder()
73                .no_proxy()
74                .timeout(Duration::from_secs(GCP_AUTH_HTTP_TIMEOUT_SECS))
75                .build()
76                .expect("GCP auth HTTP client should build"),
77        }
78    }
79}
80
81impl GcpAccessTokenProvider for MetadataGcpAccessTokenProvider {
82    fn fetch_token<'a>(
83        &'a self,
84        auth: &'a ConnectionAuth,
85    ) -> BoxFuture<'a, Result<GcpAccessToken, ContextError>> {
86        async move {
87            let scope = auth.gcp_scope();
88            if let Some(target) = auth.gcp_impersonate_service_account() {
89                self.fetch_impersonated_access_token(target, scope).await
90            } else {
91                self.fetch_metadata_access_token(scope).await
92            }
93        }
94        .boxed()
95    }
96}
97
98impl MetadataGcpAccessTokenProvider {
99    async fn fetch_metadata_access_token(
100        &self,
101        scope: &str,
102    ) -> Result<GcpAccessToken, ContextError> {
103        let response = self
104            .client
105            .get(GCP_METADATA_TOKEN_ENDPOINT)
106            .header("Metadata-Flavor", "Google")
107            .query(&[("scopes", scope)])
108            .send()
109            .await
110            .map_err(|source| ContextError::GcpAuthHttp {
111                endpoint: "metadata",
112                source,
113            })?;
114
115        let status = response.status();
116        if !status.is_success() {
117            let body = response_body_for_error(response).await;
118            return Err(ContextError::GcpAuthRejected {
119                endpoint: "metadata".to_string(),
120                status: status.as_u16(),
121                body,
122            });
123        }
124
125        let body: MetadataTokenResponse =
126            response
127                .json()
128                .await
129                .map_err(|source| ContextError::GcpAuthHttp {
130                    endpoint: "metadata",
131                    source,
132                })?;
133
134        if body.access_token.trim().is_empty() {
135            return Err(ContextError::GcpAuthInvalidResponse {
136                detail: "metadata token response omitted access_token".to_string(),
137            });
138        }
139        if body.expires_in == 0 {
140            return Err(ContextError::GcpAuthInvalidResponse {
141                detail: "metadata token response had zero expires_in".to_string(),
142            });
143        }
144
145        Ok(GcpAccessToken {
146            token: body.access_token,
147            expires_at: SystemTime::now() + Duration::from_secs(body.expires_in),
148        })
149    }
150
151    async fn fetch_impersonated_access_token(
152        &self,
153        target_service_account: &str,
154        scope: &str,
155    ) -> Result<GcpAccessToken, ContextError> {
156        let source = self
157            .fetch_metadata_access_token(GCP_IAM_CREDENTIALS_SCOPE)
158            .await?;
159        let encoded_target = percent_encoding::utf8_percent_encode(
160            target_service_account,
161            percent_encoding::NON_ALPHANUMERIC,
162        )
163        .to_string();
164        let endpoint = format!(
165            "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{encoded_target}:generateAccessToken"
166        );
167        let request = GenerateAccessTokenRequest {
168            scope: vec![scope.to_string()],
169            lifetime: format!("{GCP_IMPERSONATED_TOKEN_LIFETIME_SECS}s"),
170        };
171
172        let response = self
173            .client
174            .post(&endpoint)
175            .bearer_auth(&source.token)
176            .json(&request)
177            .send()
178            .await
179            .map_err(|source| ContextError::GcpAuthHttp {
180                endpoint: "iamcredentials",
181                source,
182            })?;
183
184        let status = response.status();
185        if !status.is_success() {
186            let body = response_body_for_error(response).await;
187            return Err(ContextError::GcpAuthRejected {
188                endpoint: "iamcredentials".to_string(),
189                status: status.as_u16(),
190                body,
191            });
192        }
193
194        let body: GenerateAccessTokenResponse =
195            response
196                .json()
197                .await
198                .map_err(|source| ContextError::GcpAuthHttp {
199                    endpoint: "iamcredentials",
200                    source,
201                })?;
202
203        if body.access_token.trim().is_empty() {
204            return Err(ContextError::GcpAuthInvalidResponse {
205                detail: "IAMCredentials response omitted accessToken".to_string(),
206            });
207        }
208        let expires_at = parse_google_expire_time(&body.expire_time).ok_or_else(|| {
209            ContextError::GcpAuthInvalidResponse {
210                detail: format!(
211                    "IAMCredentials response had invalid expireTime {:?}",
212                    body.expire_time
213                ),
214            }
215        })?;
216
217        Ok(GcpAccessToken {
218            token: body.access_token,
219            expires_at,
220        })
221    }
222}
223
224#[derive(Deserialize)]
225struct MetadataTokenResponse {
226    access_token: String,
227    expires_in: u64,
228}
229
230#[derive(Serialize)]
231struct GenerateAccessTokenRequest {
232    scope: Vec<String>,
233    lifetime: String,
234}
235
236#[derive(Deserialize)]
237struct GenerateAccessTokenResponse {
238    #[serde(rename = "accessToken")]
239    access_token: String,
240    #[serde(rename = "expireTime")]
241    expire_time: String,
242}
243
244async fn response_body_for_error(response: reqwest::Response) -> String {
245    match response.text().await {
246        Ok(body) => truncate_for_error(body),
247        Err(error) => format!("failed to read error body: {error}"),
248    }
249}
250
251fn truncate_for_error(mut body: String) -> String {
252    const MAX_ERROR_BODY_BYTES: usize = 512;
253    if body.len() <= MAX_ERROR_BODY_BYTES {
254        return body;
255    }
256    let mut end = MAX_ERROR_BODY_BYTES;
257    while !body.is_char_boundary(end) {
258        end -= 1;
259    }
260    body.truncate(end);
261    body.push_str("...");
262    body
263}
264
265fn parse_google_expire_time(expire_time: &str) -> Option<SystemTime> {
266    expire_time
267        .parse::<jiff::Timestamp>()
268        .ok()
269        .map(SystemTime::from)
270}
271
272fn token_expires_after_skew(expires_at: Option<SystemTime>, now: SystemTime) -> bool {
273    let Some(expires_at) = expires_at else {
274        return true;
275    };
276    let Some(refresh_at) = now.checked_add(Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS)) else {
277        return false;
278    };
279    expires_at > refresh_at
280}
281
282/// Guard returned by [`OperatorContext::try_lock_database`].
283///
284/// Holding this guard prevents other reconcile loops (within the same process)
285/// from starting work on the same database target. The lock is released when
286/// the guard is dropped.
287pub struct DatabaseLockGuard {
288    key: String,
289    locks: Arc<Mutex<HashMap<String, ()>>>,
290}
291
292impl Drop for DatabaseLockGuard {
293    fn drop(&mut self) {
294        // Best-effort removal — `try_lock` avoids blocking the drop.
295        if let Ok(mut map) = self.locks.try_lock() {
296            map.remove(&self.key);
297            tracing::debug!(database = %self.key, "released in-memory database lock");
298        } else {
299            // Spawn a task to clean up if the mutex is currently held.
300            // Use Handle::try_current() so we don't panic when dropped
301            // outside an active Tokio runtime (e.g. during shutdown).
302            let key = self.key.clone();
303            let locks = Arc::clone(&self.locks);
304            if let Ok(handle) = tokio::runtime::Handle::try_current() {
305                handle.spawn(async move {
306                    locks.lock().await.remove(&key);
307                    tracing::debug!(database = %key, "released in-memory database lock (deferred)");
308                });
309                tracing::debug!(
310                    database = %self.key,
311                    "deferred in-memory database lock release to background task"
312                );
313            } else {
314                // No runtime available — fall back to synchronous cleanup
315                // via blocking_lock so the entry is still removed.
316                let mut map = self.locks.blocking_lock();
317                map.remove(&key);
318                tracing::debug!(
319                    database = %key,
320                    "released in-memory database lock (fallback sync)"
321                );
322            }
323        }
324    }
325}
326
327/// Shared state for the operator, passed to every reconciliation.
328#[derive(Clone)]
329pub struct OperatorContext {
330    /// Kubernetes client for API calls.
331    pub kube_client: kube::Client,
332
333    /// Kubernetes Event recorder for transition-based policy Events.
334    pub event_recorder: Recorder,
335
336    /// Cached database connection pools keyed by `"namespace/secret-name/secret-key"`.
337    pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
338    /// In-process per-database reconciliation locks.
339    ///
340    /// Prevents concurrent reconcile loops from operating on the same database
341    /// within a single operator replica. Cross-replica safety is provided by
342    /// PostgreSQL advisory locks (see [`crate::advisory`]).
343    database_locks: Arc<Mutex<HashMap<String, ()>>>,
344
345    /// Shared health/metrics state.
346    pub observability: OperatorObservability,
347
348    /// Fetches short-lived provider-backed database passwords.
349    gcp_token_provider: Arc<dyn GcpAccessTokenProvider>,
350}
351
352impl OperatorContext {
353    /// Create a new operator context with an empty pool cache.
354    pub fn new(
355        kube_client: kube::Client,
356        observability: OperatorObservability,
357        event_recorder: Recorder,
358    ) -> Self {
359        Self {
360            kube_client,
361            event_recorder,
362            pool_cache: Arc::new(RwLock::new(HashMap::new())),
363            observability,
364            database_locks: Arc::new(Mutex::new(HashMap::new())),
365            gcp_token_provider: Arc::new(MetadataGcpAccessTokenProvider::default()),
366        }
367    }
368
369    /// Try to acquire the in-process lock for the given database identity.
370    ///
371    /// Returns `Some(guard)` if no other reconcile is in progress for this
372    /// database, `None` if one is already running. The lock is released when
373    /// the guard is dropped.
374    pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
375        let mut locks = self.database_locks.lock().await;
376        if locks.contains_key(database_identity) {
377            tracing::info!(
378                database = %database_identity,
379                "in-memory database lock contention — another reconcile is in progress"
380            );
381            return None;
382        }
383        locks.insert(database_identity.to_string(), ());
384        tracing::debug!(database = %database_identity, "acquired in-memory database lock");
385        Some(DatabaseLockGuard {
386            key: database_identity.to_string(),
387            locks: Arc::clone(&self.database_locks),
388        })
389    }
390
391    /// Resolve a param from either its literal value or a Secret reference.
392    ///
393    /// Returns `Ok(Some(value))` if one is set, `Ok(None)` if neither is set.
394    async fn resolve_param(
395        &self,
396        namespace: &str,
397        literal: &Option<String>,
398        secret: &Option<SecretKeySelector>,
399    ) -> Result<Option<String>, ContextError> {
400        if let Some(val) = literal {
401            return Ok(Some(val.clone()));
402        }
403        if let Some(sel) = secret {
404            return Ok(Some(
405                self.fetch_secret_value(namespace, &sel.name, &sel.key)
406                    .await?,
407            ));
408        }
409        Ok(None)
410    }
411
412    /// Resolve a [`ConnectionSpec`] into a PostgreSQL connection URL string.
413    ///
414    /// - **URL mode** (`secret_ref` is Some): reads the Secret key as a connection URL.
415    /// - **Params mode** (`params` is Some): resolves each field and constructs a URL.
416    pub async fn resolve_connection_url(
417        &self,
418        namespace: &str,
419        connection: &ConnectionSpec,
420    ) -> Result<String, ContextError> {
421        Ok(self
422            .resolve_connection_url_with_metadata(namespace, connection)
423            .await?
424            .database_url)
425    }
426
427    async fn resolve_connection_url_with_metadata(
428        &self,
429        namespace: &str,
430        connection: &ConnectionSpec,
431    ) -> Result<ResolvedConnectionUrl, ContextError> {
432        if let Some(ref secret_ref) = connection.secret_ref {
433            // URL mode — read the full connection URL from the Secret.
434            let database_url = self
435                .fetch_secret_value(
436                    namespace,
437                    &secret_ref.name,
438                    connection.effective_secret_key(),
439                )
440                .await?;
441            Ok(ResolvedConnectionUrl {
442                database_url,
443                token_expires_at: None,
444            })
445        } else if let Some(ref params) = connection.params {
446            // Params mode — resolve each field and build the URL.
447            let host = self
448                .resolve_param(namespace, &params.host, &params.host_secret)
449                .await?
450                .ok_or_else(|| ContextError::EmptyResolvedValue {
451                    field: "host".to_string(),
452                })?;
453            if host.trim().is_empty() {
454                return Err(ContextError::EmptyResolvedValue {
455                    field: "host".to_string(),
456                });
457            }
458
459            let port_str = params.port.map(|p| p.to_string());
460            let port = self
461                .resolve_param(namespace, &port_str, &params.port_secret)
462                .await?
463                .unwrap_or_else(|| "5432".to_string());
464            if port.trim().is_empty() {
465                return Err(ContextError::EmptyResolvedValue {
466                    field: "port".to_string(),
467                });
468            }
469
470            let dbname = self
471                .resolve_param(namespace, &params.dbname, &params.dbname_secret)
472                .await?
473                .ok_or_else(|| ContextError::EmptyResolvedValue {
474                    field: "dbname".to_string(),
475                })?;
476            if dbname.trim().is_empty() {
477                return Err(ContextError::EmptyResolvedValue {
478                    field: "dbname".to_string(),
479                });
480            }
481
482            let username = self
483                .resolve_param(namespace, &params.username, &params.username_secret)
484                .await?
485                .ok_or_else(|| ContextError::EmptyResolvedValue {
486                    field: "username".to_string(),
487                })?;
488            if username.trim().is_empty() {
489                return Err(ContextError::EmptyResolvedValue {
490                    field: "username".to_string(),
491                });
492            }
493
494            let (password, token_expires_at) = if let Some(auth) = &params.auth {
495                let token = self.gcp_token_provider.fetch_token(auth).await?;
496                (token.token, Some(token.expires_at))
497            } else {
498                let password = self
499                    .resolve_param(namespace, &params.password, &params.password_secret)
500                    .await?
501                    .ok_or_else(|| ContextError::EmptyResolvedValue {
502                        field: "password".to_string(),
503                    })?;
504                (password, None)
505            };
506            if password.trim().is_empty() {
507                return Err(ContextError::EmptyResolvedValue {
508                    field: "password".to_string(),
509                });
510            }
511
512            use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
513            let encoded_username = utf8_percent_encode(&username, NON_ALPHANUMERIC).to_string();
514            let encoded_password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string();
515
516            let mut url = format!(
517                "postgresql://{encoded_username}:{encoded_password}@{host}:{port}/{dbname}"
518            );
519
520            let ssl_mode = self
521                .resolve_param(namespace, &params.ssl_mode, &params.ssl_mode_secret)
522                .await?
523                .or_else(|| params.auth.as_ref().map(|_| "require".to_string()));
524            if let Some(ssl_mode) = ssl_mode {
525                // Validate sslMode at runtime — CRD validation only catches
526                // literal values; a secret ref could resolve to anything.
527                if !crate::crd::VALID_SSL_MODES.contains(&ssl_mode.as_str()) {
528                    return Err(ContextError::InvalidResolvedSslMode { value: ssl_mode });
529                }
530                url.push_str("?sslmode=");
531                url.push_str(&ssl_mode);
532            }
533
534            Ok(ResolvedConnectionUrl {
535                database_url: url,
536                token_expires_at,
537            })
538        } else {
539            Err(ContextError::SecretMissing {
540                name: "connection".to_string(),
541                key: "neither secretRef nor params is set".to_string(),
542            })
543        }
544    }
545
546    /// Get or create a PgPool for the given connection spec.
547    ///
548    /// Resolves the connection URL from the referenced Secret(s),
549    /// and caches the resulting pool for reuse.
550    pub async fn get_or_create_pool(
551        &self,
552        namespace: &str,
553        connection: &ConnectionSpec,
554    ) -> Result<PgPool, ContextError> {
555        let cache_key = connection.cache_key(namespace);
556
557        // For URL mode, we can do resource-version-based cache invalidation.
558        // For params mode, compute a fingerprint from all referenced secrets'
559        // resourceVersions so that secret rotations invalidate the cache.
560        let (resource_version, secret_fingerprint) =
561            if let Some(ref secret_ref) = connection.secret_ref {
562                let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
563                    kube::Api::namespaced(self.kube_client.clone(), namespace);
564                let secret = secrets_api.get(&secret_ref.name).await.map_err(|err| {
565                    ContextError::SecretFetch {
566                        name: secret_ref.name.clone(),
567                        namespace: namespace.to_string(),
568                        source: err,
569                    }
570                })?;
571                (secret.metadata.resource_version, None)
572            } else if connection.params.is_some() {
573                // Params mode — collect all referenced secret names and fetch their
574                // resourceVersions to build a fingerprint.
575                let mut secret_names = std::collections::BTreeSet::new();
576                connection.collect_secret_names(&mut secret_names);
577
578                if secret_names.is_empty() {
579                    // All values are literals — no secrets to watch.
580                    (None, Some(String::new()))
581                } else {
582                    let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
583                        kube::Api::namespaced(self.kube_client.clone(), namespace);
584                    let mut fingerprint_parts = Vec::new();
585                    for name in &secret_names {
586                        let secret = secrets_api.get(name).await.map_err(|err| {
587                            ContextError::SecretFetch {
588                                name: name.clone(),
589                                namespace: namespace.to_string(),
590                                source: err,
591                            }
592                        })?;
593                        let rv = secret
594                            .metadata
595                            .resource_version
596                            .unwrap_or_else(|| "unknown".to_string());
597                        fingerprint_parts.push(format!("{name}={rv}"));
598                    }
599                    (None, Some(fingerprint_parts.join(",")))
600                }
601            } else {
602                (None, None)
603            };
604
605        // Check cache.
606        {
607            let cache = self.pool_cache.read().await;
608            if let Some(cached) = cache.get(&cache_key) {
609                // URL mode: reuse if the Secret's resource_version matches.
610                // Params mode: reuse if the secret fingerprint matches.
611                let version_matches = match (&resource_version, &cached.resource_version) {
612                    (Some(current), Some(cached_rv)) => current == cached_rv,
613                    _ => true,
614                };
615                let fingerprint_matches = match (&secret_fingerprint, &cached.secret_fingerprint) {
616                    (Some(current), Some(cached_fp)) => current == cached_fp,
617                    (None, None) => true,
618                    _ => false,
619                };
620                let token_fresh =
621                    token_expires_after_skew(cached.token_expires_at, SystemTime::now());
622                if version_matches && fingerprint_matches && token_fresh {
623                    return Ok(cached.pool.clone());
624                }
625            }
626        }
627
628        let resolved = self
629            .resolve_connection_url_with_metadata(namespace, connection)
630            .await?;
631
632        // Create pool with explicit sizing. Reconciliation holds one dedicated
633        // connection for PostgreSQL advisory locking and needs additional pool
634        // capacity for inspection/apply queries.
635        let pool = PgPoolOptions::new()
636            .max_connections(POOL_MAX_CONNECTIONS)
637            .acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
638            .connect(&resolved.database_url)
639            .await
640            .map_err(|err| ContextError::DatabaseConnect { source: err })?;
641
642        // Cache it (write lock).
643        {
644            let mut cache = self.pool_cache.write().await;
645            cache.insert(
646                cache_key,
647                CachedPool {
648                    resource_version,
649                    secret_fingerprint,
650                    token_expires_at: resolved.token_expires_at,
651                    pool: pool.clone(),
652                },
653            );
654        }
655
656        Ok(pool)
657    }
658
659    /// Fetch a single string value from a Kubernetes Secret.
660    ///
661    /// Used to resolve role passwords from Secret references at reconcile time.
662    pub async fn fetch_secret_value(
663        &self,
664        namespace: &str,
665        secret_name: &str,
666        secret_key: &str,
667    ) -> Result<String, ContextError> {
668        let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
669            kube::Api::namespaced(self.kube_client.clone(), namespace);
670
671        let secret =
672            secrets_api
673                .get(secret_name)
674                .await
675                .map_err(|err| ContextError::SecretFetch {
676                    name: secret_name.to_string(),
677                    namespace: namespace.to_string(),
678                    source: err,
679                })?;
680
681        let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
682            name: secret_name.to_string(),
683            key: secret_key.to_string(),
684        })?;
685
686        let value_bytes = data
687            .get(secret_key)
688            .ok_or_else(|| ContextError::SecretMissing {
689                name: secret_name.to_string(),
690                key: secret_key.to_string(),
691            })?;
692
693        String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
694            name: secret_name.to_string(),
695            key: secret_key.to_string(),
696        })
697    }
698
699    /// Remove a cached pool (e.g. when secret changes or CR is deleted).
700    pub async fn evict_pool(&self, namespace: &str, connection: &ConnectionSpec) {
701        let cache_key = connection.cache_key(namespace);
702        let mut cache = self.pool_cache.write().await;
703        cache.remove(&cache_key);
704    }
705}
706
707/// Errors from operator context operations.
708#[derive(Debug, thiserror::Error)]
709pub enum ContextError {
710    #[error("failed to fetch Secret {namespace}/{name}: {source}")]
711    SecretFetch {
712        name: String,
713        namespace: String,
714        source: kube::Error,
715    },
716
717    #[error("Secret \"{name}\" does not contain key \"{key}\"")]
718    SecretMissing { name: String, key: String },
719
720    #[error("failed to connect to database: {source}")]
721    DatabaseConnect { source: sqlx::Error },
722
723    #[error("connection param \"{field}\" resolved to an empty or whitespace-only value")]
724    EmptyResolvedValue { field: String },
725
726    #[error(
727        "connection param sslMode resolved to invalid value \"{value}\" (expected one of: disable, allow, prefer, require, verify-ca, verify-full)"
728    )]
729    InvalidResolvedSslMode { value: String },
730
731    #[error("failed to fetch GCP auth token from {endpoint}: {source}")]
732    GcpAuthHttp {
733        endpoint: &'static str,
734        source: reqwest::Error,
735    },
736
737    #[error("GCP auth token endpoint {endpoint} returned HTTP {status}: {body}")]
738    GcpAuthRejected {
739        endpoint: String,
740        status: u16,
741        body: String,
742    },
743
744    #[error("GCP auth token response was invalid: {detail}")]
745    GcpAuthInvalidResponse { detail: String },
746}
747
748impl ContextError {
749    /// Returns true when a Secret fetch failed due to a non-transient client-side API error.
750    pub fn is_secret_fetch_non_transient(&self) -> bool {
751        matches!(
752            self,
753            ContextError::SecretFetch {
754                source: kube::Error::Api(response),
755                ..
756            } if (400..500).contains(&response.code) && response.code != 429
757        )
758    }
759
760    pub fn is_gcp_auth_non_transient(&self) -> bool {
761        matches!(
762            self,
763            ContextError::GcpAuthRejected { status, .. }
764                if (400..500).contains(status) && *status != 429
765        ) || matches!(self, ContextError::GcpAuthInvalidResponse { .. })
766    }
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772
773    #[test]
774    fn pool_cache_key_format() {
775        // Verify the cache key format is "namespace/secret-name/secret-key"
776        let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
777        assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
778    }
779
780    #[test]
781    fn secret_fetch_not_found_is_non_transient() {
782        let error = ContextError::SecretFetch {
783            name: "db-credentials".into(),
784            namespace: "default".into(),
785            source: kube::Error::Api(
786                kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
787                    .with_code(404)
788                    .boxed(),
789            ),
790        };
791
792        assert!(error.is_secret_fetch_non_transient());
793    }
794
795    #[test]
796    fn secret_fetch_forbidden_is_non_transient() {
797        let error = ContextError::SecretFetch {
798            name: "db-credentials".into(),
799            namespace: "default".into(),
800            source: kube::Error::Api(
801                kube::core::Status::failure("forbidden", "Forbidden")
802                    .with_code(403)
803                    .boxed(),
804            ),
805        };
806
807        assert!(error.is_secret_fetch_non_transient());
808    }
809
810    #[test]
811    fn secret_fetch_server_error_remains_transient() {
812        let error = ContextError::SecretFetch {
813            name: "db-credentials".into(),
814            namespace: "default".into(),
815            source: kube::Error::Api(
816                kube::core::Status::failure("internal error", "InternalError")
817                    .with_code(500)
818                    .boxed(),
819            ),
820        };
821
822        assert!(!error.is_secret_fetch_non_transient());
823    }
824
825    #[test]
826    fn gcp_auth_client_error_is_non_transient() {
827        let error = ContextError::GcpAuthRejected {
828            endpoint: "metadata".into(),
829            status: 403,
830            body: "forbidden".into(),
831        };
832
833        assert!(error.is_gcp_auth_non_transient());
834    }
835
836    #[test]
837    fn gcp_auth_rate_limit_remains_transient() {
838        let error = ContextError::GcpAuthRejected {
839            endpoint: "metadata".into(),
840            status: 429,
841            body: "rate limited".into(),
842        };
843
844        assert!(!error.is_gcp_auth_non_transient());
845    }
846
847    #[test]
848    fn token_expiry_uses_five_minute_refresh_skew() {
849        let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000);
850        assert!(token_expires_after_skew(
851            Some(now + Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS + 1)),
852            now
853        ));
854        assert!(!token_expires_after_skew(
855            Some(now + Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS)),
856            now
857        ));
858    }
859
860    #[test]
861    fn parse_google_expire_time_accepts_rfc3339() {
862        let parsed =
863            parse_google_expire_time("2026-05-14T02:30:00Z").expect("expireTime should parse");
864        assert_eq!(
865            parsed
866                .duration_since(SystemTime::UNIX_EPOCH)
867                .unwrap()
868                .as_secs(),
869            1_778_725_800
870        );
871    }
872
873    #[test]
874    fn truncate_for_error_keeps_utf8_boundary() {
875        let body = "é".repeat(300);
876        let truncated = truncate_for_error(body);
877
878        assert!(truncated.ends_with("..."));
879        assert!(truncated.is_char_boundary(truncated.len() - 3));
880    }
881
882    #[tokio::test]
883    async fn try_lock_database_acquires_when_free() {
884        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
885        let ctx = OperatorContextLockHelper {
886            database_locks: locks,
887        };
888        let guard = ctx.try_lock("db-a").await;
889        assert!(guard.is_some(), "should acquire lock on free database");
890    }
891
892    #[tokio::test]
893    async fn try_lock_database_contention_returns_none() {
894        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
895        let ctx = OperatorContextLockHelper {
896            database_locks: locks,
897        };
898
899        let _guard1 = ctx
900            .try_lock("db-a")
901            .await
902            .expect("first lock should succeed");
903        let guard2 = ctx.try_lock("db-a").await;
904        assert!(guard2.is_none(), "second lock on same database should fail");
905    }
906
907    #[tokio::test]
908    async fn try_lock_database_different_databases_independent() {
909        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
910        let ctx = OperatorContextLockHelper {
911            database_locks: locks,
912        };
913
914        let guard_a = ctx.try_lock("db-a").await;
915        let guard_b = ctx.try_lock("db-b").await;
916        assert!(guard_a.is_some(), "lock on db-a should succeed");
917        assert!(
918            guard_b.is_some(),
919            "lock on db-b should succeed (different database)"
920        );
921    }
922
923    #[tokio::test]
924    async fn try_lock_database_released_after_drop() {
925        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
926        let ctx = OperatorContextLockHelper {
927            database_locks: Arc::clone(&locks),
928        };
929
930        {
931            let _guard = ctx.try_lock("db-a").await.expect("should acquire");
932            // guard is dropped here
933        }
934
935        // After drop, should be able to acquire again.
936        let guard2 = ctx.try_lock("db-a").await;
937        assert!(
938            guard2.is_some(),
939            "should re-acquire after previous guard dropped"
940        );
941    }
942
943    #[tokio::test]
944    async fn try_lock_database_concurrent_contention() {
945        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
946
947        // Simulate two concurrent reconciles for the same database.
948        let locks1 = Arc::clone(&locks);
949        let locks2 = Arc::clone(&locks);
950
951        let handle1 = tokio::spawn(async move {
952            let ctx = OperatorContextLockHelper {
953                database_locks: locks1,
954            };
955            let guard = ctx.try_lock("shared-db").await;
956            if guard.is_some() {
957                // Hold the lock briefly.
958                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
959            }
960            guard.is_some()
961        });
962
963        // Small delay so handle1 is likely first.
964        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
965
966        let handle2 = tokio::spawn(async move {
967            let ctx = OperatorContextLockHelper {
968                database_locks: locks2,
969            };
970            let guard = ctx.try_lock("shared-db").await;
971            guard.is_some()
972        });
973
974        let (r1, r2) = tokio::join!(handle1, handle2);
975        let acquired1 = r1.unwrap();
976        let acquired2 = r2.unwrap();
977
978        // Exactly one should succeed.
979        assert!(
980            acquired1 ^ acquired2,
981            "exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
982        );
983    }
984
985    /// Helper to test locking without a real kube client.
986    struct OperatorContextLockHelper {
987        database_locks: Arc<Mutex<HashMap<String, ()>>>,
988    }
989
990    impl OperatorContextLockHelper {
991        async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
992            let mut locks = self.database_locks.lock().await;
993            if locks.contains_key(database_identity) {
994                return None;
995            }
996            locks.insert(database_identity.to_string(), ());
997            Some(DatabaseLockGuard {
998                key: database_identity.to_string(),
999                locks: Arc::clone(&self.database_locks),
1000            })
1001        }
1002    }
1003
1004    #[tokio::test]
1005    async fn try_lock_database_high_concurrency_same_db() {
1006        // Spawn many tasks all racing to lock the same database.
1007        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
1008        let concurrency = 50;
1009        let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1010        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
1011
1012        let mut handles = Vec::with_capacity(concurrency);
1013        for _ in 0..concurrency {
1014            let locks_clone = Arc::clone(&locks);
1015            let count = Arc::clone(&acquired_count);
1016            let bar = Arc::clone(&barrier);
1017            handles.push(tokio::spawn(async move {
1018                // Synchronize start so all tasks race at the same instant.
1019                bar.wait().await;
1020                let ctx = OperatorContextLockHelper {
1021                    database_locks: locks_clone,
1022                };
1023                let guard = ctx.try_lock("contested-db").await;
1024                if guard.is_some() {
1025                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1026                    // Hold lock briefly to let other tasks observe contention.
1027                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1028                }
1029            }));
1030        }
1031
1032        for h in handles {
1033            h.await.unwrap();
1034        }
1035
1036        // Exactly one task should have acquired the lock.
1037        let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
1038        assert_eq!(
1039            total, 1,
1040            "exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
1041        );
1042    }
1043
1044    #[tokio::test]
1045    async fn try_lock_database_high_concurrency_different_dbs() {
1046        // Many tasks each locking a different database — all should succeed.
1047        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
1048        let concurrency = 50;
1049        let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1050        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
1051
1052        let mut handles = Vec::with_capacity(concurrency);
1053        for i in 0..concurrency {
1054            let locks_clone = Arc::clone(&locks);
1055            let count = Arc::clone(&acquired_count);
1056            let bar = Arc::clone(&barrier);
1057            handles.push(tokio::spawn(async move {
1058                bar.wait().await;
1059                let ctx = OperatorContextLockHelper {
1060                    database_locks: locks_clone,
1061                };
1062                let db_name = format!("db-{i}");
1063                let guard = ctx.try_lock(&db_name).await;
1064                if guard.is_some() {
1065                    count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1066                    tokio::time::sleep(std::time::Duration::from_millis(5)).await;
1067                }
1068            }));
1069        }
1070
1071        for h in handles {
1072            h.await.unwrap();
1073        }
1074
1075        let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
1076        assert_eq!(
1077            total, concurrency,
1078            "all {concurrency} tasks locking different dbs should succeed, got {total}"
1079        );
1080    }
1081
1082    #[tokio::test]
1083    async fn try_lock_database_acquire_release_cycle_under_contention() {
1084        // Repeatedly acquire and release the same database lock from many tasks.
1085        // Each task attempts the lock in a loop until it succeeds, simulating
1086        // the requeue-after-contention pattern used in the reconciler.
1087        let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
1088        let concurrency = 20;
1089        let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1090        let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
1091
1092        let mut handles = Vec::with_capacity(concurrency);
1093        for _ in 0..concurrency {
1094            let locks_clone = Arc::clone(&locks);
1095            let count = Arc::clone(&success_count);
1096            let bar = Arc::clone(&barrier);
1097            handles.push(tokio::spawn(async move {
1098                bar.wait().await;
1099                // Retry up to 100 times with a small sleep between attempts,
1100                // simulating the jittered requeue pattern.
1101                for _ in 0..100 {
1102                    let ctx = OperatorContextLockHelper {
1103                        database_locks: Arc::clone(&locks_clone),
1104                    };
1105                    if let Some(_guard) = ctx.try_lock("shared-db").await {
1106                        count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1107                        // Brief simulated work, then guard drops (releasing lock).
1108                        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
1109                        return;
1110                    }
1111                    tokio::time::sleep(std::time::Duration::from_millis(1)).await;
1112                }
1113                // Should not reach here in practice — fail the test if we do.
1114                panic!("task failed to acquire lock after 100 retries");
1115            }));
1116        }
1117
1118        for h in handles {
1119            h.await.unwrap();
1120        }
1121
1122        let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
1123        assert_eq!(
1124            total, concurrency,
1125            "all {concurrency} tasks should eventually acquire the lock"
1126        );
1127    }
1128}