systemprompt-api 0.8.0

Axum-based HTTP server and API gateway for systemprompt.io AI governance infrastructure. Exposes governed agents, MCP, A2A, and admin endpoints with rate limiting and RBAC.
Documentation
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};

use axum::body::Body;
use bytes::{Bytes, BytesMut};
use futures_util::stream::{BoxStream, Stream};

use super::audit::GatewayAudit;
use super::captures::{CapturedToolUse, CapturedUsage};
use super::protocol::canonical::CanonicalContent;
use super::protocol::canonical_response::{
    CanonicalEvent, CanonicalResponse, CanonicalUsage, ContentBlockKind,
};
use super::protocol::inbound::InboundAdapter;

#[derive(Default)]
struct TapState {
    response_id: String,
    served_model: String,
    usage: CanonicalUsage,
    blocks: Vec<BlockAccumulator>,
    final_stop_reason: Option<super::protocol::canonical_response::CanonicalStopReason>,
    final_bytes: BytesMut,
    error: Option<String>,
    finalized: bool,
}

#[derive(Clone)]
enum BlockAccumulator {
    Text(String),
    Thinking {
        text: String,
        signature: Option<String>,
    },
    ToolUse {
        id: String,
        name: String,
        partial: String,
    },
}

pub fn tap(
    upstream: BoxStream<'static, Result<CanonicalEvent, String>>,
    inbound: Arc<dyn InboundAdapter>,
    request_model: String,
    audit: Arc<GatewayAudit>,
) -> Body {
    let state = Arc::new(Mutex::new(TapState::default()));
    let tapped = TappedStream {
        inner: upstream,
        state: Arc::clone(&state),
        inbound,
        request_model,
        audit,
    };
    Body::from_stream(tapped)
}

struct TappedStream {
    inner: BoxStream<'static, Result<CanonicalEvent, String>>,
    state: Arc<Mutex<TapState>>,
    inbound: Arc<dyn InboundAdapter>,
    request_model: String,
    audit: Arc<GatewayAudit>,
}

impl Stream for TappedStream {
    type Item = Result<Bytes, std::io::Error>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        loop {
            match self.inner.as_mut().poll_next(cx) {
                Poll::Pending => return Poll::Pending,
                Poll::Ready(None) => {
                    return self.finalize_on_eof();
                },
                Poll::Ready(Some(Err(e))) => {
                    if let Ok(mut s) = self.state.lock() {
                        s.error = Some(e.clone());
                    }
                    let err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, e);
                    return Poll::Ready(Some(Err(err)));
                },
                Poll::Ready(Some(Ok(event))) => {
                    if let Ok(mut s) = self.state.lock() {
                        accumulate_event(&mut s, &event);
                    }
                    let rendered = self.inbound.render_event(&event, &self.request_model);
                    if let Some(bytes) = rendered {
                        if let Ok(mut s) = self.state.lock() {
                            s.final_bytes.extend_from_slice(&bytes);
                        }
                        return Poll::Ready(Some(Ok(bytes)));
                    }
                },
            }
        }
    }
}

impl TappedStream {
    fn finalize_on_eof(&self) -> Poll<Option<Result<Bytes, std::io::Error>>> {
        let snapshot = self.state.lock().ok().and_then(|mut s| {
            if s.finalized {
                return None;
            }
            s.finalized = true;
            Some(extract_summary(&mut s))
        });

        let Some(summary) = snapshot else {
            return Poll::Ready(None);
        };

        spawn_audit_complete(Arc::clone(&self.audit), self.request_model.clone(), summary);
        Poll::Ready(None)
    }
}

impl Drop for TappedStream {
    fn drop(&mut self) {
        let snapshot = self.state.lock().ok().and_then(|mut s| {
            if s.finalized {
                return None;
            }
            s.finalized = true;
            Some(extract_summary(&mut s))
        });
        let Some(summary) = snapshot else {
            return;
        };

        let audit = Arc::clone(&self.audit);
        tokio::spawn(async move {
            if let Some(model) = summary.served_model.as_deref() {
                audit.set_served_model(model).await;
            }
            if summary.saw_stop && summary.error.is_none() {
                if let Err(e) = audit
                    .complete(
                        summary.usage,
                        summary.tool_calls,
                        &summary.response,
                        &summary.final_bytes,
                    )
                    .await
                {
                    tracing::warn!(error = %e, "drop-path audit complete failed");
                }
            } else {
                let msg = summary
                    .error
                    .unwrap_or_else(|| "stream abandoned by downstream".into());
                if let Err(e) = audit.fail(&msg).await {
                    tracing::warn!(error = %e, "drop-path audit fail failed");
                }
            }
        });
    }
}

struct Summary {
    usage: CapturedUsage,
    tool_calls: Vec<CapturedToolUse>,
    response: CanonicalResponse,
    final_bytes: Bytes,
    served_model: Option<String>,
    error: Option<String>,
    saw_stop: bool,
}

