1use anyhow::{Context, Result, anyhow, bail};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::fs;
5use std::io;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tempfile::TempDir;
10
11use crate::pglite::base::PglitePaths;
12use crate::pglite::builder::PgliteBuilder;
13use crate::pglite::errors::PgliteError;
14use crate::pglite::interface::{
15 DataTransferContainer, DescribeQueryParam, DescribeQueryResult, DescribeResultField,
16 ExecProtocolOptions, ExecProtocolResult, ParserMap, QueryOptions, Results, Serializer,
17 SerializerMap, TypeParser,
18};
19use crate::pglite::parse::{parse_describe_statement_results, parse_results};
20use crate::pglite::postgres_mod::PostgresMod;
21use crate::pglite::transport::Transport;
22use crate::pglite::types::{
23 DEFAULT_PARSERS, DEFAULT_SERIALIZERS, TEXT, parse_array_text, serialize_array_value,
24};
25use crate::protocol::messages::{BackendMessage, DatabaseError};
26use crate::protocol::parser::Parser as ProtocolParser;
27use crate::protocol::serializer::{BindConfig, BindValue, PortalTarget, Serialize};
28
29type ChannelCallback = Arc<dyn Fn(&str) + Send + Sync + 'static>;
30type GlobalCallback = Arc<dyn Fn(&str, &str) + Send + Sync + 'static>;
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub struct ListenerHandle {
34 channel: String,
35 normalized_channel: String,
36 id: u64,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct GlobalListenerHandle {
41 id: u64,
42}
43
44impl ListenerHandle {
45 pub fn channel(&self) -> &str {
46 &self.channel
47 }
48
49 pub fn id(&self) -> u64 {
50 self.id
51 }
52}
53
54impl GlobalListenerHandle {
55 pub fn id(&self) -> u64 {
56 self.id
57 }
58}
59
60struct ChannelListener {
61 id: u64,
62 callback: ChannelCallback,
63}
64
65struct GlobalListener {
66 id: u64,
67 callback: GlobalCallback,
68}
69
70pub struct Pglite {
72 pg: PostgresMod,
73 _temp_dir: Option<TempDir>,
74 transport: Transport,
75 parser: ProtocolParser,
76 serializers: SerializerMap,
77 parsers: ParserMap,
78 array_types_initialized: bool,
79 in_transaction: bool,
80 ready: bool,
81 closing: bool,
82 closed: bool,
83 blob_input_provided: bool,
84 notify_listeners: HashMap<String, Vec<ChannelListener>>,
85 global_notify_listeners: Vec<GlobalListener>,
86 next_listener_id: u64,
87 next_global_listener_id: u64,
88}
89
90impl Pglite {
91 pub fn builder() -> PgliteBuilder {
93 PgliteBuilder::new()
94 }
95
96 pub fn open(root: impl AsRef<Path>) -> Result<Self> {
98 Self::builder().path(root.as_ref().to_path_buf()).open()
99 }
100
101 pub fn open_app(app_id: (&str, &str, &str)) -> Result<Self> {
103 Self::builder().app_id(app_id).open()
104 }
105
106 pub fn temporary() -> Result<Self> {
108 Self::builder().temporary().open()
109 }
110
111 #[doc(hidden)]
113 pub fn new(paths: PglitePaths) -> Result<Self> {
114 let mut pg = PostgresMod::new(paths)?;
115 pg.ensure_cluster()?;
116 let transport = Transport::prepare(&mut pg)?;
117
118 let mut instance = Self {
119 pg,
120 _temp_dir: None,
121 transport,
122 parser: ProtocolParser::new(),
123 serializers: DEFAULT_SERIALIZERS.clone(),
124 parsers: DEFAULT_PARSERS.clone(),
125 array_types_initialized: false,
126 in_transaction: false,
127 ready: true,
128 closing: false,
129 closed: false,
130 blob_input_provided: false,
131 notify_listeners: HashMap::new(),
132 global_notify_listeners: Vec::new(),
133 next_listener_id: 1,
134 next_global_listener_id: 1,
135 };
136
137 instance.exec_internal("SET search_path TO public;", None)?;
138 instance.init_array_types(true)?;
139 Ok(instance)
140 }
141
142 pub fn query(
144 &mut self,
145 sql: &str,
146 params: &[Value],
147 options: Option<&QueryOptions>,
148 ) -> Result<Results> {
149 self.check_ready()?;
150 self.init_array_types(false)?;
151
152 self.query_internal(sql, params, options)
153 }
154
155 fn query_internal(
156 &mut self,
157 sql: &str,
158 params: &[Value],
159 options: Option<&QueryOptions>,
160 ) -> Result<Results> {
161 let default_options = QueryOptions::default();
162 let query_opts = options.unwrap_or(&default_options);
163
164 self.handle_blob_input(query_opts.blob.as_ref())?;
165
166 let params_snapshot: Vec<Value> = params.to_vec();
167 let options_snapshot = options.cloned();
168 let mut collected_messages: Vec<BackendMessage> = Vec::new();
169
170 let mut exec_opts = ExecProtocolOptions::no_sync();
171 exec_opts.on_notice = query_opts.on_notice.clone();
172 exec_opts.data_transfer_container = query_opts.data_transfer_container;
173
174 let result: Result<()> = (|| {
175 let param_types = if query_opts.param_types.is_empty() {
176 &[] as &[i32]
177 } else {
178 &query_opts.param_types
179 };
180
181 let parse_msg = Serialize::parse(None, sql, param_types);
182 let ExecProtocolResult { messages } =
183 self.exec_protocol(&parse_msg, exec_opts.clone())?;
184 collected_messages.extend(messages);
185
186 let describe_msg = Serialize::describe(&PortalTarget::new('S', None));
187 let ExecProtocolResult { messages } =
188 self.exec_protocol(&describe_msg, exec_opts.clone())?;
189 let data_type_ids = parse_describe_statement_results(&messages);
190 collected_messages.extend(messages);
191
192 let bind_values = self.prepare_bind_values(params, &data_type_ids, query_opts)?;
193 let bind_config = BindConfig {
194 values: bind_values,
195 ..Default::default()
196 };
197 let bind_msg = Serialize::bind(&bind_config);
198 let ExecProtocolResult { messages } =
199 self.exec_protocol(&bind_msg, exec_opts.clone())?;
200 collected_messages.extend(messages);
201
202 let describe_portal = Serialize::describe(&PortalTarget::new('P', None));
203 let ExecProtocolResult { messages } =
204 self.exec_protocol(&describe_portal, exec_opts.clone())?;
205 collected_messages.extend(messages);
206
207 let exec_msg = Serialize::execute(None);
208 let ExecProtocolResult { messages } =
209 self.exec_protocol(&exec_msg, exec_opts.clone())?;
210 collected_messages.extend(messages);
211
212 Ok(())
213 })();
214
215 match self.exec_protocol(&Serialize::sync(), exec_opts.clone()) {
216 Ok(ExecProtocolResult { messages }) => collected_messages.extend(messages),
217 Err(err) if result.is_ok() => {
218 return Err(err.context(format!("failed to synchronize extended query: {sql}")));
219 }
220 Err(_) => {}
221 }
222
223 if let Err(err) = result {
224 match err.downcast::<DatabaseError>() {
225 Ok(db_err) => {
226 let enriched = PgliteError::new(db_err, sql, params_snapshot, options_snapshot);
227 return Err(enriched.into());
228 }
229 Err(err) => {
230 return Err(err.context(format!("failed to execute extended query: {sql}")));
231 }
232 }
233 }
234
235 self.finish_query(collected_messages, options)
236 }
237
238 pub fn is_ready(&self) -> bool {
240 self.ready && !self.closing && !self.closed
241 }
242
243 #[doc(hidden)]
245 pub fn paths(&self) -> &PglitePaths {
246 self.pg.paths()
247 }
248
249 pub(crate) fn attach_temp_dir(&mut self, temp_dir: TempDir) {
250 self._temp_dir = Some(temp_dir);
251 }
252
253 pub fn is_closed(&self) -> bool {
255 self.closed
256 }
257
258 pub fn close(&mut self) -> Result<()> {
260 if self.closed {
261 return Ok(());
262 }
263 if self.closing {
264 bail!("Pglite is closing");
265 }
266
267 self.closing = true;
268 let result = {
269 let options = ExecProtocolOptions {
270 throw_on_error: false,
271 sync_to_fs: false,
272 ..ExecProtocolOptions::default()
273 };
274
275 let end_message = Serialize::end();
276 let _ = self.exec_protocol(&end_message, options);
277 self.sync_to_fs()
278 };
279
280 self.closing = false;
281 if result.is_ok() {
282 self.closed = true;
283 self.ready = false;
284 self.notify_listeners.clear();
285 self.global_notify_listeners.clear();
286 }
287 result
288 }
289
290 pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
292 self.check_ready()?;
293 self.init_array_types(false)?;
294
295 self.exec_internal(sql, options)
296 }
297
298 fn exec_internal(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
299 let options_snapshot = options.cloned();
300 let default_options = QueryOptions::default();
301 let exec_opts_ref = options.unwrap_or(&default_options);
302 let mut exec_opts = ExecProtocolOptions::no_sync();
303 exec_opts.on_notice = exec_opts_ref.on_notice.clone();
304 exec_opts.data_transfer_container = exec_opts_ref.data_transfer_container;
305
306 self.handle_blob_input(exec_opts_ref.blob.as_ref())?;
307
308 let mut collected_messages: Vec<BackendMessage> = Vec::new();
309
310 let result: Result<()> = (|| {
311 let message = Serialize::query(sql);
312 let ExecProtocolResult { messages } =
313 self.exec_protocol(&message, exec_opts.clone())?;
314 collected_messages.extend(messages);
315 Ok(())
316 })();
317
318 match self.exec_protocol(&Serialize::sync(), exec_opts.clone()) {
319 Ok(ExecProtocolResult { messages }) => collected_messages.extend(messages),
320 Err(err) if result.is_ok() => {
321 return Err(err.context(format!("failed to synchronize simple query: {sql}")));
322 }
323 Err(_) => {}
324 }
325
326 if let Err(err) = result {
327 match err.downcast::<DatabaseError>() {
328 Ok(db_err) => {
329 let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
330 return Err(enriched.into());
331 }
332 Err(err) => {
333 return Err(err.context(format!("failed to execute simple query: {sql}")));
334 }
335 }
336 }
337
338 self.finish_exec(collected_messages, options)
339 }
340
341 pub fn listen<F>(&mut self, channel: &str, callback: F) -> Result<ListenerHandle>
343 where
344 F: Fn(&str) + Send + Sync + 'static,
345 {
346 self.check_ready()?;
347 self.init_array_types(false)?;
348
349 let normalized = to_postgres_name(channel);
350 let should_listen = match self.notify_listeners.get(&normalized) {
351 Some(existing) => existing.is_empty(),
352 None => true,
353 };
354
355 if should_listen {
356 self.exec_internal(&format!("LISTEN {}", channel), None)?;
357 }
358
359 let callback: ChannelCallback = Arc::new(callback);
360 let entry = self.notify_listeners.entry(normalized.clone()).or_default();
361 let id = self.next_listener_id;
362 self.next_listener_id = self.next_listener_id.wrapping_add(1);
363 entry.push(ChannelListener { id, callback });
364
365 Ok(ListenerHandle {
366 channel: channel.to_string(),
367 normalized_channel: normalized,
368 id,
369 })
370 }
371
372 pub fn unlisten(&mut self, handle: ListenerHandle) -> Result<()> {
374 if let Some(listeners) = self.notify_listeners.get_mut(&handle.normalized_channel) {
375 listeners.retain(|listener| listener.id != handle.id);
376 if listeners.is_empty() {
377 self.notify_listeners.remove(&handle.normalized_channel);
378 self.exec_internal(&format!("UNLISTEN {}", handle.channel), None)?;
379 }
380 }
381 Ok(())
382 }
383
384 pub fn unlisten_channel(&mut self, channel: &str) -> Result<()> {
386 let normalized = to_postgres_name(channel);
387 if self.notify_listeners.remove(&normalized).is_some() {
388 self.exec_internal(&format!("UNLISTEN {}", channel), None)?;
389 }
390 Ok(())
391 }
392
393 pub fn on_notification<F>(&mut self, callback: F) -> GlobalListenerHandle
395 where
396 F: Fn(&str, &str) + Send + Sync + 'static,
397 {
398 let id = self.next_global_listener_id;
399 self.next_global_listener_id = self.next_global_listener_id.wrapping_add(1);
400 let callback: GlobalCallback = Arc::new(callback);
401 self.global_notify_listeners
402 .push(GlobalListener { id, callback });
403 GlobalListenerHandle { id }
404 }
405
406 pub fn off_notification(&mut self, handle: GlobalListenerHandle) {
408 self.global_notify_listeners
409 .retain(|listener| listener.id != handle.id);
410 }
411
412 pub fn describe_query(
414 &mut self,
415 sql: &str,
416 options: Option<&QueryOptions>,
417 ) -> Result<DescribeQueryResult> {
418 self.check_ready()?;
419 self.init_array_types(false)?;
420
421 let default_options = QueryOptions::default();
422 let query_opts = options.unwrap_or(&default_options);
423
424 let options_snapshot = options.cloned();
425 let mut exec_opts = ExecProtocolOptions::no_sync();
426 exec_opts.on_notice = query_opts.on_notice.clone();
427 exec_opts.data_transfer_container = query_opts.data_transfer_container;
428
429 let mut describe_messages: Vec<BackendMessage> = Vec::new();
430
431 let result: Result<()> = (|| {
432 let param_types = if query_opts.param_types.is_empty() {
433 &[] as &[i32]
434 } else {
435 &query_opts.param_types
436 };
437
438 let parse_msg = Serialize::parse(None, sql, param_types);
439 let _ = self.exec_protocol(&parse_msg, exec_opts.clone())?;
441
442 let describe_msg = Serialize::describe(&PortalTarget::new('S', None));
443 let ExecProtocolResult { messages } =
444 self.exec_protocol(&describe_msg, exec_opts.clone())?;
445 describe_messages.extend(messages);
446
447 Ok(())
448 })();
449
450 match self.exec_protocol(&Serialize::sync(), exec_opts.clone()) {
451 Ok(ExecProtocolResult { messages }) => describe_messages.extend(messages),
452 Err(err) if result.is_ok() => {
453 return Err(err.context(format!("failed to synchronize describe query: {sql}")));
454 }
455 Err(_) => {}
456 }
457
458 if let Err(err) = result {
459 match err.downcast::<DatabaseError>() {
460 Ok(db_err) => {
461 let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
462 return Err(enriched.into());
463 }
464 Err(err) => {
465 return Err(err.context(format!("failed to describe query: {sql}")));
466 }
467 }
468 }
469
470 let param_type_ids = parse_describe_statement_results(&describe_messages);
471 let query_params = param_type_ids
472 .into_iter()
473 .map(|oid| DescribeQueryParam {
474 data_type_id: oid,
475 serializer: self.serializers.get(&oid).cloned(),
476 })
477 .collect();
478
479 let result_fields = describe_messages
480 .iter()
481 .find_map(|msg| match msg {
482 BackendMessage::RowDescription(desc) => Some(
483 desc.fields
484 .iter()
485 .map(|field| DescribeResultField {
486 name: field.name.clone(),
487 data_type_id: field.data_type_id,
488 parser: self.parsers.get(&field.data_type_id).cloned(),
489 })
490 .collect::<Vec<_>>(),
491 ),
492 _ => None,
493 })
494 .unwrap_or_default();
495
496 Ok(DescribeQueryResult {
497 query_params,
498 result_fields,
499 })
500 }
501
502 pub fn transaction<F, T>(&mut self, mut callback: F) -> Result<T>
504 where
505 F: FnMut(&mut Transaction<'_>) -> Result<T>,
506 {
507 self.check_ready()?;
508 self.init_array_types(false)?;
509
510 self.run_exec_command("BEGIN")?;
512 self.in_transaction = true;
513
514 let mut tx = Transaction::new(self);
515 let callback_result = callback(&mut tx);
516
517 let txn_result = match callback_result {
518 Ok(value) => {
519 if !tx.closed {
520 tx.commit_internal()?;
521 }
522 Ok(value)
523 }
524 Err(err) => {
525 if !tx.closed {
526 tx.rollback_internal()?;
527 }
528 Err(err)
529 }
530 };
531
532 self.in_transaction = false;
533 txn_result
534 }
535
536 pub fn sync_to_fs(&mut self) -> Result<()> {
538 let mount_root = self.pg.paths().mount_root();
539 if let Ok(file) = std::fs::OpenOptions::new().read(true).open(mount_root) {
540 let _ = file.sync_all();
541 }
542 let data_root = mount_root.join("pglite");
543 if let Ok(file) = std::fs::OpenOptions::new().read(true).open(&data_root) {
544 let _ = file.sync_all();
545 }
546 Ok(())
547 }
548
549 fn prepare_bind_values(
550 &self,
551 params: &[Value],
552 data_type_ids: &[i32],
553 options: &QueryOptions,
554 ) -> Result<Vec<BindValue>> {
555 if params.is_empty() {
556 return Ok(Vec::new());
557 }
558
559 let mut values = Vec::with_capacity(params.len());
560 let overrides = if options.serializers.is_empty() {
561 None
562 } else {
563 Some(&options.serializers)
564 };
565
566 for (idx, value) in params.iter().enumerate() {
567 if value.is_null() {
568 values.push(BindValue::Null);
569 continue;
570 }
571
572 let oid = data_type_ids.get(idx).copied().unwrap_or(TEXT);
573 let serializer = overrides
574 .and_then(|map| map.get(&oid))
575 .or_else(|| self.serializers.get(&oid));
576
577 let serialized = match serializer {
578 Some(func) => func(value).with_context(|| {
579 format!("failed to serialize parameter {idx} using OID {oid}")
580 })?,
581 None => self.default_serialize_value(value),
582 };
583
584 values.push(BindValue::Text(serialized));
585 }
586
587 Ok(values)
588 }
589
590 fn default_serialize_value(&self, value: &Value) -> String {
591 Self::default_serialize_value_static(value)
592 }
593
594 pub(crate) fn default_serialize_value_static(value: &Value) -> String {
595 match value {
596 Value::String(s) => s.clone(),
597 Value::Number(num) => num.to_string(),
598 Value::Bool(flag) => {
599 if *flag {
600 "t".to_string()
601 } else {
602 "f".to_string()
603 }
604 }
605 _ => value.to_string(),
606 }
607 }
608
609 fn finish_query(
610 &mut self,
611 messages: Vec<BackendMessage>,
612 options: Option<&QueryOptions>,
613 ) -> Result<Results> {
614 let blob = self.get_written_blob()?;
615 self.cleanup_blob()?;
616 if !self.in_transaction {
617 self.sync_to_fs()?;
618 }
619 let parsed = parse_results(&messages, &self.parsers, options, blob);
620 parsed
621 .into_iter()
622 .next()
623 .ok_or_else(|| anyhow!("query returned no result sets"))
624 }
625
626 fn finish_exec(
627 &mut self,
628 messages: Vec<BackendMessage>,
629 options: Option<&QueryOptions>,
630 ) -> Result<Vec<Results>> {
631 let blob = self.get_written_blob()?;
632 self.cleanup_blob()?;
633 if !self.in_transaction {
634 self.sync_to_fs()?;
635 }
636 Ok(parse_results(&messages, &self.parsers, options, blob))
637 }
638
639 fn exec_protocol(
640 &mut self,
641 message: &[u8],
642 options: ExecProtocolOptions,
643 ) -> Result<ExecProtocolResult> {
644 let ExecProtocolOptions {
645 sync_to_fs,
646 throw_on_error,
647 on_notice,
648 data_transfer_container,
649 } = options;
650
651 let data = self.exec_protocol_raw(message, sync_to_fs, data_transfer_container)?;
652
653 let mut messages = Vec::new();
654 let on_notice_cb = on_notice.clone();
655 if let Err(err) = self.parser.parse(&data, |msg| {
656 if let BackendMessage::Error(db_err) = &msg
657 && throw_on_error
658 {
659 return Err(anyhow!(db_err.clone()));
660 }
661 if let Some(callback) = on_notice_cb.as_ref()
662 && let BackendMessage::Notice(notice) = &msg
663 {
664 callback(notice);
665 }
666 messages.push(msg);
667 Ok(())
668 }) {
669 match err.downcast::<DatabaseError>() {
670 Ok(db_err) => {
671 self.parser = ProtocolParser::new();
672 return Err(anyhow!(db_err));
673 }
674 Err(err) => return Err(err),
675 }
676 }
677
678 for message in &messages {
679 if let BackendMessage::Notification(note) = message {
680 let key = to_postgres_name(¬e.channel);
681 if let Some(listeners) = self.notify_listeners.get(&key) {
682 for listener in listeners {
683 (listener.callback)(¬e.payload);
684 }
685 }
686 for listener in &self.global_notify_listeners {
687 (listener.callback)(¬e.channel, ¬e.payload);
688 }
689 }
690 }
691
692 Ok(ExecProtocolResult { messages })
693 }
694
695 fn exec_protocol_raw(
696 &mut self,
697 message: &[u8],
698 sync_to_fs: bool,
699 data_transfer_container: Option<DataTransferContainer>,
700 ) -> Result<Vec<u8>> {
701 let data = self
702 .transport
703 .send(&mut self.pg, message, data_transfer_container)?;
704 if sync_to_fs {
705 self.sync_to_fs()?;
706 }
707 Ok(data)
708 }
709
710 fn init_array_types(&mut self, force: bool) -> Result<()> {
711 if self.array_types_initialized && !force {
712 return Ok(());
713 }
714
715 let prev = self.array_types_initialized;
716 self.array_types_initialized = true;
717
718 let result: Result<()> = {
719 let sql = "
720 SELECT b.oid, b.typarray
721 FROM pg_catalog.pg_type a
722 LEFT JOIN pg_catalog.pg_type b ON b.oid = a.typelem
723 WHERE a.typcategory = 'A'
724 GROUP BY b.oid, b.typarray
725 ORDER BY b.oid
726 ";
727 let results = self.exec(sql, None)?;
728 let result_set = results
729 .into_iter()
730 .next()
731 .ok_or_else(|| anyhow!("array type discovery returned no results"))?;
732
733 for row in result_set.rows {
734 let map = match row {
735 Value::Object(map) => map,
736 _ => continue,
737 };
738 let element_oid = value_to_i32(map.get("oid")).unwrap_or(0);
739 let array_oid = value_to_i32(map.get("typarray")).unwrap_or(0);
740
741 if element_oid == 0 || array_oid == 0 {
742 continue;
743 }
744
745 let element_parser = self.parsers.get(&element_oid).cloned();
746 let element_serializer = self.serializers.get(&element_oid).cloned();
747
748 let parser_clone = element_parser.clone();
749 let array_parser: TypeParser = Arc::new(move |text: &str, _| {
750 parse_array_text(text, parser_clone.clone(), element_oid, array_oid)
751 });
752 self.parsers.insert(array_oid, array_parser);
753
754 let serializer_clone = element_serializer.clone();
755 let array_serializer: Serializer = Arc::new(move |value: &Value| {
756 serialize_array_value(value, serializer_clone.clone(), array_oid)
757 });
758 self.serializers.insert(array_oid, array_serializer);
759 }
760 Ok(())
761 };
762
763 if let Err(err) = result {
764 self.array_types_initialized = prev;
765 Err(err)
766 } else {
767 Ok(())
768 }
769 }
770
771 fn run_exec_command(&mut self, sql: &str) -> Result<()> {
772 self.exec_internal(sql, None).map(|_| ())
773 }
774
775 fn handle_blob_input(&mut self, blob: Option<&Vec<u8>>) -> Result<()> {
776 let path = self.dev_blob_path();
777 if let Some(bytes) = blob {
778 if let Some(parent) = path.parent() {
779 fs::create_dir_all(parent).with_context(|| {
780 format!("failed to create blob directory {}", parent.display())
781 })?;
782 }
783 fs::write(&path, bytes)
784 .with_context(|| format!("write blob input to {}", path.display()))?;
785 self.blob_input_provided = true;
786 } else {
787 self.blob_input_provided = false;
788 let _ = fs::remove_file(&path);
789 }
790 Ok(())
791 }
792
793 fn dev_blob_path(&self) -> PathBuf {
794 self.pg.paths().pgroot.join("dev/blob")
795 }
796
797 fn cleanup_blob(&mut self) -> Result<()> {
798 Ok(())
799 }
800
801 fn get_written_blob(&mut self) -> Result<Option<Vec<u8>>> {
802 let path = self.dev_blob_path();
803
804 if self.blob_input_provided {
805 self.blob_input_provided = false;
806 let _ = fs::remove_file(&path);
807 return Ok(None);
808 }
809
810 match fs::read(&path) {
811 Ok(data) => {
812 self.blob_input_provided = false;
813 let _ = fs::remove_file(&path);
814 if data.is_empty() {
815 Ok(None)
816 } else {
817 Ok(Some(data))
818 }
819 }
820 Err(err) => {
821 if err.kind() == io::ErrorKind::NotFound {
822 self.blob_input_provided = false;
823 Ok(None)
824 } else {
825 Err(err).with_context(|| format!("read blob output from {}", path.display()))
826 }
827 }
828 }
829 }
830
831 fn check_ready(&self) -> Result<()> {
832 if self.closing {
833 bail!("Pglite instance is closing");
834 }
835 if self.closed {
836 bail!("Pglite instance is closed");
837 }
838 if !self.ready {
839 bail!("Pglite instance is not ready");
840 }
841 Ok(())
842 }
843}
844
845impl Drop for Pglite {
846 fn drop(&mut self) {
847 if !self.closed {
848 let _ = self.close();
849 }
850 }
851}
852
853fn to_postgres_name(input: &str) -> String {
854 if input.starts_with('"') && input.ends_with('"') && input.len() >= 2 {
855 input[1..input.len() - 1].to_string()
856 } else {
857 input.to_lowercase()
858 }
859}
860
861fn value_to_i32(value: Option<&Value>) -> Option<i32> {
862 match value? {
863 Value::Number(number) => number.as_i64().map(|value| value as i32),
864 Value::String(string) => string.parse::<i32>().ok(),
865 _ => None,
866 }
867}
868
869pub struct Transaction<'a> {
871 client: &'a mut Pglite,
872 closed: bool,
873}
874
875impl<'a> Transaction<'a> {
876 fn new(client: &'a mut Pglite) -> Self {
877 Self {
878 client,
879 closed: false,
880 }
881 }
882
883 fn commit_internal(&mut self) -> Result<()> {
884 self.ensure_open()?;
885 self.client.exec_internal("COMMIT", None)?;
886 self.closed = true;
887 Ok(())
888 }
889
890 fn rollback_internal(&mut self) -> Result<()> {
891 self.ensure_open()?;
892 self.client.exec_internal("ROLLBACK", None)?;
893 self.closed = true;
894 Ok(())
895 }
896
897 fn ensure_open(&self) -> Result<()> {
898 if self.closed {
899 bail!("transaction is already closed");
900 }
901 Ok(())
902 }
903
904 pub fn query(
905 &mut self,
906 sql: &str,
907 params: &[Value],
908 options: Option<&QueryOptions>,
909 ) -> Result<Results> {
910 self.ensure_open()?;
911 self.client.query_internal(sql, params, options)
912 }
913
914 pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
915 self.ensure_open()?;
916 self.client.exec_internal(sql, options)
917 }
918
919 pub fn commit(&mut self) -> Result<()> {
920 self.commit_internal()
921 }
922
923 pub fn rollback(&mut self) -> Result<()> {
924 self.rollback_internal()
925 }
926
927 pub fn is_closed(&self) -> bool {
928 self.closed
929 }
930
931 pub fn closed(&self) -> bool {
932 self.closed
933 }
934}