oxhttp/
server.rs

1use crate::io::{decode_request_body, decode_request_headers, encode_response, BUFFER_CAPACITY};
2use crate::model::header::{InvalidHeaderValue, CONNECTION, CONTENT_TYPE, EXPECT, SERVER};
3use crate::model::request::Builder as RequestBuilder;
4use crate::model::{Body, HeaderValue, Request, Response, StatusCode, Version};
5use std::fmt;
6use std::io::{copy, sink, BufReader, BufWriter, Error, ErrorKind, Result, Write};
7use std::net::{SocketAddr, TcpListener, TcpStream};
8use std::sync::{Arc, Condvar, Mutex};
9use std::thread::{Builder as ThreadBuilder, JoinHandle};
10use std::time::Duration;
11
12/// An HTTP server.
13///
14/// It uses a very simple threading mechanism: a new thread is started on each connection and closed when the client connection is closed.
15/// To avoid crashes it is possible to set an upper bound to the number of concurrent connections using the [`Server::with_max_concurrent_connections`] function.
16///
17/// ```no_run
18/// use std::net::{Ipv4Addr, Ipv6Addr};
19/// use oxhttp::Server;
20/// use oxhttp::model::{Body, Response, StatusCode};
21/// use std::time::Duration;
22///
23/// // Builds a new server that returns a 404 everywhere except for "/" where it returns the body 'home'
24/// let mut server = Server::new(|request| {
25///     if request.uri().path() == "/" {
26///         Response::builder().body(Body::from("home")).unwrap()
27///     } else {
28///         Response::builder().status(StatusCode::NOT_FOUND).body(Body::empty()).unwrap()
29///     }
30/// });
31/// // We bind the server to localhost on both IPv4 and v6
32/// server = server.bind((Ipv4Addr::LOCALHOST, 8080)).bind((Ipv6Addr::LOCALHOST, 8080));
33/// // Raise a timeout error if the client does not respond after 10s.
34/// server = server.with_global_timeout(Duration::from_secs(10));
35/// // Limits the number of concurrent connections to 128.
36/// server = server.with_max_concurrent_connections(128);
37/// // We spawn the server and block on it
38/// server.spawn()?.join()?;
39/// # Result::<_,Box<dyn std::error::Error>>::Ok(())
40/// ```
41#[allow(missing_copy_implementations)]
42pub struct Server {
43    #[allow(clippy::type_complexity)]
44    on_request: Arc<dyn Fn(&mut Request<Body>) -> Response<Body> + Send + Sync + 'static>,
45    socket_addrs: Vec<SocketAddr>,
46    timeout: Option<Duration>,
47    server: Option<HeaderValue>,
48    max_num_thread: Option<usize>,
49}
50
51impl Server {
52    /// Builds the server using the given `on_request` method that builds a `Response` from a given `Request`.
53    #[inline]
54    pub fn new(
55        on_request: impl Fn(&mut Request<Body>) -> Response<Body> + Send + Sync + 'static,
56    ) -> Self {
57        Self {
58            on_request: Arc::new(on_request),
59            socket_addrs: Vec::new(),
60            timeout: None,
61            server: None,
62            max_num_thread: None,
63        }
64    }
65
66    /// Ask the server to listen to a given socket when spawned.
67    pub fn bind(mut self, addr: impl Into<SocketAddr>) -> Self {
68        let addr = addr.into();
69        if !self.socket_addrs.contains(&addr) {
70            self.socket_addrs.push(addr);
71        }
72        self
73    }
74
75    /// Sets the default value for the [`Server`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.server) header.
76    #[inline]
77    pub fn with_server_name(
78        mut self,
79        server: impl Into<String>,
80    ) -> std::result::Result<Self, InvalidHeaderValue> {
81        self.server = Some(HeaderValue::try_from(server.into())?);
82        Ok(self)
83    }
84
85    /// Sets the global timeout value (applies to both read and write).
86    #[inline]
87    pub fn with_global_timeout(mut self, timeout: Duration) -> Self {
88        self.timeout = Some(timeout);
89        self
90    }
91
92    /// Sets the number maximum number of threads this server can spawn.
93    #[inline]
94    pub fn with_max_concurrent_connections(mut self, max_num_thread: usize) -> Self {
95        self.max_num_thread = Some(max_num_thread);
96        self
97    }
98
99    /// Spawns the server by listening to the given addresses.
100    ///
101    /// Note that this is not blocking.
102    /// To wait for the server to terminate indefinitely, call [`join`](ListeningServer::join) on the result.
103    pub fn spawn(self) -> Result<ListeningServer> {
104        let timeout = self.timeout;
105        let thread_limit = self.max_num_thread.map(Semaphore::new);
106        let listener_threads = self.socket_addrs
107                .into_iter()
108                .map(|listener_addr| {
109                    let listener = TcpListener::bind(listener_addr)?;
110                    let thread_name = format!("{listener_addr}: listener thread of OxHTTP");
111                    let thread_limit = thread_limit.clone();
112                    let on_request = Arc::clone(&self.on_request);
113                    let server = self.server.clone();
114                    ThreadBuilder::new().name(thread_name).spawn(move || {
115                        for stream in listener.incoming() {
116                            match stream {
117                                Ok(stream) => {
118                                    let peer_addr = match stream.peer_addr() {
119                                        Ok(peer) => peer,
120                                        Err(error) => {
121                                            eprintln!("OxHTTP TCP error when attempting to get the peer address: {error}");
122                                            continue;
123                                        }
124                                    };
125                                    if let Err(error) = stream.set_nodelay(true) {
126                                        eprintln!("OxHTTP TCP error when attempting to set the TCP_NODELAY option: {error}");
127                                    }
128                                    let thread_name = format!("{peer_addr}: responding thread of OxHTTP");
129                                    let thread_guard = thread_limit.as_ref().map(|s| s.lock());
130                                    let on_request = Arc::clone(&on_request);
131                                    let server = server.clone();
132                                    if let Err(error) = ThreadBuilder::new().name(thread_name).spawn(
133                                        move || {
134                                            if let Err(error) =
135                                                accept_request(stream, &*on_request, timeout, &server)
136                                            {
137                                                eprintln!(
138                                                    "OxHTTP TCP error when writing response to {peer_addr}: {error}"
139                                                )
140                                            }
141                                            drop(thread_guard);
142                                        }
143                                    ) {
144                                        eprintln!("OxHTTP thread spawn error: {error}");
145                                    }
146                                }
147                                Err(error) => {
148                                    eprintln!("OxHTTP TCP error when opening stream: {error}");
149                                }
150                            }
151                        }
152                    })
153                })
154                .collect::<Result<Vec<_>>>()?;
155        Ok(ListeningServer {
156            threads: listener_threads,
157        })
158    }
159}
160
161/// Handle to a running server created by [`Server::spawn`].
162pub struct ListeningServer {
163    threads: Vec<JoinHandle<()>>,
164}
165
166impl ListeningServer {
167    /// Join the server threads and wait for them indefinitely except in case of crash.
168    pub fn join(self) -> Result<()> {
169        for thread in self.threads {
170            thread.join().map_err(|e| {
171                Error::other(if let Ok(e) = e.downcast::<&dyn fmt::Display>() {
172                    format!("The server thread panicked with error: {e}")
173                } else {
174                    "The server thread panicked with an unknown error".into()
175                })
176            })?;
177        }
178        Ok(())
179    }
180}
181
182fn accept_request(
183    mut stream: TcpStream,
184    on_request: &dyn Fn(&mut Request<Body>) -> Response<Body>,
185    timeout: Option<Duration>,
186    server: &Option<HeaderValue>,
187) -> Result<()> {
188    stream.set_read_timeout(timeout)?;
189    stream.set_write_timeout(timeout)?;
190    let mut connection_state = ConnectionState::KeepAlive;
191    while connection_state == ConnectionState::KeepAlive {
192        let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stream.try_clone()?);
193        let (mut response, new_connection_state) = match decode_request_headers(&mut reader, false)
194        {
195            Ok(request) => {
196                // Handles Expect header
197                if let Some(expect) = request.headers_ref().unwrap().get(EXPECT).cloned() {
198                    if request
199                        .version_ref()
200                        .map_or(true, |v| *v >= Version::HTTP_11)
201                        && expect.as_bytes().eq_ignore_ascii_case(b"100-continue")
202                    {
203                        stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n")?;
204                        read_body_and_build_response(request, reader, on_request)
205                    } else {
206                        (
207                            build_text_response(
208                                StatusCode::EXPECTATION_FAILED,
209                                format!(
210                                    "Expect header value '{}' is not supported.",
211                                    String::from_utf8_lossy(expect.as_ref())
212                                ),
213                            ),
214                            ConnectionState::Close,
215                        )
216                    }
217                } else {
218                    read_body_and_build_response(request, reader, on_request)
219                }
220            }
221            Err(error) => {
222                if error.kind() == ErrorKind::ConnectionAborted {
223                    return Ok(()); // The client is disconnected. Let's ignore this error and do not try to write an answer that won't be received.
224                } else {
225                    (build_error(error), ConnectionState::Close)
226                }
227            }
228        };
229        connection_state = new_connection_state;
230
231        // Additional headers
232        if let Some(server) = server {
233            response
234                .headers_mut()
235                .entry(SERVER)
236                .or_insert_with(|| server.clone());
237        }
238
239        stream = encode_response(
240            &mut response,
241            BufWriter::with_capacity(BUFFER_CAPACITY, stream),
242        )?
243        .into_inner()
244        .map_err(|e| e.into_error())?;
245    }
246    Ok(())
247}
248
249#[derive(Eq, PartialEq, Debug, Copy, Clone)]
250enum ConnectionState {
251    Close,
252    KeepAlive,
253}
254
255fn read_body_and_build_response(
256    request: RequestBuilder,
257    reader: BufReader<TcpStream>,
258    on_request: &dyn Fn(&mut Request<Body>) -> Response<Body>,
259) -> (Response<Body>, ConnectionState) {
260    match decode_request_body(request, reader) {
261        Ok(mut request) => {
262            let response = on_request(&mut request);
263            // We make sure to finish reading the body
264            if let Err(error) = copy(request.body_mut(), &mut sink()) {
265                (build_error(error), ConnectionState::Close) // TODO: ignore?
266            } else {
267                let connection_state = request
268                    .headers()
269                    .get(CONNECTION)
270                    .and_then(|v| {
271                        v.as_bytes()
272                            .eq_ignore_ascii_case(b"close")
273                            .then_some(ConnectionState::Close)
274                    })
275                    .unwrap_or_else(|| {
276                        if request.version() <= Version::HTTP_10 {
277                            ConnectionState::Close
278                        } else {
279                            ConnectionState::KeepAlive
280                        }
281                    });
282                (response, connection_state)
283            }
284        }
285        Err(error) => (build_error(error), ConnectionState::Close),
286    }
287}
288
289fn build_error(error: Error) -> Response<Body> {
290    build_text_response(
291        match error.kind() {
292            ErrorKind::TimedOut => StatusCode::REQUEST_TIMEOUT,
293            ErrorKind::InvalidData => StatusCode::BAD_REQUEST,
294            _ => StatusCode::INTERNAL_SERVER_ERROR,
295        },
296        error.to_string(),
297    )
298}
299
300fn build_text_response(status: StatusCode, text: String) -> Response<Body> {
301    Response::builder()
302        .status(status)
303        .header(CONTENT_TYPE, "text/plain; charset=utf-8")
304        .body(Body::from(text))
305        .unwrap()
306}
307
308/// Dumb semaphore allowing to overflow capacity
309#[derive(Clone)]
310struct Semaphore {
311    inner: Arc<InnerSemaphore>,
312}
313
314struct InnerSemaphore {
315    count: Mutex<usize>,
316    capacity: usize,
317    condvar: Condvar,
318}
319
320impl Semaphore {
321    fn new(capacity: usize) -> Self {
322        Self {
323            inner: Arc::new(InnerSemaphore {
324                count: Mutex::new(0),
325                capacity,
326                condvar: Condvar::new(),
327            }),
328        }
329    }
330
331    fn lock(&self) -> SemaphoreGuard {
332        let data = &self.inner;
333        *data
334            .condvar
335            .wait_while(data.count.lock().unwrap(), |count| *count >= data.capacity)
336            .unwrap() += 1;
337        SemaphoreGuard {
338            inner: Arc::clone(&self.inner),
339        }
340    }
341}
342
343struct SemaphoreGuard {
344    inner: Arc<InnerSemaphore>,
345}
346
347impl Drop for SemaphoreGuard {
348    fn drop(&mut self) {
349        let data = &self.inner;
350        *data.count.lock().unwrap() -= 1;
351        data.condvar.notify_one();
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use std::io::Read;
359    use std::net::{Ipv4Addr, Ipv6Addr};
360    use std::thread::sleep;
361
362    #[test]
363    fn test_regular_http_operations() -> Result<()> {
364        test_server("localhost", 9999, [
365            "GET / HTTP/1.1\nhost: localhost:9999\n\n",
366            "POST /foo HTTP/1.1\nhost: localhost:9999\nexpect: 100-continue\nconnection:close\ncontent-length:4\n\nabcd",
367        ], [
368            "HTTP/1.1 200 OK\r\nserver: OxHTTP/1.0\r\ncontent-length: 4\r\n\r\nhome",
369            "HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 404 Not Found\r\nserver: OxHTTP/1.0\r\ncontent-length: 0\r\n\r\n"
370        ])
371    }
372
373    #[test]
374    fn test_bad_request() -> Result<()> {
375        test_server(
376            "::1", 9998,
377            ["GET / HTTP/1.1\nhost: localhost:9999\nfoo\n\n"],
378            ["HTTP/1.1 400 Bad Request\r\ncontent-type: text/plain; charset=utf-8\r\nserver: OxHTTP/1.0\r\ncontent-length: 19\r\n\r\ninvalid header name"],
379        )
380    }
381
382    #[test]
383    fn test_bad_expect() -> Result<()> {
384        test_server(
385            "127.0.0.1", 9997,
386            ["GET / HTTP/1.1\nhost: localhost:9999\nexpect: bad\n\n"],
387            ["HTTP/1.1 417 Expectation Failed\r\ncontent-type: text/plain; charset=utf-8\r\nserver: OxHTTP/1.0\r\ncontent-length: 43\r\n\r\nExpect header value 'bad' is not supported."],
388        )
389    }
390
391    fn test_server(
392        request_host: &'static str,
393        server_port: u16,
394        requests: impl IntoIterator<Item = &'static str>,
395        responses: impl IntoIterator<Item = &'static str>,
396    ) -> Result<()> {
397        Server::new(|request| {
398            if request.uri().path() == "/" {
399                Response::builder().body(Body::from("home")).unwrap()
400            } else {
401                Response::builder()
402                    .status(StatusCode::NOT_FOUND)
403                    .body(Body::empty())
404                    .unwrap()
405            }
406        })
407        .bind((Ipv4Addr::LOCALHOST, server_port))
408        .bind((Ipv6Addr::LOCALHOST, server_port))
409        .with_server_name("OxHTTP/1.0")
410        .unwrap()
411        .with_global_timeout(Duration::from_secs(1))
412        .spawn()?;
413        sleep(Duration::from_millis(100)); // Makes sure the server is up
414        let mut stream = TcpStream::connect((request_host, server_port))?;
415        for (request, response) in requests.into_iter().zip(responses) {
416            stream.write_all(request.as_bytes())?;
417            let mut output = vec![b'\0'; response.len()];
418            stream.read_exact(&mut output)?;
419            assert_eq!(String::from_utf8(output).unwrap(), response);
420        }
421        Ok(())
422    }
423
424    #[test]
425    fn test_thread_limit() -> Result<()> {
426        let server_port = 9996;
427        let request = b"GET / HTTP/1.1\nhost: localhost:9999\n\n";
428        let response = b"HTTP/1.1 200 OK\r\nserver: OxHTTP/1.0\r\ncontent-length: 4\r\n\r\nhome";
429        Server::new(|_| Response::builder().body(Body::from("home")).unwrap())
430            .bind((Ipv4Addr::LOCALHOST, server_port))
431            .bind((Ipv6Addr::LOCALHOST, server_port))
432            .with_server_name("OxHTTP/1.0")
433            .unwrap()
434            .with_global_timeout(Duration::from_secs(1))
435            .with_max_concurrent_connections(2)
436            .spawn()?;
437        sleep(Duration::from_millis(100)); // Makes sure the server is up
438        let streams = (0..128)
439            .map(|_| {
440                let mut stream = TcpStream::connect(("localhost", server_port))?;
441                stream.write_all(request)?;
442                Ok(stream)
443            })
444            .collect::<Result<Vec<_>>>()?;
445        for mut stream in streams {
446            let mut output = vec![b'\0'; response.len()];
447            stream.read_exact(&mut output)?;
448            assert_eq!(output, response);
449        }
450        Ok(())
451    }
452}