use std::{
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
sync::{Arc, atomic::AtomicBool},
};
use crate::Error;
#[cfg(feature = "auth-poll")]
pub mod poll;
#[cfg(feature = "auth-username-password")]
pub mod username_password;
#[derive(Debug, Clone)]
pub enum Auth {
#[cfg(feature = "auth-poll")]
Poll(poll::PollAuth),
#[cfg(feature = "auth-username-password")]
UsernamePassword(username_password::UsernamePasswordAuth),
None,
}
impl<T> From<Option<T>> for Auth
where
T: Into<Self>,
{
fn from(value: Option<T>) -> Self {
value.map_or(Self::None, Into::into)
}
}
pub trait AuthExt {
#[cfg(feature = "auth-poll")]
fn as_poll(&self) -> Option<&poll::PollAuth>;
#[cfg(feature = "auth-poll")]
fn into_poll(self) -> Option<poll::PollAuth>;
#[cfg(feature = "auth-username-password")]
fn as_username_password(&self) -> Option<&username_password::UsernamePasswordAuth>;
#[cfg(feature = "auth-username-password")]
fn into_username_password(self) -> Option<username_password::UsernamePasswordAuth>;
}
impl Auth {
#[cfg(feature = "auth-poll")]
#[must_use]
pub fn as_poll(&self) -> Option<&poll::PollAuth> {
<Self as AuthExt>::as_poll(self)
}
#[cfg(feature = "auth-poll")]
#[must_use]
pub fn into_poll(self) -> Option<poll::PollAuth> {
<Self as AuthExt>::into_poll(self)
}
#[cfg(feature = "auth-username-password")]
#[must_use]
pub fn as_username_password(&self) -> Option<&username_password::UsernamePasswordAuth> {
<Self as AuthExt>::as_username_password(self)
}
#[cfg(feature = "auth-username-password")]
#[must_use]
pub fn into_username_password(self) -> Option<username_password::UsernamePasswordAuth> {
<Self as AuthExt>::into_username_password(self)
}
}
impl AuthExt for Auth {
#[cfg(feature = "auth-poll")]
fn as_poll(&self) -> Option<&poll::PollAuth> {
let Self::Poll(x) = self else {
return None;
};
Some(x)
}
#[cfg(feature = "auth-poll")]
fn into_poll(self) -> Option<poll::PollAuth> {
let Self::Poll(x) = self else {
return None;
};
Some(x)
}
#[cfg(feature = "auth-username-password")]
fn as_username_password(&self) -> Option<&username_password::UsernamePasswordAuth> {
let Self::UsernamePassword(x) = self else {
return None;
};
Some(x)
}
#[cfg(feature = "auth-username-password")]
fn into_username_password(self) -> Option<username_password::UsernamePasswordAuth> {
let Self::UsernamePassword(x) = self else {
return None;
};
Some(x)
}
}
#[derive(Clone)]
pub struct ApiAuthBuilder {
auth: Option<Auth>,
logged_in: Option<bool>,
validate_credentials: Option<
Arc<
dyn Fn() -> Pin<
Box<
dyn Future<Output = Result<bool, Box<dyn std::error::Error + Send>>> + Send,
>,
> + Send
+ Sync,
>,
>,
}
impl std::fmt::Debug for ApiAuthBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApiAuthBuilder")
.field("auth", &self.auth)
.field("logged_in", &self.logged_in)
.finish_non_exhaustive()
}
}
impl Default for ApiAuthBuilder {
fn default() -> Self {
Self::new()
}
}
impl ApiAuthBuilder {
#[must_use]
pub const fn new() -> Self {
Self {
auth: None,
logged_in: None,
validate_credentials: None,
}
}
#[must_use]
#[allow(clippy::missing_const_for_fn)]
pub fn without_auth(mut self) -> Self {
self.auth = Some(Auth::None);
self
}
#[must_use]
pub fn with_auth(mut self, auth: impl Into<Auth>) -> Self {
self.auth = Some(auth.into());
self
}
pub fn auth(&mut self, auth: impl Into<Auth>) -> &mut Self {
self.auth = Some(auth.into());
self
}
#[must_use]
pub const fn with_logged_in(mut self, logged_in: bool) -> Self {
self.logged_in = Some(logged_in);
self
}
#[must_use]
pub fn with_validate_credentials<
Fut: Future<Output = Result<bool, Box<dyn std::error::Error + Send>>> + Send + 'static,
Func: Fn() -> Fut + Send + Sync + 'static,
>(
mut self,
validate_credentials: Func,
) -> Self {
self.validate_credentials = Some(Arc::new(move || Box::pin(validate_credentials())));
self
}
#[must_use]
pub fn build(self) -> ApiAuth {
let auth = self.auth.unwrap();
let logged_in = Arc::new(AtomicBool::new(self.logged_in.unwrap_or(false)));
ApiAuth {
logged_in,
auth,
validate_credentials: self.validate_credentials,
}
}
}
#[derive(Clone)]
pub struct ApiAuth {
logged_in: Arc<AtomicBool>,
auth: Auth,
validate_credentials: Option<
Arc<
dyn Fn() -> Pin<
Box<
dyn Future<Output = Result<bool, Box<dyn std::error::Error + Send>>> + Send,
>,
> + Send
+ Sync,
>,
>,
}
impl std::fmt::Debug for ApiAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ApiAuth")
.field("logged_in", &self.logged_in)
.field("auth", &self.auth)
.finish_non_exhaustive()
}
}
impl ApiAuth {
#[must_use]
pub const fn builder() -> ApiAuthBuilder {
ApiAuthBuilder::new()
}
#[allow(clippy::unused_async)]
pub async fn is_logged_in(&self) -> Result<bool, Error> {
Ok(self.logged_in.load(std::sync::atomic::Ordering::SeqCst))
}
pub fn set_logged_in(&self, logged_in: bool) {
self.logged_in
.store(logged_in, std::sync::atomic::Ordering::SeqCst);
}
pub async fn validate_credentials(&self) -> Result<bool, Box<dyn std::error::Error + Send>> {
if let Some(validate_credentials) = &self.validate_credentials {
match validate_credentials().await {
Ok(valid) => self.set_logged_in(valid),
Err(e) => {
self.set_logged_in(false);
return Err(e);
}
}
}
Ok(false)
}
pub async fn attempt_login<
Fut: Future<Output = Result<bool, Box<dyn std::error::Error + Send>>> + Send + 'static,
Func: Fn(&Auth) -> Fut + Send + Sync + 'static,
>(
&self,
func: Func,
) -> Result<bool, Box<dyn std::error::Error + Send>> {
let logged_in = func(&self.auth).await?;
self.logged_in
.store(logged_in, std::sync::atomic::Ordering::SeqCst);
Ok(logged_in)
}
#[cfg(feature = "auth-poll")]
#[must_use]
pub fn as_poll(&self) -> Option<&poll::PollAuth> {
<Self as AuthExt>::as_poll(self)
}
#[cfg(feature = "auth-poll")]
#[must_use]
pub fn into_poll(self) -> Option<poll::PollAuth> {
<Self as AuthExt>::into_poll(self)
}
#[cfg(feature = "auth-username-password")]
#[must_use]
pub fn as_username_password(&self) -> Option<&username_password::UsernamePasswordAuth> {
<Self as AuthExt>::as_username_password(self)
}
#[cfg(feature = "auth-username-password")]
#[must_use]
pub fn into_username_password(self) -> Option<username_password::UsernamePasswordAuth> {
<Self as AuthExt>::into_username_password(self)
}
}
impl AuthExt for ApiAuth {
#[cfg(feature = "auth-poll")]
fn as_poll(&self) -> Option<&poll::PollAuth> {
self.auth.as_poll()
}
#[cfg(feature = "auth-poll")]
fn into_poll(self) -> Option<poll::PollAuth> {
self.auth.into_poll()
}
#[cfg(feature = "auth-username-password")]
fn as_username_password(&self) -> Option<&username_password::UsernamePasswordAuth> {
self.auth.as_username_password()
}
#[cfg(feature = "auth-username-password")]
fn into_username_password(self) -> Option<username_password::UsernamePasswordAuth> {
self.auth.into_username_password()
}
}
impl Deref for ApiAuth {
type Target = Auth;
fn deref(&self) -> &Self::Target {
&self.auth
}
}
impl DerefMut for ApiAuth {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.auth
}
}
#[cfg(test)]
mod test {
use super::{ApiAuth, Auth};
#[test_log::test(switchy_async::test)]
async fn api_auth_builder_builds_with_no_auth() {
let auth = ApiAuth::builder().without_auth().build();
assert!(matches!(*auth, Auth::None));
}
#[test_log::test(switchy_async::test)]
async fn api_auth_builder_sets_logged_in_state() {
let auth = ApiAuth::builder()
.without_auth()
.with_logged_in(true)
.build();
let is_logged_in = auth.is_logged_in().await.unwrap();
assert!(is_logged_in);
}
#[test_log::test(switchy_async::test)]
async fn api_auth_set_logged_in_updates_state() {
let auth = ApiAuth::builder()
.without_auth()
.with_logged_in(false)
.build();
assert!(!auth.is_logged_in().await.unwrap());
auth.set_logged_in(true);
assert!(auth.is_logged_in().await.unwrap());
auth.set_logged_in(false);
assert!(!auth.is_logged_in().await.unwrap());
}
#[test_log::test(switchy_async::test)]
async fn api_auth_validate_credentials_returns_false_when_no_validator() {
let auth = ApiAuth::builder().without_auth().build();
let result = auth.validate_credentials().await.unwrap();
assert!(!result);
}
#[test_log::test(switchy_async::test)]
async fn api_auth_validate_credentials_calls_validator_and_updates_state() {
let auth = ApiAuth::builder()
.without_auth()
.with_validate_credentials(|| async { Ok(true) })
.build();
assert!(!auth.is_logged_in().await.unwrap());
auth.validate_credentials().await.unwrap();
assert!(auth.is_logged_in().await.unwrap());
}
#[test_log::test(switchy_async::test)]
async fn api_auth_validate_credentials_sets_logged_out_on_error() {
let auth = ApiAuth::builder()
.without_auth()
.with_logged_in(true)
.with_validate_credentials(|| async {
Err(Box::new(std::io::Error::other("validation failed"))
as Box<dyn std::error::Error + Send>)
})
.build();
assert!(auth.is_logged_in().await.unwrap());
let result = auth.validate_credentials().await;
assert!(result.is_err());
assert!(!auth.is_logged_in().await.unwrap());
}
#[test_log::test(switchy_async::test)]
async fn api_auth_attempt_login_updates_logged_in_state_on_success() {
let auth = ApiAuth::builder().without_auth().build();
let result = auth.attempt_login(|_| async { Ok(true) }).await.unwrap();
assert!(result);
assert!(auth.is_logged_in().await.unwrap());
}
#[test_log::test(switchy_async::test)]
async fn api_auth_attempt_login_sets_logged_out_on_failure() {
let auth = ApiAuth::builder()
.without_auth()
.with_logged_in(true)
.build();
let result = auth.attempt_login(|_| async { Ok(false) }).await.unwrap();
assert!(!result);
assert!(!auth.is_logged_in().await.unwrap());
}
#[test_log::test(switchy_async::test)]
async fn api_auth_attempt_login_propagates_error() {
let auth = ApiAuth::builder().without_auth().build();
let result = auth
.attempt_login(|_| async {
Err(Box::new(std::io::Error::other("login failed"))
as Box<dyn std::error::Error + Send>)
})
.await;
assert!(result.is_err());
}
#[test_log::test]
fn auth_from_option_none_converts_to_auth_none() {
let auth: Auth = None::<Auth>.into();
assert!(matches!(auth, Auth::None));
}
#[test_log::test]
fn api_auth_builder_auth_mutable_method_sets_auth() {
let mut builder = super::ApiAuthBuilder::new();
builder.auth(Auth::None);
let api_auth = builder.build();
assert!(matches!(*api_auth, Auth::None));
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn auth_as_poll_returns_some_for_poll_variant() {
use super::poll::PollAuth;
let poll = PollAuth::new();
let auth = Auth::Poll(poll);
assert!(auth.as_poll().is_some());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn auth_as_poll_returns_none_for_other_variants() {
let auth = Auth::None;
assert!(auth.as_poll().is_none());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn auth_as_username_password_returns_some_for_username_password_variant() {
use super::username_password::UsernamePasswordAuth;
let up_auth = UsernamePasswordAuth::builder()
.with_handler(|_u, _p| async { Ok(true) })
.build()
.unwrap();
let auth = Auth::UsernamePassword(up_auth);
assert!(auth.as_username_password().is_some());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn auth_as_username_password_returns_none_for_other_variants() {
let auth = Auth::None;
assert!(auth.as_username_password().is_none());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn auth_into_poll_returns_some_for_poll_variant() {
use super::poll::PollAuth;
let poll = PollAuth::new();
let auth = Auth::Poll(poll);
assert!(auth.into_poll().is_some());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn auth_into_poll_returns_none_for_other_variants() {
let auth = Auth::None;
assert!(auth.into_poll().is_none());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn auth_into_username_password_returns_some_for_username_password_variant() {
use super::username_password::UsernamePasswordAuth;
let up_auth = UsernamePasswordAuth::builder()
.with_handler(|_u, _p| async { Ok(true) })
.build()
.unwrap();
let auth = Auth::UsernamePassword(up_auth);
assert!(auth.into_username_password().is_some());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn auth_into_username_password_returns_none_for_other_variants() {
let auth = Auth::None;
assert!(auth.into_username_password().is_none());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn api_auth_into_poll_returns_some_for_poll_variant() {
use super::poll::PollAuth;
let poll = PollAuth::new();
let api_auth = ApiAuth::builder().with_auth(poll).build();
assert!(api_auth.into_poll().is_some());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn api_auth_into_poll_returns_none_for_other_variants() {
let api_auth = ApiAuth::builder().without_auth().build();
assert!(api_auth.into_poll().is_none());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn api_auth_into_username_password_returns_some_for_username_password_variant() {
use super::username_password::UsernamePasswordAuth;
let up_auth = UsernamePasswordAuth::builder()
.with_handler(|_u, _p| async { Ok(true) })
.build()
.unwrap();
let api_auth = ApiAuth::builder().with_auth(up_auth).build();
assert!(api_auth.into_username_password().is_some());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn api_auth_into_username_password_returns_none_for_other_variants() {
let api_auth = ApiAuth::builder().without_auth().build();
assert!(api_auth.into_username_password().is_none());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn api_auth_as_poll_returns_some_for_poll_variant() {
use super::poll::PollAuth;
let poll = PollAuth::new();
let api_auth = ApiAuth::builder().with_auth(poll).build();
assert!(api_auth.as_poll().is_some());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn api_auth_as_poll_returns_none_for_other_variants() {
let api_auth = ApiAuth::builder().without_auth().build();
assert!(api_auth.as_poll().is_none());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn api_auth_as_username_password_returns_some_for_username_password_variant() {
use super::username_password::UsernamePasswordAuth;
let up_auth = UsernamePasswordAuth::builder()
.with_handler(|_u, _p| async { Ok(true) })
.build()
.unwrap();
let api_auth = ApiAuth::builder().with_auth(up_auth).build();
assert!(api_auth.as_username_password().is_some());
}
#[cfg(feature = "auth-username-password")]
#[test_log::test]
fn api_auth_as_username_password_returns_none_for_other_variants() {
let api_auth = ApiAuth::builder().without_auth().build();
assert!(api_auth.as_username_password().is_none());
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn auth_from_option_some_converts_to_wrapped_auth() {
use super::poll::PollAuth;
let poll = PollAuth::new();
let auth: Auth = Some(Auth::Poll(poll)).into();
assert!(matches!(auth, Auth::Poll(_)));
}
#[test_log::test]
fn api_auth_deref_returns_inner_auth() {
let api_auth = ApiAuth::builder().without_auth().build();
let auth_ref: &Auth = &api_auth;
assert!(matches!(auth_ref, Auth::None));
}
#[cfg(feature = "auth-poll")]
#[test_log::test]
fn api_auth_deref_mut_allows_modifying_inner_auth() {
use super::poll::PollAuth;
let mut api_auth = ApiAuth::builder().without_auth().build();
assert!(matches!(*api_auth, Auth::None));
*api_auth = Auth::Poll(PollAuth::new());
assert!(matches!(*api_auth, Auth::Poll(_)));
}
#[test_log::test(switchy_async::test)]
async fn api_auth_validate_credentials_sets_logged_in_to_false_when_validator_returns_false() {
let auth = ApiAuth::builder()
.without_auth()
.with_logged_in(true)
.with_validate_credentials(|| async { Ok(false) })
.build();
assert!(auth.is_logged_in().await.unwrap());
auth.validate_credentials().await.unwrap();
assert!(!auth.is_logged_in().await.unwrap());
}
}