Skip to main content

pgwire/api/
results.rs

1use std::fmt::Debug;
2use std::pin::Pin;
3use std::sync::{Arc, LazyLock};
4
5use bytes::{BufMut, Bytes, BytesMut};
6use futures::{Stream, StreamExt, future, stream};
7use postgres_types::{IsNull, Oid, ToSql, Type};
8
9use crate::error::{ErrorInfo, PgWireError, PgWireResult};
10use crate::messages::copy::CopyData;
11use crate::messages::data::{
12    DataRow, FORMAT_CODE_BINARY, FORMAT_CODE_TEXT, FieldDescription, RowDescription,
13};
14use crate::messages::response::CommandComplete;
15use crate::types::ToSqlText;
16use crate::types::format::FormatOptions;
17use smol_str::SmolStr;
18
19/// Command completion tag for a query response.
20#[derive(Debug, Eq, PartialEq, Clone)]
21pub struct Tag {
22    command: String,
23    oid: Option<Oid>,
24    rows: Option<usize>,
25}
26
27impl Tag {
28    /// Create a new tag with the given command name.
29    pub fn new(command: &str) -> Tag {
30        Tag {
31            command: command.to_owned(),
32            oid: None,
33            rows: None,
34        }
35    }
36
37    /// Set the number of rows affected.
38    pub fn with_rows(mut self, rows: usize) -> Tag {
39        self.rows = Some(rows);
40        self
41    }
42
43    /// Set the OID of the inserted row.
44    pub fn with_oid(mut self, oid: Oid) -> Tag {
45        self.oid = Some(oid);
46        self
47    }
48}
49
50impl From<Tag> for CommandComplete {
51    fn from(tag: Tag) -> CommandComplete {
52        let tag_string = if let (Some(oid), Some(rows)) = (tag.oid, tag.rows) {
53            format!("{} {oid} {rows}", tag.command)
54        } else if let Some(rows) = tag.rows {
55            format!("{} {rows}", tag.command)
56        } else {
57            tag.command
58        };
59        CommandComplete::new(tag_string)
60    }
61}
62
63/// Describe encoding of a data field.
64#[derive(Debug, Eq, PartialEq, Clone, Copy)]
65pub enum FieldFormat {
66    Text,
67    Binary,
68}
69
70impl FieldFormat {
71    /// Get format code for the encoding.
72    pub fn value(&self) -> i16 {
73        match self {
74            Self::Text => FORMAT_CODE_TEXT,
75            Self::Binary => FORMAT_CODE_BINARY,
76        }
77    }
78
79    /// Parse FieldFormat from format code.
80    ///
81    /// 0 for text format, 1 for binary format. If the input is neither 0 nor 1,
82    /// here we return text as default value.
83    pub fn from(code: i16) -> Self {
84        if code == FORMAT_CODE_BINARY {
85            FieldFormat::Binary
86        } else {
87            FieldFormat::Text
88        }
89    }
90}
91
92/// Options for COPY text format.
93#[derive(Debug, Clone, Eq, PartialEq)]
94pub struct CopyTextOptions {
95    pub delimiter: SmolStr,
96    pub null_string: SmolStr,
97}
98
99impl Default for CopyTextOptions {
100    fn default() -> Self {
101        Self {
102            delimiter: "\t".into(),
103            null_string: "\\N".into(),
104        }
105    }
106}
107
108/// Options for COPY CSV format.
109#[derive(Debug, Clone, Eq, PartialEq)]
110pub struct CopyCsvOptions {
111    pub delimiter: SmolStr,
112    pub quote: SmolStr,
113    pub escape: SmolStr,
114    pub null_string: SmolStr,
115    pub force_quote: Vec<usize>,
116}
117
118impl Default for CopyCsvOptions {
119    fn default() -> Self {
120        Self {
121            delimiter: ",".into(),
122            quote: "\"".into(),
123            escape: "\"".into(),
124            null_string: "".into(),
125            force_quote: vec![],
126        }
127    }
128}
129
130// Default format options that are cloned in `FieldInfo::new` to avoid `Arc` allocation.
131//
132// Using thread-local storage avoids contention when multiple threads concurrently
133// clone the same `Arc<FormatOptions>` in `DataRowEncoder::encode_field`. Each thread
134// now clones its own thread-local instance rather than contending for a shared
135// global instance.
136//
137// This can be made a regular static if we remove format options cloning from
138// `DataRowEncoder::encode_field`.
139//
140// The issue with contention was observed in `examples/bench` benchmark:
141// https://github.com/sunng87/pgwire/pull/366#discussion_r2621917771
142thread_local! {
143    static DEFAULT_FORMAT_OPTIONS: LazyLock<Arc<FormatOptions>> = LazyLock::new(Default::default);
144}
145
146/// Metadata for a single field (column) in a query result.
147#[derive(Debug, new, Eq, PartialEq, Clone)]
148pub struct FieldInfo {
149    name: String,
150    table_id: Option<i32>,
151    column_id: Option<i16>,
152    datatype: Type,
153    format: FieldFormat,
154    #[new(value = "DEFAULT_FORMAT_OPTIONS.with(|opts| Arc::clone(&*opts))")]
155    format_options: Arc<FormatOptions>,
156}
157
158impl FieldInfo {
159    /// Get the field name.
160    pub fn name(&self) -> &str {
161        &self.name
162    }
163
164    /// Get the source table OID, if any.
165    pub fn table_id(&self) -> Option<i32> {
166        self.table_id
167    }
168
169    /// Get the column number within the source table, if any.
170    pub fn column_id(&self) -> Option<i16> {
171        self.column_id
172    }
173
174    /// Get the PostgreSQL type of this field.
175    pub fn datatype(&self) -> &Type {
176        &self.datatype
177    }
178
179    /// Get the field encoding format (text or binary).
180    pub fn format(&self) -> FieldFormat {
181        self.format
182    }
183
184    /// Get the format options for text encoding.
185    pub fn format_options(&self) -> &Arc<FormatOptions> {
186        &self.format_options
187    }
188
189    /// Set custom format options for text encoding.
190    pub fn with_format_options(mut self, format_options: Arc<FormatOptions>) -> Self {
191        self.format_options = format_options;
192        self
193    }
194}
195
196impl From<&FieldInfo> for FieldDescription {
197    fn from(fi: &FieldInfo) -> Self {
198        FieldDescription::new(
199            fi.name.clone(),           // name
200            fi.table_id.unwrap_or(0),  // table_id
201            fi.column_id.unwrap_or(0), // column_id
202            fi.datatype.oid(),         // type_id
203            // TODO: type size and modifier
204            0,
205            0,
206            fi.format.value(),
207        )
208    }
209}
210
211impl From<FieldDescription> for FieldInfo {
212    fn from(value: FieldDescription) -> Self {
213        FieldInfo::new(
214            value.name,
215            Some(value.table_id),
216            Some(value.column_id),
217            Type::from_oid(value.type_id).unwrap_or(Type::UNKNOWN),
218            FieldFormat::from(value.format_code),
219        )
220    }
221}
222
223pub(crate) fn into_row_description(fields: &[FieldInfo]) -> RowDescription {
224    RowDescription::new(fields.iter().map(Into::into).collect())
225}
226
227/// Type alias for a boxed, pinned, sendable stream of data rows.
228pub type SendableRowStream = Pin<Box<dyn Stream<Item = PgWireResult<DataRow>> + Send>>;
229
230/// Type alias for a boxed, pinned, sendable stream of copy data.
231pub type SendableCopyDataStream = Pin<Box<dyn Stream<Item = PgWireResult<CopyData>> + Send>>;
232
233/// Response containing row data for a SELECT-style query.
234#[non_exhaustive]
235pub struct QueryResponse {
236    pub command_tag: String,
237    pub row_schema: Arc<Vec<FieldInfo>>,
238    pub data_rows: SendableRowStream,
239}
240
241impl Debug for QueryResponse {
242    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
243        f.debug_struct("QueryResponse")
244            .field("command_tag", &self.command_tag)
245            .field("row_schema", &self.row_schema)
246            .finish()
247    }
248}
249
250impl QueryResponse {
251    /// Create `QueryResponse` from column schemas and stream of data row.
252    /// Sets "SELECT" as the command tag.
253    pub fn new<S>(field_defs: Arc<Vec<FieldInfo>>, row_stream: S) -> QueryResponse
254    where
255        S: Stream<Item = PgWireResult<DataRow>> + Send + 'static,
256    {
257        QueryResponse {
258            command_tag: "SELECT".to_owned(),
259            row_schema: field_defs,
260            data_rows: Box::pin(row_stream),
261        }
262    }
263
264    /// Get the command tag
265    pub fn command_tag(&self) -> &str {
266        &self.command_tag
267    }
268
269    /// Set the command tag
270    pub fn set_command_tag(&mut self, command_tag: &str) {
271        command_tag.clone_into(&mut self.command_tag);
272    }
273
274    /// Get schema of columns
275    pub fn row_schema(&self) -> Arc<Vec<FieldInfo>> {
276        self.row_schema.clone()
277    }
278
279    /// Get access to data rows stream
280    pub fn data_rows(&mut self) -> &mut SendableRowStream {
281        &mut self.data_rows
282    }
283}
284
285/// Encoder for building `DataRow` messages field by field.
286pub struct DataRowEncoder {
287    schema: Arc<Vec<FieldInfo>>,
288    row_buffer: BytesMut,
289    col_index: usize,
290}
291
292impl DataRowEncoder {
293    /// New DataRowEncoder from schema of column
294    pub fn new(fields: Arc<Vec<FieldInfo>>) -> DataRowEncoder {
295        Self {
296            schema: fields,
297            row_buffer: BytesMut::with_capacity(128),
298            col_index: 0,
299        }
300    }
301
302    /// Encode value with custom type and format
303    ///
304    /// This encode function ignores data type and format information from
305    /// schema of this encoder.
306    pub fn encode_field_with_type_and_format<T>(
307        &mut self,
308        value: &T,
309        data_type: &Type,
310        format: FieldFormat,
311        format_options: &FormatOptions,
312    ) -> PgWireResult<()>
313    where
314        T: ToSql + ToSqlText + Sized,
315    {
316        // remember the position of the 4-byte length field
317        let prev_index = self.row_buffer.len();
318        // write value length as -1 ahead of time
319        self.row_buffer.put_i32(-1);
320
321        let is_null = if format == FieldFormat::Text {
322            value.to_sql_text(data_type, &mut self.row_buffer, format_options)?
323        } else {
324            value.to_sql(data_type, &mut self.row_buffer)?
325        };
326
327        if let IsNull::No = is_null {
328            let value_length = self.row_buffer.len() - prev_index - 4;
329            let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
330            length_bytes.put_i32(value_length as i32);
331        }
332
333        self.col_index += 1;
334
335        Ok(())
336    }
337
338    /// Encode value using type and format, defined by schema
339    ///
340    /// Panic when encoding more columns than provided as schema.
341    pub fn encode_field<T>(&mut self, value: &T) -> PgWireResult<()>
342    where
343        T: ToSql + ToSqlText + Sized,
344    {
345        let field = &self.schema[self.col_index];
346
347        let data_type = field.datatype().clone();
348        let format = field.format();
349        let format_options = field.format_options().clone();
350
351        self.encode_field_with_type_and_format(value, &data_type, format, format_options.as_ref())
352    }
353
354    #[deprecated(
355        since = "0.37.0",
356        note = "DataRowEncoder is reusable since 0.37, use `take_row() instead`"
357    )]
358    pub fn finish(self) -> PgWireResult<DataRow> {
359        Ok(DataRow::new(self.row_buffer, self.col_index as i16))
360    }
361
362    /// Takes the current row from the encoder, resetting the encoder for reuse.
363    ///
364    /// This method splits the inner buffer, taking the current row data and leaving the
365    /// encoder with an empty buffer (but retaining the capacity) enabling buffer reuse.
366    pub fn take_row(&mut self) -> DataRow {
367        let row = DataRow::new(self.row_buffer.split(), self.col_index as i16);
368        self.col_index = 0;
369        row
370    }
371}
372
373/// Internal COPY format representation.
374#[derive(Debug, Clone, Eq, PartialEq)]
375enum CopyFormat {
376    Binary,
377    Text {
378        delimiter: SmolStr,
379        null_string: SmolStr,
380    },
381    Csv {
382        delimiter: SmolStr,
383        quote: SmolStr,
384        escape: SmolStr,
385        null_string: SmolStr,
386        force_quote: Vec<usize>,
387    },
388}
389
390/// Encoder for COPY operations.
391///
392/// This encoder produces CopyData messages for PGCOPY binary, text, and CSV formats.
393pub struct CopyEncoder {
394    schema: Arc<Vec<FieldInfo>>,
395    buffer: BytesMut,
396    format: CopyFormat,
397    col_index: usize,
398    header_written: bool,
399}
400
401impl CopyEncoder {
402    /// Create a new binary format COPY encoder.
403    pub fn new_binary(schema: Arc<Vec<FieldInfo>>) -> Self {
404        Self {
405            schema,
406            buffer: BytesMut::with_capacity(128),
407            format: CopyFormat::Binary,
408            col_index: 0,
409            header_written: false,
410        }
411    }
412
413    /// Create a new text format COPY encoder.
414    pub fn new_text(schema: Arc<Vec<FieldInfo>>, options: CopyTextOptions) -> Self {
415        Self {
416            schema,
417            buffer: BytesMut::with_capacity(128),
418            format: CopyFormat::Text {
419                delimiter: options.delimiter,
420                null_string: options.null_string,
421            },
422            col_index: 0,
423            header_written: false,
424        }
425    }
426
427    /// Create a new CSV format COPY encoder.
428    pub fn new_csv(schema: Arc<Vec<FieldInfo>>, options: CopyCsvOptions) -> Self {
429        Self {
430            schema,
431            buffer: BytesMut::with_capacity(128),
432            format: CopyFormat::Csv {
433                delimiter: options.delimiter,
434                quote: options.quote,
435                escape: options.escape,
436                null_string: options.null_string,
437                force_quote: options.force_quote,
438            },
439            col_index: 0,
440            header_written: false,
441        }
442    }
443
444    /// Encode a field value.
445    ///
446    /// This method uses the type and format information from the schema.
447    pub fn encode_field<T>(&mut self, value: &T) -> PgWireResult<()>
448    where
449        T: ToSql + ToSqlText + Sized,
450    {
451        let datatype = self.schema[self.col_index].datatype().clone();
452        let col_index = self.col_index;
453        let num_fields = self.schema.len();
454
455        match &self.format {
456            CopyFormat::Binary => self.encode_field_binary(value, &datatype)?,
457            CopyFormat::Text { .. } => {
458                let is_last = col_index == num_fields - 1;
459                self.encode_field_text(value, &datatype, is_last)?;
460            }
461            CopyFormat::Csv { .. } => {
462                let is_last = col_index == num_fields - 1;
463                self.encode_field_csv(value, &datatype, is_last)?;
464            }
465        }
466
467        self.col_index += 1;
468        Ok(())
469    }
470
471    /// Encode a field in binary format (same as DataRow encoding).
472    fn encode_field_binary<T>(&mut self, value: &T, datatype: &Type) -> PgWireResult<()>
473    where
474        T: ToSql + ToSqlText,
475    {
476        let prev_index = self.buffer.len();
477        self.buffer.put_i32(-1);
478
479        let is_null = value.to_sql(datatype, &mut self.buffer)?;
480
481        if let IsNull::No = is_null {
482            let value_length = self.buffer.len() - prev_index - 4;
483            let mut length_bytes = &mut self.buffer[prev_index..(prev_index + 4)];
484            length_bytes.put_i32(value_length as i32);
485        }
486
487        Ok(())
488    }
489
490    /// Encode a field in text format.
491    fn encode_field_text<T>(
492        &mut self,
493        value: &T,
494        datatype: &Type,
495        is_last: bool,
496    ) -> PgWireResult<()>
497    where
498        T: ToSqlText,
499    {
500        if let CopyFormat::Text {
501            delimiter,
502            null_string,
503        } = &self.format
504        {
505            let mut temp_buffer = BytesMut::new();
506            let is_null =
507                value.to_sql_text(datatype, &mut temp_buffer, &FormatOptions::default())?;
508
509            if let IsNull::Yes = is_null {
510                self.buffer.put_slice(null_string.as_bytes());
511            } else {
512                // Backslash escape special characters
513                for &byte in temp_buffer.as_ref() {
514                    match byte {
515                        b'\n' => {
516                            self.buffer.put_slice(b"\\n");
517                        }
518                        b'\r' => {
519                            self.buffer.put_slice(b"\\r");
520                        }
521                        b'\t' => {
522                            self.buffer.put_slice(b"\\t");
523                        }
524                        b'\\' => {
525                            self.buffer.put_slice(b"\\\\");
526                        }
527                        _b if byte == delimiter.as_bytes()[0] => {
528                            self.buffer.put_u8(b'\\');
529                            self.buffer.put_u8(byte);
530                        }
531                        _ => {
532                            self.buffer.put_u8(byte);
533                        }
534                    }
535                }
536            }
537
538            // Add delimiter between fields
539            if !is_last {
540                self.buffer.put_slice(delimiter.as_bytes());
541            }
542
543            Ok(())
544        } else {
545            Err(PgWireError::IoError(std::io::Error::new(
546                std::io::ErrorKind::InvalidInput,
547                "Text format expected",
548            )))
549        }
550    }
551
552    /// Encode a field in CSV format.
553    fn encode_field_csv<T>(&mut self, value: &T, datatype: &Type, is_last: bool) -> PgWireResult<()>
554    where
555        T: ToSqlText,
556    {
557        if let CopyFormat::Csv {
558            delimiter,
559            quote,
560            null_string,
561            force_quote,
562            escape: _,
563        } = &self.format
564        {
565            let col_index = self.col_index;
566            let mut temp_buffer = BytesMut::new();
567            let is_null =
568                value.to_sql_text(datatype, &mut temp_buffer, &FormatOptions::default())?;
569
570            let delimiter_byte = delimiter.as_bytes()[0];
571            let quote_byte = quote.as_bytes()[0];
572            let null_string_bytes = null_string.as_bytes();
573
574            let should_quote = force_quote.contains(&col_index)
575                || match is_null {
576                    IsNull::Yes => false, // NULL values are never quoted in CSV (handled by null_string)
577                    IsNull::No => {
578                        let data = temp_buffer.as_ref();
579                        data.contains(&delimiter_byte)
580                            || data.contains(&quote_byte)
581                            || data.contains(&b'\n')
582                            || data.contains(&b'\r')
583                            || (!null_string_bytes.is_empty()
584                                && data
585                                    .windows(null_string_bytes.len())
586                                    .any(|w| w == null_string_bytes))
587                    }
588                };
589
590            if let IsNull::Yes = is_null {
591                self.buffer.put_slice(null_string_bytes);
592            } else if should_quote {
593                self.buffer.put_u8(quote_byte);
594
595                for &byte in temp_buffer.as_ref() {
596                    if byte == quote_byte {
597                        // Double the quote character
598                        self.buffer.put_u8(byte);
599                    }
600                    self.buffer.put_u8(byte);
601                }
602
603                self.buffer.put_u8(quote_byte);
604            } else {
605                self.buffer.put_slice(temp_buffer.as_ref());
606            }
607
608            // Add delimiter between fields
609            if !is_last {
610                self.buffer.put_slice(delimiter.as_bytes());
611            }
612
613            Ok(())
614        } else {
615            Err(PgWireError::IoError(std::io::Error::new(
616                std::io::ErrorKind::InvalidInput,
617                "CSV format expected",
618            )))
619        }
620    }
621
622    /// Take the current row as a CopyData message.
623    ///
624    /// For binary format: first call includes PGCOPY header.
625    /// For text/CSV format: each call returns one row with a trailing newline.
626    pub fn take_copy(&mut self) -> CopyData {
627        match &self.format {
628            CopyFormat::Binary => {
629                if !self.header_written {
630                    // Prepend header to field data
631                    let field_data = self.buffer.split();
632                    self.write_pgcop_header();
633                    self.buffer.put_i16(self.schema.len() as i16);
634                    self.buffer.extend_from_slice(&field_data);
635                    self.header_written = true;
636                } else {
637                    // Prepend field count before field data
638                    let field_data = self.buffer.split();
639                    self.buffer.put_i16(self.schema.len() as i16);
640                    self.buffer.extend_from_slice(&field_data);
641                }
642            }
643            CopyFormat::Text { .. } | CopyFormat::Csv { .. } => {
644                // Add newline at end of row
645                self.buffer.put_u8(b'\n');
646            }
647        }
648
649        self.col_index = 0;
650        CopyData::new(self.buffer.split().freeze())
651    }
652
653    /// Finish the COPY operation of binary format.
654    ///
655    /// For binary format: returns trailer (-1).
656    /// Note that this trailer is automatically appended to stream if you use
657    /// `CopyResponse` API.
658    pub fn finish_copy_binary() -> CopyData {
659        CopyData::new(Bytes::from_static(&[0xFF, 0xFF]))
660    }
661
662    /// Write PGCOPY binary header.
663    fn write_pgcop_header(&mut self) {
664        self.buffer.put_slice(b"PGCOPY\n\xFF\r\n\x00");
665        self.buffer.put_i32(0); // Flags (no OIDs)
666        self.buffer.put_i32(0); // Header extension length
667    }
668}
669
670/// Get response data for a `Describe` command
671pub trait DescribeResponse {
672    /// Get parameter types for the described statement.
673    fn parameters(&self) -> Option<&[Type]>;
674
675    /// Get result field descriptions.
676    fn fields(&self) -> &[FieldInfo];
677
678    /// Create an no_data instance of `DescribeResponse`. This is typically used
679    /// when client tries to describe an empty query.
680    fn no_data() -> Self;
681
682    /// Return true if the `DescribeResponse` is empty/nodata
683    fn is_no_data(&self) -> bool;
684}
685
686/// Response for frontend describe statement requests.
687#[non_exhaustive]
688#[derive(Debug, new)]
689pub struct DescribeStatementResponse {
690    pub parameters: Vec<Type>,
691    pub fields: Vec<FieldInfo>,
692}
693
694impl DescribeResponse for DescribeStatementResponse {
695    fn parameters(&self) -> Option<&[Type]> {
696        Some(self.parameters.as_ref())
697    }
698
699    fn fields(&self) -> &[FieldInfo] {
700        &self.fields
701    }
702
703    /// Create an no_data instance of `DescribeStatementResponse`. This is typically used
704    /// when client tries to describe an empty query.
705    fn no_data() -> Self {
706        DescribeStatementResponse {
707            parameters: vec![],
708            fields: vec![],
709        }
710    }
711
712    /// Return true if the `DescribeStatementResponse` is empty/nodata
713    fn is_no_data(&self) -> bool {
714        self.parameters.is_empty() && self.fields.is_empty()
715    }
716}
717
718/// Response for frontend describe portal requests.
719#[non_exhaustive]
720#[derive(Debug, new)]
721pub struct DescribePortalResponse {
722    pub fields: Vec<FieldInfo>,
723}
724
725impl DescribeResponse for DescribePortalResponse {
726    fn parameters(&self) -> Option<&[Type]> {
727        None
728    }
729
730    fn fields(&self) -> &[FieldInfo] {
731        &self.fields
732    }
733
734    /// Create an no_data instance of `DescribePortalResponse`. This is typically used
735    /// when client tries to describe an empty query.
736    fn no_data() -> Self {
737        DescribePortalResponse { fields: vec![] }
738    }
739
740    /// Return true if the `DescribePortalResponse` is empty/nodata
741    fn is_no_data(&self) -> bool {
742        self.fields.is_empty()
743    }
744}
745
746/// Response for copy operations
747#[non_exhaustive]
748pub struct CopyResponse {
749    pub format: i8,
750    pub columns: usize,
751    pub data_stream: SendableCopyDataStream,
752}
753
754impl std::fmt::Debug for CopyResponse {
755    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
756        f.debug_struct("CopyResponse")
757            .field("format", &self.format)
758            .field("columns", &self.columns)
759            .finish()
760    }
761}
762
763impl CopyResponse {
764    /// Create a new copy response. Binary format automatically appends a trailer.
765    pub fn new<S>(format: i8, columns: usize, data_stream: S) -> CopyResponse
766    where
767        S: Stream<Item = PgWireResult<CopyData>> + Send + 'static,
768    {
769        if format == 1 {
770            let data_stream = data_stream.chain(stream::once(future::ready(Ok(
771                CopyEncoder::finish_copy_binary(),
772            ))));
773            CopyResponse {
774                format,
775                columns,
776                data_stream: Box::pin(data_stream),
777            }
778        } else {
779            CopyResponse {
780                format,
781                columns,
782                data_stream: Box::pin(data_stream),
783            }
784        }
785    }
786
787    /// Get mutable access to the underlying copy data stream.
788    pub fn data_stream(&mut self) -> &mut SendableCopyDataStream {
789        &mut self.data_stream
790    }
791
792    /// Get the format code for each column.
793    pub fn column_formats(&self) -> Vec<i16> {
794        (0..self.columns).map(|_| self.format as i16).collect()
795    }
796}
797
798/// Query response types:
799///
800/// * Query: the response contains data rows
801/// * Execution: response for ddl/dml execution
802/// * Error: error response
803/// * EmptyQuery: when client sends an empty query
804/// * TransactionStart: indicate previous statement just started a transaction
805/// * TransactionEnd: indicate previous statement just ended a transaction
806/// * CopyIn: response for a copy-in request
807/// * CopyOut: response for a copy-out request
808/// * CopuBoth: response for a copy-both request
809#[derive(Debug)]
810pub enum Response {
811    EmptyQuery,
812    Query(QueryResponse),
813    Execution(Tag),
814    TransactionStart(Tag),
815    TransactionEnd(Tag),
816    Error(Box<ErrorInfo>),
817    CopyIn(CopyResponse),
818    CopyOut(CopyResponse),
819    CopyBoth(CopyResponse),
820}
821
822#[cfg(test)]
823mod test {
824
825    use super::*;
826
827    #[test]
828    fn test_command_complete() {
829        let tag = Tag::new("INSERT").with_rows(100);
830        let cc = CommandComplete::from(tag);
831
832        assert_eq!(cc.tag, "INSERT 100");
833
834        let tag = Tag::new("INSERT").with_oid(0).with_rows(100);
835        let cc = CommandComplete::from(tag);
836
837        assert_eq!(cc.tag, "INSERT 0 100");
838    }
839
840    #[test]
841    #[cfg(feature = "pg-type-chrono")]
842    fn test_data_row_encoder() {
843        use std::time::SystemTime;
844
845        let schema = Arc::new(vec![
846            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
847            FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
848            FieldInfo::new("ts".into(), None, None, Type::TIMESTAMP, FieldFormat::Text),
849        ]);
850        let now = SystemTime::now();
851        let mut encoder = DataRowEncoder::new(schema);
852        encoder.encode_field(&2001).unwrap();
853        encoder.encode_field(&"udev").unwrap();
854        encoder.encode_field(&now).unwrap();
855
856        let row = encoder.take_row();
857
858        assert_eq!(row.field_count, 3);
859
860        let mut expected = BytesMut::new();
861        expected.put_i32(4);
862        expected.put_slice("2001".as_bytes());
863        expected.put_i32(4);
864        expected.put_slice("udev".as_bytes());
865        expected.put_i32(26);
866        let _ = now.to_sql_text(&Type::TIMESTAMP, &mut expected, &FormatOptions::default());
867        assert_eq!(row.data, expected);
868    }
869
870    #[test]
871    fn test_copy_text_options_default() {
872        let opts = CopyTextOptions::default();
873        assert_eq!(opts.delimiter, "\t");
874        assert_eq!(opts.null_string, "\\N");
875    }
876
877    #[test]
878    fn test_copy_csv_options_default() {
879        let opts = CopyCsvOptions::default();
880        assert_eq!(opts.delimiter, ",");
881        assert_eq!(opts.quote, "\"");
882        assert_eq!(opts.escape, "\"");
883        assert_eq!(opts.null_string, "");
884        assert!(opts.force_quote.is_empty());
885    }
886
887    #[test]
888    fn test_copy_binary_header() {
889        let schema = Arc::new(vec![FieldInfo::new(
890            "id".into(),
891            None,
892            None,
893            Type::INT4,
894            FieldFormat::Binary,
895        )]);
896        let mut encoder = CopyEncoder::new_binary(schema.clone());
897
898        // First take_copy should include header
899        encoder.encode_field(&42).unwrap();
900        let copy_data = encoder.take_copy();
901
902        let data = copy_data.data.as_ref();
903        assert_eq!(&data[0..11], b"PGCOPY\n\xFF\r\n\0");
904
905        // Check flags (4 bytes, no OIDs = 0)
906        assert_eq!(&data[11..15], &[0x00, 0x00, 0x00, 0x00]);
907
908        // Check extension length (4 bytes, no extensions = 0)
909        assert_eq!(&data[15..19], &[0x00, 0x00, 0x00, 0x00]);
910
911        // Check field count (2 bytes)
912        assert_eq!(&data[19..21], &[0x00, 0x01]); // 1 field
913
914        // Check field length (4 bytes)
915        assert_eq!(&data[21..25], &[0x00, 0x00, 0x00, 0x04]); // 4 bytes
916
917        // Check field value (42 in network byte order)
918        assert_eq!(&data[25..29], &[0x00, 0x00, 0x00, 0x2A]);
919    }
920
921    #[test]
922    fn test_copy_binary_trailer() {
923        let copy_data = CopyEncoder::finish_copy_binary();
924        let data = copy_data.data.as_ref();
925
926        // Trailer is -1 as i16 (0xFFFF in network byte order)
927        assert_eq!(data, &[0xFF, 0xFF]);
928    }
929
930    #[test]
931    fn test_copy_text_default_delimiter() {
932        let schema = Arc::new(vec![
933            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
934            FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
935        ]);
936        let mut encoder = CopyEncoder::new_text(schema, CopyTextOptions::default());
937
938        encoder.encode_field(&1).unwrap();
939        encoder.encode_field(&"Alice").unwrap();
940        let copy_data = encoder.take_copy();
941
942        // Expected: "1\tAlice\n"
943        assert_eq!(copy_data.data.as_ref(), b"1\tAlice\n");
944    }
945
946    #[test]
947    fn test_copy_text_custom_delimiter() {
948        let schema = Arc::new(vec![
949            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
950            FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
951        ]);
952        let mut encoder = CopyEncoder::new_text(
953            schema,
954            CopyTextOptions {
955                delimiter: "|".into(),
956                null_string: "\\N".into(),
957            },
958        );
959
960        encoder.encode_field(&1).unwrap();
961        encoder.encode_field(&"Alice").unwrap();
962        let copy_data = encoder.take_copy();
963
964        // Expected: "1|Alice\n"
965        assert_eq!(copy_data.data.as_ref(), b"1|Alice\n");
966    }
967
968    #[test]
969    fn test_copy_text_null_handling() {
970        let schema = Arc::new(vec![
971            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
972            FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
973        ]);
974        let mut encoder = CopyEncoder::new_text(schema, CopyTextOptions::default());
975
976        encoder.encode_field(&1).unwrap();
977        encoder.encode_field(&None::<String>).unwrap();
978        let copy_data = encoder.take_copy();
979
980        // Expected: "1\t\\N\n"
981        assert_eq!(copy_data.data.as_ref(), b"1\t\\N\n");
982    }
983
984    #[test]
985    fn test_copy_text_backslash_escaping() {
986        let schema = Arc::new(vec![FieldInfo::new(
987            "value".into(),
988            None,
989            None,
990            Type::VARCHAR,
991            FieldFormat::Text,
992        )]);
993        let mut encoder = CopyEncoder::new_text(schema, CopyTextOptions::default());
994
995        encoder.encode_field(&"a\nb\tc\rd\\e").unwrap();
996        let copy_data = encoder.take_copy();
997
998        // Expected: "a\\nb\\tc\\rd\\\\e\n"
999        assert_eq!(copy_data.data.as_ref(), b"a\\nb\\tc\\rd\\\\e\n");
1000    }
1001
1002    #[test]
1003    fn test_copy_csv_default() {
1004        let schema = Arc::new(vec![
1005            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
1006            FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
1007        ]);
1008        let mut encoder = CopyEncoder::new_csv(schema, CopyCsvOptions::default());
1009
1010        encoder.encode_field(&1).unwrap();
1011        encoder.encode_field(&"Alice").unwrap();
1012        let copy_data = encoder.take_copy();
1013
1014        // Expected: "1,Alice\n"
1015        assert_eq!(copy_data.data.as_ref(), b"1,Alice\n");
1016    }
1017
1018    #[test]
1019    fn test_copy_csv_quoting() {
1020        let schema = Arc::new(vec![FieldInfo::new(
1021            "value".into(),
1022            None,
1023            None,
1024            Type::VARCHAR,
1025            FieldFormat::Text,
1026        )]);
1027        let mut encoder = CopyEncoder::new_csv(schema, CopyCsvOptions::default());
1028
1029        encoder.encode_field(&"a,b\"c\nd").unwrap();
1030        let copy_data = encoder.take_copy();
1031
1032        // Should be quoted because it contains comma and newline
1033        assert_eq!(copy_data.data.as_ref(), b"\"a,b\"\"c\nd\"\n");
1034    }
1035
1036    #[test]
1037    fn test_copy_csv_force_quote() {
1038        let schema = Arc::new(vec![
1039            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
1040            FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
1041        ]);
1042        let mut encoder = CopyEncoder::new_csv(
1043            schema,
1044            CopyCsvOptions {
1045                force_quote: vec![1],
1046                ..Default::default()
1047            },
1048        );
1049
1050        encoder.encode_field(&1).unwrap();
1051        encoder.encode_field(&"Alice").unwrap();
1052        let copy_data = encoder.take_copy();
1053
1054        // Expected: "1,\"Alice\"\n" - second column force quoted
1055        assert_eq!(copy_data.data.as_ref(), b"1,\"Alice\"\n");
1056    }
1057
1058    #[test]
1059    fn test_copy_binary_multiple_rows() {
1060        let schema = Arc::new(vec![
1061            FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Binary),
1062            FieldInfo::new(
1063                "name".into(),
1064                None,
1065                None,
1066                Type::VARCHAR,
1067                FieldFormat::Binary,
1068            ),
1069        ]);
1070        let mut encoder = CopyEncoder::new_binary(schema);
1071
1072        // First row
1073        encoder.encode_field(&1i32).unwrap();
1074        encoder.encode_field(&"Alice".to_string()).unwrap();
1075        let copy_data1 = encoder.take_copy();
1076
1077        // Second row
1078        encoder.encode_field(&2i32).unwrap();
1079        encoder.encode_field(&"Bob".to_string()).unwrap();
1080        let copy_data2 = encoder.take_copy();
1081
1082        // Verify first row format
1083        let data1 = copy_data1.data.as_ref();
1084
1085        // Header is 19 bytes, then field count (2 bytes)
1086        assert_eq!(&data1[19..21], &[0x00, 0x02]); // 2 fields
1087
1088        // Verify second row format
1089        let data2 = copy_data2.data.as_ref();
1090
1091        // Field count should be at the beginning (no header on second row)
1092        assert_eq!(&data2[0..2], &[0x00, 0x02]); // 2 fields
1093    }
1094}