use http::{header, HeaderValue, Method, StatusCode};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
use super::shared::{
auth_session_cookies, current_session, error_response, json_response, query_param,
serialize_cookie, unauthorized,
};
use crate::api::{
create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AsyncAuthEndpoint,
AuthEndpointOptions, BodyField, BodySchema, JsonSchemaType, OpenApiOperation, PathParams,
};
use crate::auth::oauth::{
generate_oauth_state, handle_oauth_user_info, HandleOAuthUserInfoInput, OAuthAccountInput,
OAuthStateInput, OAuthStateLink, OAuthUserInfo,
};
use crate::db::DbAdapter;
use crate::error::OpenAuthError;
use crate::user::{CreateOAuthAccountInput, DbUserStore};
use openauth_oauth::oauth2::{
OAuth2Tokens, OAuth2UserInfo, SocialAuthorizationCodeRequest, SocialAuthorizationUrlRequest,
SocialIdTokenRequest, SocialOAuthProvider,
};
#[derive(Debug, Deserialize)]
struct SocialSignInBody {
provider: String,
#[serde(default, alias = "callbackURL")]
callback_url: Option<String>,
#[serde(default, alias = "errorCallbackURL")]
error_callback_url: Option<String>,
#[serde(default, alias = "newUserCallbackURL")]
new_user_callback_url: Option<String>,
#[serde(default, alias = "disableRedirect")]
disable_redirect: bool,
#[serde(default)]
scopes: Vec<String>,
#[serde(default, alias = "loginHint")]
login_hint: Option<String>,
#[serde(default, alias = "requestSignUp")]
request_sign_up: bool,
#[serde(default, alias = "additionalData")]
additional_data: Option<Value>,
#[serde(default, alias = "idToken")]
id_token: Option<IdTokenBody>,
}
#[derive(Debug, Deserialize)]
struct LinkSocialBody {
provider: String,
#[serde(default, alias = "callbackURL")]
callback_url: Option<String>,
#[serde(default, alias = "errorCallbackURL")]
error_callback_url: Option<String>,
#[serde(default, alias = "disableRedirect")]
disable_redirect: bool,
#[serde(default)]
scopes: Vec<String>,
#[serde(default, alias = "requestSignUp")]
request_sign_up: bool,
#[serde(default, alias = "additionalData")]
additional_data: Option<Value>,
#[serde(default, alias = "idToken")]
id_token: Option<IdTokenBody>,
}
#[derive(Debug, Clone, Deserialize)]
struct IdTokenBody {
token: String,
#[serde(default)]
nonce: Option<String>,
#[serde(default, alias = "accessToken")]
access_token: Option<String>,
#[serde(default, alias = "refreshToken")]
refresh_token: Option<String>,
#[serde(default)]
scopes: Vec<String>,
#[serde(default)]
user: Option<Value>,
}
#[derive(Debug, Serialize)]
struct SocialRedirectBody {
url: String,
redirect: bool,
}
#[derive(Debug, Serialize)]
struct SocialSessionBody {
redirect: bool,
token: String,
#[serde(skip_serializing_if = "Option::is_none")]
url: Option<String>,
user: crate::db::User,
}
#[derive(Debug, Serialize)]
struct LinkStatusBody {
url: String,
redirect: bool,
status: bool,
}
pub(super) fn sign_in_social_endpoint(adapter: Arc<dyn DbAdapter>) -> AsyncAuthEndpoint {
create_auth_endpoint(
"/sign-in/social",
Method::POST,
AuthEndpointOptions::new()
.operation_id("socialSignIn")
.allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
.body_schema(social_sign_in_body_schema())
.openapi(
OpenApiOperation::new("socialSignIn").description("Sign in with a social provider"),
),
move |context, request| {
let adapter = Arc::clone(&adapter);
Box::pin(async move {
let body: SocialSignInBody = parse_request_body(&request)?;
let provider = lookup_provider(context, &body.provider)?;
if let Some(id_token) = body.id_token {
return sign_in_with_id_token(context, adapter.as_ref(), provider, id_token)
.await;
}
let state = generate_oauth_state(
context,
Some(adapter.as_ref()),
OAuthStateInput {
callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
error_url: body.error_callback_url,
new_user_url: body.new_user_callback_url,
request_sign_up: body.request_sign_up,
additional_data: body.additional_data.unwrap_or(Value::Null),
..OAuthStateInput::default()
},
)
.await?;
let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
state: state.state,
redirect_uri: redirect_uri(context, provider.id()),
code_verifier: Some(state.data.code_verifier),
scopes: body.scopes,
login_hint: body.login_hint,
})?;
redirect_json_response(url.to_string(), !body.disable_redirect)
})
},
)
}
pub(super) fn callback_oauth_endpoint(
method: Method,
adapter: Arc<dyn DbAdapter>,
) -> AsyncAuthEndpoint {
create_auth_endpoint(
"/callback/:id",
method,
AuthEndpointOptions::new()
.operation_id("handleOAuthCallback")
.openapi(
OpenApiOperation::new("handleOAuthCallback").description("Handle OAuth callback"),
),
move |context, request| {
let adapter = Arc::clone(&adapter);
Box::pin(async move {
if request.method() == Method::POST {
return callback_post_redirect(context, &request);
}
callback_get(context, adapter.as_ref(), request).await
})
},
)
}
pub(super) fn link_social_endpoint(adapter: Arc<dyn DbAdapter>) -> AsyncAuthEndpoint {
create_auth_endpoint(
"/link-social",
Method::POST,
AuthEndpointOptions::new()
.operation_id("linkSocialAccount")
.allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
.body_schema(link_social_body_schema())
.openapi(
OpenApiOperation::new("linkSocialAccount").description("Link a social account"),
),
move |context, request| {
let adapter = Arc::clone(&adapter);
Box::pin(async move {
let Some((_session, user, _cookies)) =
current_session(adapter.as_ref(), context, &request).await?
else {
return unauthorized();
};
let body: LinkSocialBody = parse_request_body(&request)?;
let provider = lookup_provider(context, &body.provider)?;
if let Some(id_token) = body.id_token {
return link_with_id_token(
context,
adapter.as_ref(),
provider,
&user,
id_token,
)
.await;
}
let state = generate_oauth_state(
context,
Some(adapter.as_ref()),
OAuthStateInput {
callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
error_url: body.error_callback_url,
link: Some(OAuthStateLink {
user_id: user.id,
email: user.email,
}),
request_sign_up: body.request_sign_up,
additional_data: body.additional_data.unwrap_or(Value::Null),
..OAuthStateInput::default()
},
)
.await?;
let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
state: state.state,
redirect_uri: redirect_uri(context, provider.id()),
code_verifier: Some(state.data.code_verifier),
scopes: body.scopes,
login_hint: None,
})?;
redirect_json_response(url.to_string(), !body.disable_redirect)
})
},
)
}
async fn sign_in_with_id_token(
context: &crate::context::AuthContext,
adapter: &dyn DbAdapter,
provider: Arc<dyn SocialOAuthProvider>,
id_token: IdTokenBody,
) -> Result<ApiResponse, OpenAuthError> {
if !provider
.verify_id_token(SocialIdTokenRequest {
token: id_token.token.clone(),
nonce: id_token.nonce.clone(),
access_token: id_token.access_token.clone(),
refresh_token: id_token.refresh_token.clone(),
scopes: id_token.scopes.clone(),
provider_user: id_token.user.clone(),
})
.await?
{
return error_response(StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid token");
}
let tokens = tokens_from_id_token(&id_token);
let Some(user_info) = provider
.get_user_info(tokens.clone(), id_token.user.clone())
.await?
else {
return error_response(
StatusCode::UNAUTHORIZED,
"FAILED_TO_GET_USER_INFO",
"Failed to get user info",
);
};
let result = handle_oauth_user_info(
context,
adapter,
HandleOAuthUserInfoInput {
user_info: normalize_user_info(&user_info)?,
account: oauth_account(provider.id(), &user_info, &tokens, context)?,
disable_sign_up: provider.provider_options().disable_sign_up,
override_user_info: provider.provider_options().override_user_info_on_sign_in,
is_trusted_provider: true,
..HandleOAuthUserInfoInput::default()
},
)
.await?;
let Some(data) = result.data else {
return error_response(
StatusCode::UNAUTHORIZED,
"OAUTH_LINK_ERROR",
result
.error
.map_or("OAuth sign in failed".to_owned(), oauth_user_info_error),
);
};
let cookies = auth_session_cookies(context, &data.session, &data.user, false)?;
json_response(
StatusCode::OK,
&SocialSessionBody {
redirect: false,
token: data.session.token,
url: None,
user: data.user,
},
cookies,
)
}
async fn callback_get(
context: &crate::context::AuthContext,
adapter: &dyn DbAdapter,
request: ApiRequest,
) -> Result<ApiResponse, OpenAuthError> {
let provider_id = path_param(&request, "id")?;
let provider = lookup_provider(context, provider_id)?;
let default_error_url = format!("{}/error", context.base_url.trim_end_matches('/'));
let state = match query_param(&request, "state") {
Some(state) => state,
None => return redirect(&default_error_url, Vec::new()),
};
let state_data =
match crate::auth::oauth::parse_oauth_state(context, Some(adapter), &state).await {
Ok(data) => data,
Err(_) => return redirect_with_error(&default_error_url, "invalid_state"),
};
let error_url = state_data.error_url.clone().unwrap_or(default_error_url);
if let Some(error) = query_param(&request, "error") {
return redirect_with_error(&error_url, &error);
}
let Some(code) = query_param(&request, "code") else {
return redirect_with_error(&error_url, "no_code");
};
let tokens = match provider
.validate_authorization_code(SocialAuthorizationCodeRequest {
code,
code_verifier: Some(state_data.code_verifier),
redirect_uri: redirect_uri(context, provider.id()),
device_id: query_param(&request, "device_id"),
})
.await
{
Ok(tokens) => tokens,
Err(_) => return redirect_with_error(&error_url, "invalid_code"),
};
let provider_user =
query_param(&request, "user").and_then(|value| serde_json::from_str::<Value>(&value).ok());
let Some(user_info) = provider
.get_user_info(tokens.clone(), provider_user)
.await?
else {
return redirect_with_error(&error_url, "unable_to_get_user_info");
};
if let Some(link) = state_data.link {
link_oauth_account(
context,
adapter,
provider.clone(),
&link,
&user_info,
&tokens,
)
.await?;
return redirect(&state_data.callback_url, Vec::new());
}
let result = handle_oauth_user_info(
context,
adapter,
HandleOAuthUserInfoInput {
user_info: normalize_user_info(&user_info)?,
account: oauth_account(provider.id(), &user_info, &tokens, context)?,
callback_url: Some(state_data.callback_url.clone()),
disable_sign_up: (provider.provider_options().disable_implicit_sign_up
&& !state_data.request_sign_up)
|| provider.provider_options().disable_sign_up,
override_user_info: provider.provider_options().override_user_info_on_sign_in,
is_trusted_provider: true,
},
)
.await?;
let Some(data) = result.data else {
let error = result
.error
.map_or_else(|| "oauth_sign_in_failed".to_owned(), oauth_user_info_error);
return redirect_with_error(&error_url, &error);
};
let cookies = auth_session_cookies(context, &data.session, &data.user, false)?;
let target = if result.is_register {
state_data
.new_user_url
.as_deref()
.unwrap_or(&state_data.callback_url)
} else {
&state_data.callback_url
};
redirect(target, cookies)
}
fn oauth_user_info_error(error: crate::auth::oauth::OAuthUserInfoError) -> String {
match error {
crate::auth::oauth::OAuthUserInfoError::AccountNotLinked => "account_not_linked",
crate::auth::oauth::OAuthUserInfoError::SignupDisabled => "signup_disabled",
crate::auth::oauth::OAuthUserInfoError::UnableToCreateUser => "unable_to_create_user",
crate::auth::oauth::OAuthUserInfoError::UnableToCreateSession => "unable_to_create_session",
crate::auth::oauth::OAuthUserInfoError::UnableToLinkAccount => "unable_to_link_account",
}
.to_owned()
}
fn callback_post_redirect(
context: &crate::context::AuthContext,
request: &ApiRequest,
) -> Result<ApiResponse, OpenAuthError> {
let provider_id = path_param(request, "id")?;
let body = if request.body().is_empty() {
Value::Object(serde_json::Map::new())
} else {
parse_request_body::<Value>(request)?
};
let mut params = Vec::new();
for key in [
"code",
"error",
"device_id",
"error_description",
"state",
"user",
] {
if let Some(value) = query_param(request, key).or_else(|| body_string(&body, key)) {
params.push(format!("{key}={}", percent_encode(&value)));
}
}
let target = format!(
"{}/callback/{provider_id}?{}",
context.base_url.trim_end_matches('/'),
params.join("&")
);
redirect(&target, Vec::new())
}
async fn link_with_id_token(
context: &crate::context::AuthContext,
adapter: &dyn DbAdapter,
provider: Arc<dyn SocialOAuthProvider>,
user: &crate::db::User,
id_token: IdTokenBody,
) -> Result<ApiResponse, OpenAuthError> {
if !provider
.verify_id_token(SocialIdTokenRequest {
token: id_token.token.clone(),
nonce: id_token.nonce.clone(),
access_token: id_token.access_token.clone(),
refresh_token: id_token.refresh_token.clone(),
scopes: id_token.scopes.clone(),
provider_user: id_token.user.clone(),
})
.await?
{
return error_response(StatusCode::UNAUTHORIZED, "INVALID_TOKEN", "Invalid token");
}
let tokens = tokens_from_id_token(&id_token);
let Some(info) = provider
.get_user_info(tokens.clone(), id_token.user)
.await?
else {
return error_response(
StatusCode::UNAUTHORIZED,
"FAILED_TO_GET_USER_INFO",
"Failed to get user info",
);
};
let normalized = normalize_user_info(&info)?;
if normalized.email.to_lowercase() != user.email.to_lowercase()
&& !context
.options
.account
.account_linking
.allow_different_emails
{
return error_response(
StatusCode::UNAUTHORIZED,
"LINKING_DIFFERENT_EMAILS_NOT_ALLOWED",
"Account not linked - different emails not allowed",
);
}
link_oauth_account(
context,
adapter,
provider,
&OAuthStateLink {
user_id: user.id.clone(),
email: user.email.clone(),
},
&info,
&tokens,
)
.await?;
json_response(
StatusCode::OK,
&LinkStatusBody {
url: String::new(),
redirect: false,
status: true,
},
Vec::new(),
)
}
async fn link_oauth_account(
context: &crate::context::AuthContext,
adapter: &dyn DbAdapter,
provider: Arc<dyn SocialOAuthProvider>,
link: &OAuthStateLink,
info: &OAuth2UserInfo,
tokens: &OAuth2Tokens,
) -> Result<(), OpenAuthError> {
let normalized = normalize_user_info(info)?;
if normalized.email.to_lowercase() != link.email.to_lowercase()
&& !context
.options
.account
.account_linking
.allow_different_emails
{
return Err(OpenAuthError::Api(
"OAuth account email does not match linked user".to_owned(),
));
}
let users = DbUserStore::new(adapter);
if users
.find_account_by_provider_account(&normalized.id, provider.id())
.await?
.is_some()
{
return Ok(());
}
users
.link_account(CreateOAuthAccountInput {
id: None,
provider_id: provider.id().to_owned(),
account_id: normalized.id,
user_id: link.user_id.clone(),
access_token: crate::auth::oauth::set_token_util(
tokens.access_token.as_deref(),
context,
)?,
refresh_token: crate::auth::oauth::set_token_util(
tokens.refresh_token.as_deref(),
context,
)?,
id_token: tokens.id_token.clone(),
access_token_expires_at: tokens.access_token_expires_at,
refresh_token_expires_at: tokens.refresh_token_expires_at,
scope: (!tokens.scopes.is_empty()).then(|| tokens.scopes.join(",")),
})
.await?;
Ok(())
}
fn lookup_provider(
context: &crate::context::AuthContext,
provider_id: &str,
) -> Result<Arc<dyn SocialOAuthProvider>, OpenAuthError> {
context
.social_provider(provider_id)
.ok_or_else(|| OpenAuthError::Api(format!("social provider `{provider_id}` was not found")))
}
fn normalize_user_info(info: &OAuth2UserInfo) -> Result<OAuthUserInfo, OpenAuthError> {
let email = info
.email
.clone()
.ok_or_else(|| OpenAuthError::Api("OAuth provider did not return an email".to_owned()))?;
Ok(OAuthUserInfo {
id: info.id.clone(),
name: info.name.clone().unwrap_or_default(),
email,
image: info.image.clone(),
email_verified: info.email_verified,
})
}
fn oauth_account(
provider_id: &str,
info: &OAuth2UserInfo,
tokens: &OAuth2Tokens,
context: &crate::context::AuthContext,
) -> Result<OAuthAccountInput, OpenAuthError> {
Ok(OAuthAccountInput {
provider_id: provider_id.to_owned(),
account_id: info.id.clone(),
access_token: crate::auth::oauth::set_token_util(tokens.access_token.as_deref(), context)?,
refresh_token: crate::auth::oauth::set_token_util(
tokens.refresh_token.as_deref(),
context,
)?,
id_token: tokens.id_token.clone(),
access_token_expires_at: tokens.access_token_expires_at,
refresh_token_expires_at: tokens.refresh_token_expires_at,
scope: (!tokens.scopes.is_empty()).then(|| tokens.scopes.join(",")),
})
}
fn tokens_from_id_token(id_token: &IdTokenBody) -> OAuth2Tokens {
OAuth2Tokens {
access_token: id_token.access_token.clone(),
refresh_token: id_token.refresh_token.clone(),
id_token: Some(id_token.token.clone()),
scopes: id_token.scopes.clone(),
..OAuth2Tokens::default()
}
}
fn path_param<'a>(request: &'a ApiRequest, name: &str) -> Result<&'a str, OpenAuthError> {
request
.extensions()
.get::<PathParams>()
.and_then(|params| params.get(name))
.ok_or_else(|| OpenAuthError::Api(format!("missing path param `{name}`")))
}
fn redirect_uri(context: &crate::context::AuthContext, provider_id: &str) -> String {
format!(
"{}/callback/{provider_id}",
context.base_url.trim_end_matches('/')
)
}
fn redirect_json_response(url: String, redirect: bool) -> Result<ApiResponse, OpenAuthError> {
let mut response = json_response(
StatusCode::OK,
&SocialRedirectBody {
url: url.clone(),
redirect,
},
Vec::new(),
)?;
if redirect {
response.headers_mut().insert(
header::LOCATION,
HeaderValue::from_str(&url).map_err(|error| OpenAuthError::Api(error.to_string()))?,
);
}
Ok(response)
}
fn redirect(
location: &str,
cookies: Vec<crate::cookies::Cookie>,
) -> Result<ApiResponse, OpenAuthError> {
let mut response = http::Response::builder()
.status(StatusCode::FOUND)
.header(header::LOCATION, location)
.body(Vec::new())
.map_err(|error| OpenAuthError::Api(error.to_string()))?;
for cookie in cookies {
response.headers_mut().append(
header::SET_COOKIE,
HeaderValue::from_str(&serialize_cookie(&cookie))
.map_err(|error| OpenAuthError::Cookie(error.to_string()))?,
);
}
Ok(response)
}
fn redirect_with_error(location: &str, error: &str) -> Result<ApiResponse, OpenAuthError> {
let separator = if location.contains('?') { '&' } else { '?' };
redirect(
&format!("{location}{separator}error={}", percent_encode(error)),
Vec::new(),
)
}
fn body_string(body: &Value, key: &str) -> Option<String> {
body.get(key).and_then(|value| match value {
Value::String(value) => Some(value.clone()),
Value::Number(value) => Some(value.to_string()),
Value::Bool(value) => Some(value.to_string()),
_ => None,
})
}
fn percent_encode(value: &str) -> String {
url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
}
fn social_sign_in_body_schema() -> BodySchema {
BodySchema::object([
BodyField::new("provider", JsonSchemaType::String),
BodyField::optional("callbackURL", JsonSchemaType::String),
BodyField::optional("errorCallbackURL", JsonSchemaType::String),
BodyField::optional("newUserCallbackURL", JsonSchemaType::String),
BodyField::optional("disableRedirect", JsonSchemaType::Boolean),
BodyField::optional("scopes", JsonSchemaType::Array),
BodyField::optional("loginHint", JsonSchemaType::String),
BodyField::optional("requestSignUp", JsonSchemaType::Boolean),
BodyField::optional("additionalData", JsonSchemaType::Object),
BodyField::optional("idToken", JsonSchemaType::Object),
])
}
fn link_social_body_schema() -> BodySchema {
BodySchema::object([
BodyField::new("provider", JsonSchemaType::String),
BodyField::optional("callbackURL", JsonSchemaType::String),
BodyField::optional("errorCallbackURL", JsonSchemaType::String),
BodyField::optional("disableRedirect", JsonSchemaType::Boolean),
BodyField::optional("scopes", JsonSchemaType::Array),
BodyField::optional("requestSignUp", JsonSchemaType::Boolean),
BodyField::optional("additionalData", JsonSchemaType::Object),
BodyField::optional("idToken", JsonSchemaType::Object),
])
}