use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use anyhow::{Context, Result};
use async_trait::async_trait;
use azure_core::credentials::TokenCredential;
use azure_identity::{
DeveloperToolsCredential, ManagedIdentityCredential, WorkloadIdentityCredential,
};
pub const DEFAULT_AUDIENCE: &str = "https://ossrdbms-aad.database.windows.net/.default";
const DEFAULT_MAX_CONNECTIONS: u32 = 10;
const DEFAULT_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_REFRESH_INTERVAL: Duration = Duration::from_secs(20 * 60);
#[derive(Clone, Debug)]
pub struct EntraAuthOptions {
audience: String,
max_connections: u32,
acquire_timeout: Duration,
refresh_interval: Duration,
}
impl Default for EntraAuthOptions {
fn default() -> Self {
Self::new()
}
}
impl EntraAuthOptions {
pub fn new() -> Self {
let max_connections = std::env::var("DUROXIDE_PG_POOL_MAX")
.ok()
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(DEFAULT_MAX_CONNECTIONS);
Self {
audience: DEFAULT_AUDIENCE.to_string(),
max_connections,
acquire_timeout: DEFAULT_ACQUIRE_TIMEOUT,
refresh_interval: DEFAULT_REFRESH_INTERVAL,
}
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.audience = audience.into();
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.max_connections = max.max(1);
self
}
pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn refresh_interval(mut self, interval: Duration) -> Self {
self.refresh_interval = interval;
self
}
pub(crate) fn audience_str(&self) -> &str {
&self.audience
}
pub(crate) fn max_connections_value(&self) -> u32 {
self.max_connections
}
pub(crate) fn acquire_timeout_value(&self) -> Duration {
self.acquire_timeout
}
pub(crate) fn refresh_interval_value(&self) -> Duration {
self.refresh_interval
}
pub(crate) fn default_token_source(&self) -> Result<Arc<dyn TokenSource>> {
let credential =
build_default_chained_credential().context("Entra credential resolution failed")?;
Ok(Arc::new(AzureIdentityTokenSource::new(credential)))
}
}
#[derive(Clone)]
pub(crate) struct EntraToken {
pub(crate) secret: String,
pub(crate) expires_at: SystemTime,
}
impl std::fmt::Debug for EntraToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EntraToken")
.field("secret", &"<redacted>")
.field("expires_at", &self.expires_at)
.finish()
}
}
impl EntraToken {
pub(crate) fn new(secret: String, expires_at: SystemTime) -> Self {
Self { secret, expires_at }
}
}
#[async_trait]
pub(crate) trait TokenSource: Send + Sync {
async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken>;
}
pub(crate) struct AzureIdentityTokenSource {
credential: Arc<dyn TokenCredential>,
}
impl AzureIdentityTokenSource {
pub(crate) fn new(credential: Arc<dyn TokenCredential>) -> Self {
Self { credential }
}
}
#[async_trait]
impl TokenSource for AzureIdentityTokenSource {
async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
let access = self
.credential
.get_token(scopes, None)
.await
.map_err(|e| anyhow::anyhow!("Entra token acquisition failed: {e}"))?;
let expires_at = offset_datetime_to_system_time(access.expires_on);
validate_token_freshness(
SystemTime::now(),
expires_at,
crate::provider::ENTRA_REFRESH_SAFETY_MARGIN,
)?;
Ok(EntraToken::new(
access.token.secret().to_string(),
expires_at,
))
}
}
fn offset_datetime_to_system_time(t: azure_core::time::OffsetDateTime) -> SystemTime {
let seconds = t.unix_timestamp();
if seconds < 0 {
return UNIX_EPOCH;
}
UNIX_EPOCH
.checked_add(Duration::from_secs(seconds as u64))
.unwrap_or(UNIX_EPOCH)
}
pub(crate) fn validate_token_freshness(
now: SystemTime,
expires_at: SystemTime,
margin: Duration,
) -> Result<()> {
let cutoff = now
.checked_add(margin)
.ok_or_else(|| anyhow::anyhow!("clock arithmetic overflow validating token freshness"))?;
if expires_at <= cutoff {
let secs_remaining = expires_at
.duration_since(now)
.map(|d| d.as_secs() as i64)
.unwrap_or_else(|e| -(e.duration().as_secs() as i64));
anyhow::bail!(
"Entra token rejected: expires_at is too close to now \
(remaining={}s, required margin={}s). Possible upstream SDK \
bug, clock skew on the credential issuer, or stale cached \
token.",
secs_remaining,
margin.as_secs(),
);
}
Ok(())
}
fn build_default_chained_credential() -> azure_core::Result<Arc<dyn TokenCredential>> {
let mut sources: Vec<(&'static str, Arc<dyn TokenCredential>)> = Vec::new();
if let Ok(workload) = WorkloadIdentityCredential::new(None) {
sources.push(("WorkloadIdentityCredential", workload));
}
sources.push((
"ManagedIdentityCredential",
ManagedIdentityCredential::new(None)?,
));
sources.push((
"DeveloperToolsCredential",
DeveloperToolsCredential::new(None)?,
));
Ok(Arc::new(ChainedCredential::new(sources)))
}
struct ChainedCredential {
sources: Vec<(&'static str, Arc<dyn TokenCredential>)>,
logged_first_success: std::sync::OnceLock<()>,
}
impl ChainedCredential {
fn new(sources: Vec<(&'static str, Arc<dyn TokenCredential>)>) -> Self {
Self {
sources,
logged_first_success: std::sync::OnceLock::new(),
}
}
}
impl std::fmt::Debug for ChainedCredential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("ChainedCredential")
}
}
#[async_trait]
impl TokenCredential for ChainedCredential {
async fn get_token(
&self,
scopes: &[&str],
options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
) -> azure_core::Result<azure_core::credentials::AccessToken> {
let mut errors: Vec<String> = Vec::new();
for (name, source) in &self.sources {
match source.get_token(scopes, options.clone()).await {
Ok(token) => {
if self.logged_first_success.set(()).is_ok() {
tracing::info!(
target: "duroxide::providers::postgres",
credential = %name,
"Entra credential chain: token acquired (first success on this instance)",
);
}
return Ok(token);
}
Err(e) => errors.push(format!("{name}: {e}")),
}
}
Err(azure_core::Error::with_message_fn(
azure_core::error::ErrorKind::Credential,
move || {
format!(
"All chained Entra credentials failed to acquire a token:\n - {}",
errors.join("\n - ")
)
},
))
}
}
#[cfg(test)]
pub(crate) mod test_support {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
pub(crate) struct RecordingFakeTokenSource {
scripted: Mutex<Vec<EntraToken>>,
recorded_scopes: Mutex<Vec<Vec<String>>>,
call_count: AtomicUsize,
fail_with: Option<String>,
}
impl RecordingFakeTokenSource {
pub(crate) fn with_tokens(tokens: Vec<EntraToken>) -> Arc<Self> {
Arc::new(Self {
scripted: Mutex::new(tokens),
recorded_scopes: Mutex::new(Vec::new()),
call_count: AtomicUsize::new(0),
fail_with: None,
})
}
pub(crate) fn always_failing(message: impl Into<String>) -> Arc<Self> {
Arc::new(Self {
scripted: Mutex::new(Vec::new()),
recorded_scopes: Mutex::new(Vec::new()),
call_count: AtomicUsize::new(0),
fail_with: Some(message.into()),
})
}
pub(crate) fn call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
pub(crate) fn recorded_scopes(&self) -> Vec<Vec<String>> {
self.recorded_scopes.lock().unwrap().clone()
}
}
#[async_trait]
impl TokenSource for RecordingFakeTokenSource {
async fn fetch_token(&self, scopes: &[&str]) -> Result<EntraToken> {
self.call_count.fetch_add(1, Ordering::SeqCst);
self.recorded_scopes
.lock()
.unwrap()
.push(scopes.iter().map(|s| s.to_string()).collect());
if let Some(msg) = &self.fail_with {
return Err(anyhow::anyhow!("{msg}"));
}
let mut scripted = self.scripted.lock().unwrap();
if scripted.is_empty() {
return Err(anyhow::anyhow!(
"RecordingFakeTokenSource: script exhausted"
));
}
Ok(scripted.remove(0))
}
}
pub(crate) fn token(secret: &str, expires_in_secs: u64) -> EntraToken {
EntraToken::new(
secret.to_string(),
SystemTime::now() + Duration::from_secs(expires_in_secs),
)
}
}
#[cfg(test)]
mod tests {
use super::test_support::*;
use super::*;
#[test]
fn defaults_match_password_path() {
let opts = EntraAuthOptions::new();
assert_eq!(opts.audience_str(), DEFAULT_AUDIENCE);
assert_eq!(opts.max_connections_value(), 10);
assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(30));
assert_eq!(opts.refresh_interval_value(), DEFAULT_REFRESH_INTERVAL);
}
#[test]
fn audience_override_round_trips() {
let opts = EntraAuthOptions::new()
.audience("https://ossrdbms-aad.database.usgovcloudapi.net/.default");
assert_eq!(
opts.audience_str(),
"https://ossrdbms-aad.database.usgovcloudapi.net/.default"
);
}
#[test]
fn pool_tunables_round_trip() {
let opts = EntraAuthOptions::new()
.max_connections(5)
.acquire_timeout(Duration::from_secs(45))
.refresh_interval(Duration::from_secs(120));
assert_eq!(opts.max_connections_value(), 5);
assert_eq!(opts.acquire_timeout_value(), Duration::from_secs(45));
assert_eq!(opts.refresh_interval_value(), Duration::from_secs(120));
}
#[tokio::test]
async fn fake_token_source_returns_scripted_tokens() {
let source = RecordingFakeTokenSource::with_tokens(vec![
token("first", 3600),
token("second", 3600),
]);
let t1 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
let t2 = source.fetch_token(&[DEFAULT_AUDIENCE]).await.unwrap();
assert_eq!(t1.secret, "first");
assert_eq!(t2.secret, "second");
assert_eq!(source.call_count(), 2);
assert_eq!(
source.recorded_scopes(),
vec![
vec![DEFAULT_AUDIENCE.to_string()],
vec![DEFAULT_AUDIENCE.to_string()]
]
);
}
#[tokio::test]
async fn fake_token_source_propagates_failures() {
let source = RecordingFakeTokenSource::always_failing("simulated");
let err = source
.fetch_token(&[DEFAULT_AUDIENCE])
.await
.expect_err("must fail");
assert!(err.to_string().contains("simulated"));
assert_eq!(source.call_count(), 1);
}
#[test]
fn offset_datetime_conversion_handles_negative() {
let pre_epoch =
azure_core::time::OffsetDateTime::UNIX_EPOCH - azure_core::time::Duration::seconds(60);
assert_eq!(offset_datetime_to_system_time(pre_epoch), UNIX_EPOCH);
let post_epoch =
azure_core::time::OffsetDateTime::UNIX_EPOCH + azure_core::time::Duration::seconds(120);
let converted = offset_datetime_to_system_time(post_epoch);
assert_eq!(converted, UNIX_EPOCH + Duration::from_secs(120));
}
#[derive(Debug)]
struct StubCred {
ok: bool,
label: &'static str,
}
#[async_trait]
impl TokenCredential for StubCred {
async fn get_token(
&self,
_scopes: &[&str],
_options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
) -> azure_core::Result<azure_core::credentials::AccessToken> {
if self.ok {
Ok(azure_core::credentials::AccessToken::new(
azure_core::credentials::Secret::new(format!("token-from-{}", self.label)),
azure_core::time::OffsetDateTime::UNIX_EPOCH
+ azure_core::time::Duration::seconds(3_700_000_000),
))
} else {
Err(azure_core::Error::with_message(
azure_core::error::ErrorKind::Credential,
format!("{} failed", self.label),
))
}
}
}
#[tokio::test]
async fn chained_credential_returns_first_success_in_chain_order() {
let chain = ChainedCredential::new(vec![
(
"Failing",
Arc::new(StubCred {
ok: false,
label: "Failing",
}),
),
(
"Winner",
Arc::new(StubCred {
ok: true,
label: "Winner",
}),
),
(
"ShouldNotBeCalled",
Arc::new(StubCred {
ok: true,
label: "ShouldNotBeCalled",
}),
),
]);
let token = chain.get_token(&["aud"], None).await.unwrap();
assert_eq!(token.token.secret(), "token-from-Winner");
}
#[tokio::test]
async fn chained_credential_aggregates_class_names_in_failure_message() {
let chain = ChainedCredential::new(vec![
(
"Workload",
Arc::new(StubCred {
ok: false,
label: "WorkloadIdentity",
}),
),
(
"Managed",
Arc::new(StubCred {
ok: false,
label: "ManagedIdentity",
}),
),
(
"Dev",
Arc::new(StubCred {
ok: false,
label: "DeveloperTools",
}),
),
]);
let err = chain.get_token(&["aud"], None).await.expect_err("all fail");
let msg = format!("{err}");
assert!(msg.contains("Workload"), "{msg}");
assert!(msg.contains("Managed"), "{msg}");
assert!(msg.contains("Dev"), "{msg}");
}
#[tokio::test]
async fn chained_credential_logs_first_success_only_once() {
let chain = ChainedCredential::new(vec![(
"Winner",
Arc::new(StubCred {
ok: true,
label: "Winner",
}),
)]);
assert!(
chain.logged_first_success.get().is_none(),
"should start unset"
);
let _ = chain.get_token(&["aud"], None).await.unwrap();
assert!(
chain.logged_first_success.get().is_some(),
"OnceLock must be populated after first success",
);
let _ = chain.get_token(&["aud"], None).await.unwrap();
assert!(chain.logged_first_success.get().is_some());
}
#[test]
fn validate_token_freshness_rejects_already_expired_token() {
let now = SystemTime::now();
let expires_at = now - Duration::from_secs(10); let err = validate_token_freshness(now, expires_at, Duration::from_secs(60))
.expect_err("must reject");
let msg = format!("{err}");
assert!(msg.contains("too close to now"), "{msg}");
assert!(msg.contains("clock skew"), "{msg}");
}
#[test]
fn validate_token_freshness_rejects_token_within_safety_margin() {
let now = SystemTime::now();
let expires_at = now + Duration::from_secs(60);
let err = validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
.expect_err("must reject");
assert!(format!("{err}").contains("too close to now"));
}
#[test]
fn validate_token_freshness_accepts_fresh_token() {
let now = SystemTime::now();
let expires_at = now + Duration::from_secs(3600);
validate_token_freshness(now, expires_at, Duration::from_secs(5 * 60))
.expect("must accept fresh token");
}
#[test]
fn validate_token_freshness_rejects_at_exact_cutoff() {
let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000);
let margin = Duration::from_secs(60);
let expires_at = now + margin;
validate_token_freshness(now, expires_at, margin).expect_err("must reject at exact cutoff");
}
#[test]
fn max_connections_zero_is_clamped_to_one() {
let opts = EntraAuthOptions::new().max_connections(0);
assert_eq!(opts.max_connections_value(), 1);
}
#[test]
fn max_connections_one_is_preserved() {
let opts = EntraAuthOptions::new().max_connections(1);
assert_eq!(opts.max_connections_value(), 1);
}
}