systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use std::sync::Arc;
use systemprompt_security::HeaderExtractor;
use tracing::Instrument;

use super::extractors::ContextExtractor;
use super::requirements::ContextRequirement;
use systemprompt_identifiers::{AgentName, ContextId, TraceId};
use systemprompt_models::api::ApiError;
use systemprompt_models::execution::context::{ContextExtractionError, RequestContext};

#[derive(Debug, Clone)]
pub struct ContextMiddleware<E> {
    extractor: Arc<E>,
    auth_level: ContextRequirement,
}

impl<E> ContextMiddleware<E> {
    pub fn new(extractor: E) -> Self {
        Self {
            extractor: Arc::new(extractor),
            auth_level: ContextRequirement::default(),
        }
    }

    pub fn public(extractor: E) -> Self {
        Self {
            extractor: Arc::new(extractor),
            auth_level: ContextRequirement::None,
        }
    }

    pub fn user_only(extractor: E) -> Self {
        Self {
            extractor: Arc::new(extractor),
            auth_level: ContextRequirement::UserOnly,
        }
    }

    pub fn full(extractor: E) -> Self {
        Self {
            extractor: Arc::new(extractor),
            auth_level: ContextRequirement::UserWithContext,
        }
    }

    pub fn mcp(extractor: E) -> Self {
        Self {
            extractor: Arc::new(extractor),
            auth_level: ContextRequirement::McpWithHeaders,
        }
    }

    fn error_to_api_error(error: &ContextExtractionError) -> ApiError {
        match error {
            ContextExtractionError::MissingAuthHeader => {
                ApiError::unauthorized("Missing Authorization header")
            },
            ContextExtractionError::InvalidToken(_) => {
                ApiError::unauthorized("Invalid or expired JWT token")
            },
            ContextExtractionError::UserNotFound(_) => {
                ApiError::unauthorized("User no longer exists")
            },
            ContextExtractionError::MissingSessionId => {
                ApiError::bad_request("JWT missing required 'session_id' claim")
            },
            ContextExtractionError::MissingUserId => {
                ApiError::bad_request("JWT missing required 'sub' claim")
            },
            ContextExtractionError::MissingContextId => ApiError::bad_request(
                "Missing required 'x-context-id' header (for MCP routes) or contextId in body \
                 (for A2A routes)",
            ),
            ContextExtractionError::MissingHeader(header) => {
                ApiError::bad_request(format!("Missing required header: {header}"))
            },
            ContextExtractionError::InvalidHeaderValue { header, reason } => {
                ApiError::bad_request(format!("Invalid header {header}: {reason}"))
            },
            ContextExtractionError::InvalidUserId(reason) => {
                ApiError::bad_request(format!("Invalid user_id: {reason}"))
            },
            ContextExtractionError::DatabaseError(_) => {
                ApiError::internal_error("Internal server error")
            },
            ContextExtractionError::ForbiddenHeader { header, reason } => {
                ApiError::bad_request(format!(
                    "Header '{header}' is not allowed: {reason}. Use JWT authentication instead."
                ))
            },
        }
    }

    fn log_error_response(
        error: &ContextExtractionError,
        trace_id: &TraceId,
        path: &str,
        method: &str,
    ) -> Response {
        let _span = tracing::error_span!(
            "context_extraction_error",
            trace_id = %trace_id,
            path = %path,
            method = %method,
        )
        .entered();

        match error {
            ContextExtractionError::DatabaseError(e) => {
                tracing::error!(
                    error = %e,
                    error_type = "database",
                    "Context extraction failed due to database error"
                );
            },
            ContextExtractionError::InvalidToken(reason) => {
                tracing::warn!(
                    reason = %reason,
                    error_type = "invalid_token",
                    "Context extraction failed: invalid token"
                );
            },
            ContextExtractionError::UserNotFound(user_id) => {
                tracing::warn!(
                    user_id = %user_id,
                    error_type = "user_not_found",
                    "Context extraction failed: user not found"
                );
            },
            _ => {
                tracing::warn!(
                    error = %error,
                    error_type = "context_extraction",
                    "Context extraction failed"
                );
            },
        }

        Self::error_to_api_error(error)
            .with_trace_id(trace_id.as_str())
            .with_path(path)
            .into_response()
    }
}

