micro_http/protocol/body/
body_channel.rs1use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader};
2use bytes::Bytes;
3use futures::{SinkExt, Stream, StreamExt, channel::mpsc};
4use http_body::{Body, Frame, SizeHint};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tracing::error;
8
9pub(crate) fn create_body_sender_receiver<S>(body_stream: &mut S, payload_size: PayloadSize) -> (BodySender<'_, S>, BodyReceiver)
10where
11 S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
12{
13 let (signal_sender, signal_receiver) = mpsc::channel(8);
14 let (data_sender, data_receiver) = mpsc::channel(8);
15
16 (BodySender::new(body_stream, signal_receiver, data_sender), BodyReceiver::new(signal_sender, data_receiver, payload_size))
17}
18
19pub(crate) enum BodyRequestSignal {
20 RequestData,
21 #[allow(dead_code)]
22 Enough,
23}
24
25pub(crate) struct BodySender<'conn, S> {
26 payload_stream: &'conn mut S,
27 signal_receiver: mpsc::Receiver<BodyRequestSignal>,
28 data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
29 eof: bool,
30}
31
32impl<'conn, S> BodySender<'conn, S>
33where
34 S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
35{
36 pub fn new(
37 payload_stream: &'conn mut S,
38 signal_receiver: mpsc::Receiver<BodyRequestSignal>,
39 data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
40 ) -> Self {
41 Self { payload_stream, signal_receiver, data_sender, eof: false }
42 }
43
44 pub(crate) async fn start(&mut self) -> Result<(), ParseError> {
45 if self.eof {
46 return Ok(());
47 }
48
49 while let Some(signal) = self.signal_receiver.next().await {
50 match signal {
51 BodyRequestSignal::RequestData => match self.read_data().await {
52 Ok(payload_item) => {
53 self.eof = payload_item.is_eof();
54 if let Err(e) = self.data_sender.send(Ok(payload_item)).await {
55 error!("failed to send payload body through channel, {}", e);
56 return Err(ParseError::invalid_body("send body data error"));
57 }
58
59 if self.eof {
60 return Ok(());
61 }
62 }
63
64 Err(e) => {
65 error!("failed to read data from body stream, {}", e);
66 if let Err(send_error) = self.data_sender.send(Err(e)).await {
67 error!("failed to send error through channel, {}", send_error);
68 return Err(ParseError::invalid_body("failed to send error through channel"));
69 }
70 break;
71 }
72 },
73
74 BodyRequestSignal::Enough => {
75 break;
76 }
77 }
78 }
79
80 self.skip_data().await
81 }
82
83 pub(crate) async fn read_data(&mut self) -> Result<PayloadItem, ParseError> {
84 match self.payload_stream.next().await {
85 Some(Ok(Message::Payload(payload_item))) => Ok(payload_item),
86 Some(Ok(Message::Header(_))) => {
87 error!("should not receive header in BodySender");
88 Err(ParseError::invalid_body("should not receive header in BodySender"))
89 }
90 Some(Err(e)) => Err(e),
91 None => {
92 error!("should not receive None in BodySender");
93 Err(ParseError::invalid_body("should not receive None in BodySender"))
94 }
95 }
96 }
97
98 pub(crate) async fn skip_data(&mut self) -> Result<(), ParseError> {
99 if self.eof {
100 return Ok(());
101 }
102
103 loop {
104 match self.read_data().await {
105 Ok(payload_item) if payload_item.is_eof() => {
106 self.eof = true;
107 return Ok(());
108 }
109 Ok(_payload_item) => {
110 }
112 Err(e) => return Err(e),
113 }
114 }
115 }
116}
117
118#[derive(Debug)]
119pub(crate) struct BodyReceiver {
120 signal_sender: mpsc::Sender<BodyRequestSignal>,
121 data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
122 payload_size: PayloadSize,
123}
124
125impl BodyReceiver {
126 pub(crate) fn new(
127 signal_sender: mpsc::Sender<BodyRequestSignal>,
128 data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
129 payload_size: PayloadSize,
130 ) -> Self {
131 Self { signal_sender, data_receiver, payload_size }
132 }
133}
134
135impl BodyReceiver {
136 pub async fn receive_data(&mut self) -> Result<PayloadItem, ParseError> {
137 if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await {
138 error!("failed to send request_more through channel, {}", e);
139 return Err(ParseError::invalid_body("failed to send signal when receive body data"));
140 }
141
142 self.data_receiver
143 .next()
144 .await
145 .unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data")))
146 }
147}
148
149impl Body for BodyReceiver {
150 type Data = Bytes;
151 type Error = ParseError;
152
153 fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
154 let this = self.get_mut();
155
156 tokio::pin! {
157 let future = this.receive_data();
158 }
159
160 match future.poll(cx) {
161 Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
162 Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None),
163 Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
164 Poll::Pending => Poll::Pending,
165 }
166 }
167
168 fn size_hint(&self) -> SizeHint {
169 self.payload_size.into()
170 }
171}
172
173impl From<SizeHint> for PayloadSize {
174 fn from(size_hint: SizeHint) -> Self {
175 match size_hint.exact() {
176 Some(0) => PayloadSize::new_empty(),
177 Some(length) => PayloadSize::new_length(length),
178 None => PayloadSize::new_chunked(),
179 }
180 }
181}
182
183impl From<PayloadSize> for SizeHint {
184 fn from(payload_size: PayloadSize) -> Self {
185 match payload_size {
186 PayloadSize::Length(length) => SizeHint::with_exact(length),
187 PayloadSize::Chunked => SizeHint::new(),
188 PayloadSize::Empty => SizeHint::with_exact(0),
189 }
190 }
191}