1use std::{
4 cell::RefCell,
5 collections::VecDeque,
6 pin::Pin,
7 rc::{Rc, Weak},
8 task::{Context, Poll, Waker},
9};
10
11use bytes::Bytes;
12use futures_core::Stream;
13
14use crate::error::PayloadError;
15
16pub(crate) const MAX_BUFFER_SIZE: usize = 32_768;
18
19#[derive(Debug, PartialEq, Eq)]
20pub enum PayloadStatus {
21 Read,
22 Pause,
23 Dropped,
24}
25
26#[derive(Debug)]
33pub struct Payload {
34 inner: Rc<RefCell<Inner>>,
35}
36
37impl Payload {
38 pub fn create(eof: bool) -> (PayloadSender, Payload) {
44 let shared = Rc::new(RefCell::new(Inner::new(eof)));
45
46 (
47 PayloadSender::new(Rc::downgrade(&shared)),
48 Payload { inner: shared },
49 )
50 }
51
52 pub(crate) fn empty() -> Payload {
54 Payload {
55 inner: Rc::new(RefCell::new(Inner::new(true))),
56 }
57 }
58
59 #[cfg(test)]
61 pub fn len(&self) -> usize {
62 self.inner.borrow().len()
63 }
64
65 #[cfg(test)]
67 pub fn is_empty(&self) -> bool {
68 self.inner.borrow().len() == 0
69 }
70
71 #[inline]
73 pub fn unread_data(&mut self, data: Bytes) {
74 self.inner.borrow_mut().unread_data(data);
75 }
76}
77
78impl Stream for Payload {
79 type Item = Result<Bytes, PayloadError>;
80
81 fn poll_next(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
85 Pin::new(&mut *self.inner.borrow_mut()).poll_next(cx)
86 }
87}
88
89pub struct PayloadSender {
91 inner: Weak<RefCell<Inner>>,
92}
93
94impl PayloadSender {
95 fn new(inner: Weak<RefCell<Inner>>) -> Self {
96 Self { inner }
97 }
98
99 #[inline]
100 pub fn set_error(&mut self, err: PayloadError) {
101 if let Some(shared) = self.inner.upgrade() {
102 shared.borrow_mut().set_error(err)
103 }
104 }
105
106 #[inline]
107 pub fn feed_eof(&mut self) {
108 if let Some(shared) = self.inner.upgrade() {
109 shared.borrow_mut().feed_eof()
110 }
111 }
112
113 #[inline]
114 pub fn feed_data(&mut self, data: Bytes) {
115 if let Some(shared) = self.inner.upgrade() {
116 shared.borrow_mut().feed_data(data)
117 }
118 }
119
120 #[allow(clippy::needless_pass_by_ref_mut)]
121 #[inline]
122 pub fn need_read(&self, cx: &mut Context<'_>) -> PayloadStatus {
123 if let Some(shared) = self.inner.upgrade() {
126 if shared.borrow().need_read {
127 PayloadStatus::Read
128 } else {
129 shared.borrow_mut().register_io(cx);
130 PayloadStatus::Pause
131 }
132 } else {
133 PayloadStatus::Dropped
134 }
135 }
136
137 #[inline]
138 pub fn is_dropped(&self) -> bool {
139 self.inner.strong_count() == 0
140 }
141}
142
143impl Drop for PayloadSender {
144 fn drop(&mut self) {
145 if let Some(shared) = self.inner.upgrade() {
146 shared.borrow_mut().close_sender();
147 }
148 }
149}
150
151#[derive(Debug)]
152struct Inner {
153 len: usize,
154 eof: bool,
155 err: Option<PayloadError>,
156 sender_closed: bool,
157 need_read: bool,
158 items: VecDeque<Bytes>,
159 task: Option<Waker>,
160 io_task: Option<Waker>,
161}
162
163impl Inner {
164 fn new(eof: bool) -> Self {
165 Inner {
166 eof,
167 len: 0,
168 err: None,
169 sender_closed: eof,
170 items: VecDeque::new(),
171 need_read: true,
172 task: None,
173 io_task: None,
174 }
175 }
176
177 fn wake(&mut self) {
179 if let Some(waker) = self.task.take() {
180 waker.wake();
181 }
182 }
183
184 fn wake_io(&mut self) {
186 if let Some(waker) = self.io_task.take() {
187 waker.wake();
188 }
189 }
190
191 fn register(&mut self, cx: &Context<'_>) {
194 if self.task.as_ref().is_none_or(|w| !cx.waker().will_wake(w)) {
195 self.task = Some(cx.waker().clone());
196 }
197 }
198
199 fn register_io(&mut self, cx: &Context<'_>) {
202 if self
203 .io_task
204 .as_ref()
205 .is_none_or(|w| !cx.waker().will_wake(w))
206 {
207 self.io_task = Some(cx.waker().clone());
208 }
209 }
210
211 #[inline]
212 fn set_error(&mut self, err: PayloadError) {
213 self.sender_closed = true;
214 self.err = Some(err);
215 self.wake();
216 }
217
218 fn close_sender(&mut self) {
219 if !self.sender_closed {
220 self.sender_closed = true;
221 self.set_error(PayloadError::Incomplete(None));
222 }
223 }
224
225 #[inline]
226 fn feed_eof(&mut self) {
227 self.sender_closed = true;
228 self.eof = true;
229 self.wake();
230 }
231
232 #[inline]
233 fn feed_data(&mut self, data: Bytes) {
234 self.len += data.len();
235 self.items.push_back(data);
236 self.need_read = self.len < MAX_BUFFER_SIZE;
237 self.wake();
238 }
239
240 #[cfg(test)]
241 fn len(&self) -> usize {
242 self.len
243 }
244
245 fn poll_next(
246 mut self: Pin<&mut Self>,
247 cx: &Context<'_>,
248 ) -> Poll<Option<Result<Bytes, PayloadError>>> {
249 if let Some(data) = self.items.pop_front() {
250 self.len -= data.len();
251 self.need_read = self.len < MAX_BUFFER_SIZE;
252
253 if self.need_read && !self.eof {
254 self.register(cx);
255 }
256 self.wake_io();
257 Poll::Ready(Some(Ok(data)))
258 } else if let Some(err) = self.err.take() {
259 Poll::Ready(Some(Err(err)))
260 } else if self.eof {
261 Poll::Ready(None)
262 } else {
263 self.need_read = true;
264 self.register(cx);
265 self.wake_io();
266 Poll::Pending
267 }
268 }
269
270 fn unread_data(&mut self, data: Bytes) {
271 self.len += data.len();
272 self.items.push_front(data);
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use std::{task::Poll, time::Duration};
279
280 use actix_rt::time::timeout;
281 use actix_utils::future::poll_fn;
282 use futures_util::{FutureExt, StreamExt};
283 use static_assertions::{assert_impl_all, assert_not_impl_any};
284 use tokio::sync::oneshot;
285
286 use super::*;
287
288 assert_impl_all!(Payload: Unpin);
289 assert_not_impl_any!(Payload: Send, Sync);
290
291 assert_impl_all!(Inner: Unpin, Send, Sync);
292
293 const WAKE_TIMEOUT: Duration = Duration::from_secs(2);
294
295 fn prepare_waking_test(
296 mut payload: Payload,
297 expected: Option<Result<(), ()>>,
298 ) -> (oneshot::Receiver<()>, actix_rt::task::JoinHandle<()>) {
299 let (tx, rx) = oneshot::channel();
300
301 let handle = actix_rt::spawn(async move {
302 poll_fn(|cx| {
304 assert!(payload.poll_next_unpin(cx).is_pending());
305 Poll::Ready(())
306 })
307 .await;
308 tx.send(()).unwrap();
309
310 let mut pend_once = false;
312 poll_fn(|_| {
313 if pend_once {
314 Poll::Ready(())
315 } else {
316 pend_once = true;
319 Poll::Pending
320 }
321 })
322 .await;
323
324 let got = payload.next().now_or_never().unwrap();
325 match expected {
326 Some(Ok(_)) => assert!(got.unwrap().is_ok()),
327 Some(Err(_)) => assert!(got.unwrap().is_err()),
328 None => assert!(got.is_none()),
329 }
330 });
331 (rx, handle)
332 }
333
334 #[actix_rt::test]
335 async fn wake_on_error() {
336 let (mut sender, payload) = Payload::create(false);
337 let (rx, handle) = prepare_waking_test(payload, Some(Err(())));
338
339 rx.await.unwrap();
340 sender.set_error(PayloadError::Incomplete(None));
341 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
342 }
343
344 #[actix_rt::test]
345 async fn wake_on_eof() {
346 let (mut sender, payload) = Payload::create(false);
347 let (rx, handle) = prepare_waking_test(payload, None);
348
349 rx.await.unwrap();
350 sender.feed_eof();
351 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
352 }
353
354 #[actix_rt::test]
355 async fn wake_on_sender_drop() {
356 let (sender, payload) = Payload::create(false);
357 let (rx, handle) = prepare_waking_test(payload, Some(Err(())));
358
359 rx.await.unwrap();
360 drop(sender);
361 timeout(WAKE_TIMEOUT, handle).await.unwrap().unwrap();
362 }
363
364 #[actix_rt::test]
365 async fn test_unread_data() {
366 let (_, mut payload) = Payload::create(false);
367
368 payload.unread_data(Bytes::from("data"));
369 assert!(!payload.is_empty());
370 assert_eq!(payload.len(), 4);
371
372 assert_eq!(
373 Bytes::from("data"),
374 poll_fn(|cx| Pin::new(&mut payload).poll_next(cx))
375 .await
376 .unwrap()
377 .unwrap()
378 );
379 }
380}