micro_http/connection/
http_connection.rs

1use std::error::Error;
2use std::fmt::Display;
3
4use bytes::Bytes;
5use std::sync::Arc;
6
7use futures::{SinkExt, StreamExt};
8use http::header::EXPECT;
9use http::{Response, StatusCode};
10use http_body::Body;
11use http_body_util::{BodyExt, Empty};
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
13
14use crate::codec::{RequestDecoder, ResponseEncoder};
15use crate::handler::Handler;
16use crate::protocol::body::ReqBody;
17use crate::protocol::{HttpError, Message, ParseError, PayloadItem, PayloadSize, RequestHeader, ResponseHead, SendError};
18
19use tokio_util::codec::{FramedRead, FramedWrite};
20use tracing::{error, info};
21
22/// An HTTP connection that manages request processing and response streaming
23///
24/// `HttpConnection` handles the full lifecycle of an HTTP connection, including:
25/// - Reading and decoding requests
26/// - Processing request headers and bodies
27/// - Handling expect-continue mechanism
28/// - Streaming responses back to clients
29///
30/// # Type Parameters
31///
32/// * `R`: The async readable stream type
33/// * `W`: The async writable stream type
34///
35pub struct HttpConnection<R, W> {
36    framed_read: FramedRead<R, RequestDecoder>,
37    framed_write: FramedWrite<W, ResponseEncoder>,
38}
39
40impl<R, W> HttpConnection<R, W>
41where
42    R: AsyncRead + Unpin,
43    W: AsyncWrite + Unpin,
44{
45    pub fn new(reader: R, writer: W) -> Self {
46        Self {
47            framed_read: FramedRead::with_capacity(reader, RequestDecoder::new(), 8 * 1024),
48            framed_write: FramedWrite::new(writer, ResponseEncoder::new()),
49        }
50    }
51
52    pub async fn process<H>(mut self, mut handler: Arc<H>) -> Result<(), HttpError>
53    where
54        H: Handler,
55        H::RespBody: Body<Data = Bytes> + Unpin,
56        <H::RespBody as Body>::Error: Display,
57    {
58        loop {
59            match self.framed_read.next().await {
60                Some(Ok(Message::Header((header, payload_size)))) => {
61                    self.do_process(header, payload_size, &mut handler).await?;
62                }
63
64                Some(Ok(Message::Payload(PayloadItem::Eof))) => continue,
65
66                Some(Ok(Message::Payload(_))) => {
67                    error!("error status because chunked has read in do_process");
68                    let error_response = build_error_response(StatusCode::BAD_REQUEST);
69                    self.do_send_response(error_response).await?;
70                    return Err(ParseError::invalid_body("need header while receive body").into());
71                }
72
73                Some(Err(e)) => {
74                    error!("can't receive next request, cause {}", e);
75                    let error_response = build_error_response(StatusCode::BAD_REQUEST);
76                    self.do_send_response(error_response).await?;
77                    return Err(e.into());
78                }
79
80                None => {
81                    info!("cant read more request, break this connection down");
82                    return Ok(());
83                }
84            }
85        }
86    }
87
88    async fn do_process<H>(&mut self, header: RequestHeader, payload_size: PayloadSize, handler: &mut Arc<H>) -> Result<(), HttpError>
89    where
90        H: Handler,
91        H::RespBody: Body<Data = Bytes> + Unpin,
92        <H::RespBody as Body>::Error: Display,
93    {
94        // Check if the request header contains the "Expect: 100-continue" field.
95        if let Some(value) = header.headers().get(EXPECT) {
96            let slice = value.as_bytes();
97            // Verify if the value of the "Expect" field is "100-continue".
98            if slice.len() >= 4 && &slice[0..4] == b"100-" {
99                let writer = self.framed_write.get_mut();
100                // Send a "100 Continue" response to the client.
101                let _ = writer.write(b"HTTP/1.1 100 Continue\r\n\r\n").await.map_err(SendError::io)?;
102                writer.flush().await.map_err(SendError::io)?;
103                // Log the event of sending a "100 Continue" response.
104                info!("receive expect request header, sent continue response");
105            }
106        }
107
108        let (req_body, maybe_body_sender) = ReqBody::create_req_body(&mut self.framed_read, payload_size);
109        let request = header.body(req_body);
110
111        let response_result = match maybe_body_sender {
112            None => handler.call(request).await,
113            Some(mut body_sender) => {
114                let (handler_result, body_send_result) = tokio::join!(handler.call(request), body_sender.start());
115
116                // check if body sender has error
117                body_send_result?;
118                handler_result
119            }
120        };
121
122        self.send_response(response_result).await
123    }
124
125    async fn send_response<T, E>(&mut self, response_result: Result<Response<T>, E>) -> Result<(), HttpError>
126    where
127        T: Body + Unpin,
128        T::Error: Display,
129        E: Into<Box<dyn Error + Send + Sync>>,
130    {
131        match response_result {
132            Ok(response) => self.do_send_response(response).await,
133            Err(e) => {
134                error!("handle response error, cause: {}", e.into());
135                let error_response = build_error_response(StatusCode::INTERNAL_SERVER_ERROR);
136                self.do_send_response(error_response).await
137            }
138        }
139    }
140
141    async fn do_send_response<T>(&mut self, response: Response<T>) -> Result<(), HttpError>
142    where
143        T: Body + Unpin,
144        T::Error: Display,
145    {
146        let (header_parts, mut body) = response.into_parts();
147
148        let payload_size = {
149            let size_hint = body.size_hint();
150            match size_hint.exact() {
151                Some(0) => PayloadSize::Empty,
152                Some(length) => PayloadSize::Length(length),
153                None => PayloadSize::Chunked,
154            }
155        };
156
157        let header = Message::<_, T::Data>::Header((ResponseHead::from_parts(header_parts, ()), payload_size));
158        if !payload_size.is_empty() {
159            self.framed_write.feed(header).await?;
160        } else {
161            // using send instead of feed, because we want to flush the underlying IO
162            // when response only has header, we need to send header,
163            // otherwise, we just feed header to the buffer
164            self.framed_write.send(header).await?;
165        }
166
167        loop {
168            match body.frame().await {
169                Some(Ok(frame)) => {
170                    let payload_item =
171                        frame.into_data().map(PayloadItem::Chunk).map_err(|_e| SendError::invalid_body("resolve body response error"))?;
172
173                    self.framed_write
174                        .send(Message::Payload(payload_item))
175                        .await
176                        .map_err(|_e| SendError::invalid_body("can't send response"))?;
177                }
178                Some(Err(e)) => return Err(SendError::invalid_body(format!("resolve response body error: {e}")).into()),
179                None => {
180                    self.framed_write
181                        // using feed instead of send, because we don't want to flush the underlying IO
182                        .feed(Message::Payload(PayloadItem::<T::Data>::Eof))
183                        .await
184                        .map_err(|e| SendError::invalid_body(format!("can't send eof response: {}", e)))?;
185                    return Ok(());
186                }
187            }
188        }
189    }
190}
191
192fn build_error_response(status_code: StatusCode) -> Response<Empty<Bytes>> {
193    Response::builder().status(status_code).body(Empty::<Bytes>::new()).unwrap()
194}