dbs_uhttp/
server.rs

1// Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::io::{ErrorKind, Read, Write};
6use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
7use std::os::unix::net::{UnixListener, UnixStream};
8use std::path::Path;
9use std::thread::sleep;
10use std::time::Duration;
11
12use mio::unix::SourceFd;
13use mio::{Events, Interest, Poll, Token};
14
15use crate::common::sock_ctrl_msg::ScmSocket;
16use crate::common::{Body, ConnectionError, ServerError, Version};
17use crate::connection::HttpConnection;
18use crate::request::Request;
19use crate::response::{Response, StatusCode};
20
21static SERVER_FULL_ERROR_MESSAGE: &[u8] = b"HTTP/1.1 503\r\n\
22                                            Server: Firecracker API\r\n\
23                                            Connection: close\r\n\
24                                            Content-Length: 40\r\n\r\n{ \"error\": \"Too many open connections\" }";
25#[cfg(target_os = "linux")]
26const MAX_CONNECTIONS: usize = 256;
27#[cfg(not(target_os = "linux"))]
28const MAX_CONNECTIONS: usize = 10;
29const MAX_EVENTS: usize = 64;
30/// Payload max size
31pub(crate) const MAX_PAYLOAD_SIZE: usize = 51200;
32
33type Result<T> = std::result::Result<T, ServerError>;
34
35/// Wrapper over `Request` which adds an identification token.
36#[derive(Debug)]
37pub struct ServerRequest {
38    /// Inner request.
39    pub request: Request,
40    /// Identification token.
41    id: Token,
42}
43
44impl ServerRequest {
45    /// Creates a new `ServerRequest` object from an existing `Request`,
46    /// adding an identification token.
47    pub fn new(request: Request, id: Token) -> Self {
48        Self { request, id }
49    }
50
51    /// Returns a reference to the inner request.
52    pub fn inner(&self) -> &Request {
53        &self.request
54    }
55
56    /// Calls the function provided on the inner request to obtain the response.
57    /// The response is then wrapped in a `ServerResponse`.
58    ///
59    /// Returns a `ServerResponse` ready for yielding to the server
60    pub fn process<F>(&self, mut callable: F) -> ServerResponse
61    where
62        F: FnMut(&Request) -> Response,
63    {
64        let http_response = callable(self.inner());
65        ServerResponse::new(http_response, self.id)
66    }
67}
68
69/// Wrapper over `Response` which adds an identification token.
70#[derive(Debug)]
71pub struct ServerResponse {
72    /// Inner response.
73    response: Response,
74    /// Identification token.
75    id: Token,
76}
77
78impl ServerResponse {
79    fn new(response: Response, id: Token) -> Self {
80        Self { response, id }
81    }
82}
83
84/// Describes the state of the connection as far as data exchange
85/// on the stream is concerned.
86#[derive(PartialOrd, PartialEq)]
87enum ClientConnectionState {
88    AwaitingIncoming,
89    AwaitingOutgoing,
90    Closed,
91}
92
93/// Wrapper over `HttpConnection` which keeps track of yielded
94/// requests and absorbed responses.
95struct ClientConnection<T> {
96    /// The `HttpConnection` object which handles data exchange.
97    connection: HttpConnection<T>,
98    /// The state of the connection in the `epoll` structure.
99    state: ClientConnectionState,
100    /// Represents the difference between yielded requests and
101    /// absorbed responses.
102    /// This has to be `0` if we want to drop the connection.
103    in_flight_response_count: u32,
104}
105
106// impl<T: Read + Write + ScmSocket> ClientConnection<T> {
107impl<T: Read + Write + ScmSocket> ClientConnection<T> {
108    fn new(connection: HttpConnection<T>) -> Self {
109        Self {
110            connection,
111            state: ClientConnectionState::AwaitingIncoming,
112            in_flight_response_count: 0,
113        }
114    }
115
116    fn read(&mut self) -> Result<Vec<Request>> {
117        // Data came into the connection.
118        let mut parsed_requests = vec![];
119        let mut retry_limit = 32;
120        'out: loop {
121            match self.connection.try_read() {
122                Err(ConnectionError::ConnectionClosed) => {
123                    // Connection timeout.
124                    self.state = ClientConnectionState::Closed;
125                    // We don't want to propagate this to the server and we will
126                    // return no requests and wait for the connection to become
127                    // safe to drop.
128                    return Ok(vec![]);
129                }
130                Err(ConnectionError::StreamReadError(inner)) => {
131                    #[cfg(target_os = "linux")]
132                    if inner.errno() == libc::EAGAIN && retry_limit > 0 {
133                        sleep(Duration::from_micros(20));
134                        retry_limit -= 1;
135                        continue;
136                    }
137
138                    // Reading from the connection failed.
139                    // We should try to write an error message regardless.
140                    let mut internal_error_response =
141                        Response::new(Version::Http11, StatusCode::InternalServerError);
142                    internal_error_response.set_body(Body::new(inner.to_string()));
143                    self.connection.enqueue_response(internal_error_response);
144                    break;
145                }
146                Err(ConnectionError::ParseError(inner)) => {
147                    // An error occurred while parsing the read bytes.
148                    // Check if there are any valid parsed requests in the queue.
149                    while let Some(_discarded_request) = self.connection.pop_parsed_request() {}
150
151                    // Send an error response for the request that gave us the error.
152                    let mut error_response = Response::new(Version::Http11, StatusCode::BadRequest);
153                    error_response.set_body(Body::new(format!(
154                    "{{ \"error\": \"{}\nAll previous unanswered requests will be dropped.\" }}",
155                    inner
156                )));
157                    self.connection.enqueue_response(error_response);
158                    break;
159                }
160                Err(ConnectionError::InvalidWrite) | Err(ConnectionError::StreamWriteError(_)) => {
161                    // This is unreachable because `HttpConnection::try_read()` cannot return this error variant.
162                    unreachable!();
163                }
164                Ok(()) => {
165                    if self.connection.has_parsed_requests() {
166                        while let Some(request) = self.connection.pop_parsed_request() {
167                            // Add all valid requests to `parsed_requests`.
168                            parsed_requests.push(request);
169                        }
170                        break 'out;
171                    }
172                }
173            }
174        }
175        self.in_flight_response_count = self
176            .in_flight_response_count
177            .checked_add(parsed_requests.len() as u32)
178            .ok_or(ServerError::Overflow)?;
179        // If the state of the connection has changed, we need to update
180        // the event set in the `epoll` structure.
181        if self.connection.pending_write() {
182            self.state = ClientConnectionState::AwaitingOutgoing;
183        }
184
185        Ok(parsed_requests)
186    }
187
188    fn write(&mut self) -> Result<()> {
189        while self.state != ClientConnectionState::Closed {
190            match self.connection.try_write() {
191                Err(ConnectionError::ConnectionClosed)
192                | Err(ConnectionError::StreamWriteError(_)) => {
193                    // Writing to the stream failed so it will be removed.
194                    self.state = ClientConnectionState::Closed;
195                }
196                Err(ConnectionError::InvalidWrite) => {
197                    // A `try_write` call was performed on a connection that has nothing
198                    // to write.
199                    return Err(ServerError::ConnectionError(ConnectionError::InvalidWrite));
200                }
201                _ => {
202                    // Check if we still have bytes to write for this connection.
203                    if !self.connection.pending_write() {
204                        self.state = ClientConnectionState::AwaitingIncoming;
205                        break;
206                    }
207                }
208            }
209        }
210        Ok(())
211    }
212
213    fn enqueue_response(&mut self, response: Response) -> Result<()> {
214        if self.state != ClientConnectionState::Closed {
215            self.connection.enqueue_response(response);
216        }
217        self.in_flight_response_count = self
218            .in_flight_response_count
219            .checked_sub(1)
220            .ok_or(ServerError::Underflow)?;
221        Ok(())
222    }
223
224    /// Discards all pending writes from the inner connection.
225    fn clear_write_buffer(&mut self) {
226        self.connection.clear_write_buffer();
227    }
228
229    // Returns `true` if the connection is closed and safe to drop.
230    fn is_done(&self) -> bool {
231        self.state == ClientConnectionState::Closed
232            && !self.connection.pending_write()
233            && self.in_flight_response_count == 0
234    }
235
236    // Make the connection as closed.
237    fn close(&mut self) {
238        self.clear_write_buffer();
239        self.state = ClientConnectionState::Closed;
240        //self.in_flight_response_count = 0;
241    }
242}
243
244/// HTTP Server implementation using Unix Domain Sockets and `EPOLL` to
245/// handle multiple connections on the same thread.
246///
247/// The function that handles incoming connections, parses incoming
248/// requests and sends responses for awaiting requests is `requests`.
249/// It can be called in a loop, which will render the thread that the
250/// server runs on incapable of performing other operations, or it can
251/// be used in another `EPOLL` structure, as it provides its `epoll`,
252/// which is a wrapper over the file descriptor of the epoll structure
253/// used within the server, and it can be added to another one using
254/// the `EPOLLIN` flag. Whenever there is a notification on that fd,
255/// `requests` should be called once.
256///
257/// # Example
258///
259/// ## Starting and running the server
260///
261/// ```
262/// use dbs_uhttp::{HttpServer, Response, StatusCode};
263///
264/// let path_to_socket = "/tmp/example.sock";
265/// std::fs::remove_file(path_to_socket).unwrap_or_default();
266///
267/// // Start the server.
268/// let mut server = HttpServer::new(path_to_socket).unwrap();
269/// server.start_server().unwrap();
270///
271/// // Connect a client to the server so it doesn't block in our example.
272/// let mut socket = std::os::unix::net::UnixStream::connect(path_to_socket).unwrap();
273///
274/// // Server loop processing requests.
275/// loop {
276///     for request in server.requests().unwrap() {
277///         let response = request.process(|request| {
278///             // Your code here.
279///             Response::new(request.http_version(), StatusCode::NoContent)
280///         });
281///         server.respond(response);
282///     }
283///     // Break this example loop.
284///     break;
285/// }
286/// ```
287pub struct HttpServer {
288    /// Socket on which we listen for new connections.
289    socket: UnixListener,
290    /// Server's epoll instance.
291    poll: Poll,
292    /// Holds the token-connection pairs of the server.
293    /// Each connection has an associated identification token, which is
294    /// the file descriptor of the underlying stream.
295    /// We use the file descriptor of the stream as the key for mapping
296    /// connections because the 1-to-1 relation is guaranteed by the OS.
297    connections: HashMap<Token, ClientConnection<UnixStream>>,
298    /// Payload max size
299    payload_max_size: usize,
300}
301
302impl HttpServer {
303    /// Constructor for `HttpServer`.
304    ///
305    /// Returns the newly formed `HttpServer`.
306    ///
307    /// # Errors
308    /// Returns an `IOError` when binding or `epoll::create` fails.
309    pub fn new<P: AsRef<Path>>(path_to_socket: P) -> Result<Self> {
310        let socket = UnixListener::bind(path_to_socket).map_err(ServerError::IOError)?;
311        Self::new_from_socket(socket)
312    }
313
314    /// Constructor for `HttpServer`.
315    ///
316    /// Note that this function requires the socket_fd to be solely owned
317    /// and not be associated with another File in the caller as it uses
318    /// the unsafe `UnixListener::from_raw_fd method`.
319    ///
320    /// Returns the newly formed `HttpServer`.
321    ///
322    /// # Errors
323    /// Returns an `IOError` when `epoll::create` fails.
324    pub fn new_from_fd(socket_fd: RawFd) -> Result<Self> {
325        let socket = unsafe { UnixListener::from_raw_fd(socket_fd) };
326        Self::new_from_socket(socket)
327    }
328
329    fn new_from_socket(socket: UnixListener) -> Result<Self> {
330        // as mio use edge trigger epoll under the hook, we should set nonblocking on socket
331        // otherwise we will miss event in some cases
332        socket.set_nonblocking(true).map_err(ServerError::IOError)?;
333        let poll = Poll::new().map_err(ServerError::IOError)?;
334        Ok(HttpServer {
335            socket,
336            poll,
337            connections: HashMap::new(),
338            payload_max_size: MAX_PAYLOAD_SIZE,
339        })
340    }
341
342    /// This function sets the limit for PUT/PATCH requests. It overwrites the
343    /// default limit of 0.05MiB with the one allowed by server.
344    pub fn set_payload_max_size(&mut self, request_payload_max_size: usize) {
345        self.payload_max_size = request_payload_max_size;
346    }
347
348    /// Starts the HTTP Server.
349    pub fn start_server(&mut self) -> Result<()> {
350        // Add the socket on which we listen for new connections to the
351        // `epoll` structure.
352        Self::epoll_add(
353            &self.poll,
354            Token(self.socket.as_raw_fd() as usize),
355            self.socket.as_raw_fd(),
356        )
357    }
358
359    /// poll event use mio poll method, and handle Interrupted error explicitly
360    fn poll_events(&mut self, events: &mut Events) -> Result<()> {
361        loop {
362            if let Err(e) = self.poll.poll(events, None) {
363                if e.kind() == ErrorKind::Interrupted || e.kind() == ErrorKind::WouldBlock {
364                    continue;
365                }
366                return Err(ServerError::IOError(e));
367            }
368            return Ok(());
369        }
370    }
371
372    /// This function is responsible for the data exchange with the clients and should
373    /// be called when we are either notified through `epoll` that we need to exchange
374    /// data with at least a client or when we don't need to perform any other operations
375    /// on this thread and we can afford to call it in a loop.
376    ///
377    /// Note that this function will block the current thread if there are no notifications
378    /// to be handled by the server.
379    ///
380    /// Returns a collection of complete and valid requests to be processed by the user
381    /// of the server. Once processed, responses should be sent using `enqueue_responses()`.
382    ///
383    /// # Errors
384    /// `IOError` is returned when `read`, `write` or `epoll::ctl` operations fail.
385    /// `ServerFull` is returned when a client is trying to connect to the server, but
386    /// full capacity has already been reached.
387    /// `InvalidWrite` is returned when the server attempted to perform a write operation
388    /// on a connection on which it is not possible.
389    pub fn requests(&mut self) -> Result<Vec<ServerRequest>> {
390        let mut parsed_requests: Vec<ServerRequest> = vec![];
391        let mut events = mio::Events::with_capacity(MAX_EVENTS);
392        // This is a wrapper over the syscall `epoll_wait` and it will block the
393        // current thread until at least one event is received.
394        // The received notifications will then populate the `events` array with
395        // `event_count` elements, where 1 <= event_count <= MAX_EVENTS.
396        self.poll_events(&mut events)?;
397
398        // We use `take()` on the iterator over `events` as, even though only
399        // `events_count` events have been inserted into `events`, the size of
400        // the array is still `MAX_EVENTS`, so we discard empty elements
401        // at the end of the array.
402        for e in events.iter() {
403            // Check the file descriptor which produced the notification `e`.
404            // It could be that we have a new connection, or one of our open
405            // connections is ready to exchange data with a client.
406            match e.token() {
407                Token(fd) if fd == self.socket.as_raw_fd() as usize => {
408                    match self.handle_new_connection() {
409                        Err(ServerError::ServerFull) => {
410                            self.socket
411                                .accept()
412                                .map_err(ServerError::IOError)
413                                .and_then(move |(mut stream, _)| {
414                                    stream
415                                        .write(SERVER_FULL_ERROR_MESSAGE)
416                                        .map_err(ServerError::IOError)
417                                })?;
418                        }
419                        // An internal error will compromise any in-flight requests.
420                        Err(error) => return Err(error),
421                        Ok(()) => {}
422                    }
423                }
424                t => {
425                    let client_connection = self.connections.get_mut(&t).unwrap();
426                    // If we receive a hang up on a connection, we clear the write buffer and set
427                    // the connection state to closed to mark it ready for removal from the
428                    // connections map, which will gracefully close the socket.
429                    // The connection is also marked for removal when encountering `EPOLLERR`,
430                    // since this is an "error condition happened on the associated file
431                    // descriptor", according to the `epoll_ctl` man page.
432                    if e.is_error() || e.is_read_closed() || e.is_write_closed() {
433                        client_connection.close();
434                        continue;
435                    }
436
437                    if e.is_readable() {
438                        // We have bytes to read from this connection.
439                        // If our `read` yields `Request` objects, we wrap them with an ID before
440                        // handing them to the user.
441                        parsed_requests.append(
442                            &mut client_connection
443                                .read()?
444                                .into_iter()
445                                .map(|request| ServerRequest::new(request, e.token()))
446                                .collect(),
447                        );
448
449                        // If the connection was incoming before we read and we now have to write
450                        // either an error message or an `expect` response, we change its `epoll`
451                        // event set to notify us when the stream is ready for writing.
452                        if client_connection.state == ClientConnectionState::AwaitingOutgoing {
453                            Self::epoll_mod(
454                                &self.poll,
455                                client_connection.connection.as_raw_fd(),
456                                t,
457                                Interest::WRITABLE,
458                                // epoll::EventSet::OUT | epoll::EventSet::READ_HANG_UP,
459                            )?;
460                        }
461                    } else if e.is_writable() {
462                        // We have bytes to write on this connection.
463                        client_connection.write()?;
464                        // If the connection was outgoing before we tried to write the responses
465                        // and we don't have any more responses to write, we change the `epoll`
466                        // event set to notify us when we have bytes to read from the stream.
467                        if client_connection.state == ClientConnectionState::AwaitingIncoming {
468                            Self::epoll_mod(
469                                &self.poll,
470                                client_connection.connection.as_raw_fd(),
471                                t,
472                                Interest::READABLE,
473                            )?;
474                        }
475                    }
476                }
477            }
478        }
479
480        // Remove dead connections.
481        let epoll = &self.poll;
482        self.connections.retain(|_token, client_connection| {
483            if client_connection.is_done() {
484                // The rawfd should have been registered to the epoll fd.
485                Self::epoll_del(epoll, client_connection.connection.as_raw_fd()).unwrap();
486                false
487            } else {
488                true
489            }
490        });
491
492        Ok(parsed_requests)
493    }
494
495    /// This function is responsible with flushing any remaining outgoing
496    /// requests on the server.
497    ///
498    /// Note that this function can block the thread on write, since the
499    /// operation is blocking.
500    pub fn flush_outgoing_writes(&mut self) {
501        for (_, connection) in self.connections.iter_mut() {
502            while connection.state == ClientConnectionState::AwaitingOutgoing {
503                if let Err(e) = connection.write() {
504                    if let ServerError::ConnectionError(ConnectionError::InvalidWrite) = e {
505                        // Nothing is logged since an InvalidWrite means we have successfully
506                        // flushed the connection
507                    }
508                    break;
509                }
510            }
511        }
512    }
513
514    /// The file descriptor of the `epoll` structure can enable the server to become
515    /// a non-blocking structure in an application.
516    ///
517    /// Returns a reference to the instance of the server's internal `Poll` structure.
518    pub fn epoll(&self) -> &Poll {
519        &self.poll
520    }
521
522    /// Enqueues the provided responses in the outgoing connection.
523    ///
524    /// # Errors
525    /// `IOError` is returned when an `epoll::ctl` operation fails.
526    pub fn enqueue_responses(&mut self, responses: Vec<ServerResponse>) -> Result<()> {
527        for response in responses {
528            self.respond(response)?;
529        }
530
531        Ok(())
532    }
533
534    /// Adds the provided response to the outgoing buffer in the corresponding connection.
535    ///
536    /// # Errors
537    /// `IOError` is returned when an `epoll::ctl` operation fails.
538    /// `Underflow` is returned when `enqueue_response` fails.
539    pub fn respond(&mut self, response: ServerResponse) -> Result<()> {
540        if let Some(client_connection) = self.connections.get_mut(&response.id) {
541            // If the connection was incoming before we enqueue the response, we change its
542            // `epoll` event set to notify us when the stream is ready for writing.
543            if ClientConnectionState::AwaitingIncoming == client_connection.state {
544                client_connection.state = ClientConnectionState::AwaitingOutgoing;
545                Self::epoll_mod(
546                    &self.poll,
547                    client_connection.connection.as_raw_fd(),
548                    response.id,
549                    Interest::WRITABLE,
550                    // epoll::EventSet::OUT | epoll::EventSet::READ_HANG_UP,
551                )?;
552            }
553            client_connection.enqueue_response(response.response)?;
554        }
555        Ok(())
556    }
557
558    /// Accepts a new incoming connection and adds it to the `epoll` notification structure.
559    ///
560    /// # Errors
561    /// `IOError` is returned when socket or epoll operations fail.
562    /// `ServerFull` is returned if server full capacity has been reached.
563    fn handle_new_connection(&mut self) -> Result<()> {
564        if self.connections.len() == MAX_CONNECTIONS {
565            // If we want a replacement policy for connections
566            // this is where we will have it.
567            return Err(ServerError::ServerFull);
568        }
569        loop {
570            if let Err(e) = self
571                .socket
572                .accept()
573                .and_then(|(stream, _)| stream.set_nonblocking(true).map(|_| stream))
574                .and_then(|stream| {
575                    let raw_fd = stream.as_raw_fd();
576                    let token = Token(raw_fd as usize);
577                    self.poll.registry().register(
578                        &mut SourceFd(&raw_fd),
579                        token,
580                        Interest::READABLE,
581                    )?;
582                    let mut conn = HttpConnection::new(stream);
583                    conn.set_payload_max_size(self.payload_max_size);
584                    self.connections.insert(token, ClientConnection::new(conn));
585                    Ok(())
586                })
587            {
588                if e.kind() == ErrorKind::Interrupted {
589                    continue;
590                }
591                if e.kind() == ErrorKind::WouldBlock {
592                    break;
593                }
594                return Err(ServerError::IOError(e));
595            }
596        }
597        Ok(())
598    }
599
600    /// Changes the event type for a connection to either listen for incoming bytes
601    /// or for when the stream is ready for writing.
602    ///
603    /// # Errors
604    /// `IOError` is returned when an `EPOLL_CTL_MOD` control operation fails.
605    fn epoll_mod(epoll: &Poll, stream_fd: RawFd, token: Token, evset: Interest) -> Result<()> {
606        epoll
607            .registry()
608            .reregister(&mut SourceFd(&stream_fd), token, evset)
609            .map_err(ServerError::IOError)
610    }
611
612    /// Adds a stream to the `epoll` notification structure with the `EPOLLIN` event set.
613    ///
614    /// # Errors
615    /// `IOError` is returned when an `EPOLL_CTL_ADD` control operation fails.
616    fn epoll_add(poll: &Poll, token: Token, stream_fd: RawFd) -> Result<()> {
617        poll.registry()
618            .register(&mut SourceFd(&stream_fd), token, Interest::READABLE)
619            .map_err(ServerError::IOError)
620    }
621
622    /// Removes a stream to the `epoll` notification structure.
623    fn epoll_del(poll: &Poll, stream_fd: RawFd) -> Result<()> {
624        poll.registry()
625            .deregister(&mut SourceFd(&stream_fd))
626            .map_err(ServerError::IOError)
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633    use std::io::{Read, Write};
634    use std::net::Shutdown;
635    use std::os::unix::net::UnixStream;
636
637    use crate::common::Body;
638    use vmm_sys_util::tempfile::TempFile;
639
640    fn get_temp_socket_file() -> TempFile {
641        let mut path_to_socket = TempFile::new().unwrap();
642        path_to_socket.remove().unwrap();
643        path_to_socket
644    }
645
646    #[test]
647    fn test_wait_one_connection() {
648        let path_to_socket = get_temp_socket_file();
649
650        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
651        server.start_server().unwrap();
652
653        // Test one incoming connection.
654        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
655        assert!(server.requests().unwrap().is_empty());
656
657        socket
658            .write_all(
659                b"PATCH /machine-config HTTP/1.1\r\n\
660                         Content-Length: 13\r\n\
661                         Content-Type: application/json\r\n\r\nwhatever body",
662            )
663            .unwrap();
664
665        let mut req_vec = server.requests().unwrap();
666        let server_request = req_vec.remove(0);
667
668        server
669            .respond(server_request.process(|_request| {
670                let mut response = Response::new(Version::Http11, StatusCode::OK);
671                let response_body = b"response body";
672                response.set_body(Body::new(response_body.to_vec()));
673                response
674            }))
675            .unwrap();
676        assert!(server.requests().unwrap().is_empty());
677
678        let mut buf: [u8; 1024] = [0; 1024];
679        assert!(socket.read(&mut buf[..]).unwrap() > 0);
680    }
681
682    #[test]
683    fn test_large_payload() {
684        let path_to_socket = get_temp_socket_file();
685
686        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
687        server.start_server().unwrap();
688
689        // Test one incoming connection.
690        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
691        assert!(server.requests().unwrap().is_empty());
692
693        let mut packets = String::from(
694            "PATCH /machine-config HTTP/1.1\r\n\
695                         Content-Length: 1028\r\n\
696                         Content-Type: application/json\r\n\r\n",
697        );
698        for i in 0..1028 {
699            packets.push_str(&i.to_string());
700        }
701
702        socket.write_all(packets.as_bytes()).unwrap();
703
704        let mut req_vec = server.requests().unwrap();
705        let server_request = req_vec.remove(0);
706
707        server
708            .respond(server_request.process(|_request| {
709                let mut response = Response::new(Version::Http11, StatusCode::OK);
710                let response_body = b"response body";
711                response.set_body(Body::new(response_body.to_vec()));
712                response
713            }))
714            .unwrap();
715        assert!(server.requests().unwrap().is_empty());
716
717        let mut buf: [u8; 1024] = [0; 1024];
718        assert!(socket.read(&mut buf[..]).unwrap() > 0);
719    }
720
721    #[test]
722    fn test_connection_size_limit_exceeded() {
723        let path_to_socket = get_temp_socket_file();
724
725        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
726        server.start_server().unwrap();
727
728        // Test one incoming connection.
729        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
730        assert!(server.requests().unwrap().is_empty());
731
732        socket
733            .write_all(
734                b"PATCH /machine-config HTTP/1.1\r\n\
735                         Content-Length: 51201\r\n\
736                         Content-Type: application/json\r\n\r\naaaaa",
737            )
738            .unwrap();
739        assert!(server.requests().unwrap().is_empty());
740        assert!(server.requests().unwrap().is_empty());
741        let mut buf: [u8; 265] = [0; 265];
742        assert!(socket.read(&mut buf[..]).unwrap() > 0);
743        let error_message = b"HTTP/1.1 400 \r\n\
744                              Server: Firecracker API\r\n\
745                              Connection: keep-alive\r\n\
746                              Content-Type: application/json\r\n\
747                              Content-Length: 149\r\n\r\n{ \"error\": \"\
748                              Request payload with size 51201 is larger than \
749                              the limit of 51200 allowed by server.\nAll \
750                              previous unanswered requests will be dropped.";
751        assert_eq!(&buf[..], &error_message[..]);
752    }
753
754    #[test]
755    fn test_set_payload_size() {
756        let path_to_socket = get_temp_socket_file();
757
758        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
759        server.start_server().unwrap();
760        server.set_payload_max_size(4);
761
762        // Test one incoming connection.
763        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
764        assert!(server.requests().unwrap().is_empty());
765
766        socket
767            .write_all(
768                b"PATCH /machine-config HTTP/1.1\r\n\
769                         Content-Length: 5\r\n\
770                         Content-Type: application/json\r\n\r\naaaaa",
771            )
772            .unwrap();
773        assert!(server.requests().unwrap().is_empty());
774        assert!(server.requests().unwrap().is_empty());
775        let mut buf: [u8; 260] = [0; 260];
776        assert!(socket.read(&mut buf[..]).unwrap() > 0);
777        let error_message = b"HTTP/1.1 400 \r\n\
778                              Server: Firecracker API\r\n\
779                              Connection: keep-alive\r\n\
780                              Content-Type: application/json\r\n\
781                              Content-Length: 141\r\n\r\n{ \"error\": \"\
782                              Request payload with size 5 is larger than the \
783                              limit of 4 allowed by server.\nAll previous \
784                              unanswered requests will be dropped.\" }";
785        assert_eq!(&buf[..], &error_message[..]);
786    }
787
788    #[test]
789    fn test_wait_one_fd_connection() {
790        use std::os::unix::io::IntoRawFd;
791        let path_to_socket = get_temp_socket_file();
792
793        let socket_listener = UnixListener::bind(path_to_socket.as_path()).unwrap();
794        let socket_fd = socket_listener.into_raw_fd();
795
796        let mut server = HttpServer::new_from_fd(socket_fd).unwrap();
797        server.start_server().unwrap();
798
799        // Test one incoming connection.
800        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
801        assert!(server.requests().unwrap().is_empty());
802
803        socket
804            .write_all(
805                b"PATCH /machine-config HTTP/1.1\r\n\
806                         Content-Length: 13\r\n\
807                         Content-Type: application/json\r\n\r\nwhatever body",
808            )
809            .unwrap();
810
811        let mut req_vec = server.requests().unwrap();
812        let server_request = req_vec.remove(0);
813
814        server
815            .respond(server_request.process(|request| {
816                assert_eq!(
817                    std::str::from_utf8(&request.body.as_ref().unwrap().body).unwrap(),
818                    "whatever body"
819                );
820                let mut response = Response::new(Version::Http11, StatusCode::OK);
821                let response_body = b"response body";
822                response.set_body(Body::new(response_body.to_vec()));
823                response
824            }))
825            .unwrap();
826        assert!(server.requests().unwrap().is_empty());
827
828        let mut buf: [u8; 1024] = [0; 1024];
829        assert!(socket.read(&mut buf[..]).unwrap() > 0);
830        assert!(String::from_utf8_lossy(&buf).contains("response body"));
831    }
832
833    #[test]
834    fn test_wait_concurrent_connections() {
835        let path_to_socket = get_temp_socket_file();
836
837        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
838        server.start_server().unwrap();
839
840        // Test two concurrent connections.
841        let mut first_socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
842        assert!(server.requests().unwrap().is_empty());
843
844        first_socket
845            .write_all(
846                b"PATCH /machine-config HTTP/1.1\r\n\
847                               Content-Length: 13\r\n\
848                               Content-Type: application/json\r\n\r\nwhatever body",
849            )
850            .unwrap();
851        let mut second_socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
852
853        let mut req_vec = server.requests().unwrap();
854        let server_request = req_vec.remove(0);
855
856        server
857            .respond(server_request.process(|_request| {
858                let mut response = Response::new(Version::Http11, StatusCode::OK);
859                let response_body = b"response body";
860                response.set_body(Body::new(response_body.to_vec()));
861                response
862            }))
863            .unwrap();
864        second_socket
865            .write_all(
866                b"GET /machine-config HTTP/1.1\r\n\
867                                Content-Type: application/json\r\n\r\n",
868            )
869            .unwrap();
870
871        let mut req_vec = server.requests().unwrap();
872        let second_server_request = req_vec.remove(0);
873
874        assert_eq!(
875            second_server_request.request,
876            Request::try_from(
877                b"GET /machine-config HTTP/1.1\r\n\
878            Content-Type: application/json\r\n\r\n",
879                None
880            )
881            .unwrap()
882        );
883
884        let mut buf: [u8; 1024] = [0; 1024];
885        assert!(first_socket.read(&mut buf[..]).unwrap() > 0);
886        first_socket.shutdown(std::net::Shutdown::Both).unwrap();
887
888        server
889            .respond(second_server_request.process(|_request| {
890                let mut response = Response::new(Version::Http11, StatusCode::OK);
891                let response_body = b"response second body";
892                response.set_body(Body::new(response_body.to_vec()));
893                response
894            }))
895            .unwrap();
896
897        assert!(server.requests().unwrap().is_empty());
898        let mut buf: [u8; 1024] = [0; 1024];
899        assert!(second_socket.read(&mut buf[..]).unwrap() > 0);
900        second_socket.shutdown(std::net::Shutdown::Both).unwrap();
901        assert!(server.requests().unwrap().is_empty());
902    }
903
904    #[test]
905    fn test_wait_expect_connection() {
906        let path_to_socket = get_temp_socket_file();
907
908        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
909        server.start_server().unwrap();
910
911        // Test one incoming connection with `Expect: 100-continue`.
912        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
913        assert!(server.requests().unwrap().is_empty());
914
915        socket
916            .write_all(
917                b"PATCH /machine-config HTTP/1.1\r\n\
918                         Content-Length: 13\r\n\
919                         Expect: 100-continue\r\n\r\n",
920            )
921            .unwrap();
922        // `wait` on server to receive what the client set on the socket.
923        // This will set the stream direction to `Outgoing`, as we need to send a `100 CONTINUE` response.
924        let req_vec = server.requests().unwrap();
925        assert!(req_vec.is_empty());
926        // Another `wait`, this time to send the response.
927        // Will be called because of an `EPOLLOUT` notification.
928        let req_vec = server.requests().unwrap();
929        assert!(req_vec.is_empty());
930        let mut buf: [u8; 1024] = [0; 1024];
931        assert!(socket.read(&mut buf[..]).unwrap() > 0);
932
933        socket.write_all(b"whatever body").unwrap();
934        let mut req_vec = server.requests().unwrap();
935        let server_request = req_vec.remove(0);
936
937        server
938            .respond(server_request.process(|_request| {
939                let mut response = Response::new(Version::Http11, StatusCode::OK);
940                let response_body = b"response body";
941                response.set_body(Body::new(response_body.to_vec()));
942                response
943            }))
944            .unwrap();
945
946        let req_vec = server.requests().unwrap();
947        assert!(req_vec.is_empty());
948
949        let mut buf: [u8; 1024] = [0; 1024];
950        assert!(socket.read(&mut buf[..]).unwrap() > 0);
951    }
952
953    #[test]
954    fn test_wait_many_connections() {
955        let path_to_socket = get_temp_socket_file();
956
957        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
958        server.start_server().unwrap();
959
960        let mut sockets: Vec<UnixStream> = Vec::with_capacity(MAX_CONNECTIONS + 1);
961        for _ in 0..MAX_CONNECTIONS {
962            sockets.push(UnixStream::connect(path_to_socket.as_path()).unwrap());
963            assert!(server.requests().unwrap().is_empty());
964        }
965
966        sockets.push(UnixStream::connect(path_to_socket.as_path()).unwrap());
967        assert!(server.requests().unwrap().is_empty());
968        let mut buf: [u8; 120] = [0; 120];
969        sockets[MAX_CONNECTIONS].read_exact(&mut buf).unwrap();
970        assert_eq!(&buf[..], SERVER_FULL_ERROR_MESSAGE);
971        assert_eq!(server.connections.len(), MAX_CONNECTIONS);
972        {
973            // Drop this stream.
974            let _refused_stream = sockets.pop().unwrap();
975        }
976        assert_eq!(server.connections.len(), MAX_CONNECTIONS);
977
978        // Check that the server detects a connection shutdown.
979        let sock: &UnixStream = sockets.get(0).unwrap();
980        sock.shutdown(Shutdown::Both).unwrap();
981        assert!(server.requests().unwrap().is_empty());
982        // Server should drop a closed connection.
983        assert_eq!(server.connections.len(), MAX_CONNECTIONS - 1);
984
985        // Close the backing FD of this connection by dropping
986        // it out of scope.
987        {
988            // Enforce the drop call on the stream
989            let _sock = sockets.pop().unwrap();
990        }
991        assert!(server.requests().unwrap().is_empty());
992        // Server should drop a closed connection.
993        assert_eq!(server.connections.len(), MAX_CONNECTIONS - 2);
994
995        let sock: &UnixStream = sockets.get(1).unwrap();
996        // Close both the read and write sides of the socket
997        // separately and check that the server detects it.
998        sock.shutdown(Shutdown::Read).unwrap();
999        sock.shutdown(Shutdown::Write).unwrap();
1000        assert!(server.requests().unwrap().is_empty());
1001        // Server should drop a closed connection.
1002        assert_eq!(server.connections.len(), MAX_CONNECTIONS - 3);
1003    }
1004
1005    #[test]
1006    fn test_wait_parse_error() {
1007        let path_to_socket = get_temp_socket_file();
1008
1009        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
1010        server.start_server().unwrap();
1011
1012        // Test one incoming connection.
1013        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
1014        socket.set_nonblocking(true).unwrap();
1015        assert!(server.requests().unwrap().is_empty());
1016
1017        socket
1018            .write_all(
1019                b"PATCH /machine-config HTTP/1.1\r\n\
1020                         Content-Length: alpha\r\n\
1021                         Content-Type: application/json\r\n\r\nwhatever body",
1022            )
1023            .unwrap();
1024
1025        assert!(server.requests().unwrap().is_empty());
1026        assert!(server.requests().unwrap().is_empty());
1027        let mut buf: [u8; 255] = [0; 255];
1028        assert!(socket.read(&mut buf[..]).unwrap() > 0);
1029        let error_message = b"HTTP/1.1 400 \r\n\
1030                              Server: Firecracker API\r\n\
1031                              Connection: keep-alive\r\n\
1032                              Content-Type: application/json\r\n\
1033                              Content-Length: 136\r\n\r\n{ \"error\": \"Invalid header. \
1034                              Reason: Invalid value. Key:Content-Length; Value: alpha\nAll previous unanswered requests will be dropped.\" }";
1035        assert_eq!(&buf[..], &error_message[..]);
1036
1037        socket
1038            .write_all(
1039                b"PATCH /machine-config HTTP/1.1\r\n\
1040                         Content-Length: alpha\r\n\
1041                         Content-Type: application/json\r\n\r\nwhatever body",
1042            )
1043            .unwrap();
1044        assert!(server.requests().unwrap().is_empty());
1045        assert!(server.requests().unwrap().is_empty());
1046    }
1047
1048    #[test]
1049    fn test_wait_in_flight_responses() {
1050        let path_to_socket = get_temp_socket_file();
1051
1052        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
1053        server.start_server().unwrap();
1054
1055        // Test a connection dropped and then a new one appearing
1056        // before the user had a chance to send the response to the
1057        // first one.
1058        let mut first_socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
1059        assert!(server.requests().unwrap().is_empty());
1060
1061        first_socket
1062            .write_all(
1063                b"PATCH /machine-config HTTP/1.1\r\n\
1064                               Content-Length: 13\r\n\
1065                               Content-Type: application/json\r\n\r\nwhatever body",
1066            )
1067            .unwrap();
1068
1069        let mut req_vec = server.requests().unwrap();
1070        let server_request = req_vec.remove(0);
1071
1072        first_socket.shutdown(std::net::Shutdown::Both).unwrap();
1073        assert!(server.requests().unwrap().is_empty());
1074        assert_eq!(server.connections.len(), 1);
1075        let mut second_socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
1076        second_socket.set_nonblocking(true).unwrap();
1077        assert!(server.requests().unwrap().is_empty());
1078        assert_eq!(server.connections.len(), 2);
1079
1080        server
1081            .enqueue_responses(vec![server_request.process(|_request| {
1082                let mut response = Response::new(Version::Http11, StatusCode::OK);
1083                let response_body = b"response body";
1084                response.set_body(Body::new(response_body.to_vec()));
1085                response
1086            })])
1087            .unwrap();
1088
1089        // assert!(server.requests().unwrap().is_empty());
1090        assert_eq!(server.connections.len(), 2);
1091        let mut buf: [u8; 1024] = [0; 1024];
1092        assert!(second_socket.read(&mut buf[..]).is_err());
1093
1094        second_socket
1095            .write_all(
1096                b"GET /machine-config HTTP/1.1\r\n\
1097                                Content-Type: application/json\r\n\r\n",
1098            )
1099            .unwrap();
1100
1101        let mut req_vec = server.requests().unwrap();
1102        let second_server_request = req_vec.remove(0);
1103
1104        assert_eq!(server.connections.len(), 1);
1105        assert_eq!(
1106            second_server_request.request,
1107            Request::try_from(
1108                b"GET /machine-config HTTP/1.1\r\n\
1109            Content-Type: application/json\r\n\r\n",
1110                None
1111            )
1112            .unwrap()
1113        );
1114
1115        server
1116            .respond(second_server_request.process(|_request| {
1117                let mut response = Response::new(Version::Http11, StatusCode::OK);
1118                let response_body = b"response second body";
1119                response.set_body(Body::new(response_body.to_vec()));
1120                response
1121            }))
1122            .unwrap();
1123
1124        assert!(server.requests().unwrap().is_empty());
1125        let mut buf: [u8; 1024] = [0; 1024];
1126        assert!(second_socket.read(&mut buf[..]).unwrap() > 0);
1127        second_socket.shutdown(std::net::Shutdown::Both).unwrap();
1128        assert!(server.requests().is_ok());
1129    }
1130
1131    #[test]
1132    fn test_wait_two_messages() {
1133        let path_to_socket = get_temp_socket_file();
1134
1135        let mut server = HttpServer::new(path_to_socket.as_path()).unwrap();
1136        server.start_server().unwrap();
1137
1138        // Test one incoming connection.
1139        let mut socket = UnixStream::connect(path_to_socket.as_path()).unwrap();
1140        assert!(server.requests().unwrap().is_empty());
1141
1142        socket
1143            .write_all(
1144                b"PATCH /machine-config HTTP/1.1\r\n\
1145                         Content-Length: 13\r\n\
1146                         Content-Type: application/json\r\n\r\nwhatever body",
1147            )
1148            .unwrap();
1149
1150        let mut req_vec = server.requests().unwrap();
1151        let server_request = req_vec.remove(0);
1152
1153        socket
1154            .write_all(
1155                b"PATCH /machine-config HTTP/1.1\r\n\
1156                         Content-Length: 13\r\n\
1157                         Content-Type: application/json\r\n\r\nwhatever body",
1158            )
1159            .unwrap();
1160
1161        server
1162            .respond(server_request.process(|_request| {
1163                let mut response = Response::new(Version::Http11, StatusCode::OK);
1164                let response_body = b"response body";
1165                response.set_body(Body::new(response_body.to_vec()));
1166                response
1167            }))
1168            .unwrap();
1169        assert!(server.requests().unwrap().is_empty());
1170
1171        let mut buf: [u8; 1024] = [0; 1024];
1172        assert!(socket.read(&mut buf[..]).unwrap() > 0);
1173
1174        let mut req_vec = server.requests().unwrap();
1175        let server_request = req_vec.remove(0);
1176        server
1177            .respond(server_request.process(|_request| {
1178                let mut response = Response::new(Version::Http11, StatusCode::OK);
1179                let response_body = b"response body";
1180                response.set_body(Body::new(response_body.to_vec()));
1181                response
1182            }))
1183            .unwrap();
1184        assert!(server.requests().unwrap().is_empty());
1185
1186        let mut buf: [u8; 1024] = [0; 1024];
1187        assert!(socket.read(&mut buf[..]).unwrap() > 0);
1188    }
1189}