xitca_http/h1/
body.rs

1use core::{
2    cell::{RefCell, RefMut},
3    fmt,
4    future::poll_fn,
5    ops::DerefMut,
6    pin::Pin,
7    task::{Context, Poll, Waker},
8};
9
10use std::{collections::VecDeque, io, rc::Rc};
11
12use futures_core::stream::Stream;
13
14use crate::bytes::Bytes;
15
16/// max buffer size 32k
17pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19/// Buffered stream of request body chunk.
20///
21/// impl [Stream] trait to produce chunk as [Bytes] type in async manner.
22pub struct RequestBody(Option<Body>);
23
24impl fmt::Debug for RequestBody {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        f.write_str("RequestBody")
27    }
28}
29
30type Body = Pin<Box<dyn Stream<Item = io::Result<Bytes>>>>;
31
32impl Default for RequestBody {
33    fn default() -> Self {
34        Self::none()
35    }
36}
37
38impl RequestBody {
39    // an async spsc channel where sender is used to push data and popped from RequestBody.
40    pub(super) fn channel(eof: bool) -> (BodySender, Self) {
41        if eof {
42            (ChannelBody::none(), RequestBody::none())
43        } else {
44            let body = ChannelBody::stream();
45            (body.clone(), RequestBody::stream(body))
46        }
47    }
48
49    pub(super) fn stream<S>(stream: S) -> Self
50    where
51        S: Stream<Item = io::Result<Bytes>> + 'static,
52    {
53        RequestBody(Some(Box::pin(stream)))
54    }
55
56    pub(super) fn none() -> Self {
57        Self(None)
58    }
59}
60
61impl Stream for RequestBody {
62    type Item = io::Result<Bytes>;
63
64    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Bytes>>> {
65        match self.get_mut().0 {
66            Some(ref mut body) => body.as_mut().poll_next(cx),
67            None => Poll::Ready(None),
68        }
69    }
70}
71
72impl From<RequestBody> for crate::body::RequestBody {
73    fn from(body: RequestBody) -> Self {
74        Self::H1(body)
75    }
76}
77
78/// Sender part of the payload stream
79pub(super) type BodySender = ChannelBody;
80
81#[derive(Clone)]
82pub(super) struct ChannelBody(Option<Rc<RefCell<Inner>>>);
83
84impl Stream for ChannelBody {
85    type Item = io::Result<Bytes>;
86
87    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<Bytes>>> {
88        match self.get_mut().0 {
89            Some(ref body) => body.borrow_mut().poll_next_unpin(cx),
90            None => Poll::Ready(None),
91        }
92    }
93}
94
95// TODO: rework early eof error handling.
96impl Drop for ChannelBody {
97    fn drop(&mut self) {
98        if let Some(mut inner) = self.try_inner() {
99            if !inner.eof {
100                inner.feed_error(io::ErrorKind::UnexpectedEof.into());
101            }
102        }
103    }
104}
105
106impl ChannelBody {
107    fn stream() -> Self {
108        Self(Some(Default::default()))
109    }
110
111    fn none() -> Self {
112        Self(None)
113    }
114
115    // try to get a mutable reference of inner and ignore RequestBody::None variant.
116    fn try_inner(&mut self) -> Option<RefMut<'_, Inner>> {
117        self.try_inner_on_none_with(|| {})
118    }
119
120    // try to get a mutable reference of inner and panic on RequestBody::None variant.
121    // this is a runtime check for internal optimization to avoid unnecessary operations.
122    // public api must not be able to trigger this panic.
123    fn try_inner_infallible(&mut self) -> Option<RefMut<'_, Inner>> {
124        self.try_inner_on_none_with(|| panic!("No Request Body found. Do not waste operation on Sender."))
125    }
126
127    fn try_inner_on_none_with<F>(&mut self, func: F) -> Option<RefMut<'_, Inner>>
128    where
129        F: FnOnce(),
130    {
131        match self.0 {
132            Some(ref inner) => {
133                // request body is a shared pointer between only two owners and no weak reference.
134                debug_assert!(Rc::strong_count(inner) <= 2);
135                debug_assert_eq!(Rc::weak_count(inner), 0);
136                (Rc::strong_count(inner) != 1).then_some(inner.borrow_mut())
137            }
138            None => {
139                func();
140                None
141            }
142        }
143    }
144
145    pub(super) fn feed_error(&mut self, e: io::Error) {
146        if let Some(mut inner) = self.try_inner_infallible() {
147            inner.feed_error(e);
148        }
149    }
150
151    pub(super) fn feed_eof(&mut self) {
152        if let Some(mut inner) = self.try_inner_infallible() {
153            inner.feed_eof();
154        }
155    }
156
157    pub(super) fn feed_data(&mut self, data: Bytes) {
158        if let Some(mut inner) = self.try_inner_infallible() {
159            inner.feed_data(data);
160        }
161    }
162
163    pub(super) fn ready(&mut self) -> impl Future<Output = io::Result<()>> + '_ {
164        self.ready_with(|inner| !inner.backpressure())
165    }
166
167    // Lazily wait until RequestBody is already polled.
168    // For specific use case body must not be eagerly polled.
169    // For example: Request with Expect: Continue header.
170    pub(super) fn wait_for_poll(&mut self) -> impl Future<Output = io::Result<()>> + '_ {
171        self.ready_with(|inner| inner.waiting())
172    }
173
174    async fn ready_with<F>(&mut self, func: F) -> io::Result<()>
175    where
176        F: Fn(&mut Inner) -> bool,
177    {
178        poll_fn(|cx| {
179            // Check only if Payload (other side) is alive, Otherwise always return io error.
180            match self.try_inner_infallible() {
181                Some(mut inner) => {
182                    if func(inner.deref_mut()) {
183                        Poll::Ready(Ok(()))
184                    } else {
185                        // when payload is not ready register current task waker and wait.
186                        inner.register_io(cx);
187                        Poll::Pending
188                    }
189                }
190                None => Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
191            }
192        })
193        .await
194    }
195}
196
197#[derive(Debug, Default)]
198struct Inner {
199    eof: bool,
200    len: usize,
201    err: Option<io::Error>,
202    items: VecDeque<Bytes>,
203    task: Option<Waker>,
204    io_task: Option<Waker>,
205}
206
207impl Inner {
208    /// Wake up future waiting for payload data to be available.
209    fn wake(&mut self) {
210        if let Some(waker) = self.task.take() {
211            waker.wake();
212        }
213    }
214
215    /// Wake up future feeding data to Payload.
216    fn wake_io(&mut self) {
217        if let Some(waker) = self.io_task.take() {
218            waker.wake();
219        }
220    }
221
222    /// true when a future is waiting for payload data.
223    fn waiting(&self) -> bool {
224        self.task.is_some()
225    }
226
227    /// Register future waiting data from payload.
228    /// Waker would be used in `Inner::wake`
229    fn register(&mut self, cx: &Context<'_>) {
230        if self.task.as_ref().map(|w| !cx.waker().will_wake(w)).unwrap_or(true) {
231            self.task = Some(cx.waker().clone());
232        }
233    }
234
235    // Register future feeding data to payload.
236    /// Waker would be used in `Inner::wake_io`
237    fn register_io(&mut self, cx: &Context<'_>) {
238        if self.io_task.as_ref().map(|w| !cx.waker().will_wake(w)).unwrap_or(true) {
239            self.io_task = Some(cx.waker().clone());
240        }
241    }
242
243    fn feed_error(&mut self, err: io::Error) {
244        self.err = Some(err);
245        self.wake();
246    }
247
248    fn feed_eof(&mut self) {
249        self.eof = true;
250        self.wake();
251    }
252
253    fn feed_data(&mut self, data: Bytes) {
254        self.len += data.len();
255        self.items.push_back(data);
256        self.wake();
257    }
258
259    fn backpressure(&self) -> bool {
260        self.len >= MAX_BUFFER_SIZE
261    }
262
263    fn poll_next_unpin(&mut self, cx: &Context<'_>) -> Poll<Option<io::Result<Bytes>>> {
264        if let Some(data) = self.items.pop_front() {
265            self.len -= data.len();
266            Poll::Ready(Some(Ok(data)))
267        } else if let Some(err) = self.err.take() {
268            Poll::Ready(Some(Err(err)))
269        } else if self.eof {
270            Poll::Ready(None)
271        } else {
272            self.register(cx);
273            self.wake_io();
274            Poll::Pending
275        }
276    }
277}