micro_http/protocol/body/
body_channel.rs

1use crate::protocol::{Message, ParseError, PayloadItem, PayloadSize, RequestHeader};
2use bytes::Bytes;
3use futures::{SinkExt, Stream, StreamExt, channel::mpsc};
4use http_body::{Body, Frame, SizeHint};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tracing::error;
8
9pub(crate) fn create_body_sender_receiver<S>(body_stream: &mut S, payload_size: PayloadSize) -> (BodySender<'_, S>, BodyReceiver)
10where
11    S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
12{
13    let (signal_sender, signal_receiver) = mpsc::channel(8);
14    let (data_sender, data_receiver) = mpsc::channel(8);
15
16    (BodySender::new(body_stream, signal_receiver, data_sender), BodyReceiver::new(signal_sender, data_receiver, payload_size))
17}
18
19pub(crate) enum BodyRequestSignal {
20    RequestData,
21    #[allow(dead_code)]
22    Enough,
23}
24
25pub(crate) struct BodySender<'conn, S> {
26    payload_stream: &'conn mut S,
27    signal_receiver: mpsc::Receiver<BodyRequestSignal>,
28    data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
29    eof: bool,
30}
31
32impl<'conn, S> BodySender<'conn, S>
33where
34    S: Stream<Item = Result<Message<(RequestHeader, PayloadSize)>, ParseError>> + Unpin,
35{
36    pub fn new(
37        payload_stream: &'conn mut S,
38        signal_receiver: mpsc::Receiver<BodyRequestSignal>,
39        data_sender: mpsc::Sender<Result<PayloadItem, ParseError>>,
40    ) -> Self {
41        Self { payload_stream, signal_receiver, data_sender, eof: false }
42    }
43
44    pub(crate) async fn start(&mut self) -> Result<(), ParseError> {
45        if self.eof {
46            return Ok(());
47        }
48
49        while let Some(signal) = self.signal_receiver.next().await {
50            match signal {
51                BodyRequestSignal::RequestData => match self.read_data().await {
52                    Ok(payload_item) => {
53                        self.eof = payload_item.is_eof();
54                        if let Err(e) = self.data_sender.send(Ok(payload_item)).await {
55                            error!("failed to send payload body through channel, {}", e);
56                            return Err(ParseError::invalid_body("send body data error"));
57                        }
58
59                        if self.eof {
60                            return Ok(());
61                        }
62                    }
63
64                    Err(e) => {
65                        error!("failed to read data from body stream, {}", e);
66                        if let Err(send_error) = self.data_sender.send(Err(e)).await {
67                            error!("failed to send error through channel, {}", send_error);
68                            return Err(ParseError::invalid_body("failed to send error through channel"));
69                        }
70                        break;
71                    }
72                },
73
74                BodyRequestSignal::Enough => {
75                    break;
76                }
77            }
78        }
79
80        self.skip_data().await
81    }
82
83    pub(crate) async fn read_data(&mut self) -> Result<PayloadItem, ParseError> {
84        match self.payload_stream.next().await {
85            Some(Ok(Message::Payload(payload_item))) => Ok(payload_item),
86            Some(Ok(Message::Header(_))) => {
87                error!("should not receive header in BodySender");
88                Err(ParseError::invalid_body("should not receive header in BodySender"))
89            }
90            Some(Err(e)) => Err(e),
91            None => {
92                error!("should not receive None in BodySender");
93                Err(ParseError::invalid_body("should not receive None in BodySender"))
94            }
95        }
96    }
97
98    pub(crate) async fn skip_data(&mut self) -> Result<(), ParseError> {
99        if self.eof {
100            return Ok(());
101        }
102
103        loop {
104            match self.read_data().await {
105                Ok(payload_item) if payload_item.is_eof() => {
106                    self.eof = true;
107                    return Ok(());
108                }
109                Ok(_payload_item) => {
110                    // drop payload_item
111                }
112                Err(e) => return Err(e),
113            }
114        }
115    }
116}
117
118#[derive(Debug)]
119pub(crate) struct BodyReceiver {
120    signal_sender: mpsc::Sender<BodyRequestSignal>,
121    data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
122    payload_size: PayloadSize,
123}
124
125impl BodyReceiver {
126    pub(crate) fn new(
127        signal_sender: mpsc::Sender<BodyRequestSignal>,
128        data_receiver: mpsc::Receiver<Result<PayloadItem, ParseError>>,
129        payload_size: PayloadSize,
130    ) -> Self {
131        Self { signal_sender, data_receiver, payload_size }
132    }
133}
134
135impl BodyReceiver {
136    pub async fn receive_data(&mut self) -> Result<PayloadItem, ParseError> {
137        if let Err(e) = self.signal_sender.send(BodyRequestSignal::RequestData).await {
138            error!("failed to send request_more through channel, {}", e);
139            return Err(ParseError::invalid_body("failed to send signal when receive body data"));
140        }
141
142        self.data_receiver
143            .next()
144            .await
145            .unwrap_or_else(|| Err(ParseError::invalid_body("body stream should not receive None when receive data")))
146    }
147}
148
149impl Body for BodyReceiver {
150    type Data = Bytes;
151    type Error = ParseError;
152
153    fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
154        let this = self.get_mut();
155
156        tokio::pin! {
157            let future = this.receive_data();
158        }
159
160        match future.poll(cx) {
161            Poll::Ready(Ok(PayloadItem::Chunk(bytes))) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
162            Poll::Ready(Ok(PayloadItem::Eof)) => Poll::Ready(None),
163            Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
164            Poll::Pending => Poll::Pending,
165        }
166    }
167
168    fn size_hint(&self) -> SizeHint {
169        self.payload_size.into()
170    }
171}
172
173impl From<SizeHint> for PayloadSize {
174    fn from(size_hint: SizeHint) -> Self {
175        match size_hint.exact() {
176            Some(0) => PayloadSize::new_empty(),
177            Some(length) => PayloadSize::new_length(length),
178            None => PayloadSize::new_chunked(),
179        }
180    }
181}
182
183impl From<PayloadSize> for SizeHint {
184    fn from(payload_size: PayloadSize) -> Self {
185        match payload_size {
186            PayloadSize::Length(length) => SizeHint::with_exact(length),
187            PayloadSize::Chunked => SizeHint::new(),
188            PayloadSize::Empty => SizeHint::with_exact(0),
189        }
190    }
191}