1use std::sync::Arc;
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16
17use anyhow::{Context, Result};
18use async_trait::async_trait;
19use azure_core::credentials::TokenCredential;
20use azure_identity::{
21 DeveloperToolsCredential, ManagedIdentityCredential, WorkloadIdentityCredential,
22};
23
24pub const DEFAULT_AUDIENCE: &str = "https://ossrdbms-aad.database.windows.net/.default";
30
31const DEFAULT_MAX_CONNECTIONS: u32 = 10;
35
36const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(30);
38
39const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(20 * 60);
49
50#[derive(Clone, Debug)]
68pub struct EntraAuthOptions {
69 audience: String,
70 max_connections: u32,
71 acquire_timeout: Duration,
72 refresh_interval: Duration,
73}
74
75impl Default for EntraAuthOptions {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl EntraAuthOptions {
82 pub fn new() -> Self {
84 let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
85 .ok()
86 .and_then(|s| s.parse::<u32>().ok())
87 .unwrap_or(DEFAULT_MAX_CONNECTIONS);
88 Self {
89 audience: DEFAULT_AUDIENCE.to_string(),
90 max_connections,
91 acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
92 refresh_interval: DEFAULT_REFRESH_INTERVAL,
93 }
94 }
95
96 pub fn audience(mut self, audience: impl Into<String>) -> Self {
100 self.audience = audience.into();
101 self
102 }
103
104 pub fn max_connections(mut self, max: u32) -> Self {
110 self.max_connections = max.max(1);
111 self
112 }
113
114 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
116 self.acquire_timeout = timeout;
117 self
118 }
119
120 pub fn refresh_interval(mut self, interval: Duration) -> Self {
124 self.refresh_interval = interval;
125 self
126 }
127
128 pub(crate) fn audience_str(&self) -> &str {
130 &self.audience
131 }
132
133 pub(crate) fn max_connections_value(&self) -> u32 {
135 self.max_connections
136 }
137
138 pub(crate) fn acquire_timeout_value(&self) -> Duration {
140 self.acquire_timeout
141 }
142
143 pub(crate) fn refresh_interval_value(&self) -> Duration {
145 self.refresh_interval
146 }
147
148 pub(crate) fn default_token_source(&self) -> Result<Arc<dyn TokenSource>> {
154 let credential =
155 build_default_chained_credential().context("Entra credential resolution failed")?;
156 Ok(Arc::new(AzureIdentityTokenSource::new(credential)))
157 }
158}
159
160#[derive(Clone)]
167pub(crate) struct EntraToken {
168 pub(crate) secret: String,
169 pub(crate) expires_at: SystemTime,
170}
171
172impl std::fmt::Debug for EntraToken {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 f.debug_struct("EntraToken")
175 .field("secret", &"<redacted>")
176 .field("expires_at", &self.expires_at)
177 .finish()
178 }
179}
180
181impl EntraToken {
182 pub(crate) fn new(secret: String, expires_at: SystemTime) -> Self {
183 Self { secret, expires_at }
184 }
185}
186
187#[async_trait]
192pub(crate) trait TokenSource: Send + Sync {
193 async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken>;
194}
195
196pub(crate) struct AzureIdentityTokenSource {
199 credential: Arc<dyn TokenCredential>,
200}
201
202impl AzureIdentityTokenSource {
203 pub(crate) fn new(credential: Arc<dyn TokenCredential>) -> Self {
204 Self { credential }
205 }
206}
207
208#[async_trait]
209impl TokenSource for AzureIdentityTokenSource {
210 async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
211 let access = self
212 .credential
213 .get_token(scopes, None)
214 .await
215 .map_err(|e| anyhow::anyhow!("Entra token acquisition failed: {e}"))?;
216
217 let expires_at = offset_datetime_to_system_time(access.expires_on);
218 validate_token_freshness(
219 SystemTime::now(),
220 expires_at,
221 crate::provider::ENTRA_REFRESH_SAFETY_MARGIN,
222 )?;
223 Ok(EntraToken::new(
224 access.token.secret().to_string(),
225 expires_at,
226 ))
227 }
228}
229
230fn offset_datetime_to_system_time(t: azure_core::time::OffsetDateTime) -> SystemTime {
234 let seconds = t.unix_timestamp();
235 if seconds < 0 {
236 return UNIX_EPOCH;
237 }
238 UNIX_EPOCH
239 .checked_add(Duration::from_secs(seconds as u64))
240 .unwrap_or(UNIX_EPOCH)
241}
242
243pub(crate) fn validate_token_freshness(
252 now: SystemTime,
253 expires_at: SystemTime,
254 margin: Duration,
255) -> Result<()> {
256 let cutoff = now
257 .checked_add(margin)
258 .ok_or_else(|| anyhow::anyhow!("clock arithmetic overflow validating token freshness"))?;
259 if expires_at <= cutoff {
260 let secs_remaining = expires_at
261 .duration_since(now)
262 .map(|d| d.as_secs() as i64)
263 .unwrap_or_else(|e| -(e.duration().as_secs() as i64));
264 anyhow::bail!(
265 "Entra token rejected: expires_at is too close to now \
266 (remaining={}s, required margin={}s). Possible upstream SDK \
267 bug, clock skew on the credential issuer, or stale cached \
268 token.",
269 secs_remaining,
270 margin.as_secs(),
271 );
272 }
273 Ok(())
274}
275
276fn build_default_chained_credential() -> azure_core::Result<Arc<dyn TokenCredential>> {
291 let mut sources: Vec<(&'static str, Arc<dyn TokenCredential>)> = Vec::new();
292 if let Ok(workload) = WorkloadIdentityCredential::new(None) {
298 sources.push(("WorkloadIdentityCredential", workload));
299 }
300 sources.push((
301 "ManagedIdentityCredential",
302 ManagedIdentityCredential::new(None)?,
303 ));
304 sources.push((
305 "DeveloperToolsCredential",
306 DeveloperToolsCredential::new(None)?,
307 ));
308 Ok(Arc::new(ChainedCredential::new(sources)))
309}
310
311struct ChainedCredential {
324 sources: Vec<(&'static str, Arc<dyn TokenCredential>)>,
325 logged_first_success: std::sync::OnceLock<()>,
326}
327
328impl ChainedCredential {
329 fn new(sources: Vec<(&'static str, Arc<dyn TokenCredential>)>) -> Self {
330 Self {
331 sources,
332 logged_first_success: std::sync::OnceLock::new(),
333 }
334 }
335}
336
337impl std::fmt::Debug for ChainedCredential {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 f.write_str("ChainedCredential")
340 }
341}
342
343#[async_trait]
344impl TokenCredential for ChainedCredential {
345 async fn get_token(
346 &self,
347 scopes: &[&str],
348 options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
349 ) -> azure_core::Result<azure_core::credentials::AccessToken> {
350 let mut errors: Vec<String> = Vec::new();
351 for (name, source) in &self.sources {
352 match source.get_token(scopes, options.clone()).await {
353 Ok(token) => {
354 if self.logged_first_success.set(()).is_ok() {
355 tracing::info!(
356 target: "duroxide::providers::postgres",
357 credential = %name,
358 "Entra credential chain: token acquired (first success on this instance)",
359 );
360 }
361 return Ok(token);
362 }
363 Err(e) => errors.push(format!("{name}: {e}")),
364 }
365 }
366 Err(azure_core::Error::with_message_fn(
367 azure_core::error::ErrorKind::Credential,
368 move || {
369 format!(
370 "All chained Entra credentials failed to acquire a token:\n - {}",
371 errors.join("\n - ")
372 )
373 },
374 ))
375 }
376}
377
378#[cfg(test)]
379pub(crate) mod test_support {
380 use super::*;
382 use std::sync::atomic::{AtomicUsize, Ordering};
383 use std::sync::Mutex;
384
385 pub(crate) struct RecordingFakeTokenSource {
389 scripted: Mutex<Vec<EntraToken>>,
390 recorded_scopes: Mutex<Vec<Vec<String>>>,
391 call_count: AtomicUsize,
392 fail_with: Option<String>,
393 }
394
395 impl RecordingFakeTokenSource {
396 pub(crate) fn with_tokens(tokens: Vec<EntraToken>) -> Arc<Self> {
397 Arc::new(Self {
398 scripted: Mutex::new(tokens),
399 recorded_scopes: Mutex::new(Vec::new()),
400 call_count: AtomicUsize::new(0),
401 fail_with: None,
402 })
403 }
404
405 pub(crate) fn always_failing(message: impl Into<String>) -> Arc<Self> {
406 Arc::new(Self {
407 scripted: Mutex::new(Vec::new()),
408 recorded_scopes: Mutex::new(Vec::new()),
409 call_count: AtomicUsize::new(0),
410 fail_with: Some(message.into()),
411 })
412 }
413
414 pub(crate) fn call_count(&self) -> usize {
415 self.call_count.load(Ordering::SeqCst)
416 }
417
418 pub(crate) fn recorded_scopes(&self) -> Vec<Vec<String>> {
419 self.recorded_scopes.lock().unwrap().clone()
420 }
421 }
422
423 #[async_trait]
424 impl TokenSource for RecordingFakeTokenSource {
425 async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
426 self.call_count.fetch_add(1, Ordering::SeqCst);
427 self.recorded_scopes
428 .lock()
429 .unwrap()
430 .push(scopes.iter().map(|s| s.to_string()).collect());
431 if let Some(msg) = &self.fail_with {
432 return Err(anyhow::anyhow!("{msg}"));
433 }
434 let mut scripted = self.scripted.lock().unwrap();
435 if scripted.is_empty() {
436 return Err(anyhow::anyhow!(
437 "RecordingFakeTokenSource: script exhausted"
438 ));
439 }
440 Ok(scripted.remove(0))
441 }
442 }
443
444 pub(crate) fn token(secret: &str, expires_in_secs: u64) -> EntraToken {
445 EntraToken::new(
446 secret.to_string(),
447 SystemTime::now() + Duration::from_secs(expires_in_secs),
448 )
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::test_support::*;
455 use super::*;
456
457 #[test]
458 fn defaults_match_password_path() {
459 let opts = EntraAuthOptions::new();
460 assert_eq!(opts.audience_str(), DEFAULT_AUDIENCE);
461 assert_eq!(opts.max_connections_value(), 10);
462 assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(30));
463 assert_eq!(opts.refresh_interval_value(), DEFAULT_REFRESH_INTERVAL);
464 }
465
466 #[test]
467 fn audience_override_round_trips() {
468 let opts = EntraAuthOptions::new()
469 .audience("https://ossrdbms-aad.database.usgovcloudapi.net/.default");
470 assert_eq!(
471 opts.audience_str(),
472 "https://ossrdbms-aad.database.usgovcloudapi.net/.default"
473 );
474 }
475
476 #[test]
477 fn pool_tunables_round_trip() {
478 let opts = EntraAuthOptions::new()
479 .max_connections(5)
480 .acquire_timeout(Duration::from_secs(45))
481 .refresh_interval(Duration::from_secs(120));
482 assert_eq!(opts.max_connections_value(), 5);
483 assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(45));
484 assert_eq!(opts.refresh_interval_value(), Duration::from_secs(120));
485 }
486
487 #[tokio::test]
488 async fn fake_token_source_returns_scripted_tokens() {
489 let source = RecordingFakeTokenSource::with_tokens(vec![
490 token("first", 3600),
491 token("second", 3600),
492 ]);
493 let t1 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
494 let t2 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
495 assert_eq!(t1.secret, "first");
496 assert_eq!(t2.secret, "second");
497 assert_eq!(source.call_count(), 2);
498 assert_eq!(
499 source.recorded_scopes(),
500 vec![
501 vec![DEFAULT_AUDIENCE.to_string()],
502 vec![DEFAULT_AUDIENCE.to_string()]
503 ]
504 );
505 }
506
507 #[tokio::test]
508 async fn fake_token_source_propagates_failures() {
509 let source = RecordingFakeTokenSource::always_failing("simulated");
510 let err = source
511 .fetch_token(&[DEFAULT_AUDIENCE])
512 .await
513 .expect_err("must fail");
514 assert!(err.to_string().contains("simulated"));
515 assert_eq!(source.call_count(), 1);
516 }
517
518 #[test]
519 fn offset_datetime_conversion_handles_negative() {
520 let pre_epoch =
523 azure_core::time::OffsetDateTime::UNIX_EPOCH - azure_core::time::Duration::seconds(60);
524 assert_eq!(offset_datetime_to_system_time(pre_epoch), UNIX_EPOCH);
525
526 let post_epoch =
527 azure_core::time::OffsetDateTime::UNIX_EPOCH + azure_core::time::Duration::seconds(120);
528 let converted = offset_datetime_to_system_time(post_epoch);
529 assert_eq!(converted, UNIX_EPOCH + Duration::from_secs(120));
530 }
531
532 #[derive(Debug)]
534 struct StubCred {
535 ok: bool,
536 label: &'static str,
537 }
538
539 #[async_trait]
540 impl TokenCredential for StubCred {
541 async fn get_token(
542 &self,
543 _scopes: &[&str],
544 _options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
545 ) -> azure_core::Result<azure_core::credentials::AccessToken> {
546 if self.ok {
547 Ok(azure_core::credentials::AccessToken::new(
548 azure_core::credentials::Secret::new(format!("token-from-{}", self.label)),
549 azure_core::time::OffsetDateTime::UNIX_EPOCH
550 + azure_core::time::Duration::seconds(3_700_000_000),
551 ))
552 } else {
553 Err(azure_core::Error::with_message(
554 azure_core::error::ErrorKind::Credential,
555 format!("{} failed", self.label),
556 ))
557 }
558 }
559 }
560
561 #[tokio::test]
562 async fn chained_credential_returns_first_success_in_chain_order() {
563 let chain = ChainedCredential::new(vec![
564 (
565 "Failing",
566 Arc::new(StubCred {
567 ok: false,
568 label: "Failing",
569 }),
570 ),
571 (
572 "Winner",
573 Arc::new(StubCred {
574 ok: true,
575 label: "Winner",
576 }),
577 ),
578 (
579 "ShouldNotBeCalled",
580 Arc::new(StubCred {
581 ok: true,
582 label: "ShouldNotBeCalled",
583 }),
584 ),
585 ]);
586 let token = chain.get_token(&["aud"], None).await.unwrap();
587 assert_eq!(token.token.secret(), "token-from-Winner");
588 }
589
590 #[tokio::test]
591 async fn chained_credential_aggregates_class_names_in_failure_message() {
592 let chain = ChainedCredential::new(vec![
593 (
594 "Workload",
595 Arc::new(StubCred {
596 ok: false,
597 label: "WorkloadIdentity",
598 }),
599 ),
600 (
601 "Managed",
602 Arc::new(StubCred {
603 ok: false,
604 label: "ManagedIdentity",
605 }),
606 ),
607 (
608 "Dev",
609 Arc::new(StubCred {
610 ok: false,
611 label: "DeveloperTools",
612 }),
613 ),
614 ]);
615 let err = chain.get_token(&["aud"], None).await.expect_err("all fail");
616 let msg = format!("{err}");
617 assert!(msg.contains("Workload"), "{msg}");
618 assert!(msg.contains("Managed"), "{msg}");
619 assert!(msg.contains("Dev"), "{msg}");
620 }
621
622 #[tokio::test]
629 async fn chained_credential_logs_first_success_only_once() {
630 let chain = ChainedCredential::new(vec![(
631 "Winner",
632 Arc::new(StubCred {
633 ok: true,
634 label: "Winner",
635 }),
636 )]);
637 assert!(
638 chain.logged_first_success.get().is_none(),
639 "should start unset"
640 );
641 let _ = chain.get_token(&["aud"], None).await.unwrap();
642 assert!(
643 chain.logged_first_success.get().is_some(),
644 "OnceLock must be populated after first success",
645 );
646 let _ = chain.get_token(&["aud"], None).await.unwrap();
652 assert!(chain.logged_first_success.get().is_some());
653 }
654
655 #[test]
657 fn validate_token_freshness_rejects_already_expired_token() {
658 let now = SystemTime::now();
659 let expires_at = now - Duration::from_secs(10); let err = validate_token_freshness(now, expires_at, Duration::from_secs(60))
661 .expect_err("must reject");
662 let msg = format!("{err}");
663 assert!(msg.contains("too close to now"), "{msg}");
664 assert!(msg.contains("clock skew"), "{msg}");
665 }
666
667 #[test]
668 fn validate_token_freshness_rejects_token_within_safety_margin() {
669 let now = SystemTime::now();
670 let expires_at = now + Duration::from_secs(60);
672 let err = validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
673 .expect_err("must reject");
674 assert!(format!("{err}").contains("too close to now"));
675 }
676
677 #[test]
678 fn validate_token_freshness_accepts_fresh_token() {
679 let now = SystemTime::now();
680 let expires_at = now + Duration::from_secs(3600);
681 validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
682 .expect("must accept fresh token");
683 }
684
685 #[test]
686 fn validate_token_freshness_rejects_at_exact_cutoff() {
687 let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000);
691 let margin = Duration::from_secs(60);
692 let expires_at = now + margin;
693 validate_token_freshness(now, expires_at, margin).expect_err("must reject at exact cutoff");
694 }
695
696 #[test]
699 fn max_connections_zero_is_clamped_to_one() {
700 let opts = EntraAuthOptions::new().max_connections(0);
701 assert_eq!(opts.max_connections_value(), 1);
702 }
703
704 #[test]
705 fn max_connections_one_is_preserved() {
706 let opts = EntraAuthOptions::new().max_connections(1);
707 assert_eq!(opts.max_connections_value(), 1);
708 }
709}