1use std::{
4 cell::RefCell,
5 collections::VecDeque,
6 pin::Pin,
7 rc::{Rc, Weak},
8 task::{Context, Poll, Waker},
9};
10
11use bytes::Bytes;
12use futures_core::Stream;
13
14use crate::error::PayloadError;
15
16pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19#[derive(Debug, PartialEq, Eq)]
20pub enum PayloadStatus {
21 Read,
22 Pause,
23 Dropped,
24}
25
26#[derive(Debug)]
33pub struct Payload {
34 inner: Rc<RefCell<Inner>>,
35}
36
37impl Payload {
38 pub fn create(eof: bool) -> (PayloadSender, Payload) {
44 let shared = Rc::new(RefCell::new(Inner::new(eof)));
45
46 (
47 PayloadSender::new(Rc::downgrade(&shared)),
48 Payload { inner: shared },
49 )
50 }
51
52 pub(crate) fn empty() -> Payload {
54 Payload {
55 inner: Rc::new(RefCell::new(Inner::new(true))),
56 }
57 }
58
59 #[cfg(test)]
61 pub fn len(&self) -> usize {
62 self.inner.borrow().len()
63 }
64
65 #[cfg(test)]
67 pub fn is_empty(&self) -> bool {
68 self.inner.borrow().len() == 0
69 }
70
71 #[inline]
73 pub fn unread_data(&mut self, data: Bytes) {
74 self.inner.borrow_mut().unread_data(data);
75 }
76}
77
78impl Stream for Payload {
79 type Item = Result<Bytes, PayloadError>;
80
81 fn poll_next(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
85 Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx)
86 }
87}
88
89pub struct PayloadSender {
91 inner: Weak<RefCell<Inner>>,
92}
93
94impl PayloadSender {
95 fn new(inner: Weak<RefCell<Inner>>) -> Self {
96 Self { inner }
97 }
98
99 #[inline]
100 pub fn set_error(&mut self, err: PayloadError) {
101 if let Some(shared) = self.inner.upgrade() {
102 shared.borrow_mut().set_error(err)
103 }
104 }
105
106 #[inline]
107 pub fn feed_eof(&mut self) {
108 if let Some(shared) = self.inner.upgrade() {
109 shared.borrow_mut().feed_eof()
110 }
111 }
112
113 #[inline]
114 pub fn feed_data(&mut self, data: Bytes) {
115 if let Some(shared) = self.inner.upgrade() {
116 shared.borrow_mut().feed_data(data)
117 }
118 }
119
120 #[allow(clippy::needless_pass_by_ref_mut)]
121 #[inline]
122 pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
123 if let Some(shared) = self.inner.upgrade() {
126 if shared.borrow().need_read {
127 PayloadStatus::Read
128 } else {
129 shared.borrow_mut().register_io(cx);
130 PayloadStatus::Pause
131 }
132 } else {
133 PayloadStatus::Dropped
134 }
135 }
136
137 #[inline]
138 pub fn is_dropped(&self) -> bool {
139 self.inner.strong_count() == 0
140 }
141}
142
143#[derive(Debug)]
144struct Inner {
145 len: usize,
146 eof: bool,
147 err: Option<PayloadError>,
148 need_read: bool,
149 items: VecDeque<Bytes>,
150 task: Option<Waker>,
151 io_task: Option<Waker>,
152}
153
154impl Inner {
155 fn new(eof: bool) -> Self {
156 Inner {
157 eof,
158 len: 0,
159 err: None,
160 items: VecDeque::new(),
161 need_read: true,
162 task: None,
163 io_task: None,
164 }
165 }
166
167 fn wake(&mut self) {
169 if let Some(waker) = self.task.take() {
170 waker.wake();
171 }
172 }
173
174 fn wake_io(&mut self) {
176 if let Some(waker) = self.io_task.take() {
177 waker.wake();
178 }
179 }
180
181 fn register(&mut self, cx: &Context<'_>) {
184 if self.task.as_ref().is_none_or(|w| !cx.waker().will_wake(w)) {
185 self.task = Some(cx.waker().clone());
186 }
187 }
188
189 fn register_io(&mut self, cx: &Context<'_>) {
192 if self
193 .io_task
194 .as_ref()
195 .is_none_or(|w| !cx.waker().will_wake(w))
196 {
197 self.io_task = Some(cx.waker().clone());
198 }
199 }
200
201 #[inline]
202 fn set_error(&mut self, err: PayloadError) {
203 self.err = Some(err);
204 self.wake();
205 }
206
207 #[inline]
208 fn feed_eof(&mut self) {
209 self.eof = true;
210 self.wake();
211 }
212
213 #[inline]
214 fn feed_data(&mut self, data: Bytes) {
215 self.len += data.len();
216 self.items.push_back(data);
217 self.need_read = self.len < MAX_BUFFER_SIZE;
218 self.wake();
219 }
220
221 #[cfg(test)]
222 fn len(&self) -> usize {
223 self.len
224 }
225
226 fn poll_next(
227 mut self: Pin<&mut Self>,
228 cx: &Context<'_>,
229 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
230 if let Some(data) = self.items.pop_front() {
231 self.len -= data.len();
232 self.need_read = self.len < MAX_BUFFER_SIZE;
233
234 if self.need_read && !self.eof {
235 self.register(cx);
236 }
237 self.wake_io();
238 Poll::Ready(Some(Ok(data)))
239 } else if let Some(err) = self.err.take() {
240 Poll::Ready(Some(Err(err)))
241 } else if self.eof {
242 Poll::Ready(None)
243 } else {
244 self.need_read = true;
245 self.register(cx);
246 self.wake_io();
247 Poll::Pending
248 }
249 }
250
251 fn unread_data(&mut self, data: Bytes) {
252 self.len += data.len();
253 self.items.push_front(data);
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use std::{task::Poll, time::Duration};
260
261 use actix_rt::time::timeout;
262 use actix_utils::future::poll_fn;
263 use futures_util::{FutureExt, StreamExt};
264 use static_assertions::{assert_impl_all, assert_not_impl_any};
265 use tokio::sync::oneshot;
266
267 use super::*;
268
269 assert_impl_all!(Payload: Unpin);
270 assert_not_impl_any!(Payload: Send, Sync);
271
272 assert_impl_all!(Inner: Unpin, Send, Sync);
273
274 const WAKE_TIMEOUT: Duration = Duration::from_secs(2);
275
276 fn prepare_waking_test(
277 mut payload: Payload,
278 expected: Option<Result<(), ()>>,
279 ) -> (oneshot::Receiver<()>, actix_rt::task::JoinHandle<()>) {
280 let (tx, rx) = oneshot::channel();
281
282 let handle = actix_rt::spawn(async move {
283 poll_fn(|cx| {
285 assert!(payload.poll_next_unpin(cx).is_pending());
286 Poll::Ready(())
287 })
288 .await;
289 tx.send(()).unwrap();
290
291 let mut pend_once = false;
293 poll_fn(|_| {
294 if pend_once {
295 Poll::Ready(())
296 } else {
297 pend_once = true;
300 Poll::Pending
301 }
302 })
303 .await;
304
305 let got = payload.next().now_or_never().unwrap();
306 match expected {
307 Some(Ok(_)) => assert!(got.unwrap().is_ok()),
308 Some(Err(_)) => assert!(got.unwrap().is_err()),
309 None => assert!(got.is_none()),
310 }
311 });
312 (rx, handle)
313 }
314
315 #[actix_rt::test]
316 async fn wake_on_error() {
317 let (mut sender, payload) = Payload::create(false);
318 let (rx, handle) = prepare_waking_test(payload, Some(Err(())));
319
320 rx.await.unwrap();
321 sender.set_error(PayloadError::Incomplete(None));
322 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
323 }
324
325 #[actix_rt::test]
326 async fn wake_on_eof() {
327 let (mut sender, payload) = Payload::create(false);
328 let (rx, handle) = prepare_waking_test(payload, None);
329
330 rx.await.unwrap();
331 sender.feed_eof();
332 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
333 }
334
335 #[actix_rt::test]
336 async fn test_unread_data() {
337 let (_, mut payload) = Payload::create(false);
338
339 payload.unread_data(Bytes::from("data"));
340 assert!(!payload.is_empty());
341 assert_eq!(payload.len(), 4);
342
343 assert_eq!(
344 Bytes::from("data"),
345 poll_fn(|cx| Pin::new(&mut payload).poll_next(cx))
346 .await
347 .unwrap()
348 .unwrap()
349 );
350 }
351}