1use crate::parser::Request;
2use crate::router::RouteParams;
3use crate::threadpool::{Task, ThreadPool};
4use crate::{
5 BodyReader, Headers, HttpParsingError, HttpPrinter, Method, RequestUri, Router, Status,
6};
7use std::cell::RefCell;
8use std::io::{self, Read};
9use std::mem::MaybeUninit;
10use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
11use std::sync::Arc;
12
13mod builder;
14mod epoll;
15pub use builder::ServerBuilder;
16
17pub type RouteFn = dyn for<'req, 's> Fn(RequestContext<'req>, &mut ResponseHandle<'s>) -> io::Result<()>
18 + Send
19 + Sync;
20
21pub type ConnectionSetupHookFn =
22 dyn Fn(io::Result<(TcpStream, SocketAddr)>) -> ConnectionSetupAction + Send + Sync;
23
24pub type ConnectionTeardownHookFn = dyn Fn(TcpStream, io::Result<()>) + Send + Sync;
25
26pub type PreRoutingHookFn = dyn for<'req, 's> Fn(&mut Request<'req>, &mut ResponseHandle<'s>) -> PreRoutingAction
27 + Send
28 + Sync;
29
30struct HandlerConfig {
31 router: Router<Box<RouteFn>>,
32 pre_routing_hook: Option<Box<PreRoutingHookFn>>,
33 connection_teardown_hook: Option<Box<ConnectionTeardownHookFn>>,
34 max_request_head: usize,
35}
36
37pub struct Server {
38 bind_addrs: Vec<SocketAddr>,
39 thread_count: usize,
40 connection_setup_hook: Option<Box<ConnectionSetupHookFn>>,
41 handler_config: Arc<HandlerConfig>,
42 #[allow(dead_code)]
43 epoll_queue_max_events: usize,
44}
45
46pub enum ConnectionSetupAction {
47 Proceed(TcpStream),
48 Drop,
49 StopAccepting,
50}
51
52pub enum PreRoutingAction {
53 Proceed,
54 Drop,
55}
56
57impl Server {
58 pub fn builder<A: ToSocketAddrs>(addr: A) -> io::Result<ServerBuilder> {
59 ServerBuilder::new(addr)
60 }
61}
62
63impl Server {
64 pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
65 &self.bind_addrs
66 }
67
68 pub fn threads(&self) -> usize {
69 self.thread_count
70 }
71
72 pub fn serve(self) -> io::Result<()> {
73 struct PoolJob(TcpStream, Arc<HandlerConfig>);
74
75 impl Task for PoolJob {
76 #[inline]
77 fn run(self) {
78 let result = handle_connection(&self.0, &self.1);
79 if let Some(hook) = &self.1.connection_teardown_hook {
80 (hook)(self.0, result);
81 }
82 }
83 }
84
85 let listener = TcpListener::bind(&*self.bind_addrs)?;
86 let pool: ThreadPool<PoolJob> = ThreadPool::new(self.thread_count);
87
88 loop {
89 let conn = listener.accept();
90
91 let stream = match &self.connection_setup_hook {
92 Some(hook) => match (hook)(conn) {
93 ConnectionSetupAction::Proceed(stream) => stream,
94 ConnectionSetupAction::Drop => continue,
95 ConnectionSetupAction::StopAccepting => break,
96 },
97 None => match conn {
98 Ok((stream, _)) => stream,
99 Err(_) => continue,
100 },
101 };
102
103 pool.execute(PoolJob(stream, Arc::clone(&self.handler_config)));
104 }
105 Ok(())
106 }
107
108 pub fn serve_threaded(self) -> io::Result<()> {
109 let listener = TcpListener::bind(&*self.bind_addrs)?;
110
111 loop {
112 let conn = listener.accept();
113
114 let stream = match &self.connection_setup_hook {
115 Some(hook) => match (hook)(conn) {
116 ConnectionSetupAction::Proceed(stream) => stream,
117 ConnectionSetupAction::Drop => continue,
118 ConnectionSetupAction::StopAccepting => break,
119 },
120 None => match conn {
121 Ok((stream, _)) => stream,
122 Err(_) => continue,
123 },
124 };
125 let config = Arc::clone(&self.handler_config);
126
127 std::thread::spawn(move || {
128 let result = handle_connection(&stream, &config);
129 if let Some(hook) = &config.connection_teardown_hook {
130 (hook)(stream, result);
131 }
132 });
133 }
134 Ok(())
135 }
136
137 pub fn handle(&self, stream: &TcpStream) -> io::Result<()> {
138 handle_connection(stream, &self.handler_config)
139 }
140}
141
142pub struct ResponseHandle<'s> {
143 stream: &'s TcpStream,
144 keep_alive: bool,
145}
146
147impl<'s> ResponseHandle<'s> {
148 fn new(stream: &'s TcpStream) -> Self {
149 ResponseHandle {
150 stream,
151 keep_alive: true,
152 }
153 }
154
155 pub fn ok<B: AsRef<[u8]>>(&mut self, headers: &Headers, body: B) -> io::Result<()> {
156 self.send(&Status::OK, headers, body)
157 }
158
159 pub fn send<B: AsRef<[u8]>>(
160 &mut self,
161 status: &Status,
162 headers: &Headers,
163 body: B,
164 ) -> io::Result<()> {
165 if headers.is_connection_close() {
166 self.keep_alive = false;
167 }
168 HttpPrinter::write_response_bytes(self.stream, status, headers, body.as_ref())
169 }
170
171 pub fn ok0(&mut self, headers: &Headers) -> io::Result<()> {
172 self.send0(&Status::OK, headers)
173 }
174
175 pub fn send0(&mut self, status: &Status, headers: &Headers) -> io::Result<()> {
176 if headers.is_connection_close() {
177 self.keep_alive = false;
178 }
179 HttpPrinter::write_response_empty(self.stream, status, headers)
180 }
181
182 pub fn okr<R: Read>(&mut self, headers: &Headers, body: R) -> io::Result<()> {
183 self.sendr(&Status::OK, headers, body)
184 }
185
186 pub fn sendr<R: Read>(
187 &mut self,
188 status: &Status,
189 headers: &Headers,
190 body: R,
191 ) -> io::Result<()> {
192 if headers.is_connection_close() {
193 self.keep_alive = false;
194 }
195 HttpPrinter::write_response(self.stream, status, headers, body)
196 }
197
198 pub fn send_100_continue(&mut self) -> io::Result<()> {
199 HttpPrinter::write_100_continue(self.stream)
200 }
201
202 pub fn send_417_expectation_failed(&mut self) -> io::Result<()> {
203 HttpPrinter::write_417_expectation_failed(self.stream)
204 }
205
206 pub fn get_stream(&self) -> &TcpStream {
207 self.stream
208 }
209}
210
211pub struct RequestContext<'r> {
212 pub method: Method,
213 pub uri: &'r RequestUri<'r>,
214 pub headers: Headers<'r>,
215 pub params: &'r RouteParams<'r, 'r>,
216 pub http_version: u8,
217 body: BodyReader<'r, &'r TcpStream>,
218}
219
220impl<'r> RequestContext<'r> {
221 pub fn body(&mut self) -> &mut BodyReader<'r, &'r TcpStream> {
222 &mut self.body
223 }
224
225 pub fn get_stream(&self) -> &TcpStream {
226 self.body.inner()
227 }
228
229 pub fn into_parts(
230 self,
231 ) -> (
232 Method,
233 &'r RequestUri<'r>,
234 Headers<'r>,
235 &'r RouteParams<'r, 'r>,
236 u8,
237 BodyReader<'r, &'r TcpStream>,
238 ) {
239 (
240 self.method,
241 self.uri,
242 self.headers,
243 self.params,
244 self.http_version,
245 self.body,
246 )
247 }
248}
249
250fn handle_connection(stream: &TcpStream, config: &Arc<HandlerConfig>) -> io::Result<()> {
251 let mut response = ResponseHandle::new(stream);
252
253 loop {
254 let keep_alive = handle_one_request(stream, &mut response, config)?;
255 if !keep_alive {
256 return Ok(());
257 }
258 }
259}
260
261const DEFAULT_REQUEST_BUFFER_SIZE: usize = 4096;
262thread_local! {
263 static REQUEST_BUFFER: RefCell<Vec<MaybeUninit<u8>>> =
264 RefCell::new(Vec::with_capacity(DEFAULT_REQUEST_BUFFER_SIZE));
265}
266
267fn read_request<'a>(
270 mut stream: &TcpStream,
271 max_size: usize,
272) -> Result<(&'a [u8], Request<'a>), ReadRequestError> {
273 use std::slice::{from_raw_parts, from_raw_parts_mut};
274 use ReadRequestError::*;
275
276 REQUEST_BUFFER.with(|cell| {
277 let mut vec = cell.borrow_mut();
278
279 if vec.len() != max_size {
280 vec.resize_with(max_size, MaybeUninit::uninit);
281 }
282
283 let ptr = vec.as_mut_ptr() as *mut u8;
284 let mut filled = 0;
285
286 loop {
287 if filled == max_size {
288 return Err(RequestHeadTooLarge);
289 }
290
291 let tail = unsafe { from_raw_parts_mut(ptr.add(filled), max_size - filled) };
293
294 let n = match stream.read(tail) {
295 Ok(0) => return Err(ReadEof),
296 Ok(n) => n,
297 Err(_) => return Err(IOError),
298 };
299 filled += n;
300
301 let buf = unsafe { from_raw_parts(ptr as *const u8, filled) };
303
304 match Request::parse(buf) {
305 Ok(req) => return Ok((buf, req)),
306 Err(HttpParsingError::UnexpectedEof) => continue, Err(_) => return Err(InvalidRequestHead), }
309 }
310 })
311}
312
313enum ReadRequestError {
314 RequestHeadTooLarge,
315 InvalidRequestHead,
316 ReadEof,
317 IOError,
318}
319
320fn handle_one_request(
322 stream: &TcpStream,
323 response: &mut ResponseHandle<'_>,
324 config: &HandlerConfig,
325) -> io::Result<bool> {
326 let (buf, mut request) = match read_request(stream, config.max_request_head) {
327 Ok((buf, req)) => (buf, req),
328 Err(ReadRequestError::InvalidRequestHead) => {
329 response.send0(&Status::BAD_REQUEST, Headers::close())?;
330 return Ok(false);
331 }
332 Err(ReadRequestError::RequestHeadTooLarge) => {
333 response.send0(&Status::of(431), Headers::close())?;
334 return Ok(false);
335 }
336 Err(_) => return Ok(false), };
338
339 if let Some(hook) = &config.pre_routing_hook {
340 match (hook)(&mut request, response) {
341 PreRoutingAction::Proceed => {}
342 PreRoutingAction::Drop => return Ok(response.keep_alive),
343 }
344 }
345
346 let matched_route = config
347 .router
348 .match_route(&request.method, request.uri.path());
349
350 let body = BodyReader::from_request(&buf[request.buf_offset..], stream, &request.headers);
351 let ctx = RequestContext {
352 method: request.method,
353 headers: request.headers,
354 uri: &request.uri,
355 http_version: request.http_version,
356 params: &matched_route.params,
357 body,
358 };
359
360 let client_requested_close = ctx.headers.is_connection_close();
361 (matched_route.route)(ctx, response)?;
362 if client_requested_close {
363 return Ok(false);
364 }
365 Ok(response.keep_alive)
366}