forge-guardrails 0.1.2

Foundation types for an LLM-agent workflow framework
Documentation
use anyllm_translate::anthropic::streaming::StreamEvent;
use axum::body::{Body, Bytes};
use axum::http::{header, HeaderMap, HeaderName, HeaderValue, StatusCode};
use axum::response::{IntoResponse, Response};
use forge_guardrails::{AnthropicEventStream, HTTPServer, OpenAiEventStream};
use futures_util::StreamExt;
use std::io;
use tokio::sync::OwnedMutexGuard;

pub(crate) fn build_response(status: u16, content_type: &str, body: String) -> Response {
    let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
    let mut response = (status_code, body).into_response();
    if !content_type.is_empty() {
        if let Ok(value) = HeaderValue::from_str(content_type) {
            response.headers_mut().insert(header::CONTENT_TYPE, value);
        }
    }
    for (name, value) in HTTPServer::cors_headers() {
        if let Some(header_name) = cors_header_name(name) {
            response
                .headers_mut()
                .insert(header_name, HeaderValue::from_static(value));
        }
    }
    response
}

pub(crate) fn build_openai_sse_response(
    events: OpenAiEventStream,
    guard: Option<OwnedMutexGuard<()>>,
) -> Response {
    let mut response = (
        StatusCode::OK,
        Body::from_stream(openai_sse_bytes_stream(events, guard)),
    )
        .into_response();
    insert_sse_headers(response.headers_mut());
    insert_cors_headers(response.headers_mut());
    response
}

pub(crate) fn build_anthropic_sse_response(
    events: AnthropicEventStream,
    guard: Option<OwnedMutexGuard<()>>,
) -> Response {
    let mut response = (
        StatusCode::OK,
        Body::from_stream(anthropic_sse_bytes_stream(events, guard)),
    )
        .into_response();
    insert_sse_headers(response.headers_mut());
    insert_cors_headers(response.headers_mut());
    response
}

fn insert_cors_headers(headers: &mut HeaderMap) {
    for (name, value) in HTTPServer::cors_headers() {
        if let Some(header_name) = cors_header_name(name) {
            headers.insert(header_name, HeaderValue::from_static(value));
        }
    }
}

fn insert_sse_headers(headers: &mut HeaderMap) {
    headers.insert(
        header::CONTENT_TYPE,
        HeaderValue::from_static("text/event-stream"),
    );
    headers.insert(header::CACHE_CONTROL, HeaderValue::from_static("no-cache"));
    headers.insert(
        HeaderName::from_static("x-accel-buffering"),
        HeaderValue::from_static("no"),
    );
}

fn openai_sse_bytes_stream(
    mut events: OpenAiEventStream,
    guard: Option<OwnedMutexGuard<()>>,
) -> impl futures_core::Stream<Item = Result<Bytes, io::Error>> + Send + 'static {
    async_stream::stream! {
        let _guard = guard;
        while let Some(event) = events.next().await {
            match event {
                Ok(value) => yield Ok(Bytes::from(format!("data: {}\n\n", value))),
                Err(err) => {
                    yield Err(io::Error::other(err.to_string()));
                    return;
                }
            }
        }
        yield Ok(Bytes::from_static(b"data: [DONE]\n\n"));
    }
}

fn anthropic_sse_bytes_stream(
    mut events: AnthropicEventStream,
    guard: Option<OwnedMutexGuard<()>>,
) -> impl futures_core::Stream<Item = Result<Bytes, io::Error>> + Send + 'static {
    async_stream::stream! {
        let _guard = guard;
        while let Some(event) = events.next().await {
            match event {
                Ok(event) => {
                    let mut body = String::new();
                    push_anthropic_sse_event(&mut body, &event);
                    yield Ok(Bytes::from(body));
                }
                Err(err) => {
                    yield Err(io::Error::other(err.to_string()));
                    return;
                }
            }
        }
    }
}

fn push_anthropic_sse_event(body: &mut String, event: &StreamEvent) {
    body.push_str("event: ");
    body.push_str(anthropic_event_name(event));
    body.push('\n');
    body.push_str("data: ");
    body.push_str(&serde_json::to_string(event).unwrap_or_else(|_| "{}".to_string()));
    body.push_str("\n\n");
}

fn anthropic_event_name(event: &StreamEvent) -> &'static str {
    match event {
        StreamEvent::MessageStart { .. } => "message_start",
        StreamEvent::ContentBlockStart { .. } => "content_block_start",
        StreamEvent::ContentBlockDelta { .. } => "content_block_delta",
        StreamEvent::ContentBlockStop { .. } => "content_block_stop",
        StreamEvent::MessageDelta { .. } => "message_delta",
        StreamEvent::MessageStop { .. } => "message_stop",
        StreamEvent::Ping { .. } => "ping",
        StreamEvent::Error { .. } => "error",
        StreamEvent::Unknown => "unknown",
    }
}

fn cors_header_name(name: &str) -> Option<HeaderName> {
    match name {
        "Access-Control-Allow-Origin" => {
            Some(HeaderName::from_static("access-control-allow-origin"))
        }
        "Access-Control-Allow-Methods" => {
            Some(HeaderName::from_static("access-control-allow-methods"))
        }
        "Access-Control-Allow-Headers" => {
            Some(HeaderName::from_static("access-control-allow-headers"))
        }
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use axum::http::header;

    #[test]
    fn build_response_sets_content_type() {
        let response = build_response(200, "application/json", "{}".to_string());
        assert_eq!(
            response.headers().get(header::CONTENT_TYPE).unwrap(),
            "application/json"
        );
    }

    #[test]
    fn build_response_omits_empty_content_type() {
        let response = build_response(204, "", String::new());
        assert_eq!(response.status(), axum::http::StatusCode::NO_CONTENT);
    }

    #[test]
    fn build_response_sets_cors_headers() {
        let response = build_response(200, "application/json", "{}".to_string());
        assert_eq!(
            response
                .headers()
                .get("access-control-allow-origin")
                .unwrap(),
            "*"
        );
        assert_eq!(
            response
                .headers()
                .get("access-control-allow-methods")
                .unwrap(),
            "GET, POST, OPTIONS"
        );
        assert_eq!(
            response
                .headers()
                .get("access-control-allow-headers")
                .unwrap(),
            "Content-Type, Authorization"
        );
    }
}