use std::sync::Arc;
use tokio::sync::Mutex;
use crate::connection::auth::{AuthType, Credentials, ScopeCredentials, TokenAuth};
use crate::connection::client::DatabaseClient;
use crate::error::{Result, SurqlError};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenState {
pub token: TokenAuth,
pub auth_type: AuthType,
}
#[derive(Debug, Clone, Default)]
pub struct AuthManager {
inner: Arc<AuthManagerInner>,
}
#[derive(Debug, Default)]
struct AuthManagerInner {
state: Mutex<Option<TokenState>>,
}
impl AuthManager {
pub fn new() -> Self {
Self::default()
}
pub async fn signin<C: Credentials + ?Sized>(
&self,
client: &DatabaseClient,
creds: &C,
) -> Result<TokenAuth> {
let auth_type = creds.auth_type();
let token = client.signin(creds).await?;
*self.inner.state.lock().await = Some(TokenState {
token: token.clone(),
auth_type,
});
Ok(token)
}
pub async fn signup(
&self,
client: &DatabaseClient,
creds: &ScopeCredentials,
) -> Result<TokenAuth> {
let token = client.signup(creds).await?;
*self.inner.state.lock().await = Some(TokenState {
token: token.clone(),
auth_type: AuthType::Scope,
});
Ok(token)
}
pub async fn authenticate(&self, client: &DatabaseClient, token: &str) -> Result<()> {
client.authenticate(token).await?;
let mut slot = self.inner.state.lock().await;
let preserved = slot.as_ref().map_or(AuthType::Scope, |s| s.auth_type);
*slot = Some(TokenState {
token: TokenAuth::new(token.to_owned()),
auth_type: preserved,
});
Ok(())
}
pub async fn invalidate(&self, client: &DatabaseClient) -> Result<()> {
client.invalidate().await?;
*self.inner.state.lock().await = None;
Ok(())
}
pub async fn refresh(&self, client: &DatabaseClient) -> Result<TokenAuth> {
let cached = self
.current_token()
.await
.ok_or_else(|| SurqlError::Context {
reason: "no cached token to refresh".into(),
})?;
client.authenticate(&cached.token).await?;
Ok(cached)
}
pub async fn current_token(&self) -> Option<TokenAuth> {
self.inner
.state
.lock()
.await
.as_ref()
.map(|s| s.token.clone())
}
pub async fn auth_type(&self) -> Option<AuthType> {
self.inner.state.lock().await.as_ref().map(|s| s.auth_type)
}
pub async fn is_authenticated(&self) -> bool {
self.inner.state.lock().await.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::config::ConnectionConfig;
#[tokio::test]
async fn default_manager_is_empty() {
let am = AuthManager::new();
assert!(!am.is_authenticated().await);
assert!(am.current_token().await.is_none());
assert!(am.auth_type().await.is_none());
}
#[tokio::test]
async fn refresh_without_token_errors() {
let am = AuthManager::new();
let client =
DatabaseClient::new(ConnectionConfig::default()).expect("default config is valid");
let err = am.refresh(&client).await.unwrap_err();
assert!(matches!(err, SurqlError::Context { .. }));
}
#[tokio::test]
async fn signin_against_disconnected_client_errors() {
use crate::connection::auth::RootCredentials;
let am = AuthManager::new();
let client =
DatabaseClient::new(ConnectionConfig::default()).expect("default config is valid");
let err = am
.signin(&client, &RootCredentials::new("root", "root"))
.await
.unwrap_err();
assert!(matches!(err, SurqlError::Connection { .. }));
assert!(!am.is_authenticated().await);
}
}