systemprompt-api 0.2.0

HTTP API server and gateway for systemprompt.io OS
Documentation
use anyhow::Result;
use axum::Json;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Redirect};
use serde::{Deserialize, Serialize};
use std::sync::Arc;

use crate::routes::oauth::extractors::OAuthRepo;
use systemprompt_identifiers::{AuthorizationCode, ClientId, UserId};
use systemprompt_oauth::OAuthState;
use systemprompt_oauth::repository::{AuthCodeParams, OAuthRepository};
use systemprompt_oauth::services::webauthn::WebAuthnManager;
use systemprompt_oauth::services::{generate_secure_token, is_browser_request};

#[derive(Debug, Deserialize)]
pub struct WebAuthnCompleteQuery {
    pub user_id: UserId,
    pub auth_token: Option<String>,
    pub response_type: Option<String>,
    pub client_id: Option<ClientId>,
    pub redirect_uri: Option<String>,
    pub scope: Option<String>,
    pub state: Option<String>,
    pub code_challenge: Option<String>,
    pub code_challenge_method: Option<String>,
    pub response_mode: Option<String>,
    pub resource: Option<String>,
}

#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteError {
    pub error: String,
    pub error_description: String,
}

#[allow(unused_qualifications)]
pub async fn handle_webauthn_complete(
    headers: HeaderMap,
    Query(params): Query<WebAuthnCompleteQuery>,
    State(state): State<OAuthState>,
    OAuthRepo(repo): OAuthRepo,
) -> impl IntoResponse {
    let auth_token = match &params.auth_token {
        Some(token) => token.clone(),
        None => {
            return (
                StatusCode::BAD_REQUEST,
                Json(WebAuthnCompleteError {
                    error: "invalid_request".to_string(),
                    error_description: "Missing auth_token parameter".to_string(),
                }),
            )
                .into_response();
        },
    };

    let user_provider = state.user_provider();
    let webauthn_service =
        match WebAuthnManager::get_or_create_service(repo.clone(), Arc::clone(user_provider)).await
        {
            Ok(service) => service,
            Err(e) => {
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(WebAuthnCompleteError {
                        error: "server_error".to_string(),
                        error_description: format!("WebAuthn service initialization failed: {e}"),
                    }),
                )
                    .into_response();
            },
        };

    let Ok(verified_user_id) = webauthn_service
        .consume_verified_authentication(&auth_token)
        .await
    else {
        return (
            StatusCode::UNAUTHORIZED,
            Json(WebAuthnCompleteError {
                error: "access_denied".to_string(),
                error_description: "Invalid or expired authentication token".to_string(),
            }),
        )
            .into_response();
    };

    if params.user_id != verified_user_id {
        tracing::warn!(
            claimed_user_id = %params.user_id,
            verified_user_id = %verified_user_id,
            "WebAuthn complete user_id mismatch"
        );
        return (
            StatusCode::UNAUTHORIZED,
            Json(WebAuthnCompleteError {
                error: "access_denied".to_string(),
                error_description: "User identity verification failed".to_string(),
            }),
        )
            .into_response();
    }

    if params.client_id.is_none() {
        return (
            StatusCode::BAD_REQUEST,
            Json(WebAuthnCompleteError {
                error: "invalid_request".to_string(),
                error_description: "Missing client_id parameter".to_string(),
            }),
        )
            .into_response();
    }

    let Some(redirect_uri) = &params.redirect_uri else {
        return (
            StatusCode::BAD_REQUEST,
            Json(WebAuthnCompleteError {
                error: "invalid_request".to_string(),
                error_description: "Missing redirect_uri parameter".to_string(),
            }),
        )
            .into_response();
    };

    match user_provider.find_by_id(&verified_user_id).await {
        Ok(Some(_)) => {
            let authorization_code = generate_secure_token("auth_code");

            match store_authorization_code(&repo, &authorization_code, &params).await {
                Ok(()) => {
                    create_successful_response(&headers, redirect_uri, &authorization_code, &params)
                },
                Err(error) => (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(WebAuthnCompleteError {
                        error: "server_error".to_string(),
                        error_description: error.to_string(),
                    }),
                )
                    .into_response(),
            }
        },
        Ok(None) => (
            StatusCode::UNAUTHORIZED,
            Json(WebAuthnCompleteError {
                error: "access_denied".to_string(),
                error_description: "User not found".to_string(),
            }),
        )
            .into_response(),
        Err(error) => {
            let status_code = if error.to_string().contains("User not found") {
                StatusCode::UNAUTHORIZED
            } else {
                StatusCode::INTERNAL_SERVER_ERROR
            };

            let error_type = if status_code == StatusCode::UNAUTHORIZED {
                "access_denied"
            } else {
                "server_error"
            };

            (
                status_code,
                Json(WebAuthnCompleteError {
                    error: error_type.to_string(),
                    error_description: error.to_string(),
                }),
            )
                .into_response()
        },
    }
}

async fn store_authorization_code(
    repo: &OAuthRepository,
    code_str: &str,
    query: &WebAuthnCompleteQuery,
) -> Result<()> {
    let client_id = query
        .client_id
        .as_ref()
        .ok_or_else(|| anyhow::anyhow!("client_id is required"))?;
    let redirect_uri = query
        .redirect_uri
        .as_ref()
        .ok_or_else(|| anyhow::anyhow!("redirect_uri is required"))?;
    let scope = query.scope.as_ref().map_or_else(
        || {
            let default_roles = OAuthRepository::get_default_roles();
            if default_roles.is_empty() {
                "user".to_string()
            } else {
                default_roles.join(" ")
            }
        },
        Clone::clone,
    );

    let code = AuthorizationCode::new(code_str);

    let mut builder =
        AuthCodeParams::builder(&code, client_id, &query.user_id, redirect_uri, &scope);

    if let (Some(challenge), Some(method)) = (
        query.code_challenge.as_deref(),
        query
            .code_challenge_method
            .as_deref()
            .filter(|s| !s.is_empty()),
    ) {
        builder = builder.with_pkce(challenge, method);
    }

    if let Some(resource) = query.resource.as_deref() {
        builder = builder.with_resource(resource);
    }

    repo.store_authorization_code(builder.build()).await
}

#[derive(Debug, Serialize)]
pub struct WebAuthnCompleteResponse {
    pub authorization_code: String,
    pub state: String,
    pub redirect_uri: String,
    pub client_id: ClientId,
}

fn create_successful_response(
    headers: &HeaderMap,
    redirect_uri: &str,
    authorization_code: &str,
    params: &WebAuthnCompleteQuery,
) -> axum::response::Response {
    let state = params.state.as_deref().filter(|s| !s.is_empty());

    if is_browser_request(headers) {
        let mut target = format!("{redirect_uri}?code={authorization_code}");

        if let Some(client_id_val) = params.client_id.as_ref() {
            target.push_str(&format!(
                "&client_id={}",
                urlencoding::encode(client_id_val.as_str())
            ));
        }

        if let Some(state_val) = state {
            target.push_str(&format!("&state={}", urlencoding::encode(state_val)));
        }
        Redirect::to(&target).into_response()
    } else {
        let response_data = WebAuthnCompleteResponse {
            authorization_code: authorization_code.to_string(),
            state: state.unwrap_or("").to_string(),
            redirect_uri: redirect_uri.to_string(),
            client_id: params
                .client_id
                .clone()
                .unwrap_or_else(|| ClientId::new("")),
        };

        Json(response_data).into_response()
    }
}