systemprompt-api 0.2.0

HTTP API server and gateway for systemprompt.io OS
Documentation
use axum::Json;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use systemprompt_identifiers::{ChallengeId, UserId};
use systemprompt_oauth::OAuthState;
use systemprompt_oauth::services::webauthn::{FinishRegistrationParams, WebAuthnManager};
use tracing::instrument;
use webauthn_rs::prelude::*;

use crate::routes::oauth::extractors::OAuthRepo;

use super::RegisterError;

#[derive(Debug, Deserialize)]
pub struct FinishRegisterRequest {
    pub challenge_id: ChallengeId,
    pub username: String,
    pub email: String,
    pub full_name: Option<String>,
    pub credential: RegisterPublicKeyCredential,
    #[serde(default)]
    pub session_id: Option<String>,
}

impl FinishRegisterRequest {
    fn validate(&self) -> Result<(), String> {
        if self.challenge_id.as_str().trim().is_empty() {
            return Err("Challenge ID is required".to_string());
        }

        if self.username.trim().is_empty() {
            return Err("Username is required and cannot be empty".to_string());
        }

        if self.username.len() > 50 {
            return Err("Username must be less than 50 characters".to_string());
        }

        if !self
            .username
            .chars()
            .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
        {
            return Err(
                "Username can only contain letters, numbers, underscores, and hyphens".to_string(),
            );
        }

        if !crate::services::validation::is_valid_email(&self.email) {
            return Err("Email must be a valid email address".to_string());
        }

        Ok(())
    }
}

#[derive(Debug, Serialize)]
pub struct FinishRegisterResponse {
    pub user_id: UserId,
    pub success: bool,
}

#[instrument(skip(state, oauth_repo, request), fields(challenge_id = %request.challenge_id, username = %request.username))]
pub async fn finish_register(
    State(state): State<OAuthState>,
    OAuthRepo(oauth_repo): OAuthRepo,
    Json(request): Json<FinishRegisterRequest>,
) -> impl IntoResponse {
    if let Err(validation_error) = request.validate() {
        return (
            StatusCode::BAD_REQUEST,
            Json(RegisterError {
                error: "invalid_request".to_string(),
                error_description: validation_error,
            }),
        )
            .into_response();
    }

    let user_provider = Arc::clone(state.user_provider());

    let webauthn_service =
        match WebAuthnManager::get_or_create_service(oauth_repo, user_provider).await {
            Ok(service) => service,
            Err(e) => {
                tracing::error!(error = %e, "Failed to initialize WebAuthn");
                return (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    Json(RegisterError {
                        error: "server_error".to_string(),
                        error_description: format!("Failed to initialize WebAuthn: {e}"),
                    }),
                )
                    .into_response();
            },
        };

    let mut builder = FinishRegistrationParams::builder(
        request.challenge_id.as_str(),
        &request.username,
        &request.email,
        &request.credential,
    );
    if let Some(ref name) = request.full_name {
        builder = builder.with_full_name(name);
    }

    match webauthn_service.finish_registration(builder.build()).await {
        Ok(user_id) => {
            if let Some(publisher) = state.event_publisher() {
                publisher.publish_user_event(systemprompt_traits::UserEvent::UserCreated {
                    user_id: user_id.as_str().to_string(),
                });
            }

            if let Some(session_id_str) = &request.session_id {
                use systemprompt_identifiers::SessionId;

                let session_id = SessionId::new(session_id_str.clone());
                let analytics_provider = state.analytics_provider();

                match analytics_provider.find_session_by_id(&session_id).await {
                    Ok(Some(session)) => {
                        if let Some(old_user_id) = session.user_id {
                            let new_user_id = user_id.clone();

                            match analytics_provider
                                .migrate_user_sessions(&old_user_id, &new_user_id)
                                .await
                            {
                                Ok(count) => {
                                    tracing::info!(
                                        session_id = %session_id,
                                        old_user_id = %old_user_id,
                                        new_user_id = %new_user_id,
                                        records_migrated = count,
                                        "Successfully migrated user data"
                                    );
                                },
                                Err(e) => {
                                    tracing::error!(
                                        error = %e,
                                        session_id = %session_id,
                                        old_user_id = %old_user_id,
                                        new_user_id = %new_user_id,
                                        "Failed to migrate session"
                                    );
                                },
                            }
                        }
                    },
                    Ok(None) => {
                        tracing::warn!(session_id = %session_id, "Session not found for migration");
                    },
                    Err(e) => {
                        tracing::error!(
                            error = %e,
                            session_id = %session_id,
                            "Failed to retrieve session for migration"
                        );
                    },
                }
            }

            (
                StatusCode::OK,
                Json(FinishRegisterResponse {
                    user_id,
                    success: true,
                }),
            )
                .into_response()
        },
        Err(e) => {
            let error_msg = e.to_string();
            let (status, error_code, description) = if error_msg.contains("username_already_taken")
            {
                (
                    StatusCode::CONFLICT,
                    "username_unavailable",
                    "Username is already taken. Please choose a different username.".to_string(),
                )
            } else if error_msg.contains("email_already_registered") {
                (
                    StatusCode::CONFLICT,
                    "email_exists",
                    "An account with this email already exists.".to_string(),
                )
            } else if error_msg.contains("Registration state not found") {
                (
                    StatusCode::BAD_REQUEST,
                    "expired_challenge",
                    "Registration challenge has expired. Please start the registration process \
                     again."
                        .to_string(),
                )
            } else if error_msg.contains("finish_passkey_registration")
                || error_msg.contains("verification")
                || error_msg.contains("attestation")
            {
                (
                    StatusCode::BAD_REQUEST,
                    "invalid_credential",
                    "WebAuthn verification failed. Please ensure your authenticator and browser \
                     are compatible."
                        .to_string(),
                )
            } else {
                (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    "registration_failed",
                    format!("Registration failed: {error_msg}"),
                )
            };

            (
                status,
                Json(RegisterError {
                    error: error_code.to_string(),
                    error_description: description,
                }),
            )
                .into_response()
        },
    }
}