micro_http/protocol/body/
body_channel.rs1use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader};
2use bytes::Bytes;
3use futures::{SinkExt, Stream, StreamExt, channel::mpsc};
4use http_body::{Body, Frame, SizeHint};
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tracing::error;
9
10pub(crate) fn create_body_sender_receiver<S>(body_stream: &mut S, payload_size: PayloadSize) -> (BodySender<S>, BodyReceiver)
11where
12 S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
13{
14 let (signal_sender, signal_receiver) = mpsc::channel(8);
15 let (data_sender, data_receiver) = mpsc::channel(8);
16
17 (BodySender::new(body_stream, signal_receiver, data_sender), BodyReceiver::new(signal_sender, data_receiver, payload_size))
18}
19
20pub(crate) enum BodyRequestSignal {
21 RequestData,
22 #[allow(dead_code)]
23 Enough,
24}
25
26pub(crate) struct BodySender<'conn, S> {
27 payload_stream: &'conn mut S,
28 signal_receiver: mpsc::Receiver<BodyRequestSignal>,
29 data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
30 eof: bool,
31}
32
33impl<'conn, S> BodySender<'conn, S>
34where
35 S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
36{
37 pub fn new(
38 payload_stream: &'conn mut S,
39 signal_receiver: mpsc::Receiver<BodyRequestSignal>,
40 data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
41 ) -> Self {
42 Self { payload_stream, signal_receiver, data_sender, eof: false }
43 }
44
45 pub(crate) async fn start(&mut self) -> Result<(), ParseError> {
46 if self.eof {
47 return Ok(());
48 }
49
50 while let Some(signal) = self.signal_receiver.next().await {
51 match signal {
52 BodyRequestSignal::RequestData => match self.read_data().await {
53 Ok(payload_item) => {
54 self.eof = payload_item.is_eof();
55 if let Err(e) = self.data_sender.send(Ok(payload_item)).await {
56 error!("failed to send payload body through channel, {}", e);
57 return Err(ParseError::invalid_body("send body data error"));
58 }
59
60 if self.eof {
61 return Ok(());
62 }
63 }
64
65 Err(e) => {
66 error!("failed to read data from body stream, {}", e);
67 if let Err(send_error) = self.data_sender.send(Err(e)).await {
68 error!("failed to send error through channel, {}", send_error);
69 return Err(ParseError::invalid_body("failed to send error through channel"));
70 }
71 break;
72 }
73 },
74
75 BodyRequestSignal::Enough => {
76 break;
77 }
78 }
79 }
80
81 self.skip_data().await
82 }
83
84 pub(crate) async fn read_data(&mut self) -> Result<PayloadItem, ParseError> {
85 match self.payload_stream.next().await {
86 Some(Ok(Message::Payload(payload_item))) => Ok(payload_item),
87 Some(Ok(Message::Header(_))) => {
88 error!("should not receive header in BodySender");
89 Err(ParseError::invalid_body("should not receive header in BodySender"))
90 }
91 Some(Err(e)) => Err(e),
92 None => {
93 error!("should not receive None in BodySender");
94 Err(ParseError::invalid_body("should not receive None in BodySender"))
95 }
96 }
97 }
98
99 pub(crate) async fn skip_data(&mut self) -> Result<(), ParseError> {
100 if self.eof {
101 return Ok(());
102 }
103
104 loop {
105 match self.read_data().await {
106 Ok(payload_item) if payload_item.is_eof() => {
107 self.eof = true;
108 return Ok(());
109 }
110 Ok(_payload_item) => {
111 }
113 Err(e) => return Err(e),
114 }
115 }
116 }
117}
118
119pub(crate) struct BodyReceiver {
120 signal_sender: mpsc::Sender<BodyRequestSignal>,
121 data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
122 payload_size: PayloadSize,
123}
124
125impl BodyReceiver {
126 pub(crate) fn new(
127 signal_sender: mpsc::Sender<BodyRequestSignal>,
128 data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
129 payload_size: PayloadSize,
130 ) -> Self {
131 Self { signal_sender, data_receiver, payload_size }
132 }
133}
134
135impl BodyReceiver {
136 pub async fn receive_data(&mut self) -> Result<PayloadItem, ParseError> {
137 if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await {
138 error!("failed to send request_more through channel, {}", e);
139 return Err(ParseError::invalid_body("failed to send signal when receive body data"));
140 }
141
142 self.data_receiver
143 .next()
144 .await
145 .unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data")))
146 }
147}
148
149impl Body for BodyReceiver {
150 type Data = Bytes;
151 type Error = ParseError;
152
153 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
154 let this = self.get_mut();
155
156 tokio::pin! {
157 let future = this.receive_data();
158 }
159
160 match future.poll(cx) {
161 Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
162 Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None),
163 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
164 Poll::Pending => Poll::Pending,
165 }
166 }
167
168 fn size_hint(&self) -> SizeHint {
169 self.payload_size.into()
170 }
171}
172
173impl From<SizeHint> for PayloadSize {
174 fn from(size_hint: SizeHint) -> Self {
175 match size_hint.exact() {
176 Some(0) => PayloadSize::new_empty(),
177 Some(length) => PayloadSize::new_length(length),
178 None => PayloadSize::new_chunked(),
179 }
180 }
181}
182
183impl From<PayloadSize> for SizeHint {
184 fn from(payload_size: PayloadSize) -> Self {
185 match payload_size {
186 PayloadSize::Length(length) => SizeHint::with_exact(length),
187 PayloadSize::Chunked => SizeHint::new(),
188 PayloadSize::Empty => SizeHint::with_exact(0),
189 }
190 }
191}