use super::helpers::normalize_issuer;
use crate::authn::service::AuthnService;
use crate::authn::store::{FactorStore, IdentityStore};
use crate::session::extractor::AuthSession;
impl<I, F> AuthnService<I, F>
where
I: IdentityStore,
F: FactorStore<Error = I::Error>,
{
#[tracing::instrument(skip(self, options, session))]
pub async fn begin_oauth_login(
&self,
provider_name: &str,
options: &axess_factors::oauth::OAuthLoginOptions,
session: &AuthSession,
) -> Result<(url::Url, String), axess_factors::oauth::OAuthError> {
use axess_factors::oauth::{OAuthError, keys as oauth_keys};
let provider = self
.oauth_providers
.get(provider_name)
.ok_or_else(|| OAuthError::UnknownProvider(provider_name.to_string()))?;
if provider.fapi_config().is_some()
&& let Some(serde_json::Value::Object(map)) =
session.get_custom(oauth_keys::PAR_INFLIGHT).await
&& let Some(expires_at_str) = map.get("expires_at").and_then(|v| v.as_str())
&& let Ok(expires_at) = chrono::DateTime::parse_from_rfc3339(expires_at_str)
&& self.clock.now() < expires_at.with_timezone(&chrono::Utc)
{
tracing::warn!(
provider = %provider_name,
"refusing PAR begin while a previous request_uri is still in its \
single-use window; clear the session ceremony state and retry"
);
return Err(OAuthError::CsrfMismatch);
}
self.clear_oauth_state(session).await;
let used_par = provider.fapi_config().is_some();
let (auth_url, csrf_state, nonce, pkce_verifier) = if used_par {
provider.build_auth_url_par(options).await?
} else {
provider.build_auth_url(options)?
};
if used_par {
self.stash_par_inflight_marker(session, provider.ceremony_timeout())
.await;
}
self.stash_oauth_ceremony_state(
session,
pkce_verifier,
&csrf_state,
nonce,
provider_name,
provider.issuer(),
)
.await;
Ok((auth_url, csrf_state))
}
pub async fn begin_oauth_login_in_tenant(
&self,
provider_name: &str,
options: &axess_factors::oauth::OAuthLoginOptions,
expected_tenant: &crate::authn::ids::TenantId,
session: &AuthSession,
) -> Result<(url::Url, String), axess_factors::oauth::OAuthError> {
use axess_factors::oauth::types::keys as oauth_keys;
let result = self
.begin_oauth_login(provider_name, options, session)
.await?;
session
.set_custom(
oauth_keys::EXPECTED_TENANT,
serde_json::Value::String(expected_tenant.to_string().to_string()),
)
.await;
Ok(result)
}
async fn stash_oauth_ceremony_state(
&self,
session: &AuthSession,
pkce_verifier: String,
csrf_state: &str,
nonce: String,
provider_name: &str,
provider_issuer: Option<&str>,
) {
use axess_factors::oauth::types::keys as oauth_keys;
let str_val = |s: String| serde_json::Value::String(s);
session
.set_custom(oauth_keys::PKCE_VERIFIER, str_val(pkce_verifier))
.await;
session
.set_custom(oauth_keys::CSRF_STATE, str_val(csrf_state.to_string()))
.await;
session.set_custom(oauth_keys::NONCE, str_val(nonce)).await;
session
.set_custom(oauth_keys::PROVIDER, str_val(provider_name.to_string()))
.await;
session
.set_custom(oauth_keys::STARTED, str_val(self.clock.now().to_rfc3339()))
.await;
if let Some(issuer) = provider_issuer {
session
.set_custom(
oauth_keys::PROVIDER_ISSUER,
str_val(normalize_issuer(issuer)),
)
.await;
}
}
async fn stash_par_inflight_marker(
&self,
session: &AuthSession,
ceremony_timeout: std::time::Duration,
) {
use axess_factors::oauth::types::keys as oauth_keys;
let lifetime = ceremony_timeout.min(std::time::Duration::from_secs(600));
let expires_at = self.clock.now()
+ chrono::Duration::from_std(lifetime).unwrap_or(chrono::Duration::seconds(90));
let mut entry = serde_json::Map::new();
entry.insert(
"expires_at".to_string(),
serde_json::Value::String(expires_at.to_rfc3339()),
);
session
.set_custom(oauth_keys::PAR_INFLIGHT, serde_json::Value::Object(entry))
.await;
}
}