systemprompt-api 0.1.18

HTTP API server and gateway for systemprompt.io OS
Documentation
use axum::body::{Body, to_bytes};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use bytes::Bytes;
use futures_util::TryStreamExt;
use reqwest::Method;
use std::pin::Pin;
use std::str::FromStr;
use std::task::{Context, Poll};
use std::time::Duration;
use systemprompt_models::RequestContext;
use systemprompt_models::api::{ApiError, ErrorCode};
use systemprompt_traits::InjectContextHeaders;
use thiserror::Error;
use tokio::time::{Instant, Sleep};

#[derive(Debug, Error)]
pub enum ProxyError {
    #[error("Service '{service}' not found in inventory")]
    ServiceNotFound { service: String },

    #[error("Service '{service}' is not running (status: {status})")]
    ServiceNotRunning { service: String, status: String },

    #[error("Failed to connect to {service} at {url}: {source}")]
    ConnectionFailed {
        service: String,
        url: String,
        #[source]
        source: reqwest::Error,
    },

    #[error("Request to {service} timed out")]
    Timeout { service: String },

    #[error("Invalid response from {service}: {reason}")]
    InvalidResponse { service: String, reason: String },

    #[error("Failed to build URL for {service}: {reason}")]
    UrlConstructionFailed { service: String, reason: String },

    #[error("Failed to extract request body: {source}")]
    BodyExtractionFailed {
        #[source]
        source: axum::Error,
    },

    #[error("Invalid HTTP method: {reason}")]
    InvalidMethod { reason: String },

    #[error("Database error when looking up service '{service}': {source}")]
    DatabaseError {
        service: String,
        #[source]
        source: anyhow::Error,
    },

    #[error("Authentication required for service '{service}'")]
    AuthenticationRequired { service: String },

    #[error("OAuth challenge response")]
    AuthChallenge(Response<Body>),

    #[error("Access forbidden for service '{service}'")]
    Forbidden { service: String },

    #[error("Missing request context: {message}")]
    MissingContext { message: String },
}

impl ProxyError {
    pub fn to_status_code(&self) -> StatusCode {
        match self {
            Self::ServiceNotFound { .. } => StatusCode::NOT_FOUND,
            Self::ServiceNotRunning { .. } => StatusCode::SERVICE_UNAVAILABLE,
            Self::ConnectionFailed { .. } | Self::InvalidResponse { .. } => StatusCode::BAD_GATEWAY,
            Self::Timeout { .. } => StatusCode::GATEWAY_TIMEOUT,
            Self::UrlConstructionFailed { .. } | Self::DatabaseError { .. } => {
                StatusCode::INTERNAL_SERVER_ERROR
            },
            Self::BodyExtractionFailed { .. } | Self::InvalidMethod { .. } => {
                StatusCode::BAD_REQUEST
            },
            Self::AuthenticationRequired { .. } | Self::MissingContext { .. } => {
                StatusCode::UNAUTHORIZED
            },
            Self::AuthChallenge(response) => response.status(),
            Self::Forbidden { .. } => StatusCode::FORBIDDEN,
        }
    }
}

impl From<ProxyError> for StatusCode {
    fn from(error: ProxyError) -> Self {
        error.to_status_code()
    }
}

