Skip to main content

duroxide_pg/
entra.rs

1//! Microsoft Entra ID (formerly Azure Active Directory) authentication support
2//! for [`PostgresProvider`](crate::PostgresProvider).
3//!
4//! This module exposes [`EntraAuthOptions`] — the configuration type passed to
5//! `PostgresProvider::new_with_entra` and `PostgresProvider::new_with_schema_and_entra`
6//! (added in Phase 2) — plus the internal credential abstractions used to
7//! fetch and rotate Entra access tokens.
8//!
9//! Azure SDK types (`azure_core::credentials::TokenCredential`,
10//! `azure_identity::ManagedIdentityCredential`, etc.) are intentionally **not
11//! re-exported**. The public surface is just [`EntraAuthOptions`]; everything
12//! else stays internal so we can adapt to upstream churn without a breaking
13//! change.
14
15use std::sync::Arc;
16use std::time::{Duration, SystemTime, UNIX_EPOCH};
17
18use anyhow::{Context, Result};
19use async_trait::async_trait;
20use azure_core::credentials::TokenCredential;
21use azure_identity::{
22    DeveloperToolsCredential, ManagedIdentityCredential, WorkloadIdentityCredential,
23};
24
25/// The default audience/scope used when requesting Entra ID access tokens for
26/// Azure Database for PostgreSQL Flexible Server in the Azure public cloud.
27///
28/// Sovereign clouds (Azure US Government, Azure China, etc.) require a
29/// different audience; override via [`EntraAuthOptions::audience`].
30pub const DEFAULT_AUDIENCE: &str = "https://ossrdbms-aad.database.windows.net/.default";
31
32/// Default `max_connections` for the pool. Matches the password-path default
33/// in [`PostgresProvider::new_with_schema`](crate::PostgresProvider::new_with_schema)
34/// when `DUROXIDE_PG_POOL_MAX` is not set.
35const DEFAULT_MAX_CONNECTIONS: u32 = 10;
36
37/// Default acquire timeout for the pool. Matches the password path.
38const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(30);
39
40/// Default refresh-interval ceiling. Treated as a *ceiling* — the actual
41/// refresh schedule is driven by each token's `expires_at` minus a safety
42/// margin (see Phase 2 of the implementation plan / refresh task in
43/// `provider.rs`).
44///
45/// Entra access tokens for Azure Database for PostgreSQL typically expire
46/// after one hour; refreshing every twenty minutes keeps connect options
47/// fresh well before expiry under normal conditions while keeping background
48/// chatter minimal.
49const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(20 * 60);
50
51/// Configuration for connecting to Azure Database for PostgreSQL using
52/// Microsoft Entra ID authentication.
53///
54/// Construct with [`EntraAuthOptions::new`] (sensible Azure-public-cloud
55/// defaults) and customize via the chainable mutators below.
56///
57/// # Example
58///
59/// ```rust,no_run
60/// use duroxide_pg::entra::EntraAuthOptions;
61/// use std::time::Duration;
62///
63/// let opts = EntraAuthOptions::new()
64///     .max_connections(20)
65///     .acquire_timeout(Duration::from_secs(45));
66/// # let _ = opts;
67/// ```
68#[derive(Clone, Debug)]
69pub struct EntraAuthOptions {
70    audience: String,
71    max_connections: u32,
72    acquire_timeout: Duration,
73    refresh_interval: Duration,
74}
75
76impl Default for EntraAuthOptions {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl EntraAuthOptions {
83    /// Create options with Azure public cloud defaults.
84    pub fn new() -> Self {
85        let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
86            .ok()
87            .and_then(|s| s.parse::<u32>().ok())
88            .unwrap_or(DEFAULT_MAX_CONNECTIONS);
89        Self {
90            audience: DEFAULT_AUDIENCE.to_string(),
91            max_connections,
92            acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
93            refresh_interval: DEFAULT_REFRESH_INTERVAL,
94        }
95    }
96
97    /// Override the token audience/scope. Required for sovereign clouds
98    /// (e.g., Azure US Government uses
99    /// `https://ossrdbms-aad.database.usgovcloudapi.net/.default`).
100    pub fn audience(mut self, audience: impl Into<String>) -> Self {
101        self.audience = audience.into();
102        self
103    }
104
105    /// Override the pool's maximum connection count.
106    ///
107    /// A value of `0` is silently clamped to `1` because the pool's
108    /// hardcoded `min_connections(1)` would otherwise reject the
109    /// configuration at runtime (SF-G).
110    pub fn max_connections(mut self, max: u32) -> Self {
111        self.max_connections = max.max(1);
112        self
113    }
114
115    /// Override the pool's connection acquisition timeout.
116    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
117        self.acquire_timeout = timeout;
118        self
119    }
120
121    /// Override the upper bound on time between refresh attempts. The
122    /// background refresh task may refresh sooner than this interval if a
123    /// token's `expires_at` is approaching.
124    pub fn refresh_interval(mut self, interval: Duration) -> Self {
125        self.refresh_interval = interval;
126        self
127    }
128
129    /// Internal: read accessor for the audience.
130    pub(crate) fn audience_str(&self) -> &str {
131        &self.audience
132    }
133
134    /// Internal: read accessor for max pool connections.
135    pub(crate) fn max_connections_value(&self) -> u32 {
136        self.max_connections
137    }
138
139    /// Internal: read accessor for acquire timeout.
140    pub(crate) fn acquire_timeout_value(&self) -> Duration {
141        self.acquire_timeout
142    }
143
144    /// Internal: read accessor for refresh-interval ceiling.
145    pub(crate) fn refresh_interval_value(&self) -> Duration {
146        self.refresh_interval
147    }
148
149    /// Construct the default [`TokenSource`] (the [`AzureIdentityTokenSource`]
150    /// wrapping the chained credential).
151    ///
152    /// Surfaces a descriptive error if the underlying Azure SDK fails to build
153    /// any credential in the chain.
154    pub(crate) fn default_token_source(&self) -> Result<Arc<dyn TokenSource>> {
155        let credential =
156            build_default_chained_credential().context("Entra credential resolution failed")?;
157        Ok(Arc::new(AzureIdentityTokenSource::new(credential)))
158    }
159}
160
161/// An Entra access token plus the wall-clock time at which it expires.
162///
163/// `Debug` is hand-written to redact the token secret. The struct is
164/// `pub(crate)` so it never leaves the crate, but a panic backtrace or
165/// `?token` formatter inside the crate could otherwise leak the bearer
166/// string into logs.
167#[derive(Clone)]
168pub(crate) struct EntraToken {
169    pub(crate) secret: String,
170    pub(crate) expires_at: SystemTime,
171}
172
173impl std::fmt::Debug for EntraToken {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("EntraToken")
176            .field("secret", &"<redacted>")
177            .field("expires_at", &self.expires_at)
178            .finish()
179    }
180}
181
182impl EntraToken {
183    pub(crate) fn new(secret: String, expires_at: SystemTime) -> Self {
184        Self { secret, expires_at }
185    }
186}
187
188/// Internal seam for fetching Entra tokens. Kept `pub(crate)` so the public
189/// surface of `duroxide-pg` does not bind callers to a specific `azure_*`
190/// version (mitigation for spec Risk #2). Tests in this crate inject fake
191/// implementations via this trait.
192#[async_trait]
193pub(crate) trait TokenSource: Send + Sync {
194    async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken>;
195}
196
197/// Production [`TokenSource`] that delegates to an
198/// [`azure_core::credentials::TokenCredential`].
199pub(crate) struct AzureIdentityTokenSource {
200    credential: Arc<dyn TokenCredential>,
201}
202
203impl AzureIdentityTokenSource {
204    pub(crate) fn new(credential: Arc<dyn TokenCredential>) -> Self {
205        Self { credential }
206    }
207}
208
209#[async_trait]
210impl TokenSource for AzureIdentityTokenSource {
211    async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
212        let access = self
213            .credential
214            .get_token(scopes, None)
215            .await
216            .map_err(|e| anyhow::anyhow!("Entra token acquisition failed: {e}"))?;
217
218        let expires_at = offset_datetime_to_system_time(access.expires_on);
219        validate_token_freshness(
220            SystemTime::now(),
221            expires_at,
222            crate::provider::ENTRA_REFRESH_SAFETY_MARGIN,
223        )?;
224        Ok(EntraToken::new(
225            access.token.secret().to_string(),
226            expires_at,
227        ))
228    }
229}
230
231/// Convert an `azure_core::time::OffsetDateTime` (from `time` crate) into a
232/// `SystemTime`. Treats pre-epoch values as `UNIX_EPOCH` (effectively immediate
233/// expiry).
234fn offset_datetime_to_system_time(t: azure_core::time::OffsetDateTime) -> SystemTime {
235    let seconds = t.unix_timestamp();
236    if seconds < 0 {
237        return UNIX_EPOCH;
238    }
239    UNIX_EPOCH
240        .checked_add(Duration::from_secs(seconds as u64))
241        .unwrap_or(UNIX_EPOCH)
242}
243
244/// Pure helper: validate that a freshly-issued token's `expires_at` leaves
245/// at least `margin` of useful lifetime relative to `now`. Defends against
246/// upstream returning an already-expired (or near-expired) token, which
247/// would otherwise produce an immediate connection-storm of 28xxx errors
248/// at the pool boundary (SF-D).
249///
250/// Returns `Ok(())` if the token is fresh enough, or an error explaining
251/// the skew/staleness on rejection.
252pub(crate) fn validate_token_freshness(
253    now: SystemTime,
254    expires_at: SystemTime,
255    margin: Duration,
256) -> Result<()> {
257    let cutoff = now
258        .checked_add(margin)
259        .ok_or_else(|| anyhow::anyhow!("clock arithmetic overflow validating token freshness"))?;
260    if expires_at <= cutoff {
261        let secs_remaining = expires_at
262            .duration_since(now)
263            .map(|d| d.as_secs() as i64)
264            .unwrap_or_else(|e| -(e.duration().as_secs() as i64));
265        anyhow::bail!(
266            "Entra token rejected: expires_at is too close to now \
267             (remaining={}s, required margin={}s). Possible upstream SDK \
268             bug, clock skew on the credential issuer, or stale cached \
269             token.",
270            secs_remaining,
271            margin.as_secs(),
272        );
273    }
274    Ok(())
275}
276
277/// Build the default chained credential used when the caller does not provide
278/// an explicit token source. The chain mimics the spirit of upstream
279/// `DefaultAzureCredential` (which is not present in `azure_identity = 0.35`):
280///
281/// 1. `WorkloadIdentityCredential` — federated tokens for AKS Workload
282///    Identity, GitHub OIDC, etc. (only succeeds when the corresponding env
283///    vars are set).
284/// 2. `ManagedIdentityCredential` — IMDS for Azure VMs, App Service,
285///    Container Apps, Container Instances, Functions.
286/// 3. `DeveloperToolsCredential` — Azure CLI (`az login`) and Azure Developer
287///    CLI (`azd auth login`) for local development.
288///
289/// Returns an `azure_core::Error` from any failed inner constructor; the
290/// caller wraps it into an `anyhow::Error`.
291fn build_default_chained_credential() -> azure_core::Result<Arc<dyn TokenCredential>> {
292    let mut sources: Vec<(&'static str, Arc<dyn TokenCredential>)> = Vec::new();
293    // WorkloadIdentityCredential::new only succeeds when AZURE_FEDERATED_TOKEN_FILE
294    // (and friends) are set. Skip silently otherwise so the chain still works
295    // on a developer laptop. Note: leaving these env vars set on a NON-AKS host
296    // can cause the chain to spend per-refresh latency budget here before
297    // falling through; see Docs.md "Troubleshooting" for details.
298    if let Ok(workload) = WorkloadIdentityCredential::new(None) {
299        sources.push(("WorkloadIdentityCredential", workload));
300    }
301    sources.push((
302        "ManagedIdentityCredential",
303        ManagedIdentityCredential::new(None)?,
304    ));
305    sources.push((
306        "DeveloperToolsCredential",
307        DeveloperToolsCredential::new(None)?,
308    ));
309    Ok(Arc::new(ChainedCredential::new(sources)))
310}
311
312/// Tiny `TokenCredential` chain — tries each source in order, returns the
313/// first that succeeds, aggregates errors otherwise. Modeled after upstream
314/// [`DeveloperToolsCredential`]'s chaining behavior.
315///
316/// Sources are stored alongside a static class-name so the first successful
317/// source can be logged at INFO — useful in production to confirm the
318/// expected principal class is being used (e.g. AKS pods should show
319/// `WorkloadIdentityCredential`, not `DeveloperToolsCredential`).
320///
321/// The INFO log fires **once per `ChainedCredential` instance** (gated via
322/// `OnceLock`) so a long-running process emits exactly one credential-class
323/// disclosure per provider, regardless of how many refreshes occur (SF-A).
324struct ChainedCredential {
325    sources: Vec<(&'static str, Arc<dyn TokenCredential>)>,
326    logged_first_success: std::sync::OnceLock<()>,
327}
328
329impl ChainedCredential {
330    fn new(sources: Vec<(&'static str, Arc<dyn TokenCredential>)>) -> Self {
331        Self {
332            sources,
333            logged_first_success: std::sync::OnceLock::new(),
334        }
335    }
336}
337
338impl std::fmt::Debug for ChainedCredential {
339    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340        f.write_str("ChainedCredential")
341    }
342}
343
344#[async_trait]
345impl TokenCredential for ChainedCredential {
346    async fn get_token(
347        &self,
348        scopes: &[&str],
349        options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
350    ) -> azure_core::Result<azure_core::credentials::AccessToken> {
351        let mut errors: Vec<String> = Vec::new();
352        for (name, source) in &self.sources {
353            match source.get_token(scopes, options.clone()).await {
354                Ok(token) => {
355                    if self.logged_first_success.set(()).is_ok() {
356                        tracing::info!(
357                            target: "duroxide::providers::postgres",
358                            credential = %name,
359                            "Entra credential chain: token acquired (first success on this instance)",
360                        );
361                    }
362                    return Ok(token);
363                }
364                Err(e) => errors.push(format!("{name}: {e}")),
365            }
366        }
367        Err(azure_core::Error::with_message_fn(
368            azure_core::error::ErrorKind::Credential,
369            move || {
370                format!(
371                    "All chained Entra credentials failed to acquire a token:\n  - {}",
372                    errors.join("\n  - ")
373                )
374            },
375        ))
376    }
377}
378
379#[cfg(test)]
380pub(crate) mod test_support {
381    //! Test fixtures shared across crate unit tests (entra and provider).
382    use super::*;
383    use std::sync::atomic::{AtomicUsize, Ordering};
384    use std::sync::Mutex;
385
386    /// `TokenSource` fixture used across `entra` and `provider` unit tests.
387    /// Returns successive tokens from a script and records the scopes passed
388    /// to each call.
389    pub(crate) struct RecordingFakeTokenSource {
390        scripted: Mutex<Vec<EntraToken>>,
391        recorded_scopes: Mutex<Vec<Vec<String>>>,
392        call_count: AtomicUsize,
393        fail_with: Option<String>,
394    }
395
396    impl RecordingFakeTokenSource {
397        pub(crate) fn with_tokens(tokens: Vec<EntraToken>) -> Arc<Self> {
398            Arc::new(Self {
399                scripted: Mutex::new(tokens),
400                recorded_scopes: Mutex::new(Vec::new()),
401                call_count: AtomicUsize::new(0),
402                fail_with: None,
403            })
404        }
405
406        pub(crate) fn always_failing(message: impl Into<String>) -> Arc<Self> {
407            Arc::new(Self {
408                scripted: Mutex::new(Vec::new()),
409                recorded_scopes: Mutex::new(Vec::new()),
410                call_count: AtomicUsize::new(0),
411                fail_with: Some(message.into()),
412            })
413        }
414
415        pub(crate) fn call_count(&self) -> usize {
416            self.call_count.load(Ordering::SeqCst)
417        }
418
419        pub(crate) fn recorded_scopes(&self) -> Vec<Vec<String>> {
420            self.recorded_scopes.lock().unwrap().clone()
421        }
422    }
423
424    #[async_trait]
425    impl TokenSource for RecordingFakeTokenSource {
426        async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
427            self.call_count.fetch_add(1, Ordering::SeqCst);
428            self.recorded_scopes
429                .lock()
430                .unwrap()
431                .push(scopes.iter().map(|s| s.to_string()).collect());
432            if let Some(msg) = &self.fail_with {
433                return Err(anyhow::anyhow!("{msg}"));
434            }
435            let mut scripted = self.scripted.lock().unwrap();
436            if scripted.is_empty() {
437                return Err(anyhow::anyhow!(
438                    "RecordingFakeTokenSource: script exhausted"
439                ));
440            }
441            Ok(scripted.remove(0))
442        }
443    }
444
445    pub(crate) fn token(secret: &str, expires_in_secs: u64) -> EntraToken {
446        EntraToken::new(
447            secret.to_string(),
448            SystemTime::now() + Duration::from_secs(expires_in_secs),
449        )
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::test_support::*;
456    use super::*;
457
458    #[test]
459    fn defaults_match_password_path() {
460        let opts = EntraAuthOptions::new();
461        assert_eq!(opts.audience_str(), DEFAULT_AUDIENCE);
462        assert_eq!(opts.max_connections_value(), 10);
463        assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(30));
464        assert_eq!(opts.refresh_interval_value(), DEFAULT_REFRESH_INTERVAL);
465    }
466
467    #[test]
468    fn audience_override_round_trips() {
469        let opts = EntraAuthOptions::new()
470            .audience("https://ossrdbms-aad.database.usgovcloudapi.net/.default");
471        assert_eq!(
472            opts.audience_str(),
473            "https://ossrdbms-aad.database.usgovcloudapi.net/.default"
474        );
475    }
476
477    #[test]
478    fn pool_tunables_round_trip() {
479        let opts = EntraAuthOptions::new()
480            .max_connections(5)
481            .acquire_timeout(Duration::from_secs(45))
482            .refresh_interval(Duration::from_secs(120));
483        assert_eq!(opts.max_connections_value(), 5);
484        assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(45));
485        assert_eq!(opts.refresh_interval_value(), Duration::from_secs(120));
486    }
487
488    #[tokio::test]
489    async fn fake_token_source_returns_scripted_tokens() {
490        let source = RecordingFakeTokenSource::with_tokens(vec![
491            token("first", 3600),
492            token("second", 3600),
493        ]);
494        let t1 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
495        let t2 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
496        assert_eq!(t1.secret, "first");
497        assert_eq!(t2.secret, "second");
498        assert_eq!(source.call_count(), 2);
499        assert_eq!(
500            source.recorded_scopes(),
501            vec![
502                vec![DEFAULT_AUDIENCE.to_string()],
503                vec![DEFAULT_AUDIENCE.to_string()]
504            ]
505        );
506    }
507
508    #[tokio::test]
509    async fn fake_token_source_propagates_failures() {
510        let source = RecordingFakeTokenSource::always_failing("simulated");
511        let err = source
512            .fetch_token(&[DEFAULT_AUDIENCE])
513            .await
514            .expect_err("must fail");
515        assert!(err.to_string().contains("simulated"));
516        assert_eq!(source.call_count(), 1);
517    }
518
519    #[test]
520    fn offset_datetime_conversion_handles_negative() {
521        // OffsetDateTime::UNIX_EPOCH is the canonical zero; pre-epoch values
522        // should clamp to UNIX_EPOCH rather than panic / underflow.
523        let pre_epoch =
524            azure_core::time::OffsetDateTime::UNIX_EPOCH - azure_core::time::Duration::seconds(60);
525        assert_eq!(offset_datetime_to_system_time(pre_epoch), UNIX_EPOCH);
526
527        let post_epoch =
528            azure_core::time::OffsetDateTime::UNIX_EPOCH + azure_core::time::Duration::seconds(120);
529        let converted = offset_datetime_to_system_time(post_epoch);
530        assert_eq!(converted, UNIX_EPOCH + Duration::from_secs(120));
531    }
532
533    /// Minimal `TokenCredential` stub for chain ordering tests.
534    #[derive(Debug)]
535    struct StubCred {
536        ok: bool,
537        label: &'static str,
538    }
539
540    #[async_trait]
541    impl TokenCredential for StubCred {
542        async fn get_token(
543            &self,
544            _scopes: &[&str],
545            _options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
546        ) -> azure_core::Result<azure_core::credentials::AccessToken> {
547            if self.ok {
548                Ok(azure_core::credentials::AccessToken::new(
549                    azure_core::credentials::Secret::new(format!("token-from-{}", self.label)),
550                    azure_core::time::OffsetDateTime::UNIX_EPOCH
551                        + azure_core::time::Duration::seconds(3_700_000_000),
552                ))
553            } else {
554                Err(azure_core::Error::with_message(
555                    azure_core::error::ErrorKind::Credential,
556                    format!("{} failed", self.label),
557                ))
558            }
559        }
560    }
561
562    #[tokio::test]
563    async fn chained_credential_returns_first_success_in_chain_order() {
564        let chain = ChainedCredential::new(vec![
565            (
566                "Failing",
567                Arc::new(StubCred {
568                    ok: false,
569                    label: "Failing",
570                }),
571            ),
572            (
573                "Winner",
574                Arc::new(StubCred {
575                    ok: true,
576                    label: "Winner",
577                }),
578            ),
579            (
580                "ShouldNotBeCalled",
581                Arc::new(StubCred {
582                    ok: true,
583                    label: "ShouldNotBeCalled",
584                }),
585            ),
586        ]);
587        let token = chain.get_token(&["aud"], None).await.unwrap();
588        assert_eq!(token.token.secret(), "token-from-Winner");
589    }
590
591    #[tokio::test]
592    async fn chained_credential_aggregates_class_names_in_failure_message() {
593        let chain = ChainedCredential::new(vec![
594            (
595                "Workload",
596                Arc::new(StubCred {
597                    ok: false,
598                    label: "WorkloadIdentity",
599                }),
600            ),
601            (
602                "Managed",
603                Arc::new(StubCred {
604                    ok: false,
605                    label: "ManagedIdentity",
606                }),
607            ),
608            (
609                "Dev",
610                Arc::new(StubCred {
611                    ok: false,
612                    label: "DeveloperTools",
613                }),
614            ),
615        ]);
616        let err = chain.get_token(&["aud"], None).await.expect_err("all fail");
617        let msg = format!("{err}");
618        assert!(msg.contains("Workload"), "{msg}");
619        assert!(msg.contains("Managed"), "{msg}");
620        assert!(msg.contains("Dev"), "{msg}");
621    }
622
623    /// SF-A: INFO log fires only on the first successful chain call per
624    /// instance, not on every refresh. We can't easily intercept `tracing`
625    /// output from a unit test without pulling in a subscriber crate, so
626    /// we instead pin the *gating mechanism* directly: the OnceLock must
627    /// be empty before the first call and populated after, and stay
628    /// populated after subsequent calls.
629    #[tokio::test]
630    async fn chained_credential_logs_first_success_only_once() {
631        let chain = ChainedCredential::new(vec![(
632            "Winner",
633            Arc::new(StubCred {
634                ok: true,
635                label: "Winner",
636            }),
637        )]);
638        assert!(
639            chain.logged_first_success.get().is_none(),
640            "should start unset"
641        );
642        let _ = chain.get_token(&["aud"], None).await.unwrap();
643        assert!(
644            chain.logged_first_success.get().is_some(),
645            "OnceLock must be populated after first success",
646        );
647        // Subsequent calls re-succeed but must NOT re-arm or re-log. The
648        // observable signal here is that `set` would have returned `Err`
649        // had we tried again — but get_token already swallows that. We
650        // verify by checking the OnceLock is still populated and the
651        // call still succeeds.
652        let _ = chain.get_token(&["aud"], None).await.unwrap();
653        assert!(chain.logged_first_success.get().is_some());
654    }
655
656    // SF-D: validate_token_freshness rejects stale/expired tokens.
657    #[test]
658    fn validate_token_freshness_rejects_already_expired_token() {
659        let now = SystemTime::now();
660        let expires_at = now - Duration::from_secs(10); // 10s in the past
661        let err = validate_token_freshness(now, expires_at, Duration::from_secs(60))
662            .expect_err("must reject");
663        let msg = format!("{err}");
664        assert!(msg.contains("too close to now"), "{msg}");
665        assert!(msg.contains("clock skew"), "{msg}");
666    }
667
668    #[test]
669    fn validate_token_freshness_rejects_token_within_safety_margin() {
670        let now = SystemTime::now();
671        // Token expires in 60s but margin requires 5min — must reject.
672        let expires_at = now + Duration::from_secs(60);
673        let err = validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
674            .expect_err("must reject");
675        assert!(format!("{err}").contains("too close to now"));
676    }
677
678    #[test]
679    fn validate_token_freshness_accepts_fresh_token() {
680        let now = SystemTime::now();
681        let expires_at = now + Duration::from_secs(3600);
682        validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
683            .expect("must accept fresh token");
684    }
685
686    #[test]
687    fn validate_token_freshness_rejects_at_exact_cutoff() {
688        // Boundary: token expiring exactly at now+margin should be rejected
689        // (we use strict `>`, not `>=`, because a token whose lifetime is
690        // exactly the margin has no useful working window).
691        let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000);
692        let margin = Duration::from_secs(60);
693        let expires_at = now + margin;
694        validate_token_freshness(now, expires_at, margin).expect_err("must reject at exact cutoff");
695    }
696
697    // SF-G: max_connections(0) clamps to 1 to satisfy the
698    // hardcoded min_connections(1) invariant on the pool builder.
699    #[test]
700    fn max_connections_zero_is_clamped_to_one() {
701        let opts = EntraAuthOptions::new().max_connections(0);
702        assert_eq!(opts.max_connections_value(), 1);
703    }
704
705    #[test]
706    fn max_connections_one_is_preserved() {
707        let opts = EntraAuthOptions::new().max_connections(1);
708        assert_eq!(opts.max_connections_value(), 1);
709    }
710}