use anyhow::Result;
use axum::Json;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Redirect};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::routes::oauth::extractors::OAuthRepo;
use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
use systemprompt_oauth::OAuthState;
use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
use systemprompt_oauth::services::webauthn::WebAuthnRegistry;
use systemprompt_oauth::services::{generate_secure_token, is_browser_request};
#[derive(Debug, Deserialize)]
pub struct WebAuthnCompleteQuery {
pub user_id: UserId,
pub auth_token: Option<String>,
pub response_type: Option<String>,
pub client_id: Option<ClientId>,
pub redirect_uri: Option<String>,
pub scope: Option<String>,
pub state: Option<String>,
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
pub response_mode: Option<String>,
pub resource: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteError {
pub error: String,
pub error_description: String,
}
fn error_response(
status: StatusCode,
error: &str,
description: impl Into<String>,
) -> axum::response::Response {
(
status,
Json(WebAuthnCompleteError {
error: error.to_string(),
error_description: description.into(),
}),
)
.into_response()
}
async fn verify_completion(
params: &WebAuthnCompleteQuery,
state: &OAuthState,
repo: &OAuthRepository,
) -> Result<(UserId, String), axum::response::Response> {
let Some(auth_token) = ¶ms.auth_token else {
return Err(error_response(
StatusCode::BAD_REQUEST,
"invalid_request",
"Missing auth_token parameter",
));
};
let webauthn_service =
WebAuthnRegistry::get_or_create_service(repo.clone(), Arc::clone(state.user_provider()))
.await
.map_err(|e| {
error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
format!("WebAuthn service initialization failed: {e}"),
)
})?;
let verified_user_id = webauthn_service
.consume_verified_authentication(auth_token)
.await
.map_err(|_| {
error_response(
StatusCode::UNAUTHORIZED,
"access_denied",
"Invalid or expired authentication token",
)
})?;
if params.user_id != verified_user_id {
tracing::warn!(
claimed_user_id = %params.user_id,
verified_user_id = %verified_user_id,
"WebAuthn complete user_id mismatch"
);
return Err(error_response(
StatusCode::UNAUTHORIZED,
"access_denied",
"User identity verification failed",
));
}
if params.client_id.is_none() {
return Err(error_response(
StatusCode::BAD_REQUEST,
"invalid_request",
"Missing client_id parameter",
));
}
let Some(redirect_uri) = ¶ms.redirect_uri else {
return Err(error_response(
StatusCode::BAD_REQUEST,
"invalid_request",
"Missing redirect_uri parameter",
));
};
Ok((verified_user_id, redirect_uri.clone()))
}
#[allow(unused_qualifications)]
pub async fn handle_webauthn_complete(
headers: HeaderMap,
Query(params): Query<WebAuthnCompleteQuery>,
State(state): State<OAuthState>,
OAuthRepo(repo): OAuthRepo,
) -> impl IntoResponse {
let (verified_user_id, redirect_uri) = match verify_completion(¶ms, &state, &repo).await {
Ok(verified) => verified,
Err(response) => return response,
};
match state.user_provider().find_by_id(&verified_user_id).await {
Ok(Some(_)) => {
let authorization_code = generate_secure_token("auth_code");
match store_authorization_code(&repo, &authorization_code, ¶ms).await {
Ok(()) => create_successful_response(
&headers,
&redirect_uri,
&authorization_code,
¶ms,
),
Err(error) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
error.to_string(),
),
}
},
Ok(None) => error_response(StatusCode::UNAUTHORIZED, "access_denied", "User not found"),
Err(error) => {
let is_not_found = error.to_string().contains("User not found");
let status_code = if is_not_found {
StatusCode::UNAUTHORIZED
} else {
StatusCode::INTERNAL_SERVER_ERROR
};
let error_type = if is_not_found {
"access_denied"
} else {
"server_error"
};
error_response(status_code, error_type, error.to_string())
},
}
}
async fn store_authorization_code(
repo: &OAuthRepository,
code_str: &str,
query: &WebAuthnCompleteQuery,
) -> Result<()> {
let client_id = query
.client_id
.as_ref()
.ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
let redirect_uri = query
.redirect_uri
.as_ref()
.ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
let scope = query.scope.as_ref().map_or_else(
|| {
let default_roles = OAuthRepository::get_default_roles();
if default_roles.is_empty() {
"user".to_string()
} else {
default_roles.join(" ")
}
},
Clone::clone,
);
let code = AuthorizationCode::new(code_str);
let mut builder =
AuthCodeParams::builder(&code, client_id, &query.user_id, redirect_uri, &scope);
if let (Some(challenge), Some(method)) = (
query.code_challenge.as_deref(),
query
.code_challenge_method
.as_deref()
.filter(|s| !s.is_empty()),
) {
builder = builder.with_pkce(challenge, method);
}
if let Some(resource) = query.resource.as_deref() {
builder = builder.with_resource(resource);
}
repo.store_authorization_code(builder.build())
.await
.map_err(Into::into)
}
#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteResponse {
pub authorization_code: String,
pub state: String,
pub redirect_uri: String,
pub client_id: ClientId,
}
fn create_successful_response(
headers: &HeaderMap,
redirect_uri: &str,
authorization_code: &str,
params: &WebAuthnCompleteQuery,
) -> axum::response::Response {
let state = params.state.as_deref().filter(|s| !s.is_empty());
if is_browser_request(headers) {
let mut target = format!("{redirect_uri}?code={authorization_code}");
if let Some(client_id_val) = params.client_id.as_ref() {
target.push_str(&format!(
"&client_id={}",
urlencoding::encode(client_id_val.as_str())
));
}
if let Some(state_val) = state {
target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
}
Redirect::to(&target).into_response()
} else {
let response_data = WebAuthnCompleteResponse {
authorization_code: authorization_code.to_string(),
state: state.unwrap_or("").to_string(),
redirect_uri: redirect_uri.to_string(),
client_id: params
.client_id
.clone()
.unwrap_or_else(|| ClientId::new("")),
};
Json(response_data).into_response()
}
}