mod flow;
mod support;
use http::Method;
use serde_json::Value;
use super::shared::{sensitive_session, unauthorized};
use crate::api::{
create_auth_endpoint, parse_request_body, AsyncAuthEndpoint, AuthEndpointOptions,
OpenApiOperation,
};
use crate::auth::oauth::{generate_oauth_state, OAuthStateInput, OAuthStateLink};
use crate::cookies::Cookie;
use rustauth_oauth::oauth2::SocialAuthorizationUrlRequest;
use flow::{
callback_get, callback_post_redirect, link_with_id_token, lookup_provider,
sign_in_with_id_token,
};
use support::{
link_social_body_schema, redirect_json_response, redirect_uri, social_sign_in_body_schema,
LinkSocialBody, SocialSignInBody,
};
pub(super) fn sign_in_social_endpoint() -> AsyncAuthEndpoint {
sign_in_social_oauth_endpoint(
"/sign-in/social",
"socialSignIn",
"Sign in with a social provider",
)
}
fn sign_in_social_oauth_endpoint(
path: &'static str,
operation_id: &'static str,
description: &'static str,
) -> AsyncAuthEndpoint {
create_auth_endpoint(
path,
Method::POST,
AuthEndpointOptions::new()
.operation_id(operation_id)
.allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
.body_schema(social_sign_in_body_schema())
.openapi(OpenApiOperation::new(operation_id).description(description)),
move |context, request| async move {
let body: SocialSignInBody = parse_request_body(&request)?;
let provider = lookup_provider(&context, &body.provider)?;
if let Some(id_token) = body.id_token {
return sign_in_with_id_token(&context, context.adapter_ref()?, provider, id_token)
.await;
}
let state = generate_oauth_state(
&context,
Some(context.adapter_ref()?),
OAuthStateInput {
callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
error_url: body.error_callback_url,
new_user_url: body.new_user_callback_url,
request_sign_up: body.request_sign_up,
additional_data: body.additional_data.unwrap_or(Value::Null),
..OAuthStateInput::default()
},
)
.await?;
let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
state: state.state,
redirect_uri: redirect_uri(&context, &request, provider.id()),
code_verifier: Some(state.data.code_verifier),
scopes: body.scopes,
login_hint: body.login_hint,
})?;
redirect_json_response(
url.to_string(),
!body.disable_redirect,
vec![oauth_state_cookie(&context, &state.data.oauth_state)],
)
},
)
}
pub(super) fn callback_oauth_endpoint(method: Method) -> AsyncAuthEndpoint {
let mut options = AuthEndpointOptions::new()
.operation_id("handleOAuthCallback")
.openapi(OpenApiOperation::new("handleOAuthCallback").description("Handle OAuth callback"));
if method == Method::POST {
options = options.bypass_origin_security();
}
create_auth_endpoint(
"/callback/:id",
method,
options,
move |context, request| async move {
if request.method() == Method::POST {
return callback_post_redirect(&context, &request);
}
callback_get(&context, context.adapter_ref()?, request).await
},
)
}
pub(super) fn link_social_endpoint() -> AsyncAuthEndpoint {
create_auth_endpoint(
"/link-social",
Method::POST,
AuthEndpointOptions::new()
.operation_id("linkSocialAccount")
.allowed_media_types(["application/x-www-form-urlencoded", "application/json"])
.body_schema(link_social_body_schema())
.openapi(
OpenApiOperation::new("linkSocialAccount").description("Link a social account"),
),
move |context, request| async move {
let Some((_session, user, _cookies)) = sensitive_session(&context, &request).await?
else {
return unauthorized();
};
let body: LinkSocialBody = parse_request_body(&request)?;
let provider = lookup_provider(&context, &body.provider)?;
if let Some(id_token) = body.id_token {
return link_with_id_token(
&context,
context.adapter_ref()?,
provider,
&user,
id_token,
)
.await;
}
let state = generate_oauth_state(
&context,
Some(context.adapter_ref()?),
OAuthStateInput {
callback_url: body.callback_url.unwrap_or_else(|| "/".to_owned()),
error_url: body.error_callback_url,
link: Some(OAuthStateLink {
user_id: user.id,
email: user.email,
}),
request_sign_up: body.request_sign_up,
additional_data: body.additional_data.unwrap_or(Value::Null),
..OAuthStateInput::default()
},
)
.await?;
let url = provider.create_authorization_url(SocialAuthorizationUrlRequest {
state: state.state,
redirect_uri: redirect_uri(&context, &request, provider.id()),
code_verifier: Some(state.data.code_verifier),
scopes: body.scopes,
login_hint: None,
})?;
redirect_json_response(
url.to_string(),
!body.disable_redirect,
vec![oauth_state_cookie(&context, &state.data.oauth_state)],
)
},
)
}
fn oauth_state_cookie(context: &crate::context::AuthContext, oauth_state: &str) -> Cookie {
Cookie {
name: context.auth_cookies.oauth_state.name.clone(),
value: oauth_state.to_owned(),
attributes: context.auth_cookies.oauth_state.attributes.clone(),
}
}