micro_http/connection/
http_connection.rs1use std::error::Error;
2use std::fmt::Display;
3
4use bytes::Bytes;
5use std::sync::Arc;
6
7use futures::{SinkExt, StreamExt};
8use http::header::EXPECT;
9use http::{Response, StatusCode};
10use http_body::Body;
11use http_body_util::{BodyExt, Empty};
12use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
13
14use crate::codec::{RequestDecoder, ResponseEncoder};
15use crate::handler::Handler;
16use crate::protocol::body::ReqBody;
17use crate::protocol::{HttpError, Message, ParseError, PayloadItem, PayloadSize, RequestHeader, ResponseHead, SendError};
18
19use tokio_util::codec::{FramedRead, FramedWrite};
20use tracing::{error, info};
21
22pub struct HttpConnection<R, W> {
36 framed_read: FramedRead<R, RequestDecoder>,
37 framed_write: FramedWrite<W, ResponseEncoder>,
38}
39
40impl<R, W> HttpConnection<R, W>
41where
42 R: AsyncRead + Unpin,
43 W: AsyncWrite + Unpin,
44{
45 pub fn new(reader: R, writer: W) -> Self {
46 Self {
47 framed_read: FramedRead::with_capacity(reader, RequestDecoder::new(), 8 * 1024),
48 framed_write: FramedWrite::new(writer, ResponseEncoder::new()),
49 }
50 }
51
52 pub async fn process<H>(mut self, mut handler: Arc<H>) -> Result<(), HttpError>
53 where
54 H: Handler,
55 H::RespBody: Body<Data = Bytes> + Unpin,
56 <H::RespBody as Body>::Error: Display,
57 {
58 loop {
59 match self.framed_read.next().await {
60 Some(Ok(Message::Header((header, payload_size)))) => {
61 self.do_process(header, payload_size, &mut handler).await?;
62 }
63
64 Some(Ok(Message::Payload(PayloadItem::Eof))) => continue,
65
66 Some(Ok(Message::Payload(_))) => {
67 error!("error status because chunked has read in do_process");
68 let error_response = build_error_response(StatusCode::BAD_REQUEST);
69 self.do_send_response(error_response).await?;
70 return Err(ParseError::invalid_body("need header while receive body").into());
71 }
72
73 Some(Err(e)) => {
74 error!("can't receive next request, cause {}", e);
75 let error_response = build_error_response(StatusCode::BAD_REQUEST);
76 self.do_send_response(error_response).await?;
77 return Err(e.into());
78 }
79
80 None => {
81 info!("cant read more request, break this connection down");
82 return Ok(());
83 }
84 }
85 }
86 }
87
88 async fn do_process<H>(&mut self, header: RequestHeader, payload_size: PayloadSize, handler: &mut Arc<H>) -> Result<(), HttpError>
89 where
90 H: Handler,
91 H::RespBody: Body<Data = Bytes> + Unpin,
92 <H::RespBody as Body>::Error: Display,
93 {
94 if let Some(value) = header.headers().get(EXPECT) {
96 let slice = value.as_bytes();
97 if slice.len() >= 4 && &slice[0..4] == b"100-" {
99 let writer = self.framed_write.get_mut();
100 let _ = writer.write(b"HTTP/1.1 100 Continue\r\n\r\n").await.map_err(SendError::io)?;
102 writer.flush().await.map_err(SendError::io)?;
103 info!("receive expect request header, sent continue response");
105 }
106 }
107
108 let (req_body, maybe_body_sender) = ReqBody::create_req_body(&mut self.framed_read, payload_size);
109 let request = header.body(req_body);
110
111 let response_result = match maybe_body_sender {
112 None => handler.call(request).await,
113 Some(mut body_sender) => {
114 let (handler_result, body_send_result) = tokio::join!(handler.call(request), body_sender.start());
115
116 body_send_result?;
118 handler_result
119 }
120 };
121
122 self.send_response(response_result).await
123 }
124
125 async fn send_response<T, E>(&mut self, response_result: Result<Response<T>, E>) -> Result<(), HttpError>
126 where
127 T: Body + Unpin,
128 T::Error: Display,
129 E: Into<Box<dyn Error + Send + Sync>>,
130 {
131 match response_result {
132 Ok(response) => self.do_send_response(response).await,
133 Err(e) => {
134 error!("handle response error, cause: {}", e.into());
135 let error_response = build_error_response(StatusCode::INTERNAL_SERVER_ERROR);
136 self.do_send_response(error_response).await
137 }
138 }
139 }
140
141 async fn do_send_response<T>(&mut self, response: Response<T>) -> Result<(), HttpError>
142 where
143 T: Body + Unpin,
144 T::Error: Display,
145 {
146 let (header_parts, mut body) = response.into_parts();
147
148 let payload_size = {
149 let size_hint = body.size_hint();
150 match size_hint.exact() {
151 Some(0) => PayloadSize::Empty,
152 Some(length) => PayloadSize::Length(length),
153 None => PayloadSize::Chunked,
154 }
155 };
156
157 let header = Message::<_, T::Data>::Header((ResponseHead::from_parts(header_parts, ()), payload_size));
158 if !payload_size.is_empty() {
159 self.framed_write.feed(header).await?;
160 } else {
161 self.framed_write.send(header).await?;
165 }
166
167 loop {
168 match body.frame().await {
169 Some(Ok(frame)) => {
170 let payload_item =
171 frame.into_data().map(PayloadItem::Chunk).map_err(|_e| SendError::invalid_body("resolve body response error"))?;
172
173 self.framed_write
174 .send(Message::Payload(payload_item))
175 .await
176 .map_err(|_e| SendError::invalid_body("can't send response"))?;
177 }
178 Some(Err(e)) => return Err(SendError::invalid_body(format!("resolve response body error: {e}")).into()),
179 None => {
180 self.framed_write
181 .feed(Message::Payload(PayloadItem::<T::Data>::Eof))
183 .await
184 .map_err(|e| SendError::invalid_body(format!("can't send eof response: {}", e)))?;
185 return Ok(());
186 }
187 }
188 }
189 }
190}
191
192fn build_error_response(status_code: StatusCode) -> Response<Empty<Bytes>> {
193 Response::builder().status(status_code).body(Empty::<Bytes>::new()).unwrap()
194}