systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use axum::Json;
use axum::extract::{Query, State};
use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
use axum::response::IntoResponse;
use serde::Deserialize;
use std::sync::Arc;
use systemprompt_oauth::OAuthState;
use systemprompt_oauth::repository::OAuthRepository;
use systemprompt_oauth::services::webauthn::WebAuthnManager;
use tracing::instrument;

use super::RegisterError;

#[derive(Debug, Deserialize)]
pub struct StartRegisterQuery {
    pub username: String,
    pub email: String,
    pub full_name: Option<String>,
}

impl StartRegisterQuery {
    fn validate(&self) -> Result<(), String> {
        if self.username.trim().is_empty() {
            return Err("Username is required and cannot be empty".to_string());
        }

        if self.username.len() > 50 {
            return Err("Username must be less than 50 characters".to_string());
        }

        if !self
            .username
            .chars()
            .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
        {
            return Err(
                "Username can only contain letters, numbers, underscores, and hyphens".to_string(),
            );
        }

        if self.email.trim().is_empty() {
            return Err("Email is required and cannot be empty".to_string());
        }

        if !self.email.contains('@') || !self.email.contains('.') {
            return Err("Email must be a valid email address".to_string());
        }

        Ok(())
    }
}

#[allow(unused_qualifications)]
#[instrument(skip(state, params), fields(username = %params.username, email = %params.email))]
pub async fn start_register(
    Query(params): Query<StartRegisterQuery>,
    State(state): State<OAuthState>,
) -> impl IntoResponse {
    if let Err(validation_error) = params.validate() {
        return (
            StatusCode::BAD_REQUEST,
            Json(RegisterError {
                error: "invalid_request".to_string(),
                error_description: validation_error,
            }),
        )
            .into_response();
    }

    let oauth_repo = match OAuthRepository::new(state.db_pool()) {
        Ok(r) => r,
        Err(e) => {
            return (
                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
                axum::Json(serde_json::json!({"error": "server_error", "error_description": format!("Repository initialization failed: {}", e)})),
            ).into_response();
        },
    };
    let user_provider = Arc::clone(state.user_provider());

    let webauthn_service =
        match WebAuthnManager::get_or_create_service(oauth_repo, user_provider).await {
            Ok(service) => service,
            Err(e) => {
                tracing::error!(error = %e, "Failed to initialize WebAuthn");
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(RegisterError {
                        error: "server_error".to_string(),
                        error_description: format!("Failed to initialize WebAuthn: {e}"),
                    }),
                )
                    .into_response();
            },
        };

    match webauthn_service
        .start_registration(&params.username, &params.email, params.full_name.as_deref())
        .await
    {
        Ok((challenge, challenge_id)) => {
            let mut challenge_json = match serde_json::to_value(&challenge) {
                Ok(json) => json,
                Err(e) => {
                    return (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        Json(RegisterError {
                            error: "server_error".to_string(),
                            error_description: format!("Failed to serialize challenge: {e}"),
                        }),
                    )
                        .into_response();
                },
            };

            if let Some(public_key) = challenge_json.get_mut("publicKey") {
                if let Some(authenticator_selection) = public_key.get_mut("authenticatorSelection")
                {
                    if let Some(obj) = authenticator_selection.as_object_mut() {
                        obj.remove("authenticatorAttachment");
                    }
                }
            }

            let mut headers = HeaderMap::new();
            let header_value = HeaderValue::from_str(&challenge_id).map_err(|e| {
                (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(RegisterError {
                        error: "server_error".to_string(),
                        error_description: format!("Invalid challenge ID format: {e}"),
                    }),
                )
                    .into_response()
            });

            match header_value {
                Ok(val) => {
                    headers.insert(HeaderName::from_static("x-challenge-id"), val);
                    (StatusCode::OK, headers, Json(challenge_json)).into_response()
                },
                Err(response) => response,
            }
        },
        Err(e) => (
            StatusCode::BAD_REQUEST,
            Json(RegisterError {
                error: "registration_failed".to_string(),
                error_description: e.to_string(),
            }),
        )
            .into_response(),
    }
}