khttp/server/
mod.rs

1use crate::parser::Request;
2use crate::router::RouteParams;
3use crate::threadpool::{Task, ThreadPool};
4use crate::{
5    BodyReader, Headers, HttpParsingError, HttpPrinter, Method, RequestUri, Router, Status,
6};
7use std::cell::RefCell;
8use std::io::{self, Read};
9use std::mem::MaybeUninit;
10use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
11use std::sync::Arc;
12
13mod builder;
14mod epoll;
15pub use builder::ServerBuilder;
16
17pub type RouteFn = dyn for<'req, 's> Fn(RequestContext<'req>, &mut ResponseHandle<'s>) -> io::Result<()>
18    + Send
19    + Sync;
20
21pub type ConnectionSetupHookFn =
22    dyn Fn(io::Result<(TcpStream, SocketAddr)>) -> ConnectionSetupAction + Send + Sync;
23
24pub type ConnectionTeardownHookFn = dyn Fn(TcpStream, io::Result<()>) + Send + Sync;
25
26pub type PreRoutingHookFn = dyn for<'req, 's> Fn(&mut Request<'req>, &mut ResponseHandle<'s>) -> PreRoutingAction
27    + Send
28    + Sync;
29
30struct HandlerConfig {
31    router: Router<Box<RouteFn>>,
32    pre_routing_hook: Option<Box<PreRoutingHookFn>>,
33    connection_teardown_hook: Option<Box<ConnectionTeardownHookFn>>,
34    max_request_head: usize,
35}
36
37pub struct Server {
38    bind_addrs: Vec<SocketAddr>,
39    thread_count: usize,
40    connection_setup_hook: Option<Box<ConnectionSetupHookFn>>,
41    handler_config: Arc<HandlerConfig>,
42    #[allow(dead_code)]
43    epoll_queue_max_events: usize,
44}
45
46pub enum ConnectionSetupAction {
47    Proceed(TcpStream),
48    Drop,
49    StopAccepting,
50}
51
52pub enum PreRoutingAction {
53    Proceed,
54    Drop,
55}
56
57impl Server {
58    pub fn builder<A: ToSocketAddrs>(addr: A) -> io::Result<ServerBuilder> {
59        ServerBuilder::new(addr)
60    }
61}
62
63impl Server {
64    pub fn bind_addrs(&self) -> &Vec<SocketAddr> {
65        &self.bind_addrs
66    }
67
68    pub fn threads(&self) -> usize {
69        self.thread_count
70    }
71
72    pub fn serve(self) -> io::Result<()> {
73        struct PoolJob(TcpStream, Arc<HandlerConfig>);
74
75        impl Task for PoolJob {
76            #[inline]
77            fn run(self) {
78                let result = handle_connection(&self.0, &self.1);
79                if let Some(hook) = &self.1.connection_teardown_hook {
80                    (hook)(self.0, result);
81                }
82            }
83        }
84
85        let listener = TcpListener::bind(&*self.bind_addrs)?;
86        let pool: ThreadPool<PoolJob> = ThreadPool::new(self.thread_count);
87
88        loop {
89            let conn = listener.accept();
90
91            let stream = match &self.connection_setup_hook {
92                Some(hook) => match (hook)(conn) {
93                    ConnectionSetupAction::Proceed(stream) => stream,
94                    ConnectionSetupAction::Drop => continue,
95                    ConnectionSetupAction::StopAccepting => break,
96                },
97                None => match conn {
98                    Ok((stream, _)) => stream,
99                    Err(_) => continue,
100                },
101            };
102
103            pool.execute(PoolJob(stream, Arc::clone(&self.handler_config)));
104        }
105        Ok(())
106    }
107
108    pub fn serve_threaded(self) -> io::Result<()> {
109        let listener = TcpListener::bind(&*self.bind_addrs)?;
110
111        loop {
112            let conn = listener.accept();
113
114            let stream = match &self.connection_setup_hook {
115                Some(hook) => match (hook)(conn) {
116                    ConnectionSetupAction::Proceed(stream) => stream,
117                    ConnectionSetupAction::Drop => continue,
118                    ConnectionSetupAction::StopAccepting => break,
119                },
120                None => match conn {
121                    Ok((stream, _)) => stream,
122                    Err(_) => continue,
123                },
124            };
125            let config = Arc::clone(&self.handler_config);
126
127            std::thread::spawn(move || {
128                let result = handle_connection(&stream, &config);
129                if let Some(hook) = &config.connection_teardown_hook {
130                    (hook)(stream, result);
131                }
132            });
133        }
134        Ok(())
135    }
136
137    pub fn handle(&self, stream: &TcpStream) -> io::Result<()> {
138        handle_connection(stream, &self.handler_config)
139    }
140}
141
142pub struct ResponseHandle<'s> {
143    stream: &'s TcpStream,
144    keep_alive: bool,
145}
146
147impl<'s> ResponseHandle<'s> {
148    fn new(stream: &'s TcpStream) -> Self {
149        ResponseHandle {
150            stream,
151            keep_alive: true,
152        }
153    }
154
155    pub fn ok<B: AsRef<[u8]>>(&mut self, headers: &Headers, body: B) -> io::Result<()> {
156        self.send(&Status::OK, headers, body)
157    }
158
159    pub fn send<B: AsRef<[u8]>>(
160        &mut self,
161        status: &Status,
162        headers: &Headers,
163        body: B,
164    ) -> io::Result<()> {
165        if headers.is_connection_close() {
166            self.keep_alive = false;
167        }
168        HttpPrinter::write_response_bytes(self.stream, status, headers, body.as_ref())
169    }
170
171    pub fn ok0(&mut self, headers: &Headers) -> io::Result<()> {
172        self.send0(&Status::OK, headers)
173    }
174
175    pub fn send0(&mut self, status: &Status, headers: &Headers) -> io::Result<()> {
176        if headers.is_connection_close() {
177            self.keep_alive = false;
178        }
179        HttpPrinter::write_response_empty(self.stream, status, headers)
180    }
181
182    pub fn okr<R: Read>(&mut self, headers: &Headers, body: R) -> io::Result<()> {
183        self.sendr(&Status::OK, headers, body)
184    }
185
186    pub fn sendr<R: Read>(
187        &mut self,
188        status: &Status,
189        headers: &Headers,
190        body: R,
191    ) -> io::Result<()> {
192        if headers.is_connection_close() {
193            self.keep_alive = false;
194        }
195        HttpPrinter::write_response(self.stream, status, headers, body)
196    }
197
198    pub fn send_100_continue(&mut self) -> io::Result<()> {
199        HttpPrinter::write_100_continue(self.stream)
200    }
201
202    pub fn send_417_expectation_failed(&mut self) -> io::Result<()> {
203        HttpPrinter::write_417_expectation_failed(self.stream)
204    }
205
206    pub fn get_stream(&self) -> &TcpStream {
207        self.stream
208    }
209}
210
211pub struct RequestContext<'r> {
212    pub method: Method,
213    pub uri: &'r RequestUri<'r>,
214    pub headers: Headers<'r>,
215    pub params: &'r RouteParams<'r, 'r>,
216    pub http_version: u8,
217    body: BodyReader<'r, &'r TcpStream>,
218}
219
220impl<'r> RequestContext<'r> {
221    pub fn body(&mut self) -> &mut BodyReader<'r, &'r TcpStream> {
222        &mut self.body
223    }
224
225    pub fn get_stream(&self) -> &TcpStream {
226        self.body.inner()
227    }
228
229    pub fn into_parts(
230        self,
231    ) -> (
232        Method,
233        &'r RequestUri<'r>,
234        Headers<'r>,
235        &'r RouteParams<'r, 'r>,
236        u8,
237        BodyReader<'r, &'r TcpStream>,
238    ) {
239        (
240            self.method,
241            self.uri,
242            self.headers,
243            self.params,
244            self.http_version,
245            self.body,
246        )
247    }
248}
249
250fn handle_connection(stream: &TcpStream, config: &Arc<HandlerConfig>) -> io::Result<()> {
251    let mut response = ResponseHandle::new(stream);
252
253    loop {
254        let keep_alive = handle_one_request(stream, &mut response, config)?;
255        if !keep_alive {
256            return Ok(());
257        }
258    }
259}
260
261const DEFAULT_REQUEST_BUFFER_SIZE: usize = 4096;
262thread_local! {
263    static REQUEST_BUFFER: RefCell<Vec<MaybeUninit<u8>>> =
264        RefCell::new(Vec::with_capacity(DEFAULT_REQUEST_BUFFER_SIZE));
265}
266
267/// Read request head into a thread-local uninitialized buffer and parse it.
268/// Thread-local storage is used since each thread handles exactly one request at once.
269fn read_request<'a>(
270    mut stream: &TcpStream,
271    max_size: usize,
272) -> Result<(&'a [u8], Request<'a>), ReadRequestError> {
273    use std::slice::{from_raw_parts, from_raw_parts_mut};
274    use ReadRequestError::*;
275
276    REQUEST_BUFFER.with(|cell| {
277        let mut vec = cell.borrow_mut();
278
279        if vec.len() != max_size {
280            vec.resize_with(max_size, MaybeUninit::uninit);
281        }
282
283        let ptr = vec.as_mut_ptr() as *mut u8;
284        let mut filled = 0;
285
286        loop {
287            if filled == max_size {
288                return Err(RequestHeadTooLarge);
289            }
290
291            // SAFETY: ptr.add(filled) is within bounds; read() will init this tail region
292            let tail = unsafe { from_raw_parts_mut(ptr.add(filled), max_size - filled) };
293
294            let n = match stream.read(tail) {
295                Ok(0) => return Err(ReadEof),
296                Ok(n) => n,
297                Err(_) => return Err(IOError),
298            };
299            filled += n;
300
301            // SAFETY: only the prefix [..filled] has been written (initialized) by read()
302            let buf = unsafe { from_raw_parts(ptr as *const u8, filled) };
303
304            match Request::parse(buf) {
305                Ok(req) => return Ok((buf, req)),
306                Err(HttpParsingError::UnexpectedEof) => continue, // need more bytes, keep reading
307                Err(_) => return Err(InvalidRequestHead),         // malformed request head
308            }
309        }
310    })
311}
312
313enum ReadRequestError {
314    RequestHeadTooLarge,
315    InvalidRequestHead,
316    ReadEof,
317    IOError,
318}
319
320/// Returns "keep-alive" (whether to keep the connection alive for the next request).
321fn handle_one_request(
322    stream: &TcpStream,
323    response: &mut ResponseHandle<'_>,
324    config: &HandlerConfig,
325) -> io::Result<bool> {
326    let (buf, mut request) = match read_request(stream, config.max_request_head) {
327        Ok((buf, req)) => (buf, req),
328        Err(ReadRequestError::InvalidRequestHead) => {
329            response.send0(&Status::BAD_REQUEST, Headers::close())?;
330            return Ok(false);
331        }
332        Err(ReadRequestError::RequestHeadTooLarge) => {
333            response.send0(&Status::of(431), Headers::close())?;
334            return Ok(false);
335        }
336        Err(_) => return Ok(false), // silently drop connection on eof / io-error
337    };
338
339    if let Some(hook) = &config.pre_routing_hook {
340        match (hook)(&mut request, response) {
341            PreRoutingAction::Proceed => {}
342            PreRoutingAction::Drop => return Ok(response.keep_alive),
343        }
344    }
345
346    let matched_route = config
347        .router
348        .match_route(&request.method, request.uri.path());
349
350    let body = BodyReader::from_request(&buf[request.buf_offset..], stream, &request.headers);
351    let ctx = RequestContext {
352        method: request.method,
353        headers: request.headers,
354        uri: &request.uri,
355        http_version: request.http_version,
356        params: &matched_route.params,
357        body,
358    };
359
360    let client_requested_close = ctx.headers.is_connection_close();
361    (matched_route.route)(ctx, response)?;
362    if client_requested_close {
363        return Ok(false);
364    }
365    Ok(response.keep_alive)
366}