use super::models::{
AuthorizeFlowQuery, CallbackRequest, OAuth2ErrorResponse, OAuth2TokenResponse, OAuthCodeState,
OAuthExtensionSpec, OAuthExtensionStatus, OAuthState, TokenRequest,
};
use crate::db::{extensions as db_extensions, projects as db_projects};
use crate::server::state::AppState;
use axum::{
extract::{Path, Query, State},
http::{header, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Redirect, Response},
Json,
};
use base64::Engine;
use chrono::Utc;
use std::sync::Arc;
use tracing::{debug, error, info, warn};
use url::Url;
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
struct OidcDiscoveryDocument {
authorization_endpoint: Option<String>,
token_endpoint: Option<String>,
jwks_uri: Option<String>,
}
#[derive(Debug, Clone)]
struct ResolvedEndpoints {
authorization_endpoint: String,
token_endpoint: String,
}
async fn fetch_oidc_discovery(issuer_url: &str) -> Result<OidcDiscoveryDocument, String> {
let discovery_url = format!(
"{}/.well-known/openid-configuration",
issuer_url.trim_end_matches('/')
);
let http_client = reqwest::Client::new();
let response = http_client.get(&discovery_url).send().await.map_err(|e| {
error!(
"Failed to fetch OIDC discovery from {}: {:?}",
discovery_url, e
);
format!("Failed to fetch OIDC discovery: {}", e)
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unable to read error response".to_string());
error!(
"OIDC discovery failed with status {}: {}",
status, error_text
);
return Err(format!("OIDC discovery failed: {}", status));
}
response.json().await.map_err(|e| {
error!("Failed to parse OIDC discovery response: {:?}", e);
format!("Failed to parse OIDC discovery: {}", e)
})
}
async fn resolve_oauth_endpoints(spec: &OAuthExtensionSpec) -> Result<ResolvedEndpoints, String> {
if let (Some(auth), Some(token)) = (&spec.authorization_endpoint, &spec.token_endpoint) {
return Ok(ResolvedEndpoints {
authorization_endpoint: auth.clone(),
token_endpoint: token.clone(),
});
}
let discovery = fetch_oidc_discovery(&spec.issuer_url).await?;
let authorization_endpoint = spec
.authorization_endpoint
.clone()
.or(discovery.authorization_endpoint)
.ok_or_else(|| "No authorization_endpoint in spec or OIDC discovery".to_string())?;
let token_endpoint = spec
.token_endpoint
.clone()
.or(discovery.token_endpoint)
.ok_or_else(|| "No token_endpoint in spec or OIDC discovery".to_string())?;
Ok(ResolvedEndpoints {
authorization_endpoint,
token_endpoint,
})
}
fn generate_state_token() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
(0..32)
.map(|_| format!("{:02x}", rng.gen::<u8>()))
.collect()
}
fn generate_code_verifier() -> String {
use rand::Rng;
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
let mut rng = rand::thread_rng();
(0..128)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
fn generate_code_challenge(verifier: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash)
}
async fn validate_cors_origin(
pool: &sqlx::PgPool,
origin: &str,
project: &crate::db::models::Project,
rise_public_url: &str,
deployment_backend: &Arc<dyn crate::server::deployment::controller::DeploymentBackend>,
) -> Option<String> {
let origin_url = match Url::parse(origin) {
Ok(url) => url,
Err(_) => return None,
};
if let Some(host) = origin_url.host_str() {
if host == "localhost" || host == "127.0.0.1" {
return Some(origin.to_string());
}
}
if let Ok(rise_url) = Url::parse(rise_public_url) {
if origin_url.host() == rise_url.host()
&& origin_url.port() == rise_url.port()
&& origin_url.scheme() == rise_url.scheme()
{
return Some(origin.to_string());
}
}
let all_deployments =
match crate::db::deployments::get_active_deployments_for_project(pool, project.id).await {
Ok(deployments) => deployments,
Err(e) => {
warn!(
"Failed to fetch active deployments for project {}: {:?}",
project.name, e
);
return None;
}
};
for deployment in &all_deployments {
match deployment_backend
.get_deployment_urls(deployment, project)
.await
{
Ok(urls) => {
if !urls.default_url.is_empty() {
if let Ok(deployment_url) = Url::parse(&urls.default_url) {
if origin_url.host() == deployment_url.host()
&& origin_url.port() == deployment_url.port()
&& origin_url.scheme() == deployment_url.scheme()
{
return Some(origin.to_string());
}
}
}
for custom_url in &urls.custom_domain_urls {
if let Ok(custom_domain_url) = Url::parse(custom_url) {
if origin_url.host() == custom_domain_url.host()
&& origin_url.port() == custom_domain_url.port()
&& origin_url.scheme() == custom_domain_url.scheme()
{
return Some(origin.to_string());
}
}
}
}
Err(e) => {
warn!(
"Failed to get deployment URLs for deployment {}: {:?}",
deployment.deployment_id, e
);
}
}
}
None
}
fn cors_headers(origin: &str) -> HeaderMap {
let mut headers = HeaderMap::new();
if let Ok(origin_value) = HeaderValue::from_str(origin) {
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin_value);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static("POST, OPTIONS"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_static("Content-Type"),
);
headers.insert(
header::ACCESS_CONTROL_MAX_AGE,
HeaderValue::from_static("86400"), );
headers
}
async fn validate_redirect_uri(
pool: &sqlx::PgPool,
redirect_uri: &str,
project: &crate::db::models::Project,
rise_public_url: &str,
deployment_backend: &Arc<dyn crate::server::deployment::controller::DeploymentBackend>,
) -> Result<(), String> {
let redirect_url =
Url::parse(redirect_uri).map_err(|e| format!("Invalid redirect URI: {}", e))?;
if let Some(host) = redirect_url.host_str() {
if host == "localhost" || host == "127.0.0.1" {
return Ok(());
}
}
if redirect_uri.starts_with(rise_public_url) {
return Ok(());
}
let all_deployments =
match crate::db::deployments::get_active_deployments_for_project(pool, project.id).await {
Ok(deployments) => deployments,
Err(e) => {
warn!(
"Failed to fetch active deployments for project {}: {:?}",
project.name, e
);
vec![]
}
};
for deployment in &all_deployments {
match deployment_backend
.get_deployment_urls(deployment, project)
.await
{
Ok(urls) => {
if !urls.default_url.is_empty() && redirect_uri.starts_with(&urls.default_url) {
return Ok(());
}
for custom_url in &urls.custom_domain_urls {
if redirect_uri.starts_with(custom_url) {
return Ok(());
}
}
}
Err(e) => {
warn!(
"Failed to get deployment URLs for deployment {}: {:?}",
deployment.deployment_id, e
);
}
}
}
Err(format!(
"Invalid redirect URI: not authorized for this project. Allowed: localhost, URLs starting with Rise public URL ({}), or any active deployment URL",
rise_public_url
))
}
pub async fn authorize(
State(state): State<AppState>,
Path((project_name, extension_name)): Path<(String, String)>,
Query(req): Query<AuthorizeFlowQuery>,
) -> Result<Response, (StatusCode, String)> {
debug!(
"OAuth authorize request for project={}, extension={}",
project_name, extension_name
);
let project = db_projects::find_by_name(&state.db_pool, &project_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Project not found".to_string()))?;
let extension =
db_extensions::find_by_project_and_name(&state.db_pool, project.id, &extension_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((
StatusCode::NOT_FOUND,
"OAuth extension not configured".to_string(),
))?;
if extension.extension_type != "oauth" {
return Err((
StatusCode::BAD_REQUEST,
format!(
"Extension '{}' is not an OAuth extension (type: {})",
extension_name, extension.extension_type
),
));
}
let spec: OAuthExtensionSpec = serde_json::from_value(extension.spec.clone()).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid spec: {}", e),
)
})?;
if let Some(ref code_challenge) = req.code_challenge {
if code_challenge.len() < 43 || code_challenge.len() > 128 {
return Err((
StatusCode::BAD_REQUEST,
"code_challenge must be 43-128 characters".to_string(),
));
}
if !code_challenge
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err((
StatusCode::BAD_REQUEST,
"code_challenge contains invalid characters (must be base64url)".to_string(),
));
}
let method = req.code_challenge_method.as_deref().unwrap_or("S256");
if method != "S256" && method != "plain" {
return Err((
StatusCode::BAD_REQUEST,
format!(
"Unsupported code_challenge_method '{}'. Only 'S256' and 'plain' are supported.",
method
),
));
}
}
let final_redirect_uri = if let Some(ref uri) = req.redirect_uri {
validate_redirect_uri(
&state.db_pool,
uri,
&project,
&state.public_url,
&state.deployment_backend,
)
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e))?;
uri.clone()
} else {
let api_url = Url::parse(&state.public_url).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid API URL configuration: {}", e),
)
})?;
let api_host = api_url.host_str().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Missing host in API URL".to_string(),
))?;
let project_host = if let Some(base_domain) = api_host.strip_prefix("api.") {
format!("{}.{}", project_name, base_domain)
} else if api_host == "localhost" || api_host == "127.0.0.1" {
format!("{}.apps.rise.local", project_name)
} else {
format!("{}.{}", project_name, api_host)
};
let scheme = api_url.scheme();
if api_host == "localhost" || api_host == "127.0.0.1" {
format!("{}://{}:8080/", scheme, project_host)
} else {
format!("{}://{}/", scheme, project_host)
}
};
let state_token = generate_state_token();
let code_verifier = generate_code_verifier();
let code_challenge = generate_code_challenge(&code_verifier);
let oauth_state = OAuthState {
redirect_uri: Some(final_redirect_uri),
application_state: req.state,
project_name: project_name.clone(),
extension_name: extension_name.clone(),
code_verifier,
created_at: Utc::now(),
client_code_challenge: req.code_challenge,
client_code_challenge_method: req.code_challenge_method,
};
state
.oauth_state_store
.insert(state_token.clone(), oauth_state)
.await;
let api_url = Url::parse(&state.public_url).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid API URL configuration: {}", e),
)
})?;
let api_host = api_url.host_str().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Missing host in API URL".to_string(),
))?;
let redirect_uri = if let Some(port) = api_url.port() {
format!(
"{}://{}:{}/oidc/{}/{}/callback",
api_url.scheme(),
api_host,
port,
project_name,
extension_name
)
} else {
format!(
"{}://{}/oidc/{}/{}/callback",
api_url.scheme(),
api_host,
project_name,
extension_name
)
};
let endpoints = resolve_oauth_endpoints(&spec).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to resolve OAuth endpoints: {}", e),
)
})?;
let mut auth_url = Url::parse(&endpoints.authorization_endpoint).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid authorization endpoint: {}", e),
)
})?;
auth_url
.query_pairs_mut()
.append_pair("client_id", &spec.client_id)
.append_pair("redirect_uri", &redirect_uri)
.append_pair("response_type", "code")
.append_pair("scope", &spec.scopes.join(" "))
.append_pair("state", &state_token)
.append_pair("code_challenge", &code_challenge)
.append_pair("code_challenge_method", "S256");
debug!("Redirecting to OAuth provider: {}", auth_url.as_str());
Ok(Redirect::to(auth_url.as_str()).into_response())
}
pub async fn callback(
State(state): State<AppState>,
Path((project_name, extension_name)): Path<(String, String)>,
Query(req): Query<CallbackRequest>,
) -> Result<Response, (StatusCode, String)> {
debug!(
"OAuth callback for project={}, extension={}",
project_name, extension_name
);
let oauth_state = state
.oauth_state_store
.get(&req.state)
.await
.ok_or((StatusCode::BAD_REQUEST, "Invalid state token".to_string()))?;
if oauth_state.project_name != project_name || oauth_state.extension_name != extension_name {
return Err((StatusCode::BAD_REQUEST, "State mismatch".to_string()));
}
let final_redirect_uri = oauth_state.redirect_uri.clone().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Missing redirect URI in state".to_string(),
))?;
let is_test_flow = final_redirect_uri.starts_with(&state.public_url);
debug!(
"OAuth callback: flow_type={}, final_redirect_uri={}",
if is_test_flow { "test" } else { "real" },
final_redirect_uri
);
let project = db_projects::find_by_name(&state.db_pool, &project_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Project not found".to_string()))?;
let extension =
db_extensions::find_by_project_and_name(&state.db_pool, project.id, &extension_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((
StatusCode::NOT_FOUND,
"OAuth extension not configured".to_string(),
))?;
let spec: OAuthExtensionSpec = serde_json::from_value(extension.spec.clone()).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid spec: {}", e),
)
})?;
let encryption_provider = state.encryption_provider.as_ref().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Encryption provider not configured".to_string(),
))?;
use super::provider::{OAuthProvider, OAuthProviderConfig};
let oauth_provider = OAuthProvider::new(OAuthProviderConfig {
db_pool: state.db_pool.clone(),
encryption_provider: encryption_provider.clone(),
http_client: reqwest::Client::new(),
api_domain: state.public_url.clone(),
});
let client_secret = oauth_provider
.resolve_oauth_client_secret(project.id, &spec)
.await
.map_err(|e| {
error!("Failed to resolve OAuth client secret: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to resolve OAuth client secret: {}", e),
)
})?;
let api_url = Url::parse(&state.public_url).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid API URL configuration: {}", e),
)
})?;
let api_host = api_url.host_str().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Missing host in API URL".to_string(),
))?;
let redirect_uri = if let Some(port) = api_url.port() {
format!(
"{}://{}:{}/oidc/{}/{}/callback",
api_url.scheme(),
api_host,
port,
project_name,
extension_name
)
} else {
format!(
"{}://{}/oidc/{}/{}/callback",
api_url.scheme(),
api_host,
project_name,
extension_name
)
};
let endpoints = resolve_oauth_endpoints(&spec).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to resolve OAuth endpoints: {}", e),
)
})?;
let http_client = reqwest::Client::new();
let response = http_client
.post(&endpoints.token_endpoint)
.header("Accept", "application/json")
.form(&[
("grant_type", "authorization_code"),
("code", &req.code),
("client_id", &spec.client_id),
("client_secret", &client_secret),
("redirect_uri", &redirect_uri),
("code_verifier", &oauth_state.code_verifier),
])
.send()
.await
.map_err(|e| {
error!("Token exchange request failed: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Token exchange request failed: {}", e),
)
})?;
if is_test_flow {
info!(
"Processing test OAuth flow for project {} extension {}",
project_name, extension_name
);
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unable to read error response".to_string());
error!(
"Token exchange failed with status {}: {}",
status, error_text
);
let mut ext_status: OAuthExtensionStatus =
serde_json::from_value(extension.status.clone()).unwrap_or_default();
ext_status.error = Some(format!(
"Token exchange failed with status {}: {}",
status, error_text
));
ext_status.auth_verified = false;
db_extensions::update_status(
&state.db_pool,
project.id,
&extension_name,
&serde_json::to_value(&ext_status).unwrap(),
)
.await
.map_err(|e| {
warn!("Failed to update extension status: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to update status: {}", e),
)
})?;
state.oauth_state_store.invalidate(&req.state).await;
let mut redirect_url = Url::parse(&final_redirect_uri).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid redirect URI: {}", e),
)
})?;
redirect_url
.query_pairs_mut()
.append_pair("error", "oauth_token_exchange_failed");
return Ok(Redirect::to(redirect_url.as_str()).into_response());
}
let mut ext_status: OAuthExtensionStatus =
serde_json::from_value(extension.status.clone()).unwrap_or_default();
ext_status.redirect_uri = Some(redirect_uri);
ext_status.configured_at = Some(Utc::now());
ext_status.auth_verified = true;
ext_status.error = None;
db_extensions::update_status(
&state.db_pool,
project.id,
&extension_name,
&serde_json::to_value(&ext_status).unwrap(),
)
.await
.map_err(|e| {
warn!("Failed to update extension status: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to update status: {}", e),
)
})?;
state.oauth_state_store.invalidate(&req.state).await;
info!(
"Completed test OAuth flow for project {} extension {}",
project_name, extension_name
);
let redirect_url = Url::parse(&final_redirect_uri).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid redirect URI: {}", e),
)
})?;
Ok(Redirect::to(redirect_url.as_str()).into_response())
} else {
info!(
"Processing real OAuth flow for project {} extension {}",
project_name, extension_name
);
let status_code = response.status().as_u16();
let content_type = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("application/json")
.to_string();
let response_body = response.bytes().await.map_err(|e| {
error!("Failed to read token response body: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to read token response".to_string(),
)
})?;
let encryption_provider = state.encryption_provider.as_ref().ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"Encryption provider not configured".to_string(),
))?;
let token_response_encrypted = encryption_provider
.encrypt(&String::from_utf8_lossy(&response_body))
.await
.map_err(|e| {
error!("Failed to encrypt token response: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to encrypt token response".to_string(),
)
})?;
debug!(
"Cached token response: status={}, content_type={}",
status_code, content_type
);
let mut ext_status: OAuthExtensionStatus =
serde_json::from_value(extension.status.clone()).unwrap_or_default();
ext_status.redirect_uri = Some(redirect_uri);
ext_status.configured_at = Some(Utc::now());
ext_status.auth_verified = true;
ext_status.error = None;
db_extensions::update_status(
&state.db_pool,
project.id,
&extension_name,
&serde_json::to_value(&ext_status).unwrap(),
)
.await
.map_err(|e| {
warn!("Failed to update extension status: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to update status: {}", e),
)
})?;
state.oauth_state_store.invalidate(&req.state).await;
let authorization_code = generate_state_token();
let code_state = OAuthCodeState {
project_id: project.id,
extension_name: extension_name.clone(),
created_at: Utc::now(),
redirect_uri: oauth_state.redirect_uri.clone(),
code_challenge: oauth_state.client_code_challenge.clone(),
code_challenge_method: oauth_state.client_code_challenge_method.clone(),
token_response_encrypted,
content_type,
status_code,
};
state
.oauth_code_store
.insert(authorization_code.clone(), code_state)
.await;
let mut redirect_url = Url::parse(&final_redirect_uri).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid redirect URI: {}", e),
)
})?;
redirect_url
.query_pairs_mut()
.append_pair("code", &authorization_code);
if let Some(app_state) = oauth_state.application_state {
redirect_url
.query_pairs_mut()
.append_pair("state", &app_state);
}
info!(
"Generated authorization code for project {} extension {}",
project_name, extension_name
);
info!(
"OAuth callback complete: redirecting to {}",
redirect_url.as_str()
);
Ok(Redirect::to(redirect_url.as_str()).into_response())
}
}
fn validate_pkce(code_verifier: &str, code_challenge: &str, code_challenge_method: &str) -> bool {
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
match code_challenge_method {
"S256" => {
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let hash = hasher.finalize();
let computed_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash);
computed_challenge
.as_bytes()
.ct_eq(code_challenge.as_bytes())
.into()
}
"plain" => {
code_verifier
.as_bytes()
.ct_eq(code_challenge.as_bytes())
.into()
}
_ => false,
}
}
fn oauth2_error(
error: &str,
description: Option<String>,
) -> (StatusCode, Json<OAuth2ErrorResponse>) {
let status_code = match error {
"invalid_request" => StatusCode::BAD_REQUEST,
"invalid_client" => StatusCode::UNAUTHORIZED,
"invalid_grant" => StatusCode::BAD_REQUEST,
"unsupported_grant_type" => StatusCode::BAD_REQUEST,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
(
status_code,
Json(OAuth2ErrorResponse {
error: error.to_string(),
error_description: description,
}),
)
}
pub async fn token_endpoint(
State(state): State<AppState>,
Path((project_name, extension_name)): Path<(String, String)>,
headers: axum::http::HeaderMap,
body: String,
) -> Response {
debug!(
"Token endpoint request for project={}, extension={}",
project_name, extension_name
);
let origin = headers
.get(header::ORIGIN)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let result = token_endpoint_inner(&state, &project_name, &extension_name, &headers, body).await;
let validated_cors_headers = if let Some(ref origin_str) = origin {
if let Ok(Some(project)) = db_projects::find_by_name(&state.db_pool, &project_name).await {
validate_cors_origin(
&state.db_pool,
origin_str,
&project,
&state.public_url,
&state.deployment_backend,
)
.await
.map(|allowed| cors_headers(&allowed))
} else {
None
}
} else {
None
};
match result {
Ok(mut response) => {
if let Some(cors) = validated_cors_headers {
response.headers_mut().extend(cors);
}
response
}
Err((status, error_json)) => {
let mut response = (status, error_json).into_response();
if let Some(cors) = validated_cors_headers {
response.headers_mut().extend(cors);
} else if let Some(origin_str) = origin {
response.headers_mut().extend(cors_headers(&origin_str));
}
response
}
}
}
async fn token_endpoint_inner(
state: &AppState,
project_name: &str,
extension_name: &str,
headers: &axum::http::HeaderMap,
body: String,
) -> Result<Response, (StatusCode, Json<OAuth2ErrorResponse>)> {
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("application/x-www-form-urlencoded");
let req: TokenRequest = if content_type.contains("application/json") {
serde_json::from_str(&body).map_err(|e| {
oauth2_error(
"invalid_request",
Some(format!("Invalid JSON request body: {}", e)),
)
})?
} else {
serde_urlencoded::from_str(&body).map_err(|e| {
oauth2_error(
"invalid_request",
Some(format!("Invalid form-urlencoded request body: {}", e)),
)
})?
};
let project = db_projects::find_by_name(&state.db_pool, project_name)
.await
.map_err(|e| {
error!("Database error: {:?}", e);
oauth2_error("server_error", Some("Internal server error".to_string()))
})?
.ok_or_else(|| oauth2_error("invalid_request", Some("Project not found".to_string())))?;
let extension =
db_extensions::find_by_project_and_name(&state.db_pool, project.id, extension_name)
.await
.map_err(|e| {
error!("Database error: {:?}", e);
oauth2_error("server_error", Some("Internal server error".to_string()))
})?
.ok_or_else(|| {
oauth2_error(
"invalid_request",
Some("OAuth extension not configured".to_string()),
)
})?;
if extension.extension_type != "oauth" {
return Err(oauth2_error(
"invalid_request",
Some(format!(
"Extension '{}' is not an OAuth extension",
extension_name
)),
));
}
let spec: OAuthExtensionSpec = serde_json::from_value(extension.spec.clone()).map_err(|e| {
error!("Invalid extension spec: {:?}", e);
oauth2_error(
"server_error",
Some("Invalid extension configuration".to_string()),
)
})?;
let status: OAuthExtensionStatus =
serde_json::from_value(extension.status.clone()).map_err(|e| {
error!("Invalid extension status: {:?}", e);
oauth2_error("server_error", Some("Invalid extension status".to_string()))
})?;
let rise_client_id = status.rise_client_id.as_ref().ok_or_else(|| {
error!("Rise client ID not configured for extension");
oauth2_error(
"server_error",
Some("OAuth extension not fully configured".to_string()),
)
})?;
if &req.client_id != rise_client_id {
return Err(oauth2_error(
"invalid_client",
Some("Invalid client_id".to_string()),
));
}
let has_client_secret = req.client_secret.is_some();
let has_code_verifier = req.code_verifier.is_some();
match req.grant_type.as_str() {
"authorization_code" => {
if !has_client_secret && !has_code_verifier {
return Err(oauth2_error(
"invalid_request",
Some("Missing client authentication: provide either client_secret (confidential clients) or code_verifier (public clients with PKCE)".to_string()),
));
}
if has_client_secret && has_code_verifier {
return Err(oauth2_error(
"invalid_request",
Some("Client authentication methods are mutually exclusive: provide either client_secret (confidential clients) or code_verifier (public clients), not both".to_string()),
));
}
}
"refresh_token" => {
if has_code_verifier {
return Err(oauth2_error(
"invalid_request",
Some("code_verifier not supported for refresh_token grant (PKCE is only for authorization_code)".to_string()),
));
}
}
_ => {
}
}
if let Some(ref client_secret) = req.client_secret {
let stored_secret = status.rise_client_secret.as_ref().ok_or_else(|| {
error!("Rise client secret not configured for extension");
oauth2_error(
"server_error",
Some("OAuth extension not fully configured".to_string()),
)
})?;
use subtle::ConstantTimeEq;
let is_valid: bool = client_secret
.as_bytes()
.ct_eq(stored_secret.as_bytes())
.into();
if !is_valid {
return Err(oauth2_error(
"invalid_client",
Some("Invalid client_secret".to_string()),
));
}
}
match req.grant_type.as_str() {
"authorization_code" => {
handle_authorization_code_grant(state.clone(), project, extension_name.to_string(), req)
.await
}
"refresh_token" => {
let json_response = handle_refresh_token_grant(
state.clone(),
project,
extension_name.to_string(),
spec,
req,
)
.await?;
use axum::body::Body;
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_vec(&json_response.0).unwrap()))
.unwrap();
Ok(response)
}
_ => Err(oauth2_error(
"unsupported_grant_type",
Some(format!("Unsupported grant_type: {}", req.grant_type)),
)),
}
}
async fn handle_authorization_code_grant(
state: AppState,
project: crate::db::models::Project,
extension_name: String,
req: TokenRequest,
) -> Result<Response, (StatusCode, Json<OAuth2ErrorResponse>)> {
if req.client_secret.is_some() && req.code_verifier.is_some() {
return Err(oauth2_error(
"invalid_request",
Some("Authentication methods must be mutually exclusive".to_string()),
));
}
let code = req.code.ok_or_else(|| {
oauth2_error(
"invalid_request",
Some("Missing required parameter: code".to_string()),
)
})?;
let code_state = state.oauth_code_store.remove(&code).await.ok_or_else(|| {
oauth2_error(
"invalid_grant",
Some("Invalid or expired authorization code".to_string()),
)
})?;
if code_state.project_id != project.id || code_state.extension_name != extension_name {
return Err(oauth2_error(
"invalid_grant",
Some("Authorization code mismatch".to_string()),
));
}
if let Some(ref stored_redirect_uri) = code_state.redirect_uri {
match &req.redirect_uri {
Some(req_redirect_uri) if req_redirect_uri == stored_redirect_uri => {
}
Some(_) => {
return Err(oauth2_error(
"invalid_grant",
Some("redirect_uri does not match authorization request".to_string()),
));
}
None => {
return Err(oauth2_error(
"invalid_request",
Some("redirect_uri required (was provided during authorization)".to_string()),
));
}
}
} else if req.redirect_uri.is_some() {
return Err(oauth2_error(
"invalid_request",
Some("redirect_uri was not provided during authorization".to_string()),
));
}
if req.code_verifier.is_some() && code_state.code_challenge.is_none() {
return Err(oauth2_error(
"invalid_request",
Some("code_verifier requires prior code_challenge during authorization".to_string()),
));
}
if let Some(ref code_challenge) = code_state.code_challenge {
let code_verifier = req.code_verifier.ok_or_else(|| {
oauth2_error(
"invalid_request",
Some("Missing code_verifier for PKCE flow".to_string()),
)
})?;
if code_verifier.len() < 43 || code_verifier.len() > 128 {
return Err(oauth2_error(
"invalid_request",
Some("code_verifier must be 43-128 characters".to_string()),
));
}
if !code_verifier
.chars()
.all(|c| c.is_ascii_alphanumeric() || "-._~".contains(c))
{
return Err(oauth2_error(
"invalid_request",
Some("code_verifier contains invalid characters".to_string()),
));
}
let code_challenge_method = code_state
.code_challenge_method
.as_deref()
.unwrap_or("S256");
if !validate_pkce(&code_verifier, code_challenge, code_challenge_method) {
return Err(oauth2_error(
"invalid_grant",
Some("PKCE validation failed".to_string()),
));
}
debug!("PKCE validation successful");
}
let encryption_provider = state.encryption_provider.as_ref().ok_or_else(|| {
error!("Encryption provider not configured");
oauth2_error("server_error", Some("Internal server error".to_string()))
})?;
let token_response_body = encryption_provider
.decrypt(&code_state.token_response_encrypted)
.await
.map_err(|e| {
error!("Failed to decrypt token response: {:?}", e);
oauth2_error("server_error", Some("Internal server error".to_string()))
})?;
info!(
"Authorization code grant successful for project {} extension {}",
project.name, extension_name
);
use axum::body::Body;
let response = Response::builder()
.status(StatusCode::from_u16(code_state.status_code).unwrap_or(StatusCode::OK))
.header(header::CONTENT_TYPE, code_state.content_type)
.body(Body::from(token_response_body))
.unwrap();
Ok(response)
}
async fn handle_refresh_token_grant(
state: AppState,
project: crate::db::models::Project,
extension_name: String,
spec: OAuthExtensionSpec,
req: TokenRequest,
) -> Result<Json<OAuth2TokenResponse>, (StatusCode, Json<OAuth2ErrorResponse>)> {
let refresh_token = req.refresh_token.ok_or_else(|| {
oauth2_error(
"invalid_request",
Some("Missing required parameter: refresh_token".to_string()),
)
})?;
use super::provider::{OAuthProvider, OAuthProviderConfig};
let oauth_provider = OAuthProvider::new(OAuthProviderConfig {
db_pool: state.db_pool.clone(),
encryption_provider: state.encryption_provider.clone().ok_or_else(|| {
error!("Encryption provider not configured");
oauth2_error("server_error", Some("Internal server error".to_string()))
})?,
http_client: reqwest::Client::new(),
api_domain: state.public_url.clone(),
});
let client_secret = oauth_provider
.resolve_oauth_client_secret(project.id, &spec)
.await
.map_err(|e| {
error!("Failed to resolve OAuth client secret: {:?}", e);
oauth2_error("server_error", Some("Internal server error".to_string()))
})?;
let token_response = oauth_provider
.refresh_token(&spec, &client_secret, &refresh_token)
.await
.map_err(|e| {
error!("Failed to refresh token with upstream provider: {:?}", e);
oauth2_error("invalid_grant", Some("Failed to refresh token".to_string()))
})?;
let expires_in = token_response.expires_in;
let scope = if spec.scopes.is_empty() {
None
} else {
Some(spec.scopes.join(" "))
};
info!(
"Refresh token grant successful for project {} extension {}",
project.name, extension_name
);
Ok(Json(OAuth2TokenResponse {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_in,
refresh_token: token_response.refresh_token,
scope,
id_token: token_response.id_token,
}))
}
pub async fn token_endpoint_options(
State(state): State<AppState>,
Path((project_name, _extension_name)): Path<(String, String)>,
headers: axum::http::HeaderMap,
) -> Response {
let origin = match headers.get(header::ORIGIN).and_then(|h| h.to_str().ok()) {
Some(o) => o,
None => {
return StatusCode::NO_CONTENT.into_response();
}
};
let project = match db_projects::find_by_name(&state.db_pool, &project_name).await {
Ok(Some(p)) => p,
_ => {
return StatusCode::FORBIDDEN.into_response();
}
};
match validate_cors_origin(
&state.db_pool,
origin,
&project,
&state.public_url,
&state.deployment_backend,
)
.await
{
Some(allowed_origin) => {
let cors = cors_headers(&allowed_origin);
(StatusCode::NO_CONTENT, cors).into_response()
}
None => {
debug!(
"CORS origin '{}' not allowed for project '{}'",
origin, project_name
);
StatusCode::FORBIDDEN.into_response()
}
}
}
pub async fn oidc_discovery(
State(state): State<AppState>,
Path((project_name, extension_name)): Path<(String, String)>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
debug!(
"OIDC discovery request for project={}, extension={}",
project_name, extension_name
);
let project = db_projects::find_by_name(&state.db_pool, &project_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Project not found".to_string()))?;
let extension =
db_extensions::find_by_project_and_name(&state.db_pool, project.id, &extension_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((
StatusCode::NOT_FOUND,
"OAuth extension not configured".to_string(),
))?;
if extension.extension_type != "oauth" {
return Err((
StatusCode::BAD_REQUEST,
format!(
"Extension '{}' is not an OAuth extension (type: {})",
extension_name, extension.extension_type
),
));
}
let spec: OAuthExtensionSpec = serde_json::from_value(extension.spec.clone()).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid spec: {}", e),
)
})?;
let rise_oidc_base = format!(
"{}/oidc/{}/{}",
state.public_url.trim_end_matches('/'),
project_name,
extension_name
);
if spec.authorization_endpoint.is_some() && spec.token_endpoint.is_some() {
debug!(
"Both authorization_endpoint and token_endpoint in spec - synthesizing discovery document for {}/{}",
project_name, extension_name
);
let discovery = serde_json::json!({
"issuer": rise_oidc_base,
"authorization_endpoint": format!("{}/authorize", rise_oidc_base),
"token_endpoint": format!("{}/token", rise_oidc_base),
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256", "plain"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"]
});
info!(
"Returning synthesized OIDC discovery for {}/{} (non-OIDC OAuth 2.0 provider)",
project_name, extension_name
);
return Ok((
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
Json(discovery),
));
}
let upstream_issuer = &spec.issuer_url;
let discovery_result = fetch_oidc_discovery(upstream_issuer).await;
match discovery_result {
Ok(upstream_discovery) => {
debug!(
"Fetched upstream OIDC discovery for {}/{}",
project_name, extension_name
);
let mut discovery = serde_json::json!({
"issuer": rise_oidc_base,
"authorization_endpoint": format!("{}/authorize", rise_oidc_base),
"token_endpoint": format!("{}/token", rise_oidc_base),
"jwks_uri": format!("{}/jwks", rise_oidc_base),
});
if let Ok(upstream_json) = serde_json::to_value(&upstream_discovery) {
if let Some(upstream_obj) = upstream_json.as_object() {
if let Some(discovery_obj) = discovery.as_object_mut() {
for (key, value) in upstream_obj {
if key != "issuer"
&& key != "authorization_endpoint"
&& key != "token_endpoint"
&& key != "jwks_uri"
{
discovery_obj.insert(key.clone(), value.clone());
}
}
}
}
}
info!(
"Returning OIDC discovery for {}/{} with Rise OIDC base: {}",
project_name, extension_name, rise_oidc_base
);
Ok((
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
Json(discovery),
))
}
Err(e) => {
debug!(
"OIDC discovery failed for {}/{}: {}",
project_name, extension_name, e
);
if spec.authorization_endpoint.is_some() {
warn!(
"OIDC discovery failed for {}/{} - synthesizing from spec (fallback for non-OIDC provider)",
project_name, extension_name
);
let discovery = serde_json::json!({
"issuer": rise_oidc_base,
"authorization_endpoint": format!("{}/authorize", rise_oidc_base),
"token_endpoint": format!("{}/token", rise_oidc_base),
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256", "plain"],
"token_endpoint_auth_methods_supported": ["client_secret_post", "none"]
});
info!(
"Returning synthesized OIDC discovery for {}/{} (fallback)",
project_name, extension_name
);
Ok((
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
Json(discovery),
))
} else {
error!(
"OIDC discovery failed and no authorization_endpoint in spec for {}/{}",
project_name, extension_name
);
Err((
StatusCode::BAD_GATEWAY,
format!(
"OIDC discovery failed and no authorization_endpoint configured: {}",
e
),
))
}
}
}
}
pub async fn oidc_jwks(
State(state): State<AppState>,
Path((project_name, extension_name)): Path<(String, String)>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
debug!(
"OIDC JWKS request for project={}, extension={}",
project_name, extension_name
);
let project = db_projects::find_by_name(&state.db_pool, &project_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Project not found".to_string()))?;
let extension =
db_extensions::find_by_project_and_name(&state.db_pool, project.id, &extension_name)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((
StatusCode::NOT_FOUND,
"OAuth extension not configured".to_string(),
))?;
if extension.extension_type != "oauth" {
return Err((
StatusCode::BAD_REQUEST,
format!(
"Extension '{}' is not an OAuth extension (type: {})",
extension_name, extension.extension_type
),
));
}
let spec: OAuthExtensionSpec = serde_json::from_value(extension.spec.clone()).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Invalid spec: {}", e),
)
})?;
let discovery_result = fetch_oidc_discovery(&spec.issuer_url).await;
match discovery_result {
Ok(discovery) => {
let jwks_uri = discovery.jwks_uri.ok_or((
StatusCode::INTERNAL_SERVER_ERROR,
"No jwks_uri in OIDC discovery".to_string(),
))?;
let http_client = reqwest::Client::new();
let jwks_response = http_client.get(&jwks_uri).send().await.map_err(|e| {
error!("Failed to fetch JWKS from {}: {:?}", jwks_uri, e);
(
StatusCode::BAD_GATEWAY,
format!("Failed to fetch JWKS: {}", e),
)
})?;
if !jwks_response.status().is_success() {
let status = jwks_response.status();
return Err((
StatusCode::BAD_GATEWAY,
format!("Upstream JWKS fetch failed: {}", status),
));
}
let jwks: serde_json::Value = jwks_response.json().await.map_err(|e| {
error!("Failed to parse JWKS response: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to parse JWKS".to_string(),
)
})?;
info!(
"Returning JWKS for {}/{} from upstream: {}",
project_name, extension_name, jwks_uri
);
Ok((
StatusCode::OK,
[(header::CONTENT_TYPE, "application/json")],
Json(jwks),
))
}
Err(e) => {
warn!(
"JWKS not available for {}/{}: upstream OIDC discovery failed: {}",
project_name, extension_name, e
);
Err((
StatusCode::NOT_IMPLEMENTED,
"JWKS endpoint not available: upstream provider does not support OIDC discovery. \
JWKS is only available for OIDC-compliant providers (e.g., Google, Dex). \
Plain OAuth 2.0 providers (e.g., GitHub) do not provide public keys via JWKS."
.to_string(),
))
}
}
}