Skip to main content

mcpr_core/proxy/
sse.rs

1/// Check if the response body is SSE-formatted and extract the JSON data.
2/// Returns the extracted JSON bytes if exactly one `data:` event is found.
3pub fn extract_json_from_sse(bytes: &[u8]) -> Option<Vec<u8>> {
4    let text = std::str::from_utf8(bytes).ok()?;
5    if !text.trim_start().starts_with("data:") && !text.contains("\ndata:") {
6        return None;
7    }
8    let mut json_parts = Vec::new();
9    for line in text.lines() {
10        if let Some(data) = line.strip_prefix("data:") {
11            let data = data.trim_start();
12            if !data.is_empty() {
13                json_parts.push(data);
14            }
15        }
16    }
17    if json_parts.len() == 1 {
18        Some(json_parts[0].as_bytes().to_vec())
19    } else {
20        None
21    }
22}
23
24/// Re-wrap JSON bytes into SSE format.
25pub fn wrap_as_sse(json_bytes: &[u8]) -> Vec<u8> {
26    let mut out = b"data: ".to_vec();
27    out.extend_from_slice(json_bytes);
28    out.extend_from_slice(b"\n\n");
29    out
30}
31
32/// Split a full upstream URL into (base, path).
33/// e.g. "http://localhost:9000/mcp" → ("http://localhost:9000", "/mcp")
34/// e.g. "http://localhost:9000" → ("http://localhost:9000", "")
35pub fn split_upstream(url: &str) -> (&str, &str) {
36    let after_scheme = if let Some(pos) = url.find("://") {
37        pos + 3
38    } else {
39        0
40    };
41    match url[after_scheme..].find('/') {
42        Some(pos) => url.split_at(after_scheme + pos),
43        None => (url, ""),
44    }
45}
46
47#[cfg(test)]
48#[allow(non_snake_case)]
49mod tests {
50    use super::*;
51
52    // ── SSE extraction ──
53
54    #[test]
55    fn extract_json_from_sse__single_event() {
56        let input = b"data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\n\n";
57        let result = extract_json_from_sse(input).unwrap();
58        let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
59        assert_eq!(parsed["jsonrpc"], "2.0");
60    }
61
62    #[test]
63    fn extract_json_from_sse__indented_data_line_returns_none() {
64        let input = b"  data: {\"id\":1}\n\n";
65        assert!(extract_json_from_sse(input).is_none());
66    }
67
68    #[test]
69    fn extract_json_from_sse__non_sse_input() {
70        let input = b"{\"jsonrpc\":\"2.0\",\"id\":1}";
71        assert!(extract_json_from_sse(input).is_none());
72    }
73
74    #[test]
75    fn extract_json_from_sse__multiple_events_returns_none() {
76        let input = b"data: {\"id\":1}\n\ndata: {\"id\":2}\n\n";
77        assert!(extract_json_from_sse(input).is_none());
78    }
79
80    #[test]
81    fn extract_json_from_sse__empty_data_skipped() {
82        let input = b"data: \ndata: {\"id\":1}\n\n";
83        let result = extract_json_from_sse(input);
84        assert!(result.is_some());
85    }
86
87    // ── SSE wrapping ──
88
89    #[test]
90    fn wrap_as_sse__correct_format() {
91        let json = b"{\"id\":1}";
92        let wrapped = wrap_as_sse(json);
93        assert_eq!(wrapped, b"data: {\"id\":1}\n\n");
94    }
95
96    #[test]
97    fn sse__roundtrip() {
98        let original = b"{\"jsonrpc\":\"2.0\",\"id\":42,\"result\":{\"content\":[]}}";
99        let wrapped = wrap_as_sse(original);
100        let extracted = extract_json_from_sse(&wrapped).unwrap();
101        assert_eq!(extracted, original);
102    }
103
104    // ── split_upstream ──
105
106    #[test]
107    fn split_upstream__with_path() {
108        let (base, path) = split_upstream("http://localhost:9000/mcp");
109        assert_eq!(base, "http://localhost:9000");
110        assert_eq!(path, "/mcp");
111    }
112
113    #[test]
114    fn split_upstream__no_path() {
115        let (base, path) = split_upstream("http://localhost:9000");
116        assert_eq!(base, "http://localhost:9000");
117        assert_eq!(path, "");
118    }
119
120    #[test]
121    fn split_upstream__deep_path() {
122        let (base, path) = split_upstream("https://api.example.com/v1/mcp");
123        assert_eq!(base, "https://api.example.com");
124        assert_eq!(path, "/v1/mcp");
125    }
126
127    #[test]
128    fn split_upstream__trailing_slash() {
129        let (base, path) = split_upstream("http://localhost:9000/");
130        assert_eq!(base, "http://localhost:9000");
131        assert_eq!(path, "/");
132    }
133}