1use crate::io::{decode_request_body, decode_request_headers, encode_response, BUFFER_CAPACITY};
2use crate::model::header::{InvalidHeaderValue, CONNECTION, CONTENT_TYPE, EXPECT, SERVER};
3use crate::model::request::Builder as RequestBuilder;
4use crate::model::{Body, HeaderValue, Request, Response, StatusCode, Version};
5use std::fmt;
6use std::io::{copy, sink, BufReader, BufWriter, Error, ErrorKind, Result, Write};
7use std::net::{SocketAddr, TcpListener, TcpStream};
8use std::sync::{Arc, Condvar, Mutex};
9use std::thread::{Builder as ThreadBuilder, JoinHandle};
10use std::time::Duration;
11
12#[allow(missing_copy_implementations)]
42pub struct Server {
43 #[allow(clippy::type_complexity)]
44 on_request: Arc<dyn Fn(&mut Request<Body>) -> Response<Body> + Send + Sync + 'static>,
45 socket_addrs: Vec<SocketAddr>,
46 timeout: Option<Duration>,
47 server: Option<HeaderValue>,
48 max_num_thread: Option<usize>,
49}
50
51impl Server {
52 #[inline]
54 pub fn new(
55 on_request: impl Fn(&mut Request<Body>) -> Response<Body> + Send + Sync + 'static,
56 ) -> Self {
57 Self {
58 on_request: Arc::new(on_request),
59 socket_addrs: Vec::new(),
60 timeout: None,
61 server: None,
62 max_num_thread: None,
63 }
64 }
65
66 pub fn bind(mut self, addr: impl Into<SocketAddr>) -> Self {
68 let addr = addr.into();
69 if !self.socket_addrs.contains(&addr) {
70 self.socket_addrs.push(addr);
71 }
72 self
73 }
74
75 #[inline]
77 pub fn with_server_name(
78 mut self,
79 server: impl Into<String>,
80 ) -> std::result::Result<Self, InvalidHeaderValue> {
81 self.server = Some(HeaderValue::try_from(server.into())?);
82 Ok(self)
83 }
84
85 #[inline]
87 pub fn with_global_timeout(mut self, timeout: Duration) -> Self {
88 self.timeout = Some(timeout);
89 self
90 }
91
92 #[inline]
94 pub fn with_max_concurrent_connections(mut self, max_num_thread: usize) -> Self {
95 self.max_num_thread = Some(max_num_thread);
96 self
97 }
98
99 pub fn spawn(self) -> Result<ListeningServer> {
104 let timeout = self.timeout;
105 let thread_limit = self.max_num_thread.map(Semaphore::new);
106 let listener_threads = self.socket_addrs
107 .into_iter()
108 .map(|listener_addr| {
109 let listener = TcpListener::bind(listener_addr)?;
110 let thread_name = format!("{listener_addr}: listener thread of OxHTTP");
111 let thread_limit = thread_limit.clone();
112 let on_request = Arc::clone(&self.on_request);
113 let server = self.server.clone();
114 ThreadBuilder::new().name(thread_name).spawn(move || {
115 for stream in listener.incoming() {
116 match stream {
117 Ok(stream) => {
118 let peer_addr = match stream.peer_addr() {
119 Ok(peer) => peer,
120 Err(error) => {
121 eprintln!("OxHTTP TCP error when attempting to get the peer address: {error}");
122 continue;
123 }
124 };
125 if let Err(error) = stream.set_nodelay(true) {
126 eprintln!("OxHTTP TCP error when attempting to set the TCP_NODELAY option: {error}");
127 }
128 let thread_name = format!("{peer_addr}: responding thread of OxHTTP");
129 let thread_guard = thread_limit.as_ref().map(|s| s.lock());
130 let on_request = Arc::clone(&on_request);
131 let server = server.clone();
132 if let Err(error) = ThreadBuilder::new().name(thread_name).spawn(
133 move || {
134 if let Err(error) =
135 accept_request(stream, &*on_request, timeout, &server)
136 {
137 eprintln!(
138 "OxHTTP TCP error when writing response to {peer_addr}: {error}"
139 )
140 }
141 drop(thread_guard);
142 }
143 ) {
144 eprintln!("OxHTTP thread spawn error: {error}");
145 }
146 }
147 Err(error) => {
148 eprintln!("OxHTTP TCP error when opening stream: {error}");
149 }
150 }
151 }
152 })
153 })
154 .collect::<Result<Vec<_>>>()?;
155 Ok(ListeningServer {
156 threads: listener_threads,
157 })
158 }
159}
160
161pub struct ListeningServer {
163 threads: Vec<JoinHandle<()>>,
164}
165
166impl ListeningServer {
167 pub fn join(self) -> Result<()> {
169 for thread in self.threads {
170 thread.join().map_err(|e| {
171 Error::other(if let Ok(e) = e.downcast::<&dyn fmt::Display>() {
172 format!("The server thread panicked with error: {e}")
173 } else {
174 "The server thread panicked with an unknown error".into()
175 })
176 })?;
177 }
178 Ok(())
179 }
180}
181
182fn accept_request(
183 mut stream: TcpStream,
184 on_request: &dyn Fn(&mut Request<Body>) -> Response<Body>,
185 timeout: Option<Duration>,
186 server: &Option<HeaderValue>,
187) -> Result<()> {
188 stream.set_read_timeout(timeout)?;
189 stream.set_write_timeout(timeout)?;
190 let mut connection_state = ConnectionState::KeepAlive;
191 while connection_state == ConnectionState::KeepAlive {
192 let mut reader = BufReader::with_capacity(BUFFER_CAPACITY, stream.try_clone()?);
193 let (mut response, new_connection_state) = match decode_request_headers(&mut reader, false)
194 {
195 Ok(request) => {
196 if let Some(expect) = request.headers_ref().unwrap().get(EXPECT).cloned() {
198 if request
199 .version_ref()
200 .map_or(true, |v| *v >= Version::HTTP_11)
201 && expect.as_bytes().eq_ignore_ascii_case(b"100-continue")
202 {
203 stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n")?;
204 read_body_and_build_response(request, reader, on_request)
205 } else {
206 (
207 build_text_response(
208 StatusCode::EXPECTATION_FAILED,
209 format!(
210 "Expect header value '{}' is not supported.",
211 String::from_utf8_lossy(expect.as_ref())
212 ),
213 ),
214 ConnectionState::Close,
215 )
216 }
217 } else {
218 read_body_and_build_response(request, reader, on_request)
219 }
220 }
221 Err(error) => {
222 if error.kind() == ErrorKind::ConnectionAborted {
223 return Ok(()); } else {
225 (build_error(error), ConnectionState::Close)
226 }
227 }
228 };
229 connection_state = new_connection_state;
230
231 if let Some(server) = server {
233 response
234 .headers_mut()
235 .entry(SERVER)
236 .or_insert_with(|| server.clone());
237 }
238
239 stream = encode_response(
240 &mut response,
241 BufWriter::with_capacity(BUFFER_CAPACITY, stream),
242 )?
243 .into_inner()
244 .map_err(|e| e.into_error())?;
245 }
246 Ok(())
247}
248
249#[derive(Eq, PartialEq, Debug, Copy, Clone)]
250enum ConnectionState {
251 Close,
252 KeepAlive,
253}
254
255fn read_body_and_build_response(
256 request: RequestBuilder,
257 reader: BufReader<TcpStream>,
258 on_request: &dyn Fn(&mut Request<Body>) -> Response<Body>,
259) -> (Response<Body>, ConnectionState) {
260 match decode_request_body(request, reader) {
261 Ok(mut request) => {
262 let response = on_request(&mut request);
263 if let Err(error) = copy(request.body_mut(), &mut sink()) {
265 (build_error(error), ConnectionState::Close) } else {
267 let connection_state = request
268 .headers()
269 .get(CONNECTION)
270 .and_then(|v| {
271 v.as_bytes()
272 .eq_ignore_ascii_case(b"close")
273 .then_some(ConnectionState::Close)
274 })
275 .unwrap_or_else(|| {
276 if request.version() <= Version::HTTP_10 {
277 ConnectionState::Close
278 } else {
279 ConnectionState::KeepAlive
280 }
281 });
282 (response, connection_state)
283 }
284 }
285 Err(error) => (build_error(error), ConnectionState::Close),
286 }
287}
288
289fn build_error(error: Error) -> Response<Body> {
290 build_text_response(
291 match error.kind() {
292 ErrorKind::TimedOut => StatusCode::REQUEST_TIMEOUT,
293 ErrorKind::InvalidData => StatusCode::BAD_REQUEST,
294 _ => StatusCode::INTERNAL_SERVER_ERROR,
295 },
296 error.to_string(),
297 )
298}
299
300fn build_text_response(status: StatusCode, text: String) -> Response<Body> {
301 Response::builder()
302 .status(status)
303 .header(CONTENT_TYPE, "text/plain; charset=utf-8")
304 .body(Body::from(text))
305 .unwrap()
306}
307
308#[derive(Clone)]
310struct Semaphore {
311 inner: Arc<InnerSemaphore>,
312}
313
314struct InnerSemaphore {
315 count: Mutex<usize>,
316 capacity: usize,
317 condvar: Condvar,
318}
319
320impl Semaphore {
321 fn new(capacity: usize) -> Self {
322 Self {
323 inner: Arc::new(InnerSemaphore {
324 count: Mutex::new(0),
325 capacity,
326 condvar: Condvar::new(),
327 }),
328 }
329 }
330
331 fn lock(&self) -> SemaphoreGuard {
332 let data = &self.inner;
333 *data
334 .condvar
335 .wait_while(data.count.lock().unwrap(), |count| *count >= data.capacity)
336 .unwrap() += 1;
337 SemaphoreGuard {
338 inner: Arc::clone(&self.inner),
339 }
340 }
341}
342
343struct SemaphoreGuard {
344 inner: Arc<InnerSemaphore>,
345}
346
347impl Drop for SemaphoreGuard {
348 fn drop(&mut self) {
349 let data = &self.inner;
350 *data.count.lock().unwrap() -= 1;
351 data.condvar.notify_one();
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use std::io::Read;
359 use std::net::{Ipv4Addr, Ipv6Addr};
360 use std::thread::sleep;
361
362 #[test]
363 fn test_regular_http_operations() -> Result<()> {
364 test_server("localhost", 9999, [
365 "GET / HTTP/1.1\nhost: localhost:9999\n\n",
366 "POST /foo HTTP/1.1\nhost: localhost:9999\nexpect: 100-continue\nconnection:close\ncontent-length:4\n\nabcd",
367 ], [
368 "HTTP/1.1 200 OK\r\nserver: OxHTTP/1.0\r\ncontent-length: 4\r\n\r\nhome",
369 "HTTP/1.1 100 Continue\r\n\r\nHTTP/1.1 404 Not Found\r\nserver: OxHTTP/1.0\r\ncontent-length: 0\r\n\r\n"
370 ])
371 }
372
373 #[test]
374 fn test_bad_request() -> Result<()> {
375 test_server(
376 "::1", 9998,
377 ["GET / HTTP/1.1\nhost: localhost:9999\nfoo\n\n"],
378 ["HTTP/1.1 400 Bad Request\r\ncontent-type: text/plain; charset=utf-8\r\nserver: OxHTTP/1.0\r\ncontent-length: 19\r\n\r\ninvalid header name"],
379 )
380 }
381
382 #[test]
383 fn test_bad_expect() -> Result<()> {
384 test_server(
385 "127.0.0.1", 9997,
386 ["GET / HTTP/1.1\nhost: localhost:9999\nexpect: bad\n\n"],
387 ["HTTP/1.1 417 Expectation Failed\r\ncontent-type: text/plain; charset=utf-8\r\nserver: OxHTTP/1.0\r\ncontent-length: 43\r\n\r\nExpect header value 'bad' is not supported."],
388 )
389 }
390
391 fn test_server(
392 request_host: &'static str,
393 server_port: u16,
394 requests: impl IntoIterator<Item = &'static str>,
395 responses: impl IntoIterator<Item = &'static str>,
396 ) -> Result<()> {
397 Server::new(|request| {
398 if request.uri().path() == "/" {
399 Response::builder().body(Body::from("home")).unwrap()
400 } else {
401 Response::builder()
402 .status(StatusCode::NOT_FOUND)
403 .body(Body::empty())
404 .unwrap()
405 }
406 })
407 .bind((Ipv4Addr::LOCALHOST, server_port))
408 .bind((Ipv6Addr::LOCALHOST, server_port))
409 .with_server_name("OxHTTP/1.0")
410 .unwrap()
411 .with_global_timeout(Duration::from_secs(1))
412 .spawn()?;
413 sleep(Duration::from_millis(100)); let mut stream = TcpStream::connect((request_host, server_port))?;
415 for (request, response) in requests.into_iter().zip(responses) {
416 stream.write_all(request.as_bytes())?;
417 let mut output = vec![b'\0'; response.len()];
418 stream.read_exact(&mut output)?;
419 assert_eq!(String::from_utf8(output).unwrap(), response);
420 }
421 Ok(())
422 }
423
424 #[test]
425 fn test_thread_limit() -> Result<()> {
426 let server_port = 9996;
427 let request = b"GET / HTTP/1.1\nhost: localhost:9999\n\n";
428 let response = b"HTTP/1.1 200 OK\r\nserver: OxHTTP/1.0\r\ncontent-length: 4\r\n\r\nhome";
429 Server::new(|_| Response::builder().body(Body::from("home")).unwrap())
430 .bind((Ipv4Addr::LOCALHOST, server_port))
431 .bind((Ipv6Addr::LOCALHOST, server_port))
432 .with_server_name("OxHTTP/1.0")
433 .unwrap()
434 .with_global_timeout(Duration::from_secs(1))
435 .with_max_concurrent_connections(2)
436 .spawn()?;
437 sleep(Duration::from_millis(100)); let streams = (0..128)
439 .map(|_| {
440 let mut stream = TcpStream::connect(("localhost", server_port))?;
441 stream.write_all(request)?;
442 Ok(stream)
443 })
444 .collect::<Result<Vec<_>>>()?;
445 for mut stream in streams {
446 let mut output = vec![b'\0'; response.len()];
447 stream.read_exact(&mut output)?;
448 assert_eq!(output, response);
449 }
450 Ok(())
451 }
452}