use super::{CredentialContext, Credentials, IdentityProvider};
use anyhow::Result;
use async_trait::async_trait;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
type AsyncCredentialCallback = dyn Fn(&CredentialContext) -> Pin<Box<dyn Future<Output = Result<Credentials>> + Send>>
+ Send
+ Sync;
#[derive(Clone)]
pub struct ApplicationIdentityProvider {
callback: Arc<AsyncCredentialCallback>,
}
impl ApplicationIdentityProvider {
pub fn new<F, Fut>(callback: F) -> Self
where
F: Fn(&CredentialContext) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Credentials>> + Send + 'static,
{
let cb: Arc<AsyncCredentialCallback> =
Arc::new(move |ctx| Box::pin(callback(ctx)) as Pin<Box<_>>);
Self { callback: cb }
}
pub fn new_sync<F>(callback: F) -> Self
where
F: Fn(&CredentialContext) -> Result<Credentials> + Send + Sync + 'static,
{
let cb: Arc<AsyncCredentialCallback> = Arc::new(move |ctx| {
let result = callback(ctx);
Box::pin(async move { result })
});
Self { callback: cb }
}
}
impl std::fmt::Debug for ApplicationIdentityProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApplicationIdentityProvider")
.finish_non_exhaustive()
}
}
#[async_trait]
impl IdentityProvider for ApplicationIdentityProvider {
async fn get_credentials(&self, context: &CredentialContext) -> Result<Credentials> {
(self.callback)(context).await
}
fn clone_box(&self) -> Box<dyn IdentityProvider> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn sync_closure_returns_username_password() {
let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
Ok(Credentials::UsernamePassword {
username: "alice".into(),
password: "s3cret".into(),
})
});
let creds = provider
.get_credentials(&CredentialContext::new())
.await
.unwrap();
assert_eq!(
creds,
Credentials::UsernamePassword {
username: "alice".into(),
password: "s3cret".into(),
}
);
}
#[tokio::test]
async fn async_closure_returns_token() {
let provider = ApplicationIdentityProvider::new(|_ctx| async {
Ok(Credentials::Token {
username: "svc".into(),
token: "abc.def.ghi".into(),
})
});
let creds = provider
.get_credentials(&CredentialContext::new())
.await
.unwrap();
assert_eq!(
creds,
Credentials::Token {
username: "svc".into(),
token: "abc.def.ghi".into(),
}
);
}
#[tokio::test]
async fn callback_observes_credential_context() {
let provider = ApplicationIdentityProvider::new_sync(|ctx| {
let host = ctx.get("hostname").unwrap_or("none").to_string();
let port = ctx.get("port").unwrap_or("0").to_string();
Ok(Credentials::UsernamePassword {
username: format!("user@{host}:{port}"),
password: "pw".into(),
})
});
let ctx = CredentialContext::new()
.with_property("hostname", "db.example.com")
.with_property("port", "5432");
let creds = provider.get_credentials(&ctx).await.unwrap();
assert_eq!(
creds,
Credentials::UsernamePassword {
username: "user@db.example.com:5432".into(),
password: "pw".into(),
}
);
}
#[tokio::test]
async fn clone_box_shares_underlying_callback() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_cb = calls.clone();
let provider = ApplicationIdentityProvider::new_sync(move |_ctx| {
calls_for_cb.fetch_add(1, Ordering::SeqCst);
Ok(Credentials::UsernamePassword {
username: "u".into(),
password: "p".into(),
})
});
let cloned = provider.clone_box();
provider
.get_credentials(&CredentialContext::new())
.await
.unwrap();
cloned
.get_credentials(&CredentialContext::new())
.await
.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn callback_error_is_propagated() {
let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
Err(anyhow::anyhow!("auth backend unavailable"))
});
let err = provider
.get_credentials(&CredentialContext::new())
.await
.unwrap_err();
assert!(err.to_string().contains("auth backend unavailable"));
}
#[tokio::test]
async fn sync_closure_returns_certificate() {
let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
Ok(Credentials::Certificate {
cert_pem: "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----".into(),
key_pem: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----".into(),
username: Some("cert-user".into()),
})
});
let creds = provider
.get_credentials(&CredentialContext::new())
.await
.unwrap();
assert!(creds.is_certificate());
assert_eq!(
creds,
Credentials::Certificate {
cert_pem: "-----BEGIN CERTIFICATE-----\nMIIB...\n-----END CERTIFICATE-----".into(),
key_pem: "-----BEGIN PRIVATE KEY-----\nMIIE...\n-----END PRIVATE KEY-----".into(),
username: Some("cert-user".into()),
}
);
}
#[tokio::test]
async fn debug_impl_does_not_leak_callback_state() {
let provider = ApplicationIdentityProvider::new_sync(|_ctx| {
Ok(Credentials::UsernamePassword {
username: "should-not-appear".into(),
password: "super-secret".into(),
})
});
let formatted = format!("{provider:?}");
assert!(formatted.contains("ApplicationIdentityProvider"));
assert!(!formatted.contains("super-secret"));
assert!(!formatted.contains("should-not-appear"));
}
}