micro_http/protocol/body/
body_channel.rs

1use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader};
2use bytes::Bytes;
3use futures::{SinkExt, Stream, StreamExt, channel::mpsc};
4use http_body::{Body, Frame, SizeHint};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tracing::error;
9
10pub(crate) fn create_body_sender_receiver<S>(body_stream: &mut S, payload_size: PayloadSize) -> (BodySender<S>, BodyReceiver)
11where
12    S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
13{
14    let (signal_sender, signal_receiver) = mpsc::channel(8);
15    let (data_sender, data_receiver) = mpsc::channel(8);
16
17    (BodySender::new(body_stream, signal_receiver, data_sender), BodyReceiver::new(signal_sender, data_receiver, payload_size))
18}
19
20pub(crate) enum BodyRequestSignal {
21    RequestData,
22    #[allow(dead_code)]
23    Enough,
24}
25
26pub(crate) struct BodySender<'conn, S> {
27    payload_stream: &'conn mut S,
28    signal_receiver: mpsc::Receiver<BodyRequestSignal>,
29    data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
30    eof: bool,
31}
32
33impl<'conn, S> BodySender<'conn, S>
34where
35    S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
36{
37    pub fn new(
38        payload_stream: &'conn mut S,
39        signal_receiver: mpsc::Receiver<BodyRequestSignal>,
40        data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
41    ) -> Self {
42        Self { payload_stream, signal_receiver, data_sender, eof: false }
43    }
44
45    pub(crate) async fn start(&mut self) -> Result<(), ParseError> {
46        if self.eof {
47            return Ok(());
48        }
49
50        while let Some(signal) = self.signal_receiver.next().await {
51            match signal {
52                BodyRequestSignal::RequestData => match self.read_data().await {
53                    Ok(payload_item) => {
54                        self.eof = payload_item.is_eof();
55                        if let Err(e) = self.data_sender.send(Ok(payload_item)).await {
56                            error!("failed to send payload body through channel, {}", e);
57                            return Err(ParseError::invalid_body("send body data error"));
58                        }
59
60                        if self.eof {
61                            return Ok(());
62                        }
63                    }
64
65                    Err(e) => {
66                        error!("failed to read data from body stream, {}", e);
67                        if let Err(send_error) = self.data_sender.send(Err(e)).await {
68                            error!("failed to send error through channel, {}", send_error);
69                            return Err(ParseError::invalid_body("failed to send error through channel"));
70                        }
71                        break;
72                    }
73                },
74
75                BodyRequestSignal::Enough => {
76                    break;
77                }
78            }
79        }
80
81        self.skip_data().await
82    }
83
84    pub(crate) async fn read_data(&mut self) -> Result<PayloadItem, ParseError> {
85        match self.payload_stream.next().await {
86            Some(Ok(Message::Payload(payload_item))) => Ok(payload_item),
87            Some(Ok(Message::Header(_))) => {
88                error!("should not receive header in BodySender");
89                Err(ParseError::invalid_body("should not receive header in BodySender"))
90            }
91            Some(Err(e)) => Err(e),
92            None => {
93                error!("should not receive None in BodySender");
94                Err(ParseError::invalid_body("should not receive None in BodySender"))
95            }
96        }
97    }
98
99    pub(crate) async fn skip_data(&mut self) -> Result<(), ParseError> {
100        if self.eof {
101            return Ok(());
102        }
103
104        loop {
105            match self.read_data().await {
106                Ok(payload_item) if payload_item.is_eof() => {
107                    self.eof = true;
108                    return Ok(());
109                }
110                Ok(_payload_item) => {
111                    // drop payload_item
112                }
113                Err(e) => return Err(e),
114            }
115        }
116    }
117}
118
119pub(crate) struct BodyReceiver {
120    signal_sender: mpsc::Sender<BodyRequestSignal>,
121    data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
122    payload_size: PayloadSize,
123}
124
125impl BodyReceiver {
126    pub(crate) fn new(
127        signal_sender: mpsc::Sender<BodyRequestSignal>,
128        data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
129        payload_size: PayloadSize,
130    ) -> Self {
131        Self { signal_sender, data_receiver, payload_size }
132    }
133}
134
135impl BodyReceiver {
136    pub async fn receive_data(&mut self) -> Result<PayloadItem, ParseError> {
137        if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await {
138            error!("failed to send request_more through channel, {}", e);
139            return Err(ParseError::invalid_body("failed to send signal when receive body data"));
140        }
141
142        self.data_receiver
143            .next()
144            .await
145            .unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data")))
146    }
147}
148
149impl Body for BodyReceiver {
150    type Data = Bytes;
151    type Error = ParseError;
152
153    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
154        let this = self.get_mut();
155
156        tokio::pin! {
157            let future = this.receive_data();
158        }
159
160        match future.poll(cx) {
161            Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
162            Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None),
163            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
164            Poll::Pending => Poll::Pending,
165        }
166    }
167
168    fn size_hint(&self) -> SizeHint {
169        self.payload_size.into()
170    }
171}
172
173impl From<SizeHint> for PayloadSize {
174    fn from(size_hint: SizeHint) -> Self {
175        match size_hint.exact() {
176            Some(0) => PayloadSize::new_empty(),
177            Some(length) => PayloadSize::new_length(length),
178            None => PayloadSize::new_chunked(),
179        }
180    }
181}
182
183impl From<PayloadSize> for SizeHint {
184    fn from(payload_size: PayloadSize) -> Self {
185        match payload_size {
186            PayloadSize::Length(length) => SizeHint::with_exact(length),
187            PayloadSize::Chunked => SizeHint::new(),
188            PayloadSize::Empty => SizeHint::with_exact(0),
189        }
190    }
191}