Skip to main content

mcpr_core/
sse.rs

1/// Check if the response body is SSE-formatted and extract the JSON data.
2/// Returns the extracted JSON bytes if exactly one `data:` event is found.
3pub fn extract_json_from_sse(bytes: &[u8]) -> Option<Vec<u8>> {
4    let text = std::str::from_utf8(bytes).ok()?;
5    if !text.trim_start().starts_with("data:") && !text.contains("\ndata:") {
6        return None;
7    }
8    let mut json_parts = Vec::new();
9    for line in text.lines() {
10        if let Some(data) = line.strip_prefix("data:") {
11            let data = data.trim_start();
12            if !data.is_empty() {
13                json_parts.push(data);
14            }
15        }
16    }
17    if json_parts.len() == 1 {
18        Some(json_parts[0].as_bytes().to_vec())
19    } else {
20        None
21    }
22}
23
24/// Re-wrap JSON bytes into SSE format.
25pub fn wrap_as_sse(json_bytes: &[u8]) -> Vec<u8> {
26    let mut out = b"data: ".to_vec();
27    out.extend_from_slice(json_bytes);
28    out.extend_from_slice(b"\n\n");
29    out
30}
31
32/// Split a full upstream URL into (base, path).
33/// e.g. "http://localhost:9000/mcp" → ("http://localhost:9000", "/mcp")
34/// e.g. "http://localhost:9000" → ("http://localhost:9000", "")
35pub fn split_upstream(url: &str) -> (&str, &str) {
36    let after_scheme = if let Some(pos) = url.find("://") {
37        pos + 3
38    } else {
39        0
40    };
41    match url[after_scheme..].find('/') {
42        Some(pos) => url.split_at(after_scheme + pos),
43        None => (url, ""),
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50
51    // ── SSE extraction ──
52
53    #[test]
54    fn extract_json_from_sse_single_event() {
55        let input = b"data: {\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}\n\n";
56        let result = extract_json_from_sse(input).unwrap();
57        let parsed: serde_json::Value = serde_json::from_slice(&result).unwrap();
58        assert_eq!(parsed["jsonrpc"], "2.0");
59    }
60
61    #[test]
62    fn extract_json_from_sse_with_leading_whitespace_returns_none() {
63        let input = b"  data: {\"id\":1}\n\n";
64        assert!(extract_json_from_sse(input).is_none());
65    }
66
67    #[test]
68    fn extract_json_from_sse_not_sse() {
69        let input = b"{\"jsonrpc\":\"2.0\",\"id\":1}";
70        assert!(extract_json_from_sse(input).is_none());
71    }
72
73    #[test]
74    fn extract_json_from_sse_multiple_events_returns_none() {
75        let input = b"data: {\"id\":1}\n\ndata: {\"id\":2}\n\n";
76        assert!(extract_json_from_sse(input).is_none());
77    }
78
79    #[test]
80    fn extract_json_from_sse_empty_data_skipped() {
81        let input = b"data: \ndata: {\"id\":1}\n\n";
82        let result = extract_json_from_sse(input);
83        assert!(result.is_some());
84    }
85
86    // ── SSE wrapping ──
87
88    #[test]
89    fn wrap_as_sse_format() {
90        let json = b"{\"id\":1}";
91        let wrapped = wrap_as_sse(json);
92        assert_eq!(wrapped, b"data: {\"id\":1}\n\n");
93    }
94
95    #[test]
96    fn sse_roundtrip() {
97        let original = b"{\"jsonrpc\":\"2.0\",\"id\":42,\"result\":{\"content\":[]}}";
98        let wrapped = wrap_as_sse(original);
99        let extracted = extract_json_from_sse(&wrapped).unwrap();
100        assert_eq!(extracted, original);
101    }
102
103    // ── split_upstream ──
104
105    #[test]
106    fn split_upstream_with_path() {
107        let (base, path) = split_upstream("http://localhost:9000/mcp");
108        assert_eq!(base, "http://localhost:9000");
109        assert_eq!(path, "/mcp");
110    }
111
112    #[test]
113    fn split_upstream_no_path() {
114        let (base, path) = split_upstream("http://localhost:9000");
115        assert_eq!(base, "http://localhost:9000");
116        assert_eq!(path, "");
117    }
118
119    #[test]
120    fn split_upstream_deep_path() {
121        let (base, path) = split_upstream("https://api.example.com/v1/mcp");
122        assert_eq!(base, "https://api.example.com");
123        assert_eq!(path, "/v1/mcp");
124    }
125
126    #[test]
127    fn split_upstream_trailing_slash() {
128        let (base, path) = split_upstream("http://localhost:9000/");
129        assert_eq!(base, "http://localhost:9000");
130        assert_eq!(path, "/");
131    }
132}