Skip to main content

pglite_oxide/pglite/
pg_dump.rs

1use std::fmt;
2use std::io::{Read, Seek, Write};
3use std::mem::MaybeUninit;
4use std::net::Shutdown;
5use std::net::{IpAddr, Ipv4Addr, SocketAddr};
6use std::pin::Pin;
7use std::sync::mpsc::{self, Receiver, SyncSender};
8use std::sync::{Arc, Mutex};
9use std::task::{Context as TaskContext, Poll};
10use std::thread;
11use std::time::{Duration, Instant};
12
13use anyhow::{Context, Result, anyhow, bail};
14use tempfile::TempDir;
15use wasmer::Store;
16use wasmer_types::ModuleHash;
17use wasmer_wasix::runners::wasi::{RuntimeOrEngine, WasiRunner};
18use wasmer_wasix::runtime::task_manager::tokio::TokioTaskManager;
19use wasmer_wasix::virtual_fs::{self, AsyncRead, AsyncSeek, AsyncWrite};
20use wasmer_wasix::virtual_net::tcp_pair::TcpSocketHalf;
21use wasmer_wasix::virtual_net::{
22    self, InterestHandler, NetworkError, SocketStatus, VirtualConnectedSocket, VirtualIoSource,
23    VirtualNetworking, VirtualSocket, VirtualTcpSocket,
24};
25use wasmer_wasix::{LocalNetworking, PluggableRuntime, VirtualFile};
26
27use crate::pglite::sync_host_fs::SyncHostFileSystem;
28use crate::pglite::timing;
29use crate::pglite::{aot, assets};
30
31/// Options for the bundled WASIX `pg_dump` runner.
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct PgDumpOptions {
34    args: Vec<String>,
35    database: String,
36    username: String,
37}
38
39impl Default for PgDumpOptions {
40    fn default() -> Self {
41        Self {
42            args: Vec::new(),
43            database: "template1".to_owned(),
44            username: "postgres".to_owned(),
45        }
46    }
47}
48
49impl PgDumpOptions {
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Add one raw `pg_dump` argument.
55    pub fn arg(mut self, arg: impl Into<String>) -> Self {
56        self.args.push(arg.into());
57        self
58    }
59
60    /// Add raw `pg_dump` arguments.
61    pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
62        self.args.extend(args.into_iter().map(Into::into));
63        self
64    }
65
66    /// Select the database to dump.
67    pub fn database(mut self, database: impl Into<String>) -> Self {
68        self.database = database.into();
69        self
70    }
71
72    /// Select the user passed to `pg_dump`.
73    pub fn username(mut self, username: impl Into<String>) -> Self {
74        self.username = username.into();
75        self
76    }
77
78    pub(crate) fn validate(&self) -> Result<()> {
79        for (name, value) in [("database", &self.database), ("username", &self.username)] {
80            anyhow::ensure!(
81                !value.is_empty() && !value.contains('\0'),
82                "pg_dump {name} must not be empty or contain NUL bytes"
83            );
84        }
85        for arg in &self.args {
86            anyhow::ensure!(
87                !arg.contains('\0'),
88                "pg_dump argument must not contain NUL bytes"
89            );
90            validate_passthrough_arg(arg)?;
91        }
92        Ok(())
93    }
94
95    pub(crate) fn database_ref(&self) -> &str {
96        &self.database
97    }
98
99    pub(crate) fn username_ref(&self) -> &str {
100        &self.username
101    }
102}
103
104fn validate_passthrough_arg(arg: &str) -> Result<()> {
105    if let Some(flag) = disallowed_pg_dump_flag(arg) {
106        anyhow::bail!(
107            "pg_dump argument '{arg}' conflicts with pglite-oxide's managed {flag}; use PgDumpOptions typed setters where available"
108        );
109    }
110    Ok(())
111}
112
113fn disallowed_pg_dump_flag(arg: &str) -> Option<&'static str> {
114    const LONG_FLAGS: &[(&str, &str)] = &[
115        ("--file", "output file"),
116        ("--format", "output format"),
117        ("--host", "host"),
118        ("--port", "port"),
119        ("--username", "username"),
120        ("--dbname", "database"),
121        ("--jobs", "job count"),
122    ];
123    for (flag, label) in LONG_FLAGS {
124        if arg == *flag
125            || arg
126                .strip_prefix(*flag)
127                .is_some_and(|tail| tail.starts_with('='))
128        {
129            return Some(label);
130        }
131    }
132
133    const SHORT_FLAGS: &[(&str, &str)] = &[
134        ("-f", "output file"),
135        ("-F", "output format"),
136        ("-h", "host"),
137        ("-p", "port"),
138        ("-U", "username"),
139        ("-d", "database"),
140        ("-j", "job count"),
141    ];
142    for (flag, label) in SHORT_FLAGS {
143        if arg == *flag || (arg.starts_with(*flag) && arg.len() > flag.len()) {
144            return Some(label);
145        }
146    }
147    None
148}
149
150pub(crate) fn dump_server_sql(addr: SocketAddr, options: &PgDumpOptions) -> Result<String> {
151    dump_sql_with_networking(addr, options, LocalNetworking::new())
152}
153
154pub(crate) type PgDumpVirtualSocket = TcpSocketHalf;
155
156pub(crate) fn dump_direct_sql<F>(options: &PgDumpOptions, serve: F) -> Result<String>
157where
158    F: FnOnce(PgDumpVirtualSocket) -> Result<()>,
159{
160    options.validate()?;
161    let (socket_tx, socket_rx) = mpsc::sync_channel(1);
162    let networking = DirectPgDumpNetworking::new(socket_tx);
163    let runner_options = options.clone();
164    let runner = thread::spawn(move || {
165        dump_sql_with_networking(DIRECT_PG_DUMP_ADDR, &runner_options, networking)
166    });
167
168    let accepted = receive_direct_pg_dump_socket(&socket_rx, &runner)
169        .context("accept direct pg_dump virtual protocol connection");
170    let serve_result = match accepted {
171        Ok(socket) => serve(socket),
172        Err(err) => Err(err),
173    };
174    let dump_result = runner
175        .join()
176        .map_err(|_| anyhow!("direct pg_dump runner thread panicked"))?;
177
178    match (serve_result, dump_result) {
179        (Ok(()), Ok(sql)) => Ok(sql),
180        (Err(err), Ok(_)) => Err(err),
181        (Ok(()), Err(err)) => Err(err),
182        (Err(err), Err(dump_err)) => {
183            Err(err.context(format!("direct pg_dump runner also failed: {dump_err:#}")))
184        }
185    }
186}
187
188fn dump_sql_with_networking<N>(
189    addr: SocketAddr,
190    options: &PgDumpOptions,
191    networking: N,
192) -> Result<String>
193where
194    N: VirtualNetworking + Sync,
195{
196    options.validate()?;
197    let _phase = timing::phase("pg_dump");
198    let wasm = {
199        let _phase = timing::phase("pg_dump.load_embedded_module");
200        assets::pg_dump_wasm()
201            .ok_or_else(|| anyhow!("WASIX pg_dump asset is not bundled in this build"))?
202    };
203    let engine = aot::headless_engine();
204    let module = {
205        let _phase = timing::phase("pg_dump.load_aot");
206        aot::load_pg_dump_module(&engine)?
207    };
208    let _store = Store::new(engine.clone());
209
210    let fs_root = TempDir::new().context("create pg_dump WASIX filesystem root")?;
211    let runtime = {
212        let _phase = timing::phase("pg_dump.tokio_runtime");
213        tokio::runtime::Builder::new_multi_thread()
214            .enable_all()
215            .build()
216            .context("create Tokio runtime for WASIX pg_dump")?
217    };
218    let (host_fs, wasix_runtime) = {
219        let _phase = timing::phase("pg_dump.wasix_runtime");
220        let _runtime_guard = runtime.enter();
221        let host_fs = SyncHostFileSystem::new(fs_root.path()).with_context(|| {
222            format!(
223                "create host filesystem rooted at {}",
224                fs_root.path().display()
225            )
226        })?;
227        let host_fs = Arc::new(host_fs) as Arc<dyn virtual_fs::FileSystem + Send + Sync>;
228        let mut wasix_runtime = PluggableRuntime::new(Arc::new(TokioTaskManager::new(
229            tokio::runtime::Handle::current(),
230        )));
231        wasix_runtime.set_engine(engine.clone());
232        wasix_runtime.set_networking_implementation(networking);
233        (host_fs, wasix_runtime)
234    };
235
236    let output_path = "/host/out.sql";
237    let port = addr.port().to_string();
238    let host = match addr {
239        SocketAddr::V4(addr) => addr.ip().to_string(),
240        SocketAddr::V6(addr) => addr.ip().to_string(),
241    };
242    let mut args = options.args.clone();
243    args.extend([
244        "-U".to_owned(),
245        options.username.clone(),
246        "-h".to_owned(),
247        host,
248        "-p".to_owned(),
249        port,
250        "--inserts".to_owned(),
251        "-j".to_owned(),
252        "1".to_owned(),
253        "-f".to_owned(),
254        output_path.to_owned(),
255    ]);
256    args.push(options.database.clone());
257
258    let stdout = Arc::new(Mutex::new(Vec::new()));
259    let stderr = Arc::new(Mutex::new(Vec::new()));
260    let mut runner = WasiRunner::new();
261    runner
262        .with_mount("/host".to_owned(), host_fs)
263        .with_current_dir("/")
264        .with_args(args)
265        .with_envs([
266            ("PGUSER", options.username.as_str()),
267            ("PGPASSWORD", "password"),
268            ("PGSSLMODE", "disable"),
269        ])
270        .with_stdout(Box::new(CaptureFile::new(Arc::clone(&stdout))))
271        .with_stderr(Box::new(CaptureFile::new(Arc::clone(&stderr))));
272    {
273        let _phase = timing::phase("pg_dump.run_wasm");
274        runner
275            .run_wasm(
276                RuntimeOrEngine::Runtime(Arc::new(wasix_runtime)),
277                "pg_dump",
278                module,
279                ModuleHash::sha256(wasm),
280            )
281            .map_err(|err| {
282                let stderr =
283                    String::from_utf8_lossy(&stderr.lock().expect("stderr capture poisoned"))
284                        .trim()
285                        .to_owned();
286                if stderr.is_empty() {
287                    anyhow!(err)
288                } else {
289                    anyhow!("{err}; pg_dump stderr: {stderr}")
290                }
291            })
292            .context("run WASIX pg_dump")?;
293    }
294
295    {
296        let _phase = timing::phase("pg_dump.read_output");
297        match std::fs::read_to_string(fs_root.path().join("out.sql")) {
298            Ok(sql) => Ok(sql),
299            Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
300                let stdout = stdout.lock().expect("stdout capture poisoned");
301                if stdout.is_empty() {
302                    Err(err).with_context(|| {
303                        format!(
304                            "read pg_dump output {}",
305                            fs_root.path().join("out.sql").display()
306                        )
307                    })
308                } else {
309                    String::from_utf8(stdout.clone()).context("decode pg_dump stdout as UTF-8")
310                }
311            }
312            Err(err) => Err(err).with_context(|| {
313                format!(
314                    "read pg_dump output {}",
315                    fs_root.path().join("out.sql").display()
316                )
317            }),
318        }
319    }
320}
321
322const DIRECT_PG_DUMP_PORT: u16 = 65_432;
323const DIRECT_PG_DUMP_SOCKET_BUFFER: usize = 8 * 1024 * 1024;
324const DIRECT_PG_DUMP_LOCAL_PORT: u16 = 65_431;
325const DIRECT_PG_DUMP_ADDR: SocketAddr =
326    SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), DIRECT_PG_DUMP_PORT);
327const DIRECT_PG_DUMP_LOCAL_ADDR: SocketAddr =
328    SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), DIRECT_PG_DUMP_LOCAL_PORT);
329
330struct DirectPgDumpNetworking {
331    socket_tx: Mutex<Option<SyncSender<PgDumpVirtualSocket>>>,
332}
333
334impl DirectPgDumpNetworking {
335    fn new(socket_tx: SyncSender<PgDumpVirtualSocket>) -> Self {
336        Self {
337            socket_tx: Mutex::new(Some(socket_tx)),
338        }
339    }
340}
341
342impl fmt::Debug for DirectPgDumpNetworking {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        f.debug_struct("DirectPgDumpNetworking")
345            .finish_non_exhaustive()
346    }
347}
348
349#[async_trait::async_trait]
350impl VirtualNetworking for DirectPgDumpNetworking {
351    async fn connect_tcp(
352        &self,
353        addr: SocketAddr,
354        peer: SocketAddr,
355    ) -> virtual_net::Result<Box<dyn VirtualTcpSocket + Sync>> {
356        if peer != DIRECT_PG_DUMP_ADDR {
357            return Err(NetworkError::ConnectionRefused);
358        }
359
360        let sender = self
361            .socket_tx
362            .lock()
363            .map_err(|_| NetworkError::IOError)?
364            .take()
365            .ok_or(NetworkError::ConnectionRefused)?;
366        let local = if addr.port() == 0 {
367            DIRECT_PG_DUMP_LOCAL_ADDR
368        } else {
369            addr
370        };
371        let (guest, host) = TcpSocketHalf::channel(DIRECT_PG_DUMP_SOCKET_BUFFER, local, peer);
372        sender
373            .send(host)
374            .map_err(|_| NetworkError::ConnectionAborted)?;
375        Ok(Box::new(DirectPgDumpTcpSocket {
376            inner: guest,
377            first_write_ready_probe: true,
378        }))
379    }
380
381    async fn resolve(
382        &self,
383        host: &str,
384        _port: Option<u16>,
385        _dns_server: Option<IpAddr>,
386    ) -> virtual_net::Result<Vec<IpAddr>> {
387        match host {
388            "localhost" | "127.0.0.1" => Ok(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]),
389            _ => Err(NetworkError::AddressNotAvailable),
390        }
391    }
392}
393
394#[derive(Debug)]
395struct DirectPgDumpTcpSocket {
396    inner: TcpSocketHalf,
397    // WASIX probes writability once while completing a blocking connect.
398    // `TcpSocketHalf` suppresses an immediate second write-ready poll until a
399    // write happens, but libpq polls again before its first StartupMessage.
400    // Keep the adapter level-triggered for that connect-to-first-write handoff.
401    first_write_ready_probe: bool,
402}
403
404impl VirtualIoSource for DirectPgDumpTcpSocket {
405    fn remove_handler(&mut self) {
406        self.inner.remove_handler();
407    }
408
409    fn poll_read_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<virtual_net::Result<usize>> {
410        self.inner.poll_read_ready(cx)
411    }
412
413    fn poll_write_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<virtual_net::Result<usize>> {
414        if self.first_write_ready_probe {
415            self.first_write_ready_probe = false;
416            return Poll::Ready(Ok(self.inner.send_buf_size().unwrap_or(1).max(1)));
417        }
418        self.inner.poll_write_ready(cx)
419    }
420}
421
422impl VirtualSocket for DirectPgDumpTcpSocket {
423    fn set_ttl(&mut self, ttl: u32) -> virtual_net::Result<()> {
424        self.inner.set_ttl(ttl)
425    }
426
427    fn ttl(&self) -> virtual_net::Result<u32> {
428        self.inner.ttl()
429    }
430
431    fn addr_local(&self) -> virtual_net::Result<SocketAddr> {
432        self.inner.addr_local()
433    }
434
435    fn status(&self) -> virtual_net::Result<SocketStatus> {
436        self.inner.status()
437    }
438
439    fn set_handler(
440        &mut self,
441        handler: Box<dyn InterestHandler + Send + Sync>,
442    ) -> virtual_net::Result<()> {
443        self.inner.set_handler(handler)
444    }
445}
446
447impl VirtualConnectedSocket for DirectPgDumpTcpSocket {
448    fn set_linger(&mut self, linger: Option<Duration>) -> virtual_net::Result<()> {
449        self.inner.set_linger(linger)
450    }
451
452    fn linger(&self) -> virtual_net::Result<Option<Duration>> {
453        self.inner.linger()
454    }
455
456    fn try_send(&mut self, data: &[u8]) -> virtual_net::Result<usize> {
457        self.inner.try_send(data)
458    }
459
460    fn try_flush(&mut self) -> virtual_net::Result<()> {
461        self.inner.try_flush()
462    }
463
464    fn close(&mut self) -> virtual_net::Result<()> {
465        self.inner.close()
466    }
467
468    fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> virtual_net::Result<usize> {
469        self.inner.try_recv(buf, peek)
470    }
471}
472
473impl VirtualTcpSocket for DirectPgDumpTcpSocket {
474    fn set_recv_buf_size(&mut self, size: usize) -> virtual_net::Result<()> {
475        self.inner.set_recv_buf_size(size)
476    }
477
478    fn recv_buf_size(&self) -> virtual_net::Result<usize> {
479        self.inner.recv_buf_size()
480    }
481
482    fn set_send_buf_size(&mut self, size: usize) -> virtual_net::Result<()> {
483        self.inner.set_send_buf_size(size)
484    }
485
486    fn send_buf_size(&self) -> virtual_net::Result<usize> {
487        self.inner.send_buf_size()
488    }
489
490    fn set_nodelay(&mut self, reuse: bool) -> virtual_net::Result<()> {
491        self.inner.set_nodelay(reuse)
492    }
493
494    fn nodelay(&self) -> virtual_net::Result<bool> {
495        self.inner.nodelay()
496    }
497
498    fn set_keepalive(&mut self, keepalive: bool) -> virtual_net::Result<()> {
499        self.inner.set_keepalive(keepalive)
500    }
501
502    fn keepalive(&self) -> virtual_net::Result<bool> {
503        self.inner.keepalive()
504    }
505
506    fn set_dontroute(&mut self, keepalive: bool) -> virtual_net::Result<()> {
507        self.inner.set_dontroute(keepalive)
508    }
509
510    fn dontroute(&self) -> virtual_net::Result<bool> {
511        self.inner.dontroute()
512    }
513
514    fn addr_peer(&self) -> virtual_net::Result<SocketAddr> {
515        self.inner.addr_peer()
516    }
517
518    fn shutdown(&mut self, how: Shutdown) -> virtual_net::Result<()> {
519        self.inner.shutdown(how)
520    }
521
522    fn is_closed(&self) -> bool {
523        self.inner.is_closed()
524    }
525}
526
527fn receive_direct_pg_dump_socket(
528    socket_rx: &Receiver<PgDumpVirtualSocket>,
529    runner: &thread::JoinHandle<Result<String>>,
530) -> Result<PgDumpVirtualSocket> {
531    let started = Instant::now();
532    loop {
533        match socket_rx.recv_timeout(Duration::from_millis(5)) {
534            Ok(socket) => return Ok(socket),
535            Err(mpsc::RecvTimeoutError::Timeout) => {
536                if runner.is_finished() {
537                    bail!("pg_dump exited before opening the direct virtual protocol connection");
538                }
539                if started.elapsed() > Duration::from_secs(30) {
540                    bail!(
541                        "timed out waiting for pg_dump to open the direct virtual protocol connection"
542                    );
543                }
544            }
545            Err(mpsc::RecvTimeoutError::Disconnected) => {
546                bail!("pg_dump direct virtual networking channel closed before connect")
547            }
548        }
549    }
550}
551
552#[derive(Debug)]
553struct CaptureFile {
554    buffer: Arc<Mutex<Vec<u8>>>,
555}
556
557impl CaptureFile {
558    fn new(buffer: Arc<Mutex<Vec<u8>>>) -> Self {
559        Self { buffer }
560    }
561}
562
563impl VirtualFile for CaptureFile {
564    fn last_accessed(&self) -> u64 {
565        0
566    }
567
568    fn last_modified(&self) -> u64 {
569        0
570    }
571
572    fn created_time(&self) -> u64 {
573        0
574    }
575
576    fn size(&self) -> u64 {
577        self.buffer.lock().expect("capture lock poisoned").len() as u64
578    }
579
580    fn set_len(&mut self, _new_size: u64) -> Result<(), wasmer_wasix::FsError> {
581        Err(wasmer_wasix::FsError::PermissionDenied)
582    }
583
584    fn unlink(&mut self) -> Result<(), wasmer_wasix::FsError> {
585        Ok(())
586    }
587
588    fn poll_read_ready(
589        self: Pin<&mut Self>,
590        _cx: &mut TaskContext<'_>,
591    ) -> Poll<std::io::Result<usize>> {
592        Poll::Ready(Ok(0))
593    }
594
595    fn poll_write_ready(
596        self: Pin<&mut Self>,
597        _cx: &mut TaskContext<'_>,
598    ) -> Poll<std::io::Result<usize>> {
599        Poll::Ready(Ok(8192))
600    }
601}
602
603impl AsyncRead for CaptureFile {
604    fn poll_read(
605        self: Pin<&mut Self>,
606        _cx: &mut TaskContext<'_>,
607        _buf: &mut tokio::io::ReadBuf<'_>,
608    ) -> Poll<std::io::Result<()>> {
609        Poll::Ready(Ok(()))
610    }
611}
612
613impl AsyncWrite for CaptureFile {
614    fn poll_write(
615        mut self: Pin<&mut Self>,
616        _cx: &mut TaskContext<'_>,
617        buf: &[u8],
618    ) -> Poll<std::io::Result<usize>> {
619        Poll::Ready(self.write(buf))
620    }
621
622    fn poll_flush(self: Pin<&mut Self>, _cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
623        Poll::Ready(Ok(()))
624    }
625
626    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
627        Poll::Ready(Ok(()))
628    }
629}
630
631impl AsyncSeek for CaptureFile {
632    fn start_seek(self: Pin<&mut Self>, _position: std::io::SeekFrom) -> std::io::Result<()> {
633        Ok(())
634    }
635
636    fn poll_complete(
637        self: Pin<&mut Self>,
638        _cx: &mut TaskContext<'_>,
639    ) -> Poll<std::io::Result<u64>> {
640        Poll::Ready(Ok(0))
641    }
642}
643
644impl Read for CaptureFile {
645    fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
646        Ok(0)
647    }
648}
649
650impl Write for CaptureFile {
651    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
652        self.buffer
653            .lock()
654            .expect("capture lock poisoned")
655            .extend_from_slice(buf);
656        Ok(buf.len())
657    }
658
659    fn flush(&mut self) -> std::io::Result<()> {
660        Ok(())
661    }
662}
663
664impl Seek for CaptureFile {
665    fn seek(&mut self, _pos: std::io::SeekFrom) -> std::io::Result<u64> {
666        Ok(0)
667    }
668}
669
670#[cfg(all(test, feature = "extensions"))]
671mod tests {
672    use super::*;
673    use crate::pglite::Pglite;
674    use crate::pglite::extensions;
675    use crate::pglite::server::PgliteServer;
676    use serde_json::json;
677    use sqlx::{Connection, Executor, Row};
678
679    #[test]
680    fn pg_dump_options_reject_managed_args() {
681        for arg in [
682            "-f",
683            "-f/tmp/out.sql",
684            "--file",
685            "--file=/tmp/out.sql",
686            "-F",
687            "-Fc",
688            "--format",
689            "--format=custom",
690            "-h",
691            "-hlocalhost",
692            "--host=localhost",
693            "-p",
694            "-p5432",
695            "--port=5432",
696            "-U",
697            "-Upostgres",
698            "--username=postgres",
699            "-d",
700            "-dpostgres",
701            "--dbname=postgres",
702            "-j",
703            "-j2",
704            "--jobs=2",
705        ] {
706            let err = PgDumpOptions::new()
707                .arg(arg)
708                .validate()
709                .expect_err("managed pg_dump arg should be rejected");
710            assert!(
711                err.to_string().contains("conflicts with pglite-oxide"),
712                "unexpected error for {arg}: {err:#}"
713            );
714        }
715    }
716
717    #[test]
718    fn pg_dump_options_allow_dump_shaping_args() -> Result<()> {
719        PgDumpOptions::new()
720            .args([
721                "--schema-only",
722                "--quote-all-identifiers",
723                "-n",
724                "public",
725                "-t",
726                "dump_items",
727            ])
728            .validate()
729    }
730
731    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
732    async fn pg_dump_round_trip_plain_sql() -> Result<()> {
733        let server = PgliteServer::temporary_tcp()?;
734        let mut conn = sqlx::PgConnection::connect(&server.database_url())
735            .await
736            .context("connect to PGlite server")?;
737        conn.execute(
738            "CREATE TABLE dump_items(id INTEGER PRIMARY KEY, value TEXT);
739             CREATE INDEX dump_items_value_idx ON dump_items(value);
740             CREATE SEQUENCE dump_items_seq START WITH 10;
741             CREATE VIEW dump_item_values AS SELECT value FROM dump_items;
742             INSERT INTO dump_items(id, value) VALUES (1, 'alpha'), (2, 'beta');
743             SELECT nextval('dump_items_seq');",
744        )
745        .await
746        .context("seed pg_dump source data")?;
747        drop(conn);
748
749        let (server, dump) = tokio::task::spawn_blocking(move || -> Result<_> {
750            let dump = server.dump_sql(PgDumpOptions::default())?;
751            Ok((server, dump))
752        })
753        .await
754        .context("join pg_dump task")??;
755
756        assert!(dump.contains("PostgreSQL database dump"));
757        assert!(
758            dump.contains("CREATE TABLE public.dump_items"),
759            "dump did not contain dump_items table DDL:\n{dump}"
760        );
761        assert!(dump.contains("CREATE INDEX dump_items_value_idx"));
762        assert!(dump.contains("CREATE SEQUENCE public.dump_items_seq"));
763        assert!(dump.contains("CREATE VIEW public.dump_item_values"));
764        assert!(dump.contains("INSERT INTO"));
765
766        let (server, schema_only) = tokio::task::spawn_blocking(move || -> Result<_> {
767            let dump = server.dump_sql(PgDumpOptions::new().arg("--schema-only"))?;
768            Ok((server, dump))
769        })
770        .await
771        .context("join schema-only pg_dump task")??;
772        assert!(schema_only.contains("CREATE TABLE public.dump_items"));
773        assert!(
774            !schema_only.contains("INSERT INTO public.dump_items"),
775            "schema-only dump unexpectedly contained data:\n{schema_only}"
776        );
777
778        let (server, quoted) = tokio::task::spawn_blocking(move || -> Result<_> {
779            let dump = server.dump_sql(PgDumpOptions::new().arg("--quote-all-identifiers"))?;
780            Ok((server, dump))
781        })
782        .await
783        .context("join quoted pg_dump task")??;
784        assert!(quoted.contains("CREATE TABLE \"public\".\"dump_items\""));
785        assert!(quoted.contains("INSERT INTO \"public\".\"dump_items\""));
786
787        let mut usable = sqlx::PgConnection::connect(&server.database_url())
788            .await
789            .context("reconnect after pg_dump")?;
790        let row = sqlx::query("SELECT count(*)::int4 AS count FROM public.dump_items")
791            .fetch_one(&mut usable)
792            .await
793            .context("server should remain usable after pg_dump")?;
794        assert_eq!(row.try_get::<i32, _>("count")?, 2);
795        usable.close().await?;
796
797        server.shutdown()?;
798
799        tokio::task::spawn_blocking(move || -> Result<()> {
800            let mut restored = Pglite::builder().temporary().open()?;
801            restored.exec(&dump, None).context("restore pg_dump SQL")?;
802            let result = restored.query(
803                "SELECT value FROM public.dump_items WHERE id = $1",
804                &[json!(2)],
805                None,
806            )?;
807            let value = result
808                .rows
809                .first()
810                .and_then(|row| row.get("value"))
811                .cloned();
812            assert_eq!(value, Some(json!("beta")));
813            let view = restored.query(
814                "SELECT count(*)::int AS count FROM public.dump_item_values",
815                &[],
816                None,
817            )?;
818            assert_eq!(view.rows[0]["count"], json!(2));
819            let sequence = restored.query(
820                "SELECT nextval('public.dump_items_seq')::int AS next_value",
821                &[],
822                None,
823            )?;
824            assert_eq!(sequence.rows[0]["next_value"], json!(11));
825            restored.close()?;
826            Ok(())
827        })
828        .await
829        .context("join restore task")??;
830        Ok(())
831    }
832
833    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
834    async fn pg_dump_round_trip_vector_extension() -> Result<()> {
835        let server = PgliteServer::builder()
836            .temporary()
837            .extension(extensions::VECTOR)
838            .start()?;
839        let mut conn = sqlx::PgConnection::connect(&server.database_url())
840            .await
841            .context("connect to extension-enabled PGlite server")?;
842        conn.execute(
843            "CREATE TABLE vector_dump_items(id INTEGER PRIMARY KEY, embedding vector(3));
844             INSERT INTO vector_dump_items(id, embedding) VALUES (1, '[1,2,3]');",
845        )
846        .await
847        .context("seed vector pg_dump source data")?;
848        drop(conn);
849
850        let (server, dump) = tokio::task::spawn_blocking(move || -> Result<_> {
851            let dump = server.dump_sql(PgDumpOptions::default())?;
852            Ok((server, dump))
853        })
854        .await
855        .context("join vector pg_dump task")??;
856        server.shutdown()?;
857
858        assert!(
859            dump.contains("CREATE EXTENSION IF NOT EXISTS vector"),
860            "dump did not contain vector extension DDL:\n{dump}"
861        );
862        assert!(dump.contains("CREATE TABLE public.vector_dump_items"));
863        assert!(dump.contains("'[1,2,3]'"));
864
865        tokio::task::spawn_blocking(move || -> Result<()> {
866            let mut restored = Pglite::builder()
867                .temporary()
868                .extension(extensions::VECTOR)
869                .open()?;
870            restored
871                .exec(&dump, None)
872                .context("restore vector dump SQL")?;
873            let result = restored.query(
874                "SELECT embedding <-> '[1,2,4]'::vector AS distance \
875                 FROM public.vector_dump_items WHERE id = $1",
876                &[json!(1)],
877                None,
878            )?;
879            let distance = result
880                .rows
881                .first()
882                .and_then(|row| row.get("distance"))
883                .and_then(|value| value.as_f64());
884            assert_eq!(distance, Some(1.0));
885            restored.close()?;
886            Ok(())
887        })
888        .await
889        .context("join vector restore task")??;
890        Ok(())
891    }
892
893    #[test]
894    fn direct_pg_dump_public_api_round_trip() -> Result<()> {
895        let mut db = Pglite::temporary()?;
896        db.exec("CREATE TABLE direct_dump_items(value TEXT)", None)?;
897        db.exec("INSERT INTO direct_dump_items VALUES ('alpha')", None)?;
898
899        let mismatched_database = db
900            .dump_sql(PgDumpOptions::new().database("other_database"))
901            .expect_err("direct pg_dump should reject database switching");
902        assert!(
903            mismatched_database
904                .to_string()
905                .contains("already-open embedded backend database"),
906            "unexpected direct pg_dump database mismatch error: {mismatched_database:#}"
907        );
908
909        let dump = db.dump_sql(PgDumpOptions::new())?;
910        assert!(dump.contains("CREATE TABLE public.direct_dump_items"));
911        assert!(dump.contains("INSERT INTO"));
912        let source_still_usable = db.query(
913            "SELECT count(*)::int AS count FROM direct_dump_items",
914            &[],
915            None,
916        )?;
917        assert_eq!(source_still_usable.rows[0]["count"], json!(1));
918
919        let mut restored = Pglite::temporary()?;
920        restored.exec(&dump, None)?;
921        let result = restored.query("SELECT value FROM public.direct_dump_items", &[], None)?;
922        assert_eq!(result.rows[0]["value"], json!("alpha"));
923
924        restored.close()?;
925        db.close()?;
926        Ok(())
927    }
928
929    #[test]
930    fn direct_pg_dump_round_trip_vector_extension() -> Result<()> {
931        let mut db = Pglite::builder()
932            .temporary()
933            .extension(extensions::VECTOR)
934            .open()?;
935        db.exec(
936            "CREATE TABLE direct_vector_dump_items(id INTEGER PRIMARY KEY, embedding vector(3));
937             INSERT INTO direct_vector_dump_items(id, embedding) VALUES (1, '[1,2,3]');",
938            None,
939        )?;
940
941        let dump = db.dump_sql(PgDumpOptions::new())?;
942        assert!(dump.contains("CREATE EXTENSION IF NOT EXISTS vector"));
943        assert!(dump.contains("CREATE TABLE public.direct_vector_dump_items"));
944
945        let mut restored = Pglite::builder()
946            .temporary()
947            .extension(extensions::VECTOR)
948            .open()?;
949        restored.exec(&dump, None)?;
950        let result = restored.query(
951            "SELECT embedding <-> '[1,2,4]'::vector AS distance \
952             FROM public.direct_vector_dump_items WHERE id = $1",
953            &[json!(1)],
954            None,
955        )?;
956        assert_eq!(result.rows[0]["distance"], json!(1.0));
957
958        restored.close()?;
959        db.close()?;
960        Ok(())
961    }
962}