fn create_request_span(ctx: &RequestContext) -> tracing::Span {
    tracing::info_span!(
        "request",
        user_id = %ctx.user_id(),
        session_id = %ctx.session_id(),
        trace_id = %ctx.trace_id(),
        context_id = %ctx.context_id(),
    )
}

impl<E: ContextExtractor> ContextMiddleware<E> {
    pub async fn handle(&self, request: Request, next: Next) -> Response {
        let requirement = request
            .extensions()
            .get::<ContextRequirement>()
            .copied()
            .unwrap_or(self.auth_level);

        if request.extensions().get::<RequestContext>().is_some()
            && self.auth_level == ContextRequirement::None
        {
            return next.run(request).await;
        }

        match requirement {
            ContextRequirement::None => self.handle_none_requirement(request, next).await,
            ContextRequirement::UserOnly => self.handle_user_only(request, next).await,
            ContextRequirement::UserWithContext => {
                self.handle_user_with_context(request, next).await
            },
            ContextRequirement::McpWithHeaders => self.handle_mcp_with_headers(request, next).await,
        }
    }

    async fn handle_none_requirement(&self, mut request: Request, next: Next) -> Response {
        let headers = request.headers();
        let mut req_ctx = if let Some(ctx) = request.extensions().get::<RequestContext>() {
            ctx.clone()
        } else {
            return ApiError::internal_error(
                "Middleware configuration error: SessionMiddleware must run before \
                 ContextMiddleware",
            )
            .into_response();
        };

        if let Some(context_id) = headers.get("x-context-id") {
            if let Ok(id) = context_id.to_str() {
                req_ctx.execution.context_id = ContextId::new(id.to_string());
            }
        }

        if let Some(agent_name) = headers.get("x-agent-name") {
            if let Ok(name) = agent_name.to_str() {
                req_ctx.execution.agent_name = AgentName::new(name.to_string());
            }
        }

        let span = create_request_span(&req_ctx);
        request.extensions_mut().insert(req_ctx);
        next.run(request).instrument(span).await
    }

    async fn handle_user_only(&self, mut request: Request, next: Next) -> Response {
        let trace_id = HeaderExtractor::extract_trace_id(request.headers());
        let path = request.uri().path().to_string();
        let method = request.method().to_string();

        match self.extractor.extract_user_only(request.headers()).await {
            Ok(context) => {
                let span = create_request_span(&context);
                request.extensions_mut().insert(context);
                next.run(request).instrument(span).await
            },
            Err(e) => Self::log_error_response(&e, &trace_id, &path, &method),
        }
    }

    async fn handle_user_with_context(&self, request: Request, next: Next) -> Response {
        let trace_id = HeaderExtractor::extract_trace_id(request.headers());
        let path = request.uri().path().to_string();
        let method = request.method().to_string();

        match self.extractor.extract_from_request(request).await {
            Ok((context, reconstructed_request)) => {
                let span = create_request_span(&context);
                let mut req = reconstructed_request;
                req.extensions_mut().insert(context);
                next.run(req).instrument(span).await
            },
            Err(e) => Self::log_error_response(&e, &trace_id, &path, &method),
        }
    }

    async fn handle_mcp_with_headers(&self, request: Request, next: Next) -> Response {
        let trace_id = HeaderExtractor::extract_trace_id(request.headers());
        let path = request.uri().path().to_string();
        let method = request.method().to_string();

        match self.extractor.extract_from_headers(request.headers()).await {
            Ok(context) => {
                let span = create_request_span(&context);
                let mut req = request;
                req.extensions_mut().insert(context);
                next.run(req).instrument(span).await
            },
            Err(e) => {
                let fallback_ctx = request.extensions().get::<RequestContext>().cloned();
                #[allow(clippy::single_match_else)]
                match fallback_ctx {
                    Some(ctx) => {
                        tracing::debug!(
                            error = %e,
                            trace_id = %trace_id,
                            "MCP header extraction failed, using session context"
                        );
                        let span = create_request_span(&ctx);
                        next.run(request).instrument(span).await
                    },
                    None => {
                        tracing::error!(
                            trace_id = %trace_id,
                            path = %path,
                            method = %method,
                            "Middleware configuration error: SessionMiddleware must run before ContextMiddleware"
                        );
                        ApiError::internal_error("Middleware configuration error")
                            .with_trace_id(trace_id.as_str())
                            .with_path(&path)
                            .into_response()
                    },
                }
            },
        }
    }
}