1use std::fmt;
2use std::io::{Read, Seek, Write};
3use std::mem::MaybeUninit;
4use std::net::Shutdown;
5use std::net::{IpAddr, Ipv4Addr, SocketAddr};
6use std::pin::Pin;
7use std::sync::mpsc::{self, Receiver, SyncSender};
8use std::sync::{Arc, Mutex};
9use std::task::{Context as TaskContext, Poll};
10use std::thread;
11use std::time::{Duration, Instant};
12
13use anyhow::{Context, Result, anyhow, bail};
14use tempfile::TempDir;
15use wasmer::Store;
16use wasmer_types::ModuleHash;
17use wasmer_wasix::runners::wasi::{RuntimeOrEngine, WasiRunner};
18use wasmer_wasix::runtime::task_manager::tokio::TokioTaskManager;
19use wasmer_wasix::virtual_fs::{self, AsyncRead, AsyncSeek, AsyncWrite};
20use wasmer_wasix::virtual_net::tcp_pair::TcpSocketHalf;
21use wasmer_wasix::virtual_net::{
22 self, InterestHandler, NetworkError, SocketStatus, VirtualConnectedSocket, VirtualIoSource,
23 VirtualNetworking, VirtualSocket, VirtualTcpSocket,
24};
25use wasmer_wasix::{LocalNetworking, PluggableRuntime, VirtualFile};
26
27use crate::pglite::sync_host_fs::SyncHostFileSystem;
28use crate::pglite::timing;
29use crate::pglite::{aot, assets};
30
31#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct PgDumpOptions {
34 args: Vec<String>,
35 database: String,
36 username: String,
37}
38
39impl Default for PgDumpOptions {
40 fn default() -> Self {
41 Self {
42 args: Vec::new(),
43 database: "template1".to_owned(),
44 username: "postgres".to_owned(),
45 }
46 }
47}
48
49impl PgDumpOptions {
50 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn arg(mut self, arg: impl Into<String>) -> Self {
56 self.args.push(arg.into());
57 self
58 }
59
60 pub fn args(mut self, args: impl IntoIterator<Item = impl Into<String>>) -> Self {
62 self.args.extend(args.into_iter().map(Into::into));
63 self
64 }
65
66 pub fn database(mut self, database: impl Into<String>) -> Self {
68 self.database = database.into();
69 self
70 }
71
72 pub fn username(mut self, username: impl Into<String>) -> Self {
74 self.username = username.into();
75 self
76 }
77
78 pub(crate) fn validate(&self) -> Result<()> {
79 for (name, value) in [("database", &self.database), ("username", &self.username)] {
80 anyhow::ensure!(
81 !value.is_empty() && !value.contains('\0'),
82 "pg_dump {name} must not be empty or contain NUL bytes"
83 );
84 }
85 for arg in &self.args {
86 anyhow::ensure!(
87 !arg.contains('\0'),
88 "pg_dump argument must not contain NUL bytes"
89 );
90 validate_passthrough_arg(arg)?;
91 }
92 Ok(())
93 }
94
95 pub(crate) fn database_ref(&self) -> &str {
96 &self.database
97 }
98
99 pub(crate) fn username_ref(&self) -> &str {
100 &self.username
101 }
102}
103
104fn validate_passthrough_arg(arg: &str) -> Result<()> {
105 if let Some(flag) = disallowed_pg_dump_flag(arg) {
106 anyhow::bail!(
107 "pg_dump argument '{arg}' conflicts with pglite-oxide's managed {flag}; use PgDumpOptions typed setters where available"
108 );
109 }
110 Ok(())
111}
112
113fn disallowed_pg_dump_flag(arg: &str) -> Option<&'static str> {
114 const LONG_FLAGS: &[(&str, &str)] = &[
115 ("--file", "output file"),
116 ("--format", "output format"),
117 ("--host", "host"),
118 ("--port", "port"),
119 ("--username", "username"),
120 ("--dbname", "database"),
121 ("--jobs", "job count"),
122 ];
123 for (flag, label) in LONG_FLAGS {
124 if arg == *flag
125 || arg
126 .strip_prefix(*flag)
127 .is_some_and(|tail| tail.starts_with('='))
128 {
129 return Some(label);
130 }
131 }
132
133 const SHORT_FLAGS: &[(&str, &str)] = &[
134 ("-f", "output file"),
135 ("-F", "output format"),
136 ("-h", "host"),
137 ("-p", "port"),
138 ("-U", "username"),
139 ("-d", "database"),
140 ("-j", "job count"),
141 ];
142 for (flag, label) in SHORT_FLAGS {
143 if arg == *flag || (arg.starts_with(*flag) && arg.len() > flag.len()) {
144 return Some(label);
145 }
146 }
147 None
148}
149
150pub(crate) fn dump_server_sql(addr: SocketAddr, options: &PgDumpOptions) -> Result<String> {
151 dump_sql_with_networking(addr, options, LocalNetworking::new())
152}
153
154pub(crate) type PgDumpVirtualSocket = TcpSocketHalf;
155
156pub(crate) fn dump_direct_sql<F>(options: &PgDumpOptions, serve: F) -> Result<String>
157where
158 F: FnOnce(PgDumpVirtualSocket) -> Result<()>,
159{
160 options.validate()?;
161 let (socket_tx, socket_rx) = mpsc::sync_channel(1);
162 let networking = DirectPgDumpNetworking::new(socket_tx);
163 let runner_options = options.clone();
164 let runner = thread::spawn(move || {
165 dump_sql_with_networking(DIRECT_PG_DUMP_ADDR, &runner_options, networking)
166 });
167
168 let accepted = receive_direct_pg_dump_socket(&socket_rx, &runner)
169 .context("accept direct pg_dump virtual protocol connection");
170 let serve_result = match accepted {
171 Ok(socket) => serve(socket),
172 Err(err) => Err(err),
173 };
174 let dump_result = runner
175 .join()
176 .map_err(|_| anyhow!("direct pg_dump runner thread panicked"))?;
177
178 match (serve_result, dump_result) {
179 (Ok(()), Ok(sql)) => Ok(sql),
180 (Err(err), Ok(_)) => Err(err),
181 (Ok(()), Err(err)) => Err(err),
182 (Err(err), Err(dump_err)) => {
183 Err(err.context(format!("direct pg_dump runner also failed: {dump_err:#}")))
184 }
185 }
186}
187
188fn dump_sql_with_networking<N>(
189 addr: SocketAddr,
190 options: &PgDumpOptions,
191 networking: N,
192) -> Result<String>
193where
194 N: VirtualNetworking + Sync,
195{
196 options.validate()?;
197 let _phase = timing::phase("pg_dump");
198 let wasm = {
199 let _phase = timing::phase("pg_dump.load_embedded_module");
200 assets::pg_dump_wasm()
201 .ok_or_else(|| anyhow!("WASIX pg_dump asset is not bundled in this build"))?
202 };
203 let engine = aot::headless_engine();
204 let module = {
205 let _phase = timing::phase("pg_dump.load_aot");
206 aot::load_pg_dump_module(&engine)?
207 };
208 let _store = Store::new(engine.clone());
209
210 let fs_root = TempDir::new().context("create pg_dump WASIX filesystem root")?;
211 let runtime = {
212 let _phase = timing::phase("pg_dump.tokio_runtime");
213 tokio::runtime::Builder::new_multi_thread()
214 .enable_all()
215 .build()
216 .context("create Tokio runtime for WASIX pg_dump")?
217 };
218 let (host_fs, wasix_runtime) = {
219 let _phase = timing::phase("pg_dump.wasix_runtime");
220 let _runtime_guard = runtime.enter();
221 let host_fs = SyncHostFileSystem::new(fs_root.path()).with_context(|| {
222 format!(
223 "create host filesystem rooted at {}",
224 fs_root.path().display()
225 )
226 })?;
227 let host_fs = Arc::new(host_fs) as Arc<dyn virtual_fs::FileSystem + Send + Sync>;
228 let mut wasix_runtime = PluggableRuntime::new(Arc::new(TokioTaskManager::new(
229 tokio::runtime::Handle::current(),
230 )));
231 wasix_runtime.set_engine(engine.clone());
232 wasix_runtime.set_networking_implementation(networking);
233 (host_fs, wasix_runtime)
234 };
235
236 let output_path = "/host/out.sql";
237 let port = addr.port().to_string();
238 let host = match addr {
239 SocketAddr::V4(addr) => addr.ip().to_string(),
240 SocketAddr::V6(addr) => addr.ip().to_string(),
241 };
242 let mut args = options.args.clone();
243 args.extend([
244 "-U".to_owned(),
245 options.username.clone(),
246 "-h".to_owned(),
247 host,
248 "-p".to_owned(),
249 port,
250 "--inserts".to_owned(),
251 "-j".to_owned(),
252 "1".to_owned(),
253 "-f".to_owned(),
254 output_path.to_owned(),
255 ]);
256 args.push(options.database.clone());
257
258 let stdout = Arc::new(Mutex::new(Vec::new()));
259 let stderr = Arc::new(Mutex::new(Vec::new()));
260 let mut runner = WasiRunner::new();
261 runner
262 .with_mount("/host".to_owned(), host_fs)
263 .with_current_dir("/")
264 .with_args(args)
265 .with_envs([
266 ("PGUSER", options.username.as_str()),
267 ("PGPASSWORD", "password"),
268 ("PGSSLMODE", "disable"),
269 ])
270 .with_stdout(Box::new(CaptureFile::new(Arc::clone(&stdout))))
271 .with_stderr(Box::new(CaptureFile::new(Arc::clone(&stderr))));
272 {
273 let _phase = timing::phase("pg_dump.run_wasm");
274 runner
275 .run_wasm(
276 RuntimeOrEngine::Runtime(Arc::new(wasix_runtime)),
277 "pg_dump",
278 module,
279 ModuleHash::sha256(wasm),
280 )
281 .map_err(|err| {
282 let stderr =
283 String::from_utf8_lossy(&stderr.lock().expect("stderr capture poisoned"))
284 .trim()
285 .to_owned();
286 if stderr.is_empty() {
287 anyhow!(err)
288 } else {
289 anyhow!("{err}; pg_dump stderr: {stderr}")
290 }
291 })
292 .context("run WASIX pg_dump")?;
293 }
294
295 {
296 let _phase = timing::phase("pg_dump.read_output");
297 match std::fs::read_to_string(fs_root.path().join("out.sql")) {
298 Ok(sql) => Ok(sql),
299 Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
300 let stdout = stdout.lock().expect("stdout capture poisoned");
301 if stdout.is_empty() {
302 Err(err).with_context(|| {
303 format!(
304 "read pg_dump output {}",
305 fs_root.path().join("out.sql").display()
306 )
307 })
308 } else {
309 String::from_utf8(stdout.clone()).context("decode pg_dump stdout as UTF-8")
310 }
311 }
312 Err(err) => Err(err).with_context(|| {
313 format!(
314 "read pg_dump output {}",
315 fs_root.path().join("out.sql").display()
316 )
317 }),
318 }
319 }
320}
321
322const DIRECT_PG_DUMP_PORT: u16 = 65_432;
323const DIRECT_PG_DUMP_SOCKET_BUFFER: usize = 8 * 1024 * 1024;
324const DIRECT_PG_DUMP_LOCAL_PORT: u16 = 65_431;
325const DIRECT_PG_DUMP_ADDR: SocketAddr =
326 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), DIRECT_PG_DUMP_PORT);
327const DIRECT_PG_DUMP_LOCAL_ADDR: SocketAddr =
328 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), DIRECT_PG_DUMP_LOCAL_PORT);
329
330struct DirectPgDumpNetworking {
331 socket_tx: Mutex<Option<SyncSender<PgDumpVirtualSocket>>>,
332}
333
334impl DirectPgDumpNetworking {
335 fn new(socket_tx: SyncSender<PgDumpVirtualSocket>) -> Self {
336 Self {
337 socket_tx: Mutex::new(Some(socket_tx)),
338 }
339 }
340}
341
342impl fmt::Debug for DirectPgDumpNetworking {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 f.debug_struct("DirectPgDumpNetworking")
345 .finish_non_exhaustive()
346 }
347}
348
349#[async_trait::async_trait]
350impl VirtualNetworking for DirectPgDumpNetworking {
351 async fn connect_tcp(
352 &self,
353 addr: SocketAddr,
354 peer: SocketAddr,
355 ) -> virtual_net::Result<Box<dyn VirtualTcpSocket + Sync>> {
356 if peer != DIRECT_PG_DUMP_ADDR {
357 return Err(NetworkError::ConnectionRefused);
358 }
359
360 let sender = self
361 .socket_tx
362 .lock()
363 .map_err(|_| NetworkError::IOError)?
364 .take()
365 .ok_or(NetworkError::ConnectionRefused)?;
366 let local = if addr.port() == 0 {
367 DIRECT_PG_DUMP_LOCAL_ADDR
368 } else {
369 addr
370 };
371 let (guest, host) = TcpSocketHalf::channel(DIRECT_PG_DUMP_SOCKET_BUFFER, local, peer);
372 sender
373 .send(host)
374 .map_err(|_| NetworkError::ConnectionAborted)?;
375 Ok(Box::new(DirectPgDumpTcpSocket {
376 inner: guest,
377 first_write_ready_probe: true,
378 }))
379 }
380
381 async fn resolve(
382 &self,
383 host: &str,
384 _port: Option<u16>,
385 _dns_server: Option<IpAddr>,
386 ) -> virtual_net::Result<Vec<IpAddr>> {
387 match host {
388 "localhost" | "127.0.0.1" => Ok(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]),
389 _ => Err(NetworkError::AddressNotAvailable),
390 }
391 }
392}
393
394#[derive(Debug)]
395struct DirectPgDumpTcpSocket {
396 inner: TcpSocketHalf,
397 first_write_ready_probe: bool,
402}
403
404impl VirtualIoSource for DirectPgDumpTcpSocket {
405 fn remove_handler(&mut self) {
406 self.inner.remove_handler();
407 }
408
409 fn poll_read_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<virtual_net::Result<usize>> {
410 self.inner.poll_read_ready(cx)
411 }
412
413 fn poll_write_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<virtual_net::Result<usize>> {
414 if self.first_write_ready_probe {
415 self.first_write_ready_probe = false;
416 return Poll::Ready(Ok(self.inner.send_buf_size().unwrap_or(1).max(1)));
417 }
418 self.inner.poll_write_ready(cx)
419 }
420}
421
422impl VirtualSocket for DirectPgDumpTcpSocket {
423 fn set_ttl(&mut self, ttl: u32) -> virtual_net::Result<()> {
424 self.inner.set_ttl(ttl)
425 }
426
427 fn ttl(&self) -> virtual_net::Result<u32> {
428 self.inner.ttl()
429 }
430
431 fn addr_local(&self) -> virtual_net::Result<SocketAddr> {
432 self.inner.addr_local()
433 }
434
435 fn status(&self) -> virtual_net::Result<SocketStatus> {
436 self.inner.status()
437 }
438
439 fn set_handler(
440 &mut self,
441 handler: Box<dyn InterestHandler + Send + Sync>,
442 ) -> virtual_net::Result<()> {
443 self.inner.set_handler(handler)
444 }
445}
446
447impl VirtualConnectedSocket for DirectPgDumpTcpSocket {
448 fn set_linger(&mut self, linger: Option<Duration>) -> virtual_net::Result<()> {
449 self.inner.set_linger(linger)
450 }
451
452 fn linger(&self) -> virtual_net::Result<Option<Duration>> {
453 self.inner.linger()
454 }
455
456 fn try_send(&mut self, data: &[u8]) -> virtual_net::Result<usize> {
457 self.inner.try_send(data)
458 }
459
460 fn try_flush(&mut self) -> virtual_net::Result<()> {
461 self.inner.try_flush()
462 }
463
464 fn close(&mut self) -> virtual_net::Result<()> {
465 self.inner.close()
466 }
467
468 fn try_recv(&mut self, buf: &mut [MaybeUninit<u8>], peek: bool) -> virtual_net::Result<usize> {
469 self.inner.try_recv(buf, peek)
470 }
471}
472
473impl VirtualTcpSocket for DirectPgDumpTcpSocket {
474 fn set_recv_buf_size(&mut self, size: usize) -> virtual_net::Result<()> {
475 self.inner.set_recv_buf_size(size)
476 }
477
478 fn recv_buf_size(&self) -> virtual_net::Result<usize> {
479 self.inner.recv_buf_size()
480 }
481
482 fn set_send_buf_size(&mut self, size: usize) -> virtual_net::Result<()> {
483 self.inner.set_send_buf_size(size)
484 }
485
486 fn send_buf_size(&self) -> virtual_net::Result<usize> {
487 self.inner.send_buf_size()
488 }
489
490 fn set_nodelay(&mut self, reuse: bool) -> virtual_net::Result<()> {
491 self.inner.set_nodelay(reuse)
492 }
493
494 fn nodelay(&self) -> virtual_net::Result<bool> {
495 self.inner.nodelay()
496 }
497
498 fn set_keepalive(&mut self, keepalive: bool) -> virtual_net::Result<()> {
499 self.inner.set_keepalive(keepalive)
500 }
501
502 fn keepalive(&self) -> virtual_net::Result<bool> {
503 self.inner.keepalive()
504 }
505
506 fn set_dontroute(&mut self, keepalive: bool) -> virtual_net::Result<()> {
507 self.inner.set_dontroute(keepalive)
508 }
509
510 fn dontroute(&self) -> virtual_net::Result<bool> {
511 self.inner.dontroute()
512 }
513
514 fn addr_peer(&self) -> virtual_net::Result<SocketAddr> {
515 self.inner.addr_peer()
516 }
517
518 fn shutdown(&mut self, how: Shutdown) -> virtual_net::Result<()> {
519 self.inner.shutdown(how)
520 }
521
522 fn is_closed(&self) -> bool {
523 self.inner.is_closed()
524 }
525}
526
527fn receive_direct_pg_dump_socket(
528 socket_rx: &Receiver<PgDumpVirtualSocket>,
529 runner: &thread::JoinHandle<Result<String>>,
530) -> Result<PgDumpVirtualSocket> {
531 let started = Instant::now();
532 loop {
533 match socket_rx.recv_timeout(Duration::from_millis(5)) {
534 Ok(socket) => return Ok(socket),
535 Err(mpsc::RecvTimeoutError::Timeout) => {
536 if runner.is_finished() {
537 bail!("pg_dump exited before opening the direct virtual protocol connection");
538 }
539 if started.elapsed() > Duration::from_secs(30) {
540 bail!(
541 "timed out waiting for pg_dump to open the direct virtual protocol connection"
542 );
543 }
544 }
545 Err(mpsc::RecvTimeoutError::Disconnected) => {
546 bail!("pg_dump direct virtual networking channel closed before connect")
547 }
548 }
549 }
550}
551
552#[derive(Debug)]
553struct CaptureFile {
554 buffer: Arc<Mutex<Vec<u8>>>,
555}
556
557impl CaptureFile {
558 fn new(buffer: Arc<Mutex<Vec<u8>>>) -> Self {
559 Self { buffer }
560 }
561}
562
563impl VirtualFile for CaptureFile {
564 fn last_accessed(&self) -> u64 {
565 0
566 }
567
568 fn last_modified(&self) -> u64 {
569 0
570 }
571
572 fn created_time(&self) -> u64 {
573 0
574 }
575
576 fn size(&self) -> u64 {
577 self.buffer.lock().expect("capture lock poisoned").len() as u64
578 }
579
580 fn set_len(&mut self, _new_size: u64) -> Result<(), wasmer_wasix::FsError> {
581 Err(wasmer_wasix::FsError::PermissionDenied)
582 }
583
584 fn unlink(&mut self) -> Result<(), wasmer_wasix::FsError> {
585 Ok(())
586 }
587
588 fn poll_read_ready(
589 self: Pin<&mut Self>,
590 _cx: &mut TaskContext<'_>,
591 ) -> Poll<std::io::Result<usize>> {
592 Poll::Ready(Ok(0))
593 }
594
595 fn poll_write_ready(
596 self: Pin<&mut Self>,
597 _cx: &mut TaskContext<'_>,
598 ) -> Poll<std::io::Result<usize>> {
599 Poll::Ready(Ok(8192))
600 }
601}
602
603impl AsyncRead for CaptureFile {
604 fn poll_read(
605 self: Pin<&mut Self>,
606 _cx: &mut TaskContext<'_>,
607 _buf: &mut tokio::io::ReadBuf<'_>,
608 ) -> Poll<std::io::Result<()>> {
609 Poll::Ready(Ok(()))
610 }
611}
612
613impl AsyncWrite for CaptureFile {
614 fn poll_write(
615 mut self: Pin<&mut Self>,
616 _cx: &mut TaskContext<'_>,
617 buf: &[u8],
618 ) -> Poll<std::io::Result<usize>> {
619 Poll::Ready(self.write(buf))
620 }
621
622 fn poll_flush(self: Pin<&mut Self>, _cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
623 Poll::Ready(Ok(()))
624 }
625
626 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
627 Poll::Ready(Ok(()))
628 }
629}
630
631impl AsyncSeek for CaptureFile {
632 fn start_seek(self: Pin<&mut Self>, _position: std::io::SeekFrom) -> std::io::Result<()> {
633 Ok(())
634 }
635
636 fn poll_complete(
637 self: Pin<&mut Self>,
638 _cx: &mut TaskContext<'_>,
639 ) -> Poll<std::io::Result<u64>> {
640 Poll::Ready(Ok(0))
641 }
642}
643
644impl Read for CaptureFile {
645 fn read(&mut self, _buf: &mut [u8]) -> std::io::Result<usize> {
646 Ok(0)
647 }
648}
649
650impl Write for CaptureFile {
651 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
652 self.buffer
653 .lock()
654 .expect("capture lock poisoned")
655 .extend_from_slice(buf);
656 Ok(buf.len())
657 }
658
659 fn flush(&mut self) -> std::io::Result<()> {
660 Ok(())
661 }
662}
663
664impl Seek for CaptureFile {
665 fn seek(&mut self, _pos: std::io::SeekFrom) -> std::io::Result<u64> {
666 Ok(0)
667 }
668}
669
670#[cfg(all(test, feature = "extensions"))]
671mod tests {
672 use super::*;
673 use crate::pglite::Pglite;
674 use crate::pglite::extensions;
675 use crate::pglite::server::PgliteServer;
676 use serde_json::json;
677 use sqlx::{Connection, Executor, Row};
678
679 #[test]
680 fn pg_dump_options_reject_managed_args() {
681 for arg in [
682 "-f",
683 "-f/tmp/out.sql",
684 "--file",
685 "--file=/tmp/out.sql",
686 "-F",
687 "-Fc",
688 "--format",
689 "--format=custom",
690 "-h",
691 "-hlocalhost",
692 "--host=localhost",
693 "-p",
694 "-p5432",
695 "--port=5432",
696 "-U",
697 "-Upostgres",
698 "--username=postgres",
699 "-d",
700 "-dpostgres",
701 "--dbname=postgres",
702 "-j",
703 "-j2",
704 "--jobs=2",
705 ] {
706 let err = PgDumpOptions::new()
707 .arg(arg)
708 .validate()
709 .expect_err("managed pg_dump arg should be rejected");
710 assert!(
711 err.to_string().contains("conflicts with pglite-oxide"),
712 "unexpected error for {arg}: {err:#}"
713 );
714 }
715 }
716
717 #[test]
718 fn pg_dump_options_allow_dump_shaping_args() -> Result<()> {
719 PgDumpOptions::new()
720 .args([
721 "--schema-only",
722 "--quote-all-identifiers",
723 "-n",
724 "public",
725 "-t",
726 "dump_items",
727 ])
728 .validate()
729 }
730
731 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
732 async fn pg_dump_round_trip_plain_sql() -> Result<()> {
733 let server = PgliteServer::temporary_tcp()?;
734 let mut conn = sqlx::PgConnection::connect(&server.database_url())
735 .await
736 .context("connect to PGlite server")?;
737 conn.execute(
738 "CREATE TABLE dump_items(id INTEGER PRIMARY KEY, value TEXT);
739 CREATE INDEX dump_items_value_idx ON dump_items(value);
740 CREATE SEQUENCE dump_items_seq START WITH 10;
741 CREATE VIEW dump_item_values AS SELECT value FROM dump_items;
742 INSERT INTO dump_items(id, value) VALUES (1, 'alpha'), (2, 'beta');
743 SELECT nextval('dump_items_seq');",
744 )
745 .await
746 .context("seed pg_dump source data")?;
747 drop(conn);
748
749 let (server, dump) = tokio::task::spawn_blocking(move || -> Result<_> {
750 let dump = server.dump_sql(PgDumpOptions::default())?;
751 Ok((server, dump))
752 })
753 .await
754 .context("join pg_dump task")??;
755
756 assert!(dump.contains("PostgreSQL database dump"));
757 assert!(
758 dump.contains("CREATE TABLE public.dump_items"),
759 "dump did not contain dump_items table DDL:\n{dump}"
760 );
761 assert!(dump.contains("CREATE INDEX dump_items_value_idx"));
762 assert!(dump.contains("CREATE SEQUENCE public.dump_items_seq"));
763 assert!(dump.contains("CREATE VIEW public.dump_item_values"));
764 assert!(dump.contains("INSERT INTO"));
765
766 let (server, schema_only) = tokio::task::spawn_blocking(move || -> Result<_> {
767 let dump = server.dump_sql(PgDumpOptions::new().arg("--schema-only"))?;
768 Ok((server, dump))
769 })
770 .await
771 .context("join schema-only pg_dump task")??;
772 assert!(schema_only.contains("CREATE TABLE public.dump_items"));
773 assert!(
774 !schema_only.contains("INSERT INTO public.dump_items"),
775 "schema-only dump unexpectedly contained data:\n{schema_only}"
776 );
777
778 let (server, quoted) = tokio::task::spawn_blocking(move || -> Result<_> {
779 let dump = server.dump_sql(PgDumpOptions::new().arg("--quote-all-identifiers"))?;
780 Ok((server, dump))
781 })
782 .await
783 .context("join quoted pg_dump task")??;
784 assert!(quoted.contains("CREATE TABLE \"public\".\"dump_items\""));
785 assert!(quoted.contains("INSERT INTO \"public\".\"dump_items\""));
786
787 let mut usable = sqlx::PgConnection::connect(&server.database_url())
788 .await
789 .context("reconnect after pg_dump")?;
790 let row = sqlx::query("SELECT count(*)::int4 AS count FROM public.dump_items")
791 .fetch_one(&mut usable)
792 .await
793 .context("server should remain usable after pg_dump")?;
794 assert_eq!(row.try_get::<i32, _>("count")?, 2);
795 usable.close().await?;
796
797 server.shutdown()?;
798
799 tokio::task::spawn_blocking(move || -> Result<()> {
800 let mut restored = Pglite::builder().temporary().open()?;
801 restored.exec(&dump, None).context("restore pg_dump SQL")?;
802 let result = restored.query(
803 "SELECT value FROM public.dump_items WHERE id = $1",
804 &[json!(2)],
805 None,
806 )?;
807 let value = result
808 .rows
809 .first()
810 .and_then(|row| row.get("value"))
811 .cloned();
812 assert_eq!(value, Some(json!("beta")));
813 let view = restored.query(
814 "SELECT count(*)::int AS count FROM public.dump_item_values",
815 &[],
816 None,
817 )?;
818 assert_eq!(view.rows[0]["count"], json!(2));
819 let sequence = restored.query(
820 "SELECT nextval('public.dump_items_seq')::int AS next_value",
821 &[],
822 None,
823 )?;
824 assert_eq!(sequence.rows[0]["next_value"], json!(11));
825 restored.close()?;
826 Ok(())
827 })
828 .await
829 .context("join restore task")??;
830 Ok(())
831 }
832
833 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
834 async fn pg_dump_round_trip_vector_extension() -> Result<()> {
835 let server = PgliteServer::builder()
836 .temporary()
837 .extension(extensions::VECTOR)
838 .start()?;
839 let mut conn = sqlx::PgConnection::connect(&server.database_url())
840 .await
841 .context("connect to extension-enabled PGlite server")?;
842 conn.execute(
843 "CREATE TABLE vector_dump_items(id INTEGER PRIMARY KEY, embedding vector(3));
844 INSERT INTO vector_dump_items(id, embedding) VALUES (1, '[1,2,3]');",
845 )
846 .await
847 .context("seed vector pg_dump source data")?;
848 drop(conn);
849
850 let (server, dump) = tokio::task::spawn_blocking(move || -> Result<_> {
851 let dump = server.dump_sql(PgDumpOptions::default())?;
852 Ok((server, dump))
853 })
854 .await
855 .context("join vector pg_dump task")??;
856 server.shutdown()?;
857
858 assert!(
859 dump.contains("CREATE EXTENSION IF NOT EXISTS vector"),
860 "dump did not contain vector extension DDL:\n{dump}"
861 );
862 assert!(dump.contains("CREATE TABLE public.vector_dump_items"));
863 assert!(dump.contains("'[1,2,3]'"));
864
865 tokio::task::spawn_blocking(move || -> Result<()> {
866 let mut restored = Pglite::builder()
867 .temporary()
868 .extension(extensions::VECTOR)
869 .open()?;
870 restored
871 .exec(&dump, None)
872 .context("restore vector dump SQL")?;
873 let result = restored.query(
874 "SELECT embedding <-> '[1,2,4]'::vector AS distance \
875 FROM public.vector_dump_items WHERE id = $1",
876 &[json!(1)],
877 None,
878 )?;
879 let distance = result
880 .rows
881 .first()
882 .and_then(|row| row.get("distance"))
883 .and_then(|value| value.as_f64());
884 assert_eq!(distance, Some(1.0));
885 restored.close()?;
886 Ok(())
887 })
888 .await
889 .context("join vector restore task")??;
890 Ok(())
891 }
892
893 #[test]
894 fn direct_pg_dump_public_api_round_trip() -> Result<()> {
895 let mut db = Pglite::temporary()?;
896 db.exec("CREATE TABLE direct_dump_items(value TEXT)", None)?;
897 db.exec("INSERT INTO direct_dump_items VALUES ('alpha')", None)?;
898
899 let mismatched_database = db
900 .dump_sql(PgDumpOptions::new().database("other_database"))
901 .expect_err("direct pg_dump should reject database switching");
902 assert!(
903 mismatched_database
904 .to_string()
905 .contains("already-open embedded backend database"),
906 "unexpected direct pg_dump database mismatch error: {mismatched_database:#}"
907 );
908
909 let dump = db.dump_sql(PgDumpOptions::new())?;
910 assert!(dump.contains("CREATE TABLE public.direct_dump_items"));
911 assert!(dump.contains("INSERT INTO"));
912 let source_still_usable = db.query(
913 "SELECT count(*)::int AS count FROM direct_dump_items",
914 &[],
915 None,
916 )?;
917 assert_eq!(source_still_usable.rows[0]["count"], json!(1));
918
919 let mut restored = Pglite::temporary()?;
920 restored.exec(&dump, None)?;
921 let result = restored.query("SELECT value FROM public.direct_dump_items", &[], None)?;
922 assert_eq!(result.rows[0]["value"], json!("alpha"));
923
924 restored.close()?;
925 db.close()?;
926 Ok(())
927 }
928
929 #[test]
930 fn direct_pg_dump_round_trip_vector_extension() -> Result<()> {
931 let mut db = Pglite::builder()
932 .temporary()
933 .extension(extensions::VECTOR)
934 .open()?;
935 db.exec(
936 "CREATE TABLE direct_vector_dump_items(id INTEGER PRIMARY KEY, embedding vector(3));
937 INSERT INTO direct_vector_dump_items(id, embedding) VALUES (1, '[1,2,3]');",
938 None,
939 )?;
940
941 let dump = db.dump_sql(PgDumpOptions::new())?;
942 assert!(dump.contains("CREATE EXTENSION IF NOT EXISTS vector"));
943 assert!(dump.contains("CREATE TABLE public.direct_vector_dump_items"));
944
945 let mut restored = Pglite::builder()
946 .temporary()
947 .extension(extensions::VECTOR)
948 .open()?;
949 restored.exec(&dump, None)?;
950 let result = restored.query(
951 "SELECT embedding <-> '[1,2,4]'::vector AS distance \
952 FROM public.direct_vector_dump_items WHERE id = $1",
953 &[json!(1)],
954 None,
955 )?;
956 assert_eq!(result.rows[0]["distance"], json!(1.0));
957
958 restored.close()?;
959 db.close()?;
960 Ok(())
961 }
962}