just_common/transport/
sse.rs1use std::{
13 fmt,
14 pin::Pin,
15 task::{Context, Poll},
16};
17
18use async_stream::try_stream;
19use futures_core::Stream;
20use futures_util::StreamExt;
21use reqwest::header::CONTENT_TYPE;
22use serde::de::DeserializeOwned;
23
24use crate::error::TransportError;
25
26type BoxedJsonStream<T> = Pin<Box<dyn Stream<Item = Result<T, TransportError>> + Send>>;
27
28pub struct JsonEventStream<T> {
30 inner: BoxedJsonStream<T>,
31}
32
33impl<T> JsonEventStream<T>
34where
35 T: DeserializeOwned + Send + 'static,
36{
37 pub fn from_response(response: reqwest::Response) -> Result<Self, TransportError> {
39 ensure_event_stream(&response)?;
40
41 let stream = try_stream! {
42 let mut bytes = response.bytes_stream();
43 let mut buffer = Vec::new();
44 let mut done = false;
45
46 while let Some(chunk) = bytes.next().await {
47 let chunk = chunk.map_err(TransportError::Transport)?;
48 buffer.extend_from_slice(&chunk);
49
50 while let Some((event_end, consumed)) = split_event(&buffer) {
51 let event = buffer[..event_end].to_vec();
52 buffer.drain(..consumed);
53
54 match parse_event::<T>(&event)? {
55 ParsedEvent::Done => {
56 done = true;
57 break;
58 }
59 ParsedEvent::Skip => {}
60 ParsedEvent::Chunk(chunk) => yield chunk,
61 }
62 }
63
64 if done {
65 break;
66 }
67 }
68
69 if !done && !buffer.iter().all(u8::is_ascii_whitespace) {
70 match parse_event::<T>(&buffer)? {
71 ParsedEvent::Done | ParsedEvent::Skip => {}
72 ParsedEvent::Chunk(chunk) => yield chunk,
73 }
74 }
75 };
76
77 Ok(Self {
78 inner: Box::pin(stream),
79 })
80 }
81}
82
83impl<T> fmt::Debug for JsonEventStream<T> {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 f.debug_struct("JsonEventStream").finish_non_exhaustive()
86 }
87}
88
89impl<T> Stream for JsonEventStream<T> {
90 type Item = Result<T, TransportError>;
91
92 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
93 unsafe { self.map_unchecked_mut(|stream| &mut stream.inner) }.poll_next(cx)
95 }
96}
97
98#[derive(Debug)]
99enum ParsedEvent<T> {
100 Done,
101 Skip,
102 Chunk(T),
103}
104
105fn ensure_event_stream(response: &reqwest::Response) -> Result<(), TransportError> {
106 let Some(content_type) = response.headers().get(CONTENT_TYPE) else {
107 return Err(TransportError::InvalidResponse(
108 "streaming response was missing content-type".to_owned(),
109 ));
110 };
111
112 let content_type = content_type.to_str().map_err(|_| {
113 TransportError::InvalidResponse(
114 "streaming response content-type was not valid UTF-8".to_owned(),
115 )
116 })?;
117
118 if !content_type.starts_with("text/event-stream") {
119 return Err(TransportError::InvalidResponse(format!(
120 "expected text/event-stream response, got {content_type}"
121 )));
122 }
123
124 Ok(())
125}
126
127fn split_event(buffer: &[u8]) -> Option<(usize, usize)> {
128 let mut index = 0;
129
130 while index < buffer.len() {
131 if buffer[index..].starts_with(b"\r\n\r\n") {
132 return Some((index, index + 4));
133 }
134
135 if buffer[index..].starts_with(b"\n\n") {
136 return Some((index, index + 2));
137 }
138
139 index += 1;
140 }
141
142 None
143}
144
145fn parse_event<T>(raw_event: &[u8]) -> Result<ParsedEvent<T>, TransportError>
146where
147 T: DeserializeOwned,
148{
149 if raw_event.is_empty() || raw_event.iter().all(u8::is_ascii_whitespace) {
150 return Ok(ParsedEvent::Skip);
151 }
152
153 let event = String::from_utf8(raw_event.to_vec()).map_err(TransportError::Utf8)?;
154 let mut data_lines = Vec::new();
155
156 for line in event.lines() {
157 let line = line.trim_end_matches('\r');
158
159 if line.starts_with(':') {
160 continue;
161 }
162
163 if let Some(data) = line.strip_prefix("data:") {
164 data_lines.push(data.trim_start());
165 }
166 }
167
168 if data_lines.is_empty() {
169 return Ok(ParsedEvent::Skip);
170 }
171
172 let payload = data_lines.join("\n");
173
174 if payload == "[DONE]" {
175 return Ok(ParsedEvent::Done);
176 }
177
178 let chunk = serde_json::from_str(&payload).map_err(|source| TransportError::Deserialize {
179 source,
180 body: payload,
181 })?;
182
183 Ok(ParsedEvent::Chunk(chunk))
184}