pub mod providers;
use std::sync::Arc;
use oauth2::TokenResponse;
use providers::{Provider, UserInfo};
use torii_core::error::AuthError;
use torii_core::{Error, NewUser, Plugin, User, UserId, UserManager};
use torii_core::{
events::{Event, EventBus},
storage::OAuthStorage,
};
pub struct AuthorizationUrl {
url: String,
csrf_state: String,
}
impl AuthorizationUrl {
pub fn new(url: &str, csrf_state: &str) -> Self {
Self {
url: url.to_string(),
csrf_state: csrf_state.to_string(),
}
}
pub fn url(&self) -> &str {
&self.url
}
pub fn csrf_state(&self) -> &str {
&self.csrf_state
}
}
pub struct OAuthPlugin<M, S>
where
M: UserManager,
S: OAuthStorage,
{
provider: Provider,
user_manager: Arc<M>,
oauth_storage: Arc<S>,
event_bus: Option<EventBus>,
}
impl<M, S> Plugin for OAuthPlugin<M, S>
where
M: UserManager,
S: OAuthStorage,
{
fn name(&self) -> String {
self.provider.name().to_string()
}
}
pub struct OAuthPluginBuilder<M, S>
where
M: UserManager,
S: OAuthStorage,
{
provider: Provider,
user_manager: Arc<M>,
oauth_storage: Arc<S>,
event_bus: Option<EventBus>,
}
impl<M, S> OAuthPluginBuilder<M, S>
where
M: UserManager,
S: OAuthStorage,
{
pub fn new(provider: Provider, user_manager: Arc<M>, oauth_storage: Arc<S>) -> Self {
Self {
provider,
user_manager,
oauth_storage,
event_bus: None,
}
}
pub fn event_bus(mut self, event_bus: EventBus) -> Self {
self.event_bus = Some(event_bus);
self
}
pub fn build(self) -> OAuthPlugin<M, S> {
OAuthPlugin {
provider: self.provider,
user_manager: self.user_manager,
oauth_storage: self.oauth_storage,
event_bus: self.event_bus,
}
}
}
impl<M, S> OAuthPlugin<M, S>
where
M: UserManager,
S: OAuthStorage,
{
pub fn builder(
provider: Provider,
user_manager: Arc<M>,
oauth_storage: Arc<S>,
) -> OAuthPluginBuilder<M, S> {
OAuthPluginBuilder::new(provider, user_manager, oauth_storage)
}
pub fn new(provider: Provider, user_manager: Arc<M>, oauth_storage: Arc<S>) -> Self {
Self {
provider,
user_manager,
oauth_storage,
event_bus: None,
}
}
pub fn with_event_bus(mut self, event_bus: EventBus) -> Self {
self.event_bus = Some(event_bus);
self
}
pub fn google(
client_id: &str,
client_secret: &str,
redirect_uri: &str,
user_manager: Arc<M>,
oauth_storage: Arc<S>,
) -> Self {
OAuthPluginBuilder::new(
Provider::google(client_id, client_secret, redirect_uri),
user_manager,
oauth_storage,
)
.build()
}
pub fn github(
client_id: &str,
client_secret: &str,
redirect_uri: &str,
user_manager: Arc<M>,
oauth_storage: Arc<S>,
) -> Self {
OAuthPluginBuilder::new(
Provider::github(client_id, client_secret, redirect_uri),
user_manager,
oauth_storage,
)
.build()
}
}
impl<M, S> OAuthPlugin<M, S>
where
M: UserManager,
S: OAuthStorage,
{
pub async fn get_authorization_url(&self) -> Result<AuthorizationUrl, Error> {
let (authorization_url, pkce_verifier) = self.provider.get_authorization_url()?;
self.oauth_storage
.store_pkce_verifier(
&authorization_url.csrf_state,
&pkce_verifier,
chrono::Duration::minutes(5),
)
.await
.map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
Ok(authorization_url)
}
pub async fn get_or_create_user(&self, email: String, subject: String) -> Result<User, Error> {
let oauth_account = self
.oauth_storage
.get_oauth_account_by_provider_and_subject(self.provider.name(), &subject)
.await
.map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
if let Some(oauth_account) = oauth_account {
tracing::info!(
user_id = ?oauth_account.user_id,
"User already exists in database"
);
let user = self
.user_manager
.get_user(&oauth_account.user_id)
.await?
.ok_or(Error::Auth(AuthError::UserNotFound))?;
return Ok(user);
}
let new_user = NewUser::builder()
.id(UserId::new_random())
.email(email)
.email_verified_at(Some(chrono::Utc::now()))
.build()
.unwrap();
let user = self.user_manager.create_user(&new_user).await?;
self.oauth_storage
.create_oauth_account(self.provider.name(), &subject, &user.id)
.await
.map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
tracing::info!(
user_id = ?user.id,
provider = ?self.provider.name(),
subject = ?subject,
"Successfully created link between user and provider"
);
self.emit_event(&Event::UserCreated(user.clone())).await?;
Ok(user)
}
pub async fn exchange_code(
&self,
code: String,
csrf_state: String,
) -> Result<(User, UserInfo), Error> {
let pkce_verifier = self
.oauth_storage
.get_pkce_verifier(&csrf_state)
.await
.map_err(|_| Error::Auth(AuthError::InvalidCredentials))?
.ok_or(Error::Auth(AuthError::InvalidCredentials))?;
tracing::debug!(
pkce_verifier = ?pkce_verifier,
csrf_state = ?csrf_state,
"Exchanging code for token"
);
let token_response = self.provider.exchange_code(&code, &pkce_verifier).await?;
let access_token = token_response.access_token();
tracing::debug!(
access_token = ?access_token,
"Getting user info"
);
let user_info = self.provider.get_user_info(access_token.secret()).await?;
tracing::debug!(
user_info = ?user_info,
"Got user info"
);
let email = match &user_info {
UserInfo::Google(user_info) => user_info.email.clone(),
UserInfo::Github(user_info) => {
user_info.email.clone().expect("No email found for user")
}
};
let subject = match &user_info {
UserInfo::Google(user_info) => user_info.sub.clone(),
UserInfo::Github(user_info) => user_info.id.to_string(),
};
tracing::debug!(
email = ?email,
subject = ?subject,
"Getting or creating user"
);
let user = self
.get_or_create_user(email, subject)
.await
.map_err(|_| Error::Auth(AuthError::InvalidCredentials))?;
Ok((user, user_info))
}
async fn emit_event(&self, event: &Event) -> Result<(), Error> {
if let Some(event_bus) = &self.event_bus {
event_bus.emit(event).await?;
}
Ok(())
}
}