micro_http/connection/
http_connection.rs

1use bytes::Bytes;
2use std::error::Error;
3use std::fmt::{Debug, Display};
4use std::sync::Arc;
5
6use futures::StreamExt;
7use http::header::EXPECT;
8use http::{Response, StatusCode};
9use http_body::Body;
10use http_body_util::{BodyExt, Empty};
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
12
13use crate::codec::RequestDecoder;
14use crate::handler::Handler;
15use crate::protocol::body::ReqBody;
16use crate::protocol::{HttpError, Message, ParseError, PayloadItem, PayloadSize, RequestHeader, ResponseHead, SendError};
17
18use crate::connection::message_writer::MessageWriter;
19use tokio_util::codec::FramedRead;
20use tracing::{error, info};
21
22/// An HTTP connection that manages request processing and response streaming
23///
24/// `HttpConnection` handles the full lifecycle of an HTTP connection, including:
25/// - Reading and decoding requests
26/// - Processing request headers and bodies
27/// - Handling expect-continue mechanism
28/// - Streaming responses back to clients
29///
30/// # Type Parameters
31///
32/// * `R`: The async readable stream type
33/// * `W`: The async writable stream type
34///
35#[derive(Debug)]
36pub struct HttpConnection<R, W>
37where
38    R: Debug,
39    W: Debug,
40{
41    framed_read: FramedRead<R, RequestDecoder>,
42    message_writer: MessageWriter<W>,
43}
44
45impl<R, W> HttpConnection<R, W>
46where
47    R: AsyncRead + Unpin + Debug,
48    W: AsyncWrite + Unpin + Debug,
49{
50    pub fn new(reader: R, writer: W) -> Self {
51        Self {
52            framed_read: FramedRead::with_capacity(reader, RequestDecoder::new(), 8 * 1024),
53            message_writer: MessageWriter::with_capacity(writer, 8 * 1024),
54        }
55    }
56
57    pub async fn process<H>(mut self, mut handler: Arc<H>) -> Result<(), HttpError>
58    where
59        H: Handler,
60        H::RespBody: Body<Data = Bytes> + Unpin,
61        <H::RespBody as Body>::Error: Display,
62    {
63        loop {
64            match self.framed_read.next().await {
65                Some(Ok(Message::Header((header, payload_size)))) => {
66                    self.do_process(header, payload_size, &mut handler).await?;
67                }
68
69                Some(Ok(Message::Payload(PayloadItem::Eof))) => continue,
70
71                Some(Ok(Message::Payload(_))) => {
72                    error!("error status because chunked has read in do_process");
73                    let error_response = build_error_response(StatusCode::BAD_REQUEST);
74                    self.do_send_response(error_response).await?;
75                    return Err(ParseError::invalid_body("need header while receive body").into());
76                }
77
78                Some(Err(ParseError::Io { source })) => {
79                    info!("connection io error: {}", source);
80                    return Ok(());
81                }
82
83                Some(Err(e)) => {
84                    error!("can't receive next request, cause {}", e);
85                    return Err(e.into());
86                }
87
88                None => {
89                    info!("can't read more request, break this connection down");
90                    return Ok(());
91                }
92            }
93        }
94    }
95
96    async fn do_process<H>(&mut self, header: RequestHeader, payload_size: PayloadSize, handler: &mut Arc<H>) -> Result<(), HttpError>
97    where
98        H: Handler,
99        H::RespBody: Body<Data = Bytes> + Unpin,
100        <H::RespBody as Body>::Error: Display,
101    {
102        // Check if the request header contains the "Expect: 100-continue" field.
103        if let Some(value) = header.headers().get(EXPECT) {
104            let slice = value.as_bytes();
105            // Verify if the value of the "Expect" field is "100-continue".
106            if slice.len() >= 4 && &slice[0..4] == b"100-" {
107                let writer = self.message_writer.get_mut();
108                // Send a "100 Continue" response to the client.
109                let _ = writer.write(b"HTTP/1.1 100 Continue\r\n\r\n").await.map_err(SendError::io)?;
110                writer.flush().await.map_err(SendError::io)?;
111                // Log the event of sending a "100 Continue" response.
112                info!("receive expect request header, sent continue response");
113            }
114        }
115
116        let (req_body, maybe_body_sender) = ReqBody::create_req_body(&mut self.framed_read, payload_size);
117        let request = header.body(req_body);
118
119        let response_result = match maybe_body_sender {
120            None => handler.call(request).await,
121            Some(mut body_sender) => {
122                let (handler_result, body_send_result) = tokio::join!(handler.call(request), body_sender.start());
123
124                // check if body sender has error
125                body_send_result?;
126                handler_result
127            }
128        };
129
130        self.send_response(response_result).await
131    }
132
133    async fn send_response<T, E>(&mut self, response_result: Result<Response<T>, E>) -> Result<(), HttpError>
134    where
135        T: Body + Unpin,
136        T::Error: Display,
137        E: Into<Box<dyn Error + Send + Sync>>,
138    {
139        match response_result {
140            Ok(response) => self.do_send_response(response).await,
141            Err(e) => {
142                error!("handle response error, cause: {}", e.into());
143                let error_response = build_error_response(StatusCode::INTERNAL_SERVER_ERROR);
144                self.do_send_response(error_response).await
145            }
146        }
147    }
148
149    async fn do_send_response<T>(&mut self, response: Response<T>) -> Result<(), HttpError>
150    where
151        T: Body + Unpin,
152        T::Error: Display,
153    {
154        let (header_parts, mut body) = response.into_parts();
155
156        let payload_size = {
157            let size_hint = body.size_hint();
158            match size_hint.exact() {
159                Some(0) => PayloadSize::Empty,
160                Some(length) => PayloadSize::Length(length),
161                None => PayloadSize::Chunked,
162            }
163        };
164
165        let header = Message::<_, T::Data>::Header((ResponseHead::from_parts(header_parts, ()), payload_size));
166
167        self.message_writer.write(header)?;
168
169        loop {
170            match body.frame().await {
171                Some(Ok(frame)) => {
172                    let payload_item =
173                        frame.into_data().map(PayloadItem::Chunk).map_err(|_e| SendError::invalid_body("resolve body response error"))?;
174
175                    self.message_writer
176                        .write(Message::Payload(payload_item))
177                        .map_err(|_e| SendError::invalid_body("can't send response"))?;
178                }
179                Some(Err(e)) => return Err(SendError::invalid_body(format!("resolve response body error: {e}")).into()),
180                None => {
181                    self.message_writer
182                        // using feed instead of send, because we don't want to flush the underlying IO
183                        .write(Message::Payload(PayloadItem::<T::Data>::Eof))
184                        .map_err(|e| SendError::invalid_body(format!("can't send eof response: {}", e)))?;
185                    self.message_writer.flush().await?;
186                    self.message_writer.clear_buf();
187                    return Ok(());
188                }
189            }
190        }
191    }
192}
193
194fn build_error_response(status_code: StatusCode) -> Response<Empty<Bytes>> {
195    Response::builder().status(status_code).body(Empty::<Bytes>::new()).unwrap()
196}