use async_trait::async_trait;
use openidconnect::{
CsrfToken, Nonce, PkceCodeChallenge, Scope, core::CoreAuthenticationFlow,
url::form_urlencoded::parse,
};
use serde::Deserialize;
use shield::{
Action, ActionMethod, Form, Input, InputAddon, InputType, InputTypeHidden, InputTypeSubmit,
InputValue, MethodSession, Provider, Request, Response, ResponseType, SessionAction,
ShieldError, SignInAction, erased_action,
};
use url::Url;
use crate::{
options::OidcOptions,
provider::{OidcProvider, OidcProviderPkceCodeChallenge},
session::OidcSession,
};
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SignInData {
pub redirect_origin: Url,
pub redirect_url: Option<String>,
}
pub struct OidcSignInAction {
options: OidcOptions,
}
impl OidcSignInAction {
pub fn new(options: OidcOptions) -> Self {
Self { options }
}
}
#[async_trait]
impl Action<OidcProvider, OidcSession> for OidcSignInAction {
fn id(&self) -> String {
SignInAction::id()
}
fn name(&self) -> String {
SignInAction::name()
}
fn openapi_summary(&self) -> &'static str {
"Sign in with OpenID Connect"
}
fn openapi_description(&self) -> &'static str {
"Sign in with OpenID Connect."
}
fn method(&self) -> ActionMethod {
ActionMethod::Post
}
async fn forms(&self, provider: OidcProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![Form {
inputs: vec![
Input {
name: "redirectOrigin".to_owned(),
label: None,
r#type: InputType::Hidden(InputTypeHidden::default()),
value: Some(InputValue::Origin),
addon_start: None,
addon_end: None,
},
Input {
name: "redirectUrl".to_owned(),
label: None,
r#type: InputType::Hidden(InputTypeHidden::default()),
value: Some(InputValue::Query {
key: "redirectUrl".to_owned(),
}),
addon_start: None,
addon_end: None,
},
Input {
name: "submit".to_owned(),
label: None,
r#type: InputType::Submit(InputTypeSubmit::default()),
value: Some(InputValue::String {
value: format!("Sign in with {}", provider.name()),
}),
addon_start: provider
.icon_url
.as_ref()
.map(|icon_url| InputAddon::Image {
alt: format!("{} logo", provider.name()),
src: icon_url.clone(),
}),
addon_end: None,
},
],
}])
}
async fn call(
&self,
provider: OidcProvider,
_session: &MethodSession<OidcSession>,
request: Request,
) -> Result<Response, ShieldError> {
let data = serde_json::from_value::<SignInData>(request.form_data)
.map_err(|err| ShieldError::Validation(err.to_string()))?;
let redirect_url = data
.redirect_url
.map(|redirect_url| data.redirect_origin.join(&redirect_url))
.unwrap_or_else(|| data.redirect_origin.join(&self.options.sign_in_redirect))
.map_err(|err| ShieldError::Validation(format!("redirect URL parse error: {err}")))?;
if let Some(redirect_origins) = &self.options.redirect_origins {
let redirect_origin = Url::parse(&redirect_url.origin().ascii_serialization())
.map_err(|err| {
ShieldError::Validation(format!("redirect origin parse error: {err}"))
})?;
if !redirect_origins.contains(&redirect_origin) {
return Err(ShieldError::Validation(format!(
"redirect origin `{redirect_origin}` not allowed"
)));
}
}
if let Some(redirect_patterns) = &self.options.redirect_patterns {
let redirect_url_str = redirect_url.to_string();
if !redirect_patterns
.iter()
.any(|pattern| pattern.is_match(&redirect_url_str))
{
return Err(ShieldError::Validation(format!(
"redirect URL `{redirect_url}` not allowed"
)));
}
}
let client = provider.oidc_client().await?;
let mut authorization_request = client.authorize_url(
CoreAuthenticationFlow::AuthorizationCode,
CsrfToken::new_random,
Nonce::new_random,
);
let pkce_code_challenge = match provider.pkce_code_challenge {
OidcProviderPkceCodeChallenge::None => None,
OidcProviderPkceCodeChallenge::Plain => Some(PkceCodeChallenge::new_random_plain()),
OidcProviderPkceCodeChallenge::S256 => Some(PkceCodeChallenge::new_random_sha256()),
};
if let Some((pkce_code_challenge, _)) = &pkce_code_challenge {
authorization_request =
authorization_request.set_pkce_challenge(pkce_code_challenge.clone());
}
if let Some(scopes) = provider.scopes {
authorization_request =
authorization_request.add_scopes(scopes.into_iter().map(Scope::new));
}
if let Some(authorization_url_params) = provider.authorization_url_params {
let params = parse(authorization_url_params.trim_start_matches('?').as_bytes());
for (name, value) in params {
authorization_request =
authorization_request.add_extra_param(name.into_owned(), value.into_owned());
}
}
let (auth_url, csrf_token, nonce) = authorization_request.url();
Ok(Response::new(ResponseType::Redirect(auth_url.to_string()))
.session_action(SessionAction::unauthenticate())
.session_action(SessionAction::data(OidcSession {
redirect_url: Some(redirect_url),
csrf: Some(csrf_token.secret().clone()),
nonce: Some(nonce.secret().clone()),
pkce_verifier: pkce_code_challenge
.map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()),
oidc_connection_id: None,
})?))
}
}
erased_action!(OidcSignInAction);