use std::net::{SocketAddr, TcpListener, TcpStream};
#[cfg(unix)]
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::{Path, PathBuf};
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
mpsc::{Receiver, sync_channel},
};
use std::thread::{self, JoinHandle};
use anyhow::{Context, Result, anyhow};
use tempfile::TempDir;
use crate::pglite::base::{PreparedRoot, RootLock, RootPlan, RootSource, RootTarget, prepare_root};
use crate::pglite::config::{PostgresConfig, StartupConfig};
#[cfg(feature = "extensions")]
use crate::pglite::extensions::{Extension, resolve_extension_set};
use crate::pglite::interface::DebugLevel;
#[cfg(feature = "extensions")]
use crate::pglite::pg_dump::{PgDumpOptions, dump_server_sql};
use crate::pglite::proxy::PgliteProxy;
use crate::pglite::timing;
#[derive(Debug)]
pub struct PgliteServer {
root: PathBuf,
_temp_dir: Option<TempDir>,
_root_lock: Option<RootLock>,
endpoint: ServerEndpoint,
startup_config: StartupConfig,
shutdown: Arc<AtomicBool>,
handle: Option<JoinHandle<Result<()>>>,
}
#[derive(Debug, Clone)]
enum ServerEndpoint {
Tcp(SocketAddr),
#[cfg(unix)]
Unix(PathBuf),
}
impl PgliteServer {
pub fn builder() -> PgliteServerBuilder {
PgliteServerBuilder::new()
}
pub fn temporary_tcp() -> Result<Self> {
Self::builder().temporary().start()
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn tcp_addr(&self) -> Option<SocketAddr> {
match self.endpoint {
ServerEndpoint::Tcp(addr) => Some(addr),
#[cfg(unix)]
ServerEndpoint::Unix(_) => None,
}
}
#[cfg(unix)]
pub fn socket_path(&self) -> Option<&Path> {
match &self.endpoint {
ServerEndpoint::Tcp(_) => None,
ServerEndpoint::Unix(path) => Some(path),
}
}
pub fn connection_uri(&self) -> String {
match &self.endpoint {
ServerEndpoint::Tcp(addr) => tcp_connection_uri(*addr, &self.startup_config),
#[cfg(unix)]
ServerEndpoint::Unix(path) => {
let host = path.parent().unwrap_or_else(|| Path::new("/tmp"));
let port = parse_unix_socket_port(path).unwrap_or(5432);
format!(
"postgresql://{}@/{}?host={}&port={}&sslmode=disable",
self.startup_config.username,
self.startup_config.database,
percent_encode_query_value(&host.display().to_string()),
port
)
}
}
}
pub fn database_url(&self) -> String {
self.connection_uri()
}
#[cfg(feature = "extensions")]
pub fn dump_sql(&self, options: PgDumpOptions) -> Result<String> {
let addr = self
.tcp_addr()
.context("pg_dump currently requires a TCP PgliteServer endpoint")?;
dump_server_sql(addr, &options)
}
#[cfg(feature = "extensions")]
pub fn dump_bytes(&self, options: PgDumpOptions) -> Result<Vec<u8>> {
Ok(self.dump_sql(options)?.into_bytes())
}
pub fn shutdown(mut self) -> Result<()> {
self.stop()
}
fn stop(&mut self) -> Result<()> {
self.shutdown.store(true, Ordering::SeqCst);
{
let _phase = timing::phase("server.shutdown_wake");
wake_listener(&self.endpoint);
}
if let Some(handle) = self.handle.take() {
let _phase = timing::phase("server.thread_join");
handle
.join()
.map_err(|_| anyhow!("pglite server thread panicked"))??;
}
Ok(())
}
}
impl Drop for PgliteServer {
fn drop(&mut self) {
if let Err(err) = self.stop() {
tracing::warn!("pglite server shutdown during drop failed: {err:#}");
}
}
}
#[derive(Debug, Clone)]
pub struct PgliteServerBuilder {
root: ServerRoot,
endpoint: ServerEndpointConfig,
postgres_config: PostgresConfig,
startup_config: StartupConfig,
#[cfg(feature = "extensions")]
extensions: Vec<Extension>,
}
#[derive(Debug, Clone)]
enum ServerRoot {
Temporary { template_cache: bool },
Path(PathBuf),
}
#[derive(Debug, Clone)]
enum ServerEndpointConfig {
Tcp(SocketAddr),
#[cfg(unix)]
Unix(PathBuf),
}
impl Default for PgliteServerBuilder {
fn default() -> Self {
Self {
root: ServerRoot::Temporary {
template_cache: true,
},
endpoint: ServerEndpointConfig::Tcp(SocketAddr::from(([127, 0, 0, 1], 0))),
postgres_config: PostgresConfig::default(),
startup_config: StartupConfig::default(),
#[cfg(feature = "extensions")]
extensions: Vec::new(),
}
}
}
impl PgliteServerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn path(mut self, root: impl Into<PathBuf>) -> Self {
self.root = ServerRoot::Path(root.into());
self
}
pub fn temporary(mut self) -> Self {
self.root = ServerRoot::Temporary {
template_cache: true,
};
self
}
pub fn fresh_temporary(mut self) -> Self {
self.root = ServerRoot::Temporary {
template_cache: false,
};
self
}
pub fn tcp(mut self, addr: SocketAddr) -> Self {
self.endpoint = ServerEndpointConfig::Tcp(addr);
self
}
#[cfg(unix)]
pub fn unix(mut self, path: impl Into<PathBuf>) -> Self {
self.endpoint = ServerEndpointConfig::Unix(path.into());
self
}
pub fn postgres_config(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.postgres_config.insert(name, value);
self
}
pub fn postgres_configs<K, V>(mut self, settings: impl IntoIterator<Item = (K, V)>) -> Self
where
K: Into<String>,
V: Into<String>,
{
for (name, value) in settings {
self.postgres_config.insert(name, value);
}
self
}
pub fn username(mut self, username: impl Into<String>) -> Self {
self.startup_config.username = username.into();
self
}
pub fn database(mut self, database: impl Into<String>) -> Self {
self.startup_config.database = database.into();
self
}
pub fn debug_level(mut self, level: DebugLevel) -> Self {
self.startup_config.debug_level = Some(level);
self
}
pub fn relaxed_durability(mut self, enabled: bool) -> Self {
self.startup_config.relaxed_durability = enabled;
self
}
pub fn startup_arg(mut self, arg: impl Into<String>) -> Self {
self.startup_config.extra_args.push(arg.into());
self
}
pub fn startup_args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.startup_config
.extra_args
.extend(args.into_iter().map(Into::into));
self
}
#[cfg(feature = "extensions")]
pub fn extension(mut self, extension: Extension) -> Self {
self.extensions.push(extension);
self
}
#[cfg(feature = "extensions")]
pub fn extensions(mut self, extensions: impl IntoIterator<Item = Extension>) -> Self {
self.extensions.extend(extensions);
self
}
pub fn start(self) -> Result<PgliteServer> {
self.postgres_config.validate()?;
self.startup_config.validate()?;
#[cfg(feature = "extensions")]
let extensions = resolve_extension_set(&self.extensions)?;
let postgres_config = self.postgres_config.clone();
let startup_config = self.startup_config.clone();
let prepared_root = {
let _phase = timing::phase("server.root_prepare");
match self.root {
ServerRoot::Path(root) => {
let _phase = timing::phase("server.root_prepare.path");
let plan = RootPlan::new(RootTarget::Path(root), RootSource::Template);
#[cfg(feature = "extensions")]
let plan = plan.with_extensions(extensions.clone(), postgres_config.clone());
prepare_root(plan)?
}
ServerRoot::Temporary { template_cache } => {
let source = if template_cache {
RootSource::Template
} else {
RootSource::FreshInitdb
};
let phase = if template_cache {
"server.root_prepare.temporary_cached"
} else {
"server.root_prepare.temporary_fresh"
};
let _phase = timing::phase(phase);
let plan = RootPlan::new(RootTarget::Temporary, source);
#[cfg(feature = "extensions")]
let plan = plan.with_extensions(extensions.clone(), postgres_config.clone());
run_blocking("pglite-template-cache", move || prepare_root(plan))?
}
}
};
let PreparedRoot {
root,
temp_dir,
root_lock,
outcome,
} = prepared_root;
let shutdown = Arc::new(AtomicBool::new(false));
let proxy = {
let _phase = timing::phase("server.proxy_create");
PgliteProxy::new(root.clone()).with_prepared_root(outcome)
};
let proxy = proxy
.with_postgres_config(postgres_config)
.with_startup_config(startup_config.clone());
#[cfg(feature = "extensions")]
let proxy = proxy.with_extensions(extensions);
let (endpoint, handle) = match self.endpoint {
ServerEndpointConfig::Tcp(addr) => start_tcp(proxy, addr, shutdown.clone())?,
#[cfg(unix)]
ServerEndpointConfig::Unix(path) => start_unix(proxy, path, shutdown.clone())?,
};
Ok(PgliteServer {
root,
_temp_dir: temp_dir,
_root_lock: root_lock,
endpoint,
startup_config,
shutdown,
handle: Some(handle),
})
}
}
fn start_tcp(
proxy: PgliteProxy,
addr: SocketAddr,
shutdown: Arc<AtomicBool>,
) -> Result<(ServerEndpoint, JoinHandle<Result<()>>)> {
let listener = {
let _phase = timing::phase("server.tcp_bind");
TcpListener::bind(addr).context("bind PGlite TCP server")?
};
let addr = {
let _phase = timing::phase("server.tcp_local_addr");
listener.local_addr().context("read PGlite TCP address")?
};
let (ready_tx, ready_rx) = sync_channel(1);
let recorder = timing::current_recorder();
let handle = {
let _phase = timing::phase("server.thread_spawn");
thread::spawn(move || {
timing::with_recorder(recorder, || {
proxy.serve_tcp_listener_until_ready(listener, shutdown, Some(ready_tx))
})
})
};
{
let _phase = timing::phase("server.wait_ready");
wait_until_ready(&ready_rx)?;
}
Ok((ServerEndpoint::Tcp(addr), handle))
}
fn tcp_connection_uri(addr: SocketAddr, startup: &StartupConfig) -> String {
match addr {
SocketAddr::V4(addr) => {
format!(
"postgresql://{}@{}:{}/{}?sslmode=disable",
startup.username,
addr.ip(),
addr.port(),
startup.database
)
}
SocketAddr::V6(addr) => {
format!(
"postgresql://{}@[{}]:{}/{}?sslmode=disable",
startup.username,
addr.ip(),
addr.port(),
startup.database
)
}
}
}
fn run_blocking<T, F>(name: &'static str, f: F) -> Result<T>
where
T: Send + 'static,
F: FnOnce() -> Result<T> + Send + 'static,
{
let recorder = timing::current_recorder();
thread::Builder::new()
.name(name.to_string())
.spawn(move || timing::with_recorder(recorder, f))
.with_context(|| format!("spawn {name} worker"))?
.join()
.map_err(|_| anyhow!("{name} worker panicked"))?
}
#[cfg(unix)]
fn start_unix(
proxy: PgliteProxy,
path: PathBuf,
shutdown: Arc<AtomicBool>,
) -> Result<(ServerEndpoint, JoinHandle<Result<()>>)> {
{
let _phase = timing::phase("server.unix_prepare_path");
if path.exists() {
std::fs::remove_file(&path)
.with_context(|| format!("remove stale socket {}", path.display()))?;
}
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create socket directory {}", parent.display()))?;
}
}
let listener = {
let _phase = timing::phase("server.unix_bind");
UnixListener::bind(&path)
.with_context(|| format!("bind PGlite Unix socket {}", path.display()))?
};
let endpoint = ServerEndpoint::Unix(path);
let (ready_tx, ready_rx) = sync_channel(1);
let recorder = timing::current_recorder();
let handle = {
let _phase = timing::phase("server.thread_spawn");
thread::spawn(move || {
timing::with_recorder(recorder, || {
proxy.serve_unix_listener_until_ready(listener, shutdown, Some(ready_tx))
})
})
};
{
let _phase = timing::phase("server.wait_ready");
wait_until_ready(&ready_rx)?;
}
Ok((endpoint, handle))
}
fn wait_until_ready(ready_rx: &Receiver<Result<()>>) -> Result<()> {
ready_rx
.recv()
.context("PGlite server thread exited before reporting readiness")?
}
fn wake_listener(endpoint: &ServerEndpoint) {
match endpoint {
ServerEndpoint::Tcp(addr) => {
let _ = TcpStream::connect(addr);
}
#[cfg(unix)]
ServerEndpoint::Unix(path) => {
let _ = UnixStream::connect(path);
}
}
}
#[cfg(unix)]
fn parse_unix_socket_port(path: &Path) -> Option<u16> {
let name = path.file_name()?.to_str()?;
name.strip_prefix(".s.PGSQL.")?.parse().ok()
}
#[cfg(unix)]
fn percent_encode_query_value(value: &str) -> String {
let mut encoded = String::with_capacity(value.len());
for byte in value.bytes() {
if matches!(
byte,
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' | b'/'
) {
encoded.push(byte as char);
} else {
encoded.push_str(&format!("%{byte:02X}"));
}
}
encoded
}
#[cfg(all(test, unix))]
mod tests {
use super::percent_encode_query_value;
#[test]
fn unix_socket_uri_host_is_query_encoded() {
assert_eq!(
percent_encode_query_value("/tmp/Application Support/pglite"),
"/tmp/Application%20Support/pglite"
);
}
}