Skip to main content

pglite_oxide/pglite/
server.rs

1use std::net::{SocketAddr, TcpListener};
2#[cfg(unix)]
3use std::os::unix::net::UnixListener;
4use std::path::{Path, PathBuf};
5use std::sync::{
6    Arc,
7    atomic::{AtomicBool, Ordering},
8    mpsc::{Receiver, sync_channel},
9};
10use std::thread::{self, JoinHandle};
11
12use anyhow::{Context, Result, anyhow};
13use tempfile::TempDir;
14
15use crate::pglite::base::{install_into, install_temporary_from_template};
16use crate::pglite::proxy::PgliteProxy;
17
18/// A supervised local PostgreSQL socket backed by one embedded PGlite runtime.
19///
20/// This is the compatibility entry point for code that expects a PostgreSQL URL,
21/// such as `tokio-postgres`, SQLx, or tools that speak the wire protocol. The
22/// server owns one embedded backend, so downstream pools should use a single
23/// connection.
24#[derive(Debug)]
25pub struct PgliteServer {
26    root: PathBuf,
27    _temp_dir: Option<TempDir>,
28    endpoint: ServerEndpoint,
29    shutdown: Arc<AtomicBool>,
30    handle: Option<JoinHandle<Result<()>>>,
31}
32
33#[derive(Debug, Clone)]
34enum ServerEndpoint {
35    Tcp(SocketAddr),
36    #[cfg(unix)]
37    Unix(PathBuf),
38}
39
40impl PgliteServer {
41    /// Build a local PGlite server. The default is a cached temporary database
42    /// served on `127.0.0.1:0`.
43    pub fn builder() -> PgliteServerBuilder {
44        PgliteServerBuilder::new()
45    }
46
47    /// Start a cached temporary database on a random local TCP port.
48    pub fn temporary_tcp() -> Result<Self> {
49        Self::builder().temporary().start()
50    }
51
52    /// Return the root directory used for runtime files and cluster data.
53    pub fn root(&self) -> &Path {
54        &self.root
55    }
56
57    /// Return the bound TCP address, if this server is using TCP.
58    pub fn tcp_addr(&self) -> Option<SocketAddr> {
59        match self.endpoint {
60            ServerEndpoint::Tcp(addr) => Some(addr),
61            #[cfg(unix)]
62            ServerEndpoint::Unix(_) => None,
63        }
64    }
65
66    /// Return the Unix-domain socket path, if this server is using UDS.
67    #[cfg(unix)]
68    pub fn socket_path(&self) -> Option<&Path> {
69        match &self.endpoint {
70            ServerEndpoint::Tcp(_) => None,
71            ServerEndpoint::Unix(path) => Some(path),
72        }
73    }
74
75    /// Return a PostgreSQL connection URI for the local server.
76    pub fn connection_uri(&self) -> String {
77        match &self.endpoint {
78            ServerEndpoint::Tcp(addr) => tcp_connection_uri(*addr),
79            #[cfg(unix)]
80            ServerEndpoint::Unix(path) => {
81                let host = path.parent().unwrap_or_else(|| Path::new("/tmp"));
82                let port = parse_unix_socket_port(path).unwrap_or(5432);
83                format!(
84                    "postgresql://postgres@/template1?host={}&port={}&sslmode=disable",
85                    percent_encode_query_value(&host.display().to_string()),
86                    port
87                )
88            }
89        }
90    }
91
92    /// Request shutdown and wait for the listener thread to exit.
93    ///
94    /// Close database clients before calling this method. The current proxy owns
95    /// one blocking backend connection at a time, so an open client can keep the
96    /// worker thread busy until it disconnects.
97    pub fn shutdown(mut self) -> Result<()> {
98        self.stop()
99    }
100
101    fn stop(&mut self) -> Result<()> {
102        self.shutdown.store(true, Ordering::SeqCst);
103        if let Some(handle) = self.handle.take() {
104            handle
105                .join()
106                .map_err(|_| anyhow!("pglite server thread panicked"))??;
107        }
108        Ok(())
109    }
110}
111
112impl Drop for PgliteServer {
113    fn drop(&mut self) {
114        self.shutdown.store(true, Ordering::SeqCst);
115    }
116}
117
118/// Builder for [`PgliteServer`].
119#[derive(Debug, Clone)]
120pub struct PgliteServerBuilder {
121    root: ServerRoot,
122    endpoint: ServerEndpointConfig,
123}
124
125#[derive(Debug, Clone)]
126enum ServerRoot {
127    Temporary { template_cache: bool },
128    Path(PathBuf),
129}
130
131#[derive(Debug, Clone)]
132enum ServerEndpointConfig {
133    Tcp(SocketAddr),
134    #[cfg(unix)]
135    Unix(PathBuf),
136}
137
138impl Default for PgliteServerBuilder {
139    fn default() -> Self {
140        Self {
141            root: ServerRoot::Temporary {
142                template_cache: true,
143            },
144            endpoint: ServerEndpointConfig::Tcp(SocketAddr::from(([127, 0, 0, 1], 0))),
145        }
146    }
147}
148
149impl PgliteServerBuilder {
150    /// Create a builder. Defaults to a cached temporary database on
151    /// `127.0.0.1:0`.
152    pub fn new() -> Self {
153        Self::default()
154    }
155
156    /// Serve a persistent database rooted at `root`.
157    pub fn path(mut self, root: impl Into<PathBuf>) -> Self {
158        self.root = ServerRoot::Path(root.into());
159        self
160    }
161
162    /// Serve a temporary database cloned from the process-local template cache.
163    pub fn temporary(mut self) -> Self {
164        self.root = ServerRoot::Temporary {
165            template_cache: true,
166        };
167        self
168    }
169
170    /// Serve a temporary database initialized without the template cache.
171    pub fn fresh_temporary(mut self) -> Self {
172        self.root = ServerRoot::Temporary {
173            template_cache: false,
174        };
175        self
176    }
177
178    /// Bind the server to a TCP address.
179    pub fn tcp(mut self, addr: SocketAddr) -> Self {
180        self.endpoint = ServerEndpointConfig::Tcp(addr);
181        self
182    }
183
184    /// Bind the server to a Unix-domain socket path.
185    #[cfg(unix)]
186    pub fn unix(mut self, path: impl Into<PathBuf>) -> Self {
187        self.endpoint = ServerEndpointConfig::Unix(path.into());
188        self
189    }
190
191    /// Install the runtime if needed, initialize the cluster, and start serving.
192    pub fn start(self) -> Result<PgliteServer> {
193        let (root, temp_dir) = match self.root {
194            ServerRoot::Path(root) => {
195                install_into(&root)?;
196                (root, None)
197            }
198            ServerRoot::Temporary { template_cache } => {
199                if template_cache {
200                    let (root, temp_dir) = prepare_cached_temporary_root()?;
201                    (root, Some(temp_dir))
202                } else {
203                    let temp_dir = TempDir::new().context("create temporary pglite directory")?;
204                    install_into(temp_dir.path())?;
205                    (temp_dir.path().to_path_buf(), Some(temp_dir))
206                }
207            }
208        };
209
210        let shutdown = Arc::new(AtomicBool::new(false));
211        let proxy = PgliteProxy::new(root.clone());
212
213        let (endpoint, handle) = match self.endpoint {
214            ServerEndpointConfig::Tcp(addr) => start_tcp(proxy, addr, shutdown.clone())?,
215            #[cfg(unix)]
216            ServerEndpointConfig::Unix(path) => start_unix(proxy, path, shutdown.clone())?,
217        };
218
219        Ok(PgliteServer {
220            root,
221            _temp_dir: temp_dir,
222            endpoint,
223            shutdown,
224            handle: Some(handle),
225        })
226    }
227}
228
229fn start_tcp(
230    proxy: PgliteProxy,
231    addr: SocketAddr,
232    shutdown: Arc<AtomicBool>,
233) -> Result<(ServerEndpoint, JoinHandle<Result<()>>)> {
234    let listener = TcpListener::bind(addr).context("bind PGlite TCP server")?;
235    let addr = listener.local_addr().context("read PGlite TCP address")?;
236    let (ready_tx, ready_rx) = sync_channel(1);
237    let handle = thread::spawn(move || {
238        proxy.serve_tcp_listener_until_ready(listener, shutdown, Some(ready_tx))
239    });
240    wait_until_ready(&ready_rx)?;
241    Ok((ServerEndpoint::Tcp(addr), handle))
242}
243
244fn tcp_connection_uri(addr: SocketAddr) -> String {
245    match addr {
246        SocketAddr::V4(addr) => {
247            format!(
248                "postgresql://postgres@{}:{}/template1?sslmode=disable",
249                addr.ip(),
250                addr.port()
251            )
252        }
253        SocketAddr::V6(addr) => {
254            format!(
255                "postgresql://postgres@[{}]:{}/template1?sslmode=disable",
256                addr.ip(),
257                addr.port()
258            )
259        }
260    }
261}
262
263fn prepare_cached_temporary_root() -> Result<(PathBuf, TempDir)> {
264    run_blocking("pglite-template-cache", || {
265        let (temp_dir, _outcome) = install_temporary_from_template()?;
266        Ok((temp_dir.path().to_path_buf(), temp_dir))
267    })
268}
269
270fn run_blocking<T, F>(name: &'static str, f: F) -> Result<T>
271where
272    T: Send + 'static,
273    F: FnOnce() -> Result<T> + Send + 'static,
274{
275    thread::Builder::new()
276        .name(name.to_string())
277        .spawn(f)
278        .with_context(|| format!("spawn {name} worker"))?
279        .join()
280        .map_err(|_| anyhow!("{name} worker panicked"))?
281}
282
283#[cfg(unix)]
284fn start_unix(
285    proxy: PgliteProxy,
286    path: PathBuf,
287    shutdown: Arc<AtomicBool>,
288) -> Result<(ServerEndpoint, JoinHandle<Result<()>>)> {
289    if path.exists() {
290        std::fs::remove_file(&path)
291            .with_context(|| format!("remove stale socket {}", path.display()))?;
292    }
293    if let Some(parent) = path.parent() {
294        std::fs::create_dir_all(parent)
295            .with_context(|| format!("create socket directory {}", parent.display()))?;
296    }
297
298    let listener = UnixListener::bind(&path)
299        .with_context(|| format!("bind PGlite Unix socket {}", path.display()))?;
300    let endpoint = ServerEndpoint::Unix(path);
301    let (ready_tx, ready_rx) = sync_channel(1);
302    let handle = thread::spawn(move || {
303        proxy.serve_unix_listener_until_ready(listener, shutdown, Some(ready_tx))
304    });
305    wait_until_ready(&ready_rx)?;
306    Ok((endpoint, handle))
307}
308
309fn wait_until_ready(ready_rx: &Receiver<Result<()>>) -> Result<()> {
310    ready_rx
311        .recv()
312        .context("PGlite server thread exited before reporting readiness")?
313}
314
315#[cfg(unix)]
316fn parse_unix_socket_port(path: &Path) -> Option<u16> {
317    let name = path.file_name()?.to_str()?;
318    name.strip_prefix(".s.PGSQL.")?.parse().ok()
319}
320
321#[cfg(unix)]
322fn percent_encode_query_value(value: &str) -> String {
323    let mut encoded = String::with_capacity(value.len());
324    for byte in value.bytes() {
325        if matches!(
326            byte,
327            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' | b'/'
328        ) {
329            encoded.push(byte as char);
330        } else {
331            encoded.push_str(&format!("%{byte:02X}"));
332        }
333    }
334    encoded
335}
336
337#[cfg(all(test, unix))]
338mod tests {
339    use super::percent_encode_query_value;
340
341    #[test]
342    fn unix_socket_uri_host_is_query_encoded() {
343        assert_eq!(
344            percent_encode_query_value("/tmp/Application Support/pglite"),
345            "/tmp/Application%20Support/pglite"
346        );
347    }
348}