Skip to main content

duroxide_pg/
entra.rs

1//! Microsoft Entra ID (formerly Azure Active Directory) authentication support
2//! for [`PostgresProvider`](crate::PostgresProvider).
3//!
4//! This module exposes [`EntraAuthOptions`] — the configuration type passed to
5//! [`ProviderConfig::entra`](crate::ProviderConfig::entra) — plus the internal
6//! credential abstractions used to fetch and rotate Entra access tokens.
7//!
8//! Azure SDK types (`azure_core::credentials::TokenCredential`,
9//! `azure_identity::ManagedIdentityCredential`, etc.) are intentionally **not
10//! re-exported**. The public surface is just [`EntraAuthOptions`]; everything
11//! else stays internal so we can adapt to upstream churn without a breaking
12//! change.
13
14use std::sync::Arc;
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16
17use anyhow::{Context, Result};
18use async_trait::async_trait;
19use azure_core::credentials::TokenCredential;
20use azure_identity::{
21    DeveloperToolsCredential, ManagedIdentityCredential, WorkloadIdentityCredential,
22};
23
24/// The default audience/scope used when requesting Entra ID access tokens for
25/// Azure Database for PostgreSQL Flexible Server in the Azure public cloud.
26///
27/// Sovereign clouds (Azure US Government, Azure China, etc.) require a
28/// different audience; override via [`EntraAuthOptions::audience`].
29pub const DEFAULT_AUDIENCE: &str = "https://ossrdbms-aad.database.windows.net/.default";
30
31/// Default `max_connections` for the pool. Matches the password-path default
32/// in [`PostgresProvider::new_with_schema`](crate::PostgresProvider::new_with_schema)
33/// when `DUROXIDE_PG_POOL_MAX` is not set.
34const DEFAULT_MAX_CONNECTIONS: u32 = 10;
35
36/// Default acquire timeout for the pool. Matches the password path.
37const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(30);
38
39/// Default refresh-interval ceiling. Treated as a *ceiling* — the actual
40/// refresh schedule is driven by each token's `expires_at` minus a safety
41/// margin (see Phase 2 of the implementation plan / refresh task in
42/// `provider.rs`).
43///
44/// Entra access tokens for Azure Database for PostgreSQL typically expire
45/// after one hour; refreshing every twenty minutes keeps connect options
46/// fresh well before expiry under normal conditions while keeping background
47/// chatter minimal.
48const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(20 * 60);
49
50/// Configuration for connecting to Azure Database for PostgreSQL using
51/// Microsoft Entra ID authentication.
52///
53/// Construct with [`EntraAuthOptions::new`] (sensible Azure-public-cloud
54/// defaults) and customize via the chainable mutators below.
55///
56/// # Example
57///
58/// ```rust,no_run
59/// use duroxide_pg::entra::EntraAuthOptions;
60/// use std::time::Duration;
61///
62/// let opts = EntraAuthOptions::new()
63///     .max_connections(20)
64///     .acquire_timeout(Duration::from_secs(45));
65/// # let _ = opts;
66/// ```
67#[derive(Clone, Debug)]
68pub struct EntraAuthOptions {
69    audience: String,
70    max_connections: u32,
71    acquire_timeout: Duration,
72    refresh_interval: Duration,
73}
74
75impl Default for EntraAuthOptions {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl EntraAuthOptions {
82    /// Create options with Azure public cloud defaults.
83    pub fn new() -> Self {
84        let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
85            .ok()
86            .and_then(|s| s.parse::<u32>().ok())
87            .unwrap_or(DEFAULT_MAX_CONNECTIONS);
88        Self {
89            audience: DEFAULT_AUDIENCE.to_string(),
90            max_connections,
91            acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
92            refresh_interval: DEFAULT_REFRESH_INTERVAL,
93        }
94    }
95
96    /// Override the token audience/scope. Required for sovereign clouds
97    /// (e.g., Azure US Government uses
98    /// `https://ossrdbms-aad.database.usgovcloudapi.net/.default`).
99    pub fn audience(mut self, audience: impl Into<String>) -> Self {
100        self.audience = audience.into();
101        self
102    }
103
104    /// Override the pool's maximum connection count.
105    ///
106    /// A value of `0` is silently clamped to `1` because the pool's
107    /// hardcoded `min_connections(1)` would otherwise reject the
108    /// configuration at runtime (SF-G).
109    pub fn max_connections(mut self, max: u32) -> Self {
110        self.max_connections = max.max(1);
111        self
112    }
113
114    /// Override the pool's connection acquisition timeout.
115    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
116        self.acquire_timeout = timeout;
117        self
118    }
119
120    /// Override the upper bound on time between refresh attempts. The
121    /// background refresh task may refresh sooner than this interval if a
122    /// token's `expires_at` is approaching.
123    pub fn refresh_interval(mut self, interval: Duration) -> Self {
124        self.refresh_interval = interval;
125        self
126    }
127
128    /// Internal: read accessor for the audience.
129    pub(crate) fn audience_str(&self) -> &str {
130        &self.audience
131    }
132
133    /// Internal: read accessor for max pool connections.
134    pub(crate) fn max_connections_value(&self) -> u32 {
135        self.max_connections
136    }
137
138    /// Internal: read accessor for acquire timeout.
139    pub(crate) fn acquire_timeout_value(&self) -> Duration {
140        self.acquire_timeout
141    }
142
143    /// Internal: read accessor for refresh-interval ceiling.
144    pub(crate) fn refresh_interval_value(&self) -> Duration {
145        self.refresh_interval
146    }
147
148    /// Construct the default [`TokenSource`] (the [`AzureIdentityTokenSource`]
149    /// wrapping the chained credential).
150    ///
151    /// Surfaces a descriptive error if the underlying Azure SDK fails to build
152    /// any credential in the chain.
153    pub(crate) fn default_token_source(&self) -> Result<Arc<dyn TokenSource>> {
154        let credential =
155            build_default_chained_credential().context("Entra credential resolution failed")?;
156        Ok(Arc::new(AzureIdentityTokenSource::new(credential)))
157    }
158}
159
160/// An Entra access token plus the wall-clock time at which it expires.
161///
162/// `Debug` is hand-written to redact the token secret. The struct is
163/// `pub(crate)` so it never leaves the crate, but a panic backtrace or
164/// `?token` formatter inside the crate could otherwise leak the bearer
165/// string into logs.
166#[derive(Clone)]
167pub(crate) struct EntraToken {
168    pub(crate) secret: String,
169    pub(crate) expires_at: SystemTime,
170}
171
172impl std::fmt::Debug for EntraToken {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        f.debug_struct("EntraToken")
175            .field("secret", &"<redacted>")
176            .field("expires_at", &self.expires_at)
177            .finish()
178    }
179}
180
181impl EntraToken {
182    pub(crate) fn new(secret: String, expires_at: SystemTime) -> Self {
183        Self { secret, expires_at }
184    }
185}
186
187/// Internal seam for fetching Entra tokens. Kept `pub(crate)` so the public
188/// surface of `duroxide-pg` does not bind callers to a specific `azure_*`
189/// version (mitigation for spec Risk #2). Tests in this crate inject fake
190/// implementations via this trait.
191#[async_trait]
192pub(crate) trait TokenSource: Send + Sync {
193    async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken>;
194}
195
196/// Production [`TokenSource`] that delegates to an
197/// [`azure_core::credentials::TokenCredential`].
198pub(crate) struct AzureIdentityTokenSource {
199    credential: Arc<dyn TokenCredential>,
200}
201
202impl AzureIdentityTokenSource {
203    pub(crate) fn new(credential: Arc<dyn TokenCredential>) -> Self {
204        Self { credential }
205    }
206}
207
208#[async_trait]
209impl TokenSource for AzureIdentityTokenSource {
210    async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
211        let access = self
212            .credential
213            .get_token(scopes, None)
214            .await
215            .map_err(|e| anyhow::anyhow!("Entra token acquisition failed: {e}"))?;
216
217        let expires_at = offset_datetime_to_system_time(access.expires_on);
218        validate_token_freshness(
219            SystemTime::now(),
220            expires_at,
221            crate::provider::ENTRA_REFRESH_SAFETY_MARGIN,
222        )?;
223        Ok(EntraToken::new(
224            access.token.secret().to_string(),
225            expires_at,
226        ))
227    }
228}
229
230/// Convert an `azure_core::time::OffsetDateTime` (from `time` crate) into a
231/// `SystemTime`. Treats pre-epoch values as `UNIX_EPOCH` (effectively immediate
232/// expiry).
233fn offset_datetime_to_system_time(t: azure_core::time::OffsetDateTime) -> SystemTime {
234    let seconds = t.unix_timestamp();
235    if seconds < 0 {
236        return UNIX_EPOCH;
237    }
238    UNIX_EPOCH
239        .checked_add(Duration::from_secs(seconds as u64))
240        .unwrap_or(UNIX_EPOCH)
241}
242
243/// Pure helper: validate that a freshly-issued token's `expires_at` leaves
244/// at least `margin` of useful lifetime relative to `now`. Defends against
245/// upstream returning an already-expired (or near-expired) token, which
246/// would otherwise produce an immediate connection-storm of 28xxx errors
247/// at the pool boundary (SF-D).
248///
249/// Returns `Ok(())` if the token is fresh enough, or an error explaining
250/// the skew/staleness on rejection.
251pub(crate) fn validate_token_freshness(
252    now: SystemTime,
253    expires_at: SystemTime,
254    margin: Duration,
255) -> Result<()> {
256    let cutoff = now
257        .checked_add(margin)
258        .ok_or_else(|| anyhow::anyhow!("clock arithmetic overflow validating token freshness"))?;
259    if expires_at <= cutoff {
260        let secs_remaining = expires_at
261            .duration_since(now)
262            .map(|d| d.as_secs() as i64)
263            .unwrap_or_else(|e| -(e.duration().as_secs() as i64));
264        anyhow::bail!(
265            "Entra token rejected: expires_at is too close to now \
266             (remaining={}s, required margin={}s). Possible upstream SDK \
267             bug, clock skew on the credential issuer, or stale cached \
268             token.",
269            secs_remaining,
270            margin.as_secs(),
271        );
272    }
273    Ok(())
274}
275
276/// Build the default chained credential used when the caller does not provide
277/// an explicit token source. The chain mimics the spirit of upstream
278/// `DefaultAzureCredential` (which is not present in `azure_identity = 0.35`):
279///
280/// 1. `WorkloadIdentityCredential` — federated tokens for AKS Workload
281///    Identity, GitHub OIDC, etc. (only succeeds when the corresponding env
282///    vars are set).
283/// 2. `ManagedIdentityCredential` — IMDS for Azure VMs, App Service,
284///    Container Apps, Container Instances, Functions.
285/// 3. `DeveloperToolsCredential` — Azure CLI (`az login`) and Azure Developer
286///    CLI (`azd auth login`) for local development.
287///
288/// Returns an `azure_core::Error` from any failed inner constructor; the
289/// caller wraps it into an `anyhow::Error`.
290fn build_default_chained_credential() -> azure_core::Result<Arc<dyn TokenCredential>> {
291    let mut sources: Vec<(&'static str, Arc<dyn TokenCredential>)> = Vec::new();
292    // WorkloadIdentityCredential::new only succeeds when AZURE_FEDERATED_TOKEN_FILE
293    // (and friends) are set. Skip silently otherwise so the chain still works
294    // on a developer laptop. Note: leaving these env vars set on a NON-AKS host
295    // can cause the chain to spend per-refresh latency budget here before
296    // falling through; see Docs.md "Troubleshooting" for details.
297    if let Ok(workload) = WorkloadIdentityCredential::new(None) {
298        sources.push(("WorkloadIdentityCredential", workload));
299    }
300    sources.push((
301        "ManagedIdentityCredential",
302        ManagedIdentityCredential::new(None)?,
303    ));
304    sources.push((
305        "DeveloperToolsCredential",
306        DeveloperToolsCredential::new(None)?,
307    ));
308    Ok(Arc::new(ChainedCredential::new(sources)))
309}
310
311/// Tiny `TokenCredential` chain — tries each source in order, returns the
312/// first that succeeds, aggregates errors otherwise. Modeled after upstream
313/// [`DeveloperToolsCredential`]'s chaining behavior.
314///
315/// Sources are stored alongside a static class-name so the first successful
316/// source can be logged at INFO — useful in production to confirm the
317/// expected principal class is being used (e.g. AKS pods should show
318/// `WorkloadIdentityCredential`, not `DeveloperToolsCredential`).
319///
320/// The INFO log fires **once per `ChainedCredential` instance** (gated via
321/// `OnceLock`) so a long-running process emits exactly one credential-class
322/// disclosure per provider, regardless of how many refreshes occur (SF-A).
323struct ChainedCredential {
324    sources: Vec<(&'static str, Arc<dyn TokenCredential>)>,
325    logged_first_success: std::sync::OnceLock<()>,
326}
327
328impl ChainedCredential {
329    fn new(sources: Vec<(&'static str, Arc<dyn TokenCredential>)>) -> Self {
330        Self {
331            sources,
332            logged_first_success: std::sync::OnceLock::new(),
333        }
334    }
335}
336
337impl std::fmt::Debug for ChainedCredential {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.write_str("ChainedCredential")
340    }
341}
342
343#[async_trait]
344impl TokenCredential for ChainedCredential {
345    async fn get_token(
346        &self,
347        scopes: &[&str],
348        options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
349    ) -> azure_core::Result<azure_core::credentials::AccessToken> {
350        let mut errors: Vec<String> = Vec::new();
351        for (name, source) in &self.sources {
352            match source.get_token(scopes, options.clone()).await {
353                Ok(token) => {
354                    if self.logged_first_success.set(()).is_ok() {
355                        tracing::info!(
356                            target: "duroxide::providers::postgres",
357                            credential = %name,
358                            "Entra credential chain: token acquired (first success on this instance)",
359                        );
360                    }
361                    return Ok(token);
362                }
363                Err(e) => errors.push(format!("{name}: {e}")),
364            }
365        }
366        Err(azure_core::Error::with_message_fn(
367            azure_core::error::ErrorKind::Credential,
368            move || {
369                format!(
370                    "All chained Entra credentials failed to acquire a token:\n  - {}",
371                    errors.join("\n  - ")
372                )
373            },
374        ))
375    }
376}
377
378#[cfg(test)]
379pub(crate) mod test_support {
380    //! Test fixtures shared across crate unit tests (entra and provider).
381    use super::*;
382    use std::sync::atomic::{AtomicUsize, Ordering};
383    use std::sync::Mutex;
384
385    /// `TokenSource` fixture used across `entra` and `provider` unit tests.
386    /// Returns successive tokens from a script and records the scopes passed
387    /// to each call.
388    pub(crate) struct RecordingFakeTokenSource {
389        scripted: Mutex<Vec<EntraToken>>,
390        recorded_scopes: Mutex<Vec<Vec<String>>>,
391        call_count: AtomicUsize,
392        fail_with: Option<String>,
393    }
394
395    impl RecordingFakeTokenSource {
396        pub(crate) fn with_tokens(tokens: Vec<EntraToken>) -> Arc<Self> {
397            Arc::new(Self {
398                scripted: Mutex::new(tokens),
399                recorded_scopes: Mutex::new(Vec::new()),
400                call_count: AtomicUsize::new(0),
401                fail_with: None,
402            })
403        }
404
405        pub(crate) fn always_failing(message: impl Into<String>) -> Arc<Self> {
406            Arc::new(Self {
407                scripted: Mutex::new(Vec::new()),
408                recorded_scopes: Mutex::new(Vec::new()),
409                call_count: AtomicUsize::new(0),
410                fail_with: Some(message.into()),
411            })
412        }
413
414        pub(crate) fn call_count(&self) -> usize {
415            self.call_count.load(Ordering::SeqCst)
416        }
417
418        pub(crate) fn recorded_scopes(&self) -> Vec<Vec<String>> {
419            self.recorded_scopes.lock().unwrap().clone()
420        }
421    }
422
423    #[async_trait]
424    impl TokenSource for RecordingFakeTokenSource {
425        async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
426            self.call_count.fetch_add(1, Ordering::SeqCst);
427            self.recorded_scopes
428                .lock()
429                .unwrap()
430                .push(scopes.iter().map(|s| s.to_string()).collect());
431            if let Some(msg) = &self.fail_with {
432                return Err(anyhow::anyhow!("{msg}"));
433            }
434            let mut scripted = self.scripted.lock().unwrap();
435            if scripted.is_empty() {
436                return Err(anyhow::anyhow!(
437                    "RecordingFakeTokenSource: script exhausted"
438                ));
439            }
440            Ok(scripted.remove(0))
441        }
442    }
443
444    pub(crate) fn token(secret: &str, expires_in_secs: u64) -> EntraToken {
445        EntraToken::new(
446            secret.to_string(),
447            SystemTime::now() + Duration::from_secs(expires_in_secs),
448        )
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::test_support::*;
455    use super::*;
456
457    #[test]
458    fn defaults_match_password_path() {
459        let opts = EntraAuthOptions::new();
460        assert_eq!(opts.audience_str(), DEFAULT_AUDIENCE);
461        assert_eq!(opts.max_connections_value(), 10);
462        assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(30));
463        assert_eq!(opts.refresh_interval_value(), DEFAULT_REFRESH_INTERVAL);
464    }
465
466    #[test]
467    fn audience_override_round_trips() {
468        let opts = EntraAuthOptions::new()
469            .audience("https://ossrdbms-aad.database.usgovcloudapi.net/.default");
470        assert_eq!(
471            opts.audience_str(),
472            "https://ossrdbms-aad.database.usgovcloudapi.net/.default"
473        );
474    }
475
476    #[test]
477    fn pool_tunables_round_trip() {
478        let opts = EntraAuthOptions::new()
479            .max_connections(5)
480            .acquire_timeout(Duration::from_secs(45))
481            .refresh_interval(Duration::from_secs(120));
482        assert_eq!(opts.max_connections_value(), 5);
483        assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(45));
484        assert_eq!(opts.refresh_interval_value(), Duration::from_secs(120));
485    }
486
487    #[tokio::test]
488    async fn fake_token_source_returns_scripted_tokens() {
489        let source = RecordingFakeTokenSource::with_tokens(vec![
490            token("first", 3600),
491            token("second", 3600),
492        ]);
493        let t1 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
494        let t2 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
495        assert_eq!(t1.secret, "first");
496        assert_eq!(t2.secret, "second");
497        assert_eq!(source.call_count(), 2);
498        assert_eq!(
499            source.recorded_scopes(),
500            vec![
501                vec![DEFAULT_AUDIENCE.to_string()],
502                vec![DEFAULT_AUDIENCE.to_string()]
503            ]
504        );
505    }
506
507    #[tokio::test]
508    async fn fake_token_source_propagates_failures() {
509        let source = RecordingFakeTokenSource::always_failing("simulated");
510        let err = source
511            .fetch_token(&[DEFAULT_AUDIENCE])
512            .await
513            .expect_err("must fail");
514        assert!(err.to_string().contains("simulated"));
515        assert_eq!(source.call_count(), 1);
516    }
517
518    #[test]
519    fn offset_datetime_conversion_handles_negative() {
520        // OffsetDateTime::UNIX_EPOCH is the canonical zero; pre-epoch values
521        // should clamp to UNIX_EPOCH rather than panic / underflow.
522        let pre_epoch =
523            azure_core::time::OffsetDateTime::UNIX_EPOCH - azure_core::time::Duration::seconds(60);
524        assert_eq!(offset_datetime_to_system_time(pre_epoch), UNIX_EPOCH);
525
526        let post_epoch =
527            azure_core::time::OffsetDateTime::UNIX_EPOCH + azure_core::time::Duration::seconds(120);
528        let converted = offset_datetime_to_system_time(post_epoch);
529        assert_eq!(converted, UNIX_EPOCH + Duration::from_secs(120));
530    }
531
532    /// Minimal `TokenCredential` stub for chain ordering tests.
533    #[derive(Debug)]
534    struct StubCred {
535        ok: bool,
536        label: &'static str,
537    }
538
539    #[async_trait]
540    impl TokenCredential for StubCred {
541        async fn get_token(
542            &self,
543            _scopes: &[&str],
544            _options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
545        ) -> azure_core::Result<azure_core::credentials::AccessToken> {
546            if self.ok {
547                Ok(azure_core::credentials::AccessToken::new(
548                    azure_core::credentials::Secret::new(format!("token-from-{}", self.label)),
549                    azure_core::time::OffsetDateTime::UNIX_EPOCH
550                        + azure_core::time::Duration::seconds(3_700_000_000),
551                ))
552            } else {
553                Err(azure_core::Error::with_message(
554                    azure_core::error::ErrorKind::Credential,
555                    format!("{} failed", self.label),
556                ))
557            }
558        }
559    }
560
561    #[tokio::test]
562    async fn chained_credential_returns_first_success_in_chain_order() {
563        let chain = ChainedCredential::new(vec![
564            (
565                "Failing",
566                Arc::new(StubCred {
567                    ok: false,
568                    label: "Failing",
569                }),
570            ),
571            (
572                "Winner",
573                Arc::new(StubCred {
574                    ok: true,
575                    label: "Winner",
576                }),
577            ),
578            (
579                "ShouldNotBeCalled",
580                Arc::new(StubCred {
581                    ok: true,
582                    label: "ShouldNotBeCalled",
583                }),
584            ),
585        ]);
586        let token = chain.get_token(&["aud"], None).await.unwrap();
587        assert_eq!(token.token.secret(), "token-from-Winner");
588    }
589
590    #[tokio::test]
591    async fn chained_credential_aggregates_class_names_in_failure_message() {
592        let chain = ChainedCredential::new(vec![
593            (
594                "Workload",
595                Arc::new(StubCred {
596                    ok: false,
597                    label: "WorkloadIdentity",
598                }),
599            ),
600            (
601                "Managed",
602                Arc::new(StubCred {
603                    ok: false,
604                    label: "ManagedIdentity",
605                }),
606            ),
607            (
608                "Dev",
609                Arc::new(StubCred {
610                    ok: false,
611                    label: "DeveloperTools",
612                }),
613            ),
614        ]);
615        let err = chain.get_token(&["aud"], None).await.expect_err("all fail");
616        let msg = format!("{err}");
617        assert!(msg.contains("Workload"), "{msg}");
618        assert!(msg.contains("Managed"), "{msg}");
619        assert!(msg.contains("Dev"), "{msg}");
620    }
621
622    /// SF-A: INFO log fires only on the first successful chain call per
623    /// instance, not on every refresh. We can't easily intercept `tracing`
624    /// output from a unit test without pulling in a subscriber crate, so
625    /// we instead pin the *gating mechanism* directly: the OnceLock must
626    /// be empty before the first call and populated after, and stay
627    /// populated after subsequent calls.
628    #[tokio::test]
629    async fn chained_credential_logs_first_success_only_once() {
630        let chain = ChainedCredential::new(vec![(
631            "Winner",
632            Arc::new(StubCred {
633                ok: true,
634                label: "Winner",
635            }),
636        )]);
637        assert!(
638            chain.logged_first_success.get().is_none(),
639            "should start unset"
640        );
641        let _ = chain.get_token(&["aud"], None).await.unwrap();
642        assert!(
643            chain.logged_first_success.get().is_some(),
644            "OnceLock must be populated after first success",
645        );
646        // Subsequent calls re-succeed but must NOT re-arm or re-log. The
647        // observable signal here is that `set` would have returned `Err`
648        // had we tried again — but get_token already swallows that. We
649        // verify by checking the OnceLock is still populated and the
650        // call still succeeds.
651        let _ = chain.get_token(&["aud"], None).await.unwrap();
652        assert!(chain.logged_first_success.get().is_some());
653    }
654
655    // SF-D: validate_token_freshness rejects stale/expired tokens.
656    #[test]
657    fn validate_token_freshness_rejects_already_expired_token() {
658        let now = SystemTime::now();
659        let expires_at = now - Duration::from_secs(10); // 10s in the past
660        let err = validate_token_freshness(now, expires_at, Duration::from_secs(60))
661            .expect_err("must reject");
662        let msg = format!("{err}");
663        assert!(msg.contains("too close to now"), "{msg}");
664        assert!(msg.contains("clock skew"), "{msg}");
665    }
666
667    #[test]
668    fn validate_token_freshness_rejects_token_within_safety_margin() {
669        let now = SystemTime::now();
670        // Token expires in 60s but margin requires 5min — must reject.
671        let expires_at = now + Duration::from_secs(60);
672        let err = validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
673            .expect_err("must reject");
674        assert!(format!("{err}").contains("too close to now"));
675    }
676
677    #[test]
678    fn validate_token_freshness_accepts_fresh_token() {
679        let now = SystemTime::now();
680        let expires_at = now + Duration::from_secs(3600);
681        validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
682            .expect("must accept fresh token");
683    }
684
685    #[test]
686    fn validate_token_freshness_rejects_at_exact_cutoff() {
687        // Boundary: token expiring exactly at now+margin should be rejected
688        // (we use strict `>`, not `>=`, because a token whose lifetime is
689        // exactly the margin has no useful working window).
690        let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000);
691        let margin = Duration::from_secs(60);
692        let expires_at = now + margin;
693        validate_token_freshness(now, expires_at, margin).expect_err("must reject at exact cutoff");
694    }
695
696    // SF-G: max_connections(0) clamps to 1 to satisfy the
697    // hardcoded min_connections(1) invariant on the pool builder.
698    #[test]
699    fn max_connections_zero_is_clamped_to_one() {
700        let opts = EntraAuthOptions::new().max_connections(0);
701        assert_eq!(opts.max_connections_value(), 1);
702    }
703
704    #[test]
705    fn max_connections_one_is_preserved() {
706        let opts = EntraAuthOptions::new().max_connections(1);
707        assert_eq!(opts.max_connections_value(), 1);
708    }
709}