fn extract_summary(state: &mut TapState) -> Summary {
    let response = build_response(state);
    let usage = CapturedUsage {
        input_tokens: state.usage.input_tokens,
        output_tokens: state.usage.output_tokens,
    };
    let tool_calls = response
        .content
        .iter()
        .filter_map(|c| {
            if let CanonicalContent::ToolUse { id, name, input } = c {
                Some(CapturedToolUse {
                    ai_tool_call_id: id.clone(),
                    tool_name: name.clone(),
                    tool_input: serde_json::to_string(input).unwrap_or_default(),
                })
            } else {
                None
            }
        })
        .collect();
    let final_bytes = std::mem::take(&mut state.final_bytes).freeze();
    let served_model = if state.served_model.is_empty() {
        None
    } else {
        Some(state.served_model.clone())
    };
    Summary {
        usage,
        tool_calls,
        response,
        final_bytes,
        served_model,
        error: state.error.clone(),
        saw_stop: state.final_stop_reason.is_some(),
    }
}

fn build_response(state: &TapState) -> CanonicalResponse {
    let content = state
        .blocks
        .iter()
        .map(|b| match b {
            BlockAccumulator::Text(t) => CanonicalContent::Text(t.clone()),
            BlockAccumulator::Thinking { text, signature } => CanonicalContent::Thinking {
                text: text.clone(),
                signature: signature.clone(),
            },
            BlockAccumulator::ToolUse { id, name, partial } => CanonicalContent::ToolUse {
                id: id.clone(),
                name: name.clone(),
                input: serde_json::from_str(partial)
                    .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
            },
        })
        .collect();
    CanonicalResponse {
        id: state.response_id.clone(),
        model: state.served_model.clone(),
        content,
        stop_reason: state.final_stop_reason,
        usage: state.usage,
    }
}

fn accumulate_event(state: &mut TapState, event: &CanonicalEvent) {
    match event {
        CanonicalEvent::MessageStart { id, model, usage } => {
            state.response_id.clone_from(id);
            if !model.is_empty() {
                state.served_model.clone_from(model);
            }
            state.usage = *usage;
        },
        CanonicalEvent::ContentBlockStart { index, block } => {
            let slot = match block {
                ContentBlockKind::Text => BlockAccumulator::Text(String::new()),
                ContentBlockKind::Thinking { signature } => BlockAccumulator::Thinking {
                    text: String::new(),
                    signature: signature.clone(),
                },
                ContentBlockKind::ToolUse { id, name } => BlockAccumulator::ToolUse {
                    id: id.clone(),
                    name: name.clone(),
                    partial: String::new(),
                },
            };
            let idx = *index as usize;
            while state.blocks.len() <= idx {
                state.blocks.push(BlockAccumulator::Text(String::new()));
            }
            state.blocks[idx] = slot;
        },
        CanonicalEvent::TextDelta { index, text } => {
            if let Some(BlockAccumulator::Text(buf)) = state.blocks.get_mut(*index as usize) {
                buf.push_str(text);
            }
        },
        CanonicalEvent::ThinkingDelta { index, text } => {
            if let Some(BlockAccumulator::Thinking { text: buf, .. }) =
                state.blocks.get_mut(*index as usize)
            {
                buf.push_str(text);
            }
        },
        CanonicalEvent::ToolUseDelta {
            index,
            partial_json,
        } => {
            if let Some(BlockAccumulator::ToolUse { partial, .. }) =
                state.blocks.get_mut(*index as usize)
            {
                partial.push_str(partial_json);
            }
        },
        CanonicalEvent::ContentBlockStop { .. } => {},
        CanonicalEvent::UsageDelta(u) => {
            if u.input_tokens > 0 {
                state.usage.input_tokens = u.input_tokens;
            }
            if u.output_tokens > 0 {
                state.usage.output_tokens = u.output_tokens;
            }
        },
        CanonicalEvent::MessageStop { stop_reason, .. } => {
            state.final_stop_reason = stop_reason.or(Some(
                super::protocol::canonical_response::CanonicalStopReason::EndTurn,
            ));
        },
        CanonicalEvent::Error(msg) => {
            state.error = Some(msg.clone());
        },
    }
}

fn spawn_audit_complete(audit: Arc<GatewayAudit>, _request_model: String, summary: Summary) {
    tokio::spawn(async move {
        if let Some(model) = summary.served_model.as_deref() {
            audit.set_served_model(model).await;
        }
        if let Some(err) = summary.error {
            if let Err(e) = audit.fail(&err).await {
                tracing::warn!(error = %e, "stream audit fail failed");
            }
        } else if let Err(e) = audit
            .complete(
                summary.usage,
                summary.tool_calls,
                &summary.response,
                &summary.final_bytes,
            )
            .await
        {
            tracing::warn!(error = %e, "stream audit complete failed");
        }
    });
}