1use crate::body::message_body::MessageBody;
4use crate::error::{Error, ServerError};
5use crate::handlers::Handlers;
6use crate::request::{HttpPayload, HttpRequest};
7use crate::response::HttpResponse;
8use crate::route::HandlerFn;
9use crate::server::builder::HttpServerBuilder;
10use crate::state::State;
11use http::HeaderMap;
12use std::fmt::Debug;
13use tokio::io;
14use tokio::io::BufReader;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::ToSocketAddrs;
17use tower::layer::util::Identity;
18use tower::{Layer, Service, ServiceBuilder, ServiceExt};
19#[cfg(feature = "trace")]
20use tracing::trace;
21use tracing::{debug, error, info};
22
23pub mod builder;
24mod test;
25
26pub struct HttpServer<L>
30where
31 L: Layer<HandlerFn> + Clone + Send + 'static,
32{
33 listener: tokio::net::TcpListener,
34 handlers: Handlers,
35 app_state: State,
36 service_builder: ServiceBuilder<L>,
37}
38
39impl<L> HttpServer<L>
40where
41 L: Layer<HandlerFn> + Clone + Send + 'static,
42 L::Service: Service<(HttpRequest, HttpPayload), Response = HttpResponse, Error = Error>
43 + Send
44 + 'static,
45 <L::Service as Service<(HttpRequest, HttpPayload)>>::Future: Send + 'static,
46{
47 #[cfg_attr(
48 feature = "trace",
49 tracing::instrument(level = "trace", skip(service_builder))
50 )]
51 pub(crate) async fn new<T: ToSocketAddrs + Debug>(
55 addr: T,
56 handlers: Handlers,
57 app_state: State,
58 service_builder: ServiceBuilder<L>,
59 ) -> io::Result<Self> {
60 let listener = tokio::net::TcpListener::bind(addr).await?;
61
62 #[cfg(feature = "trace")]
63 trace!("Server Bound to {}", listener.local_addr()?);
64
65 Ok(Self {
66 listener,
67 handlers,
68 app_state,
69 service_builder,
70 })
71 }
72
73 pub fn builder<T: ToSocketAddrs + Default + Debug + Clone>() -> HttpServerBuilder<T, Identity> {
75 HttpServerBuilder::<T, Identity>::new()
76 }
77
78 pub async fn serve(self) -> Result<(), ServerError> {
80 info!("Listening on {}", self.listener.local_addr()?);
81 loop {
82 match self.listener.accept().await {
83 Ok((stream, socket)) => {
84 #[cfg(feature = "trace")]
85 trace!("Accepted connection from {}", socket);
86 self.accept_connection(stream, socket)?;
87 }
88 Err(err) => {
89 error!("Failed to accept connection: {}", err);
90 continue;
91 }
92 }
93 }
94 }
95
96 fn accept_connection(
100 &self,
101 stream: tokio::net::TcpStream,
102 socket: std::net::SocketAddr,
103 ) -> Result<(), ServerError> {
104 let handlers = self.handlers.clone();
105 let state = self.app_state.clone();
106 let service_builder = self.service_builder.clone();
107
108 tokio::spawn(async move {
109 if let Err(e) = Self::handle_connection(
110 stream,
111 #[cfg(feature = "trace")]
112 socket,
113 handlers,
114 state,
115 service_builder,
116 )
117 .await
118 {
119 error!("Error handling connection from {}: {:?}", socket, e);
120 }
121 });
122
123 Ok(())
124 }
125
126 #[cfg_attr(feature = "trace", tracing::instrument(level = "trace", skip_all))]
127 async fn handle_connection(
129 stream: tokio::net::TcpStream,
130 #[cfg(feature = "trace")] socket: std::net::SocketAddr,
131 handlers: Handlers,
132 state: State,
133 service_builder: ServiceBuilder<L>,
134 ) -> Result<(), ServerError> {
135 #[cfg(feature = "trace")]
136 trace!("Accepted connection from {}", socket);
137
138 let mut reader = BufReader::new(stream);
139
140 let request_buffer = match Self::read_request(&mut reader).await {
141 Ok(buffer) => buffer,
142 Err(e) => {
143 error!("Failed to read request: {}", e);
144 return Err(e);
145 }
146 };
147
148 let (mut request, payload) = match HttpRequest::from_bytes(&request_buffer) {
149 Ok(req) => req,
150 Err(e) => {
151 error!("Failed to parse request: {}", e);
152 return Err(e);
153 }
154 };
155
156 request.data = state;
157
158 #[cfg(feature = "trace")]
159 trace!("Request: {:?}", request);
160
161 let handler = handlers.get_handler(request.method(), request.uri().path());
162
163 request.params_mut().extend(handler.1.clone());
164
165 let mut service = service_builder.service(handler.handler());
166
167 match service.ready().await {
168 Ok(_) => {}
169 Err(e) => {
170 error!("Failed to construct service: {}", e);
171 return Err(ServerError::ServiceConstructionFailed);
172 }
173 };
174
175 let response = service.call((request, payload)).await.unwrap_or_else(|e| {
176 error!("Failed to process request: {}", e);
177 e.error_response()
178 });
179
180 Self::send_response(reader, response).await
181 }
182
183 #[cfg_attr(feature = "trace", tracing::instrument(level = "trace", skip(reader)))]
184 async fn send_response(
186 reader: BufReader<tokio::net::TcpStream>,
187 mut response: HttpResponse,
188 ) -> Result<(), ServerError> {
189 let content_length = response
190 .body
191 .clone()
192 .try_into_bytes()
193 .unwrap_or_default()
194 .len() as u64;
195
196 Self::insert_content_length(response.headers_mut(), content_length);
197
198 response
199 .headers_mut()
200 .insert("Connection", "close".parse()?);
201
202 let response_bytes = response.to_bytes()?;
203
204 let mut stream = reader.into_inner();
205 stream.write_all(&response_bytes).await?;
206 stream.flush().await?;
207
208 Ok(())
209 }
210
211 fn insert_content_length(headers: &mut HeaderMap, content_length: u64) {
212 headers.insert(
213 "Content-Length",
214 content_length.to_string().parse().unwrap(),
215 );
216 }
217
218 #[cfg_attr(feature = "trace", tracing::instrument(level = "trace", skip(reader)))]
219 async fn read_request(
221 reader: &mut BufReader<tokio::net::TcpStream>,
222 ) -> Result<Vec<u8>, ServerError> {
223 let mut request_buffer = Vec::new();
224 let mut headers_read = false;
225 let mut content_length = 0;
226
227 loop {
228 let mut buf = [0; 1024];
229 let n = reader.read(&mut buf).await?;
230
231 if n == 0 {
232 debug!("Connection closed by the client.");
233 return Err(ServerError::ConnectionClosed);
234 }
235
236 request_buffer.extend_from_slice(&buf[..n]);
237
238 if !headers_read {
239 if let Some(headers_end) = Self::find_headers_end(&request_buffer) {
240 headers_read = true;
241
242 let headers = &request_buffer[..headers_end];
243 let headers_str = String::from_utf8_lossy(headers);
244
245 for line in headers_str.lines() {
246 if line.to_lowercase().starts_with("content-length:") {
247 if let Some(length_str) = line.split(':').nth(1) {
248 content_length = length_str.trim().parse::<usize>().unwrap_or(0);
249 }
250 }
251 }
252
253 let body_bytes_read = request_buffer.len() - headers_end;
254 if body_bytes_read >= content_length {
255 break;
256 }
257 }
258 } else {
259 let total_bytes = request_buffer.len();
260 let headers_end = Self::find_headers_end(&request_buffer).unwrap_or(0);
261 let body_bytes_read = total_bytes - headers_end;
262
263 if body_bytes_read >= content_length {
264 break;
265 }
266 }
267 }
268
269 Ok(request_buffer)
270 }
271
272 #[inline]
273 fn find_headers_end(buffer: &[u8]) -> Option<usize> {
275 buffer
276 .windows(4)
277 .position(|window| window == b"\r\n\r\n")
278 .map(|pos| pos + 4)
279 }
280}