systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use std::sync::Arc;

use systemprompt_analytics::{SessionRepository, ThrottleLevel};
use systemprompt_database::DbPool;
use systemprompt_models::RequestContext;
use systemprompt_models::api::{ApiError, ErrorCode};

#[derive(Debug, Clone)]
pub struct ThrottleMiddleware {
    session_repo: Arc<SessionRepository>,
}

impl ThrottleMiddleware {
    pub fn new(db_pool: &DbPool) -> anyhow::Result<Self> {
        Ok(Self {
            session_repo: Arc::new(SessionRepository::new(db_pool)?),
        })
    }

    pub async fn check_throttle(&self, request: Request, next: Next) -> Result<Response, ApiError> {
        let Some(req_ctx) = request.extensions().get::<RequestContext>().cloned() else {
            return Ok(next.run(request).await);
        };

        if !req_ctx.request.is_tracked {
            return Ok(next.run(request).await);
        }

        let throttle_level = self
            .session_repo
            .get_throttle_level(&req_ctx.request.session_id)
            .await
            .unwrap_or_else(|e| {
                tracing::warn!(error = %e, session_id = %req_ctx.request.session_id, "Failed to get throttle level");
                0
            });

        let level = ThrottleLevel::from(throttle_level);

        if !level.allows_requests() {
            let api_error = ApiError::new(
                ErrorCode::RateLimited,
                "Request blocked due to suspicious activity",
            );
            let mut response = api_error.into_response();
            response
                .headers_mut()
                .insert("Retry-After", http::HeaderValue::from_static("3600"));
            response.headers_mut().insert(
                "X-Throttle-Level",
                http::HeaderValue::from_static("blocked"),
            );
            response.headers_mut().insert(
                "X-Throttle-Reason",
                http::HeaderValue::from_static("behavioral_bot_detection"),
            );
            return Ok(response);
        }

        let mut response = next.run(request).await;

        if throttle_level > 0 {
            let level_str = match throttle_level {
                1 => "warning",
                2 => "severe",
                _ => "unknown",
            };

            if let Ok(header_value) = level_str.parse() {
                response
                    .headers_mut()
                    .insert("X-Throttle-Level", header_value);
            }

            let multiplier = level.rate_multiplier();
            if let Ok(header_value) = format!("{multiplier}").parse() {
                response
                    .headers_mut()
                    .insert("X-Rate-Multiplier", header_value);
            }
        }

        Ok(response)
    }
}

pub async fn check_throttle_level(
    middleware: axum::extract::State<ThrottleMiddleware>,
    request: Request,
    next: Next,
) -> Result<Response, ApiError> {
    middleware.check_throttle(request, next).await
}