1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, SystemTime};
6
7use futures::future::{BoxFuture, FutureExt};
8use kube::runtime::events::Recorder;
9use serde::{Deserialize, Serialize};
10
11use sqlx::postgres::{PgPool, PgPoolOptions};
12use tokio::sync::{Mutex, RwLock};
13
14use crate::crd::{ConnectionAuth, ConnectionSpec, SecretKeySelector};
15use crate::observability::OperatorObservability;
16
17const POOL_MAX_CONNECTIONS: u32 = 5;
22
23const POOL_ACQUIRE_TIMEOUT_SECS: u64 = 10;
26
27const _: () = assert!(POOL_MAX_CONNECTIONS >= 2);
28
29const GCP_METADATA_TOKEN_ENDPOINT: &str =
30 "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token";
31const GCP_IAM_CREDENTIALS_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";
32const GCP_TOKEN_CACHE_SKEW_SECS: u64 = 300;
33const GCP_IMPERSONATED_TOKEN_LIFETIME_SECS: u64 = 3600;
34const GCP_AUTH_HTTP_TIMEOUT_SECS: u64 = 10;
35
36#[derive(Clone)]
37struct CachedPool {
38 resource_version: Option<String>,
39 secret_fingerprint: Option<String>,
41 token_expires_at: Option<SystemTime>,
43 pool: PgPool,
44}
45
46struct ResolvedConnectionUrl {
47 database_url: String,
48 token_expires_at: Option<SystemTime>,
49}
50
51#[derive(Clone)]
52struct GcpAccessToken {
53 token: String,
54 expires_at: SystemTime,
55}
56
57trait GcpAccessTokenProvider: Send + Sync {
58 fn fetch_token<'a>(
59 &'a self,
60 auth: &'a ConnectionAuth,
61 ) -> BoxFuture<'a, Result<GcpAccessToken, ContextError>>;
62}
63
64#[derive(Clone)]
65struct MetadataGcpAccessTokenProvider {
66 client: reqwest::Client,
67}
68
69impl Default for MetadataGcpAccessTokenProvider {
70 fn default() -> Self {
71 Self {
72 client: reqwest::Client::builder()
73 .no_proxy()
74 .timeout(Duration::from_secs(GCP_AUTH_HTTP_TIMEOUT_SECS))
75 .build()
76 .expect("GCP auth HTTP client should build"),
77 }
78 }
79}
80
81impl GcpAccessTokenProvider for MetadataGcpAccessTokenProvider {
82 fn fetch_token<'a>(
83 &'a self,
84 auth: &'a ConnectionAuth,
85 ) -> BoxFuture<'a, Result<GcpAccessToken, ContextError>> {
86 async move {
87 let scope = auth.gcp_scope();
88 if let Some(target) = auth.gcp_impersonate_service_account() {
89 self.fetch_impersonated_access_token(target, scope).await
90 } else {
91 self.fetch_metadata_access_token(scope).await
92 }
93 }
94 .boxed()
95 }
96}
97
98impl MetadataGcpAccessTokenProvider {
99 async fn fetch_metadata_access_token(
100 &self,
101 scope: &str,
102 ) -> Result<GcpAccessToken, ContextError> {
103 let response = self
104 .client
105 .get(GCP_METADATA_TOKEN_ENDPOINT)
106 .header("Metadata-Flavor", "Google")
107 .query(&[("scopes", scope)])
108 .send()
109 .await
110 .map_err(|source| ContextError::GcpAuthHttp {
111 endpoint: "metadata",
112 source,
113 })?;
114
115 let status = response.status();
116 if !status.is_success() {
117 let body = response_body_for_error(response).await;
118 return Err(ContextError::GcpAuthRejected {
119 endpoint: "metadata".to_string(),
120 status: status.as_u16(),
121 body,
122 });
123 }
124
125 let body: MetadataTokenResponse =
126 response
127 .json()
128 .await
129 .map_err(|source| ContextError::GcpAuthHttp {
130 endpoint: "metadata",
131 source,
132 })?;
133
134 if body.access_token.trim().is_empty() {
135 return Err(ContextError::GcpAuthInvalidResponse {
136 detail: "metadata token response omitted access_token".to_string(),
137 });
138 }
139 if body.expires_in == 0 {
140 return Err(ContextError::GcpAuthInvalidResponse {
141 detail: "metadata token response had zero expires_in".to_string(),
142 });
143 }
144
145 Ok(GcpAccessToken {
146 token: body.access_token,
147 expires_at: SystemTime::now() + Duration::from_secs(body.expires_in),
148 })
149 }
150
151 async fn fetch_impersonated_access_token(
152 &self,
153 target_service_account: &str,
154 scope: &str,
155 ) -> Result<GcpAccessToken, ContextError> {
156 let source = self
157 .fetch_metadata_access_token(GCP_IAM_CREDENTIALS_SCOPE)
158 .await?;
159 let encoded_target = percent_encoding::utf8_percent_encode(
160 target_service_account,
161 percent_encoding::NON_ALPHANUMERIC,
162 )
163 .to_string();
164 let endpoint = format!(
165 "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{encoded_target}:generateAccessToken"
166 );
167 let request = GenerateAccessTokenRequest {
168 scope: vec![scope.to_string()],
169 lifetime: format!("{GCP_IMPERSONATED_TOKEN_LIFETIME_SECS}s"),
170 };
171
172 let response = self
173 .client
174 .post(&endpoint)
175 .bearer_auth(&source.token)
176 .json(&request)
177 .send()
178 .await
179 .map_err(|source| ContextError::GcpAuthHttp {
180 endpoint: "iamcredentials",
181 source,
182 })?;
183
184 let status = response.status();
185 if !status.is_success() {
186 let body = response_body_for_error(response).await;
187 return Err(ContextError::GcpAuthRejected {
188 endpoint: "iamcredentials".to_string(),
189 status: status.as_u16(),
190 body,
191 });
192 }
193
194 let body: GenerateAccessTokenResponse =
195 response
196 .json()
197 .await
198 .map_err(|source| ContextError::GcpAuthHttp {
199 endpoint: "iamcredentials",
200 source,
201 })?;
202
203 if body.access_token.trim().is_empty() {
204 return Err(ContextError::GcpAuthInvalidResponse {
205 detail: "IAMCredentials response omitted accessToken".to_string(),
206 });
207 }
208 let expires_at = parse_google_expire_time(&body.expire_time).ok_or_else(|| {
209 ContextError::GcpAuthInvalidResponse {
210 detail: format!(
211 "IAMCredentials response had invalid expireTime {:?}",
212 body.expire_time
213 ),
214 }
215 })?;
216
217 Ok(GcpAccessToken {
218 token: body.access_token,
219 expires_at,
220 })
221 }
222}
223
224#[derive(Deserialize)]
225struct MetadataTokenResponse {
226 access_token: String,
227 expires_in: u64,
228}
229
230#[derive(Serialize)]
231struct GenerateAccessTokenRequest {
232 scope: Vec<String>,
233 lifetime: String,
234}
235
236#[derive(Deserialize)]
237struct GenerateAccessTokenResponse {
238 #[serde(rename = "accessToken")]
239 access_token: String,
240 #[serde(rename = "expireTime")]
241 expire_time: String,
242}
243
244async fn response_body_for_error(response: reqwest::Response) -> String {
245 match response.text().await {
246 Ok(body) => truncate_for_error(body),
247 Err(error) => format!("failed to read error body: {error}"),
248 }
249}
250
251fn truncate_for_error(mut body: String) -> String {
252 const MAX_ERROR_BODY_BYTES: usize = 512;
253 if body.len() <= MAX_ERROR_BODY_BYTES {
254 return body;
255 }
256 let mut end = MAX_ERROR_BODY_BYTES;
257 while !body.is_char_boundary(end) {
258 end -= 1;
259 }
260 body.truncate(end);
261 body.push_str("...");
262 body
263}
264
265fn parse_google_expire_time(expire_time: &str) -> Option<SystemTime> {
266 expire_time
267 .parse::<jiff::Timestamp>()
268 .ok()
269 .map(SystemTime::from)
270}
271
272fn token_expires_after_skew(expires_at: Option<SystemTime>, now: SystemTime) -> bool {
273 let Some(expires_at) = expires_at else {
274 return true;
275 };
276 let Some(refresh_at) = now.checked_add(Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS)) else {
277 return false;
278 };
279 expires_at > refresh_at
280}
281
282pub struct DatabaseLockGuard {
288 key: String,
289 locks: Arc<Mutex<HashMap<String, ()>>>,
290}
291
292impl Drop for DatabaseLockGuard {
293 fn drop(&mut self) {
294 if let Ok(mut map) = self.locks.try_lock() {
296 map.remove(&self.key);
297 tracing::debug!(database = %self.key, "released in-memory database lock");
298 } else {
299 let key = self.key.clone();
303 let locks = Arc::clone(&self.locks);
304 if let Ok(handle) = tokio::runtime::Handle::try_current() {
305 handle.spawn(async move {
306 locks.lock().await.remove(&key);
307 tracing::debug!(database = %key, "released in-memory database lock (deferred)");
308 });
309 tracing::debug!(
310 database = %self.key,
311 "deferred in-memory database lock release to background task"
312 );
313 } else {
314 let mut map = self.locks.blocking_lock();
317 map.remove(&key);
318 tracing::debug!(
319 database = %key,
320 "released in-memory database lock (fallback sync)"
321 );
322 }
323 }
324 }
325}
326
327#[derive(Clone)]
329pub struct OperatorContext {
330 pub kube_client: kube::Client,
332
333 pub event_recorder: Recorder,
335
336 pool_cache: Arc<RwLock<HashMap<String, CachedPool>>>,
338 database_locks: Arc<Mutex<HashMap<String, ()>>>,
344
345 pub observability: OperatorObservability,
347
348 gcp_token_provider: Arc<dyn GcpAccessTokenProvider>,
350}
351
352impl OperatorContext {
353 pub fn new(
355 kube_client: kube::Client,
356 observability: OperatorObservability,
357 event_recorder: Recorder,
358 ) -> Self {
359 Self {
360 kube_client,
361 event_recorder,
362 pool_cache: Arc::new(RwLock::new(HashMap::new())),
363 observability,
364 database_locks: Arc::new(Mutex::new(HashMap::new())),
365 gcp_token_provider: Arc::new(MetadataGcpAccessTokenProvider::default()),
366 }
367 }
368
369 pub async fn try_lock_database(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
375 let mut locks = self.database_locks.lock().await;
376 if locks.contains_key(database_identity) {
377 tracing::info!(
378 database = %database_identity,
379 "in-memory database lock contention — another reconcile is in progress"
380 );
381 return None;
382 }
383 locks.insert(database_identity.to_string(), ());
384 tracing::debug!(database = %database_identity, "acquired in-memory database lock");
385 Some(DatabaseLockGuard {
386 key: database_identity.to_string(),
387 locks: Arc::clone(&self.database_locks),
388 })
389 }
390
391 async fn resolve_param(
395 &self,
396 namespace: &str,
397 literal: &Option<String>,
398 secret: &Option<SecretKeySelector>,
399 ) -> Result<Option<String>, ContextError> {
400 if let Some(val) = literal {
401 return Ok(Some(val.clone()));
402 }
403 if let Some(sel) = secret {
404 return Ok(Some(
405 self.fetch_secret_value(namespace, &sel.name, &sel.key)
406 .await?,
407 ));
408 }
409 Ok(None)
410 }
411
412 pub async fn resolve_connection_url(
417 &self,
418 namespace: &str,
419 connection: &ConnectionSpec,
420 ) -> Result<String, ContextError> {
421 Ok(self
422 .resolve_connection_url_with_metadata(namespace, connection)
423 .await?
424 .database_url)
425 }
426
427 async fn resolve_connection_url_with_metadata(
428 &self,
429 namespace: &str,
430 connection: &ConnectionSpec,
431 ) -> Result<ResolvedConnectionUrl, ContextError> {
432 if let Some(ref secret_ref) = connection.secret_ref {
433 let database_url = self
435 .fetch_secret_value(
436 namespace,
437 &secret_ref.name,
438 connection.effective_secret_key(),
439 )
440 .await?;
441 Ok(ResolvedConnectionUrl {
442 database_url,
443 token_expires_at: None,
444 })
445 } else if let Some(ref params) = connection.params {
446 let host = self
448 .resolve_param(namespace, ¶ms.host, ¶ms.host_secret)
449 .await?
450 .ok_or_else(|| ContextError::EmptyResolvedValue {
451 field: "host".to_string(),
452 })?;
453 if host.trim().is_empty() {
454 return Err(ContextError::EmptyResolvedValue {
455 field: "host".to_string(),
456 });
457 }
458
459 let port_str = params.port.map(|p| p.to_string());
460 let port = self
461 .resolve_param(namespace, &port_str, ¶ms.port_secret)
462 .await?
463 .unwrap_or_else(|| "5432".to_string());
464 if port.trim().is_empty() {
465 return Err(ContextError::EmptyResolvedValue {
466 field: "port".to_string(),
467 });
468 }
469
470 let dbname = self
471 .resolve_param(namespace, ¶ms.dbname, ¶ms.dbname_secret)
472 .await?
473 .ok_or_else(|| ContextError::EmptyResolvedValue {
474 field: "dbname".to_string(),
475 })?;
476 if dbname.trim().is_empty() {
477 return Err(ContextError::EmptyResolvedValue {
478 field: "dbname".to_string(),
479 });
480 }
481
482 let username = self
483 .resolve_param(namespace, ¶ms.username, ¶ms.username_secret)
484 .await?
485 .ok_or_else(|| ContextError::EmptyResolvedValue {
486 field: "username".to_string(),
487 })?;
488 if username.trim().is_empty() {
489 return Err(ContextError::EmptyResolvedValue {
490 field: "username".to_string(),
491 });
492 }
493
494 let (password, token_expires_at) = if let Some(auth) = ¶ms.auth {
495 let token = self.gcp_token_provider.fetch_token(auth).await?;
496 (token.token, Some(token.expires_at))
497 } else {
498 let password = self
499 .resolve_param(namespace, ¶ms.password, ¶ms.password_secret)
500 .await?
501 .ok_or_else(|| ContextError::EmptyResolvedValue {
502 field: "password".to_string(),
503 })?;
504 (password, None)
505 };
506 if password.trim().is_empty() {
507 return Err(ContextError::EmptyResolvedValue {
508 field: "password".to_string(),
509 });
510 }
511
512 use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
513 let encoded_username = utf8_percent_encode(&username, NON_ALPHANUMERIC).to_string();
514 let encoded_password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string();
515
516 let mut url = format!(
517 "postgresql://{encoded_username}:{encoded_password}@{host}:{port}/{dbname}"
518 );
519
520 let ssl_mode = self
521 .resolve_param(namespace, ¶ms.ssl_mode, ¶ms.ssl_mode_secret)
522 .await?
523 .or_else(|| params.auth.as_ref().map(|_| "require".to_string()));
524 if let Some(ssl_mode) = ssl_mode {
525 if !crate::crd::VALID_SSL_MODES.contains(&ssl_mode.as_str()) {
528 return Err(ContextError::InvalidResolvedSslMode { value: ssl_mode });
529 }
530 url.push_str("?sslmode=");
531 url.push_str(&ssl_mode);
532 }
533
534 Ok(ResolvedConnectionUrl {
535 database_url: url,
536 token_expires_at,
537 })
538 } else {
539 Err(ContextError::SecretMissing {
540 name: "connection".to_string(),
541 key: "neither secretRef nor params is set".to_string(),
542 })
543 }
544 }
545
546 pub async fn get_or_create_pool(
551 &self,
552 namespace: &str,
553 connection: &ConnectionSpec,
554 ) -> Result<PgPool, ContextError> {
555 let cache_key = connection.cache_key(namespace);
556
557 let (resource_version, secret_fingerprint) =
561 if let Some(ref secret_ref) = connection.secret_ref {
562 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
563 kube::Api::namespaced(self.kube_client.clone(), namespace);
564 let secret = secrets_api.get(&secret_ref.name).await.map_err(|err| {
565 ContextError::SecretFetch {
566 name: secret_ref.name.clone(),
567 namespace: namespace.to_string(),
568 source: err,
569 }
570 })?;
571 (secret.metadata.resource_version, None)
572 } else if connection.params.is_some() {
573 let mut secret_names = std::collections::BTreeSet::new();
576 connection.collect_secret_names(&mut secret_names);
577
578 if secret_names.is_empty() {
579 (None, Some(String::new()))
581 } else {
582 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
583 kube::Api::namespaced(self.kube_client.clone(), namespace);
584 let mut fingerprint_parts = Vec::new();
585 for name in &secret_names {
586 let secret = secrets_api.get(name).await.map_err(|err| {
587 ContextError::SecretFetch {
588 name: name.clone(),
589 namespace: namespace.to_string(),
590 source: err,
591 }
592 })?;
593 let rv = secret
594 .metadata
595 .resource_version
596 .unwrap_or_else(|| "unknown".to_string());
597 fingerprint_parts.push(format!("{name}={rv}"));
598 }
599 (None, Some(fingerprint_parts.join(",")))
600 }
601 } else {
602 (None, None)
603 };
604
605 {
607 let cache = self.pool_cache.read().await;
608 if let Some(cached) = cache.get(&cache_key) {
609 let version_matches = match (&resource_version, &cached.resource_version) {
612 (Some(current), Some(cached_rv)) => current == cached_rv,
613 _ => true,
614 };
615 let fingerprint_matches = match (&secret_fingerprint, &cached.secret_fingerprint) {
616 (Some(current), Some(cached_fp)) => current == cached_fp,
617 (None, None) => true,
618 _ => false,
619 };
620 let token_fresh =
621 token_expires_after_skew(cached.token_expires_at, SystemTime::now());
622 if version_matches && fingerprint_matches && token_fresh {
623 return Ok(cached.pool.clone());
624 }
625 }
626 }
627
628 let resolved = self
629 .resolve_connection_url_with_metadata(namespace, connection)
630 .await?;
631
632 let pool = PgPoolOptions::new()
636 .max_connections(POOL_MAX_CONNECTIONS)
637 .acquire_timeout(Duration::from_secs(POOL_ACQUIRE_TIMEOUT_SECS))
638 .connect(&resolved.database_url)
639 .await
640 .map_err(|err| ContextError::DatabaseConnect { source: err })?;
641
642 {
644 let mut cache = self.pool_cache.write().await;
645 cache.insert(
646 cache_key,
647 CachedPool {
648 resource_version,
649 secret_fingerprint,
650 token_expires_at: resolved.token_expires_at,
651 pool: pool.clone(),
652 },
653 );
654 }
655
656 Ok(pool)
657 }
658
659 pub async fn fetch_secret_value(
663 &self,
664 namespace: &str,
665 secret_name: &str,
666 secret_key: &str,
667 ) -> Result<String, ContextError> {
668 let secrets_api: kube::Api<k8s_openapi::api::core::v1::Secret> =
669 kube::Api::namespaced(self.kube_client.clone(), namespace);
670
671 let secret =
672 secrets_api
673 .get(secret_name)
674 .await
675 .map_err(|err| ContextError::SecretFetch {
676 name: secret_name.to_string(),
677 namespace: namespace.to_string(),
678 source: err,
679 })?;
680
681 let data = secret.data.ok_or_else(|| ContextError::SecretMissing {
682 name: secret_name.to_string(),
683 key: secret_key.to_string(),
684 })?;
685
686 let value_bytes = data
687 .get(secret_key)
688 .ok_or_else(|| ContextError::SecretMissing {
689 name: secret_name.to_string(),
690 key: secret_key.to_string(),
691 })?;
692
693 String::from_utf8(value_bytes.0.clone()).map_err(|_| ContextError::SecretMissing {
694 name: secret_name.to_string(),
695 key: secret_key.to_string(),
696 })
697 }
698
699 pub async fn evict_pool(&self, namespace: &str, connection: &ConnectionSpec) {
701 let cache_key = connection.cache_key(namespace);
702 let mut cache = self.pool_cache.write().await;
703 cache.remove(&cache_key);
704 }
705}
706
707#[derive(Debug, thiserror::Error)]
709pub enum ContextError {
710 #[error("failed to fetch Secret {namespace}/{name}: {source}")]
711 SecretFetch {
712 name: String,
713 namespace: String,
714 source: kube::Error,
715 },
716
717 #[error("Secret \"{name}\" does not contain key \"{key}\"")]
718 SecretMissing { name: String, key: String },
719
720 #[error("failed to connect to database: {source}")]
721 DatabaseConnect { source: sqlx::Error },
722
723 #[error("connection param \"{field}\" resolved to an empty or whitespace-only value")]
724 EmptyResolvedValue { field: String },
725
726 #[error(
727 "connection param sslMode resolved to invalid value \"{value}\" (expected one of: disable, allow, prefer, require, verify-ca, verify-full)"
728 )]
729 InvalidResolvedSslMode { value: String },
730
731 #[error("failed to fetch GCP auth token from {endpoint}: {source}")]
732 GcpAuthHttp {
733 endpoint: &'static str,
734 source: reqwest::Error,
735 },
736
737 #[error("GCP auth token endpoint {endpoint} returned HTTP {status}: {body}")]
738 GcpAuthRejected {
739 endpoint: String,
740 status: u16,
741 body: String,
742 },
743
744 #[error("GCP auth token response was invalid: {detail}")]
745 GcpAuthInvalidResponse { detail: String },
746}
747
748impl ContextError {
749 pub fn is_secret_fetch_non_transient(&self) -> bool {
751 matches!(
752 self,
753 ContextError::SecretFetch {
754 source: kube::Error::Api(response),
755 ..
756 } if (400..500).contains(&response.code) && response.code != 429
757 )
758 }
759
760 pub fn is_gcp_auth_non_transient(&self) -> bool {
761 matches!(
762 self,
763 ContextError::GcpAuthRejected { status, .. }
764 if (400..500).contains(status) && *status != 429
765 ) || matches!(self, ContextError::GcpAuthInvalidResponse { .. })
766 }
767}
768
769#[cfg(test)]
770mod tests {
771 use super::*;
772
773 #[test]
774 fn pool_cache_key_format() {
775 let key = format!("{}/{}/{}", "prod", "pg-credentials", "DATABASE_URL");
777 assert_eq!(key, "prod/pg-credentials/DATABASE_URL");
778 }
779
780 #[test]
781 fn secret_fetch_not_found_is_non_transient() {
782 let error = ContextError::SecretFetch {
783 name: "db-credentials".into(),
784 namespace: "default".into(),
785 source: kube::Error::Api(
786 kube::core::Status::failure("secrets \"db-credentials\" not found", "NotFound")
787 .with_code(404)
788 .boxed(),
789 ),
790 };
791
792 assert!(error.is_secret_fetch_non_transient());
793 }
794
795 #[test]
796 fn secret_fetch_forbidden_is_non_transient() {
797 let error = ContextError::SecretFetch {
798 name: "db-credentials".into(),
799 namespace: "default".into(),
800 source: kube::Error::Api(
801 kube::core::Status::failure("forbidden", "Forbidden")
802 .with_code(403)
803 .boxed(),
804 ),
805 };
806
807 assert!(error.is_secret_fetch_non_transient());
808 }
809
810 #[test]
811 fn secret_fetch_server_error_remains_transient() {
812 let error = ContextError::SecretFetch {
813 name: "db-credentials".into(),
814 namespace: "default".into(),
815 source: kube::Error::Api(
816 kube::core::Status::failure("internal error", "InternalError")
817 .with_code(500)
818 .boxed(),
819 ),
820 };
821
822 assert!(!error.is_secret_fetch_non_transient());
823 }
824
825 #[test]
826 fn gcp_auth_client_error_is_non_transient() {
827 let error = ContextError::GcpAuthRejected {
828 endpoint: "metadata".into(),
829 status: 403,
830 body: "forbidden".into(),
831 };
832
833 assert!(error.is_gcp_auth_non_transient());
834 }
835
836 #[test]
837 fn gcp_auth_rate_limit_remains_transient() {
838 let error = ContextError::GcpAuthRejected {
839 endpoint: "metadata".into(),
840 status: 429,
841 body: "rate limited".into(),
842 };
843
844 assert!(!error.is_gcp_auth_non_transient());
845 }
846
847 #[test]
848 fn token_expiry_uses_five_minute_refresh_skew() {
849 let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000);
850 assert!(token_expires_after_skew(
851 Some(now + Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS + 1)),
852 now
853 ));
854 assert!(!token_expires_after_skew(
855 Some(now + Duration::from_secs(GCP_TOKEN_CACHE_SKEW_SECS)),
856 now
857 ));
858 }
859
860 #[test]
861 fn parse_google_expire_time_accepts_rfc3339() {
862 let parsed =
863 parse_google_expire_time("2026-05-14T02:30:00Z").expect("expireTime should parse");
864 assert_eq!(
865 parsed
866 .duration_since(SystemTime::UNIX_EPOCH)
867 .unwrap()
868 .as_secs(),
869 1_778_725_800
870 );
871 }
872
873 #[test]
874 fn truncate_for_error_keeps_utf8_boundary() {
875 let body = "é".repeat(300);
876 let truncated = truncate_for_error(body);
877
878 assert!(truncated.ends_with("..."));
879 assert!(truncated.is_char_boundary(truncated.len() - 3));
880 }
881
882 #[tokio::test]
883 async fn try_lock_database_acquires_when_free() {
884 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
885 let ctx = OperatorContextLockHelper {
886 database_locks: locks,
887 };
888 let guard = ctx.try_lock("db-a").await;
889 assert!(guard.is_some(), "should acquire lock on free database");
890 }
891
892 #[tokio::test]
893 async fn try_lock_database_contention_returns_none() {
894 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
895 let ctx = OperatorContextLockHelper {
896 database_locks: locks,
897 };
898
899 let _guard1 = ctx
900 .try_lock("db-a")
901 .await
902 .expect("first lock should succeed");
903 let guard2 = ctx.try_lock("db-a").await;
904 assert!(guard2.is_none(), "second lock on same database should fail");
905 }
906
907 #[tokio::test]
908 async fn try_lock_database_different_databases_independent() {
909 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
910 let ctx = OperatorContextLockHelper {
911 database_locks: locks,
912 };
913
914 let guard_a = ctx.try_lock("db-a").await;
915 let guard_b = ctx.try_lock("db-b").await;
916 assert!(guard_a.is_some(), "lock on db-a should succeed");
917 assert!(
918 guard_b.is_some(),
919 "lock on db-b should succeed (different database)"
920 );
921 }
922
923 #[tokio::test]
924 async fn try_lock_database_released_after_drop() {
925 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
926 let ctx = OperatorContextLockHelper {
927 database_locks: Arc::clone(&locks),
928 };
929
930 {
931 let _guard = ctx.try_lock("db-a").await.expect("should acquire");
932 }
934
935 let guard2 = ctx.try_lock("db-a").await;
937 assert!(
938 guard2.is_some(),
939 "should re-acquire after previous guard dropped"
940 );
941 }
942
943 #[tokio::test]
944 async fn try_lock_database_concurrent_contention() {
945 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
946
947 let locks1 = Arc::clone(&locks);
949 let locks2 = Arc::clone(&locks);
950
951 let handle1 = tokio::spawn(async move {
952 let ctx = OperatorContextLockHelper {
953 database_locks: locks1,
954 };
955 let guard = ctx.try_lock("shared-db").await;
956 if guard.is_some() {
957 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
959 }
960 guard.is_some()
961 });
962
963 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
965
966 let handle2 = tokio::spawn(async move {
967 let ctx = OperatorContextLockHelper {
968 database_locks: locks2,
969 };
970 let guard = ctx.try_lock("shared-db").await;
971 guard.is_some()
972 });
973
974 let (r1, r2) = tokio::join!(handle1, handle2);
975 let acquired1 = r1.unwrap();
976 let acquired2 = r2.unwrap();
977
978 assert!(
980 acquired1 ^ acquired2,
981 "exactly one of two concurrent locks should succeed: got ({acquired1}, {acquired2})"
982 );
983 }
984
985 struct OperatorContextLockHelper {
987 database_locks: Arc<Mutex<HashMap<String, ()>>>,
988 }
989
990 impl OperatorContextLockHelper {
991 async fn try_lock(&self, database_identity: &str) -> Option<DatabaseLockGuard> {
992 let mut locks = self.database_locks.lock().await;
993 if locks.contains_key(database_identity) {
994 return None;
995 }
996 locks.insert(database_identity.to_string(), ());
997 Some(DatabaseLockGuard {
998 key: database_identity.to_string(),
999 locks: Arc::clone(&self.database_locks),
1000 })
1001 }
1002 }
1003
1004 #[tokio::test]
1005 async fn try_lock_database_high_concurrency_same_db() {
1006 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
1008 let concurrency = 50;
1009 let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1010 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
1011
1012 let mut handles = Vec::with_capacity(concurrency);
1013 for _ in 0..concurrency {
1014 let locks_clone = Arc::clone(&locks);
1015 let count = Arc::clone(&acquired_count);
1016 let bar = Arc::clone(&barrier);
1017 handles.push(tokio::spawn(async move {
1018 bar.wait().await;
1020 let ctx = OperatorContextLockHelper {
1021 database_locks: locks_clone,
1022 };
1023 let guard = ctx.try_lock("contested-db").await;
1024 if guard.is_some() {
1025 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1026 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1028 }
1029 }));
1030 }
1031
1032 for h in handles {
1033 h.await.unwrap();
1034 }
1035
1036 let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
1038 assert_eq!(
1039 total, 1,
1040 "exactly one of {concurrency} concurrent tasks should acquire the lock, got {total}"
1041 );
1042 }
1043
1044 #[tokio::test]
1045 async fn try_lock_database_high_concurrency_different_dbs() {
1046 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
1048 let concurrency = 50;
1049 let acquired_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1050 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
1051
1052 let mut handles = Vec::with_capacity(concurrency);
1053 for i in 0..concurrency {
1054 let locks_clone = Arc::clone(&locks);
1055 let count = Arc::clone(&acquired_count);
1056 let bar = Arc::clone(&barrier);
1057 handles.push(tokio::spawn(async move {
1058 bar.wait().await;
1059 let ctx = OperatorContextLockHelper {
1060 database_locks: locks_clone,
1061 };
1062 let db_name = format!("db-{i}");
1063 let guard = ctx.try_lock(&db_name).await;
1064 if guard.is_some() {
1065 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1066 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
1067 }
1068 }));
1069 }
1070
1071 for h in handles {
1072 h.await.unwrap();
1073 }
1074
1075 let total = acquired_count.load(std::sync::atomic::Ordering::SeqCst);
1076 assert_eq!(
1077 total, concurrency,
1078 "all {concurrency} tasks locking different dbs should succeed, got {total}"
1079 );
1080 }
1081
1082 #[tokio::test]
1083 async fn try_lock_database_acquire_release_cycle_under_contention() {
1084 let locks: Arc<Mutex<HashMap<String, ()>>> = Arc::new(Mutex::new(HashMap::new()));
1088 let concurrency = 20;
1089 let success_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1090 let barrier = Arc::new(tokio::sync::Barrier::new(concurrency));
1091
1092 let mut handles = Vec::with_capacity(concurrency);
1093 for _ in 0..concurrency {
1094 let locks_clone = Arc::clone(&locks);
1095 let count = Arc::clone(&success_count);
1096 let bar = Arc::clone(&barrier);
1097 handles.push(tokio::spawn(async move {
1098 bar.wait().await;
1099 for _ in 0..100 {
1102 let ctx = OperatorContextLockHelper {
1103 database_locks: Arc::clone(&locks_clone),
1104 };
1105 if let Some(_guard) = ctx.try_lock("shared-db").await {
1106 count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1107 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
1109 return;
1110 }
1111 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
1112 }
1113 panic!("task failed to acquire lock after 100 retries");
1115 }));
1116 }
1117
1118 for h in handles {
1119 h.await.unwrap();
1120 }
1121
1122 let total = success_count.load(std::sync::atomic::Ordering::SeqCst);
1123 assert_eq!(
1124 total, concurrency,
1125 "all {concurrency} tasks should eventually acquire the lock"
1126 );
1127 }
1128}