1use std::net::{SocketAddr, TcpListener};
2#[cfg(unix)]
3use std::os::unix::net::UnixListener;
4use std::path::{Path, PathBuf};
5use std::sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 mpsc::{Receiver, sync_channel},
9};
10use std::thread::{self, JoinHandle};
11
12use anyhow::{Context, Result, anyhow};
13use tempfile::TempDir;
14
15use crate::pglite::base::{install_into, install_temporary_from_template};
16use crate::pglite::proxy::PgliteProxy;
17
18#[derive(Debug)]
25pub struct PgliteServer {
26 root: PathBuf,
27 _temp_dir: Option<TempDir>,
28 endpoint: ServerEndpoint,
29 shutdown: Arc<AtomicBool>,
30 handle: Option<JoinHandle<Result<()>>>,
31}
32
33#[derive(Debug, Clone)]
34enum ServerEndpoint {
35 Tcp(SocketAddr),
36 #[cfg(unix)]
37 Unix(PathBuf),
38}
39
40impl PgliteServer {
41 pub fn builder() -> PgliteServerBuilder {
44 PgliteServerBuilder::new()
45 }
46
47 pub fn temporary_tcp() -> Result<Self> {
49 Self::builder().temporary().start()
50 }
51
52 pub fn root(&self) -> &Path {
54 &self.root
55 }
56
57 pub fn tcp_addr(&self) -> Option<SocketAddr> {
59 match self.endpoint {
60 ServerEndpoint::Tcp(addr) => Some(addr),
61 #[cfg(unix)]
62 ServerEndpoint::Unix(_) => None,
63 }
64 }
65
66 #[cfg(unix)]
68 pub fn socket_path(&self) -> Option<&Path> {
69 match &self.endpoint {
70 ServerEndpoint::Tcp(_) => None,
71 ServerEndpoint::Unix(path) => Some(path),
72 }
73 }
74
75 pub fn connection_uri(&self) -> String {
77 match &self.endpoint {
78 ServerEndpoint::Tcp(addr) => tcp_connection_uri(*addr),
79 #[cfg(unix)]
80 ServerEndpoint::Unix(path) => {
81 let host = path.parent().unwrap_or_else(|| Path::new("/tmp"));
82 let port = parse_unix_socket_port(path).unwrap_or(5432);
83 format!(
84 "postgresql://postgres@/template1?host={}&port={}&sslmode=disable",
85 host.display(),
86 port
87 )
88 }
89 }
90 }
91
92 pub fn shutdown(mut self) -> Result<()> {
98 self.stop()
99 }
100
101 fn stop(&mut self) -> Result<()> {
102 self.shutdown.store(true, Ordering::SeqCst);
103 if let Some(handle) = self.handle.take() {
104 handle
105 .join()
106 .map_err(|_| anyhow!("pglite server thread panicked"))??;
107 }
108 Ok(())
109 }
110}
111
112impl Drop for PgliteServer {
113 fn drop(&mut self) {
114 self.shutdown.store(true, Ordering::SeqCst);
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct PgliteServerBuilder {
121 root: ServerRoot,
122 endpoint: ServerEndpointConfig,
123}
124
125#[derive(Debug, Clone)]
126enum ServerRoot {
127 Temporary { template_cache: bool },
128 Path(PathBuf),
129}
130
131#[derive(Debug, Clone)]
132enum ServerEndpointConfig {
133 Tcp(SocketAddr),
134 #[cfg(unix)]
135 Unix(PathBuf),
136}
137
138impl Default for PgliteServerBuilder {
139 fn default() -> Self {
140 Self {
141 root: ServerRoot::Temporary {
142 template_cache: true,
143 },
144 endpoint: ServerEndpointConfig::Tcp(SocketAddr::from(([127, 0, 0, 1], 0))),
145 }
146 }
147}
148
149impl PgliteServerBuilder {
150 pub fn new() -> Self {
153 Self::default()
154 }
155
156 pub fn path(mut self, root: impl Into<PathBuf>) -> Self {
158 self.root = ServerRoot::Path(root.into());
159 self
160 }
161
162 pub fn temporary(mut self) -> Self {
164 self.root = ServerRoot::Temporary {
165 template_cache: true,
166 };
167 self
168 }
169
170 pub fn fresh_temporary(mut self) -> Self {
172 self.root = ServerRoot::Temporary {
173 template_cache: false,
174 };
175 self
176 }
177
178 pub fn tcp(mut self, addr: SocketAddr) -> Self {
180 self.endpoint = ServerEndpointConfig::Tcp(addr);
181 self
182 }
183
184 #[cfg(unix)]
186 pub fn unix(mut self, path: impl Into<PathBuf>) -> Self {
187 self.endpoint = ServerEndpointConfig::Unix(path.into());
188 self
189 }
190
191 pub fn start(self) -> Result<PgliteServer> {
193 let (root, temp_dir) = match self.root {
194 ServerRoot::Path(root) => {
195 install_into(&root)?;
196 (root, None)
197 }
198 ServerRoot::Temporary { template_cache } => {
199 if template_cache {
200 let (root, temp_dir) = prepare_cached_temporary_root()?;
201 (root, Some(temp_dir))
202 } else {
203 let temp_dir = TempDir::new().context("create temporary pglite directory")?;
204 install_into(temp_dir.path())?;
205 (temp_dir.path().to_path_buf(), Some(temp_dir))
206 }
207 }
208 };
209
210 let shutdown = Arc::new(AtomicBool::new(false));
211 let proxy = PgliteProxy::new(root.clone());
212
213 let (endpoint, handle) = match self.endpoint {
214 ServerEndpointConfig::Tcp(addr) => start_tcp(proxy, addr, shutdown.clone())?,
215 #[cfg(unix)]
216 ServerEndpointConfig::Unix(path) => start_unix(proxy, path, shutdown.clone())?,
217 };
218
219 Ok(PgliteServer {
220 root,
221 _temp_dir: temp_dir,
222 endpoint,
223 shutdown,
224 handle: Some(handle),
225 })
226 }
227}
228
229fn start_tcp(
230 proxy: PgliteProxy,
231 addr: SocketAddr,
232 shutdown: Arc<AtomicBool>,
233) -> Result<(ServerEndpoint, JoinHandle<Result<()>>)> {
234 let listener = TcpListener::bind(addr).context("bind PGlite TCP server")?;
235 let addr = listener.local_addr().context("read PGlite TCP address")?;
236 let (ready_tx, ready_rx) = sync_channel(1);
237 let handle = thread::spawn(move || {
238 proxy.serve_tcp_listener_until_ready(listener, shutdown, Some(ready_tx))
239 });
240 wait_until_ready(&ready_rx)?;
241 Ok((ServerEndpoint::Tcp(addr), handle))
242}
243
244fn tcp_connection_uri(addr: SocketAddr) -> String {
245 match addr {
246 SocketAddr::V4(addr) => {
247 format!(
248 "postgresql://postgres@{}:{}/template1?sslmode=disable",
249 addr.ip(),
250 addr.port()
251 )
252 }
253 SocketAddr::V6(addr) => {
254 format!(
255 "postgresql://postgres@[{}]:{}/template1?sslmode=disable",
256 addr.ip(),
257 addr.port()
258 )
259 }
260 }
261}
262
263fn prepare_cached_temporary_root() -> Result<(PathBuf, TempDir)> {
264 run_blocking("pglite-template-cache", || {
265 let (temp_dir, _outcome) = install_temporary_from_template()?;
266 Ok((temp_dir.path().to_path_buf(), temp_dir))
267 })
268}
269
270fn run_blocking<T, F>(name: &'static str, f: F) -> Result<T>
271where
272 T: Send + 'static,
273 F: FnOnce() -> Result<T> + Send + 'static,
274{
275 thread::Builder::new()
276 .name(name.to_string())
277 .spawn(f)
278 .with_context(|| format!("spawn {name} worker"))?
279 .join()
280 .map_err(|_| anyhow!("{name} worker panicked"))?
281}
282
283#[cfg(unix)]
284fn start_unix(
285 proxy: PgliteProxy,
286 path: PathBuf,
287 shutdown: Arc<AtomicBool>,
288) -> Result<(ServerEndpoint, JoinHandle<Result<()>>)> {
289 if path.exists() {
290 std::fs::remove_file(&path)
291 .with_context(|| format!("remove stale socket {}", path.display()))?;
292 }
293 if let Some(parent) = path.parent() {
294 std::fs::create_dir_all(parent)
295 .with_context(|| format!("create socket directory {}", parent.display()))?;
296 }
297
298 let listener = UnixListener::bind(&path)
299 .with_context(|| format!("bind PGlite Unix socket {}", path.display()))?;
300 let endpoint = ServerEndpoint::Unix(path);
301 let (ready_tx, ready_rx) = sync_channel(1);
302 let handle = thread::spawn(move || {
303 proxy.serve_unix_listener_until_ready(listener, shutdown, Some(ready_tx))
304 });
305 wait_until_ready(&ready_rx)?;
306 Ok((endpoint, handle))
307}
308
309fn wait_until_ready(ready_rx: &Receiver<Result<()>>) -> Result<()> {
310 ready_rx
311 .recv()
312 .context("PGlite server thread exited before reporting readiness")?
313}
314
315#[cfg(unix)]
316fn parse_unix_socket_port(path: &Path) -> Option<u16> {
317 let name = path.file_name()?.to_str()?;
318 name.strip_prefix(".s.PGSQL.")?.parse().ok()
319}