httptest/server_pool.rs
1use crate::Server;
2use once_cell::sync::OnceCell;
3use std::marker::PhantomData;
4use std::ops::{Deref, DerefMut};
5use std::sync::Mutex;
6
7/// A pool of shared servers.
8///
9/// The typical way to use this library is to create a new Server for each test
10/// using [Server::run](struct.Server.html#method.run). This way each test
11/// contains it's own independent state. However for very large test suites it
12/// may be beneficial to share servers between test runs. This is typically only
13/// desirable when running into limits because the test framework spins up an
14/// independent thread for each test and system wide resources (TCP ports) may
15/// become scarce. In those cases you can opt into using a shared ServerPool that
16/// will create a maximum of N servers that tests can share. Invoking
17/// [get_server](#method.get_server) on the pool will return a
18/// [ServerHandle](struct.ServerHandle.html) that deref's to a
19/// [Server](struct.Server.html). For the life of the
20/// [ServerHandle](struct.ServerHandle.html) it has unique access to this server
21/// instance. When the handle is dropped the server expectations are asserted and
22/// cleared and the server is returned back into the
23/// [ServerPool](struct.ServerPool.html) for use by another test.
24///
25/// Example:
26///
27/// ```
28/// # use httptest::ServerPool;
29/// // Create a server pool that will create at most 99 servers.
30/// static SERVER_POOL: ServerPool = ServerPool::new(99);
31///
32/// #[test]
33/// fn test_one() {
34/// let server = SERVER_POOL.get_server();
35/// server.expect(Expectation::matching(any()).respond_with(status_code(200)));
36/// // invoke http requests to server.
37///
38/// // server will assert expectations are met on drop.
39/// }
40///
41/// #[test]
42/// fn test_two() {
43/// let server = SERVER_POOL.get_server();
44/// server.expect(Expectation::matching(any()).respond_with(status_code(200)));
45/// // invoke http requests to server.
46///
47/// // server will assert expectations are met on drop.
48/// }
49/// ```
50
51/// A pool of running servers.
52#[derive(Debug)]
53pub struct ServerPool(OnceCell<InnerPool>, usize);
54
55impl ServerPool {
56 /// Create a new pool of servers.
57 ///
58 /// `max_servers` is the maximum number of servers that will be created.
59 /// servers are created on-demand when `get_server` is invoked.
60 pub const fn new(max_servers: usize) -> Self {
61 ServerPool(OnceCell::new(), max_servers)
62 }
63
64 /// Get the next available server from the pool.
65 pub fn get_server(&self) -> ServerHandle {
66 self.0.get_or_init(|| InnerPool::new(self.1)).get_server()
67 }
68}
69
70#[allow(clippy::mutex_atomic)]
71#[derive(Debug)]
72struct InnerPool {
73 servers_created: Mutex<usize>,
74 servers_tx: crossbeam_channel::Sender<Server>,
75 servers_rx: crossbeam_channel::Receiver<Server>,
76}
77
78#[allow(clippy::mutex_atomic)]
79impl InnerPool {
80 fn new(max_capacity: usize) -> Self {
81 assert!(max_capacity > 0);
82 let (servers_tx, servers_rx) = crossbeam_channel::bounded(max_capacity);
83 InnerPool {
84 servers_created: Mutex::new(0),
85 servers_tx,
86 servers_rx,
87 }
88 }
89
90 fn get_server(&self) -> ServerHandle {
91 if let Ok(server) = self.servers_rx.try_recv() {
92 return ServerHandle {
93 servers_tx: self.servers_tx.clone(),
94 server: Some(server),
95 lifetime_marker: PhantomData,
96 };
97 }
98 {
99 let mut servers_created = self.servers_created.lock().expect("poisoned mutex");
100 if *servers_created < self.servers_tx.capacity().unwrap() {
101 *servers_created += 1;
102 return ServerHandle {
103 servers_tx: self.servers_tx.clone(),
104 server: Some(Server::run()),
105 lifetime_marker: PhantomData,
106 };
107 }
108 }
109 ServerHandle {
110 servers_tx: self.servers_tx.clone(),
111 server: Some(
112 self.servers_rx
113 .recv()
114 .expect("all senders unexpectedly dropped"),
115 ),
116 lifetime_marker: PhantomData,
117 }
118 }
119}
120
121#[allow(clippy::mutex_atomic)]
122impl Drop for InnerPool {
123 fn drop(&mut self) {
124 // wait for all created servers to get returned to the pool.
125 let servers_created = self.servers_created.lock().expect("poisoned mutex");
126 for _ in 0..*servers_created {
127 self.servers_rx
128 .recv()
129 .expect("all senders unexpectedly dropped");
130 }
131 }
132}
133
134/// A handle to a server. Expectations are inserted when the handle is dropped.
135#[derive(Debug)]
136pub struct ServerHandle<'a> {
137 servers_tx: crossbeam_channel::Sender<Server>,
138 server: Option<Server>,
139
140 // We add a lifetime to the ServerHandle just to restrict the ownership
141 // beyond what the implementation currently allows. The public facing API
142 // will appear that the ServerHandle is borrowed from a Pool, which enforces
143 // the desired behavior. No SeverHandle should outlive the underlying Pool.
144 // The current implementation passes around owned values so there is no real
145 // borrowing constraint, but a future implementation may do something more
146 // efficient.
147 lifetime_marker: PhantomData<&'a ()>,
148}
149
150impl Deref for ServerHandle<'_> {
151 type Target = Server;
152
153 fn deref(&self) -> &Server {
154 self.server.as_ref().unwrap()
155 }
156}
157
158impl DerefMut for ServerHandle<'_> {
159 fn deref_mut(&mut self) -> &mut Server {
160 self.server.as_mut().unwrap()
161 }
162}
163
164impl Drop for ServerHandle<'_> {
165 fn drop(&mut self) {
166 let mut server = self.server.take().unwrap();
167 server.verify_and_clear();
168 self.servers_tx
169 .send(server)
170 .expect("all receivers unexpectedly dropped");
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 const MAX_SERVERS: usize = 5;
179 static POOL: ServerPool = ServerPool::new(MAX_SERVERS);
180
181 #[test]
182 fn test_max_threads() {
183 use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
184 let concurrent_server_handles = AtomicUsize::new(0);
185 let desired_concurrency_reached = std::sync::Barrier::new(MAX_SERVERS);
186 crossbeam_utils::thread::scope(|s| {
187 for _ in 0..10 {
188 s.spawn(|_| {
189 let _server = POOL.get_server();
190
191 // Ensure that we've reached the desired number of concurrent servers.
192 desired_concurrency_reached.wait();
193
194 // Ensure that we have not exceeded the desired number of concurrent servers.
195 let prev_value = concurrent_server_handles.fetch_add(1, SeqCst);
196 if prev_value > MAX_SERVERS {
197 panic!("too many concurrent server handles: {}", prev_value + 1);
198 }
199 std::thread::sleep(std::time::Duration::from_millis(500));
200 concurrent_server_handles.fetch_sub(1, SeqCst);
201 });
202 }
203 })
204 .unwrap();
205 }
206}