use crate::build_lax_cookie_300;
use crate::cookie_state::{OIDC_STATE_COOKIE, OidcCookieState};
use crate::provider::OidcProvider;
use crate::rauthy_error::RauthyError;
use crate::token_set::OidcTokenSet;
use crate::tokens::claims::{AccessToken, IdToken};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use tracing::error;
#[cfg(feature = "axum")]
pub mod axum;
#[cfg(feature = "actix-web")]
pub mod actix_web;
#[derive(Debug, Deserialize)]
pub struct OidcCallbackParams {
pub code: String,
pub state: String,
}
#[derive(Debug, Serialize)]
struct OidcCodeRequestParams {
client_id: String,
client_secret: Option<String>,
code: String,
code_verifier: String,
grant_type: &'static str,
redirect_uri: String,
}
#[derive(Debug, PartialEq)]
pub enum OidcCookieInsecure {
Yes,
No,
}
#[derive(Debug, PartialEq)]
pub enum OidcSetRedirectStatus {
Yes,
No,
}
impl OidcCodeRequestParams {
pub async fn try_new(
code: String,
code_verifier: String,
redirect_uri: String,
) -> Result<Self, RauthyError> {
let cfg = OidcProvider::config()?;
let client_id = cfg.client_id.clone();
let client_secret = cfg.secret.clone();
Ok(Self {
client_id,
client_secret,
code,
code_verifier,
grant_type: "authorization_code",
redirect_uri,
})
}
}
#[cfg(not(any(feature = "axum", feature = "actix-web")))]
pub async fn validate_principal_generic(
principal: Option<crate::principal::PrincipalOidc>,
enc_key: &[u8],
insecure: OidcCookieInsecure,
) -> Result<(), Option<(String, String)>> {
if principal.is_some() {
Ok(())
} else {
let (cookie_state, challenge) = OidcCookieState::generate();
let loc = {
let base = match OidcProvider::config() {
Ok(c) => &c.auth_url_base,
Err(_) => {
return Err(None);
}
};
format!(
"{base}&code_challenge={challenge}&nonce={}&state={}",
cookie_state.nonce, cookie_state.state
)
};
let value = cookie_state.to_encrypted_cookie_value(enc_key);
let cookie = build_lax_cookie_300(OIDC_STATE_COOKIE, &value, insecure);
Err(Some((loc, cookie)))
}
}
pub async fn oidc_callback(
cookie_state: OidcCookieState,
params: OidcCallbackParams,
insecure: OidcCookieInsecure,
) -> Result<(String, OidcTokenSet, IdToken), RauthyError> {
if params.state != cookie_state.state {
return Err(RauthyError::BadRequest("Bad state"));
}
let (token_uri, redirect_uri) = {
let cfg = OidcProvider::config()?;
let t = cfg.provider.token_endpoint.clone();
let r = cfg.redirect_uri.clone();
(t, r)
};
let req_data = OidcCodeRequestParams::try_new(
params.code.clone(),
cookie_state.pkce_verifier,
redirect_uri,
)
.await?;
let res = OidcProvider::client()
.post(&token_uri)
.form(&req_data)
.send()
.await?;
if res.status().as_u16() >= 300 {
error!("{:?}", res);
let body = res.text().await;
let msg = match body {
Ok(value) => {
error!("raw OIDC provider response: {:?}", value);
value
}
Err(_) => "Internal Error - Bad response status".to_string(),
};
Err(RauthyError::Provider(Cow::from(msg)))
} else {
match res.json::<OidcTokenSet>().await {
Ok(ts) => {
let access_claims = AccessToken::from_token_validated(&ts.access_token).await?;
if ts.id_token.is_none() {
return Err(RauthyError::Provider(Cow::from("ID token is missing")));
}
let id_claims = IdToken::from_token_validated(
ts.id_token.as_deref().unwrap(),
&cookie_state.nonce,
)
.await?;
if access_claims.common.sub.is_none()
|| access_claims.common.sub != id_claims.common.sub
{
return Err(RauthyError::InvalidClaims("Invalid `sub` claims"));
}
let cookie = build_lax_cookie_300(OIDC_STATE_COOKIE, "", insecure);
Ok((cookie, ts, id_claims))
}
Err(err) => {
error!("Deserializing OIDC response to OidcTokenSet: {}", err);
Err(RauthyError::Provider(Cow::from(
"Internal Error - Deserializing OIDC response",
)))
}
}
}
}