1use crate::error::{Error, Result};
7use bytes::Bytes;
8use futures::stream::Stream;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::io::AsyncRead;
12
13pub struct StreamingBody {
15 inner: Pin<Box<dyn Stream<Item = Result<Bytes>> + Send>>,
16}
17
18impl StreamingBody {
19 pub fn new<S>(stream: S) -> Self
21 where
22 S: Stream<Item = Result<Bytes>> + Send + 'static,
23 {
24 Self {
25 inner: Box::pin(stream),
26 }
27 }
28
29 pub fn from_reader<R>(reader: R, chunk_size: usize) -> Self
31 where
32 R: AsyncRead + Send + Unpin + 'static,
33 {
34 let stream = ReaderStream::new(reader, chunk_size);
35 Self::new(stream)
36 }
37
38 pub async fn next_chunk(&mut self) -> Option<Result<Bytes>> {
40 use futures::StreamExt;
41 self.inner.next().await
42 }
43
44 pub async fn collect(mut self) -> Result<Bytes> {
46 let mut chunks = Vec::new();
47 while let Some(chunk) = self.next_chunk().await {
48 chunks.push(chunk?);
49 }
50
51 let total_len: usize = chunks.iter().map(|c| c.len()).sum();
52 let mut result = Vec::with_capacity(total_len);
53 for chunk in chunks {
54 result.extend_from_slice(&chunk);
55 }
56
57 Ok(Bytes::from(result))
58 }
59}
60
61struct ReaderStream<R> {
63 reader: Option<R>,
64 chunk_size: usize,
65}
66
67impl<R> ReaderStream<R> {
68 fn new(reader: R, chunk_size: usize) -> Self {
69 Self {
70 reader: Some(reader),
71 chunk_size,
72 }
73 }
74}
75
76impl<R: AsyncRead + Unpin + Send> Stream for ReaderStream<R> {
77 type Item = Result<Bytes>;
78
79 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
80 let chunk_size = self.chunk_size;
81
82 if let Some(reader) = &mut self.reader {
83 let mut buf = vec![0u8; chunk_size];
84 let mut read_buf = tokio::io::ReadBuf::new(&mut buf);
85
86 match Pin::new(reader).poll_read(cx, &mut read_buf) {
87 Poll::Ready(Ok(())) => {
88 let filled = read_buf.filled().len();
89 if filled == 0 {
90 self.reader = None;
91 Poll::Ready(None)
92 } else {
93 buf.truncate(filled);
94 Poll::Ready(Some(Ok(Bytes::from(buf))))
95 }
96 }
97 Poll::Ready(Err(e)) => {
98 self.reader = None;
99 Poll::Ready(Some(Err(Error::BodyReadError { source: e })))
100 }
101 Poll::Pending => Poll::Pending,
102 }
103 } else {
104 Poll::Ready(None)
105 }
106 }
107}
108
109pub struct ChunkedEncoder {
111 inner: StreamingBody,
112}
113
114impl ChunkedEncoder {
115 pub fn new(body: StreamingBody) -> Self {
117 Self { inner: body }
118 }
119
120 pub async fn next_encoded_chunk(&mut self) -> Option<Result<Bytes>> {
122 match self.inner.next_chunk().await {
123 Some(Ok(chunk)) => {
124 if chunk.is_empty() {
125 Some(Ok(Bytes::from("0\r\n\r\n")))
127 } else {
128 let size_hex = format!("{:x}\r\n", chunk.len());
130 let mut encoded = Vec::with_capacity(size_hex.len() + chunk.len() + 2);
131 encoded.extend_from_slice(size_hex.as_bytes());
132 encoded.extend_from_slice(&chunk);
133 encoded.extend_from_slice(b"\r\n");
134 Some(Ok(Bytes::from(encoded)))
135 }
136 }
137 Some(Err(e)) => Some(Err(e)),
138 None => Some(Ok(Bytes::from("0\r\n\r\n"))), }
140 }
141}
142
143pub struct SseStream {
145 inner: StreamingBody,
146}
147
148impl SseStream {
149 pub fn new(body: StreamingBody) -> Self {
151 Self { inner: body }
152 }
153
154 pub fn event(event_type: &str, data: &str) -> Bytes {
156 let mut event = String::new();
157 if !event_type.is_empty() {
158 event.push_str(&format!("event: {}\n", event_type));
159 }
160 for line in data.lines() {
161 event.push_str(&format!("data: {}\n", line));
162 }
163 event.push('\n');
164 Bytes::from(event)
165 }
166
167 pub fn comment(text: &str) -> Bytes {
169 Bytes::from(format!(": {}\n\n", text))
170 }
171
172 pub fn retry(milliseconds: u64) -> Bytes {
174 Bytes::from(format!("retry: {}\n\n", milliseconds))
175 }
176
177 pub async fn next_message(&mut self) -> Option<Result<Bytes>> {
179 self.inner.next_chunk().await
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use futures::stream;
187
188 #[tokio::test]
189 async fn test_streaming_body_from_vec() {
190 let data = vec![
191 Ok(Bytes::from("Hello, ")),
192 Ok(Bytes::from("World!")),
193 ];
194 let stream = stream::iter(data);
195 let mut body = StreamingBody::new(stream);
196
197 let chunk1 = body.next_chunk().await.unwrap().unwrap();
198 assert_eq!(chunk1, Bytes::from("Hello, "));
199
200 let chunk2 = body.next_chunk().await.unwrap().unwrap();
201 assert_eq!(chunk2, Bytes::from("World!"));
202
203 assert!(body.next_chunk().await.is_none());
204 }
205
206 #[tokio::test]
207 async fn test_streaming_body_collect() {
208 let data = vec![
209 Ok(Bytes::from("Hello, ")),
210 Ok(Bytes::from("World!")),
211 ];
212 let stream = stream::iter(data);
213 let body = StreamingBody::new(stream);
214
215 let collected = body.collect().await.unwrap();
216 assert_eq!(collected, Bytes::from("Hello, World!"));
217 }
218
219 #[test]
220 fn test_sse_event() {
221 let event = SseStream::event("message", "Hello, World!");
222 let expected = "event: message\ndata: Hello, World!\n\n";
223 assert_eq!(event, Bytes::from(expected));
224 }
225
226 #[test]
227 fn test_sse_event_multiline() {
228 let event = SseStream::event("update", "Line 1\nLine 2\nLine 3");
229 let expected = "event: update\ndata: Line 1\ndata: Line 2\ndata: Line 3\n\n";
230 assert_eq!(event, Bytes::from(expected));
231 }
232
233 #[test]
234 fn test_sse_comment() {
235 let comment = SseStream::comment("Keep alive");
236 assert_eq!(comment, Bytes::from(": Keep alive\n\n"));
237 }
238
239 #[test]
240 fn test_sse_retry() {
241 let retry = SseStream::retry(3000);
242 assert_eq!(retry, Bytes::from("retry: 3000\n\n"));
243 }
244}