1use anyhow::{Context, Result, anyhow, bail};
2use serde_json::Value;
3use std::collections::{HashMap, HashSet};
4use std::fs;
5use std::io;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tempfile::TempDir;
10#[cfg(feature = "extensions")]
11use tokio::io::{AsyncWrite, AsyncWriteExt};
12#[cfg(feature = "extensions")]
13use tokio::runtime::Runtime;
14#[cfg(feature = "extensions")]
15use wasmer_wasix::virtual_net::VirtualTcpSocket;
16#[cfg(feature = "extensions")]
17use wasmer_wasix::virtual_net::tcp_pair::TcpSocketHalfRx;
18
19use crate::pglite::aot;
20#[cfg(feature = "extensions")]
21use crate::pglite::assets;
22use crate::pglite::backend::{BackendOpenKind, BackendSession};
23#[cfg(feature = "extensions")]
24use crate::pglite::base::install_bundled_extension_bytes;
25use crate::pglite::base::{InstallOutcome, PglitePaths, RootLock};
26use crate::pglite::builder::PgliteBuilder;
27use crate::pglite::config::{PostgresConfig, StartupConfig};
28use crate::pglite::data_dir::{DataDirArchiveFormat, dump_pgdata_archive};
29use crate::pglite::errors::PgliteError;
30#[cfg(feature = "extensions")]
31use crate::pglite::extensions::{
32 Extension, by_sql_name, extension_session_setup_sql, extension_setup_sql, resolve_extension_set,
33};
34use crate::pglite::interface::{
35 DataTransferContainer, DescribeQueryParam, DescribeQueryResult, DescribeResultField,
36 ExecProtocolOptions, ExecProtocolResult, ParserMap, QueryOptions, Results, SerializerMap,
37};
38use crate::pglite::parse::{parse_describe_statement_results, parse_results};
39#[cfg(feature = "extensions")]
40use crate::pglite::pg_dump::{PgDumpOptions, PgDumpVirtualSocket, dump_direct_sql};
41#[cfg(feature = "extensions")]
42use crate::pglite::postgres_mod::PostgresMod;
43use crate::pglite::timing;
44use crate::pglite::types::{
45 ArrayTypeInfo, DEFAULT_PARSERS, DEFAULT_SERIALIZERS, TEXT, register_array_type,
46};
47#[cfg(feature = "extensions")]
48use crate::pglite::wire::{FrontendFrameKind, FrontendFrameReader, classify_frontend_message};
49use crate::protocol::messages::{BackendMessage, DatabaseError};
50use crate::protocol::parser::Parser as ProtocolParser;
51use crate::protocol::serializer::{BindConfig, BindValue, PortalTarget, Serialize};
52
53type ChannelCallback = Arc<dyn Fn(&str) + Send + Sync + 'static>;
54type GlobalCallback = Arc<dyn Fn(&str, &str) + Send + Sync + 'static>;
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57pub struct ListenerHandle {
58 channel: String,
59 normalized_channel: String,
60 id: u64,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
64pub struct GlobalListenerHandle {
65 id: u64,
66}
67
68impl ListenerHandle {
69 pub fn channel(&self) -> &str {
70 &self.channel
71 }
72
73 pub fn id(&self) -> u64 {
74 self.id
75 }
76}
77
78impl GlobalListenerHandle {
79 pub fn id(&self) -> u64 {
80 self.id
81 }
82}
83
84struct ChannelListener {
85 id: u64,
86 callback: ChannelCallback,
87}
88
89struct GlobalListener {
90 id: u64,
91 callback: GlobalCallback,
92}
93
94pub struct Pglite {
96 backend: BackendSession,
97 _temp_dir: Option<TempDir>,
98 _root_lock: Option<RootLock>,
99 parser: ProtocolParser,
100 serializers: SerializerMap,
101 parsers: ParserMap,
102 array_type_lookup_misses: HashSet<i32>,
103 in_transaction: bool,
104 ready: bool,
105 closing: bool,
106 closed: bool,
107 blob_input_provided: bool,
108 notify_listeners: HashMap<String, Vec<ChannelListener>>,
109 global_notify_listeners: Vec<GlobalListener>,
110 next_listener_id: u64,
111 next_global_listener_id: u64,
112}
113
114impl Pglite {
115 pub fn builder() -> PgliteBuilder {
117 PgliteBuilder::new()
118 }
119
120 pub fn open(root: impl AsRef<Path>) -> Result<Self> {
122 Self::builder().path(root.as_ref().to_path_buf()).open()
123 }
124
125 pub fn open_app(app_id: (&str, &str, &str)) -> Result<Self> {
127 Self::builder().app_id(app_id).open()
128 }
129
130 pub fn temporary() -> Result<Self> {
132 Self::builder().temporary().open()
133 }
134
135 pub fn preload() -> Result<()> {
137 let (temp_dir, paths) = {
138 let _phase = timing::phase("preload.tempdir");
139 PglitePaths::with_temp_dir()?
140 };
141 {
142 let _phase = timing::phase("preload.runtime_module");
143 crate::pglite::base::preload_runtime_module(&paths)?;
144 }
145 {
146 let _phase = timing::phase("preload.aot_runtime");
147 aot::preload_runtime_artifact()?;
148 }
149 drop(temp_dir);
150 Ok(())
151 }
152
153 #[cfg(feature = "extensions")]
155 pub fn preload_extensions(extensions: impl IntoIterator<Item = Extension>) -> Result<()> {
156 Self::preload()?;
157 let extensions = extensions.into_iter().collect::<Vec<_>>();
158 for extension in resolve_extension_set(&extensions)? {
159 let bytes = assets::extension_archive(extension.sql_name()).ok_or_else(|| {
160 anyhow!(
161 "extension asset '{}' is not bundled in this pglite-oxide build",
162 extension.sql_name()
163 )
164 })?;
165 let (temp_dir, paths) = {
166 let _phase = timing::phase("preload.extension_tempdir");
167 PglitePaths::with_temp_dir()?
168 };
169 {
170 let _phase = timing::phase("preload.extension_runtime_module");
171 crate::pglite::base::preload_runtime_module(&paths)?;
172 }
173 {
174 let _phase = timing::phase("preload.extension_archive_install");
175 install_bundled_extension_bytes(&paths, extension.sql_name(), bytes)?;
176 }
177 {
178 let _phase = timing::phase("preload.extension_side_module");
179 PostgresMod::preload_extension_module_from_paths(&paths, extension)?;
180 }
181 {
182 let _phase = timing::phase("preload.extension_aot");
183 aot::preload_extension_artifact(extension)?;
184 }
185 drop(temp_dir);
186 }
187 Ok(())
188 }
189
190 #[doc(hidden)]
192 pub fn new(paths: PglitePaths) -> Result<Self> {
193 let outcome = crate::pglite::base::prepare_database_root(
194 paths,
195 crate::pglite::base::RootPrepareOptions::template(),
196 )?;
197 Self::new_prepared(outcome)
198 }
199
200 pub(crate) fn new_prepared(outcome: InstallOutcome) -> Result<Self> {
201 Self::new_prepared_with_config(outcome, PostgresConfig::default(), StartupConfig::default())
202 }
203
204 pub(crate) fn new_prepared_with_config(
205 outcome: InstallOutcome,
206 postgres_config: PostgresConfig,
207 startup_config: StartupConfig,
208 ) -> Result<Self> {
209 let _phase = timing::phase("pglite.open");
210 let session_startup_config = startup_config.clone();
211 let backend = BackendSession::open(
212 outcome,
213 postgres_config,
214 startup_config,
215 BackendOpenKind::Direct,
216 )?;
217
218 let mut instance = {
219 let _phase = timing::phase("pglite.client_struct_init");
220 Self {
221 backend,
222 _temp_dir: None,
223 _root_lock: None,
224 parser: ProtocolParser::new(),
225 serializers: DEFAULT_SERIALIZERS.clone(),
226 parsers: DEFAULT_PARSERS.clone(),
227 array_type_lookup_misses: HashSet::new(),
228 in_transaction: false,
229 ready: true,
230 closing: false,
231 closed: false,
232 blob_input_provided: false,
233 notify_listeners: HashMap::new(),
234 global_notify_listeners: Vec::new(),
235 next_listener_id: 1,
236 next_global_listener_id: 1,
237 }
238 };
239
240 if session_startup_config.username != "postgres" {
241 let sql = format!(
242 "SET ROLE {}",
243 crate::pglite::templating::quote_identifier(&session_startup_config.username)
244 );
245 instance
246 .exec(&sql, None)
247 .with_context(|| format!("set startup role {}", session_startup_config.username))?;
248 }
249
250 Ok(instance)
251 }
252
253 #[cfg(feature = "extensions")]
255 pub fn enable_extension(&mut self, extension: Extension) -> Result<()> {
256 let _phase = timing::phase("extension.enable");
257 let bytes = assets::extension_archive(extension.sql_name()).ok_or_else(|| {
258 anyhow!(
259 "extension asset '{}' is not bundled in this pglite-oxide build",
260 extension.sql_name()
261 )
262 })?;
263 install_bundled_extension_bytes(self.paths(), extension.sql_name(), bytes)?;
264 self.backend.preload_extension_module(extension)?;
265 for sql in extension_setup_sql(extension) {
266 self.exec(&sql, None)?;
267 }
268 Ok(())
269 }
270
271 #[cfg(feature = "extensions")]
272 pub(crate) fn enable_preinstalled_extension(&mut self, extension: Extension) -> Result<()> {
273 let _phase = timing::phase("extension.enable_preinstalled");
274 self.backend.preload_installed_extension(extension)?;
275 for sql in extension_session_setup_sql(extension) {
276 self.exec(&sql, None)?;
277 }
278 Ok(())
279 }
280
281 pub fn refresh_array_types(&mut self) -> Result<()> {
287 self.check_ready()?;
288 self.refresh_array_types_internal()
289 }
290
291 pub fn query(
293 &mut self,
294 sql: &str,
295 params: &[Value],
296 options: Option<&QueryOptions>,
297 ) -> Result<Results> {
298 self.check_ready()?;
299
300 self.query_internal(sql, params, options)
301 }
302
303 fn query_internal(
304 &mut self,
305 sql: &str,
306 params: &[Value],
307 options: Option<&QueryOptions>,
308 ) -> Result<Results> {
309 let default_options = QueryOptions::default();
310 let query_opts = options.unwrap_or(&default_options);
311
312 self.handle_blob_input(query_opts.blob.as_ref())?;
313
314 let params_snapshot: Vec<Value> = params.to_vec();
315 let options_snapshot = options.cloned();
316 let mut collected_messages: Vec<BackendMessage> = Vec::new();
317
318 let mut exec_opts = ExecProtocolOptions::no_sync();
319 exec_opts.on_notice = query_opts.on_notice.clone();
320 exec_opts.data_transfer_container = query_opts.data_transfer_container;
321
322 let result: Result<()> = (|| {
323 let param_types = if query_opts.param_types.is_empty() {
324 &[] as &[i32]
325 } else {
326 &query_opts.param_types
327 };
328
329 let mut messages = {
330 let _phase = timing::phase("client.query.parse_describe");
331 self.parse_and_describe(sql, param_types, exec_opts.clone())?
332 };
333 let mut data_type_ids = parse_describe_statement_results(&messages);
334 if self.ensure_array_types_for_bind_values(params, &data_type_ids, query_opts)? {
335 messages = {
336 let _phase = timing::phase("client.query.parse_describe_after_array_register");
337 self.parse_and_describe(sql, param_types, exec_opts.clone())?
338 };
339 data_type_ids = parse_describe_statement_results(&messages);
340 }
341 collected_messages.extend(messages);
342 let bind_values = {
343 let _phase = timing::phase("client.query.prepare_bind_values");
344 self.prepare_bind_values(params, &data_type_ids, query_opts)?
345 };
346 let bind_config = BindConfig {
347 values: bind_values,
348 ..Default::default()
349 };
350 let execute_batch = {
351 let _phase = timing::phase("client.query.serialize_execute");
352 let mut execute_batch = Vec::new();
353 execute_batch.extend(Serialize::bind(&bind_config));
354 execute_batch.extend(Serialize::describe(&PortalTarget::new('P', None)));
355 execute_batch.extend(Serialize::execute(None));
356 execute_batch.extend(Serialize::sync());
357 execute_batch
358 };
359 let ExecProtocolResult { messages, .. } = {
360 let _phase = timing::phase("client.query.execute_roundtrip");
361 self.exec_protocol(&execute_batch, exec_opts.clone())?
362 };
363 collected_messages.extend(messages);
364
365 Ok(())
366 })();
367
368 if let Err(err) = result {
369 match err.downcast::<DatabaseError>() {
370 Ok(db_err) => {
371 let enriched = PgliteError::new(db_err, sql, params_snapshot, options_snapshot);
372 return Err(enriched.into());
373 }
374 Err(err) => {
375 return Err(err.context(format!("failed to execute extended query: {sql}")));
376 }
377 }
378 }
379
380 {
381 let _phase = timing::phase("client.query.finish");
382 self.finish_query(collected_messages, options)
383 }
384 }
385
386 pub fn is_ready(&self) -> bool {
388 self.ready && !self.closing && !self.closed
389 }
390
391 #[doc(hidden)]
393 pub fn paths(&self) -> &PglitePaths {
394 self.backend.paths()
395 }
396
397 #[doc(hidden)]
399 #[cfg(debug_assertions)]
400 pub fn guest_bridge_allocation_counts(&self) -> (u64, u64) {
401 self.backend.guest_bridge_allocation_counts()
402 }
403
404 pub fn dump_data_dir(&mut self) -> Result<Vec<u8>> {
410 self.dump_data_dir_with_format(DataDirArchiveFormat::TarGz)
411 }
412
413 pub fn dump_data_dir_with_format(&mut self, format: DataDirArchiveFormat) -> Result<Vec<u8>> {
415 self.check_ready()?;
416 self.archive_quiesced_pgdata("dump PGDATA archive", format)
417 }
418
419 pub fn try_clone(&mut self) -> Result<Self> {
421 #[cfg(feature = "extensions")]
422 let extensions = self.bundled_extensions_in_database()?;
423 let archive = self.dump_data_dir_with_format(DataDirArchiveFormat::Tar)?;
424 let builder = Self::builder().temporary().load_data_dir_archive(archive);
425 #[cfg(feature = "extensions")]
426 let builder = builder.extensions(extensions);
427 builder.open()
428 }
429
430 #[cfg(feature = "extensions")]
432 pub fn dump_sql(&mut self, options: PgDumpOptions) -> Result<String> {
433 self.check_ready()?;
434 options.validate()?;
435 self.checkpoint_backend_for_physical_snapshot("direct pg_dump")?;
436 self.dump_sql_via_direct_protocol(&options)
437 }
438
439 #[cfg(feature = "extensions")]
441 pub fn dump_bytes(&mut self, options: PgDumpOptions) -> Result<Vec<u8>> {
442 Ok(self.dump_sql(options)?.into_bytes())
443 }
444
445 fn checkpoint_backend_for_physical_snapshot(&mut self, operation: &'static str) -> Result<()> {
446 if self.in_transaction {
447 bail!("{operation} cannot run while a direct transaction is active");
448 }
449 self.exec("CHECKPOINT", None)
450 .with_context(|| format!("checkpoint before {operation}"))?;
451 Ok(())
452 }
453
454 fn archive_quiesced_pgdata(
455 &mut self,
456 operation: &'static str,
457 format: DataDirArchiveFormat,
458 ) -> Result<Vec<u8>> {
459 self.checkpoint_backend_for_physical_snapshot(operation)?;
460 self.backend
461 .shutdown()
462 .with_context(|| format!("quiesce backend before {operation}"))?;
463
464 let archive = dump_pgdata_archive(
465 &self.backend.paths().pgdata,
466 self.backend.pgdata_template_root(),
467 format,
468 )
469 .with_context(|| format!("materialize physical PGDATA archive for {operation}"));
470 let restart = self
471 .backend
472 .restart()
473 .and_then(|_| self.restore_session_state_after_backend_restart())
474 .with_context(|| format!("restart backend after {operation}"));
475
476 match (archive, restart) {
477 (Ok(archive), Ok(())) => Ok(archive),
478 (Err(err), Ok(())) => Err(err),
479 (Ok(_), Err(err)) => {
480 self.ready = false;
481 self.closed = true;
482 Err(err)
483 }
484 (Err(err), Err(restart_err)) => {
485 self.ready = false;
486 self.closed = true;
487 Err(err.context(format!(
488 "backend restart after failed {operation} also failed: {restart_err:#}"
489 )))
490 }
491 }
492 }
493
494 fn restore_session_state_after_backend_restart(&mut self) -> Result<()> {
495 let username = self.backend.startup_config().username.clone();
496 if username != "postgres" {
497 let sql = format!(
498 "SET ROLE {}",
499 crate::pglite::templating::quote_identifier(&username)
500 );
501 self.exec(&sql, None).with_context(|| {
502 format!("restore startup role {username} after backend restart")
503 })?;
504 }
505
506 let channels = self
507 .notify_listeners
508 .iter()
509 .filter(|(_, listeners)| !listeners.is_empty())
510 .map(|(channel, _)| channel.clone())
511 .collect::<Vec<_>>();
512 for channel in channels {
513 let quoted_channel = crate::pglite::templating::quote_identifier(&channel);
514 self.exec_internal(&format!("LISTEN {quoted_channel}"), None)
515 .with_context(|| format!("restore LISTEN {channel} after backend restart"))?;
516 }
517 Ok(())
518 }
519
520 #[cfg(feature = "extensions")]
521 fn dump_sql_via_direct_protocol(&mut self, options: &PgDumpOptions) -> Result<String> {
522 ensure_direct_pg_dump_options_match_session(self.backend.startup_config(), options)?;
523 let result = dump_direct_sql(options, |socket| self.serve_direct_pg_dump_protocol(socket));
524 let cleanup_result = self.cleanup_after_direct_pg_dump_session();
525
526 match (result, cleanup_result) {
527 (Ok(sql), Ok(())) => Ok(sql),
528 (Err(err), Ok(())) => Err(err),
529 (Ok(_), Err(err)) => Err(err),
530 (Err(err), Err(cleanup_err)) => Err(err.context(format!(
531 "direct pg_dump cleanup also failed: {cleanup_err:#}"
532 ))),
533 }
534 }
535
536 #[cfg(feature = "extensions")]
537 fn cleanup_after_direct_pg_dump_session(&mut self) -> Result<()> {
538 self.exec("DEALLOCATE ALL; SET search_path TO DEFAULT;", None)
539 .context("reset direct pg_dump session state")?;
540 Ok(())
541 }
542
543 #[cfg(feature = "extensions")]
544 fn serve_direct_pg_dump_protocol(&mut self, mut socket: PgDumpVirtualSocket) -> Result<()> {
545 let _ = socket.set_nodelay(true);
546 let (mut socket_tx, mut socket_rx) = socket.split();
547 let runtime = tokio::runtime::Builder::new_current_thread()
548 .enable_all()
549 .build()
550 .context("create direct pg_dump virtual socket runtime")?;
551 let mut reader = FrontendFrameReader::default();
552 let mut buffer = [0u8; 64 * 1024];
553 loop {
554 let read = read_direct_pg_dump_socket(&runtime, &mut socket_rx, &mut buffer)
555 .context("read direct pg_dump protocol socket")?;
556 if read == 0 {
557 return Ok(());
558 }
559 for message in reader.push(&buffer[..read])? {
560 match classify_frontend_message(&message)? {
561 FrontendFrameKind::SslOrGssRequest => {
562 write_direct_pg_dump_socket(&runtime, &mut socket_tx, b"N")
563 .context("write direct pg_dump SSL refusal")?;
564 }
565 FrontendFrameKind::CancelRequest | FrontendFrameKind::Terminate => {
566 return Ok(());
567 }
568 FrontendFrameKind::Startup => {
569 if let Some(response) = self.backend.existing_startup_response() {
570 write_direct_pg_dump_socket(&runtime, &mut socket_tx, &response)
571 .context("write direct pg_dump existing startup response")?;
572 } else {
573 let response = self.backend.startup_with_packet(&message)?;
574 write_direct_pg_dump_socket(&runtime, &mut socket_tx, &response.output)
575 .context("write direct pg_dump startup response")?;
576 if !response.accepted {
577 return Ok(());
578 }
579 }
580 }
581 FrontendFrameKind::Protocol => {
582 self.exec_protocol_raw_stream(
583 &message,
584 ExecProtocolOptions::no_sync(),
585 |chunk| {
586 write_direct_pg_dump_socket(&runtime, &mut socket_tx, chunk)
587 .context("write direct pg_dump backend protocol chunk")?;
588 Ok(())
589 },
590 )?;
591 }
592 }
593 }
594 flush_direct_pg_dump_socket(&runtime, &mut socket_tx)
595 .context("flush direct pg_dump socket")?;
596 }
597 }
598
599 #[cfg(feature = "extensions")]
600 fn bundled_extensions_in_database(&mut self) -> Result<Vec<Extension>> {
601 let results = self.query(
602 "SELECT extname FROM pg_catalog.pg_extension ORDER BY extname",
603 &[],
604 None,
605 )?;
606 let extensions = results
607 .rows
608 .iter()
609 .filter_map(|row| row.get("extname"))
610 .filter_map(|value| value.as_str())
611 .filter_map(by_sql_name)
612 .collect();
613 Ok(extensions)
614 }
615
616 pub(crate) fn attach_temp_dir(&mut self, temp_dir: TempDir) {
617 self._temp_dir = Some(temp_dir);
618 }
619
620 pub(crate) fn attach_root_lock(&mut self, root_lock: RootLock) {
621 self._root_lock = Some(root_lock);
622 }
623
624 pub fn is_closed(&self) -> bool {
626 self.closed
627 }
628
629 pub fn close(&mut self) -> Result<()> {
631 self.close_backend()
632 }
633
634 fn close_backend(&mut self) -> Result<()> {
635 if self.closed {
636 return Ok(());
637 }
638 if self.closing {
639 bail!("Pglite is closing");
640 }
641
642 self.closing = true;
643 let result = (|| {
644 self.backend.shutdown()?;
645 self.sync_to_fs()
646 })();
647
648 self.closing = false;
649 if result.is_ok() {
650 self.closed = true;
651 self.ready = false;
652 self.notify_listeners.clear();
653 self.global_notify_listeners.clear();
654 self._root_lock = None;
655 }
656 result
657 }
658
659 #[cfg(feature = "extensions")]
660 pub(crate) fn close_for_template_cache(&mut self) -> Result<()> {
661 self.close_backend()
662 }
663
664 pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
666 self.check_ready()?;
667
668 self.exec_internal(sql, options)
669 }
670
671 fn exec_internal(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
672 let options_snapshot = options.cloned();
673 let default_options = QueryOptions::default();
674 let exec_opts_ref = options.unwrap_or(&default_options);
675 let mut exec_opts = ExecProtocolOptions::no_sync();
676 exec_opts.on_notice = exec_opts_ref.on_notice.clone();
677 exec_opts.data_transfer_container = exec_opts_ref.data_transfer_container;
678
679 self.handle_blob_input(exec_opts_ref.blob.as_ref())?;
680
681 let mut collected_messages: Vec<BackendMessage> = Vec::new();
682
683 let message = Serialize::query(sql);
684 let ExecProtocolResult { messages, .. } = match self.exec_protocol(&message, exec_opts) {
685 Ok(result) => result,
686 Err(err) => match err.downcast::<DatabaseError>() {
687 Ok(db_err) => {
688 let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
689 return Err(enriched.into());
690 }
691 Err(err) => {
692 return Err(err.context(format!("failed to execute simple query: {sql}")));
693 }
694 },
695 };
696 collected_messages.extend(messages);
697
698 self.finish_exec(collected_messages, options)
699 }
700
701 pub fn listen<F>(&mut self, channel: &str, callback: F) -> Result<ListenerHandle>
703 where
704 F: Fn(&str) + Send + Sync + 'static,
705 {
706 self.check_ready()?;
707
708 let quoted_channel = crate::pglite::templating::quote_identifier(channel);
709 let normalized = channel.to_string();
710 let should_listen = match self.notify_listeners.get(&normalized) {
711 Some(existing) => existing.is_empty(),
712 None => true,
713 };
714
715 if should_listen {
716 self.exec_internal(&format!("LISTEN {quoted_channel}"), None)?;
717 }
718
719 let callback: ChannelCallback = Arc::new(callback);
720 let entry = self.notify_listeners.entry(normalized.clone()).or_default();
721 let id = self.next_listener_id;
722 self.next_listener_id = self.next_listener_id.wrapping_add(1);
723 entry.push(ChannelListener { id, callback });
724
725 Ok(ListenerHandle {
726 channel: channel.to_string(),
727 normalized_channel: normalized,
728 id,
729 })
730 }
731
732 pub fn unlisten(&mut self, handle: ListenerHandle) -> Result<()> {
734 if let Some(listeners) = self.notify_listeners.get_mut(&handle.normalized_channel) {
735 listeners.retain(|listener| listener.id != handle.id);
736 if listeners.is_empty() {
737 self.notify_listeners.remove(&handle.normalized_channel);
738 let quoted_channel = crate::pglite::templating::quote_identifier(&handle.channel);
739 self.exec_internal(&format!("UNLISTEN {quoted_channel}"), None)?;
740 }
741 }
742 Ok(())
743 }
744
745 pub fn unlisten_channel(&mut self, channel: &str) -> Result<()> {
747 let quoted_channel = crate::pglite::templating::quote_identifier(channel);
748 let normalized = channel.to_string();
749 if self.notify_listeners.remove(&normalized).is_some() {
750 self.exec_internal(&format!("UNLISTEN {quoted_channel}"), None)?;
751 }
752 Ok(())
753 }
754
755 pub fn on_notification<F>(&mut self, callback: F) -> GlobalListenerHandle
757 where
758 F: Fn(&str, &str) + Send + Sync + 'static,
759 {
760 let id = self.next_global_listener_id;
761 self.next_global_listener_id = self.next_global_listener_id.wrapping_add(1);
762 let callback: GlobalCallback = Arc::new(callback);
763 self.global_notify_listeners
764 .push(GlobalListener { id, callback });
765 GlobalListenerHandle { id }
766 }
767
768 pub fn off_notification(&mut self, handle: GlobalListenerHandle) {
770 self.global_notify_listeners
771 .retain(|listener| listener.id != handle.id);
772 }
773
774 pub fn describe_query(
776 &mut self,
777 sql: &str,
778 options: Option<&QueryOptions>,
779 ) -> Result<DescribeQueryResult> {
780 self.check_ready()?;
781
782 let default_options = QueryOptions::default();
783 let query_opts = options.unwrap_or(&default_options);
784
785 let options_snapshot = options.cloned();
786 let mut exec_opts = ExecProtocolOptions::no_sync();
787 exec_opts.on_notice = query_opts.on_notice.clone();
788 exec_opts.data_transfer_container = query_opts.data_transfer_container;
789
790 let mut describe_messages: Vec<BackendMessage> = Vec::new();
791
792 let result: Result<()> = (|| {
793 let param_types = if query_opts.param_types.is_empty() {
794 &[] as &[i32]
795 } else {
796 &query_opts.param_types
797 };
798
799 let mut describe_batch = Vec::new();
800 describe_batch.extend(Serialize::parse(None, sql, param_types));
801 describe_batch.extend(Serialize::describe(&PortalTarget::new('S', None)));
802 describe_batch.extend(Serialize::sync());
803 let ExecProtocolResult { messages, .. } =
804 self.exec_protocol(&describe_batch, exec_opts.clone())?;
805 if !messages
806 .iter()
807 .any(|message| matches!(message, BackendMessage::ParseComplete { .. }))
808 {
809 bail!("extended query parse did not complete");
810 }
811 describe_messages.extend(messages);
812
813 Ok(())
814 })();
815
816 if let Err(err) = result {
817 match err.downcast::<DatabaseError>() {
818 Ok(db_err) => {
819 let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
820 return Err(enriched.into());
821 }
822 Err(err) => {
823 return Err(err.context(format!("failed to describe query: {sql}")));
824 }
825 }
826 }
827
828 let param_type_ids = parse_describe_statement_results(&describe_messages);
829 self.ensure_array_types_for_oids(param_type_ids.iter().copied(), Some(query_opts))?;
830 let result_type_ids = describe_messages
831 .iter()
832 .filter_map(|msg| match msg {
833 BackendMessage::RowDescription(desc) => Some(desc),
834 _ => None,
835 })
836 .flat_map(|desc| desc.fields.iter().map(|field| field.data_type_id))
837 .collect::<Vec<_>>();
838 self.ensure_array_types_for_oids(result_type_ids.iter().copied(), Some(query_opts))?;
839
840 let query_params = param_type_ids
841 .into_iter()
842 .map(|oid| DescribeQueryParam {
843 data_type_id: oid,
844 serializer: self.serializers.get(&oid).cloned(),
845 })
846 .collect();
847
848 let result_fields = describe_messages
849 .iter()
850 .find_map(|msg| match msg {
851 BackendMessage::RowDescription(desc) => Some(
852 desc.fields
853 .iter()
854 .map(|field| DescribeResultField {
855 name: field.name.clone(),
856 data_type_id: field.data_type_id,
857 parser: self.parsers.get(&field.data_type_id).cloned(),
858 })
859 .collect::<Vec<_>>(),
860 ),
861 _ => None,
862 })
863 .unwrap_or_default();
864
865 Ok(DescribeQueryResult {
866 query_params,
867 result_fields,
868 })
869 }
870
871 pub fn transaction<F, T>(&mut self, mut callback: F) -> Result<T>
873 where
874 F: FnMut(&mut Transaction<'_>) -> Result<T>,
875 {
876 self.check_ready()?;
877
878 self.run_exec_command("BEGIN")?;
880 self.in_transaction = true;
881
882 let mut tx = Transaction::new(self);
883 let callback_result = callback(&mut tx);
884
885 let txn_result = match callback_result {
886 Ok(value) => {
887 if !tx.closed {
888 tx.commit_internal()?;
889 }
890 Ok(value)
891 }
892 Err(err) => {
893 if !tx.closed {
894 tx.rollback_internal()?;
895 }
896 Err(err)
897 }
898 };
899
900 self.in_transaction = false;
901 txn_result
902 }
903
904 pub fn sync_to_fs(&mut self) -> Result<()> {
911 Ok(())
912 }
913
914 fn prepare_bind_values(
915 &self,
916 params: &[Value],
917 data_type_ids: &[i32],
918 options: &QueryOptions,
919 ) -> Result<Vec<BindValue>> {
920 if params.is_empty() {
921 return Ok(Vec::new());
922 }
923
924 let mut values = Vec::with_capacity(params.len());
925 let overrides = if options.serializers.is_empty() {
926 None
927 } else {
928 Some(&options.serializers)
929 };
930
931 for (idx, value) in params.iter().enumerate() {
932 if value.is_null() {
933 values.push(BindValue::Null);
934 continue;
935 }
936
937 let oid = data_type_ids.get(idx).copied().unwrap_or(TEXT);
938 let serializer = overrides
939 .and_then(|map| map.get(&oid))
940 .or_else(|| self.serializers.get(&oid));
941
942 let serialized = match serializer {
943 Some(func) => func(value).with_context(|| {
944 format!("failed to serialize parameter {idx} using OID {oid}")
945 })?,
946 None => self.default_serialize_value(value),
947 };
948
949 values.push(BindValue::Text(serialized));
950 }
951
952 Ok(values)
953 }
954
955 fn parse_and_describe(
956 &mut self,
957 sql: &str,
958 param_types: &[i32],
959 exec_opts: ExecProtocolOptions,
960 ) -> Result<Vec<BackendMessage>> {
961 let mut prepare_batch = Vec::new();
962 prepare_batch.extend(Serialize::parse(None, sql, param_types));
963 prepare_batch.extend(Serialize::describe(&PortalTarget::new('S', None)));
964 prepare_batch.extend(Serialize::sync());
965 let ExecProtocolResult { messages, .. } = self.exec_protocol(&prepare_batch, exec_opts)?;
966 if !messages
967 .iter()
968 .any(|message| matches!(message, BackendMessage::ParseComplete { .. }))
969 {
970 bail!("extended query parse did not complete");
971 }
972 Ok(messages)
973 }
974
975 fn default_serialize_value(&self, value: &Value) -> String {
976 Self::default_serialize_value_static(value)
977 }
978
979 pub(crate) fn default_serialize_value_static(value: &Value) -> String {
980 match value {
981 Value::String(s) => s.clone(),
982 Value::Number(num) => num.to_string(),
983 Value::Bool(flag) => {
984 if *flag {
985 "t".to_string()
986 } else {
987 "f".to_string()
988 }
989 }
990 _ => value.to_string(),
991 }
992 }
993
994 fn finish_query(
995 &mut self,
996 messages: Vec<BackendMessage>,
997 options: Option<&QueryOptions>,
998 ) -> Result<Results> {
999 let blob = {
1000 let _phase = timing::phase("client.finish.blob_read");
1001 self.get_written_blob()?
1002 };
1003 {
1004 let _phase = timing::phase("client.finish.blob_cleanup");
1005 self.cleanup_blob()?;
1006 }
1007 if !self.in_transaction {
1008 let _phase = timing::phase("client.finish.sync_to_fs");
1009 self.sync_to_fs()?;
1010 }
1011 {
1012 let _phase = timing::phase("client.finish.ensure_array_types");
1013 self.ensure_array_types_for_result_messages(&messages, options)?;
1014 }
1015 let parsed = {
1016 let _phase = timing::phase("client.finish.parse_results");
1017 parse_results(&messages, &self.parsers, options, blob)
1018 };
1019 parsed
1020 .into_iter()
1021 .next()
1022 .ok_or_else(|| anyhow!("query returned no result sets"))
1023 }
1024
1025 fn finish_exec(
1026 &mut self,
1027 messages: Vec<BackendMessage>,
1028 options: Option<&QueryOptions>,
1029 ) -> Result<Vec<Results>> {
1030 let blob = {
1031 let _phase = timing::phase("client.finish.blob_read");
1032 self.get_written_blob()?
1033 };
1034 {
1035 let _phase = timing::phase("client.finish.blob_cleanup");
1036 self.cleanup_blob()?;
1037 }
1038 if !self.in_transaction {
1039 let _phase = timing::phase("client.finish.sync_to_fs");
1040 self.sync_to_fs()?;
1041 }
1042 {
1043 let _phase = timing::phase("client.finish.ensure_array_types");
1044 self.ensure_array_types_for_result_messages(&messages, options)?;
1045 }
1046 let parsed = {
1047 let _phase = timing::phase("client.finish.parse_results");
1048 parse_results(&messages, &self.parsers, options, blob)
1049 };
1050 Ok(parsed)
1051 }
1052
1053 pub fn exec_protocol(
1056 &mut self,
1057 message: &[u8],
1058 options: ExecProtocolOptions,
1059 ) -> Result<ExecProtocolResult> {
1060 let ExecProtocolOptions {
1061 sync_to_fs,
1062 throw_on_error,
1063 on_notice,
1064 data_transfer_container,
1065 } = options;
1066
1067 let data = {
1068 let _phase = timing::phase("client.protocol_roundtrip");
1069 self.exec_protocol_raw_inner(message, sync_to_fs, data_transfer_container)?
1070 };
1071
1072 let mut messages = Vec::new();
1073 let on_notice_cb = on_notice.clone();
1074 let parse_result = {
1075 let _phase = timing::phase("client.protocol_parse");
1076 self.parser.parse(&data, |msg| {
1077 if let BackendMessage::Error(db_err) = &msg
1078 && throw_on_error
1079 {
1080 return Err(anyhow!(db_err.clone()));
1081 }
1082 if let Some(callback) = on_notice_cb.as_ref()
1083 && let BackendMessage::Notice(notice) = &msg
1084 {
1085 callback(notice);
1086 }
1087 messages.push(msg);
1088 Ok(())
1089 })
1090 };
1091 if let Err(err) = parse_result {
1092 match err.downcast::<DatabaseError>() {
1093 Ok(db_err) => {
1094 self.parser = ProtocolParser::new();
1095 return Err(anyhow!(db_err));
1096 }
1097 Err(err) => return Err(err),
1098 }
1099 }
1100
1101 for message in &messages {
1102 if let BackendMessage::Notification(note) = message {
1103 if let Some(listeners) = self.notify_listeners.get(¬e.channel) {
1104 for listener in listeners {
1105 (listener.callback)(¬e.payload);
1106 }
1107 }
1108 for listener in &self.global_notify_listeners {
1109 (listener.callback)(¬e.channel, ¬e.payload);
1110 }
1111 }
1112 }
1113
1114 Ok(ExecProtocolResult { data, messages })
1115 }
1116
1117 pub fn exec_protocol_raw(
1120 &mut self,
1121 message: &[u8],
1122 options: ExecProtocolOptions,
1123 ) -> Result<Vec<u8>> {
1124 self.exec_protocol_raw_inner(message, options.sync_to_fs, options.data_transfer_container)
1125 }
1126
1127 pub fn exec_protocol_raw_stream<F>(
1130 &mut self,
1131 message: &[u8],
1132 options: ExecProtocolOptions,
1133 mut on_data: F,
1134 ) -> Result<()>
1135 where
1136 F: FnMut(&[u8]) -> Result<()>,
1137 {
1138 self.backend.send_framed_raw_stream(
1139 message,
1140 options.data_transfer_container,
1141 &mut on_data,
1142 )?;
1143 if options.sync_to_fs {
1144 let _phase = timing::phase("client.protocol_stream_sync_to_fs");
1145 self.sync_to_fs()?;
1146 }
1147 Ok(())
1148 }
1149
1150 fn exec_protocol_raw_inner(
1151 &mut self,
1152 message: &[u8],
1153 sync_to_fs: bool,
1154 data_transfer_container: Option<DataTransferContainer>,
1155 ) -> Result<Vec<u8>> {
1156 let data = {
1157 let _phase = timing::phase("client.protocol_transport_send");
1158 self.backend
1159 .send_buffered(message, data_transfer_container)?
1160 };
1161 if sync_to_fs {
1162 let _phase = timing::phase("client.protocol_sync_to_fs");
1163 self.sync_to_fs()?;
1164 }
1165 Ok(data)
1166 }
1167
1168 fn ensure_array_types_for_bind_values(
1169 &mut self,
1170 params: &[Value],
1171 data_type_ids: &[i32],
1172 options: &QueryOptions,
1173 ) -> Result<bool> {
1174 let mut registered = false;
1175 for (idx, value) in params.iter().enumerate() {
1176 if !value.is_array() {
1177 continue;
1178 }
1179 let oid = data_type_ids.get(idx).copied().unwrap_or(TEXT);
1180 if options.serializers.contains_key(&oid) || self.serializers.contains_key(&oid) {
1181 continue;
1182 }
1183 registered |= self.try_register_array_type_by_array_oid(oid)?;
1184 }
1185 Ok(registered)
1186 }
1187
1188 fn ensure_array_types_for_result_messages(
1189 &mut self,
1190 messages: &[BackendMessage],
1191 options: Option<&QueryOptions>,
1192 ) -> Result<()> {
1193 let oids = messages
1194 .iter()
1195 .filter_map(|msg| match msg {
1196 BackendMessage::RowDescription(desc) => Some(desc),
1197 _ => None,
1198 })
1199 .flat_map(|desc| desc.fields.iter().map(|field| field.data_type_id))
1200 .collect::<Vec<_>>();
1201 self.ensure_array_types_for_oids(oids, options)
1202 }
1203
1204 fn ensure_array_types_for_oids(
1205 &mut self,
1206 oids: impl IntoIterator<Item = i32>,
1207 options: Option<&QueryOptions>,
1208 ) -> Result<()> {
1209 for oid in oids {
1210 if oid <= 0 || self.parsers.contains_key(&oid) {
1211 continue;
1212 }
1213 if options.is_some_and(|options| options.parsers.contains_key(&oid)) {
1214 continue;
1215 }
1216 self.try_register_array_type_by_array_oid(oid)?;
1217 }
1218 Ok(())
1219 }
1220
1221 fn refresh_array_types_internal(&mut self) -> Result<()> {
1222 let sql = "
1223 SELECT e.oid, a.oid AS typarray, e.typdelim::text AS typdelim
1224 FROM pg_catalog.pg_type a
1225 JOIN pg_catalog.pg_type e ON e.oid = a.typelem
1226 WHERE a.typcategory = 'A'
1227 AND a.typelem <> 0
1228 ORDER BY e.oid
1229 ";
1230 let results = {
1231 let _phase = timing::phase("pglite.array_type_catalog_query");
1232 self.exec_internal(sql, None)?
1233 };
1234 let result_set = results
1235 .into_iter()
1236 .next()
1237 .ok_or_else(|| anyhow!("array type discovery returned no results"))?;
1238
1239 {
1240 let _phase = timing::phase("pglite.array_type_register");
1241 for row in result_set.rows {
1242 if let Some(info) = array_type_info_from_row(&row) {
1243 self.register_array_type(info);
1244 }
1245 }
1246 }
1247 Ok(())
1248 }
1249
1250 fn try_register_array_type_by_array_oid(&mut self, array_oid: i32) -> Result<bool> {
1251 if array_oid <= 0
1252 || self.parsers.contains_key(&array_oid)
1253 || self.array_type_lookup_misses.contains(&array_oid)
1254 {
1255 return Ok(false);
1256 }
1257
1258 let sql = format!(
1259 "SELECT e.oid, a.oid AS typarray, e.typdelim::text AS typdelim \
1260 FROM pg_catalog.pg_type a \
1261 JOIN pg_catalog.pg_type e ON e.oid = a.typelem \
1262 WHERE a.oid = {array_oid}::oid \
1263 AND a.typcategory = 'A' \
1264 AND a.typelem <> 0"
1265 );
1266 let results = {
1267 let _phase = timing::phase("pglite.array_type_targeted_lookup");
1268 self.exec_internal(&sql, None)?
1269 };
1270 let Some(result_set) = results.into_iter().next() else {
1271 self.array_type_lookup_misses.insert(array_oid);
1272 return Ok(false);
1273 };
1274 let Some(row) = result_set.rows.into_iter().next() else {
1275 self.array_type_lookup_misses.insert(array_oid);
1276 return Ok(false);
1277 };
1278 let Some(info) = array_type_info_from_row(&row) else {
1279 self.array_type_lookup_misses.insert(array_oid);
1280 return Ok(false);
1281 };
1282
1283 self.register_array_type(info);
1284 Ok(true)
1285 }
1286
1287 fn register_array_type(&mut self, info: ArrayTypeInfo) {
1288 register_array_type(&mut self.parsers, &mut self.serializers, info);
1289 self.array_type_lookup_misses.remove(&info.array_oid);
1290 }
1291
1292 fn run_exec_command(&mut self, sql: &str) -> Result<()> {
1293 self.exec_internal(sql, None).map(|_| ())
1294 }
1295
1296 fn handle_blob_input(&mut self, blob: Option<&Vec<u8>>) -> Result<()> {
1297 let path = self.dev_blob_path();
1298 if let Some(bytes) = blob {
1299 if let Some(parent) = path.parent() {
1300 fs::create_dir_all(parent).with_context(|| {
1301 format!("failed to create blob directory {}", parent.display())
1302 })?;
1303 }
1304 fs::write(&path, bytes)
1305 .with_context(|| format!("write blob input to {}", path.display()))?;
1306 self.blob_input_provided = true;
1307 } else {
1308 self.blob_input_provided = false;
1309 let _ = fs::remove_file(&path);
1310 }
1311 Ok(())
1312 }
1313
1314 fn dev_blob_path(&self) -> PathBuf {
1315 self.backend.paths().runtime_root().join("dev/blob")
1316 }
1317
1318 fn cleanup_blob(&mut self) -> Result<()> {
1319 Ok(())
1320 }
1321
1322 fn get_written_blob(&mut self) -> Result<Option<Vec<u8>>> {
1323 let path = self.dev_blob_path();
1324
1325 if self.blob_input_provided {
1326 self.blob_input_provided = false;
1327 let _ = fs::remove_file(&path);
1328 return Ok(None);
1329 }
1330
1331 match fs::read(&path) {
1332 Ok(data) => {
1333 self.blob_input_provided = false;
1334 let _ = fs::remove_file(&path);
1335 if data.is_empty() {
1336 Ok(None)
1337 } else {
1338 Ok(Some(data))
1339 }
1340 }
1341 Err(err) => {
1342 if err.kind() == io::ErrorKind::NotFound {
1343 self.blob_input_provided = false;
1344 Ok(None)
1345 } else {
1346 Err(err).with_context(|| format!("read blob output from {}", path.display()))
1347 }
1348 }
1349 }
1350 }
1351
1352 fn check_ready(&self) -> Result<()> {
1353 if self.closing {
1354 bail!("Pglite instance is closing");
1355 }
1356 if self.closed {
1357 bail!("Pglite instance is closed");
1358 }
1359 if !self.ready {
1360 bail!("Pglite instance is not ready");
1361 }
1362 Ok(())
1363 }
1364}
1365
1366impl Drop for Pglite {
1367 fn drop(&mut self) {
1368 if !self.closed {
1369 let _ = self.close();
1370 }
1371 }
1372}
1373
1374#[cfg(feature = "extensions")]
1375fn ensure_direct_pg_dump_options_match_session(
1376 startup_config: &StartupConfig,
1377 options: &PgDumpOptions,
1378) -> Result<()> {
1379 if options.database_ref() != startup_config.database {
1380 bail!(
1381 "direct pg_dump runs against the already-open embedded backend database '{}'; requested database '{}' would require a separate server connection",
1382 startup_config.database,
1383 options.database_ref()
1384 );
1385 }
1386 if options.username_ref() != startup_config.username {
1387 bail!(
1388 "direct pg_dump runs through the already-open embedded backend user '{}'; requested user '{}' would require a separate server connection",
1389 startup_config.username,
1390 options.username_ref()
1391 );
1392 }
1393 Ok(())
1394}
1395
1396#[cfg(feature = "extensions")]
1397fn read_direct_pg_dump_socket(
1398 runtime: &Runtime,
1399 reader: &mut TcpSocketHalfRx,
1400 buffer: &mut [u8],
1401) -> Result<usize> {
1402 runtime
1403 .block_on(async {
1404 std::future::poll_fn(|cx| {
1405 let read = match reader.poll_fill_buf(cx) {
1406 std::task::Poll::Ready(Ok(available)) => {
1407 let read = available.len().min(buffer.len());
1408 buffer[..read].copy_from_slice(&available[..read]);
1409 read
1410 }
1411 std::task::Poll::Ready(Err(err)) => return std::task::Poll::Ready(Err(err)),
1412 std::task::Poll::Pending => return std::task::Poll::Pending,
1413 };
1414 reader.consume(read);
1415 std::task::Poll::Ready(Ok(read))
1416 })
1417 .await
1418 })
1419 .context("read direct pg_dump virtual socket")
1420}
1421
1422#[cfg(feature = "extensions")]
1423fn write_direct_pg_dump_socket(
1424 runtime: &Runtime,
1425 writer: &mut (impl AsyncWrite + Unpin),
1426 bytes: &[u8],
1427) -> Result<()> {
1428 runtime
1429 .block_on(writer.write_all(bytes))
1430 .context("write direct pg_dump virtual socket")
1431}
1432
1433#[cfg(feature = "extensions")]
1434fn flush_direct_pg_dump_socket(
1435 runtime: &Runtime,
1436 writer: &mut (impl AsyncWrite + Unpin),
1437) -> Result<()> {
1438 runtime
1439 .block_on(writer.flush())
1440 .context("flush direct pg_dump virtual socket")
1441}
1442
1443fn value_to_i32(value: Option<&Value>) -> Option<i32> {
1444 match value? {
1445 Value::Number(number) => number.as_i64().map(|value| value as i32),
1446 Value::String(string) => string.parse::<i32>().ok(),
1447 _ => None,
1448 }
1449}
1450
1451fn value_to_char(value: Option<&Value>) -> Option<char> {
1452 match value? {
1453 Value::String(string) => string.chars().next(),
1454 _ => None,
1455 }
1456}
1457
1458fn array_type_info_from_row(row: &Value) -> Option<ArrayTypeInfo> {
1459 let Value::Object(map) = row else {
1460 return None;
1461 };
1462 let element_oid = value_to_i32(map.get("oid"))?;
1463 let array_oid = value_to_i32(map.get("typarray"))?;
1464 if element_oid == 0 || array_oid == 0 {
1465 return None;
1466 }
1467 let delimiter = value_to_char(map.get("typdelim")).unwrap_or(',');
1468 Some(ArrayTypeInfo::new(element_oid, array_oid, delimiter))
1469}
1470
1471pub struct Transaction<'a> {
1473 client: &'a mut Pglite,
1474 closed: bool,
1475}
1476
1477impl<'a> Transaction<'a> {
1478 fn new(client: &'a mut Pglite) -> Self {
1479 Self {
1480 client,
1481 closed: false,
1482 }
1483 }
1484
1485 fn commit_internal(&mut self) -> Result<()> {
1486 self.ensure_open()?;
1487 self.client.exec_internal("COMMIT", None)?;
1488 self.closed = true;
1489 Ok(())
1490 }
1491
1492 fn rollback_internal(&mut self) -> Result<()> {
1493 self.ensure_open()?;
1494 self.client.exec_internal("ROLLBACK", None)?;
1495 self.closed = true;
1496 Ok(())
1497 }
1498
1499 fn ensure_open(&self) -> Result<()> {
1500 if self.closed {
1501 bail!("transaction is already closed");
1502 }
1503 Ok(())
1504 }
1505
1506 pub fn query(
1507 &mut self,
1508 sql: &str,
1509 params: &[Value],
1510 options: Option<&QueryOptions>,
1511 ) -> Result<Results> {
1512 self.ensure_open()?;
1513 self.client.query_internal(sql, params, options)
1514 }
1515
1516 pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
1517 self.ensure_open()?;
1518 self.client.exec_internal(sql, options)
1519 }
1520
1521 pub fn refresh_array_types(&mut self) -> Result<()> {
1522 self.ensure_open()?;
1523 self.client.refresh_array_types_internal()
1524 }
1525
1526 pub fn commit(&mut self) -> Result<()> {
1527 self.commit_internal()
1528 }
1529
1530 pub fn rollback(&mut self) -> Result<()> {
1531 self.rollback_internal()
1532 }
1533
1534 pub fn is_closed(&self) -> bool {
1535 self.closed
1536 }
1537
1538 pub fn closed(&self) -> bool {
1539 self.closed
1540 }
1541}