Skip to main content

pglite_oxide/pglite/
client.rs

1use anyhow::{Context, Result, anyhow, bail};
2use serde_json::Value;
3use std::collections::{HashMap, HashSet};
4use std::fs;
5use std::io;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tempfile::TempDir;
10#[cfg(feature = "extensions")]
11use tokio::io::{AsyncWrite, AsyncWriteExt};
12#[cfg(feature = "extensions")]
13use tokio::runtime::Runtime;
14#[cfg(feature = "extensions")]
15use wasmer_wasix::virtual_net::VirtualTcpSocket;
16#[cfg(feature = "extensions")]
17use wasmer_wasix::virtual_net::tcp_pair::TcpSocketHalfRx;
18
19use crate::pglite::aot;
20#[cfg(feature = "extensions")]
21use crate::pglite::assets;
22use crate::pglite::backend::{BackendOpenKind, BackendSession};
23#[cfg(feature = "extensions")]
24use crate::pglite::base::install_bundled_extension_bytes;
25use crate::pglite::base::{InstallOutcome, PglitePaths, RootLock};
26use crate::pglite::builder::PgliteBuilder;
27use crate::pglite::config::{PostgresConfig, StartupConfig};
28use crate::pglite::data_dir::{DataDirArchiveFormat, dump_pgdata_archive};
29use crate::pglite::errors::PgliteError;
30#[cfg(feature = "extensions")]
31use crate::pglite::extensions::{
32    Extension, by_sql_name, extension_session_setup_sql, extension_setup_sql, resolve_extension_set,
33};
34use crate::pglite::interface::{
35    DataTransferContainer, DescribeQueryParam, DescribeQueryResult, DescribeResultField,
36    ExecProtocolOptions, ExecProtocolResult, ParserMap, QueryOptions, Results, SerializerMap,
37};
38use crate::pglite::parse::{parse_describe_statement_results, parse_results};
39#[cfg(feature = "extensions")]
40use crate::pglite::pg_dump::{PgDumpOptions, PgDumpVirtualSocket, dump_direct_sql};
41#[cfg(feature = "extensions")]
42use crate::pglite::postgres_mod::PostgresMod;
43use crate::pglite::timing;
44use crate::pglite::types::{
45    ArrayTypeInfo, DEFAULT_PARSERS, DEFAULT_SERIALIZERS, TEXT, register_array_type,
46};
47#[cfg(feature = "extensions")]
48use crate::pglite::wire::{FrontendFrameKind, FrontendFrameReader, classify_frontend_message};
49use crate::protocol::messages::{BackendMessage, DatabaseError};
50use crate::protocol::parser::Parser as ProtocolParser;
51use crate::protocol::serializer::{BindConfig, BindValue, PortalTarget, Serialize};
52
53type ChannelCallback = Arc<dyn Fn(&str) + Send + Sync + 'static>;
54type GlobalCallback = Arc<dyn Fn(&str, &str) + Send + Sync + 'static>;
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct ListenerHandle {
58    channel: String,
59    normalized_channel: String,
60    id: u64,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub struct GlobalListenerHandle {
65    id: u64,
66}
67
68impl ListenerHandle {
69    pub fn channel(&self) -> &str {
70        &self.channel
71    }
72
73    pub fn id(&self) -> u64 {
74        self.id
75    }
76}
77
78impl GlobalListenerHandle {
79    pub fn id(&self) -> u64 {
80        self.id
81    }
82}
83
84struct ChannelListener {
85    id: u64,
86    callback: ChannelCallback,
87}
88
89struct GlobalListener {
90    id: u64,
91    callback: GlobalCallback,
92}
93
94/// Primary entry point for interacting with the embedded Postgres runtime.
95pub struct Pglite {
96    backend: BackendSession,
97    _temp_dir: Option<TempDir>,
98    _root_lock: Option<RootLock>,
99    parser: ProtocolParser,
100    serializers: SerializerMap,
101    parsers: ParserMap,
102    array_type_lookup_misses: HashSet<i32>,
103    in_transaction: bool,
104    ready: bool,
105    closing: bool,
106    closed: bool,
107    blob_input_provided: bool,
108    notify_listeners: HashMap<String, Vec<ChannelListener>>,
109    global_notify_listeners: Vec<GlobalListener>,
110    next_listener_id: u64,
111    next_global_listener_id: u64,
112}
113
114impl Pglite {
115    /// Create a builder for opening persistent or temporary PGlite databases.
116    pub fn builder() -> PgliteBuilder {
117        PgliteBuilder::new()
118    }
119
120    /// Open a persistent PGlite database rooted at `root`, installing and initializing it if needed.
121    pub fn open(root: impl AsRef<Path>) -> Result<Self> {
122        Self::builder().path(root.as_ref().to_path_buf()).open()
123    }
124
125    /// Open a persistent PGlite database under the platform data directory for `app_id`.
126    pub fn open_app(app_id: (&str, &str, &str)) -> Result<Self> {
127        Self::builder().app_id(app_id).open()
128    }
129
130    /// Create an ephemeral PGlite database whose files are removed when the instance is dropped.
131    pub fn temporary() -> Result<Self> {
132        Self::builder().temporary().open()
133    }
134
135    /// Warm the runtime module and bundled AOT artifact cache without opening a database.
136    pub fn preload() -> Result<()> {
137        let (temp_dir, paths) = {
138            let _phase = timing::phase("preload.tempdir");
139            PglitePaths::with_temp_dir()?
140        };
141        {
142            let _phase = timing::phase("preload.runtime_module");
143            crate::pglite::base::preload_runtime_module(&paths)?;
144        }
145        {
146            let _phase = timing::phase("preload.aot_runtime");
147            aot::preload_runtime_artifact()?;
148        }
149        drop(temp_dir);
150        Ok(())
151    }
152
153    /// Warm bundled extension artifacts without permanently opening a database.
154    #[cfg(feature = "extensions")]
155    pub fn preload_extensions(extensions: impl IntoIterator<Item = Extension>) -> Result<()> {
156        Self::preload()?;
157        let extensions = extensions.into_iter().collect::<Vec<_>>();
158        for extension in resolve_extension_set(&extensions)? {
159            let bytes = assets::extension_archive(extension.sql_name()).ok_or_else(|| {
160                anyhow!(
161                    "extension asset '{}' is not bundled in this pglite-oxide build",
162                    extension.sql_name()
163                )
164            })?;
165            let (temp_dir, paths) = {
166                let _phase = timing::phase("preload.extension_tempdir");
167                PglitePaths::with_temp_dir()?
168            };
169            {
170                let _phase = timing::phase("preload.extension_runtime_module");
171                crate::pglite::base::preload_runtime_module(&paths)?;
172            }
173            {
174                let _phase = timing::phase("preload.extension_archive_install");
175                install_bundled_extension_bytes(&paths, extension.sql_name(), bytes)?;
176            }
177            {
178                let _phase = timing::phase("preload.extension_side_module");
179                PostgresMod::preload_extension_module_from_paths(&paths, extension)?;
180            }
181            {
182                let _phase = timing::phase("preload.extension_aot");
183                aot::preload_extension_artifact(extension)?;
184            }
185            drop(temp_dir);
186        }
187        Ok(())
188    }
189
190    /// Create a new Pglite instance backed by the provided runtime paths.
191    #[doc(hidden)]
192    pub fn new(paths: PglitePaths) -> Result<Self> {
193        let outcome = crate::pglite::base::prepare_database_root(
194            paths,
195            crate::pglite::base::RootPrepareOptions::template(),
196        )?;
197        Self::new_prepared(outcome)
198    }
199
200    pub(crate) fn new_prepared(outcome: InstallOutcome) -> Result<Self> {
201        Self::new_prepared_with_config(outcome, PostgresConfig::default(), StartupConfig::default())
202    }
203
204    pub(crate) fn new_prepared_with_config(
205        outcome: InstallOutcome,
206        postgres_config: PostgresConfig,
207        startup_config: StartupConfig,
208    ) -> Result<Self> {
209        let _phase = timing::phase("pglite.open");
210        let session_startup_config = startup_config.clone();
211        let backend = BackendSession::open(
212            outcome,
213            postgres_config,
214            startup_config,
215            BackendOpenKind::Direct,
216        )?;
217
218        let mut instance = {
219            let _phase = timing::phase("pglite.client_struct_init");
220            Self {
221                backend,
222                _temp_dir: None,
223                _root_lock: None,
224                parser: ProtocolParser::new(),
225                serializers: DEFAULT_SERIALIZERS.clone(),
226                parsers: DEFAULT_PARSERS.clone(),
227                array_type_lookup_misses: HashSet::new(),
228                in_transaction: false,
229                ready: true,
230                closing: false,
231                closed: false,
232                blob_input_provided: false,
233                notify_listeners: HashMap::new(),
234                global_notify_listeners: Vec::new(),
235                next_listener_id: 1,
236                next_global_listener_id: 1,
237            }
238        };
239
240        if session_startup_config.username != "postgres" {
241            let sql = format!(
242                "SET ROLE {}",
243                crate::pglite::templating::quote_identifier(&session_startup_config.username)
244            );
245            instance
246                .exec(&sql, None)
247                .with_context(|| format!("set startup role {}", session_startup_config.username))?;
248        }
249
250        Ok(instance)
251    }
252
253    /// Install and enable a bundled Postgres extension.
254    #[cfg(feature = "extensions")]
255    pub fn enable_extension(&mut self, extension: Extension) -> Result<()> {
256        let _phase = timing::phase("extension.enable");
257        let bytes = assets::extension_archive(extension.sql_name()).ok_or_else(|| {
258            anyhow!(
259                "extension asset '{}' is not bundled in this pglite-oxide build",
260                extension.sql_name()
261            )
262        })?;
263        install_bundled_extension_bytes(self.paths(), extension.sql_name(), bytes)?;
264        self.backend.preload_extension_module(extension)?;
265        for sql in extension_setup_sql(extension) {
266            self.exec(&sql, None)?;
267        }
268        Ok(())
269    }
270
271    #[cfg(feature = "extensions")]
272    pub(crate) fn enable_preinstalled_extension(&mut self, extension: Extension) -> Result<()> {
273        let _phase = timing::phase("extension.enable_preinstalled");
274        self.backend.preload_installed_extension(extension)?;
275        for sql in extension_session_setup_sql(extension) {
276            self.exec(&sql, None)?;
277        }
278        Ok(())
279    }
280
281    /// Refresh direct API array parser and serializer registrations.
282    ///
283    /// This mirrors upstream PGlite's `refreshArrayTypes()` escape hatch. Most
284    /// applications should not need it because built-in arrays are registered
285    /// statically and runtime custom arrays are discovered lazily when possible.
286    pub fn refresh_array_types(&mut self) -> Result<()> {
287        self.check_ready()?;
288        self.refresh_array_types_internal()
289    }
290
291    /// Execute a SQL query using the extended protocol.
292    pub fn query(
293        &mut self,
294        sql: &str,
295        params: &[Value],
296        options: Option<&QueryOptions>,
297    ) -> Result<Results> {
298        self.check_ready()?;
299
300        self.query_internal(sql, params, options)
301    }
302
303    fn query_internal(
304        &mut self,
305        sql: &str,
306        params: &[Value],
307        options: Option<&QueryOptions>,
308    ) -> Result<Results> {
309        let default_options = QueryOptions::default();
310        let query_opts = options.unwrap_or(&default_options);
311
312        self.handle_blob_input(query_opts.blob.as_ref())?;
313
314        let params_snapshot: Vec<Value> = params.to_vec();
315        let options_snapshot = options.cloned();
316        let mut collected_messages: Vec<BackendMessage> = Vec::new();
317
318        let mut exec_opts = ExecProtocolOptions::no_sync();
319        exec_opts.on_notice = query_opts.on_notice.clone();
320        exec_opts.data_transfer_container = query_opts.data_transfer_container;
321
322        let result: Result<()> = (|| {
323            let param_types = if query_opts.param_types.is_empty() {
324                &[] as &[i32]
325            } else {
326                &query_opts.param_types
327            };
328
329            let mut messages = {
330                let _phase = timing::phase("client.query.parse_describe");
331                self.parse_and_describe(sql, param_types, exec_opts.clone())?
332            };
333            let mut data_type_ids = parse_describe_statement_results(&messages);
334            if self.ensure_array_types_for_bind_values(params, &data_type_ids, query_opts)? {
335                messages = {
336                    let _phase = timing::phase("client.query.parse_describe_after_array_register");
337                    self.parse_and_describe(sql, param_types, exec_opts.clone())?
338                };
339                data_type_ids = parse_describe_statement_results(&messages);
340            }
341            collected_messages.extend(messages);
342            let bind_values = {
343                let _phase = timing::phase("client.query.prepare_bind_values");
344                self.prepare_bind_values(params, &data_type_ids, query_opts)?
345            };
346            let bind_config = BindConfig {
347                values: bind_values,
348                ..Default::default()
349            };
350            let execute_batch = {
351                let _phase = timing::phase("client.query.serialize_execute");
352                let mut execute_batch = Vec::new();
353                execute_batch.extend(Serialize::bind(&bind_config));
354                execute_batch.extend(Serialize::describe(&PortalTarget::new('P', None)));
355                execute_batch.extend(Serialize::execute(None));
356                execute_batch.extend(Serialize::sync());
357                execute_batch
358            };
359            let ExecProtocolResult { messages, .. } = {
360                let _phase = timing::phase("client.query.execute_roundtrip");
361                self.exec_protocol(&execute_batch, exec_opts.clone())?
362            };
363            collected_messages.extend(messages);
364
365            Ok(())
366        })();
367
368        if let Err(err) = result {
369            match err.downcast::<DatabaseError>() {
370                Ok(db_err) => {
371                    let enriched = PgliteError::new(db_err, sql, params_snapshot, options_snapshot);
372                    return Err(enriched.into());
373                }
374                Err(err) => {
375                    return Err(err.context(format!("failed to execute extended query: {sql}")));
376                }
377            }
378        }
379
380        {
381            let _phase = timing::phase("client.query.finish");
382            self.finish_query(collected_messages, options)
383        }
384    }
385
386    /// Return `true` if the instance is ready for new work.
387    pub fn is_ready(&self) -> bool {
388        self.ready && !self.closing && !self.closed
389    }
390
391    /// Return the host-side runtime and data-directory paths backing this instance.
392    #[doc(hidden)]
393    pub fn paths(&self) -> &PglitePaths {
394        self.backend.paths()
395    }
396
397    /// Return debug-build bridge allocation/free counters for ownership tests.
398    #[doc(hidden)]
399    #[cfg(debug_assertions)]
400    pub fn guest_bridge_allocation_counts(&self) -> (u64, u64) {
401        self.backend.guest_bridge_allocation_counts()
402    }
403
404    /// Dump the physical PGDATA directory to a gzipped tar archive.
405    ///
406    /// The archive is intended to be loaded back into pglite-oxide/PGlite with
407    /// the same PostgreSQL/PGlite version. Use [`dump_sql`](Self::dump_sql) for
408    /// logical backups across versions.
409    pub fn dump_data_dir(&mut self) -> Result<Vec<u8>> {
410        self.dump_data_dir_with_format(DataDirArchiveFormat::TarGz)
411    }
412
413    /// Dump the physical PGDATA directory with the selected archive format.
414    pub fn dump_data_dir_with_format(&mut self, format: DataDirArchiveFormat) -> Result<Vec<u8>> {
415        self.check_ready()?;
416        self.archive_quiesced_pgdata("dump PGDATA archive", format)
417    }
418
419    /// Clone this database into a new temporary [`Pglite`] instance.
420    pub fn try_clone(&mut self) -> Result<Self> {
421        #[cfg(feature = "extensions")]
422        let extensions = self.bundled_extensions_in_database()?;
423        let archive = self.dump_data_dir_with_format(DataDirArchiveFormat::Tar)?;
424        let builder = Self::builder().temporary().load_data_dir_archive(archive);
425        #[cfg(feature = "extensions")]
426        let builder = builder.extensions(extensions);
427        builder.open()
428    }
429
430    /// Run the bundled WASIX `pg_dump` against this database and return SQL text.
431    #[cfg(feature = "extensions")]
432    pub fn dump_sql(&mut self, options: PgDumpOptions) -> Result<String> {
433        self.check_ready()?;
434        options.validate()?;
435        self.checkpoint_backend_for_physical_snapshot("direct pg_dump")?;
436        self.dump_sql_via_direct_protocol(&options)
437    }
438
439    /// Run the bundled WASIX `pg_dump` and return UTF-8 SQL bytes.
440    #[cfg(feature = "extensions")]
441    pub fn dump_bytes(&mut self, options: PgDumpOptions) -> Result<Vec<u8>> {
442        Ok(self.dump_sql(options)?.into_bytes())
443    }
444
445    fn checkpoint_backend_for_physical_snapshot(&mut self, operation: &'static str) -> Result<()> {
446        if self.in_transaction {
447            bail!("{operation} cannot run while a direct transaction is active");
448        }
449        self.exec("CHECKPOINT", None)
450            .with_context(|| format!("checkpoint before {operation}"))?;
451        Ok(())
452    }
453
454    fn archive_quiesced_pgdata(
455        &mut self,
456        operation: &'static str,
457        format: DataDirArchiveFormat,
458    ) -> Result<Vec<u8>> {
459        self.checkpoint_backend_for_physical_snapshot(operation)?;
460        self.backend
461            .shutdown()
462            .with_context(|| format!("quiesce backend before {operation}"))?;
463
464        let archive = dump_pgdata_archive(
465            &self.backend.paths().pgdata,
466            self.backend.pgdata_template_root(),
467            format,
468        )
469        .with_context(|| format!("materialize physical PGDATA archive for {operation}"));
470        let restart = self
471            .backend
472            .restart()
473            .and_then(|_| self.restore_session_state_after_backend_restart())
474            .with_context(|| format!("restart backend after {operation}"));
475
476        match (archive, restart) {
477            (Ok(archive), Ok(())) => Ok(archive),
478            (Err(err), Ok(())) => Err(err),
479            (Ok(_), Err(err)) => {
480                self.ready = false;
481                self.closed = true;
482                Err(err)
483            }
484            (Err(err), Err(restart_err)) => {
485                self.ready = false;
486                self.closed = true;
487                Err(err.context(format!(
488                    "backend restart after failed {operation} also failed: {restart_err:#}"
489                )))
490            }
491        }
492    }
493
494    fn restore_session_state_after_backend_restart(&mut self) -> Result<()> {
495        let username = self.backend.startup_config().username.clone();
496        if username != "postgres" {
497            let sql = format!(
498                "SET ROLE {}",
499                crate::pglite::templating::quote_identifier(&username)
500            );
501            self.exec(&sql, None).with_context(|| {
502                format!("restore startup role {username} after backend restart")
503            })?;
504        }
505
506        let channels = self
507            .notify_listeners
508            .iter()
509            .filter(|(_, listeners)| !listeners.is_empty())
510            .map(|(channel, _)| channel.clone())
511            .collect::<Vec<_>>();
512        for channel in channels {
513            let quoted_channel = crate::pglite::templating::quote_identifier(&channel);
514            self.exec_internal(&format!("LISTEN {quoted_channel}"), None)
515                .with_context(|| format!("restore LISTEN {channel} after backend restart"))?;
516        }
517        Ok(())
518    }
519
520    #[cfg(feature = "extensions")]
521    fn dump_sql_via_direct_protocol(&mut self, options: &PgDumpOptions) -> Result<String> {
522        ensure_direct_pg_dump_options_match_session(self.backend.startup_config(), options)?;
523        let result = dump_direct_sql(options, |socket| self.serve_direct_pg_dump_protocol(socket));
524        let cleanup_result = self.cleanup_after_direct_pg_dump_session();
525
526        match (result, cleanup_result) {
527            (Ok(sql), Ok(())) => Ok(sql),
528            (Err(err), Ok(())) => Err(err),
529            (Ok(_), Err(err)) => Err(err),
530            (Err(err), Err(cleanup_err)) => Err(err.context(format!(
531                "direct pg_dump cleanup also failed: {cleanup_err:#}"
532            ))),
533        }
534    }
535
536    #[cfg(feature = "extensions")]
537    fn cleanup_after_direct_pg_dump_session(&mut self) -> Result<()> {
538        self.exec("DEALLOCATE ALL; SET search_path TO DEFAULT;", None)
539            .context("reset direct pg_dump session state")?;
540        Ok(())
541    }
542
543    #[cfg(feature = "extensions")]
544    fn serve_direct_pg_dump_protocol(&mut self, mut socket: PgDumpVirtualSocket) -> Result<()> {
545        let _ = socket.set_nodelay(true);
546        let (mut socket_tx, mut socket_rx) = socket.split();
547        let runtime = tokio::runtime::Builder::new_current_thread()
548            .enable_all()
549            .build()
550            .context("create direct pg_dump virtual socket runtime")?;
551        let mut reader = FrontendFrameReader::default();
552        let mut buffer = [0u8; 64 * 1024];
553        loop {
554            let read = read_direct_pg_dump_socket(&runtime, &mut socket_rx, &mut buffer)
555                .context("read direct pg_dump protocol socket")?;
556            if read == 0 {
557                return Ok(());
558            }
559            for message in reader.push(&buffer[..read])? {
560                match classify_frontend_message(&message)? {
561                    FrontendFrameKind::SslOrGssRequest => {
562                        write_direct_pg_dump_socket(&runtime, &mut socket_tx, b"N")
563                            .context("write direct pg_dump SSL refusal")?;
564                    }
565                    FrontendFrameKind::CancelRequest | FrontendFrameKind::Terminate => {
566                        return Ok(());
567                    }
568                    FrontendFrameKind::Startup => {
569                        if let Some(response) = self.backend.existing_startup_response() {
570                            write_direct_pg_dump_socket(&runtime, &mut socket_tx, &response)
571                                .context("write direct pg_dump existing startup response")?;
572                        } else {
573                            let response = self.backend.startup_with_packet(&message)?;
574                            write_direct_pg_dump_socket(&runtime, &mut socket_tx, &response.output)
575                                .context("write direct pg_dump startup response")?;
576                            if !response.accepted {
577                                return Ok(());
578                            }
579                        }
580                    }
581                    FrontendFrameKind::Protocol => {
582                        self.exec_protocol_raw_stream(
583                            &message,
584                            ExecProtocolOptions::no_sync(),
585                            |chunk| {
586                                write_direct_pg_dump_socket(&runtime, &mut socket_tx, chunk)
587                                    .context("write direct pg_dump backend protocol chunk")?;
588                                Ok(())
589                            },
590                        )?;
591                    }
592                }
593            }
594            flush_direct_pg_dump_socket(&runtime, &mut socket_tx)
595                .context("flush direct pg_dump socket")?;
596        }
597    }
598
599    #[cfg(feature = "extensions")]
600    fn bundled_extensions_in_database(&mut self) -> Result<Vec<Extension>> {
601        let results = self.query(
602            "SELECT extname FROM pg_catalog.pg_extension ORDER BY extname",
603            &[],
604            None,
605        )?;
606        let extensions = results
607            .rows
608            .iter()
609            .filter_map(|row| row.get("extname"))
610            .filter_map(|value| value.as_str())
611            .filter_map(by_sql_name)
612            .collect();
613        Ok(extensions)
614    }
615
616    pub(crate) fn attach_temp_dir(&mut self, temp_dir: TempDir) {
617        self._temp_dir = Some(temp_dir);
618    }
619
620    pub(crate) fn attach_root_lock(&mut self, root_lock: RootLock) {
621        self._root_lock = Some(root_lock);
622    }
623
624    /// Return `true` if the instance has already been closed.
625    pub fn is_closed(&self) -> bool {
626        self.closed
627    }
628
629    /// Shut down the embedded Postgres runtime.
630    pub fn close(&mut self) -> Result<()> {
631        self.close_backend()
632    }
633
634    fn close_backend(&mut self) -> Result<()> {
635        if self.closed {
636            return Ok(());
637        }
638        if self.closing {
639            bail!("Pglite is closing");
640        }
641
642        self.closing = true;
643        let result = (|| {
644            self.backend.shutdown()?;
645            self.sync_to_fs()
646        })();
647
648        self.closing = false;
649        if result.is_ok() {
650            self.closed = true;
651            self.ready = false;
652            self.notify_listeners.clear();
653            self.global_notify_listeners.clear();
654            self._root_lock = None;
655        }
656        result
657    }
658
659    #[cfg(feature = "extensions")]
660    pub(crate) fn close_for_template_cache(&mut self) -> Result<()> {
661        self.close_backend()
662    }
663
664    /// Execute a simple SQL statement that may contain multiple commands.
665    pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
666        self.check_ready()?;
667
668        self.exec_internal(sql, options)
669    }
670
671    fn exec_internal(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
672        let options_snapshot = options.cloned();
673        let default_options = QueryOptions::default();
674        let exec_opts_ref = options.unwrap_or(&default_options);
675        let mut exec_opts = ExecProtocolOptions::no_sync();
676        exec_opts.on_notice = exec_opts_ref.on_notice.clone();
677        exec_opts.data_transfer_container = exec_opts_ref.data_transfer_container;
678
679        self.handle_blob_input(exec_opts_ref.blob.as_ref())?;
680
681        let mut collected_messages: Vec<BackendMessage> = Vec::new();
682
683        let message = Serialize::query(sql);
684        let ExecProtocolResult { messages, .. } = match self.exec_protocol(&message, exec_opts) {
685            Ok(result) => result,
686            Err(err) => match err.downcast::<DatabaseError>() {
687                Ok(db_err) => {
688                    let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
689                    return Err(enriched.into());
690                }
691                Err(err) => {
692                    return Err(err.context(format!("failed to execute simple query: {sql}")));
693                }
694            },
695        };
696        collected_messages.extend(messages);
697
698        self.finish_exec(collected_messages, options)
699    }
700
701    /// Register a listener for `LISTEN channel`. Returns a handle that can be used to unlisten.
702    pub fn listen<F>(&mut self, channel: &str, callback: F) -> Result<ListenerHandle>
703    where
704        F: Fn(&str) + Send + Sync + 'static,
705    {
706        self.check_ready()?;
707
708        let quoted_channel = crate::pglite::templating::quote_identifier(channel);
709        let normalized = channel.to_string();
710        let should_listen = match self.notify_listeners.get(&normalized) {
711            Some(existing) => existing.is_empty(),
712            None => true,
713        };
714
715        if should_listen {
716            self.exec_internal(&format!("LISTEN {quoted_channel}"), None)?;
717        }
718
719        let callback: ChannelCallback = Arc::new(callback);
720        let entry = self.notify_listeners.entry(normalized.clone()).or_default();
721        let id = self.next_listener_id;
722        self.next_listener_id = self.next_listener_id.wrapping_add(1);
723        entry.push(ChannelListener { id, callback });
724
725        Ok(ListenerHandle {
726            channel: channel.to_string(),
727            normalized_channel: normalized,
728            id,
729        })
730    }
731
732    /// Remove a listener corresponding to the provided handle.
733    pub fn unlisten(&mut self, handle: ListenerHandle) -> Result<()> {
734        if let Some(listeners) = self.notify_listeners.get_mut(&handle.normalized_channel) {
735            listeners.retain(|listener| listener.id != handle.id);
736            if listeners.is_empty() {
737                self.notify_listeners.remove(&handle.normalized_channel);
738                let quoted_channel = crate::pglite::templating::quote_identifier(&handle.channel);
739                self.exec_internal(&format!("UNLISTEN {quoted_channel}"), None)?;
740            }
741        }
742        Ok(())
743    }
744
745    /// Remove all listeners for the specified channel.
746    pub fn unlisten_channel(&mut self, channel: &str) -> Result<()> {
747        let quoted_channel = crate::pglite::templating::quote_identifier(channel);
748        let normalized = channel.to_string();
749        if self.notify_listeners.remove(&normalized).is_some() {
750            self.exec_internal(&format!("UNLISTEN {quoted_channel}"), None)?;
751        }
752        Ok(())
753    }
754
755    /// Register a global notification callback.
756    pub fn on_notification<F>(&mut self, callback: F) -> GlobalListenerHandle
757    where
758        F: Fn(&str, &str) + Send + Sync + 'static,
759    {
760        let id = self.next_global_listener_id;
761        self.next_global_listener_id = self.next_global_listener_id.wrapping_add(1);
762        let callback: GlobalCallback = Arc::new(callback);
763        self.global_notify_listeners
764            .push(GlobalListener { id, callback });
765        GlobalListenerHandle { id }
766    }
767
768    /// Deregister a previously registered global notification callback.
769    pub fn off_notification(&mut self, handle: GlobalListenerHandle) {
770        self.global_notify_listeners
771            .retain(|listener| listener.id != handle.id);
772    }
773
774    /// Describe the parameter and result metadata for a SQL query.
775    pub fn describe_query(
776        &mut self,
777        sql: &str,
778        options: Option<&QueryOptions>,
779    ) -> Result<DescribeQueryResult> {
780        self.check_ready()?;
781
782        let default_options = QueryOptions::default();
783        let query_opts = options.unwrap_or(&default_options);
784
785        let options_snapshot = options.cloned();
786        let mut exec_opts = ExecProtocolOptions::no_sync();
787        exec_opts.on_notice = query_opts.on_notice.clone();
788        exec_opts.data_transfer_container = query_opts.data_transfer_container;
789
790        let mut describe_messages: Vec<BackendMessage> = Vec::new();
791
792        let result: Result<()> = (|| {
793            let param_types = if query_opts.param_types.is_empty() {
794                &[] as &[i32]
795            } else {
796                &query_opts.param_types
797            };
798
799            let mut describe_batch = Vec::new();
800            describe_batch.extend(Serialize::parse(None, sql, param_types));
801            describe_batch.extend(Serialize::describe(&PortalTarget::new('S', None)));
802            describe_batch.extend(Serialize::sync());
803            let ExecProtocolResult { messages, .. } =
804                self.exec_protocol(&describe_batch, exec_opts.clone())?;
805            if !messages
806                .iter()
807                .any(|message| matches!(message, BackendMessage::ParseComplete { .. }))
808            {
809                bail!("extended query parse did not complete");
810            }
811            describe_messages.extend(messages);
812
813            Ok(())
814        })();
815
816        if let Err(err) = result {
817            match err.downcast::<DatabaseError>() {
818                Ok(db_err) => {
819                    let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
820                    return Err(enriched.into());
821                }
822                Err(err) => {
823                    return Err(err.context(format!("failed to describe query: {sql}")));
824                }
825            }
826        }
827
828        let param_type_ids = parse_describe_statement_results(&describe_messages);
829        self.ensure_array_types_for_oids(param_type_ids.iter().copied(), Some(query_opts))?;
830        let result_type_ids = describe_messages
831            .iter()
832            .filter_map(|msg| match msg {
833                BackendMessage::RowDescription(desc) => Some(desc),
834                _ => None,
835            })
836            .flat_map(|desc| desc.fields.iter().map(|field| field.data_type_id))
837            .collect::<Vec<_>>();
838        self.ensure_array_types_for_oids(result_type_ids.iter().copied(), Some(query_opts))?;
839
840        let query_params = param_type_ids
841            .into_iter()
842            .map(|oid| DescribeQueryParam {
843                data_type_id: oid,
844                serializer: self.serializers.get(&oid).cloned(),
845            })
846            .collect();
847
848        let result_fields = describe_messages
849            .iter()
850            .find_map(|msg| match msg {
851                BackendMessage::RowDescription(desc) => Some(
852                    desc.fields
853                        .iter()
854                        .map(|field| DescribeResultField {
855                            name: field.name.clone(),
856                            data_type_id: field.data_type_id,
857                            parser: self.parsers.get(&field.data_type_id).cloned(),
858                        })
859                        .collect::<Vec<_>>(),
860                ),
861                _ => None,
862            })
863            .unwrap_or_default();
864
865        Ok(DescribeQueryResult {
866            query_params,
867            result_fields,
868        })
869    }
870
871    /// Run a closure within an SQL transaction (`BEGIN .. COMMIT/ROLLBACK`).
872    pub fn transaction<F, T>(&mut self, mut callback: F) -> Result<T>
873    where
874        F: FnMut(&mut Transaction<'_>) -> Result<T>,
875    {
876        self.check_ready()?;
877
878        // Begin transaction
879        self.run_exec_command("BEGIN")?;
880        self.in_transaction = true;
881
882        let mut tx = Transaction::new(self);
883        let callback_result = callback(&mut tx);
884
885        let txn_result = match callback_result {
886            Ok(value) => {
887                if !tx.closed {
888                    tx.commit_internal()?;
889                }
890                Ok(value)
891            }
892            Err(err) => {
893                if !tx.closed {
894                    tx.rollback_internal()?;
895                }
896                Err(err)
897            }
898        };
899
900        self.in_transaction = false;
901        txn_result
902    }
903
904    /// Flush runtime writes to the underlying filesystem.
905    ///
906    /// The WASIX backend uses host-mounted files and PostgreSQL's own fsync/WAL
907    /// behavior for durability. Adding an unconditional host directory
908    /// `sync_all` after every direct query is both expensive and weaker than the
909    /// database's file-level fsyncs, so the Rust-level hook remains a no-op.
910    pub fn sync_to_fs(&mut self) -> Result<()> {
911        Ok(())
912    }
913
914    fn prepare_bind_values(
915        &self,
916        params: &[Value],
917        data_type_ids: &[i32],
918        options: &QueryOptions,
919    ) -> Result<Vec<BindValue>> {
920        if params.is_empty() {
921            return Ok(Vec::new());
922        }
923
924        let mut values = Vec::with_capacity(params.len());
925        let overrides = if options.serializers.is_empty() {
926            None
927        } else {
928            Some(&options.serializers)
929        };
930
931        for (idx, value) in params.iter().enumerate() {
932            if value.is_null() {
933                values.push(BindValue::Null);
934                continue;
935            }
936
937            let oid = data_type_ids.get(idx).copied().unwrap_or(TEXT);
938            let serializer = overrides
939                .and_then(|map| map.get(&oid))
940                .or_else(|| self.serializers.get(&oid));
941
942            let serialized = match serializer {
943                Some(func) => func(value).with_context(|| {
944                    format!("failed to serialize parameter {idx} using OID {oid}")
945                })?,
946                None => self.default_serialize_value(value),
947            };
948
949            values.push(BindValue::Text(serialized));
950        }
951
952        Ok(values)
953    }
954
955    fn parse_and_describe(
956        &mut self,
957        sql: &str,
958        param_types: &[i32],
959        exec_opts: ExecProtocolOptions,
960    ) -> Result<Vec<BackendMessage>> {
961        let mut prepare_batch = Vec::new();
962        prepare_batch.extend(Serialize::parse(None, sql, param_types));
963        prepare_batch.extend(Serialize::describe(&PortalTarget::new('S', None)));
964        prepare_batch.extend(Serialize::sync());
965        let ExecProtocolResult { messages, .. } = self.exec_protocol(&prepare_batch, exec_opts)?;
966        if !messages
967            .iter()
968            .any(|message| matches!(message, BackendMessage::ParseComplete { .. }))
969        {
970            bail!("extended query parse did not complete");
971        }
972        Ok(messages)
973    }
974
975    fn default_serialize_value(&self, value: &Value) -> String {
976        Self::default_serialize_value_static(value)
977    }
978
979    pub(crate) fn default_serialize_value_static(value: &Value) -> String {
980        match value {
981            Value::String(s) => s.clone(),
982            Value::Number(num) => num.to_string(),
983            Value::Bool(flag) => {
984                if *flag {
985                    "t".to_string()
986                } else {
987                    "f".to_string()
988                }
989            }
990            _ => value.to_string(),
991        }
992    }
993
994    fn finish_query(
995        &mut self,
996        messages: Vec<BackendMessage>,
997        options: Option<&QueryOptions>,
998    ) -> Result<Results> {
999        let blob = {
1000            let _phase = timing::phase("client.finish.blob_read");
1001            self.get_written_blob()?
1002        };
1003        {
1004            let _phase = timing::phase("client.finish.blob_cleanup");
1005            self.cleanup_blob()?;
1006        }
1007        if !self.in_transaction {
1008            let _phase = timing::phase("client.finish.sync_to_fs");
1009            self.sync_to_fs()?;
1010        }
1011        {
1012            let _phase = timing::phase("client.finish.ensure_array_types");
1013            self.ensure_array_types_for_result_messages(&messages, options)?;
1014        }
1015        let parsed = {
1016            let _phase = timing::phase("client.finish.parse_results");
1017            parse_results(&messages, &self.parsers, options, blob)
1018        };
1019        parsed
1020            .into_iter()
1021            .next()
1022            .ok_or_else(|| anyhow!("query returned no result sets"))
1023    }
1024
1025    fn finish_exec(
1026        &mut self,
1027        messages: Vec<BackendMessage>,
1028        options: Option<&QueryOptions>,
1029    ) -> Result<Vec<Results>> {
1030        let blob = {
1031            let _phase = timing::phase("client.finish.blob_read");
1032            self.get_written_blob()?
1033        };
1034        {
1035            let _phase = timing::phase("client.finish.blob_cleanup");
1036            self.cleanup_blob()?;
1037        }
1038        if !self.in_transaction {
1039            let _phase = timing::phase("client.finish.sync_to_fs");
1040            self.sync_to_fs()?;
1041        }
1042        {
1043            let _phase = timing::phase("client.finish.ensure_array_types");
1044            self.ensure_array_types_for_result_messages(&messages, options)?;
1045        }
1046        let parsed = {
1047            let _phase = timing::phase("client.finish.parse_results");
1048            parse_results(&messages, &self.parsers, options, blob)
1049        };
1050        Ok(parsed)
1051    }
1052
1053    /// Execute raw PostgreSQL frontend protocol bytes and parse backend
1054    /// protocol messages.
1055    pub fn exec_protocol(
1056        &mut self,
1057        message: &[u8],
1058        options: ExecProtocolOptions,
1059    ) -> Result<ExecProtocolResult> {
1060        let ExecProtocolOptions {
1061            sync_to_fs,
1062            throw_on_error,
1063            on_notice,
1064            data_transfer_container,
1065        } = options;
1066
1067        let data = {
1068            let _phase = timing::phase("client.protocol_roundtrip");
1069            self.exec_protocol_raw_inner(message, sync_to_fs, data_transfer_container)?
1070        };
1071
1072        let mut messages = Vec::new();
1073        let on_notice_cb = on_notice.clone();
1074        let parse_result = {
1075            let _phase = timing::phase("client.protocol_parse");
1076            self.parser.parse(&data, |msg| {
1077                if let BackendMessage::Error(db_err) = &msg
1078                    && throw_on_error
1079                {
1080                    return Err(anyhow!(db_err.clone()));
1081                }
1082                if let Some(callback) = on_notice_cb.as_ref()
1083                    && let BackendMessage::Notice(notice) = &msg
1084                {
1085                    callback(notice);
1086                }
1087                messages.push(msg);
1088                Ok(())
1089            })
1090        };
1091        if let Err(err) = parse_result {
1092            match err.downcast::<DatabaseError>() {
1093                Ok(db_err) => {
1094                    self.parser = ProtocolParser::new();
1095                    return Err(anyhow!(db_err));
1096                }
1097                Err(err) => return Err(err),
1098            }
1099        }
1100
1101        for message in &messages {
1102            if let BackendMessage::Notification(note) = message {
1103                if let Some(listeners) = self.notify_listeners.get(&note.channel) {
1104                    for listener in listeners {
1105                        (listener.callback)(&note.payload);
1106                    }
1107                }
1108                for listener in &self.global_notify_listeners {
1109                    (listener.callback)(&note.channel, &note.payload);
1110                }
1111            }
1112        }
1113
1114        Ok(ExecProtocolResult { data, messages })
1115    }
1116
1117    /// Execute raw PostgreSQL frontend protocol bytes and return raw backend
1118    /// protocol bytes.
1119    pub fn exec_protocol_raw(
1120        &mut self,
1121        message: &[u8],
1122        options: ExecProtocolOptions,
1123    ) -> Result<Vec<u8>> {
1124        self.exec_protocol_raw_inner(message, options.sync_to_fs, options.data_transfer_container)
1125    }
1126
1127    /// Execute raw protocol bytes and pass the returned backend bytes to
1128    /// `on_data`.
1129    pub fn exec_protocol_raw_stream<F>(
1130        &mut self,
1131        message: &[u8],
1132        options: ExecProtocolOptions,
1133        mut on_data: F,
1134    ) -> Result<()>
1135    where
1136        F: FnMut(&[u8]) -> Result<()>,
1137    {
1138        self.backend.send_framed_raw_stream(
1139            message,
1140            options.data_transfer_container,
1141            &mut on_data,
1142        )?;
1143        if options.sync_to_fs {
1144            let _phase = timing::phase("client.protocol_stream_sync_to_fs");
1145            self.sync_to_fs()?;
1146        }
1147        Ok(())
1148    }
1149
1150    fn exec_protocol_raw_inner(
1151        &mut self,
1152        message: &[u8],
1153        sync_to_fs: bool,
1154        data_transfer_container: Option<DataTransferContainer>,
1155    ) -> Result<Vec<u8>> {
1156        let data = {
1157            let _phase = timing::phase("client.protocol_transport_send");
1158            self.backend
1159                .send_buffered(message, data_transfer_container)?
1160        };
1161        if sync_to_fs {
1162            let _phase = timing::phase("client.protocol_sync_to_fs");
1163            self.sync_to_fs()?;
1164        }
1165        Ok(data)
1166    }
1167
1168    fn ensure_array_types_for_bind_values(
1169        &mut self,
1170        params: &[Value],
1171        data_type_ids: &[i32],
1172        options: &QueryOptions,
1173    ) -> Result<bool> {
1174        let mut registered = false;
1175        for (idx, value) in params.iter().enumerate() {
1176            if !value.is_array() {
1177                continue;
1178            }
1179            let oid = data_type_ids.get(idx).copied().unwrap_or(TEXT);
1180            if options.serializers.contains_key(&oid) || self.serializers.contains_key(&oid) {
1181                continue;
1182            }
1183            registered |= self.try_register_array_type_by_array_oid(oid)?;
1184        }
1185        Ok(registered)
1186    }
1187
1188    fn ensure_array_types_for_result_messages(
1189        &mut self,
1190        messages: &[BackendMessage],
1191        options: Option<&QueryOptions>,
1192    ) -> Result<()> {
1193        let oids = messages
1194            .iter()
1195            .filter_map(|msg| match msg {
1196                BackendMessage::RowDescription(desc) => Some(desc),
1197                _ => None,
1198            })
1199            .flat_map(|desc| desc.fields.iter().map(|field| field.data_type_id))
1200            .collect::<Vec<_>>();
1201        self.ensure_array_types_for_oids(oids, options)
1202    }
1203
1204    fn ensure_array_types_for_oids(
1205        &mut self,
1206        oids: impl IntoIterator<Item = i32>,
1207        options: Option<&QueryOptions>,
1208    ) -> Result<()> {
1209        for oid in oids {
1210            if oid <= 0 || self.parsers.contains_key(&oid) {
1211                continue;
1212            }
1213            if options.is_some_and(|options| options.parsers.contains_key(&oid)) {
1214                continue;
1215            }
1216            self.try_register_array_type_by_array_oid(oid)?;
1217        }
1218        Ok(())
1219    }
1220
1221    fn refresh_array_types_internal(&mut self) -> Result<()> {
1222        let sql = "
1223            SELECT e.oid, a.oid AS typarray, e.typdelim::text AS typdelim
1224            FROM pg_catalog.pg_type a
1225            JOIN pg_catalog.pg_type e ON e.oid = a.typelem
1226            WHERE a.typcategory = 'A'
1227              AND a.typelem <> 0
1228            ORDER BY e.oid
1229        ";
1230        let results = {
1231            let _phase = timing::phase("pglite.array_type_catalog_query");
1232            self.exec_internal(sql, None)?
1233        };
1234        let result_set = results
1235            .into_iter()
1236            .next()
1237            .ok_or_else(|| anyhow!("array type discovery returned no results"))?;
1238
1239        {
1240            let _phase = timing::phase("pglite.array_type_register");
1241            for row in result_set.rows {
1242                if let Some(info) = array_type_info_from_row(&row) {
1243                    self.register_array_type(info);
1244                }
1245            }
1246        }
1247        Ok(())
1248    }
1249
1250    fn try_register_array_type_by_array_oid(&mut self, array_oid: i32) -> Result<bool> {
1251        if array_oid <= 0
1252            || self.parsers.contains_key(&array_oid)
1253            || self.array_type_lookup_misses.contains(&array_oid)
1254        {
1255            return Ok(false);
1256        }
1257
1258        let sql = format!(
1259            "SELECT e.oid, a.oid AS typarray, e.typdelim::text AS typdelim \
1260             FROM pg_catalog.pg_type a \
1261             JOIN pg_catalog.pg_type e ON e.oid = a.typelem \
1262             WHERE a.oid = {array_oid}::oid \
1263               AND a.typcategory = 'A' \
1264               AND a.typelem <> 0"
1265        );
1266        let results = {
1267            let _phase = timing::phase("pglite.array_type_targeted_lookup");
1268            self.exec_internal(&sql, None)?
1269        };
1270        let Some(result_set) = results.into_iter().next() else {
1271            self.array_type_lookup_misses.insert(array_oid);
1272            return Ok(false);
1273        };
1274        let Some(row) = result_set.rows.into_iter().next() else {
1275            self.array_type_lookup_misses.insert(array_oid);
1276            return Ok(false);
1277        };
1278        let Some(info) = array_type_info_from_row(&row) else {
1279            self.array_type_lookup_misses.insert(array_oid);
1280            return Ok(false);
1281        };
1282
1283        self.register_array_type(info);
1284        Ok(true)
1285    }
1286
1287    fn register_array_type(&mut self, info: ArrayTypeInfo) {
1288        register_array_type(&mut self.parsers, &mut self.serializers, info);
1289        self.array_type_lookup_misses.remove(&info.array_oid);
1290    }
1291
1292    fn run_exec_command(&mut self, sql: &str) -> Result<()> {
1293        self.exec_internal(sql, None).map(|_| ())
1294    }
1295
1296    fn handle_blob_input(&mut self, blob: Option<&Vec<u8>>) -> Result<()> {
1297        let path = self.dev_blob_path();
1298        if let Some(bytes) = blob {
1299            if let Some(parent) = path.parent() {
1300                fs::create_dir_all(parent).with_context(|| {
1301                    format!("failed to create blob directory {}", parent.display())
1302                })?;
1303            }
1304            fs::write(&path, bytes)
1305                .with_context(|| format!("write blob input to {}", path.display()))?;
1306            self.blob_input_provided = true;
1307        } else {
1308            self.blob_input_provided = false;
1309            let _ = fs::remove_file(&path);
1310        }
1311        Ok(())
1312    }
1313
1314    fn dev_blob_path(&self) -> PathBuf {
1315        self.backend.paths().runtime_root().join("dev/blob")
1316    }
1317
1318    fn cleanup_blob(&mut self) -> Result<()> {
1319        Ok(())
1320    }
1321
1322    fn get_written_blob(&mut self) -> Result<Option<Vec<u8>>> {
1323        let path = self.dev_blob_path();
1324
1325        if self.blob_input_provided {
1326            self.blob_input_provided = false;
1327            let _ = fs::remove_file(&path);
1328            return Ok(None);
1329        }
1330
1331        match fs::read(&path) {
1332            Ok(data) => {
1333                self.blob_input_provided = false;
1334                let _ = fs::remove_file(&path);
1335                if data.is_empty() {
1336                    Ok(None)
1337                } else {
1338                    Ok(Some(data))
1339                }
1340            }
1341            Err(err) => {
1342                if err.kind() == io::ErrorKind::NotFound {
1343                    self.blob_input_provided = false;
1344                    Ok(None)
1345                } else {
1346                    Err(err).with_context(|| format!("read blob output from {}", path.display()))
1347                }
1348            }
1349        }
1350    }
1351
1352    fn check_ready(&self) -> Result<()> {
1353        if self.closing {
1354            bail!("Pglite instance is closing");
1355        }
1356        if self.closed {
1357            bail!("Pglite instance is closed");
1358        }
1359        if !self.ready {
1360            bail!("Pglite instance is not ready");
1361        }
1362        Ok(())
1363    }
1364}
1365
1366impl Drop for Pglite {
1367    fn drop(&mut self) {
1368        if !self.closed {
1369            let _ = self.close();
1370        }
1371    }
1372}
1373
1374#[cfg(feature = "extensions")]
1375fn ensure_direct_pg_dump_options_match_session(
1376    startup_config: &StartupConfig,
1377    options: &PgDumpOptions,
1378) -> Result<()> {
1379    if options.database_ref() != startup_config.database {
1380        bail!(
1381            "direct pg_dump runs against the already-open embedded backend database '{}'; requested database '{}' would require a separate server connection",
1382            startup_config.database,
1383            options.database_ref()
1384        );
1385    }
1386    if options.username_ref() != startup_config.username {
1387        bail!(
1388            "direct pg_dump runs through the already-open embedded backend user '{}'; requested user '{}' would require a separate server connection",
1389            startup_config.username,
1390            options.username_ref()
1391        );
1392    }
1393    Ok(())
1394}
1395
1396#[cfg(feature = "extensions")]
1397fn read_direct_pg_dump_socket(
1398    runtime: &Runtime,
1399    reader: &mut TcpSocketHalfRx,
1400    buffer: &mut [u8],
1401) -> Result<usize> {
1402    runtime
1403        .block_on(async {
1404            std::future::poll_fn(|cx| {
1405                let read = match reader.poll_fill_buf(cx) {
1406                    std::task::Poll::Ready(Ok(available)) => {
1407                        let read = available.len().min(buffer.len());
1408                        buffer[..read].copy_from_slice(&available[..read]);
1409                        read
1410                    }
1411                    std::task::Poll::Ready(Err(err)) => return std::task::Poll::Ready(Err(err)),
1412                    std::task::Poll::Pending => return std::task::Poll::Pending,
1413                };
1414                reader.consume(read);
1415                std::task::Poll::Ready(Ok(read))
1416            })
1417            .await
1418        })
1419        .context("read direct pg_dump virtual socket")
1420}
1421
1422#[cfg(feature = "extensions")]
1423fn write_direct_pg_dump_socket(
1424    runtime: &Runtime,
1425    writer: &mut (impl AsyncWrite + Unpin),
1426    bytes: &[u8],
1427) -> Result<()> {
1428    runtime
1429        .block_on(writer.write_all(bytes))
1430        .context("write direct pg_dump virtual socket")
1431}
1432
1433#[cfg(feature = "extensions")]
1434fn flush_direct_pg_dump_socket(
1435    runtime: &Runtime,
1436    writer: &mut (impl AsyncWrite + Unpin),
1437) -> Result<()> {
1438    runtime
1439        .block_on(writer.flush())
1440        .context("flush direct pg_dump virtual socket")
1441}
1442
1443fn value_to_i32(value: Option<&Value>) -> Option<i32> {
1444    match value? {
1445        Value::Number(number) => number.as_i64().map(|value| value as i32),
1446        Value::String(string) => string.parse::<i32>().ok(),
1447        _ => None,
1448    }
1449}
1450
1451fn value_to_char(value: Option<&Value>) -> Option<char> {
1452    match value? {
1453        Value::String(string) => string.chars().next(),
1454        _ => None,
1455    }
1456}
1457
1458fn array_type_info_from_row(row: &Value) -> Option<ArrayTypeInfo> {
1459    let Value::Object(map) = row else {
1460        return None;
1461    };
1462    let element_oid = value_to_i32(map.get("oid"))?;
1463    let array_oid = value_to_i32(map.get("typarray"))?;
1464    if element_oid == 0 || array_oid == 0 {
1465        return None;
1466    }
1467    let delimiter = value_to_char(map.get("typdelim")).unwrap_or(',');
1468    Some(ArrayTypeInfo::new(element_oid, array_oid, delimiter))
1469}
1470
1471/// Transaction handle used within [`Pglite::transaction`].
1472pub struct Transaction<'a> {
1473    client: &'a mut Pglite,
1474    closed: bool,
1475}
1476
1477impl<'a> Transaction<'a> {
1478    fn new(client: &'a mut Pglite) -> Self {
1479        Self {
1480            client,
1481            closed: false,
1482        }
1483    }
1484
1485    fn commit_internal(&mut self) -> Result<()> {
1486        self.ensure_open()?;
1487        self.client.exec_internal("COMMIT", None)?;
1488        self.closed = true;
1489        Ok(())
1490    }
1491
1492    fn rollback_internal(&mut self) -> Result<()> {
1493        self.ensure_open()?;
1494        self.client.exec_internal("ROLLBACK", None)?;
1495        self.closed = true;
1496        Ok(())
1497    }
1498
1499    fn ensure_open(&self) -> Result<()> {
1500        if self.closed {
1501            bail!("transaction is already closed");
1502        }
1503        Ok(())
1504    }
1505
1506    pub fn query(
1507        &mut self,
1508        sql: &str,
1509        params: &[Value],
1510        options: Option<&QueryOptions>,
1511    ) -> Result<Results> {
1512        self.ensure_open()?;
1513        self.client.query_internal(sql, params, options)
1514    }
1515
1516    pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
1517        self.ensure_open()?;
1518        self.client.exec_internal(sql, options)
1519    }
1520
1521    pub fn refresh_array_types(&mut self) -> Result<()> {
1522        self.ensure_open()?;
1523        self.client.refresh_array_types_internal()
1524    }
1525
1526    pub fn commit(&mut self) -> Result<()> {
1527        self.commit_internal()
1528    }
1529
1530    pub fn rollback(&mut self) -> Result<()> {
1531        self.rollback_internal()
1532    }
1533
1534    pub fn is_closed(&self) -> bool {
1535        self.closed
1536    }
1537
1538    pub fn closed(&self) -> bool {
1539        self.closed
1540    }
1541}