micro_http/connection/
http_connection.rs

1use std::error::Error;
2use std::fmt::Display;
3use bytes::Bytes;
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use http::header::EXPECT;
8use http::{Response, StatusCode};
9use http_body::Body;
10use http_body_util::{BodyExt, Empty};
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
12
13use crate::codec::{RequestDecoder, ResponseEncoder};
14use crate::handler::Handler;
15use crate::protocol::body::ReqBody;
16use crate::protocol::{HttpError, Message, ParseError, PayloadItem, PayloadSize, RequestHeader, ResponseHead, SendError};
17
18use tokio_util::codec::{FramedRead, FramedWrite};
19use tracing::{error, info};
20
21/// An HTTP connection that manages request processing and response streaming
22///
23/// `HttpConnection` handles the full lifecycle of an HTTP connection, including:
24/// - Reading and decoding requests
25/// - Processing request headers and bodies
26/// - Handling expect-continue mechanism
27/// - Streaming responses back to clients
28///
29/// # Type Parameters
30///
31/// * `R`: The async readable stream type
32/// * `W`: The async writable stream type
33///
34pub struct HttpConnection<R, W> {
35    framed_read: FramedRead<R, RequestDecoder>,
36    framed_write: FramedWrite<W, ResponseEncoder>,
37}
38
39impl<R, W> HttpConnection<R, W>
40where
41    R: AsyncRead + Unpin,
42    W: AsyncWrite + Unpin,
43{
44    pub fn new(reader: R, writer: W) -> Self {
45        Self {
46            framed_read: FramedRead::with_capacity(reader, RequestDecoder::new(), 8 * 1024),
47            framed_write: FramedWrite::new(writer, ResponseEncoder::new()),
48        }
49    }
50
51    pub async fn process<H>(mut self, mut handler: Arc<H>) -> Result<(), HttpError>
52    where
53        H: Handler,
54        H::RespBody: Body<Data = Bytes> + Unpin,
55        <H::RespBody as Body>::Error: Display,
56    {
57        loop {
58            match self.framed_read.next().await {
59                Some(Ok(Message::Header((header, payload_size)))) => {
60                    self.do_process(header, payload_size, &mut handler).await?;
61                }
62
63                Some(Ok(Message::Payload(PayloadItem::Eof))) => continue,
64
65                Some(Ok(Message::Payload(_))) => {
66                    error!("error status because chunked has read in do_process");
67                    let error_response = build_error_response(StatusCode::BAD_REQUEST);
68                    self.do_send_response(error_response).await?;
69                    return Err(ParseError::invalid_body("need header while receive body").into());
70                }
71
72                Some(Err(ParseError::Io { source})) => {
73                    info!("connection io error: {}", source);
74                    return Ok(());
75                }
76
77                Some(Err(e)) => {
78                    error!("can't receive next request, cause {}", e);
79                    return Err(e.into());
80                }
81
82                None => {
83                    info!("can't read more request, break this connection down");
84                    return Ok(());
85                }
86            }
87        }
88    }
89
90    async fn do_process<H>(&mut self, header: RequestHeader, payload_size: PayloadSize, handler: &mut Arc<H>) -> Result<(), HttpError>
91    where
92        H: Handler,
93        H::RespBody: Body<Data = Bytes> + Unpin,
94        <H::RespBody as Body>::Error: Display,
95    {
96        // Check if the request header contains the "Expect: 100-continue" field.
97        if let Some(value) = header.headers().get(EXPECT) {
98            let slice = value.as_bytes();
99            // Verify if the value of the "Expect" field is "100-continue".
100            if slice.len() >= 4 && &slice[0..4] == b"100-" {
101                let writer = self.framed_write.get_mut();
102                // Send a "100 Continue" response to the client.
103                let _ = writer.write(b"HTTP/1.1 100 Continue\r\n\r\n").await.map_err(SendError::io)?;
104                writer.flush().await.map_err(SendError::io)?;
105                // Log the event of sending a "100 Continue" response.
106                info!("receive expect request header, sent continue response");
107            }
108        }
109
110        let (req_body, maybe_body_sender) = ReqBody::create_req_body(&mut self.framed_read, payload_size);
111        let request = header.body(req_body);
112
113        let response_result = match maybe_body_sender {
114            None => handler.call(request).await,
115            Some(mut body_sender) => {
116                let (handler_result, body_send_result) = tokio::join!(handler.call(request), body_sender.start());
117
118                // check if body sender has error
119                body_send_result?;
120                handler_result
121            }
122        };
123
124        self.send_response(response_result).await
125    }
126
127    async fn send_response<T, E>(&mut self, response_result: Result<Response<T>, E>) -> Result<(), HttpError>
128    where
129        T: Body + Unpin,
130        T::Error: Display,
131        E: Into<Box<dyn Error + Send + Sync>>,
132    {
133        match response_result {
134            Ok(response) => self.do_send_response(response).await,
135            Err(e) => {
136                error!("handle response error, cause: {}", e.into());
137                let error_response = build_error_response(StatusCode::INTERNAL_SERVER_ERROR);
138                self.do_send_response(error_response).await
139            }
140        }
141    }
142
143    async fn do_send_response<T>(&mut self, response: Response<T>) -> Result<(), HttpError>
144    where
145        T: Body + Unpin,
146        T::Error: Display,
147    {
148        let (header_parts, mut body) = response.into_parts();
149
150        let payload_size = {
151            let size_hint = body.size_hint();
152            match size_hint.exact() {
153                Some(0) => PayloadSize::Empty,
154                Some(length) => PayloadSize::Length(length),
155                None => PayloadSize::Chunked,
156            }
157        };
158
159        let header = Message::<_, T::Data>::Header((ResponseHead::from_parts(header_parts, ()), payload_size));
160        if !payload_size.is_empty() {
161            self.framed_write.feed(header).await?;
162        } else {
163            // using send instead of feed, because we want to flush the underlying IO
164            // when response only has header, we need to send header,
165            // otherwise, we just feed header to the buffer
166            self.framed_write.send(header).await?;
167        }
168
169        loop {
170            match body.frame().await {
171                Some(Ok(frame)) => {
172                    let payload_item =
173                        frame.into_data().map(PayloadItem::Chunk).map_err(|_e| SendError::invalid_body("resolve body response error"))?;
174
175                    self.framed_write
176                        .send(Message::Payload(payload_item))
177                        .await
178                        .map_err(|_e| SendError::invalid_body("can't send response"))?;
179                }
180                Some(Err(e)) => return Err(SendError::invalid_body(format!("resolve response body error: {e}")).into()),
181                None => {
182                    self.framed_write
183                        // using feed instead of send, because we don't want to flush the underlying IO
184                        .feed(Message::Payload(PayloadItem::<T::Data>::Eof))
185                        .await
186                        .map_err(|e| SendError::invalid_body(format!("can't send eof response: {}", e)))?;
187                    return Ok(());
188                }
189            }
190        }
191    }
192}
193
194fn build_error_response(status_code: StatusCode) -> Response<Empty<Bytes>> {
195    Response::builder().status(status_code).body(Empty::<Bytes>::new()).unwrap()
196}