systemprompt-api 0.3.0

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

use axum::body::Body;
use bytes::{Bytes, BytesMut};
use futures_util::stream::Stream;
use serde_json::Value;

use super::audit::{CapturedToolUse, CapturedUsage, GatewayAudit};

#[derive(Default)]
struct TapState {
    sse_buffer: Vec<u8>,
    response_buffer: BytesMut,
    input_tokens: u32,
    output_tokens: u32,
    tool_uses_in_progress: Vec<PartialToolUse>,
    tool_uses_done: Vec<CapturedToolUse>,
    error: Option<String>,
}

#[derive(Default, Debug)]
struct PartialToolUse {
    index: i64,
    id: String,
    name: String,
    input_json: String,
}

pub fn tap<S>(upstream: S, audit: Arc<GatewayAudit>) -> Body
where
    S: Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
{
    let state = Arc::new(Mutex::new(TapState::default()));
    let tapped = TappedStream {
        inner: Box::pin(upstream),
        state: Arc::clone(&state),
        audit,
    };
    Body::from_stream(tapped)
}

struct TappedStream<S> {
    inner: Pin<Box<S>>,
    state: Arc<Mutex<TapState>>,
    audit: Arc<GatewayAudit>,
}

impl<S> Stream for TappedStream<S>
where
    S: Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
{
    type Item = Result<Bytes, std::io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match self.inner.as_mut().poll_next(cx) {
            Poll::Pending => Poll::Pending,
            Poll::Ready(Some(Ok(bytes))) => {
                if let Ok(mut s) = self.state.lock() {
                    s.response_buffer.extend_from_slice(&bytes);
                    s.sse_buffer.extend_from_slice(&bytes);
                    drain_sse(&mut s);
                }
                Poll::Ready(Some(Ok(bytes)))
            },
            Poll::Ready(Some(Err(e))) => {
                if let Ok(mut s) = self.state.lock() {
                    s.error = Some(e.to_string());
                }
                Poll::Ready(Some(Err(e)))
            },
            Poll::Ready(None) => {
                let (usage, tools, body, error) = {
                    let mut s = self
                        .state
                        .lock()
                        .unwrap_or_else(std::sync::PoisonError::into_inner);
                    finalize_partials(&mut s);
                    (
                        CapturedUsage {
                            input_tokens: s.input_tokens,
                            output_tokens: s.output_tokens,
                        },
                        std::mem::take(&mut s.tool_uses_done),
                        std::mem::take(&mut s.response_buffer).freeze(),
                        s.error.take(),
                    )
                };

                let audit = Arc::clone(&self.audit);
                tokio::spawn(async move {
                    if let Some(err) = error {
                        let _ = audit.fail(&err).await;
                    } else if let Err(e) = audit.complete(usage, tools, &body).await {
                        tracing::warn!(error = %e, "stream audit complete failed");
                    }
                });
                Poll::Ready(None)
            },
        }
    }
}

fn drain_sse(state: &mut TapState) {
    loop {
        let Some(pos) = find_double_newline(&state.sse_buffer) else {
            return;
        };
        let frame_bytes: Vec<u8> = state.sse_buffer.drain(..pos + 2).collect();
        let frame = String::from_utf8_lossy(&frame_bytes);
        for line in frame.lines() {
            if let Some(data) = line.strip_prefix("data: ") {
                if data.trim() == "[DONE]" {
                    continue;
                }
                let Ok(json) = serde_json::from_str::<Value>(data) else {
                    continue;
                };
                handle_sse_event(state, &json);
            }
        }
    }
}

fn find_double_newline(buf: &[u8]) -> Option<usize> {
    buf.windows(2).position(|w| w == b"\n\n")
}

fn handle_sse_event(state: &mut TapState, event: &Value) {
    let Some(kind) = event.get("type").and_then(Value::as_str) else {
        return;
    };
    match kind {
        "message_start" => {
            if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
                if let Some(v) = usage.get("input_tokens").and_then(Value::as_u64) {
                    state.input_tokens = v as u32;
                }
                if let Some(v) = usage.get("output_tokens").and_then(Value::as_u64) {
                    state.output_tokens = v as u32;
                }
            }
        },
        "message_delta" => {
            if let Some(usage) = event.get("usage") {
                if let Some(v) = usage.get("output_tokens").and_then(Value::as_u64) {
                    state.output_tokens = v as u32;
                }
                if let Some(v) = usage.get("input_tokens").and_then(Value::as_u64) {
                    state.input_tokens = v as u32;
                }
            }
        },
        "content_block_start" => {
            let index = event.get("index").and_then(Value::as_i64).unwrap_or(-1);
            if let Some(block) = event.get("content_block") {
                if block.get("type").and_then(Value::as_str) == Some("tool_use") {
                    let id = block
                        .get("id")
                        .and_then(Value::as_str)
                        .unwrap_or("")
                        .to_string();
                    let name = block
                        .get("name")
                        .and_then(Value::as_str)
                        .unwrap_or("")
                        .to_string();
                    state.tool_uses_in_progress.push(PartialToolUse {
                        index,
                        id,
                        name,
                        input_json: String::new(),
                    });
                }
            }
        },
        "content_block_delta" => {
            let index = event.get("index").and_then(Value::as_i64).unwrap_or(-1);
            if let Some(delta) = event.get("delta") {
                if delta.get("type").and_then(Value::as_str) == Some("input_json_delta") {
                    if let Some(partial) = delta.get("partial_json").and_then(Value::as_str) {
                        if let Some(pt) = state
                            .tool_uses_in_progress
                            .iter_mut()
                            .find(|p| p.index == index)
                        {
                            pt.input_json.push_str(partial);
                        }
                    }
                }
            }
        },
        "content_block_stop" => {
            let index = event.get("index").and_then(Value::as_i64).unwrap_or(-1);
            if let Some(pos) = state
                .tool_uses_in_progress
                .iter()
                .position(|p| p.index == index)
            {
                let done = state.tool_uses_in_progress.remove(pos);
                state.tool_uses_done.push(CapturedToolUse {
                    ai_tool_call_id: done.id,
                    tool_name: done.name,
                    tool_input: if done.input_json.is_empty() {
                        "{}".to_string()
                    } else {
                        done.input_json
                    },
                });
            }
        },
        _ => {},
    }
}

fn finalize_partials(state: &mut TapState) {
    let leftover: Vec<PartialToolUse> = std::mem::take(&mut state.tool_uses_in_progress);
    for p in leftover {
        state.tool_uses_done.push(CapturedToolUse {
            ai_tool_call_id: p.id,
            tool_name: p.name,
            tool_input: if p.input_json.is_empty() {
                "{}".to_string()
            } else {
                p.input_json
            },
        });
    }
}