systemprompt-api 0.9.0

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
use axum::body::{Body, to_bytes};
use axum::http::{HeaderMap, StatusCode};
use axum::response::Response;
use bytes::Bytes;
use futures_util::TryStreamExt;
use reqwest::Method;
use std::pin::Pin;
use std::str::FromStr;
use std::task::{Context, Poll};
use std::time::Duration;
use systemprompt_models::RequestContext;
use systemprompt_traits::InjectContextHeaders;
use tokio::time::{Instant, Sleep};

pub use super::errors::ProxyError;

#[derive(Debug, Clone, Copy)]
pub struct HeaderInjector;

impl HeaderInjector {
    pub fn inject_context(headers: &mut HeaderMap, req_ctx: &RequestContext) {
        req_ctx.inject_headers(headers);
    }
}

#[derive(Debug, Clone, Copy)]
pub struct UrlResolver;

impl UrlResolver {
    pub fn build_backend_url(protocol: &str, host: &str, port: i32, path: &str) -> String {
        let clean_path = path.trim_start_matches('/');

        match protocol {
            "mcp" => {
                if clean_path.is_empty() || clean_path == "mcp" {
                    format!("http://{host}:{port}/mcp")
                } else {
                    format!("http://{host}:{port}/{clean_path}")
                }
            },
            _ => {
                if clean_path.is_empty() {
                    format!("http://{host}:{port}/")
                } else {
                    format!("http://{host}:{port}/{clean_path}")
                }
            },
        }
    }

    pub fn append_query_params(url: String, query: Option<&str>) -> String {
        match query {
            Some(q) if !q.is_empty() => format!("{url}?{q}"),
            _ => url,
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub struct RequestBuilder;

impl RequestBuilder {
    pub async fn extract_body(body: Body) -> Result<Vec<u8>, axum::Error> {
        const MAX_BODY_SIZE: usize = 100 * 1024 * 1024;

        to_bytes(body, MAX_BODY_SIZE)
            .await
            .map(|bytes| bytes.to_vec())
    }

    pub fn parse_method(method_str: &str) -> Result<Method, String> {
        Method::from_str(method_str)
            .map_err(|e| format!("Invalid HTTP method '{}': {}", method_str, e))
    }

    pub fn build_request(
        client: &reqwest::Client,
        method: Method,
        url: &str,
        headers: &HeaderMap,
        body: Vec<u8>,
    ) -> reqwest::RequestBuilder {
        let mut req_builder = client.request(method, url);
        req_builder = Self::add_headers(req_builder, headers);

        if !body.is_empty() {
            req_builder = req_builder.body(body);
        }

        req_builder
    }

    fn add_headers(
        mut req_builder: reqwest::RequestBuilder,
        headers: &HeaderMap,
    ) -> reqwest::RequestBuilder {
        for (key, value) in headers {
            if let Ok(value_str) = value.to_str() {
                let key_str = key.as_str();

                if Self::should_skip_header(key_str) {
                    continue;
                }

                if key_str.eq_ignore_ascii_case("authorization") {
                    if Self::is_valid_auth_header(value_str) {
                        req_builder = req_builder.header(key_str, value_str);
                    }
                } else {
                    req_builder = req_builder.header(key_str, value_str);
                }
            }
        }
        req_builder
    }

    fn should_skip_header(header_name: &str) -> bool {
        let lower_name = header_name.to_lowercase();
        matches!(lower_name.as_str(), "host" | "x-mcp-proxy-auth")
    }

    fn is_valid_auth_header(value: &str) -> bool {
        value != "Bearer" && !value.trim().eq_ignore_ascii_case("bearer")
    }
}

const SSE_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(15);
const SSE_KEEPALIVE_PAYLOAD: &[u8] = b": keepalive\n\n";

struct SseKeepaliveStream<S> {
    inner: S,
    keepalive_interval: Duration,
    deadline: Pin<Box<Sleep>>,
}

impl<S> SseKeepaliveStream<S> {
    fn new(inner: S, keepalive_interval: Duration) -> Self {
        Self {
            inner,
            keepalive_interval,
            deadline: Box::pin(tokio::time::sleep(keepalive_interval)),
        }
    }
}

impl<S> futures_util::Stream for SseKeepaliveStream<S>
where
    S: futures_util::Stream<Item = Result<Bytes, std::io::Error>> + Unpin,
{
    type Item = Result<Bytes, std::io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let interval = self.keepalive_interval;
        match Pin::new(&mut self.inner).poll_next(cx) {
            Poll::Ready(Some(item)) => {
                self.deadline.as_mut().reset(Instant::now() + interval);
                Poll::Ready(Some(item))
            },
            Poll::Ready(None) => Poll::Ready(None),
            Poll::Pending => match self.deadline.as_mut().poll(cx) {
                Poll::Ready(()) => {
                    self.deadline.as_mut().reset(Instant::now() + interval);
                    Poll::Ready(Some(Ok(Bytes::from_static(SSE_KEEPALIVE_PAYLOAD))))
                },
                Poll::Pending => Poll::Pending,
            },
        }
    }
}

#[derive(Debug, Clone, Copy)]
pub struct ResponseHandler;

impl ResponseHandler {
    pub fn build_response(response: reqwest::Response) -> Result<Response<Body>, String> {
        let status_code = response.status().as_u16();
        let axum_status = StatusCode::from_u16(status_code)
            .map_err(|e| format!("Invalid status code {}: {}", status_code, e))?;

        let response_headers = response.headers().clone();
        let is_sse = response_headers
            .get("content-type")
            .and_then(|v| v.to_str().ok())
            .is_some_and(|ct| ct.contains("text/event-stream"));

        let stream = response.bytes_stream().map_err(std::io::Error::other);
        let body = if is_sse {
            let keepalive_stream = SseKeepaliveStream::new(stream, SSE_KEEPALIVE_INTERVAL);
            Body::from_stream(keepalive_stream)
        } else {
            Body::from_stream(stream)
        };

        let mut axum_response = Response::builder().status(axum_status);

        for (key, value) in &response_headers {
            let key_str = key.as_str();
            if let Ok(value_str) = value.to_str() {
                if Self::should_preserve_header(key_str) {
                    axum_response = axum_response.header(key_str, value_str);
                }
            }
        }

        axum_response = axum_response
            .header("connection", "keep-alive")
            .header("cache-control", "no-cache");

        if is_sse {
            axum_response = axum_response
                .header("x-accel-buffering", "no")
                .header("cache-control", "no-cache, no-transform");
        }

        axum_response
            .body(body)
            .map_err(|e| format!("Failed to build response body: {}", e))
    }

    fn should_preserve_header(key: &str) -> bool {
        match key.to_lowercase().as_str() {
            "host" | "authorization" | "proxy-authorization" | "upgrade" | "te" => false,
            header if header.starts_with("x-mcp-") => true,
            _ => true,
        }
    }
}