Skip to main content

actix_http/h1/
payload.rs

1//! Payload stream
2
3use std::{
4    cell::RefCell,
5    collections::VecDeque,
6    pin::Pin,
7    rc::{Rc, Weak},
8    task::{Context, Poll, Waker},
9};
10
11use bytes::Bytes;
12use futures_core::Stream;
13
14use crate::error::PayloadError;
15
16/// max buffer size 32k
17pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19#[derive(Debug, PartialEq, Eq)]
20pub enum PayloadStatus {
21    Read,
22    Pause,
23    Dropped,
24}
25
26/// Buffered stream of bytes chunks
27///
28/// Payload stores chunks in a vector. First chunk can be received with `poll_next`. Payload does
29/// not notify current task when new data is available.
30///
31/// Payload can be used as `Response` body stream.
32#[derive(Debug)]
33pub struct Payload {
34    inner: Rc<RefCell<Inner>>,
35}
36
37impl Payload {
38    /// Creates a payload stream.
39    ///
40    /// This method construct two objects responsible for bytes stream generation:
41    /// - `PayloadSender` - *Sender* side of the stream
42    /// - `Payload` - *Receiver* side of the stream
43    pub fn create(eof: bool) -> (PayloadSender, Payload) {
44        let shared = Rc::new(RefCell::new(Inner::new(eof)));
45
46        (
47            PayloadSender::new(Rc::downgrade(&shared)),
48            Payload { inner: shared },
49        )
50    }
51
52    /// Creates an empty payload.
53    pub(crate) fn empty() -> Payload {
54        Payload {
55            inner: Rc::new(RefCell::new(Inner::new(true))),
56        }
57    }
58
59    /// Length of the data in this payload
60    #[cfg(test)]
61    pub fn len(&self) -> usize {
62        self.inner.borrow().len()
63    }
64
65    /// Is payload empty
66    #[cfg(test)]
67    pub fn is_empty(&self) -> bool {
68        self.inner.borrow().len() == 0
69    }
70
71    /// Put unused data back to payload
72    #[inline]
73    pub fn unread_data(&mut self, data: Bytes) {
74        self.inner.borrow_mut().unread_data(data);
75    }
76}
77
78impl Stream for Payload {
79    type Item = Result<Bytes, PayloadError>;
80
81    fn poll_next(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84    ) -> Poll<Option<Result<Bytes, PayloadError>>> {
85        Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx)
86    }
87}
88
89/// Sender part of the payload stream
90pub struct PayloadSender {
91    inner: Weak<RefCell<Inner>>,
92}
93
94impl PayloadSender {
95    fn new(inner: Weak<RefCell<Inner>>) -> Self {
96        Self { inner }
97    }
98
99    #[inline]
100    pub fn set_error(&mut self, err: PayloadError) {
101        if let Some(shared) = self.inner.upgrade() {
102            shared.borrow_mut().set_error(err)
103        }
104    }
105
106    #[inline]
107    pub fn feed_eof(&mut self) {
108        if let Some(shared) = self.inner.upgrade() {
109            shared.borrow_mut().feed_eof()
110        }
111    }
112
113    #[inline]
114    pub fn feed_data(&mut self, data: Bytes) {
115        if let Some(shared) = self.inner.upgrade() {
116            shared.borrow_mut().feed_data(data)
117        }
118    }
119
120    #[allow(clippy::needless_pass_by_ref_mut)]
121    #[inline]
122    pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
123        // we check need_read only if Payload (other side) is alive,
124        // otherwise always return true (consume payload)
125        if let Some(shared) = self.inner.upgrade() {
126            if shared.borrow().need_read {
127                PayloadStatus::Read
128            } else {
129                shared.borrow_mut().register_io(cx);
130                PayloadStatus::Pause
131            }
132        } else {
133            PayloadStatus::Dropped
134        }
135    }
136
137    #[inline]
138    pub fn is_dropped(&self) -> bool {
139        self.inner.strong_count() == 0
140    }
141}
142
143#[derive(Debug)]
144struct Inner {
145    len: usize,
146    eof: bool,
147    err: Option<PayloadError>,
148    need_read: bool,
149    items: VecDeque<Bytes>,
150    task: Option<Waker>,
151    io_task: Option<Waker>,
152}
153
154impl Inner {
155    fn new(eof: bool) -> Self {
156        Inner {
157            eof,
158            len: 0,
159            err: None,
160            items: VecDeque::new(),
161            need_read: true,
162            task: None,
163            io_task: None,
164        }
165    }
166
167    /// Wake up future waiting for payload data to be available.
168    fn wake(&mut self) {
169        if let Some(waker) = self.task.take() {
170            waker.wake();
171        }
172    }
173
174    /// Wake up future feeding data to Payload.
175    fn wake_io(&mut self) {
176        if let Some(waker) = self.io_task.take() {
177            waker.wake();
178        }
179    }
180
181    /// Register future waiting data from payload.
182    /// Waker would be used in `Inner::wake`
183    fn register(&mut self, cx: &Context<'_>) {
184        if self.task.as_ref().is_none_or(|w| !cx.waker().will_wake(w)) {
185            self.task = Some(cx.waker().clone());
186        }
187    }
188
189    // Register future feeding data to payload.
190    /// Waker would be used in `Inner::wake_io`
191    fn register_io(&mut self, cx: &Context<'_>) {
192        if self
193            .io_task
194            .as_ref()
195            .is_none_or(|w| !cx.waker().will_wake(w))
196        {
197            self.io_task = Some(cx.waker().clone());
198        }
199    }
200
201    #[inline]
202    fn set_error(&mut self, err: PayloadError) {
203        self.err = Some(err);
204        self.wake();
205    }
206
207    #[inline]
208    fn feed_eof(&mut self) {
209        self.eof = true;
210        self.wake();
211    }
212
213    #[inline]
214    fn feed_data(&mut self, data: Bytes) {
215        self.len += data.len();
216        self.items.push_back(data);
217        self.need_read = self.len < MAX_BUFFER_SIZE;
218        self.wake();
219    }
220
221    #[cfg(test)]
222    fn len(&self) -> usize {
223        self.len
224    }
225
226    fn poll_next(
227        mut self: Pin<&mut Self>,
228        cx: &Context<'_>,
229    ) -> Poll<Option<Result<Bytes, PayloadError>>> {
230        if let Some(data) = self.items.pop_front() {
231            self.len -= data.len();
232            self.need_read = self.len < MAX_BUFFER_SIZE;
233
234            if self.need_read && !self.eof {
235                self.register(cx);
236            }
237            self.wake_io();
238            Poll::Ready(Some(Ok(data)))
239        } else if let Some(err) = self.err.take() {
240            Poll::Ready(Some(Err(err)))
241        } else if self.eof {
242            Poll::Ready(None)
243        } else {
244            self.need_read = true;
245            self.register(cx);
246            self.wake_io();
247            Poll::Pending
248        }
249    }
250
251    fn unread_data(&mut self, data: Bytes) {
252        self.len += data.len();
253        self.items.push_front(data);
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use std::{task::Poll, time::Duration};
260
261    use actix_rt::time::timeout;
262    use actix_utils::future::poll_fn;
263    use futures_util::{FutureExt, StreamExt};
264    use static_assertions::{assert_impl_all, assert_not_impl_any};
265    use tokio::sync::oneshot;
266
267    use super::*;
268
269    assert_impl_all!(Payload: Unpin);
270    assert_not_impl_any!(Payload: Send, Sync);
271
272    assert_impl_all!(Inner: Unpin, Send, Sync);
273
274    const WAKE_TIMEOUT: Duration = Duration::from_secs(2);
275
276    fn prepare_waking_test(
277        mut payload: Payload,
278        expected: Option<Result<(), ()>>,
279    ) -> (oneshot::Receiver<()>, actix_rt::task::JoinHandle<()>) {
280        let (tx, rx) = oneshot::channel();
281
282        let handle = actix_rt::spawn(async move {
283            // Make sure to poll once to set the waker
284            poll_fn(|cx| {
285                assert!(payload.poll_next_unpin(cx).is_pending());
286                Poll::Ready(())
287            })
288            .await;
289            tx.send(()).unwrap();
290
291            // actix-rt is single-threaded, so this won't race with `rx.await`
292            let mut pend_once = false;
293            poll_fn(|_| {
294                if pend_once {
295                    Poll::Ready(())
296                } else {
297                    // Return pending without storing wakers, we already did on the previous
298                    // `poll_fn`, now this task will only continue if the `sender` wakes us
299                    pend_once = true;
300                    Poll::Pending
301                }
302            })
303            .await;
304
305            let got = payload.next().now_or_never().unwrap();
306            match expected {
307                Some(Ok(_)) => assert!(got.unwrap().is_ok()),
308                Some(Err(_)) => assert!(got.unwrap().is_err()),
309                None => assert!(got.is_none()),
310            }
311        });
312        (rx, handle)
313    }
314
315    #[actix_rt::test]
316    async fn wake_on_error() {
317        let (mut sender, payload) = Payload::create(false);
318        let (rx, handle) = prepare_waking_test(payload, Some(Err(())));
319
320        rx.await.unwrap();
321        sender.set_error(PayloadError::Incomplete(None));
322        timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
323    }
324
325    #[actix_rt::test]
326    async fn wake_on_eof() {
327        let (mut sender, payload) = Payload::create(false);
328        let (rx, handle) = prepare_waking_test(payload, None);
329
330        rx.await.unwrap();
331        sender.feed_eof();
332        timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
333    }
334
335    #[actix_rt::test]
336    async fn test_unread_data() {
337        let (_, mut payload) = Payload::create(false);
338
339        payload.unread_data(Bytes::from("data"));
340        assert!(!payload.is_empty());
341        assert_eq!(payload.len(), 4);
342
343        assert_eq!(
344            Bytes::from("data"),
345            poll_fn(|cx| Pin::new(&mut payload).poll_next(cx))
346                .await
347                .unwrap()
348                .unwrap()
349        );
350    }
351}