use anyhow::Result;
use axum::Json;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Redirect};
use serde::{Deserialize, Serialize};
use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
use systemprompt_oauth::OAuthState;
use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
use systemprompt_oauth::services::{generate_secure_token, is_browser_request};
#[derive(Debug, Deserialize)]
pub struct WebAuthnCompleteQuery {
pub user_id: String,
pub response_type: Option<String>,
pub client_id: Option<String>,
pub redirect_uri: Option<String>,
pub scope: Option<String>,
pub state: Option<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
pub response_mode: Option<String>,
pub resource: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteError {
pub error: String,
pub error_description: String,
}
#[allow(unused_qualifications)]
pub async fn handle_webauthn_complete(
headers: HeaderMap,
Query(params): Query<WebAuthnCompleteQuery>,
State(state): State<OAuthState>,
) -> impl IntoResponse {
let repo = match OAuthRepository::new(state.db_pool()) {
Ok(r) => r,
Err(e) => {
return (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
).into_response();
},
};
if params.client_id.is_none() {
return (
StatusCode::BAD_REQUEST,
Json(WebAuthnCompleteError {
error: "invalid_request".to_string(),
error_description: "Missing client_id parameter".to_string(),
}),
)
.into_response();
}
let Some(redirect_uri) = ¶ms.redirect_uri else {
return (
StatusCode::BAD_REQUEST,
Json(WebAuthnCompleteError {
error: "invalid_request".to_string(),
error_description: "Missing redirect_uri parameter".to_string(),
}),
)
.into_response();
};
let user_provider = state.user_provider();
match user_provider.find_by_id(¶ms.user_id).await {
Ok(Some(_)) => {
let authorization_code = generate_secure_token("auth_code");
match store_authorization_code(&repo, &authorization_code, ¶ms).await {
Ok(()) => {
create_successful_response(&headers, redirect_uri, &authorization_code, ¶ms)
},
Err(error) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(WebAuthnCompleteError {
error: "server_error".to_string(),
error_description: error.to_string(),
}),
)
.into_response(),
}
},
Ok(None) => (
StatusCode::UNAUTHORIZED,
Json(WebAuthnCompleteError {
error: "access_denied".to_string(),
error_description: "User not found".to_string(),
}),
)
.into_response(),
Err(error) => {
let status_code = if error.to_string().contains("User not found") {
StatusCode::UNAUTHORIZED
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
let error_type = if status_code == StatusCode::UNAUTHORIZED {
"access_denied"
} else {
"server_error"
};
(
status_code,
Json(WebAuthnCompleteError {
error: error_type.to_string(),
error_description: error.to_string(),
}),
)
.into_response()
},
}
}
async fn store_authorization_code(
repo: &OAuthRepository,
code_str: &str,
query: &WebAuthnCompleteQuery,
) -> Result<()> {
let client_id_str = query
.client_id
.as_ref()
.ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
let redirect_uri = query
.redirect_uri
.as_ref()
.ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
let scope = query.scope.as_ref().map_or_else(
|| {
let default_roles = OAuthRepository::get_default_roles();
if default_roles.is_empty() {
"user".to_string()
} else {
default_roles.join(" ")
}
},
Clone::clone,
);
let code = AuthorizationCode::new(code_str);
let client_id = ClientId::new(client_id_str);
let user_id = UserId::new(&query.user_id);
let mut builder = AuthCodeParams::builder(&code, &client_id, &user_id, redirect_uri, &scope);
if let (Some(challenge), Some(method)) = (
query.code_challenge.as_deref(),
query
.code_challenge_method
.as_deref()
.filter(|s| !s.is_empty()),
) {
builder = builder.with_pkce(challenge, method);
}
if let Some(resource) = query.resource.as_deref() {
builder = builder.with_resource(resource);
}
repo.store_authorization_code(builder.build()).await
}
#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteResponse {
pub authorization_code: String,
pub state: String,
pub redirect_uri: String,
pub client_id: String,
}
fn create_successful_response(
headers: &HeaderMap,
redirect_uri: &str,
authorization_code: &str,
params: &WebAuthnCompleteQuery,
) -> axum::response::Response {
let state = params.state.as_deref().filter(|s| !s.is_empty());
if is_browser_request(headers) {
let mut target = format!("{redirect_uri}?code={authorization_code}");
if let Some(client_id_val) = params.client_id.as_deref() {
target.push_str(&format!(
"&client_id={}",
urlencoding::encode(client_id_val)
));
}
if let Some(state_val) = state {
target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
}
Redirect::to(&target).into_response()
} else {
let response_data = WebAuthnCompleteResponse {
authorization_code: authorization_code.to_string(),
state: state.unwrap_or("").to_string(),
redirect_uri: redirect_uri.to_string(),
client_id: params.client_id.as_deref().unwrap_or("").to_string(),
};
let mut response = Json(response_data).into_response();
let headers = response.headers_mut();
headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
headers.insert(
"access-control-allow-methods",
HeaderValue::from_static("GET, POST, OPTIONS"),
);
headers.insert(
"access-control-allow-headers",
HeaderValue::from_static("content-type, authorization"),
);
response
}
}