1use std::io::{BufReader, Read, Write};
59
60use asupersync::Cx;
61use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse};
62
63use crate::{Codec, CodecError, Transport, TransportError};
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum SseEventType {
72 Endpoint,
74 Message,
76}
77
78impl SseEventType {
79 #[must_use]
81 pub fn as_str(&self) -> &'static str {
82 match self {
83 SseEventType::Endpoint => "endpoint",
84 SseEventType::Message => "message",
85 }
86 }
87
88 #[must_use]
90 pub fn from_str(s: &str) -> Option<Self> {
91 match s {
92 "endpoint" => Some(SseEventType::Endpoint),
93 "message" => Some(SseEventType::Message),
94 _ => None,
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct SseEvent {
102 pub event_type: SseEventType,
104 pub data: String,
106 pub id: Option<String>,
108 pub retry: Option<u64>,
110}
111
112impl SseEvent {
113 #[must_use]
115 pub fn endpoint(url: impl Into<String>) -> Self {
116 Self {
117 event_type: SseEventType::Endpoint,
118 data: url.into(),
119 id: None,
120 retry: None,
121 }
122 }
123
124 #[must_use]
126 pub fn message(data: impl Into<String>) -> Self {
127 Self {
128 event_type: SseEventType::Message,
129 data: data.into(),
130 id: None,
131 retry: None,
132 }
133 }
134
135 #[must_use]
137 pub fn with_id(mut self, id: impl Into<String>) -> Self {
138 self.id = Some(id.into());
139 self
140 }
141
142 #[must_use]
144 pub fn with_retry(mut self, retry_ms: u64) -> Self {
145 self.retry = Some(retry_ms);
146 self
147 }
148
149 #[must_use]
155 pub fn to_bytes(&self) -> Vec<u8> {
156 let mut output = Vec::with_capacity(self.data.len() + 64);
157
158 output.extend_from_slice(b"event: ");
160 output.extend_from_slice(self.event_type.as_str().as_bytes());
161 output.push(b'\n');
162
163 if let Some(ref id) = self.id {
165 output.extend_from_slice(b"id: ");
166 output.extend_from_slice(id.as_bytes());
167 output.push(b'\n');
168 }
169
170 if let Some(retry) = self.retry {
172 output.extend_from_slice(b"retry: ");
173 output.extend_from_slice(retry.to_string().as_bytes());
174 output.push(b'\n');
175 }
176
177 for line in self.data.lines() {
179 output.extend_from_slice(b"data: ");
180 output.extend_from_slice(line.as_bytes());
181 output.push(b'\n');
182 }
183
184 if self.data.is_empty() {
186 output.extend_from_slice(b"data: \n");
187 }
188
189 output.push(b'\n');
191
192 output
193 }
194}
195
196pub struct SseWriter<W> {
217 writer: W,
218 codec: Codec,
219 event_counter: u64,
220}
221
222impl<W: Write> SseWriter<W> {
223 #[must_use]
225 pub fn new(writer: W) -> Self {
226 Self {
227 writer,
228 codec: Codec::new(),
229 event_counter: 0,
230 }
231 }
232
233 pub fn write_event(&mut self, cx: &Cx, event: &SseEvent) -> Result<(), TransportError> {
239 if cx.is_cancel_requested() {
240 return Err(TransportError::Cancelled);
241 }
242
243 let bytes = event.to_bytes();
244 self.writer.write_all(&bytes)?;
245 self.writer.flush()?;
246 Ok(())
247 }
248
249 pub fn write_endpoint(&mut self, cx: &Cx, url: &str) -> Result<(), TransportError> {
253 let event = SseEvent::endpoint(url);
254 self.write_event(cx, &event)
255 }
256
257 pub fn write_message(
259 &mut self,
260 cx: &Cx,
261 message: &JsonRpcMessage,
262 ) -> Result<(), TransportError> {
263 if cx.is_cancel_requested() {
264 return Err(TransportError::Cancelled);
265 }
266
267 let json = match message {
268 JsonRpcMessage::Request(req) => {
269 serde_json::to_string(req).map_err(CodecError::Json)?
271 }
272 JsonRpcMessage::Response(resp) => {
273 serde_json::to_string(resp).map_err(CodecError::Json)?
274 }
275 };
276
277 self.event_counter += 1;
278 let event = SseEvent::message(json).with_id(self.event_counter.to_string());
279 self.write_event(cx, &event)
280 }
281
282 pub fn write_response(
284 &mut self,
285 cx: &Cx,
286 response: &JsonRpcResponse,
287 ) -> Result<(), TransportError> {
288 self.write_message(cx, &JsonRpcMessage::Response(response.clone()))
289 }
290
291 pub fn write_request(
296 &mut self,
297 cx: &Cx,
298 request: &JsonRpcRequest,
299 ) -> Result<(), TransportError> {
300 self.write_message(cx, &JsonRpcMessage::Request(request.clone()))
301 }
302
303 pub fn write_comment(&mut self, cx: &Cx, comment: &str) -> Result<(), TransportError> {
308 if cx.is_cancel_requested() {
309 return Err(TransportError::Cancelled);
310 }
311
312 self.writer.write_all(b": ")?;
314 self.writer.write_all(comment.as_bytes())?;
315 self.writer.write_all(b"\n")?;
316 self.writer.flush()?;
317 Ok(())
318 }
319
320 pub fn keep_alive(&mut self, cx: &Cx) -> Result<(), TransportError> {
322 self.write_comment(cx, "keep-alive")
323 }
324
325 pub fn inner(&self) -> &W {
327 &self.writer
328 }
329
330 pub fn inner_mut(&mut self) -> &mut W {
332 &mut self.writer
333 }
334
335 pub fn into_inner(self) -> W {
337 self.writer
338 }
339}
340
341const MAX_SSE_LINE_SIZE: usize = 64 * 1024;
347
348pub struct SseReader<R> {
365 reader: BufReader<R>,
366 line_buffer: String,
367 max_line_size: usize,
369}
370
371impl<R: Read> SseReader<R> {
372 #[must_use]
374 pub fn new(reader: R) -> Self {
375 Self {
376 reader: BufReader::new(reader),
377 line_buffer: String::with_capacity(4096),
378 max_line_size: MAX_SSE_LINE_SIZE,
379 }
380 }
381
382 fn read_line_bounded(&mut self) -> Result<usize, std::io::Error> {
392 use std::io::BufRead;
393
394 let mut total_read = 0;
395 loop {
396 let available = self.reader.fill_buf()?;
397 if available.is_empty() {
398 return Ok(total_read);
400 }
401
402 let newline_pos = available.iter().position(|&b| b == b'\n');
404 let bytes_to_consume = match newline_pos {
405 Some(pos) => pos + 1, None => available.len(),
407 };
408
409 if self.line_buffer.len() + bytes_to_consume > self.max_line_size {
411 return Err(std::io::Error::new(
412 std::io::ErrorKind::InvalidData,
413 format!(
414 "SSE line exceeds maximum size of {} bytes",
415 self.max_line_size
416 ),
417 ));
418 }
419
420 let chunk = &available[..bytes_to_consume];
422 let chunk_str = std::str::from_utf8(chunk).map_err(|e| {
423 std::io::Error::new(
424 std::io::ErrorKind::InvalidData,
425 format!("Invalid UTF-8: {e}"),
426 )
427 })?;
428 self.line_buffer.push_str(chunk_str);
429 total_read += bytes_to_consume;
430
431 self.reader.consume(bytes_to_consume);
432
433 if newline_pos.is_some() {
434 return Ok(total_read);
436 }
437 }
438 }
439
440 pub fn read_event(&mut self, cx: &Cx) -> Result<Option<SseEvent>, TransportError> {
452 const MAX_EVENT_DATA_SIZE: usize = 1024 * 1024;
454
455 if cx.is_cancel_requested() {
456 return Err(TransportError::Cancelled);
457 }
458
459 let mut event_type: Option<SseEventType> = None;
460 let mut unknown_event = false;
461 let mut data_lines: Vec<String> = Vec::new();
462 let mut total_data_size: usize = 0;
463 let mut event_id: Option<String> = None;
464 let mut retry: Option<u64> = None;
465
466 loop {
467 self.line_buffer.clear();
468 let bytes_read = self.read_line_bounded()?;
469
470 if bytes_read == 0 {
471 return Ok(None);
473 }
474
475 if cx.is_cancel_requested() {
477 return Err(TransportError::Cancelled);
478 }
479
480 let line = self
481 .line_buffer
482 .trim_end_matches(|c| c == '\n' || c == '\r');
483
484 if line.is_empty() {
486 if unknown_event {
487 event_type = None;
488 data_lines.clear();
489 total_data_size = 0;
490 event_id = None;
491 retry = None;
492 unknown_event = false;
493 continue;
494 }
495 if event_type.is_some() || !data_lines.is_empty() {
496 let data = data_lines.join("\n");
498 return Ok(Some(SseEvent {
499 event_type: event_type.unwrap_or(SseEventType::Message),
500 data,
501 id: event_id,
502 retry,
503 }));
504 }
505 continue;
507 }
508
509 if line.starts_with(':') {
511 continue;
512 }
513
514 if let Some((field, value)) = line.split_once(':') {
516 let value = value.strip_prefix(' ').unwrap_or(value);
518
519 match field {
520 "event" => {
521 if let Some(parsed) = SseEventType::from_str(value) {
522 event_type = Some(parsed);
523 unknown_event = false;
524 } else {
525 event_type = None;
526 unknown_event = true;
527 }
528 }
529 "data" => {
530 total_data_size = total_data_size.saturating_add(value.len() + 1); if total_data_size > MAX_EVENT_DATA_SIZE {
533 return Err(TransportError::Io(std::io::Error::new(
534 std::io::ErrorKind::InvalidData,
535 format!(
536 "SSE event data exceeds maximum size of {} bytes",
537 MAX_EVENT_DATA_SIZE
538 ),
539 )));
540 }
541 data_lines.push(value.to_string());
542 }
543 "id" => {
544 event_id = Some(value.to_string());
545 }
546 "retry" => {
547 retry = value.parse().ok();
548 }
549 _ => {
550 }
552 }
553 }
554 }
555 }
556
557 pub fn read_message(&mut self, cx: &Cx) -> Result<Option<JsonRpcMessage>, TransportError> {
567 loop {
568 match self.read_event(cx)? {
569 Some(event) => {
570 if event.event_type == SseEventType::Message {
571 let message: JsonRpcMessage = serde_json::from_str(&event.data)
572 .map_err(|e| TransportError::Codec(CodecError::Json(e)))?;
573 return Ok(Some(message));
574 }
575 continue;
577 }
578 None => return Ok(None),
579 }
580 }
581 }
582
583 pub fn read_endpoint(&mut self, cx: &Cx) -> Result<Option<String>, TransportError> {
593 loop {
594 match self.read_event(cx)? {
595 Some(event) => {
596 if event.event_type == SseEventType::Endpoint {
597 return Ok(Some(event.data));
598 }
599 continue;
601 }
602 None => return Ok(None),
603 }
604 }
605 }
606
607 pub fn inner(&self) -> &BufReader<R> {
609 &self.reader
610 }
611}
612
613pub struct SseServerTransport<W, R> {
653 writer: SseWriter<W>,
654 request_source: R,
656 endpoint_sent: bool,
657 endpoint_url: String,
658}
659
660impl<W: Write, R: Iterator<Item = JsonRpcRequest>> SseServerTransport<W, R> {
661 #[must_use]
669 pub fn new(writer: W, request_source: R, endpoint_url: impl Into<String>) -> Self {
670 Self {
671 writer: SseWriter::new(writer),
672 request_source,
673 endpoint_sent: false,
674 endpoint_url: endpoint_url.into(),
675 }
676 }
677
678 fn ensure_endpoint_sent(&mut self, cx: &Cx) -> Result<(), TransportError> {
680 if !self.endpoint_sent {
681 self.writer.write_endpoint(cx, &self.endpoint_url)?;
682 self.endpoint_sent = true;
683 }
684 Ok(())
685 }
686}
687
688impl<W: Write, R: Iterator<Item = JsonRpcRequest>> Transport for SseServerTransport<W, R> {
689 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
690 self.ensure_endpoint_sent(cx)?;
691 self.writer.write_message(cx, message)
692 }
693
694 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
695 if cx.is_cancel_requested() {
696 return Err(TransportError::Cancelled);
697 }
698
699 match self.request_source.next() {
701 Some(request) => Ok(JsonRpcMessage::Request(request)),
702 None => Err(TransportError::Closed),
703 }
704 }
705
706 fn close(&mut self) -> Result<(), TransportError> {
707 self.writer.inner_mut().flush()?;
710 Ok(())
711 }
712}
713
714pub struct SseClientTransport<R, W> {
741 reader: SseReader<R>,
742 request_sink: W,
744 codec: Codec,
745}
746
747impl<R: Read, W: Write> SseClientTransport<R, W> {
748 #[must_use]
755 pub fn new(reader: R, request_sink: W) -> Self {
756 Self {
757 reader: SseReader::new(reader),
758 request_sink,
759 codec: Codec::new(),
760 }
761 }
762
763 pub fn read_endpoint(&mut self, cx: &Cx) -> Result<Option<String>, TransportError> {
767 self.reader.read_endpoint(cx)
768 }
769}
770
771impl<R: Read, W: Write> Transport for SseClientTransport<R, W> {
772 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
773 if cx.is_cancel_requested() {
774 return Err(TransportError::Cancelled);
775 }
776
777 let bytes = match message {
779 JsonRpcMessage::Request(req) => self.codec.encode_request(req)?,
780 JsonRpcMessage::Response(resp) => self.codec.encode_response(resp)?,
781 };
782
783 self.request_sink.write_all(&bytes)?;
784 self.request_sink.flush()?;
785 Ok(())
786 }
787
788 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
789 match self.reader.read_message(cx)? {
790 Some(message) => Ok(message),
791 None => Err(TransportError::Closed),
792 }
793 }
794
795 fn close(&mut self) -> Result<(), TransportError> {
796 self.request_sink.flush()?;
797 Ok(())
798 }
799}
800
801#[cfg(test)]
806mod tests {
807 use super::*;
808 use std::io::Cursor;
809
810 #[test]
811 fn test_sse_event_endpoint() {
812 let event = SseEvent::endpoint("http://localhost:8080/messages");
813 let bytes = event.to_bytes();
814 let output = String::from_utf8(bytes).unwrap();
815
816 assert!(output.contains("event: endpoint\n"));
817 assert!(output.contains("data: http://localhost:8080/messages\n"));
818 assert!(output.ends_with("\n\n")); }
820
821 #[test]
822 fn test_sse_event_message() {
823 let event = SseEvent::message(r#"{"jsonrpc":"2.0","id":1}"#).with_id("42");
824 let bytes = event.to_bytes();
825 let output = String::from_utf8(bytes).unwrap();
826
827 assert!(output.contains("event: message\n"));
828 assert!(output.contains("id: 42\n"));
829 assert!(output.contains(r#"data: {"jsonrpc":"2.0","id":1}"#));
830 }
831
832 #[test]
833 fn test_sse_event_with_retry() {
834 let event = SseEvent::message("test").with_retry(5000);
835 let bytes = event.to_bytes();
836 let output = String::from_utf8(bytes).unwrap();
837
838 assert!(output.contains("retry: 5000\n"));
839 }
840
841 #[test]
842 fn test_sse_event_multiline_data() {
843 let event = SseEvent::message("line1\nline2\nline3");
844 let bytes = event.to_bytes();
845 let output = String::from_utf8(bytes).unwrap();
846
847 assert!(output.contains("data: line1\n"));
848 assert!(output.contains("data: line2\n"));
849 assert!(output.contains("data: line3\n"));
850 }
851
852 #[test]
853 fn test_sse_reader_simple_event() {
854 let input = b"event: message\ndata: hello\n\n";
855 let reader = Cursor::new(input.to_vec());
856 let mut sse_reader = SseReader::new(reader);
857
858 let cx = Cx::for_testing();
859 let event = sse_reader.read_event(&cx).unwrap().unwrap();
860
861 assert_eq!(event.event_type, SseEventType::Message);
862 assert_eq!(event.data, "hello");
863 }
864
865 #[test]
866 fn test_sse_reader_with_id() {
867 let input = b"event: message\nid: 42\ndata: test\n\n";
868 let reader = Cursor::new(input.to_vec());
869 let mut sse_reader = SseReader::new(reader);
870
871 let cx = Cx::for_testing();
872 let event = sse_reader.read_event(&cx).unwrap().unwrap();
873
874 assert_eq!(event.id, Some("42".to_string()));
875 }
876
877 #[test]
878 fn test_sse_reader_multiline_data() {
879 let input = b"event: message\ndata: line1\ndata: line2\n\n";
880 let reader = Cursor::new(input.to_vec());
881 let mut sse_reader = SseReader::new(reader);
882
883 let cx = Cx::for_testing();
884 let event = sse_reader.read_event(&cx).unwrap().unwrap();
885
886 assert_eq!(event.data, "line1\nline2");
887 }
888
889 #[test]
890 fn test_sse_reader_skips_comments() {
891 let input = b": this is a comment\nevent: message\ndata: hello\n\n";
892 let reader = Cursor::new(input.to_vec());
893 let mut sse_reader = SseReader::new(reader);
894
895 let cx = Cx::for_testing();
896 let event = sse_reader.read_event(&cx).unwrap().unwrap();
897
898 assert_eq!(event.data, "hello");
899 }
900
901 #[test]
902 fn test_sse_reader_skips_unknown_events() {
903 let input = b"event: ping\ndata: keep-alive\n\n\
904event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"ping\",\"id\":1}\n\n";
905 let reader = Cursor::new(input.to_vec());
906 let mut sse_reader = SseReader::new(reader);
907
908 let cx = Cx::for_testing();
909 let message = sse_reader.read_message(&cx).unwrap().unwrap();
910
911 assert!(
912 matches!(message, JsonRpcMessage::Request(_)),
913 "Expected request"
914 );
915 if let JsonRpcMessage::Request(req) = message {
916 assert_eq!(req.method, "ping");
917 }
918 }
919
920 #[test]
921 fn test_sse_reader_eof() {
922 let input = b"";
923 let reader = Cursor::new(input.to_vec());
924 let mut sse_reader = SseReader::new(reader);
925
926 let cx = Cx::for_testing();
927 let result = sse_reader.read_event(&cx).unwrap();
928
929 assert!(result.is_none());
930 }
931
932 #[test]
933 fn test_sse_reader_endpoint_event() {
934 let input = b"event: endpoint\ndata: http://localhost/post\n\n";
935 let reader = Cursor::new(input.to_vec());
936 let mut sse_reader = SseReader::new(reader);
937
938 let cx = Cx::for_testing();
939 let url = sse_reader.read_endpoint(&cx).unwrap().unwrap();
940
941 assert_eq!(url, "http://localhost/post");
942 }
943
944 #[test]
945 fn test_sse_writer_endpoint() {
946 let buffer = Vec::new();
947 let mut writer = SseWriter::new(buffer);
948
949 let cx = Cx::for_testing();
950 writer
951 .write_endpoint(&cx, "http://localhost:8080/messages")
952 .unwrap();
953
954 let output = String::from_utf8(writer.into_inner()).unwrap();
955 assert!(output.contains("event: endpoint\n"));
956 assert!(output.contains("data: http://localhost:8080/messages\n"));
957 }
958
959 #[test]
960 fn test_sse_writer_keep_alive() {
961 let buffer = Vec::new();
962 let mut writer = SseWriter::new(buffer);
963
964 let cx = Cx::for_testing();
965 writer.keep_alive(&cx).unwrap();
966
967 let output = String::from_utf8(writer.into_inner()).unwrap();
968 assert!(output.contains(": keep-alive\n"));
969 }
970
971 #[test]
972 fn test_sse_roundtrip() {
973 let write_buffer = Vec::new();
975 let mut writer = SseWriter::new(write_buffer);
976
977 let cx = Cx::for_testing();
978 let message = JsonRpcMessage::Response(JsonRpcResponse {
979 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
980 result: Some(serde_json::json!({"status": "ok"})),
981 error: None,
982 id: Some(fastmcp_protocol::RequestId::Number(1)),
983 });
984
985 writer.write_message(&cx, &message).unwrap();
986 let written = writer.into_inner();
987
988 let mut reader = SseReader::new(Cursor::new(written));
990 let read_message = reader.read_message(&cx).unwrap().unwrap();
991
992 assert!(
993 matches!(read_message, JsonRpcMessage::Response(_)),
994 "Expected response"
995 );
996 if let JsonRpcMessage::Response(resp) = read_message {
997 assert_eq!(resp.result, Some(serde_json::json!({"status": "ok"})));
998 }
999 }
1000
1001 #[test]
1002 fn test_sse_reader_cancellation() {
1003 let input = b"event: message\ndata: hello\n\n";
1004 let reader = Cursor::new(input.to_vec());
1005 let mut sse_reader = SseReader::new(reader);
1006
1007 let cx = Cx::for_testing();
1008 cx.set_cancel_requested(true);
1009
1010 let result = sse_reader.read_event(&cx);
1011 assert!(matches!(result, Err(TransportError::Cancelled)));
1012 }
1013
1014 #[test]
1015 fn test_sse_writer_cancellation() {
1016 let buffer = Vec::new();
1017 let mut writer = SseWriter::new(buffer);
1018
1019 let cx = Cx::for_testing();
1020 cx.set_cancel_requested(true);
1021
1022 let result = writer.write_endpoint(&cx, "http://test");
1023 assert!(matches!(result, Err(TransportError::Cancelled)));
1024 }
1025
1026 #[test]
1031 fn e2e_sse_connection_establishment() {
1032 let buffer = Vec::new();
1034 let mut writer = SseWriter::new(buffer);
1035 let cx = Cx::for_testing();
1036
1037 writer
1039 .write_endpoint(&cx, "http://localhost:8080/mcp/messages")
1040 .unwrap();
1041
1042 let output = String::from_utf8(writer.into_inner()).unwrap();
1043
1044 assert!(output.starts_with("event: endpoint\n"));
1046 assert!(output.contains("data: http://localhost:8080/mcp/messages\n"));
1047 assert!(output.contains("\n\n")); }
1049
1050 #[test]
1051 fn e2e_sse_event_stream_sequence() {
1052 let buffer = Vec::new();
1054 let mut writer = SseWriter::new(buffer);
1055 let cx = Cx::for_testing();
1056
1057 writer.write_endpoint(&cx, "http://localhost/post").unwrap();
1059
1060 for i in 1..=3 {
1062 let response = JsonRpcResponse {
1063 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1064 result: Some(serde_json::json!({"count": i})),
1065 error: None,
1066 id: Some(fastmcp_protocol::RequestId::Number(i)),
1067 };
1068 writer.write_response(&cx, &response).unwrap();
1069 }
1070
1071 writer.keep_alive(&cx).unwrap();
1073
1074 let output = String::from_utf8(writer.into_inner()).unwrap();
1075
1076 assert!(output.contains("event: endpoint\n"));
1078 assert!(output.contains("event: message\n"));
1079 assert!(output.contains("id: 1\n")); assert!(output.contains("id: 2\n"));
1081 assert!(output.contains("id: 3\n"));
1082 assert!(output.contains(": keep-alive\n"));
1083 }
1084
1085 #[test]
1086 fn e2e_sse_resumability_with_last_event_id() {
1087 let input = b"\
1089event: message\n\
1090id: 100\n\
1091data: {\"jsonrpc\":\"2.0\",\"result\":{\"n\":1},\"id\":1}\n\
1092\n\
1093event: message\n\
1094id: 101\n\
1095data: {\"jsonrpc\":\"2.0\",\"result\":{\"n\":2},\"id\":2}\n\
1096\n\
1097event: message\n\
1098id: 102\n\
1099data: {\"jsonrpc\":\"2.0\",\"result\":{\"n\":3},\"id\":3}\n\
1100\n";
1101
1102 let reader = Cursor::new(input.to_vec());
1103 let mut sse_reader = SseReader::new(reader);
1104 let cx = Cx::for_testing();
1105
1106 let mut event_ids = Vec::new();
1108 while let Some(event) = sse_reader.read_event(&cx).unwrap() {
1109 if let Some(id) = event.id {
1110 event_ids.push(id);
1111 }
1112 }
1113
1114 assert_eq!(event_ids, vec!["100", "101", "102"]);
1115
1116 let last_event_id = event_ids.last().unwrap();
1118 assert_eq!(last_event_id, "102");
1119 }
1120
1121 #[test]
1122 fn e2e_sse_graceful_disconnect_on_eof() {
1123 let input = b"\
1125event: message\n\
1126data: {\"jsonrpc\":\"2.0\",\"method\":\"test\"}\n\
1127\n";
1128
1129 let reader = Cursor::new(input.to_vec());
1130 let mut sse_reader = SseReader::new(reader);
1131 let cx = Cx::for_testing();
1132
1133 let event = sse_reader.read_event(&cx).unwrap();
1135 assert!(event.is_some());
1136
1137 let event = sse_reader.read_event(&cx).unwrap();
1139 assert!(event.is_none());
1140 }
1141
1142 #[test]
1143 fn e2e_sse_server_transport_flow() {
1144 let requests = vec![
1146 JsonRpcRequest::new("initialize", None, 1i64),
1147 JsonRpcRequest::new("tools/list", None, 2i64),
1148 ];
1149
1150 let buffer = Vec::new();
1151 let mut transport =
1152 SseServerTransport::new(buffer, requests.into_iter(), "http://localhost/post");
1153 let cx = Cx::for_testing();
1154
1155 let msg1 = transport.recv(&cx).unwrap();
1157 assert!(matches!(msg1, JsonRpcMessage::Request(_)));
1158 if let JsonRpcMessage::Request(req) = msg1 {
1159 assert_eq!(req.method, "initialize");
1160 }
1161
1162 let response = JsonRpcResponse {
1164 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1165 result: Some(serde_json::json!({"capabilities": {}})),
1166 error: None,
1167 id: Some(fastmcp_protocol::RequestId::Number(1)),
1168 };
1169 transport
1170 .send(&cx, &JsonRpcMessage::Response(response))
1171 .unwrap();
1172
1173 let msg2 = transport.recv(&cx).unwrap();
1175 if let JsonRpcMessage::Request(req) = msg2 {
1176 assert_eq!(req.method, "tools/list");
1177 }
1178
1179 let result = transport.recv(&cx);
1181 assert!(matches!(result, Err(TransportError::Closed)));
1182 }
1183
1184 #[test]
1185 fn e2e_sse_client_transport_flow() {
1186 let sse_input = b"\
1188event: endpoint\n\
1189data: http://localhost/post\n\
1190\n\
1191event: message\n\
1192data: {\"jsonrpc\":\"2.0\",\"result\":{\"tools\":[]},\"id\":1}\n\
1193\n";
1194
1195 let reader = Cursor::new(sse_input.to_vec());
1196 let mut request_buffer = Vec::new();
1197
1198 {
1199 let mut transport = SseClientTransport::new(reader, &mut request_buffer);
1200 let cx = Cx::for_testing();
1201
1202 let endpoint = transport.read_endpoint(&cx).unwrap().unwrap();
1204 assert_eq!(endpoint, "http://localhost/post");
1205
1206 let request = JsonRpcRequest::new("tools/list", None, 1i64);
1208 transport
1209 .send(&cx, &JsonRpcMessage::Request(request))
1210 .unwrap();
1211
1212 let msg = transport.recv(&cx).unwrap();
1214 assert!(matches!(msg, JsonRpcMessage::Response(_)));
1215 }
1216
1217 let sent = String::from_utf8(request_buffer).unwrap();
1219 assert!(sent.contains("\"method\":\"tools/list\""));
1220 }
1221
1222 #[test]
1223 fn e2e_sse_event_with_retry() {
1224 let input = b"\
1226event: message\n\
1227id: 1\n\
1228retry: 5000\n\
1229data: test\n\
1230\n";
1231
1232 let reader = Cursor::new(input.to_vec());
1233 let mut sse_reader = SseReader::new(reader);
1234 let cx = Cx::for_testing();
1235
1236 let event = sse_reader.read_event(&cx).unwrap().unwrap();
1237 assert_eq!(event.retry, Some(5000));
1238 }
1239
1240 #[test]
1241 fn e2e_sse_multiple_data_lines() {
1242 let input = b"\
1245event: message\n\
1246data: {\n\
1247data: \"jsonrpc\": \"2.0\",\n\
1248data: \"result\": {\"key\": \"value\"},\n\
1249data: \"id\": 1\n\
1250data: }\n\
1251\n";
1252
1253 let reader = Cursor::new(input.to_vec());
1254 let mut sse_reader = SseReader::new(reader);
1255 let cx = Cx::for_testing();
1256
1257 let event = sse_reader.read_event(&cx).unwrap().unwrap();
1258
1259 assert!(event.data.contains("\"jsonrpc\""));
1261 assert!(event.data.contains("\"result\""));
1262
1263 let parsed: serde_json::Value = serde_json::from_str(&event.data).unwrap();
1265 assert_eq!(parsed.get("id"), Some(&serde_json::json!(1)));
1266 }
1267
1268 #[test]
1269 fn e2e_sse_unicode_in_events() {
1270 let input = "event: message\ndata: {\"text\":\"Hello δΈη π\"}\n\n";
1272
1273 let reader = Cursor::new(input.as_bytes().to_vec());
1274 let mut sse_reader = SseReader::new(reader);
1275 let cx = Cx::for_testing();
1276
1277 let event = sse_reader.read_event(&cx).unwrap().unwrap();
1278 assert!(event.data.contains("δΈη"));
1279 assert!(event.data.contains("π"));
1280 }
1281
1282 #[test]
1287 fn sse_event_type_as_str_round_trip() {
1288 for ty in [SseEventType::Endpoint, SseEventType::Message] {
1289 let s = ty.as_str();
1290 let parsed = SseEventType::from_str(s).unwrap();
1291 assert_eq!(parsed, ty);
1292 }
1293 }
1294
1295 #[test]
1296 fn sse_event_type_from_str_unknown_returns_none() {
1297 assert!(SseEventType::from_str("ping").is_none());
1298 assert!(SseEventType::from_str("").is_none());
1299 assert!(SseEventType::from_str("MESSAGE").is_none());
1300 }
1301
1302 #[test]
1303 fn sse_event_empty_data_serialization() {
1304 let event = SseEvent::message("");
1306 let bytes = event.to_bytes();
1307 let output = String::from_utf8(bytes).unwrap();
1308
1309 assert!(output.contains("data: \n"));
1310 assert!(output.contains("event: message\n"));
1311 assert!(output.ends_with("\n\n"));
1312 }
1313
1314 #[test]
1315 fn sse_writer_event_counter_auto_increments() {
1316 let buffer = Vec::new();
1317 let mut writer = SseWriter::new(buffer);
1318 let cx = Cx::for_testing();
1319
1320 for _ in 0..3 {
1322 let msg = JsonRpcMessage::Response(JsonRpcResponse {
1323 jsonrpc: std::borrow::Cow::Borrowed(fastmcp_protocol::JSONRPC_VERSION),
1324 result: Some(serde_json::json!(null)),
1325 error: None,
1326 id: Some(fastmcp_protocol::RequestId::Number(1)),
1327 });
1328 writer.write_message(&cx, &msg).unwrap();
1329 }
1330
1331 let output = String::from_utf8(writer.into_inner()).unwrap();
1332 let events: Vec<&str> = output.split("\n\n").filter(|s| !s.is_empty()).collect();
1334 assert_eq!(events.len(), 3);
1335 assert!(events[0].contains("id: 1\n"));
1336 assert!(events[1].contains("id: 2\n"));
1337 assert!(events[2].contains("id: 3\n"));
1338 }
1339
1340 #[test]
1341 fn sse_writer_inner_and_inner_mut_accessors() {
1342 let buffer: Vec<u8> = Vec::new();
1343 let mut writer = SseWriter::new(buffer);
1344
1345 assert!(writer.inner().is_empty());
1347
1348 writer.inner_mut().extend_from_slice(b"raw");
1350 assert_eq!(writer.inner().len(), 3);
1351 }
1352
1353 #[test]
1354 fn sse_writer_write_comment_custom_text() {
1355 let buffer = Vec::new();
1356 let mut writer = SseWriter::new(buffer);
1357 let cx = Cx::for_testing();
1358
1359 writer.write_comment(&cx, "hello world").unwrap();
1360
1361 let output = String::from_utf8(writer.into_inner()).unwrap();
1362 assert_eq!(output, ": hello world\n");
1363 }
1364
1365 #[test]
1366 fn sse_reader_read_endpoint_skips_message_events() {
1367 let input = b"event: message\ndata: {\"jsonrpc\":\"2.0\",\"method\":\"ping\"}\n\n\
1369event: endpoint\ndata: http://localhost/post\n\n";
1370 let reader = Cursor::new(input.to_vec());
1371 let mut sse_reader = SseReader::new(reader);
1372 let cx = Cx::for_testing();
1373
1374 let url = sse_reader.read_endpoint(&cx).unwrap().unwrap();
1375 assert_eq!(url, "http://localhost/post");
1376 }
1377
1378 #[test]
1379 fn sse_server_transport_close_flushes() {
1380 let requests: Vec<JsonRpcRequest> = vec![];
1381 let buffer = Vec::new();
1382 let mut transport =
1383 SseServerTransport::new(buffer, requests.into_iter(), "http://localhost/post");
1384
1385 transport.close().unwrap();
1387 }
1388
1389 #[test]
1390 fn sse_client_transport_send_cancelled() {
1391 let sse_input = b"";
1392 let reader = Cursor::new(sse_input.to_vec());
1393 let mut request_buffer = Vec::new();
1394
1395 let mut transport = SseClientTransport::new(reader, &mut request_buffer);
1396 let cx = Cx::for_testing();
1397 cx.set_cancel_requested(true);
1398
1399 let request = JsonRpcRequest::new("test", None, 1i64);
1400 let result = transport.send(&cx, &JsonRpcMessage::Request(request));
1401 assert!(matches!(result, Err(TransportError::Cancelled)));
1402 }
1403}