1use std::sync::Arc;
16use std::time::{Duration, SystemTime, UNIX_EPOCH};
17
18use anyhow::{Context, Result};
19use async_trait::async_trait;
20use azure_core::credentials::TokenCredential;
21use azure_identity::{
22 DeveloperToolsCredential, ManagedIdentityCredential, WorkloadIdentityCredential,
23};
24
25pub const DEFAULT_AUDIENCE: &str = "https://ossrdbms-aad.database.windows.net/.default";
31
32const DEFAULT_MAX_CONNECTIONS: u32 = 10;
36
37const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(30);
39
40const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(20 * 60);
50
51#[derive(Clone, Debug)]
69pub struct EntraAuthOptions {
70 audience: String,
71 max_connections: u32,
72 acquire_timeout: Duration,
73 refresh_interval: Duration,
74}
75
76impl Default for EntraAuthOptions {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl EntraAuthOptions {
83 pub fn new() -> Self {
85 let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
86 .ok()
87 .and_then(|s| s.parse::<u32>().ok())
88 .unwrap_or(DEFAULT_MAX_CONNECTIONS);
89 Self {
90 audience: DEFAULT_AUDIENCE.to_string(),
91 max_connections,
92 acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
93 refresh_interval: DEFAULT_REFRESH_INTERVAL,
94 }
95 }
96
97 pub fn audience(mut self, audience: impl Into<String>) -> Self {
101 self.audience = audience.into();
102 self
103 }
104
105 pub fn max_connections(mut self, max: u32) -> Self {
111 self.max_connections = max.max(1);
112 self
113 }
114
115 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
117 self.acquire_timeout = timeout;
118 self
119 }
120
121 pub fn refresh_interval(mut self, interval: Duration) -> Self {
125 self.refresh_interval = interval;
126 self
127 }
128
129 pub(crate) fn audience_str(&self) -> &str {
131 &self.audience
132 }
133
134 pub(crate) fn max_connections_value(&self) -> u32 {
136 self.max_connections
137 }
138
139 pub(crate) fn acquire_timeout_value(&self) -> Duration {
141 self.acquire_timeout
142 }
143
144 pub(crate) fn refresh_interval_value(&self) -> Duration {
146 self.refresh_interval
147 }
148
149 pub(crate) fn default_token_source(&self) -> Result<Arc<dyn TokenSource>> {
155 let credential =
156 build_default_chained_credential().context("Entra credential resolution failed")?;
157 Ok(Arc::new(AzureIdentityTokenSource::new(credential)))
158 }
159}
160
161#[derive(Clone)]
168pub(crate) struct EntraToken {
169 pub(crate) secret: String,
170 pub(crate) expires_at: SystemTime,
171}
172
173impl std::fmt::Debug for EntraToken {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("EntraToken")
176 .field("secret", &"<redacted>")
177 .field("expires_at", &self.expires_at)
178 .finish()
179 }
180}
181
182impl EntraToken {
183 pub(crate) fn new(secret: String, expires_at: SystemTime) -> Self {
184 Self { secret, expires_at }
185 }
186}
187
188#[async_trait]
193pub(crate) trait TokenSource: Send + Sync {
194 async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken>;
195}
196
197pub(crate) struct AzureIdentityTokenSource {
200 credential: Arc<dyn TokenCredential>,
201}
202
203impl AzureIdentityTokenSource {
204 pub(crate) fn new(credential: Arc<dyn TokenCredential>) -> Self {
205 Self { credential }
206 }
207}
208
209#[async_trait]
210impl TokenSource for AzureIdentityTokenSource {
211 async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
212 let access = self
213 .credential
214 .get_token(scopes, None)
215 .await
216 .map_err(|e| anyhow::anyhow!("Entra token acquisition failed: {e}"))?;
217
218 let expires_at = offset_datetime_to_system_time(access.expires_on);
219 validate_token_freshness(
220 SystemTime::now(),
221 expires_at,
222 crate::provider::ENTRA_REFRESH_SAFETY_MARGIN,
223 )?;
224 Ok(EntraToken::new(
225 access.token.secret().to_string(),
226 expires_at,
227 ))
228 }
229}
230
231fn offset_datetime_to_system_time(t: azure_core::time::OffsetDateTime) -> SystemTime {
235 let seconds = t.unix_timestamp();
236 if seconds < 0 {
237 return UNIX_EPOCH;
238 }
239 UNIX_EPOCH
240 .checked_add(Duration::from_secs(seconds as u64))
241 .unwrap_or(UNIX_EPOCH)
242}
243
244pub(crate) fn validate_token_freshness(
253 now: SystemTime,
254 expires_at: SystemTime,
255 margin: Duration,
256) -> Result<()> {
257 let cutoff = now
258 .checked_add(margin)
259 .ok_or_else(|| anyhow::anyhow!("clock arithmetic overflow validating token freshness"))?;
260 if expires_at <= cutoff {
261 let secs_remaining = expires_at
262 .duration_since(now)
263 .map(|d| d.as_secs() as i64)
264 .unwrap_or_else(|e| -(e.duration().as_secs() as i64));
265 anyhow::bail!(
266 "Entra token rejected: expires_at is too close to now \
267 (remaining={}s, required margin={}s). Possible upstream SDK \
268 bug, clock skew on the credential issuer, or stale cached \
269 token.",
270 secs_remaining,
271 margin.as_secs(),
272 );
273 }
274 Ok(())
275}
276
277fn build_default_chained_credential() -> azure_core::Result<Arc<dyn TokenCredential>> {
292 let mut sources: Vec<(&'static str, Arc<dyn TokenCredential>)> = Vec::new();
293 if let Ok(workload) = WorkloadIdentityCredential::new(None) {
299 sources.push(("WorkloadIdentityCredential", workload));
300 }
301 sources.push((
302 "ManagedIdentityCredential",
303 ManagedIdentityCredential::new(None)?,
304 ));
305 sources.push((
306 "DeveloperToolsCredential",
307 DeveloperToolsCredential::new(None)?,
308 ));
309 Ok(Arc::new(ChainedCredential::new(sources)))
310}
311
312struct ChainedCredential {
325 sources: Vec<(&'static str, Arc<dyn TokenCredential>)>,
326 logged_first_success: std::sync::OnceLock<()>,
327}
328
329impl ChainedCredential {
330 fn new(sources: Vec<(&'static str, Arc<dyn TokenCredential>)>) -> Self {
331 Self {
332 sources,
333 logged_first_success: std::sync::OnceLock::new(),
334 }
335 }
336}
337
338impl std::fmt::Debug for ChainedCredential {
339 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340 f.write_str("ChainedCredential")
341 }
342}
343
344#[async_trait]
345impl TokenCredential for ChainedCredential {
346 async fn get_token(
347 &self,
348 scopes: &[&str],
349 options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
350 ) -> azure_core::Result<azure_core::credentials::AccessToken> {
351 let mut errors: Vec<String> = Vec::new();
352 for (name, source) in &self.sources {
353 match source.get_token(scopes, options.clone()).await {
354 Ok(token) => {
355 if self.logged_first_success.set(()).is_ok() {
356 tracing::info!(
357 target: "duroxide::providers::postgres",
358 credential = %name,
359 "Entra credential chain: token acquired (first success on this instance)",
360 );
361 }
362 return Ok(token);
363 }
364 Err(e) => errors.push(format!("{name}: {e}")),
365 }
366 }
367 Err(azure_core::Error::with_message_fn(
368 azure_core::error::ErrorKind::Credential,
369 move || {
370 format!(
371 "All chained Entra credentials failed to acquire a token:\n - {}",
372 errors.join("\n - ")
373 )
374 },
375 ))
376 }
377}
378
379#[cfg(test)]
380pub(crate) mod test_support {
381 use super::*;
383 use std::sync::atomic::{AtomicUsize, Ordering};
384 use std::sync::Mutex;
385
386 pub(crate) struct RecordingFakeTokenSource {
390 scripted: Mutex<Vec<EntraToken>>,
391 recorded_scopes: Mutex<Vec<Vec<String>>>,
392 call_count: AtomicUsize,
393 fail_with: Option<String>,
394 }
395
396 impl RecordingFakeTokenSource {
397 pub(crate) fn with_tokens(tokens: Vec<EntraToken>) -> Arc<Self> {
398 Arc::new(Self {
399 scripted: Mutex::new(tokens),
400 recorded_scopes: Mutex::new(Vec::new()),
401 call_count: AtomicUsize::new(0),
402 fail_with: None,
403 })
404 }
405
406 pub(crate) fn always_failing(message: impl Into<String>) -> Arc<Self> {
407 Arc::new(Self {
408 scripted: Mutex::new(Vec::new()),
409 recorded_scopes: Mutex::new(Vec::new()),
410 call_count: AtomicUsize::new(0),
411 fail_with: Some(message.into()),
412 })
413 }
414
415 pub(crate) fn call_count(&self) -> usize {
416 self.call_count.load(Ordering::SeqCst)
417 }
418
419 pub(crate) fn recorded_scopes(&self) -> Vec<Vec<String>> {
420 self.recorded_scopes.lock().unwrap().clone()
421 }
422 }
423
424 #[async_trait]
425 impl TokenSource for RecordingFakeTokenSource {
426 async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
427 self.call_count.fetch_add(1, Ordering::SeqCst);
428 self.recorded_scopes
429 .lock()
430 .unwrap()
431 .push(scopes.iter().map(|s| s.to_string()).collect());
432 if let Some(msg) = &self.fail_with {
433 return Err(anyhow::anyhow!("{msg}"));
434 }
435 let mut scripted = self.scripted.lock().unwrap();
436 if scripted.is_empty() {
437 return Err(anyhow::anyhow!(
438 "RecordingFakeTokenSource: script exhausted"
439 ));
440 }
441 Ok(scripted.remove(0))
442 }
443 }
444
445 pub(crate) fn token(secret: &str, expires_in_secs: u64) -> EntraToken {
446 EntraToken::new(
447 secret.to_string(),
448 SystemTime::now() + Duration::from_secs(expires_in_secs),
449 )
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::test_support::*;
456 use super::*;
457
458 #[test]
459 fn defaults_match_password_path() {
460 let opts = EntraAuthOptions::new();
461 assert_eq!(opts.audience_str(), DEFAULT_AUDIENCE);
462 assert_eq!(opts.max_connections_value(), 10);
463 assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(30));
464 assert_eq!(opts.refresh_interval_value(), DEFAULT_REFRESH_INTERVAL);
465 }
466
467 #[test]
468 fn audience_override_round_trips() {
469 let opts = EntraAuthOptions::new()
470 .audience("https://ossrdbms-aad.database.usgovcloudapi.net/.default");
471 assert_eq!(
472 opts.audience_str(),
473 "https://ossrdbms-aad.database.usgovcloudapi.net/.default"
474 );
475 }
476
477 #[test]
478 fn pool_tunables_round_trip() {
479 let opts = EntraAuthOptions::new()
480 .max_connections(5)
481 .acquire_timeout(Duration::from_secs(45))
482 .refresh_interval(Duration::from_secs(120));
483 assert_eq!(opts.max_connections_value(), 5);
484 assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(45));
485 assert_eq!(opts.refresh_interval_value(), Duration::from_secs(120));
486 }
487
488 #[tokio::test]
489 async fn fake_token_source_returns_scripted_tokens() {
490 let source = RecordingFakeTokenSource::with_tokens(vec![
491 token("first", 3600),
492 token("second", 3600),
493 ]);
494 let t1 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
495 let t2 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
496 assert_eq!(t1.secret, "first");
497 assert_eq!(t2.secret, "second");
498 assert_eq!(source.call_count(), 2);
499 assert_eq!(
500 source.recorded_scopes(),
501 vec![
502 vec![DEFAULT_AUDIENCE.to_string()],
503 vec![DEFAULT_AUDIENCE.to_string()]
504 ]
505 );
506 }
507
508 #[tokio::test]
509 async fn fake_token_source_propagates_failures() {
510 let source = RecordingFakeTokenSource::always_failing("simulated");
511 let err = source
512 .fetch_token(&[DEFAULT_AUDIENCE])
513 .await
514 .expect_err("must fail");
515 assert!(err.to_string().contains("simulated"));
516 assert_eq!(source.call_count(), 1);
517 }
518
519 #[test]
520 fn offset_datetime_conversion_handles_negative() {
521 let pre_epoch =
524 azure_core::time::OffsetDateTime::UNIX_EPOCH - azure_core::time::Duration::seconds(60);
525 assert_eq!(offset_datetime_to_system_time(pre_epoch), UNIX_EPOCH);
526
527 let post_epoch =
528 azure_core::time::OffsetDateTime::UNIX_EPOCH + azure_core::time::Duration::seconds(120);
529 let converted = offset_datetime_to_system_time(post_epoch);
530 assert_eq!(converted, UNIX_EPOCH + Duration::from_secs(120));
531 }
532
533 #[derive(Debug)]
535 struct StubCred {
536 ok: bool,
537 label: &'static str,
538 }
539
540 #[async_trait]
541 impl TokenCredential for StubCred {
542 async fn get_token(
543 &self,
544 _scopes: &[&str],
545 _options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
546 ) -> azure_core::Result<azure_core::credentials::AccessToken> {
547 if self.ok {
548 Ok(azure_core::credentials::AccessToken::new(
549 azure_core::credentials::Secret::new(format!("token-from-{}", self.label)),
550 azure_core::time::OffsetDateTime::UNIX_EPOCH
551 + azure_core::time::Duration::seconds(3_700_000_000),
552 ))
553 } else {
554 Err(azure_core::Error::with_message(
555 azure_core::error::ErrorKind::Credential,
556 format!("{} failed", self.label),
557 ))
558 }
559 }
560 }
561
562 #[tokio::test]
563 async fn chained_credential_returns_first_success_in_chain_order() {
564 let chain = ChainedCredential::new(vec![
565 (
566 "Failing",
567 Arc::new(StubCred {
568 ok: false,
569 label: "Failing",
570 }),
571 ),
572 (
573 "Winner",
574 Arc::new(StubCred {
575 ok: true,
576 label: "Winner",
577 }),
578 ),
579 (
580 "ShouldNotBeCalled",
581 Arc::new(StubCred {
582 ok: true,
583 label: "ShouldNotBeCalled",
584 }),
585 ),
586 ]);
587 let token = chain.get_token(&["aud"], None).await.unwrap();
588 assert_eq!(token.token.secret(), "token-from-Winner");
589 }
590
591 #[tokio::test]
592 async fn chained_credential_aggregates_class_names_in_failure_message() {
593 let chain = ChainedCredential::new(vec![
594 (
595 "Workload",
596 Arc::new(StubCred {
597 ok: false,
598 label: "WorkloadIdentity",
599 }),
600 ),
601 (
602 "Managed",
603 Arc::new(StubCred {
604 ok: false,
605 label: "ManagedIdentity",
606 }),
607 ),
608 (
609 "Dev",
610 Arc::new(StubCred {
611 ok: false,
612 label: "DeveloperTools",
613 }),
614 ),
615 ]);
616 let err = chain.get_token(&["aud"], None).await.expect_err("all fail");
617 let msg = format!("{err}");
618 assert!(msg.contains("Workload"), "{msg}");
619 assert!(msg.contains("Managed"), "{msg}");
620 assert!(msg.contains("Dev"), "{msg}");
621 }
622
623 #[tokio::test]
630 async fn chained_credential_logs_first_success_only_once() {
631 let chain = ChainedCredential::new(vec![(
632 "Winner",
633 Arc::new(StubCred {
634 ok: true,
635 label: "Winner",
636 }),
637 )]);
638 assert!(
639 chain.logged_first_success.get().is_none(),
640 "should start unset"
641 );
642 let _ = chain.get_token(&["aud"], None).await.unwrap();
643 assert!(
644 chain.logged_first_success.get().is_some(),
645 "OnceLock must be populated after first success",
646 );
647 let _ = chain.get_token(&["aud"], None).await.unwrap();
653 assert!(chain.logged_first_success.get().is_some());
654 }
655
656 #[test]
658 fn validate_token_freshness_rejects_already_expired_token() {
659 let now = SystemTime::now();
660 let expires_at = now - Duration::from_secs(10); let err = validate_token_freshness(now, expires_at, Duration::from_secs(60))
662 .expect_err("must reject");
663 let msg = format!("{err}");
664 assert!(msg.contains("too close to now"), "{msg}");
665 assert!(msg.contains("clock skew"), "{msg}");
666 }
667
668 #[test]
669 fn validate_token_freshness_rejects_token_within_safety_margin() {
670 let now = SystemTime::now();
671 let expires_at = now + Duration::from_secs(60);
673 let err = validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
674 .expect_err("must reject");
675 assert!(format!("{err}").contains("too close to now"));
676 }
677
678 #[test]
679 fn validate_token_freshness_accepts_fresh_token() {
680 let now = SystemTime::now();
681 let expires_at = now + Duration::from_secs(3600);
682 validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
683 .expect("must accept fresh token");
684 }
685
686 #[test]
687 fn validate_token_freshness_rejects_at_exact_cutoff() {
688 let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000);
692 let margin = Duration::from_secs(60);
693 let expires_at = now + margin;
694 validate_token_freshness(now, expires_at, margin).expect_err("must reject at exact cutoff");
695 }
696
697 #[test]
700 fn max_connections_zero_is_clamped_to_one() {
701 let opts = EntraAuthOptions::new().max_connections(0);
702 assert_eq!(opts.max_connections_value(), 1);
703 }
704
705 #[test]
706 fn max_connections_one_is_preserved() {
707 let opts = EntraAuthOptions::new().max_connections(1);
708 assert_eq!(opts.max_connections_value(), 1);
709 }
710}