Skip to main content

pglite_oxide/pglite/
client.rs

1use anyhow::{Context, Result, anyhow, bail};
2use serde_json::Value;
3use std::collections::HashMap;
4use std::fs;
5use std::io;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9use tempfile::TempDir;
10
11use crate::pglite::base::PglitePaths;
12use crate::pglite::builder::PgliteBuilder;
13use crate::pglite::errors::PgliteError;
14use crate::pglite::interface::{
15    DataTransferContainer, DescribeQueryParam, DescribeQueryResult, DescribeResultField,
16    ExecProtocolOptions, ExecProtocolResult, ParserMap, QueryOptions, Results, Serializer,
17    SerializerMap, TypeParser,
18};
19use crate::pglite::parse::{parse_describe_statement_results, parse_results};
20use crate::pglite::postgres_mod::PostgresMod;
21use crate::pglite::transport::Transport;
22use crate::pglite::types::{
23    DEFAULT_PARSERS, DEFAULT_SERIALIZERS, TEXT, parse_array_text, serialize_array_value,
24};
25use crate::protocol::messages::{BackendMessage, DatabaseError};
26use crate::protocol::parser::Parser as ProtocolParser;
27use crate::protocol::serializer::{BindConfig, BindValue, PortalTarget, Serialize};
28
29type ChannelCallback = Arc<dyn Fn(&str) + Send + Sync + 'static>;
30type GlobalCallback = Arc<dyn Fn(&str, &str) + Send + Sync + 'static>;
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
33pub struct ListenerHandle {
34    channel: String,
35    normalized_channel: String,
36    id: u64,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
40pub struct GlobalListenerHandle {
41    id: u64,
42}
43
44impl ListenerHandle {
45    pub fn channel(&self) -> &str {
46        &self.channel
47    }
48
49    pub fn id(&self) -> u64 {
50        self.id
51    }
52}
53
54impl GlobalListenerHandle {
55    pub fn id(&self) -> u64 {
56        self.id
57    }
58}
59
60struct ChannelListener {
61    id: u64,
62    callback: ChannelCallback,
63}
64
65struct GlobalListener {
66    id: u64,
67    callback: GlobalCallback,
68}
69
70/// Primary entry point for interacting with the embedded Postgres runtime.
71pub struct Pglite {
72    pg: PostgresMod,
73    _temp_dir: Option<TempDir>,
74    transport: Transport,
75    parser: ProtocolParser,
76    serializers: SerializerMap,
77    parsers: ParserMap,
78    array_types_initialized: bool,
79    in_transaction: bool,
80    ready: bool,
81    closing: bool,
82    closed: bool,
83    blob_input_provided: bool,
84    notify_listeners: HashMap<String, Vec<ChannelListener>>,
85    global_notify_listeners: Vec<GlobalListener>,
86    next_listener_id: u64,
87    next_global_listener_id: u64,
88}
89
90impl Pglite {
91    /// Create a builder for opening persistent or temporary PGlite databases.
92    pub fn builder() -> PgliteBuilder {
93        PgliteBuilder::new()
94    }
95
96    /// Open a persistent PGlite database rooted at `root`, installing and initializing it if needed.
97    pub fn open(root: impl AsRef<Path>) -> Result<Self> {
98        Self::builder().path(root.as_ref().to_path_buf()).open()
99    }
100
101    /// Open a persistent PGlite database under the platform data directory for `app_id`.
102    pub fn open_app(app_id: (&str, &str, &str)) -> Result<Self> {
103        Self::builder().app_id(app_id).open()
104    }
105
106    /// Create an ephemeral PGlite database whose files are removed when the instance is dropped.
107    pub fn temporary() -> Result<Self> {
108        Self::builder().temporary().open()
109    }
110
111    /// Create a new Pglite instance backed by the provided runtime paths.
112    #[doc(hidden)]
113    pub fn new(paths: PglitePaths) -> Result<Self> {
114        let mut pg = PostgresMod::new(paths)?;
115        pg.ensure_cluster()?;
116        let transport = Transport::prepare(&mut pg)?;
117
118        let mut instance = Self {
119            pg,
120            _temp_dir: None,
121            transport,
122            parser: ProtocolParser::new(),
123            serializers: DEFAULT_SERIALIZERS.clone(),
124            parsers: DEFAULT_PARSERS.clone(),
125            array_types_initialized: false,
126            in_transaction: false,
127            ready: true,
128            closing: false,
129            closed: false,
130            blob_input_provided: false,
131            notify_listeners: HashMap::new(),
132            global_notify_listeners: Vec::new(),
133            next_listener_id: 1,
134            next_global_listener_id: 1,
135        };
136
137        instance.exec_internal("SET search_path TO public;", None)?;
138        instance.init_array_types(true)?;
139        Ok(instance)
140    }
141
142    /// Execute a SQL query using the extended protocol.
143    pub fn query(
144        &mut self,
145        sql: &str,
146        params: &[Value],
147        options: Option<&QueryOptions>,
148    ) -> Result<Results> {
149        self.check_ready()?;
150        self.init_array_types(false)?;
151
152        self.query_internal(sql, params, options)
153    }
154
155    fn query_internal(
156        &mut self,
157        sql: &str,
158        params: &[Value],
159        options: Option<&QueryOptions>,
160    ) -> Result<Results> {
161        let default_options = QueryOptions::default();
162        let query_opts = options.unwrap_or(&default_options);
163
164        self.handle_blob_input(query_opts.blob.as_ref())?;
165
166        let params_snapshot: Vec<Value> = params.to_vec();
167        let options_snapshot = options.cloned();
168        let mut collected_messages: Vec<BackendMessage> = Vec::new();
169
170        let mut exec_opts = ExecProtocolOptions::no_sync();
171        exec_opts.on_notice = query_opts.on_notice.clone();
172        exec_opts.data_transfer_container = query_opts.data_transfer_container;
173
174        let result: Result<()> = (|| {
175            let param_types = if query_opts.param_types.is_empty() {
176                &[] as &[i32]
177            } else {
178                &query_opts.param_types
179            };
180
181            let parse_msg = Serialize::parse(None, sql, param_types);
182            let ExecProtocolResult { messages } =
183                self.exec_protocol(&parse_msg, exec_opts.clone())?;
184            collected_messages.extend(messages);
185
186            let describe_msg = Serialize::describe(&PortalTarget::new('S', None));
187            let ExecProtocolResult { messages } =
188                self.exec_protocol(&describe_msg, exec_opts.clone())?;
189            let data_type_ids = parse_describe_statement_results(&messages);
190            collected_messages.extend(messages);
191
192            let bind_values = self.prepare_bind_values(params, &data_type_ids, query_opts)?;
193            let bind_config = BindConfig {
194                values: bind_values,
195                ..Default::default()
196            };
197            let bind_msg = Serialize::bind(&bind_config);
198            let ExecProtocolResult { messages } =
199                self.exec_protocol(&bind_msg, exec_opts.clone())?;
200            collected_messages.extend(messages);
201
202            let describe_portal = Serialize::describe(&PortalTarget::new('P', None));
203            let ExecProtocolResult { messages } =
204                self.exec_protocol(&describe_portal, exec_opts.clone())?;
205            collected_messages.extend(messages);
206
207            let exec_msg = Serialize::execute(None);
208            let ExecProtocolResult { messages } =
209                self.exec_protocol(&exec_msg, exec_opts.clone())?;
210            collected_messages.extend(messages);
211
212            Ok(())
213        })();
214
215        match self.exec_protocol(&Serialize::sync(), exec_opts.clone()) {
216            Ok(ExecProtocolResult { messages }) => collected_messages.extend(messages),
217            Err(err) if result.is_ok() => {
218                return Err(err.context(format!("failed to synchronize extended query: {sql}")));
219            }
220            Err(_) => {}
221        }
222
223        if let Err(err) = result {
224            match err.downcast::<DatabaseError>() {
225                Ok(db_err) => {
226                    let enriched = PgliteError::new(db_err, sql, params_snapshot, options_snapshot);
227                    return Err(enriched.into());
228                }
229                Err(err) => {
230                    return Err(err.context(format!("failed to execute extended query: {sql}")));
231                }
232            }
233        }
234
235        self.finish_query(collected_messages, options)
236    }
237
238    /// Return `true` if the instance is ready for new work.
239    pub fn is_ready(&self) -> bool {
240        self.ready && !self.closing && !self.closed
241    }
242
243    /// Return the host-side runtime and data-directory paths backing this instance.
244    #[doc(hidden)]
245    pub fn paths(&self) -> &PglitePaths {
246        self.pg.paths()
247    }
248
249    pub(crate) fn attach_temp_dir(&mut self, temp_dir: TempDir) {
250        self._temp_dir = Some(temp_dir);
251    }
252
253    /// Return `true` if the instance has already been closed.
254    pub fn is_closed(&self) -> bool {
255        self.closed
256    }
257
258    /// Shut down the embedded Postgres runtime.
259    pub fn close(&mut self) -> Result<()> {
260        if self.closed {
261            return Ok(());
262        }
263        if self.closing {
264            bail!("Pglite is closing");
265        }
266
267        self.closing = true;
268        let result = {
269            let options = ExecProtocolOptions {
270                throw_on_error: false,
271                sync_to_fs: false,
272                ..ExecProtocolOptions::default()
273            };
274
275            let end_message = Serialize::end();
276            let _ = self.exec_protocol(&end_message, options);
277            self.sync_to_fs()
278        };
279
280        self.closing = false;
281        if result.is_ok() {
282            self.closed = true;
283            self.ready = false;
284            self.notify_listeners.clear();
285            self.global_notify_listeners.clear();
286        }
287        result
288    }
289
290    /// Execute a simple SQL statement that may contain multiple commands.
291    pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
292        self.check_ready()?;
293        self.init_array_types(false)?;
294
295        self.exec_internal(sql, options)
296    }
297
298    fn exec_internal(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
299        let options_snapshot = options.cloned();
300        let default_options = QueryOptions::default();
301        let exec_opts_ref = options.unwrap_or(&default_options);
302        let mut exec_opts = ExecProtocolOptions::no_sync();
303        exec_opts.on_notice = exec_opts_ref.on_notice.clone();
304        exec_opts.data_transfer_container = exec_opts_ref.data_transfer_container;
305
306        self.handle_blob_input(exec_opts_ref.blob.as_ref())?;
307
308        let mut collected_messages: Vec<BackendMessage> = Vec::new();
309
310        let result: Result<()> = (|| {
311            let message = Serialize::query(sql);
312            let ExecProtocolResult { messages } =
313                self.exec_protocol(&message, exec_opts.clone())?;
314            collected_messages.extend(messages);
315            Ok(())
316        })();
317
318        match self.exec_protocol(&Serialize::sync(), exec_opts.clone()) {
319            Ok(ExecProtocolResult { messages }) => collected_messages.extend(messages),
320            Err(err) if result.is_ok() => {
321                return Err(err.context(format!("failed to synchronize simple query: {sql}")));
322            }
323            Err(_) => {}
324        }
325
326        if let Err(err) = result {
327            match err.downcast::<DatabaseError>() {
328                Ok(db_err) => {
329                    let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
330                    return Err(enriched.into());
331                }
332                Err(err) => {
333                    return Err(err.context(format!("failed to execute simple query: {sql}")));
334                }
335            }
336        }
337
338        self.finish_exec(collected_messages, options)
339    }
340
341    /// Register a listener for `LISTEN channel`. Returns a handle that can be used to unlisten.
342    pub fn listen<F>(&mut self, channel: &str, callback: F) -> Result<ListenerHandle>
343    where
344        F: Fn(&str) + Send + Sync + 'static,
345    {
346        self.check_ready()?;
347        self.init_array_types(false)?;
348
349        let normalized = to_postgres_name(channel);
350        let should_listen = match self.notify_listeners.get(&normalized) {
351            Some(existing) => existing.is_empty(),
352            None => true,
353        };
354
355        if should_listen {
356            self.exec_internal(&format!("LISTEN {}", channel), None)?;
357        }
358
359        let callback: ChannelCallback = Arc::new(callback);
360        let entry = self.notify_listeners.entry(normalized.clone()).or_default();
361        let id = self.next_listener_id;
362        self.next_listener_id = self.next_listener_id.wrapping_add(1);
363        entry.push(ChannelListener { id, callback });
364
365        Ok(ListenerHandle {
366            channel: channel.to_string(),
367            normalized_channel: normalized,
368            id,
369        })
370    }
371
372    /// Remove a listener corresponding to the provided handle.
373    pub fn unlisten(&mut self, handle: ListenerHandle) -> Result<()> {
374        if let Some(listeners) = self.notify_listeners.get_mut(&handle.normalized_channel) {
375            listeners.retain(|listener| listener.id != handle.id);
376            if listeners.is_empty() {
377                self.notify_listeners.remove(&handle.normalized_channel);
378                self.exec_internal(&format!("UNLISTEN {}", handle.channel), None)?;
379            }
380        }
381        Ok(())
382    }
383
384    /// Remove all listeners for the specified channel.
385    pub fn unlisten_channel(&mut self, channel: &str) -> Result<()> {
386        let normalized = to_postgres_name(channel);
387        if self.notify_listeners.remove(&normalized).is_some() {
388            self.exec_internal(&format!("UNLISTEN {}", channel), None)?;
389        }
390        Ok(())
391    }
392
393    /// Register a global notification callback.
394    pub fn on_notification<F>(&mut self, callback: F) -> GlobalListenerHandle
395    where
396        F: Fn(&str, &str) + Send + Sync + 'static,
397    {
398        let id = self.next_global_listener_id;
399        self.next_global_listener_id = self.next_global_listener_id.wrapping_add(1);
400        let callback: GlobalCallback = Arc::new(callback);
401        self.global_notify_listeners
402            .push(GlobalListener { id, callback });
403        GlobalListenerHandle { id }
404    }
405
406    /// Deregister a previously registered global notification callback.
407    pub fn off_notification(&mut self, handle: GlobalListenerHandle) {
408        self.global_notify_listeners
409            .retain(|listener| listener.id != handle.id);
410    }
411
412    /// Describe the parameter and result metadata for a SQL query.
413    pub fn describe_query(
414        &mut self,
415        sql: &str,
416        options: Option<&QueryOptions>,
417    ) -> Result<DescribeQueryResult> {
418        self.check_ready()?;
419        self.init_array_types(false)?;
420
421        let default_options = QueryOptions::default();
422        let query_opts = options.unwrap_or(&default_options);
423
424        let options_snapshot = options.cloned();
425        let mut exec_opts = ExecProtocolOptions::no_sync();
426        exec_opts.on_notice = query_opts.on_notice.clone();
427        exec_opts.data_transfer_container = query_opts.data_transfer_container;
428
429        let mut describe_messages: Vec<BackendMessage> = Vec::new();
430
431        let result: Result<()> = (|| {
432            let param_types = if query_opts.param_types.is_empty() {
433                &[] as &[i32]
434            } else {
435                &query_opts.param_types
436            };
437
438            let parse_msg = Serialize::parse(None, sql, param_types);
439            // Ignore returned messages; we just need to ensure the statement parses.
440            let _ = self.exec_protocol(&parse_msg, exec_opts.clone())?;
441
442            let describe_msg = Serialize::describe(&PortalTarget::new('S', None));
443            let ExecProtocolResult { messages } =
444                self.exec_protocol(&describe_msg, exec_opts.clone())?;
445            describe_messages.extend(messages);
446
447            Ok(())
448        })();
449
450        match self.exec_protocol(&Serialize::sync(), exec_opts.clone()) {
451            Ok(ExecProtocolResult { messages }) => describe_messages.extend(messages),
452            Err(err) if result.is_ok() => {
453                return Err(err.context(format!("failed to synchronize describe query: {sql}")));
454            }
455            Err(_) => {}
456        }
457
458        if let Err(err) = result {
459            match err.downcast::<DatabaseError>() {
460                Ok(db_err) => {
461                    let enriched = PgliteError::new(db_err, sql, Vec::new(), options_snapshot);
462                    return Err(enriched.into());
463                }
464                Err(err) => {
465                    return Err(err.context(format!("failed to describe query: {sql}")));
466                }
467            }
468        }
469
470        let param_type_ids = parse_describe_statement_results(&describe_messages);
471        let query_params = param_type_ids
472            .into_iter()
473            .map(|oid| DescribeQueryParam {
474                data_type_id: oid,
475                serializer: self.serializers.get(&oid).cloned(),
476            })
477            .collect();
478
479        let result_fields = describe_messages
480            .iter()
481            .find_map(|msg| match msg {
482                BackendMessage::RowDescription(desc) => Some(
483                    desc.fields
484                        .iter()
485                        .map(|field| DescribeResultField {
486                            name: field.name.clone(),
487                            data_type_id: field.data_type_id,
488                            parser: self.parsers.get(&field.data_type_id).cloned(),
489                        })
490                        .collect::<Vec<_>>(),
491                ),
492                _ => None,
493            })
494            .unwrap_or_default();
495
496        Ok(DescribeQueryResult {
497            query_params,
498            result_fields,
499        })
500    }
501
502    /// Run a closure within an SQL transaction (`BEGIN .. COMMIT/ROLLBACK`).
503    pub fn transaction<F, T>(&mut self, mut callback: F) -> Result<T>
504    where
505        F: FnMut(&mut Transaction<'_>) -> Result<T>,
506    {
507        self.check_ready()?;
508        self.init_array_types(false)?;
509
510        // Begin transaction
511        self.run_exec_command("BEGIN")?;
512        self.in_transaction = true;
513
514        let mut tx = Transaction::new(self);
515        let callback_result = callback(&mut tx);
516
517        let txn_result = match callback_result {
518            Ok(value) => {
519                if !tx.closed {
520                    tx.commit_internal()?;
521                }
522                Ok(value)
523            }
524            Err(err) => {
525                if !tx.closed {
526                    tx.rollback_internal()?;
527                }
528                Err(err)
529            }
530        };
531
532        self.in_transaction = false;
533        txn_result
534    }
535
536    /// Flush runtime writes to the underlying filesystem. Currently a no-op on the host.
537    pub fn sync_to_fs(&mut self) -> Result<()> {
538        let mount_root = self.pg.paths().mount_root();
539        if let Ok(file) = std::fs::OpenOptions::new().read(true).open(mount_root) {
540            let _ = file.sync_all();
541        }
542        let data_root = mount_root.join("pglite");
543        if let Ok(file) = std::fs::OpenOptions::new().read(true).open(&data_root) {
544            let _ = file.sync_all();
545        }
546        Ok(())
547    }
548
549    fn prepare_bind_values(
550        &self,
551        params: &[Value],
552        data_type_ids: &[i32],
553        options: &QueryOptions,
554    ) -> Result<Vec<BindValue>> {
555        if params.is_empty() {
556            return Ok(Vec::new());
557        }
558
559        let mut values = Vec::with_capacity(params.len());
560        let overrides = if options.serializers.is_empty() {
561            None
562        } else {
563            Some(&options.serializers)
564        };
565
566        for (idx, value) in params.iter().enumerate() {
567            if value.is_null() {
568                values.push(BindValue::Null);
569                continue;
570            }
571
572            let oid = data_type_ids.get(idx).copied().unwrap_or(TEXT);
573            let serializer = overrides
574                .and_then(|map| map.get(&oid))
575                .or_else(|| self.serializers.get(&oid));
576
577            let serialized = match serializer {
578                Some(func) => func(value).with_context(|| {
579                    format!("failed to serialize parameter {idx} using OID {oid}")
580                })?,
581                None => self.default_serialize_value(value),
582            };
583
584            values.push(BindValue::Text(serialized));
585        }
586
587        Ok(values)
588    }
589
590    fn default_serialize_value(&self, value: &Value) -> String {
591        Self::default_serialize_value_static(value)
592    }
593
594    pub(crate) fn default_serialize_value_static(value: &Value) -> String {
595        match value {
596            Value::String(s) => s.clone(),
597            Value::Number(num) => num.to_string(),
598            Value::Bool(flag) => {
599                if *flag {
600                    "t".to_string()
601                } else {
602                    "f".to_string()
603                }
604            }
605            _ => value.to_string(),
606        }
607    }
608
609    fn finish_query(
610        &mut self,
611        messages: Vec<BackendMessage>,
612        options: Option<&QueryOptions>,
613    ) -> Result<Results> {
614        let blob = self.get_written_blob()?;
615        self.cleanup_blob()?;
616        if !self.in_transaction {
617            self.sync_to_fs()?;
618        }
619        let parsed = parse_results(&messages, &self.parsers, options, blob);
620        parsed
621            .into_iter()
622            .next()
623            .ok_or_else(|| anyhow!("query returned no result sets"))
624    }
625
626    fn finish_exec(
627        &mut self,
628        messages: Vec<BackendMessage>,
629        options: Option<&QueryOptions>,
630    ) -> Result<Vec<Results>> {
631        let blob = self.get_written_blob()?;
632        self.cleanup_blob()?;
633        if !self.in_transaction {
634            self.sync_to_fs()?;
635        }
636        Ok(parse_results(&messages, &self.parsers, options, blob))
637    }
638
639    fn exec_protocol(
640        &mut self,
641        message: &[u8],
642        options: ExecProtocolOptions,
643    ) -> Result<ExecProtocolResult> {
644        let ExecProtocolOptions {
645            sync_to_fs,
646            throw_on_error,
647            on_notice,
648            data_transfer_container,
649        } = options;
650
651        let data = self.exec_protocol_raw(message, sync_to_fs, data_transfer_container)?;
652
653        let mut messages = Vec::new();
654        let on_notice_cb = on_notice.clone();
655        if let Err(err) = self.parser.parse(&data, |msg| {
656            if let BackendMessage::Error(db_err) = &msg
657                && throw_on_error
658            {
659                return Err(anyhow!(db_err.clone()));
660            }
661            if let Some(callback) = on_notice_cb.as_ref()
662                && let BackendMessage::Notice(notice) = &msg
663            {
664                callback(notice);
665            }
666            messages.push(msg);
667            Ok(())
668        }) {
669            match err.downcast::<DatabaseError>() {
670                Ok(db_err) => {
671                    self.parser = ProtocolParser::new();
672                    return Err(anyhow!(db_err));
673                }
674                Err(err) => return Err(err),
675            }
676        }
677
678        for message in &messages {
679            if let BackendMessage::Notification(note) = message {
680                let key = to_postgres_name(&note.channel);
681                if let Some(listeners) = self.notify_listeners.get(&key) {
682                    for listener in listeners {
683                        (listener.callback)(&note.payload);
684                    }
685                }
686                for listener in &self.global_notify_listeners {
687                    (listener.callback)(&note.channel, &note.payload);
688                }
689            }
690        }
691
692        Ok(ExecProtocolResult { messages })
693    }
694
695    fn exec_protocol_raw(
696        &mut self,
697        message: &[u8],
698        sync_to_fs: bool,
699        data_transfer_container: Option<DataTransferContainer>,
700    ) -> Result<Vec<u8>> {
701        let data = self
702            .transport
703            .send(&mut self.pg, message, data_transfer_container)?;
704        if sync_to_fs {
705            self.sync_to_fs()?;
706        }
707        Ok(data)
708    }
709
710    fn init_array_types(&mut self, force: bool) -> Result<()> {
711        if self.array_types_initialized && !force {
712            return Ok(());
713        }
714
715        let prev = self.array_types_initialized;
716        self.array_types_initialized = true;
717
718        let result: Result<()> = {
719            let sql = "
720                SELECT b.oid, b.typarray
721                FROM pg_catalog.pg_type a
722                LEFT JOIN pg_catalog.pg_type b ON b.oid = a.typelem
723                WHERE a.typcategory = 'A'
724                GROUP BY b.oid, b.typarray
725                ORDER BY b.oid
726            ";
727            let results = self.exec(sql, None)?;
728            let result_set = results
729                .into_iter()
730                .next()
731                .ok_or_else(|| anyhow!("array type discovery returned no results"))?;
732
733            for row in result_set.rows {
734                let map = match row {
735                    Value::Object(map) => map,
736                    _ => continue,
737                };
738                let element_oid = value_to_i32(map.get("oid")).unwrap_or(0);
739                let array_oid = value_to_i32(map.get("typarray")).unwrap_or(0);
740
741                if element_oid == 0 || array_oid == 0 {
742                    continue;
743                }
744
745                let element_parser = self.parsers.get(&element_oid).cloned();
746                let element_serializer = self.serializers.get(&element_oid).cloned();
747
748                let parser_clone = element_parser.clone();
749                let array_parser: TypeParser = Arc::new(move |text: &str, _| {
750                    parse_array_text(text, parser_clone.clone(), element_oid, array_oid)
751                });
752                self.parsers.insert(array_oid, array_parser);
753
754                let serializer_clone = element_serializer.clone();
755                let array_serializer: Serializer = Arc::new(move |value: &Value| {
756                    serialize_array_value(value, serializer_clone.clone(), array_oid)
757                });
758                self.serializers.insert(array_oid, array_serializer);
759            }
760            Ok(())
761        };
762
763        if let Err(err) = result {
764            self.array_types_initialized = prev;
765            Err(err)
766        } else {
767            Ok(())
768        }
769    }
770
771    fn run_exec_command(&mut self, sql: &str) -> Result<()> {
772        self.exec_internal(sql, None).map(|_| ())
773    }
774
775    fn handle_blob_input(&mut self, blob: Option<&Vec<u8>>) -> Result<()> {
776        let path = self.dev_blob_path();
777        if let Some(bytes) = blob {
778            if let Some(parent) = path.parent() {
779                fs::create_dir_all(parent).with_context(|| {
780                    format!("failed to create blob directory {}", parent.display())
781                })?;
782            }
783            fs::write(&path, bytes)
784                .with_context(|| format!("write blob input to {}", path.display()))?;
785            self.blob_input_provided = true;
786        } else {
787            self.blob_input_provided = false;
788            let _ = fs::remove_file(&path);
789        }
790        Ok(())
791    }
792
793    fn dev_blob_path(&self) -> PathBuf {
794        self.pg.paths().pgroot.join("dev/blob")
795    }
796
797    fn cleanup_blob(&mut self) -> Result<()> {
798        Ok(())
799    }
800
801    fn get_written_blob(&mut self) -> Result<Option<Vec<u8>>> {
802        let path = self.dev_blob_path();
803
804        if self.blob_input_provided {
805            self.blob_input_provided = false;
806            let _ = fs::remove_file(&path);
807            return Ok(None);
808        }
809
810        match fs::read(&path) {
811            Ok(data) => {
812                self.blob_input_provided = false;
813                let _ = fs::remove_file(&path);
814                if data.is_empty() {
815                    Ok(None)
816                } else {
817                    Ok(Some(data))
818                }
819            }
820            Err(err) => {
821                if err.kind() == io::ErrorKind::NotFound {
822                    self.blob_input_provided = false;
823                    Ok(None)
824                } else {
825                    Err(err).with_context(|| format!("read blob output from {}", path.display()))
826                }
827            }
828        }
829    }
830
831    fn check_ready(&self) -> Result<()> {
832        if self.closing {
833            bail!("Pglite instance is closing");
834        }
835        if self.closed {
836            bail!("Pglite instance is closed");
837        }
838        if !self.ready {
839            bail!("Pglite instance is not ready");
840        }
841        Ok(())
842    }
843}
844
845impl Drop for Pglite {
846    fn drop(&mut self) {
847        if !self.closed {
848            let _ = self.close();
849        }
850    }
851}
852
853fn to_postgres_name(input: &str) -> String {
854    if input.starts_with('"') && input.ends_with('"') && input.len() >= 2 {
855        input[1..input.len() - 1].to_string()
856    } else {
857        input.to_lowercase()
858    }
859}
860
861fn value_to_i32(value: Option<&Value>) -> Option<i32> {
862    match value? {
863        Value::Number(number) => number.as_i64().map(|value| value as i32),
864        Value::String(string) => string.parse::<i32>().ok(),
865        _ => None,
866    }
867}
868
869/// Transaction handle used within [`Pglite::transaction`].
870pub struct Transaction<'a> {
871    client: &'a mut Pglite,
872    closed: bool,
873}
874
875impl<'a> Transaction<'a> {
876    fn new(client: &'a mut Pglite) -> Self {
877        Self {
878            client,
879            closed: false,
880        }
881    }
882
883    fn commit_internal(&mut self) -> Result<()> {
884        self.ensure_open()?;
885        self.client.exec_internal("COMMIT", None)?;
886        self.closed = true;
887        Ok(())
888    }
889
890    fn rollback_internal(&mut self) -> Result<()> {
891        self.ensure_open()?;
892        self.client.exec_internal("ROLLBACK", None)?;
893        self.closed = true;
894        Ok(())
895    }
896
897    fn ensure_open(&self) -> Result<()> {
898        if self.closed {
899            bail!("transaction is already closed");
900        }
901        Ok(())
902    }
903
904    pub fn query(
905        &mut self,
906        sql: &str,
907        params: &[Value],
908        options: Option<&QueryOptions>,
909    ) -> Result<Results> {
910        self.ensure_open()?;
911        self.client.query_internal(sql, params, options)
912    }
913
914    pub fn exec(&mut self, sql: &str, options: Option<&QueryOptions>) -> Result<Vec<Results>> {
915        self.ensure_open()?;
916        self.client.exec_internal(sql, options)
917    }
918
919    pub fn commit(&mut self) -> Result<()> {
920        self.commit_internal()
921    }
922
923    pub fn rollback(&mut self) -> Result<()> {
924        self.rollback_internal()
925    }
926
927    pub fn is_closed(&self) -> bool {
928        self.closed
929    }
930
931    pub fn closed(&self) -> bool {
932        self.closed
933    }
934}