impl IntoResponse for ProxyError {
    fn into_response(self) -> Response {
        match self {
            Self::AuthChallenge(response) => response.into_response(),
            ref error => {
                let status = error.to_status_code();
                let error_type = match &self {
                    Self::ServiceNotFound { .. } => "service_not_found",
                    Self::ServiceNotRunning { .. } => "service_not_running",
                    Self::ConnectionFailed { .. } => "connection_failed",
                    Self::Timeout { .. } => "timeout",
                    Self::InvalidResponse { .. } => "invalid_response",
                    Self::UrlConstructionFailed { .. } => "url_construction_failed",
                    Self::BodyExtractionFailed { .. } => "body_extraction_failed",
                    Self::InvalidMethod { .. } => "invalid_method",
                    Self::DatabaseError { .. } => "database_error",
                    Self::AuthenticationRequired { .. } => "authentication_required",
                    Self::AuthChallenge(_) => "auth_challenge",
                    Self::Forbidden { .. } => "forbidden",
                    Self::MissingContext { .. } => "missing_context",
                };

                if status.is_server_error() {
                    tracing::error!(
                        error_type = %error_type,
                        status_code = %status.as_u16(),
                        error = %self,
                        "Proxy server error"
                    );
                } else if status.is_client_error() {
                    tracing::warn!(
                        error_type = %error_type,
                        status_code = %status.as_u16(),
                        error = %self,
                        "Proxy client error"
                    );
                }

                let message = self.to_string();
                let api_error = match status {
                    StatusCode::NOT_FOUND => ApiError::not_found(message),
                    StatusCode::UNAUTHORIZED => ApiError::unauthorized(message),
                    StatusCode::FORBIDDEN => ApiError::forbidden(message),
                    StatusCode::BAD_REQUEST => ApiError::bad_request(message),
                    StatusCode::SERVICE_UNAVAILABLE
                    | StatusCode::BAD_GATEWAY
                    | StatusCode::GATEWAY_TIMEOUT => {
                        ApiError::new(ErrorCode::ServiceUnavailable, message)
                    },
                    _ => ApiError::internal_error(message),
                };
                api_error.into_response()
            },
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub struct HeaderInjector;

impl HeaderInjector {
    pub fn inject_context(headers: &mut HeaderMap, req_ctx: &RequestContext) {
        req_ctx.inject_headers(headers);
    }
}

#[derive(Debug, Clone, Copy)]
pub struct UrlResolver;

impl UrlResolver {
    pub fn build_backend_url(protocol: &str, host: &str, port: i32, path: &str) -> String {
        let clean_path = path.trim_start_matches('/');

        match protocol {
            "mcp" => {
                if clean_path.is_empty() || clean_path == "mcp" {
                    format!("http://{host}:{port}/mcp")
                } else {
                    format!("http://{host}:{port}/{clean_path}")
                }
            },
            _ => {
                if clean_path.is_empty() {
                    format!("http://{host}:{port}/")
                } else {
                    format!("http://{host}:{port}/{clean_path}")
                }
            },
        }
    }

    pub fn append_query_params(url: String, query: Option<&str>) -> String {
        match query {
            Some(q) if !q.is_empty() => format!("{url}?{q}"),
            _ => url,
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub struct RequestBuilder;

impl RequestBuilder {
    pub async fn extract_body(body: Body) -> Result<Vec<u8>, axum::Error> {
        const MAX_BODY_SIZE: usize = 100 * 1024 * 1024;

        to_bytes(body, MAX_BODY_SIZE)
            .await
            .map(|bytes| bytes.to_vec())
    }

    pub fn parse_method(method_str: &str) -> Result<Method, String> {
        Method::from_str(method_str)
            .map_err(|e| format!("Invalid HTTP method '{}': {}", method_str, e))
    }

    pub fn build_request(
        client: &reqwest::Client,
        method: Method,
        url: &str,
        headers: &HeaderMap,
        body: Vec<u8>,
    ) -> Result<reqwest::RequestBuilder, StatusCode> {
        let mut req_builder = client.request(method, url);
        req_builder = Self::add_headers(req_builder, headers);

        if !body.is_empty() {
            req_builder = req_builder.body(body);
        }

        Ok(req_builder)
    }

    fn add_headers(
        mut req_builder: reqwest::RequestBuilder,
        headers: &HeaderMap,
    ) -> reqwest::RequestBuilder {
        for (key, value) in headers {
            if let Ok(value_str) = value.to_str() {
                let key_str = key.as_str();

                if Self::should_skip_header(key_str) {
                    continue;
                }

                if key_str.eq_ignore_ascii_case("authorization") {
                    if Self::is_valid_auth_header(value_str) {
                        req_builder = req_builder.header(key_str, value_str);
                    }
                } else {
                    req_builder = req_builder.header(key_str, value_str);
                }
            }
        }
        req_builder
    }

    fn should_skip_header(header_name: &str) -> bool {
        let lower_name = header_name.to_lowercase();
        matches!(lower_name.as_str(), "host" | "x-mcp-proxy-auth")
    }

    fn is_valid_auth_header(value: &str) -> bool {
        value != "Bearer" && !value.trim().eq_ignore_ascii_case("bearer")
    }
}

const SSE_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(15);
const SSE_KEEPALIVE_PAYLOAD: &[u8] = b": keepalive\n\n";

struct SseKeepaliveStream<S> {
    inner: S,
    keepalive_interval: Duration,
    deadline: Pin<Box<Sleep>>,
}

impl<S> SseKeepaliveStream<S> {
    fn new(inner: S, keepalive_interval: Duration) -> Self {
        Self {
            inner,
            keepalive_interval,
            deadline: Box::pin(tokio::time::sleep(keepalive_interval)),
        }
    }
}

impl<S> futures_util::Stream for SseKeepaliveStream<S>
where
    S: futures_util::Stream<Item = Result<Bytes, std::io::Error>> + Unpin,
{
    type Item = Result<Bytes, std::io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let interval = self.keepalive_interval;
        match Pin::new(&mut self.inner).poll_next(cx) {
            Poll::Ready(Some(item)) => {
                self.deadline.as_mut().reset(Instant::now() + interval);
                Poll::Ready(Some(item))
            },
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Pending => match self.deadline.as_mut().poll(cx) {
                Poll::Ready(()) => {
                    self.deadline.as_mut().reset(Instant::now() + interval);
                    Poll::Ready(Some(Ok(Bytes::from_static(SSE_KEEPALIVE_PAYLOAD))))
                },
                Poll::Pending => Poll::Pending,
            },
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub struct ResponseHandler;

impl ResponseHandler {
    pub fn build_response(response: reqwest::Response) -> Result<Response<Body>, String> {
        let status_code = response.status().as_u16();
        let axum_status = StatusCode::from_u16(status_code)
            .map_err(|e| format!("Invalid status code {}: {}", status_code, e))?;

        let response_headers = response.headers().clone();
        let is_sse = response_headers
            .get("content-type")
            .and_then(|v| v.to_str().ok())
            .is_some_and(|ct| ct.contains("text/event-stream"));

        let stream = response.bytes_stream().map_err(std::io::Error::other);
        let body = if is_sse {
            let keepalive_stream = SseKeepaliveStream::new(stream, SSE_KEEPALIVE_INTERVAL);
            Body::from_stream(keepalive_stream)
        } else {
            Body::from_stream(stream)
        };

        let mut axum_response = Response::builder().status(axum_status);

        for (key, value) in &response_headers {
            let key_str = key.as_str();
            if let Ok(value_str) = value.to_str() {
                if Self::should_preserve_header(key_str) {
                    axum_response = axum_response.header(key_str, value_str);
                }
            }
        }

        axum_response = axum_response
            .header("connection", "keep-alive")
            .header("cache-control", "no-cache");

        if is_sse {
            axum_response = axum_response
                .header("x-accel-buffering", "no")
                .header("cache-control", "no-cache, no-transform");
        }

        axum_response
            .body(body)
            .map_err(|e| format!("Failed to build response body: {}", e))
    }

    fn should_preserve_header(key: &str) -> bool {
        match key.to_lowercase().as_str() {
            "host" | "authorization" | "proxy-authorization" | "upgrade" | "te" => false,
            header if header.starts_with("x-mcp-") => true,
            _ => true,
        }
    }
}