1#![allow(clippy::type_complexity)]
2
3use crate::{
4 CborEncoding, ClientRequest, ClientResponse, Encoding, FromResponse, IntoRequest, JsonEncoding,
5 ServerFnError,
6};
7use axum::extract::{FromRequest, Request};
8use axum_core::response::IntoResponse;
9use bytes::{Buf as _, Bytes};
10use dioxus_fullstack_core::{HttpError, RequestError};
11use futures::{Stream, StreamExt};
12#[cfg(feature = "server")]
13use futures_channel::mpsc::UnboundedSender;
14use headers::{ContentType, Header};
15use send_wrapper::SendWrapper;
16use serde::{de::DeserializeOwned, Serialize};
17use std::{future::Future, marker::PhantomData, pin::Pin};
18
19pub type TextStream = Streaming<String>;
31
32pub type ByteStream = Streaming<Bytes>;
43
44pub type JsonStream<T> = Streaming<T, JsonEncoding>;
55
56pub type CborStream<T> = Streaming<T, CborEncoding>;
64
65pub type ChunkedByteStream = Streaming<Bytes, CborEncoding>;
70
71pub type ChunkedTextStream = Streaming<String, CborEncoding>;
76
77pub struct Streaming<T = String, E = ()> {
101 stream: Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>,
102 encoding: PhantomData<E>,
103}
104
105#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq, Hash)]
106pub enum StreamingError {
107 #[error("The streaming request was interrupted")]
109 Interrupted,
110
111 #[error("The stream failed to decode a chunk")]
113 Decoding,
114
115 #[error("The streaming request failed")]
117 Failed,
118}
119
120impl<T: 'static + Send, E> Streaming<T, E> {
121 pub fn new(value: impl Stream<Item = T> + Send + 'static) -> Self {
123 Self {
125 stream: Box::pin(value.map(|item| Ok(item)))
126 as Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>,
127 encoding: PhantomData,
128 }
129 }
130
131 #[cfg(feature = "server")]
135 pub fn spawn<F>(callback: impl FnOnce(UnboundedSender<T>) -> F + Send + 'static) -> Self
136 where
137 F: Future<Output = ()> + 'static,
138 T: Send,
139 {
140 let (tx, rx) = futures_channel::mpsc::unbounded();
141
142 crate::spawn_platform(move || callback(tx));
143
144 Self::new(rx)
145 }
146
147 pub async fn next(&mut self) -> Option<Result<T, StreamingError>> {
149 self.stream.as_mut().next().await
150 }
151
152 pub fn into_inner(self) -> impl Stream<Item = Result<T, StreamingError>> + Send {
154 self.stream
155 }
156
157 fn from_bytes(stream: impl Stream<Item = Result<T, StreamingError>> + Send + 'static) -> Self {
161 Self {
162 stream: Box::pin(stream),
163 encoding: PhantomData,
164 }
165 }
166}
167
168impl<S, U> From<S> for TextStream
169where
170 S: Stream<Item = U> + Send + 'static,
171 U: Into<String>,
172{
173 fn from(value: S) -> Self {
174 Self::new(value.map(|data| data.into()))
175 }
176}
177
178impl<S, E> From<S> for ByteStream
179where
180 S: Stream<Item = Result<Bytes, E>> + Send + 'static,
181{
182 fn from(value: S) -> Self {
183 Self {
184 stream: Box::pin(value.map(|data| data.map_err(|_| StreamingError::Failed))),
185 encoding: PhantomData,
186 }
187 }
188}
189
190impl<T, S, U, E> From<S> for Streaming<T, E>
191where
192 S: Stream<Item = U> + Send + 'static,
193 U: Into<T>,
194 T: 'static + Send,
195 E: Encoding,
196{
197 fn from(value: S) -> Self {
198 Self::from_bytes(value.map(|data| Ok(data.into())))
199 }
200}
201
202impl IntoResponse for Streaming<String> {
203 fn into_response(self) -> axum_core::response::Response {
204 axum::response::Response::builder()
205 .header("Content-Type", "text/plain; charset=utf-8")
206 .body(axum::body::Body::from_stream(self.stream))
207 .unwrap()
208 }
209}
210
211impl IntoResponse for Streaming<Bytes> {
212 fn into_response(self) -> axum_core::response::Response {
213 axum::response::Response::builder()
214 .header("Content-Type", "application/octet-stream")
215 .body(axum::body::Body::from_stream(self.stream))
216 .unwrap()
217 }
218}
219
220impl<T: DeserializeOwned + Serialize + 'static, E: Encoding> IntoResponse for Streaming<T, E> {
221 fn into_response(self) -> axum_core::response::Response {
222 let res = self.stream.map(|r| match r {
223 Ok(res) => match encode_stream_frame::<T, E>(res) {
224 Some(bytes) => Ok(bytes),
225 None => Err(StreamingError::Failed),
226 },
227 Err(_err) => Err(StreamingError::Failed),
228 });
229
230 axum::response::Response::builder()
231 .header("Content-Type", E::stream_content_type())
232 .body(axum::body::Body::from_stream(res))
233 .unwrap()
234 }
235}
236
237impl FromResponse for Streaming<String> {
238 fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
239 SendWrapper::new(async move {
240 let client_stream = Box::pin(res.bytes_stream().map(|byte| match byte {
241 Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
242 Ok(string) => Ok(string),
243 Err(_) => Err(StreamingError::Decoding),
244 },
245 Err(_) => Err(StreamingError::Failed),
246 }));
247
248 Ok(Self {
249 stream: client_stream,
250 encoding: PhantomData,
251 })
252 })
253 }
254}
255
256impl FromResponse for Streaming<Bytes> {
257 fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
258 async move {
259 let client_stream = Box::pin(SendWrapper::new(res.bytes_stream().map(
260 |byte| match byte {
261 Ok(bytes) => Ok(bytes),
262 Err(_) => Err(StreamingError::Failed),
263 },
264 )));
265
266 Ok(Self {
267 stream: client_stream,
268 encoding: PhantomData,
269 })
270 }
271 }
272}
273
274impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding> FromResponse
275 for Streaming<T, E>
276{
277 fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
278 SendWrapper::new(async move {
279 Ok(Self {
280 stream: byte_stream_to_client_stream::<E, _, _, _>(res.bytes_stream()),
281 encoding: PhantomData,
282 })
283 })
284 }
285}
286
287impl<S> FromRequest<S> for Streaming<String> {
288 type Rejection = ServerFnError;
289
290 fn from_request(
291 req: Request,
292 _state: &S,
293 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
294 async move {
295 let (parts, body) = req.into_parts();
296 let content_type = parts
297 .headers
298 .get("content-type")
299 .and_then(|v| v.to_str().ok())
300 .unwrap_or("");
301
302 if !content_type.starts_with("text/plain") {
303 HttpError::bad_request("Invalid content type")?;
304 }
305
306 let stream = body.into_data_stream();
307
308 Ok(Self {
309 stream: Box::pin(stream.map(|byte| match byte {
310 Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
311 Ok(string) => Ok(string),
312 Err(_) => Err(StreamingError::Decoding),
313 },
314 Err(_) => Err(StreamingError::Failed),
315 })),
316 encoding: PhantomData,
317 })
318 }
319 }
320}
321
322impl<S> FromRequest<S> for ByteStream {
323 type Rejection = ServerFnError;
324
325 fn from_request(
326 req: Request,
327 _state: &S,
328 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
329 async move {
330 let (parts, body) = req.into_parts();
331 let content_type = parts
332 .headers
333 .get("content-type")
334 .and_then(|v| v.to_str().ok())
335 .unwrap_or("");
336
337 if !content_type.starts_with("application/octet-stream") {
338 HttpError::bad_request("Invalid content type")?;
339 }
340
341 let stream = body.into_data_stream();
342
343 Ok(Self {
344 stream: Box::pin(stream.map(|byte| match byte {
345 Ok(bytes) => Ok(bytes),
346 Err(_) => Err(StreamingError::Failed),
347 })),
348 encoding: PhantomData,
349 })
350 }
351 }
352}
353
354impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding, S> FromRequest<S>
355 for Streaming<T, E>
356{
357 type Rejection = ServerFnError;
358
359 fn from_request(
360 req: Request,
361 _state: &S,
362 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
363 async move {
364 let (parts, body) = req.into_parts();
365 let content_type = parts
366 .headers
367 .get("content-type")
368 .and_then(|v| v.to_str().ok())
369 .unwrap_or("");
370
371 if !content_type.starts_with(E::stream_content_type()) {
372 HttpError::bad_request("Invalid content type")?;
373 }
374
375 let stream = body.into_data_stream();
376
377 Ok(Self {
378 stream: byte_stream_to_client_stream::<E, _, _, _>(stream),
379 encoding: PhantomData,
380 })
381 }
382 }
383}
384
385impl IntoRequest for Streaming<String> {
386 fn into_request(
387 self,
388 builder: ClientRequest,
389 ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
390 async move {
391 builder
392 .header("Content-Type", "text/plain; charset=utf-8")?
393 .send_body_stream(self.stream.map(|e| e.map(Bytes::from)))
394 .await
395 }
396 }
397}
398
399impl IntoRequest for ByteStream {
400 fn into_request(
401 self,
402 builder: ClientRequest,
403 ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
404 async move {
405 builder
406 .header(ContentType::name(), "application/octet-stream")?
407 .send_body_stream(self.stream)
408 .await
409 }
410 }
411}
412
413impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding> IntoRequest
414 for Streaming<T, E>
415{
416 fn into_request(
417 self,
418 builder: ClientRequest,
419 ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
420 async move {
421 builder
422 .header("Content-Type", E::stream_content_type())?
423 .send_body_stream(self.stream.map(|r| {
424 r.and_then(|item| {
425 encode_stream_frame::<T, E>(item).ok_or(StreamingError::Failed)
426 })
427 }))
428 .await
429 }
430 }
431}
432
433impl<T> std::fmt::Debug for Streaming<T> {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 f.debug_tuple("Streaming").finish()
436 }
437}
438
439impl<T, E: Encoding> std::fmt::Debug for Streaming<T, E> {
440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441 f.debug_struct("Streaming")
442 .field("encoding", &std::any::type_name::<E>())
443 .finish()
444 }
445}
446
447pub fn encode_stream_frame<T: Serialize, E: Encoding>(data: T) -> Option<Bytes> {
454 let mut bytes = vec![0u8; 10];
462
463 E::encode(data, &mut bytes)?;
464
465 let len = (bytes.len() - 10) as u64;
466 let opcode = 0x82; let offset = if len <= 125 {
470 bytes[8] = opcode;
471 bytes[9] = len as u8;
472 8
473 } else if len <= u16::MAX as u64 {
474 bytes[6] = opcode;
475 bytes[7] = 126;
476 let len_bytes = (len as u16).to_be_bytes();
477 bytes[8] = len_bytes[0];
478 bytes[9] = len_bytes[1];
479 6
480 } else {
481 bytes[0] = opcode;
482 bytes[1] = 127;
483 bytes[2..10].copy_from_slice(&len.to_be_bytes());
484 0
485 };
486
487 Some(Bytes::from(bytes).slice(offset..))
489}
490
491fn byte_stream_to_client_stream<E, T, S, E1>(
492 stream: S,
493) -> Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>
494where
495 S: Stream<Item = Result<Bytes, E1>> + 'static + Send,
496 E: Encoding,
497 T: DeserializeOwned + 'static,
498{
499 Box::pin(stream.flat_map(|bytes| {
500 enum DecodeIteratorState {
501 Empty,
502 Failed,
503 Checked(Bytes),
504 UnChecked(Bytes),
505 }
506
507 let mut state = match bytes {
508 Ok(bytes) => DecodeIteratorState::UnChecked(bytes),
509 Err(_) => DecodeIteratorState::Failed,
510 };
511
512 futures::stream::iter(std::iter::from_fn(move || {
513 match std::mem::replace(&mut state, DecodeIteratorState::Empty) {
514 DecodeIteratorState::Empty => None,
515 DecodeIteratorState::Failed => Some(Err(StreamingError::Failed)),
516 DecodeIteratorState::Checked(mut bytes) => {
517 let r = decode_stream_frame_multi::<T, E>(&mut bytes);
518 if r.is_some() {
519 state = DecodeIteratorState::Checked(bytes)
520 }
521 r
522 }
523 DecodeIteratorState::UnChecked(mut bytes) => {
524 let r = decode_stream_frame_multi::<T, E>(&mut bytes);
525 if r.is_some() {
526 state = DecodeIteratorState::Checked(bytes);
527 r
528 } else {
529 Some(Err(StreamingError::Decoding))
530 }
531 }
532 }
533 }))
534 }))
535}
536
537pub fn decode_stream_frame<T, E>(mut frame: Bytes) -> Option<T>
543where
544 E: Encoding,
545 T: DeserializeOwned,
546{
547 decode_stream_frame_multi::<T, E>(&mut frame).and_then(|r| r.ok())
548}
549
550fn decode_stream_frame_multi<T, E>(frame: &mut Bytes) -> Option<Result<T, StreamingError>>
557where
558 E: Encoding,
559 T: DeserializeOwned,
560{
561 let (offset, payload_len) = match offset_payload_len(frame)? {
562 Ok(r) => r,
563 Err(e) => return Some(Err(e)),
564 };
565
566 let r = E::decode(frame.slice(offset..offset + payload_len));
567 frame.advance(offset + payload_len);
568 r.map(|r| Ok(r))
569}
570
571fn offset_payload_len(frame: &Bytes) -> Option<Result<(usize, usize), StreamingError>> {
573 let data = frame.as_ref();
574
575 if data.is_empty() {
576 return None;
577 }
578
579 if data.len() < 2 {
580 return Some(Err(StreamingError::Decoding));
581 }
582
583 let first = data[0];
584 let second = data[1];
585
586 let fin = first & 0x80 != 0;
588 let opcode = first & 0x0F;
589 let rsv = first & 0x70;
590 if !fin || opcode != 0x02 || rsv != 0 {
591 return Some(Err(StreamingError::Decoding));
592 }
593
594 if second & 0x80 != 0 {
596 return Some(Err(StreamingError::Decoding));
597 }
598
599 let mut offset = 2usize;
600 let mut payload_len = (second & 0x7F) as usize;
601
602 if payload_len == 126 {
603 if data.len() < offset + 2 {
604 return Some(Err(StreamingError::Decoding));
605 }
606
607 payload_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
608 offset += 2;
609 } else if payload_len == 127 {
610 if data.len() < offset + 8 {
611 return Some(Err(StreamingError::Decoding));
612 }
613
614 let mut len_bytes = [0u8; 8];
615 len_bytes.copy_from_slice(&data[offset..offset + 8]);
616 let len_u64 = u64::from_be_bytes(len_bytes);
617
618 if len_u64 > usize::MAX as u64 {
619 return Some(Err(StreamingError::Decoding));
620 }
621
622 payload_len = len_u64 as usize;
623 offset += 8;
624 }
625
626 if data.len() < offset + payload_len {
627 return Some(Err(StreamingError::Decoding));
628 }
629 Some(Ok((offset, payload_len)))
630}