micro_http/connection/
http_connection.rs1use bytes::Bytes;
2use std::error::Error;
3use std::fmt::{Debug, Display};
4use std::sync::Arc;
5
6use futures::StreamExt;
7use http::header::EXPECT;
8use http::{Response, StatusCode};
9use http_body::Body;
10use http_body_util::{BodyExt, Empty};
11use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
12
13use crate::codec::RequestDecoder;
14use crate::handler::Handler;
15use crate::protocol::body::ReqBody;
16use crate::protocol::{HttpError, Message, ParseError, PayloadItem, PayloadSize, RequestHeader, ResponseHead, SendError};
17
18use crate::connection::message_writer::MessageWriter;
19use tokio_util::codec::FramedRead;
20use tracing::{error, info};
21
22#[derive(Debug)]
36pub struct HttpConnection<R, W>
37where
38 R: Debug,
39 W: Debug,
40{
41 framed_read: FramedRead<R, RequestDecoder>,
42 message_writer: MessageWriter<W>,
43}
44
45impl<R, W> HttpConnection<R, W>
46where
47 R: AsyncRead + Unpin + Debug,
48 W: AsyncWrite + Unpin + Debug,
49{
50 pub fn new(reader: R, writer: W) -> Self {
51 Self {
52 framed_read: FramedRead::with_capacity(reader, RequestDecoder::new(), 8 * 1024),
53 message_writer: MessageWriter::with_capacity(writer, 8 * 1024),
54 }
55 }
56
57 pub async fn process<H>(mut self, mut handler: Arc<H>) -> Result<(), HttpError>
58 where
59 H: Handler,
60 H::RespBody: Body<Data = Bytes> + Unpin,
61 <H::RespBody as Body>::Error: Display,
62 {
63 loop {
64 match self.framed_read.next().await {
65 Some(Ok(Message::Header((header, payload_size)))) => {
66 self.do_process(header, payload_size, &mut handler).await?;
67 }
68
69 Some(Ok(Message::Payload(PayloadItem::Eof))) => continue,
70
71 Some(Ok(Message::Payload(_))) => {
72 error!("error status because chunked has read in do_process");
73 let error_response = build_error_response(StatusCode::BAD_REQUEST);
74 self.do_send_response(error_response).await?;
75 return Err(ParseError::invalid_body("need header while receive body").into());
76 }
77
78 Some(Err(ParseError::Io { source })) => {
79 info!("connection io error: {}", source);
80 return Ok(());
81 }
82
83 Some(Err(e)) => {
84 error!("can't receive next request, cause {}", e);
85 return Err(e.into());
86 }
87
88 None => {
89 info!("can't read more request, break this connection down");
90 return Ok(());
91 }
92 }
93 }
94 }
95
96 async fn do_process<H>(&mut self, header: RequestHeader, payload_size: PayloadSize, handler: &mut Arc<H>) -> Result<(), HttpError>
97 where
98 H: Handler,
99 H::RespBody: Body<Data = Bytes> + Unpin,
100 <H::RespBody as Body>::Error: Display,
101 {
102 if let Some(value) = header.headers().get(EXPECT) {
104 let slice = value.as_bytes();
105 if slice.len() >= 4 && &slice[0..4] == b"100-" {
107 let writer = self.message_writer.get_mut();
108 let _ = writer.write(b"HTTP/1.1 100 Continue\r\n\r\n").await.map_err(SendError::io)?;
110 writer.flush().await.map_err(SendError::io)?;
111 info!("receive expect request header, sent continue response");
113 }
114 }
115
116 let (req_body, maybe_body_sender) = ReqBody::create_req_body(&mut self.framed_read, payload_size);
117 let request = header.body(req_body);
118
119 let response_result = match maybe_body_sender {
120 None => handler.call(request).await,
121 Some(mut body_sender) => {
122 let (handler_result, body_send_result) = tokio::join!(handler.call(request), body_sender.start());
123
124 body_send_result?;
126 handler_result
127 }
128 };
129
130 self.send_response(response_result).await
131 }
132
133 async fn send_response<T, E>(&mut self, response_result: Result<Response<T>, E>) -> Result<(), HttpError>
134 where
135 T: Body + Unpin,
136 T::Error: Display,
137 E: Into<Box<dyn Error + Send + Sync>>,
138 {
139 match response_result {
140 Ok(response) => self.do_send_response(response).await,
141 Err(e) => {
142 error!("handle response error, cause: {}", e.into());
143 let error_response = build_error_response(StatusCode::INTERNAL_SERVER_ERROR);
144 self.do_send_response(error_response).await
145 }
146 }
147 }
148
149 async fn do_send_response<T>(&mut self, response: Response<T>) -> Result<(), HttpError>
150 where
151 T: Body + Unpin,
152 T::Error: Display,
153 {
154 let (header_parts, mut body) = response.into_parts();
155
156 let payload_size = {
157 let size_hint = body.size_hint();
158 match size_hint.exact() {
159 Some(0) => PayloadSize::Empty,
160 Some(length) => PayloadSize::Length(length),
161 None => PayloadSize::Chunked,
162 }
163 };
164
165 let header = Message::<_, T::Data>::Header((ResponseHead::from_parts(header_parts, ()), payload_size));
166
167 self.message_writer.write(header)?;
168
169 loop {
170 match body.frame().await {
171 Some(Ok(frame)) => {
172 let payload_item =
173 frame.into_data().map(PayloadItem::Chunk).map_err(|_e| SendError::invalid_body("resolve body response error"))?;
174
175 self.message_writer
176 .write(Message::Payload(payload_item))
177 .map_err(|_e| SendError::invalid_body("can't send response"))?;
178 }
179 Some(Err(e)) => return Err(SendError::invalid_body(format!("resolve response body error: {e}")).into()),
180 None => {
181 self.message_writer
182 .write(Message::Payload(PayloadItem::<T::Data>::Eof))
184 .map_err(|e| SendError::invalid_body(format!("can't send eof response: {}", e)))?;
185 self.message_writer.flush().await?;
186 self.message_writer.clear_buf();
187 return Ok(());
188 }
189 }
190 }
191 }
192}
193
194fn build_error_response(status_code: StatusCode) -> Response<Empty<Bytes>> {
195 Response::builder().status(status_code).body(Empty::<Bytes>::new()).unwrap()
196}