1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
use std::pin::Pin;
use std::task::{ready, Context, Poll};

use bytes::Bytes;

use futures::channel::{mpsc, oneshot};
use futures::{FutureExt, SinkExt, Stream, StreamExt};

use http_body::{Body, Frame};
use tracing::{error, info};

use crate::protocol::{Message, ParseError, PayloadItem, RequestHeader};

pub struct ReqBody {
    signal: mpsc::Sender<oneshot::Sender<PayloadItem>>,
    receiving: Option<oneshot::Receiver<PayloadItem>>,
}

impl ReqBody {
    fn new(signal: mpsc::Sender<oneshot::Sender<PayloadItem>>) -> Self {
        Self { signal, receiving: None }
    }

    pub fn body_channel<S>(payload_stream: &mut S) -> (ReqBody, ReqBodySender<S>)
    where
        S: Stream + Unpin,
    {
        let (tx, receiver) = mpsc::channel(16);

        let req_body = ReqBody::new(tx);

        let body_sender = ReqBodySender { payload_stream, receiver, eof: false };

        (req_body, body_sender)
    }
}

pub struct ReqBodySender<'conn, S>
where
    S: Stream + Unpin,
{
    payload_stream: &'conn mut S,
    receiver: mpsc::Receiver<oneshot::Sender<PayloadItem>>,
    eof: bool,
}

impl<'conn, S> ReqBodySender<'conn, S>
where
    S: Stream<Item = Result<Message<RequestHeader>, ParseError>> + Unpin,
{
    pub async fn send_body(&mut self) -> Result<(), ParseError> {
        loop {
            if self.eof {
                return Ok(());
            }

            if let Some(sender) = self.receiver.next().await {
                match self.payload_stream.next().await {
                    Some(Ok(Message::Payload(payload_item))) => {
                        if payload_item.is_eof() {
                            self.eof = true;
                        }
                        sender.send(payload_item).unwrap();
                    }

                    Some(Ok(Message::Header(_header))) => {
                        error!("received header from receive body phase");
                        return Err(ParseError::invalid_body("received header from receive body phase"));
                    }

                    Some(Err(e)) => {
                        return Err(e);
                    }

                    None => {
                        error!("cant read body");
                        return Err(ParseError::invalid_body("cant read body"));
                    }
                }
            }
        }
    }

    pub async fn skip_body(&mut self) {
        if !self.eof {
            let mut size: usize = 0;
            while let Some(Ok(Message::Payload(payload_item))) = self.payload_stream.next().await {
                if payload_item.is_eof() {
                    self.eof = true;
                    if size > 0 {
                        info!(size = size, "skip request body");
                    }
                    break;
                }

                if let Some(bytes) = payload_item.as_bytes() {
                    size += bytes.len();
                }
            }
        }
    }
}

impl Body for ReqBody {
    type Data = Bytes;
    type Error = ParseError;

    fn poll_frame(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        loop {
            if let Some(oneshot_receiver) = &mut self.receiving {
                return match ready!(oneshot_receiver.poll_unpin(cx)) {
                    Ok(PayloadItem::Chunk(bytes)) => {
                        self.receiving.take();
                        Poll::Ready(Some(Ok(Frame::data(bytes))))
                    }
                    Ok(PayloadItem::Eof) => {
                        self.receiving.take();
                        Poll::Ready(None)
                    }
                    Err(_) => {
                        self.receiving.take();
                        Poll::Ready(Some(Err(ParseError::invalid_body("parse body canceled"))))
                    }
                };
            }

            match ready!(self.signal.poll_ready_unpin(cx)) {
                Ok(_) => {
                    let (tx, rx) = oneshot::channel();
                    match self.signal.start_send(tx) {
                        Ok(_) => {
                            self.receiving = Some(rx);
                            continue;
                        }
                        Err(e) => return Poll::Ready(Some(Err(ParseError::invalid_body(e)))),
                    }
                }
                Err(e) => return Poll::Ready(Some(Err(ParseError::invalid_body(e)))),
            };
        }
    }
}