use crate::error::{AuthError, Result};
use async_trait::async_trait;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthMethod {
OAuth,
ApiKey,
}
#[derive(Debug, Clone)]
pub enum AuthProgress {
NeedInput(String),
InProgress(String),
Complete,
Error(String),
OAuthStarted { auth_url: String },
}
#[async_trait]
pub trait AuthenticationFlow: Send + Sync {
type State: Send + Sync;
fn available_methods(&self) -> Vec<AuthMethod>;
async fn start_auth(&self, method: AuthMethod) -> Result<Self::State>;
async fn get_initial_progress(
&self,
state: &Self::State,
method: AuthMethod,
) -> Result<AuthProgress>;
async fn handle_input(&self, state: &mut Self::State, input: &str) -> Result<AuthProgress>;
async fn is_authenticated(&self) -> Result<bool>;
fn provider_name(&self) -> String;
}
#[async_trait]
pub trait DynAuthenticationFlow: Send + Sync {
fn available_methods(&self) -> Vec<AuthMethod>;
async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>>;
async fn get_initial_progress(
&self,
state: &Box<dyn std::any::Any + Send + Sync>,
method: AuthMethod,
) -> Result<AuthProgress>;
async fn handle_input(
&self,
state: &mut Box<dyn std::any::Any + Send + Sync>,
input: &str,
) -> Result<AuthProgress>;
async fn is_authenticated(&self) -> Result<bool>;
fn provider_name(&self) -> String;
}
pub struct AuthFlowWrapper<T: AuthenticationFlow> {
inner: T,
}
impl<T: AuthenticationFlow> AuthFlowWrapper<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}
}
#[async_trait]
impl<T: AuthenticationFlow + 'static> DynAuthenticationFlow for AuthFlowWrapper<T>
where
T::State: 'static,
{
fn available_methods(&self) -> Vec<AuthMethod> {
self.inner.available_methods()
}
async fn start_auth(&self, method: AuthMethod) -> Result<Box<dyn std::any::Any + Send + Sync>> {
let state = self.inner.start_auth(method).await?;
Ok(Box::new(state))
}
async fn get_initial_progress(
&self,
state: &Box<dyn std::any::Any + Send + Sync>,
method: AuthMethod,
) -> Result<AuthProgress> {
let concrete_state = state
.downcast_ref::<T::State>()
.ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
self.inner
.get_initial_progress(concrete_state, method)
.await
}
async fn handle_input(
&self,
state: &mut Box<dyn std::any::Any + Send + Sync>,
input: &str,
) -> Result<AuthProgress> {
let concrete_state = state
.downcast_mut::<T::State>()
.ok_or_else(|| AuthError::InvalidResponse("Invalid state type".to_string()))?;
self.inner.handle_input(concrete_state, input).await
}
async fn is_authenticated(&self) -> Result<bool> {
self.inner.is_authenticated().await
}
fn provider_name(&self) -> String {
self.inner.provider_name()
}
}