use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::Error;
pub type CredentialsFuture<'a> = Pin<Box<dyn Future<Output = Result<String, Error>> + Send + 'a>>;
pub trait CredentialsProvider: Send + Sync {
fn get_token(&self) -> CredentialsFuture<'_>;
fn refresh_hint(&self) -> Option<std::time::Duration> {
None
}
fn supports_refresh(&self) -> bool {
false
}
}
impl<T: CredentialsProvider + ?Sized> CredentialsProvider for Arc<T> {
fn get_token(&self) -> CredentialsFuture<'_> {
(**self).get_token()
}
fn refresh_hint(&self) -> Option<std::time::Duration> {
(**self).refresh_hint()
}
fn supports_refresh(&self) -> bool {
(**self).supports_refresh()
}
}
impl<T: CredentialsProvider + ?Sized> CredentialsProvider for Box<T> {
fn get_token(&self) -> CredentialsFuture<'_> {
(**self).get_token()
}
fn refresh_hint(&self) -> Option<std::time::Duration> {
(**self).refresh_hint()
}
fn supports_refresh(&self) -> bool {
(**self).supports_refresh()
}
}
#[derive(Debug, Clone)]
pub struct StaticTokenProvider {
token: Arc<str>,
}
impl StaticTokenProvider {
pub fn new(token: impl Into<String>) -> Self {
Self {
token: Arc::from(token.into()),
}
}
}
impl CredentialsProvider for StaticTokenProvider {
fn get_token(&self) -> CredentialsFuture<'_> {
let token = self.token.clone();
Box::pin(async move { Ok(token.to_string()) })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_static_token_provider() {
let provider = StaticTokenProvider::new("test_token");
let token = provider.get_token().await.unwrap();
assert_eq!(token, "test_token");
}
#[tokio::test]
async fn test_static_token_provider_multiple_calls() {
let provider = StaticTokenProvider::new("consistent");
let token1 = provider.get_token().await.unwrap();
let token2 = provider.get_token().await.unwrap();
assert_eq!(token1, token2);
}
#[test]
fn test_static_token_provider_defaults() {
let provider = StaticTokenProvider::new("token");
assert!(provider.refresh_hint().is_none());
assert!(!provider.supports_refresh());
}
#[tokio::test]
async fn test_arc_provider() {
let provider: Arc<dyn CredentialsProvider> =
Arc::new(StaticTokenProvider::new("arc_token"));
let token = provider.get_token().await.unwrap();
assert_eq!(token, "arc_token");
}
#[tokio::test]
async fn test_box_provider() {
let provider: Box<dyn CredentialsProvider> =
Box::new(StaticTokenProvider::new("box_token"));
let token = provider.get_token().await.unwrap();
assert_eq!(token, "box_token");
}
struct CustomProvider {
counter: std::sync::atomic::AtomicU32,
}
impl CustomProvider {
fn new() -> Self {
Self {
counter: std::sync::atomic::AtomicU32::new(0),
}
}
}
impl CredentialsProvider for CustomProvider {
fn get_token(&self) -> CredentialsFuture<'_> {
let count = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Box::pin(async move { Ok(format!("token_{}", count)) })
}
fn supports_refresh(&self) -> bool {
true
}
fn refresh_hint(&self) -> Option<std::time::Duration> {
Some(std::time::Duration::from_secs(300))
}
}
#[tokio::test]
async fn test_custom_provider() {
let provider = CustomProvider::new();
assert!(provider.supports_refresh());
assert_eq!(
provider.refresh_hint(),
Some(std::time::Duration::from_secs(300))
);
let token1 = provider.get_token().await.unwrap();
let token2 = provider.get_token().await.unwrap();
assert_eq!(token1, "token_0");
assert_eq!(token2, "token_1");
}
#[tokio::test]
async fn test_arc_provider_delegations() {
let provider: Arc<dyn CredentialsProvider> = Arc::new(CustomProvider::new());
assert!(provider.supports_refresh());
assert_eq!(
provider.refresh_hint(),
Some(std::time::Duration::from_secs(300))
);
let token = provider.get_token().await.unwrap();
assert_eq!(token, "token_0");
}
#[tokio::test]
async fn test_box_provider_delegations() {
let provider: Box<dyn CredentialsProvider> = Box::new(CustomProvider::new());
assert!(provider.supports_refresh());
assert_eq!(
provider.refresh_hint(),
Some(std::time::Duration::from_secs(300))
);
let token = provider.get_token().await.unwrap();
assert_eq!(token, "token_0");
}
#[test]
fn test_static_token_provider_debug() {
let provider = StaticTokenProvider::new("test");
let debug = format!("{:?}", provider);
assert!(debug.contains("StaticTokenProvider"));
}
#[test]
fn test_static_token_provider_clone() {
let provider = StaticTokenProvider::new("clone_test");
let cloned = provider.clone();
assert_eq!(format!("{:?}", provider), format!("{:?}", cloned));
}
}