use std::future::Future;
use crate::{AuthError, AuthStrategy, ServiceToken};
pub struct AuthStrategyFn<F> {
get_token: F,
}
impl<F> AuthStrategyFn<F> {
pub fn new(get_token: F) -> Self {
Self { get_token }
}
}
#[cfg(not(target_arch = "wasm32"))]
impl<F, Fut> AuthStrategy for &AuthStrategyFn<F>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<ServiceToken, AuthError>> + Send,
{
fn get_token(self) -> impl Future<Output = Result<ServiceToken, AuthError>> + Send {
(self.get_token)()
}
}
#[cfg(target_arch = "wasm32")]
impl<F, Fut> AuthStrategy for &AuthStrategyFn<F>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<ServiceToken, AuthError>>,
{
fn get_token(self) -> impl Future<Output = Result<ServiceToken, AuthError>> {
(self.get_token)()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::SecretToken;
use super::*;
fn dummy_service_token(jwt: &str) -> ServiceToken {
ServiceToken::new(SecretToken::new(jwt.to_string()))
}
#[tokio::test]
async fn closure_runs_on_each_get_token_call() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_clone = Arc::clone(&calls);
let strategy = AuthStrategyFn::new(move || {
let calls = Arc::clone(&calls_clone);
async move {
let n = calls.fetch_add(1, Ordering::SeqCst);
Ok(dummy_service_token(&format!("jwt-{n}")))
}
});
let first = (&strategy).get_token().await.unwrap();
assert_eq!(
first.as_str(),
"jwt-0",
"first call should yield the first token the closure produced"
);
let second = (&strategy).get_token().await.unwrap();
assert_eq!(
second.as_str(),
"jwt-1",
"second call should re-invoke the closure"
);
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"closure should have fired exactly twice"
);
}
#[tokio::test]
async fn closure_errors_propagate_unchanged() {
let strategy = AuthStrategyFn::new(|| async { Err(AuthError::AccessDenied) });
let err = (&strategy).get_token().await.unwrap_err();
assert!(
matches!(err, AuthError::AccessDenied),
"AccessDenied from the closure should surface verbatim, got: {err:?}"
);
}
}