sampling2api 0.1.1

Expose MCP client sampling as an Anthropic-compatible Messages API
Documentation
use std::convert::Infallible;

use axum::{
    Json,
    response::{
        IntoResponse, Response,
        sse::{Event, Sse},
    },
};
use futures_util::stream;
use serde_json::{Value, json};

use crate::anthropic::{MessagesResponse, OutputContentBlock};

#[derive(Debug, Clone, PartialEq)]
pub struct SseFrame {
    pub event: &'static str,
    pub data: Value,
}

pub fn messages_response_to_sse_frames(response: &MessagesResponse) -> Vec<SseFrame> {
    let mut frames = Vec::new();

    frames.push(SseFrame {
        event: "message_start",
        data: json!({
            "type": "message_start",
            "message": {
                "id": response.id,
                "type": response.object_type,
                "role": response.role,
                "content": [],
                "model": response.model,
                "stop_reason": Value::Null,
                "stop_sequence": Value::Null,
                "usage": response.usage,
            }
        }),
    });

    for (index, block) in response.content.iter().enumerate() {
        match block {
            OutputContentBlock::Text { text } => {
                frames.push(SseFrame {
                    event: "content_block_start",
                    data: json!({
                        "type": "content_block_start",
                        "index": index,
                        "content_block": {
                            "type": "text",
                            "text": "",
                        }
                    }),
                });
                frames.push(SseFrame {
                    event: "content_block_delta",
                    data: json!({
                        "type": "content_block_delta",
                        "index": index,
                        "delta": {
                            "type": "text_delta",
                            "text": text,
                        }
                    }),
                });
            }
            OutputContentBlock::ToolUse { id, name, input } => {
                frames.push(SseFrame {
                    event: "content_block_start",
                    data: json!({
                        "type": "content_block_start",
                        "index": index,
                        "content_block": {
                            "type": "tool_use",
                            "id": id,
                            "name": name,
                            "input": {},
                        }
                    }),
                });
                frames.push(SseFrame {
                    event: "content_block_delta",
                    data: json!({
                        "type": "content_block_delta",
                        "index": index,
                        "delta": {
                            "type": "input_json_delta",
                            "partial_json": serde_json::to_string(input).expect("tool input should serialize"),
                        }
                    }),
                });
            }
        }

        frames.push(SseFrame {
            event: "content_block_stop",
            data: json!({
                "type": "content_block_stop",
                "index": index,
            }),
        });
    }

    frames.push(SseFrame {
        event: "message_delta",
        data: json!({
            "type": "message_delta",
            "delta": {
                "stop_reason": response.stop_reason,
                "stop_sequence": response.stop_sequence,
            },
            "usage": {
                "output_tokens": response.usage.output_tokens,
            }
        }),
    });
    frames.push(SseFrame {
        event: "message_stop",
        data: json!({
            "type": "message_stop",
        }),
    });

    frames
}

pub fn messages_response_to_sse_response(response: MessagesResponse) -> Response {
    let frames = messages_response_to_sse_frames(&response);
    let stream = stream::iter(frames.into_iter().map(|frame| {
        Ok::<Event, Infallible>(
            Event::default()
                .event(frame.event)
                .json_data(frame.data)
                .expect("SSE frame should serialize"),
        )
    }));

    Sse::new(stream).into_response()
}

pub fn messages_response_to_json_response(response: MessagesResponse) -> Response {
    Json(response).into_response()
}

#[cfg(test)]
mod tests {
    use serde_json::json;

    use super::*;
    use crate::anthropic::{MessagesResponse, OutputContentBlock, Usage};

    #[test]
    fn streaming_expands_text_and_tool_use_into_standard_events() {
        let response = MessagesResponse {
            id: "msg_1".to_string(),
            object_type: "message".to_string(),
            role: "assistant".to_string(),
            content: vec![
                OutputContentBlock::Text {
                    text: "Hello".to_string(),
                },
                OutputContentBlock::ToolUse {
                    id: "toolu_1".to_string(),
                    name: "lookup_weather".to_string(),
                    input: json!({"city": "Paris"}),
                },
            ],
            model: "mock-model".to_string(),
            stop_reason: Some("tool_use".to_string()),
            stop_sequence: None,
            usage: Usage::default(),
        };

        let frames = messages_response_to_sse_frames(&response);
        let events = frames.iter().map(|frame| frame.event).collect::<Vec<_>>();

        assert_eq!(
            events,
            vec![
                "message_start",
                "content_block_start",
                "content_block_delta",
                "content_block_stop",
                "content_block_start",
                "content_block_delta",
                "content_block_stop",
                "message_delta",
                "message_stop",
            ]
        );
        assert_eq!(frames[2].data["delta"]["type"], "text_delta");
        assert_eq!(frames[5].data["delta"]["type"], "input_json_delta");
        assert_eq!(
            frames[5].data["delta"]["partial_json"],
            "{\"city\":\"Paris\"}"
        );
    }
}