use anyhow::Result;
use axum::Json;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Redirect};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::routes::oauth::extractors::OAuthRepo;
use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
use systemprompt_oauth::OAuthState;
use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
use systemprompt_oauth::services::webauthn::WebAuthnManager;
use systemprompt_oauth::services::{generate_secure_token, is_browser_request};
#[derive(Debug, Deserialize)]
pub struct WebAuthnCompleteQuery {
pub user_id: UserId,
pub auth_token: Option<String>,
pub response_type: Option<String>,
pub client_id: Option<ClientId>,
pub redirect_uri: Option<String>,
pub scope: Option<String>,
pub state: Option<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
pub response_mode: Option<String>,
pub resource: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteError {
pub error: String,
pub error_description: String,
}
#[allow(unused_qualifications)]
pub async fn handle_webauthn_complete(
headers: HeaderMap,
Query(params): Query<WebAuthnCompleteQuery>,
State(state): State<OAuthState>,
OAuthRepo(repo): OAuthRepo,
) -> impl IntoResponse {
let auth_token = match ¶ms.auth_token {
Some(token) => token.clone(),
None => {
return (
StatusCode::BAD_REQUEST,
Json(WebAuthnCompleteError {
error: "invalid_request".to_string(),
error_description: "Missing auth_token parameter".to_string(),
}),
)
.into_response();
},
};
let user_provider = state.user_provider();
let webauthn_service =
match WebAuthnManager::get_or_create_service(repo.clone(), Arc::clone(user_provider)).await
{
Ok(service) => service,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(WebAuthnCompleteError {
error: "server_error".to_string(),
error_description: format!("WebAuthn service initialization failed: {e}"),
}),
)
.into_response();
},
};
let Ok(verified_user_id) = webauthn_service
.consume_verified_authentication(&auth_token)
.await
else {
return (
StatusCode::UNAUTHORIZED,
Json(WebAuthnCompleteError {
error: "access_denied".to_string(),
error_description: "Invalid or expired authentication token".to_string(),
}),
)
.into_response();
};
if params.user_id != verified_user_id {
tracing::warn!(
claimed_user_id = %params.user_id,
verified_user_id = %verified_user_id,
"WebAuthn complete user_id mismatch"
);
return (
StatusCode::UNAUTHORIZED,
Json(WebAuthnCompleteError {
error: "access_denied".to_string(),
error_description: "User identity verification failed".to_string(),
}),
)
.into_response();
}
if params.client_id.is_none() {
return (
StatusCode::BAD_REQUEST,
Json(WebAuthnCompleteError {
error: "invalid_request".to_string(),
error_description: "Missing client_id parameter".to_string(),
}),
)
.into_response();
}
let Some(redirect_uri) = ¶ms.redirect_uri else {
return (
StatusCode::BAD_REQUEST,
Json(WebAuthnCompleteError {
error: "invalid_request".to_string(),
error_description: "Missing redirect_uri parameter".to_string(),
}),
)
.into_response();
};
match user_provider.find_by_id(&verified_user_id).await {
Ok(Some(_)) => {
let authorization_code = generate_secure_token("auth_code");
match store_authorization_code(&repo, &authorization_code, ¶ms).await {
Ok(()) => {
create_successful_response(&headers, redirect_uri, &authorization_code, ¶ms)
},
Err(error) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(WebAuthnCompleteError {
error: "server_error".to_string(),
error_description: error.to_string(),
}),
)
.into_response(),
}
},
Ok(None) => (
StatusCode::UNAUTHORIZED,
Json(WebAuthnCompleteError {
error: "access_denied".to_string(),
error_description: "User not found".to_string(),
}),
)
.into_response(),
Err(error) => {
let status_code = if error.to_string().contains("User not found") {
StatusCode::UNAUTHORIZED
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
let error_type = if status_code == StatusCode::UNAUTHORIZED {
"access_denied"
} else {
"server_error"
};
(
status_code,
Json(WebAuthnCompleteError {
error: error_type.to_string(),
error_description: error.to_string(),
}),
)
.into_response()
},
}
}
async fn store_authorization_code(
repo: &OAuthRepository,
code_str: &str,
query: &WebAuthnCompleteQuery,
) -> Result<()> {
let client_id = query
.client_id
.as_ref()
.ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
let redirect_uri = query
.redirect_uri
.as_ref()
.ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
let scope = query.scope.as_ref().map_or_else(
|| {
let default_roles = OAuthRepository::get_default_roles();
if default_roles.is_empty() {
"user".to_string()
} else {
default_roles.join(" ")
}
},
Clone::clone,
);
let code = AuthorizationCode::new(code_str);
let mut builder =
AuthCodeParams::builder(&code, client_id, &query.user_id, redirect_uri, &scope);
if let (Some(challenge), Some(method)) = (
query.code_challenge.as_deref(),
query
.code_challenge_method
.as_deref()
.filter(|s| !s.is_empty()),
) {
builder = builder.with_pkce(challenge, method);
}
if let Some(resource) = query.resource.as_deref() {
builder = builder.with_resource(resource);
}
repo.store_authorization_code(builder.build()).await
}
#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteResponse {
pub authorization_code: String,
pub state: String,
pub redirect_uri: String,
pub client_id: ClientId,
}
fn create_successful_response(
headers: &HeaderMap,
redirect_uri: &str,
authorization_code: &str,
params: &WebAuthnCompleteQuery,
) -> axum::response::Response {
let state = params.state.as_deref().filter(|s| !s.is_empty());
if is_browser_request(headers) {
let mut target = format!("{redirect_uri}?code={authorization_code}");
if let Some(client_id_val) = params.client_id.as_ref() {
target.push_str(&format!(
"&client_id={}",
urlencoding::encode(client_id_val.as_str())
));
}
if let Some(state_val) = state {
target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
}
Redirect::to(&target).into_response()
} else {
let response_data = WebAuthnCompleteResponse {
authorization_code: authorization_code.to_string(),
state: state.unwrap_or("").to_string(),
redirect_uri: redirect_uri.to_string(),
client_id: params
.client_id
.clone()
.unwrap_or_else(|| ClientId::new("")),
};
Json(response_data).into_response()
}
}