use super::StdbOidcAuthOptions;
use crate::{AUTHORIZATION_CODE_GRANT_TYPE, AUTHORIZATION_ENDPOINT, error::StdbAuthError};
use oauth2::{CsrfToken, PkceCodeChallenge};
use std::collections::BTreeMap;
use url::Url;
pub(super) struct StdbOidcAuthorizationRequest {
pub(super) authorization_url: Url,
pub(super) state: String,
pub(super) pkce_verifier: String,
}
pub(super) struct StdbOidcAuthorizationCode {
pub(super) code: String,
}
pub(super) struct StdbOidcTokenRequestForm {
pub(super) params: BTreeMap<String, String>,
}
pub(super) fn build_authorization_request(
options: &StdbOidcAuthOptions,
) -> Result<StdbOidcAuthorizationRequest, StdbAuthError> {
let client_id = require_non_empty(&options.client_id, "client_id")?;
let redirect_uri = validate_redirect_uri(&options.redirect_uri)?;
let scopes = normalized_scopes(&options.scopes);
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let state = CsrfToken::new_random();
let mut authorization_url = Url::parse(AUTHORIZATION_ENDPOINT)
.expect("static SpacetimeAuth authorization endpoint must be valid");
{
let mut query = authorization_url.query_pairs_mut();
query.append_pair("response_type", "code");
query.append_pair("client_id", &client_id);
query.append_pair("redirect_uri", redirect_uri.as_str());
query.append_pair("state", state.secret());
query.append_pair("code_challenge", pkce_challenge.as_str());
query.append_pair("code_challenge_method", pkce_challenge.method().as_str());
if !scopes.is_empty() {
query.append_pair("scope", &scopes.join(" "));
}
if let Some(prompt) = options.prompt.as_param() {
query.append_pair("prompt", prompt);
}
}
Ok(StdbOidcAuthorizationRequest {
authorization_url,
state: state.into_secret(),
pkce_verifier: pkce_verifier.into_secret(),
})
}
pub(super) fn parse_callback_url(
callback_url: &str,
expected_state: &str,
) -> Result<StdbOidcAuthorizationCode, StdbAuthError> {
let callback_url = Url::parse(callback_url).map_err(|error| {
StdbAuthError::InvalidOidcCallback(format!("callback URL is invalid: {error}"))
})?;
let state = query_param(&callback_url, "state").ok_or_else(|| {
StdbAuthError::InvalidOidcCallback("callback is missing `state`".to_string())
})?;
if state != expected_state {
return Err(StdbAuthError::InvalidOidcCallback(
"callback `state` does not match the pending authorization".to_string(),
));
}
if let Some(error) = query_param(&callback_url, "error") {
let description = query_param(&callback_url, "error_description");
return Err(StdbAuthError::Provider(format_provider_error(
&error,
description.as_deref(),
)));
}
let code = query_param(&callback_url, "code").ok_or_else(|| {
StdbAuthError::InvalidOidcCallback("callback is missing `code`".to_string())
})?;
if code.trim().is_empty() {
return Err(StdbAuthError::InvalidOidcCallback(
"callback `code` must not be empty".to_string(),
));
}
Ok(StdbOidcAuthorizationCode { code })
}
pub(super) fn authorization_code_token_form(
options: &StdbOidcAuthOptions,
code: &str,
pkce_verifier: &str,
) -> Result<StdbOidcTokenRequestForm, StdbAuthError> {
let mut params = BTreeMap::new();
params.insert(
"grant_type".to_string(),
AUTHORIZATION_CODE_GRANT_TYPE.to_string(),
);
params.insert("code".to_string(), require_non_empty(code, "code")?);
params.insert(
"redirect_uri".to_string(),
validate_redirect_uri(&options.redirect_uri)?.to_string(),
);
params.insert(
"client_id".to_string(),
require_non_empty(&options.client_id, "client_id")?,
);
params.insert(
"code_verifier".to_string(),
require_non_empty(pkce_verifier, "code_verifier")?,
);
Ok(StdbOidcTokenRequestForm { params })
}
fn require_non_empty(value: &str, field: &'static str) -> Result<String, StdbAuthError> {
let value = value.trim().to_string();
if value.is_empty() {
return Err(StdbAuthError::InvalidConfig(format!(
"`{field}` must not be empty"
)));
}
Ok(value)
}
fn validate_redirect_uri(redirect_uri: &str) -> Result<Url, StdbAuthError> {
let redirect_uri = require_non_empty(redirect_uri, "redirect_uri")?;
let redirect_uri = Url::parse(&redirect_uri).map_err(|error| {
StdbAuthError::InvalidConfig(format!("`redirect_uri` is invalid: {error}"))
})?;
if redirect_uri.fragment().is_some() {
return Err(StdbAuthError::InvalidConfig(
"`redirect_uri` must not include a fragment".to_string(),
));
}
Ok(redirect_uri)
}
fn normalized_scopes(scopes: &[String]) -> Vec<String> {
scopes
.iter()
.filter_map(|scope| {
let scope = scope.trim();
(!scope.is_empty()).then(|| scope.to_string())
})
.collect()
}
fn query_param(url: &Url, name: &str) -> Option<String> {
url.query_pairs()
.find_map(|(key, value)| (key == name).then(|| value.into_owned()))
}
fn format_provider_error(error: &str, description: Option<&str>) -> String {
match description.filter(|description| !description.trim().is_empty()) {
Some(description) => format!("{error}: {description}"),
None => error.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oidc::StdbOidcPrompt;
fn auth_options() -> StdbOidcAuthOptions {
StdbOidcAuthOptions {
client_id: "client".to_string(),
redirect_uri: "http://127.0.0.1:3000/callback".to_string(),
post_logout_redirect_uri: None,
scopes: vec!["openid".to_string(), "email".to_string()],
prompt: StdbOidcPrompt::Login,
}
}
fn query_map(url: &Url) -> BTreeMap<String, String> {
url.query_pairs()
.map(|(key, value)| (key.into_owned(), value.into_owned()))
.collect()
}
fn form_map(form: StdbOidcTokenRequestForm) -> BTreeMap<String, String> {
form.params
}
#[test]
fn authorization_request_contains_oidc_parameters() {
let request = build_authorization_request(&auth_options())
.expect("authorization request should be valid");
let query = query_map(&request.authorization_url);
assert_eq!(
request.authorization_url.as_str().split('?').next(),
Some("https://auth.spacetimedb.com/oidc/auth")
);
assert_eq!(query.get("response_type").map(String::as_str), Some("code"));
assert_eq!(query.get("client_id").map(String::as_str), Some("client"));
assert_eq!(
query.get("redirect_uri").map(String::as_str),
Some("http://127.0.0.1:3000/callback")
);
assert_eq!(query.get("scope").map(String::as_str), Some("openid email"));
assert_eq!(query.get("prompt").map(String::as_str), Some("login"));
assert_eq!(
query.get("code_challenge_method").map(String::as_str),
Some("S256")
);
assert_eq!(query.get("state"), Some(&request.state));
assert!(query.contains_key("code_challenge"));
assert!(request.pkce_verifier.len() >= 43);
}
#[test]
fn callback_url_returns_authorization_code() {
let callback = parse_callback_url(
"http://127.0.0.1:3000/callback?code=abc&state=state",
"state",
)
.expect("callback should be valid");
assert_eq!(callback.code, "abc");
}
#[test]
fn callback_url_rejects_state_mismatch() {
let result = parse_callback_url(
"http://127.0.0.1:3000/callback?code=abc&state=other",
"state",
);
assert!(matches!(result, Err(StdbAuthError::InvalidOidcCallback(_))));
}
#[test]
fn callback_url_returns_provider_error() {
let result = parse_callback_url(
"http://127.0.0.1:3000/callback?error=access_denied&error_description=nope&state=state",
"state",
);
assert!(matches!(result, Err(StdbAuthError::Provider(_))));
}
#[test]
fn authorization_code_token_form_contains_required_fields() {
let form = authorization_code_token_form(&auth_options(), "code", "verifier")
.expect("authorization-code token form should be valid");
let form = form_map(form);
assert_eq!(
form.get("grant_type").map(String::as_str),
Some("authorization_code")
);
assert_eq!(form.get("code").map(String::as_str), Some("code"));
assert_eq!(form.get("client_id").map(String::as_str), Some("client"));
assert_eq!(
form.get("code_verifier").map(String::as_str),
Some("verifier")
);
assert_eq!(
form.get("redirect_uri").map(String::as_str),
Some("http://127.0.0.1:3000/callback")
);
}
}