dioxus_fullstack/payloads/
stream.rs

1#![allow(clippy::type_complexity)]
2
3use crate::{
4    CborEncoding, ClientRequest, ClientResponse, Encoding, FromResponse, IntoRequest, JsonEncoding,
5    ServerFnError,
6};
7use axum::extract::{FromRequest, Request};
8use axum_core::response::IntoResponse;
9use bytes::{Buf as _, Bytes};
10use dioxus_fullstack_core::{HttpError, RequestError};
11use futures::{Stream, StreamExt};
12#[cfg(feature = "server")]
13use futures_channel::mpsc::UnboundedSender;
14use headers::{ContentType, Header};
15use send_wrapper::SendWrapper;
16use serde::{de::DeserializeOwned, Serialize};
17use std::{future::Future, marker::PhantomData, pin::Pin};
18
19/// A stream of text data.
20///
21/// # Chunking
22///
23/// Note that strings sent by the server might not arrive in the same chunking as they were sent.
24///
25/// This is because the underlying transport layer (HTTP/2 or HTTP/3) may choose to split or combine
26/// chunks for efficiency.
27///
28/// If you need to preserve individual string boundaries, consider using `ChunkedTextStream` or another
29/// encoding that preserves chunk boundaries.
30pub type TextStream = Streaming<String>;
31
32/// A stream of binary data.
33///
34/// # Chunking
35///
36/// Note that bytes sent by the server might not arrive in the same chunking as they were sent.
37/// This is because the underlying transport layer (HTTP/2 or HTTP/3) may choose to split or combine
38/// chunks for efficiency.
39///
40/// If you need to preserve individual byte boundaries, consider using `ChunkedByteStream` or another
41/// encoding that preserves chunk boundaries.
42pub type ByteStream = Streaming<Bytes>;
43
44/// A stream of JSON-encoded data.
45///
46/// # Chunking
47///
48/// Normally, it's not possible to stream JSON over HTTP because browsers are free to re-chunk
49/// data as they see fit. However, this implementation manually frames each JSON as if it were an unmasked
50/// websocket message.
51///
52/// If you need to send a stream of JSON data without framing, consider using TextStream instead and
53/// manually handling JSON buffering.
54pub type JsonStream<T> = Streaming<T, JsonEncoding>;
55
56/// A stream of Cbor-encoded data.
57///
58/// # Chunking
59///
60/// Normally, it's not possible to stream JSON over HTTP because browsers are free to re-chunk
61/// data as they see fit. However, this implementation manually frames each item as if it were an unmasked
62/// websocket message.
63pub type CborStream<T> = Streaming<T, CborEncoding>;
64
65/// A stream of manually chunked binary data.
66///
67/// This encoding preserves chunk boundaries by framing each chunk with its length, using Websocket
68/// Framing.
69pub type ChunkedByteStream = Streaming<Bytes, CborEncoding>;
70
71/// A stream of manually chunked text data.
72///
73/// This encoding preserves chunk boundaries by framing each chunk with its length, using Websocket
74/// Framing.
75pub type ChunkedTextStream = Streaming<String, CborEncoding>;
76
77/// A streaming payload.
78///
79/// ## Frames and Chunking
80///
81/// The streaming payload sends and receives data in discrete chunks or "frames". The size is converted
82/// to hex and sent before each chunk, followed by a CRLF, the chunk data, and another CRLF.
83///
84/// This mimics actual HTTP chunked transfer encoding, but allows us to define our own framing
85/// protocol on top of it.
86///
87/// Arbitrary bytes can be encoded between these frames, but the frames do come with some overhead.
88///
89/// ## Browser Support for Streaming Input
90///
91/// Browser fetch requests do not currently support full request duplexing, which
92/// means that that they do not begin handling responses until the full request has been sent.
93///
94/// This means that if you use a streaming input encoding, the input stream needs to
95/// end before the output will begin.
96///
97/// Streaming requests are only allowed over HTTP2 or HTTP3.
98///
99/// Also note that not all browsers support streaming bodies to servers.
100pub struct Streaming<T = String, E = ()> {
101    stream: Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>,
102    encoding: PhantomData<E>,
103}
104
105#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq, Hash)]
106pub enum StreamingError {
107    /// The streaming request was interrupted and could not be completed.
108    #[error("The streaming request was interrupted")]
109    Interrupted,
110
111    /// The stream failed to decode a chunk - possibly due to invalid data or version mismatch.
112    #[error("The stream failed to decode a chunk")]
113    Decoding,
114
115    /// The stream failed to connect or encountered an error.
116    #[error("The streaming request failed")]
117    Failed,
118}
119
120impl<T: 'static + Send, E> Streaming<T, E> {
121    /// Creates a new stream from the given stream.
122    pub fn new(value: impl Stream<Item = T> + Send + 'static) -> Self {
123        // Box and pin the incoming stream and store as a trait object
124        Self {
125            stream: Box::pin(value.map(|item| Ok(item)))
126                as Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>,
127            encoding: PhantomData,
128        }
129    }
130
131    /// Spawns a new task that produces items for the stream.
132    ///
133    /// The callback is provided an `UnboundedSender` that can be used to send items to the stream.
134    #[cfg(feature = "server")]
135    pub fn spawn<F>(callback: impl FnOnce(UnboundedSender<T>) -> F + Send + 'static) -> Self
136    where
137        F: Future<Output = ()> + 'static,
138        T: Send,
139    {
140        let (tx, rx) = futures_channel::mpsc::unbounded();
141
142        crate::spawn_platform(move || callback(tx));
143
144        Self::new(rx)
145    }
146
147    /// Returns the next item in the stream, or `None` if the stream has ended.
148    pub async fn next(&mut self) -> Option<Result<T, StreamingError>> {
149        self.stream.as_mut().next().await
150    }
151
152    /// Consumes the wrapper, returning the inner stream.
153    pub fn into_inner(self) -> impl Stream<Item = Result<T, StreamingError>> + Send {
154        self.stream
155    }
156
157    /// Creates a streaming payload from an existing stream of bytes.
158    ///
159    /// This uses the internal framing mechanism to decode the stream into items of type `T`.
160    fn from_bytes(stream: impl Stream<Item = Result<T, StreamingError>> + Send + 'static) -> Self {
161        Self {
162            stream: Box::pin(stream),
163            encoding: PhantomData,
164        }
165    }
166}
167
168impl<S, U> From<S> for TextStream
169where
170    S: Stream<Item = U> + Send + 'static,
171    U: Into<String>,
172{
173    fn from(value: S) -> Self {
174        Self::new(value.map(|data| data.into()))
175    }
176}
177
178impl<S, E> From<S> for ByteStream
179where
180    S: Stream<Item = Result<Bytes, E>> + Send + 'static,
181{
182    fn from(value: S) -> Self {
183        Self {
184            stream: Box::pin(value.map(|data| data.map_err(|_| StreamingError::Failed))),
185            encoding: PhantomData,
186        }
187    }
188}
189
190impl<T, S, U, E> From<S> for Streaming<T, E>
191where
192    S: Stream<Item = U> + Send + 'static,
193    U: Into<T>,
194    T: 'static + Send,
195    E: Encoding,
196{
197    fn from(value: S) -> Self {
198        Self::from_bytes(value.map(|data| Ok(data.into())))
199    }
200}
201
202impl IntoResponse for Streaming<String> {
203    fn into_response(self) -> axum_core::response::Response {
204        axum::response::Response::builder()
205            .header("Content-Type", "text/plain; charset=utf-8")
206            .body(axum::body::Body::from_stream(self.stream))
207            .unwrap()
208    }
209}
210
211impl IntoResponse for Streaming<Bytes> {
212    fn into_response(self) -> axum_core::response::Response {
213        axum::response::Response::builder()
214            .header("Content-Type", "application/octet-stream")
215            .body(axum::body::Body::from_stream(self.stream))
216            .unwrap()
217    }
218}
219
220impl<T: DeserializeOwned + Serialize + 'static, E: Encoding> IntoResponse for Streaming<T, E> {
221    fn into_response(self) -> axum_core::response::Response {
222        let res = self.stream.map(|r| match r {
223            Ok(res) => match encode_stream_frame::<T, E>(res) {
224                Some(bytes) => Ok(bytes),
225                None => Err(StreamingError::Failed),
226            },
227            Err(_err) => Err(StreamingError::Failed),
228        });
229
230        axum::response::Response::builder()
231            .header("Content-Type", E::stream_content_type())
232            .body(axum::body::Body::from_stream(res))
233            .unwrap()
234    }
235}
236
237impl FromResponse for Streaming<String> {
238    fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
239        SendWrapper::new(async move {
240            let client_stream = Box::pin(res.bytes_stream().map(|byte| match byte {
241                Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
242                    Ok(string) => Ok(string),
243                    Err(_) => Err(StreamingError::Decoding),
244                },
245                Err(_) => Err(StreamingError::Failed),
246            }));
247
248            Ok(Self {
249                stream: client_stream,
250                encoding: PhantomData,
251            })
252        })
253    }
254}
255
256impl FromResponse for Streaming<Bytes> {
257    fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
258        async move {
259            let client_stream = Box::pin(SendWrapper::new(res.bytes_stream().map(
260                |byte| match byte {
261                    Ok(bytes) => Ok(bytes),
262                    Err(_) => Err(StreamingError::Failed),
263                },
264            )));
265
266            Ok(Self {
267                stream: client_stream,
268                encoding: PhantomData,
269            })
270        }
271    }
272}
273
274impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding> FromResponse
275    for Streaming<T, E>
276{
277    fn from_response(res: ClientResponse) -> impl Future<Output = Result<Self, ServerFnError>> {
278        SendWrapper::new(async move {
279            Ok(Self {
280                stream: byte_stream_to_client_stream::<E, _, _, _>(res.bytes_stream()),
281                encoding: PhantomData,
282            })
283        })
284    }
285}
286
287impl<S> FromRequest<S> for Streaming<String> {
288    type Rejection = ServerFnError;
289
290    fn from_request(
291        req: Request,
292        _state: &S,
293    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
294        async move {
295            let (parts, body) = req.into_parts();
296            let content_type = parts
297                .headers
298                .get("content-type")
299                .and_then(|v| v.to_str().ok())
300                .unwrap_or("");
301
302            if !content_type.starts_with("text/plain") {
303                HttpError::bad_request("Invalid content type")?;
304            }
305
306            let stream = body.into_data_stream();
307
308            Ok(Self {
309                stream: Box::pin(stream.map(|byte| match byte {
310                    Ok(bytes) => match String::from_utf8(bytes.to_vec()) {
311                        Ok(string) => Ok(string),
312                        Err(_) => Err(StreamingError::Decoding),
313                    },
314                    Err(_) => Err(StreamingError::Failed),
315                })),
316                encoding: PhantomData,
317            })
318        }
319    }
320}
321
322impl<S> FromRequest<S> for ByteStream {
323    type Rejection = ServerFnError;
324
325    fn from_request(
326        req: Request,
327        _state: &S,
328    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
329        async move {
330            let (parts, body) = req.into_parts();
331            let content_type = parts
332                .headers
333                .get("content-type")
334                .and_then(|v| v.to_str().ok())
335                .unwrap_or("");
336
337            if !content_type.starts_with("application/octet-stream") {
338                HttpError::bad_request("Invalid content type")?;
339            }
340
341            let stream = body.into_data_stream();
342
343            Ok(Self {
344                stream: Box::pin(stream.map(|byte| match byte {
345                    Ok(bytes) => Ok(bytes),
346                    Err(_) => Err(StreamingError::Failed),
347                })),
348                encoding: PhantomData,
349            })
350        }
351    }
352}
353
354impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding, S> FromRequest<S>
355    for Streaming<T, E>
356{
357    type Rejection = ServerFnError;
358
359    fn from_request(
360        req: Request,
361        _state: &S,
362    ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
363        async move {
364            let (parts, body) = req.into_parts();
365            let content_type = parts
366                .headers
367                .get("content-type")
368                .and_then(|v| v.to_str().ok())
369                .unwrap_or("");
370
371            if !content_type.starts_with(E::stream_content_type()) {
372                HttpError::bad_request("Invalid content type")?;
373            }
374
375            let stream = body.into_data_stream();
376
377            Ok(Self {
378                stream: byte_stream_to_client_stream::<E, _, _, _>(stream),
379                encoding: PhantomData,
380            })
381        }
382    }
383}
384
385impl IntoRequest for Streaming<String> {
386    fn into_request(
387        self,
388        builder: ClientRequest,
389    ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
390        async move {
391            builder
392                .header("Content-Type", "text/plain; charset=utf-8")?
393                .send_body_stream(self.stream.map(|e| e.map(Bytes::from)))
394                .await
395        }
396    }
397}
398
399impl IntoRequest for ByteStream {
400    fn into_request(
401        self,
402        builder: ClientRequest,
403    ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
404        async move {
405            builder
406                .header(ContentType::name(), "application/octet-stream")?
407                .send_body_stream(self.stream)
408                .await
409        }
410    }
411}
412
413impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding> IntoRequest
414    for Streaming<T, E>
415{
416    fn into_request(
417        self,
418        builder: ClientRequest,
419    ) -> impl Future<Output = Result<ClientResponse, RequestError>> + 'static {
420        async move {
421            builder
422                .header("Content-Type", E::stream_content_type())?
423                .send_body_stream(self.stream.map(|r| {
424                    r.and_then(|item| {
425                        encode_stream_frame::<T, E>(item).ok_or(StreamingError::Failed)
426                    })
427                }))
428                .await
429        }
430    }
431}
432
433impl<T> std::fmt::Debug for Streaming<T> {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_tuple("Streaming").finish()
436    }
437}
438
439impl<T, E: Encoding> std::fmt::Debug for Streaming<T, E> {
440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
441        f.debug_struct("Streaming")
442            .field("encoding", &std::any::type_name::<E>())
443            .finish()
444    }
445}
446
447/// This function encodes a single frame of a streaming payload using the specified encoding.
448///
449/// The resulting `Bytes` object is encoded as a websocket frame, so you can send it over a streaming
450/// HTTP response or even a websocket connection.
451///
452/// Note that the packet is not masked, as it is assumed to be sent over a trusted connection.
453pub fn encode_stream_frame<T: Serialize, E: Encoding>(data: T) -> Option<Bytes> {
454    // We use full advantage of `BytesMut` here, writing a maximally full frame and then shrinking it
455    // down to size at the end.
456    //
457    // Also note we don't do any masking over this data since it's not going over an untrusted
458    // network like a websocket would.
459    //
460    // We allocate 10 extra bytes to account for framing overhead, which we'll shrink after
461    let mut bytes = vec![0u8; 10];
462
463    E::encode(data, &mut bytes)?;
464
465    let len = (bytes.len() - 10) as u64;
466    let opcode = 0x82; // FIN + binary opcode
467
468    // Write the header directly into the allocated space.
469    let offset = if len <= 125 {
470        bytes[8] = opcode;
471        bytes[9] = len as u8;
472        8
473    } else if len <= u16::MAX as u64 {
474        bytes[6] = opcode;
475        bytes[7] = 126;
476        let len_bytes = (len as u16).to_be_bytes();
477        bytes[8] = len_bytes[0];
478        bytes[9] = len_bytes[1];
479        6
480    } else {
481        bytes[0] = opcode;
482        bytes[1] = 127;
483        bytes[2..10].copy_from_slice(&len.to_be_bytes());
484        0
485    };
486
487    // Shrink down to the actual used size - is zero copy!
488    Some(Bytes::from(bytes).slice(offset..))
489}
490
491fn byte_stream_to_client_stream<E, T, S, E1>(
492    stream: S,
493) -> Pin<Box<dyn Stream<Item = Result<T, StreamingError>> + Send>>
494where
495    S: Stream<Item = Result<Bytes, E1>> + 'static + Send,
496    E: Encoding,
497    T: DeserializeOwned + 'static,
498{
499    Box::pin(stream.flat_map(|bytes| {
500        enum DecodeIteratorState {
501            Empty,
502            Failed,
503            Checked(Bytes),
504            UnChecked(Bytes),
505        }
506
507        let mut state = match bytes {
508            Ok(bytes) => DecodeIteratorState::UnChecked(bytes),
509            Err(_) => DecodeIteratorState::Failed,
510        };
511
512        futures::stream::iter(std::iter::from_fn(move || {
513            match std::mem::replace(&mut state, DecodeIteratorState::Empty) {
514                DecodeIteratorState::Empty => None,
515                DecodeIteratorState::Failed => Some(Err(StreamingError::Failed)),
516                DecodeIteratorState::Checked(mut bytes) => {
517                    let r = decode_stream_frame_multi::<T, E>(&mut bytes);
518                    if r.is_some() {
519                        state = DecodeIteratorState::Checked(bytes)
520                    }
521                    r
522                }
523                DecodeIteratorState::UnChecked(mut bytes) => {
524                    let r = decode_stream_frame_multi::<T, E>(&mut bytes);
525                    if r.is_some() {
526                        state = DecodeIteratorState::Checked(bytes);
527                        r
528                    } else {
529                        Some(Err(StreamingError::Decoding))
530                    }
531                }
532            }
533        }))
534    }))
535}
536
537/// Decode a websocket-framed streaming payload produced by [`encode_stream_frame`].
538///
539/// This function returns `None` if the frame is invalid or cannot be decoded.
540///
541/// It cannot handle masked frames, as those are not produced by our encoding function.
542pub fn decode_stream_frame<T, E>(mut frame: Bytes) -> Option<T>
543where
544    E: Encoding,
545    T: DeserializeOwned,
546{
547    decode_stream_frame_multi::<T, E>(&mut frame).and_then(|r| r.ok())
548}
549
550/// Decode one value and advance the bytes pointer
551///
552/// If the frame is empty return None.
553///
554/// Otherwise, if the initial opcode is not the one expected for binary stream
555/// or the frame is not large enough return error StreamingError::Decoding
556fn decode_stream_frame_multi<T, E>(frame: &mut Bytes) -> Option<Result<T, StreamingError>>
557where
558    E: Encoding,
559    T: DeserializeOwned,
560{
561    let (offset, payload_len) = match offset_payload_len(frame)? {
562        Ok(r) => r,
563        Err(e) => return Some(Err(e)),
564    };
565
566    let r = E::decode(frame.slice(offset..offset + payload_len));
567    frame.advance(offset + payload_len);
568    r.map(|r| Ok(r))
569}
570
571/// Compute (offset,len) for decoding data
572fn offset_payload_len(frame: &Bytes) -> Option<Result<(usize, usize), StreamingError>> {
573    let data = frame.as_ref();
574
575    if data.is_empty() {
576        return None;
577    }
578
579    if data.len() < 2 {
580        return Some(Err(StreamingError::Decoding));
581    }
582
583    let first = data[0];
584    let second = data[1];
585
586    // Require FIN with binary opcode and no RSV bits
587    let fin = first & 0x80 != 0;
588    let opcode = first & 0x0F;
589    let rsv = first & 0x70;
590    if !fin || opcode != 0x02 || rsv != 0 {
591        return Some(Err(StreamingError::Decoding));
592    }
593
594    // Mask bit must be zero for our framing
595    if second & 0x80 != 0 {
596        return Some(Err(StreamingError::Decoding));
597    }
598
599    let mut offset = 2usize;
600    let mut payload_len = (second & 0x7F) as usize;
601
602    if payload_len == 126 {
603        if data.len() < offset + 2 {
604            return Some(Err(StreamingError::Decoding));
605        }
606
607        payload_len = u16::from_be_bytes([data[offset], data[offset + 1]]) as usize;
608        offset += 2;
609    } else if payload_len == 127 {
610        if data.len() < offset + 8 {
611            return Some(Err(StreamingError::Decoding));
612        }
613
614        let mut len_bytes = [0u8; 8];
615        len_bytes.copy_from_slice(&data[offset..offset + 8]);
616        let len_u64 = u64::from_be_bytes(len_bytes);
617
618        if len_u64 > usize::MAX as u64 {
619            return Some(Err(StreamingError::Decoding));
620        }
621
622        payload_len = len_u64 as usize;
623        offset += 8;
624    }
625
626    if data.len() < offset + payload_len {
627        return Some(Err(StreamingError::Decoding));
628    }
629    Some(Ok((offset, payload_len)))
630}