1use std::fmt::Debug;
2use std::pin::Pin;
3use std::sync::{Arc, LazyLock};
4
5use bytes::{BufMut, BytesMut};
6use futures::Stream;
7use postgres_types::{IsNull, Oid, ToSql, Type};
8
9use crate::error::{ErrorInfo, PgWireResult};
10use crate::messages::data::{
11 DataRow, FieldDescription, RowDescription, FORMAT_CODE_BINARY, FORMAT_CODE_TEXT,
12};
13use crate::messages::response::CommandComplete;
14use crate::types::format::FormatOptions;
15use crate::types::ToSqlText;
16
17#[derive(Debug, Eq, PartialEq, Clone)]
18pub struct Tag {
19 command: String,
20 oid: Option<Oid>,
21 rows: Option<usize>,
22}
23
24impl Tag {
25 pub fn new(command: &str) -> Tag {
26 Tag {
27 command: command.to_owned(),
28 oid: None,
29 rows: None,
30 }
31 }
32
33 pub fn with_rows(mut self, rows: usize) -> Tag {
34 self.rows = Some(rows);
35 self
36 }
37
38 pub fn with_oid(mut self, oid: Oid) -> Tag {
39 self.oid = Some(oid);
40 self
41 }
42}
43
44impl From<Tag> for CommandComplete {
45 fn from(tag: Tag) -> CommandComplete {
46 let tag_string = if let (Some(oid), Some(rows)) = (tag.oid, tag.rows) {
47 format!("{} {oid} {rows}", tag.command)
48 } else if let Some(rows) = tag.rows {
49 format!("{} {rows}", tag.command)
50 } else {
51 tag.command
52 };
53 CommandComplete::new(tag_string)
54 }
55}
56
57#[derive(Debug, Eq, PartialEq, Clone, Copy)]
59pub enum FieldFormat {
60 Text,
61 Binary,
62}
63
64impl FieldFormat {
65 pub fn value(&self) -> i16 {
67 match self {
68 Self::Text => FORMAT_CODE_TEXT,
69 Self::Binary => FORMAT_CODE_BINARY,
70 }
71 }
72
73 pub fn from(code: i16) -> Self {
78 if code == FORMAT_CODE_BINARY {
79 FieldFormat::Binary
80 } else {
81 FieldFormat::Text
82 }
83 }
84}
85
86thread_local! {
99 static DEFAULT_FORMAT_OPTIONS: LazyLock<Arc<FormatOptions>> = LazyLock::new(Default::default);
100}
101
102#[derive(Debug, new, Eq, PartialEq, Clone)]
103pub struct FieldInfo {
104 name: String,
105 table_id: Option<i32>,
106 column_id: Option<i16>,
107 datatype: Type,
108 format: FieldFormat,
109 #[new(value = "DEFAULT_FORMAT_OPTIONS.with(|opts| Arc::clone(&*opts))")]
110 format_options: Arc<FormatOptions>,
111}
112
113impl FieldInfo {
114 pub fn name(&self) -> &str {
115 &self.name
116 }
117
118 pub fn table_id(&self) -> Option<i32> {
119 self.table_id
120 }
121
122 pub fn column_id(&self) -> Option<i16> {
123 self.column_id
124 }
125
126 pub fn datatype(&self) -> &Type {
127 &self.datatype
128 }
129
130 pub fn format(&self) -> FieldFormat {
131 self.format
132 }
133
134 pub fn format_options(&self) -> &Arc<FormatOptions> {
135 &self.format_options
136 }
137
138 pub fn with_format_options(mut self, format_options: Arc<FormatOptions>) -> Self {
139 self.format_options = format_options;
140 self
141 }
142}
143
144impl From<&FieldInfo> for FieldDescription {
145 fn from(fi: &FieldInfo) -> Self {
146 FieldDescription::new(
147 fi.name.clone(), fi.table_id.unwrap_or(0), fi.column_id.unwrap_or(0), fi.datatype.oid(), 0,
153 0,
154 fi.format.value(),
155 )
156 }
157}
158
159impl From<FieldDescription> for FieldInfo {
160 fn from(value: FieldDescription) -> Self {
161 FieldInfo::new(
162 value.name,
163 Some(value.table_id),
164 Some(value.column_id),
165 Type::from_oid(value.type_id).unwrap_or(Type::UNKNOWN),
166 FieldFormat::from(value.format_code),
167 )
168 }
169}
170
171pub(crate) fn into_row_description(fields: &[FieldInfo]) -> RowDescription {
172 RowDescription::new(fields.iter().map(Into::into).collect())
173}
174
175pub type SendableRowStream = Pin<Box<dyn Stream<Item = PgWireResult<DataRow>> + Send>>;
176
177pub struct QueryResponse {
178 command_tag: String,
179 row_schema: Arc<Vec<FieldInfo>>,
180 data_rows: SendableRowStream,
181}
182
183impl Debug for QueryResponse {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 f.debug_struct("QueryResponse")
186 .field("command_tag", &self.command_tag)
187 .field("row_schema", &self.row_schema)
188 .finish()
189 }
190}
191
192impl QueryResponse {
193 pub fn new<S>(field_defs: Arc<Vec<FieldInfo>>, row_stream: S) -> QueryResponse
196 where
197 S: Stream<Item = PgWireResult<DataRow>> + Send + 'static,
198 {
199 QueryResponse {
200 command_tag: "SELECT".to_owned(),
201 row_schema: field_defs,
202 data_rows: Box::pin(row_stream),
203 }
204 }
205
206 pub fn command_tag(&self) -> &str {
208 &self.command_tag
209 }
210
211 pub fn set_command_tag(&mut self, command_tag: &str) {
213 command_tag.clone_into(&mut self.command_tag);
214 }
215
216 pub fn row_schema(&self) -> Arc<Vec<FieldInfo>> {
218 self.row_schema.clone()
219 }
220
221 pub fn data_rows(&mut self) -> &mut SendableRowStream {
223 &mut self.data_rows
224 }
225}
226
227pub struct DataRowEncoder {
228 schema: Arc<Vec<FieldInfo>>,
229 row_buffer: BytesMut,
230 col_index: usize,
231}
232
233impl DataRowEncoder {
234 pub fn new(fields: Arc<Vec<FieldInfo>>) -> DataRowEncoder {
236 Self {
237 schema: fields,
238 row_buffer: BytesMut::with_capacity(128),
239 col_index: 0,
240 }
241 }
242
243 pub fn encode_field_with_type_and_format<T>(
248 &mut self,
249 value: &T,
250 data_type: &Type,
251 format: FieldFormat,
252 format_options: &FormatOptions,
253 ) -> PgWireResult<()>
254 where
255 T: ToSql + ToSqlText + Sized,
256 {
257 let prev_index = self.row_buffer.len();
259 self.row_buffer.put_i32(-1);
261
262 let is_null = if format == FieldFormat::Text {
263 value.to_sql_text(data_type, &mut self.row_buffer, format_options)?
264 } else {
265 value.to_sql(data_type, &mut self.row_buffer)?
266 };
267
268 if let IsNull::No = is_null {
269 let value_length = self.row_buffer.len() - prev_index - 4;
270 let mut length_bytes = &mut self.row_buffer[prev_index..(prev_index + 4)];
271 length_bytes.put_i32(value_length as i32);
272 }
273
274 self.col_index += 1;
275
276 Ok(())
277 }
278
279 pub fn encode_field<T>(&mut self, value: &T) -> PgWireResult<()>
283 where
284 T: ToSql + ToSqlText + Sized,
285 {
286 let field = &self.schema[self.col_index];
287
288 let data_type = field.datatype().clone();
289 let format = field.format();
290 let format_options = field.format_options().clone();
291
292 self.encode_field_with_type_and_format(value, &data_type, format, format_options.as_ref())
293 }
294
295 #[deprecated(
296 since = "0.37.0",
297 note = "DataRowEncoder is reusable since 0.37, use `take_row() instead`"
298 )]
299 pub fn finish(self) -> PgWireResult<DataRow> {
300 Ok(DataRow::new(self.row_buffer, self.col_index as i16))
301 }
302
303 pub fn take_row(&mut self) -> DataRow {
308 let row = DataRow::new(self.row_buffer.split(), self.col_index as i16);
309 self.col_index = 0;
310 row
311 }
312}
313
314pub trait DescribeResponse {
316 fn parameters(&self) -> Option<&[Type]>;
317
318 fn fields(&self) -> &[FieldInfo];
319
320 fn no_data() -> Self;
323
324 fn is_no_data(&self) -> bool;
326}
327
328#[non_exhaustive]
330#[derive(Debug, new)]
331pub struct DescribeStatementResponse {
332 pub parameters: Vec<Type>,
333 pub fields: Vec<FieldInfo>,
334}
335
336impl DescribeResponse for DescribeStatementResponse {
337 fn parameters(&self) -> Option<&[Type]> {
338 Some(self.parameters.as_ref())
339 }
340
341 fn fields(&self) -> &[FieldInfo] {
342 &self.fields
343 }
344
345 fn no_data() -> Self {
348 DescribeStatementResponse {
349 parameters: vec![],
350 fields: vec![],
351 }
352 }
353
354 fn is_no_data(&self) -> bool {
356 self.parameters.is_empty() && self.fields.is_empty()
357 }
358}
359
360#[non_exhaustive]
362#[derive(Debug, new)]
363pub struct DescribePortalResponse {
364 pub fields: Vec<FieldInfo>,
365}
366
367impl DescribeResponse for DescribePortalResponse {
368 fn parameters(&self) -> Option<&[Type]> {
369 None
370 }
371
372 fn fields(&self) -> &[FieldInfo] {
373 &self.fields
374 }
375
376 fn no_data() -> Self {
379 DescribePortalResponse { fields: vec![] }
380 }
381
382 fn is_no_data(&self) -> bool {
384 self.fields.is_empty()
385 }
386}
387
388#[non_exhaustive]
390#[derive(Debug, new)]
391pub struct CopyResponse {
392 pub format: i8,
393 pub columns: usize,
394 pub column_formats: Vec<i16>,
395}
396
397#[derive(Debug)]
409pub enum Response {
410 EmptyQuery,
411 Query(QueryResponse),
412 Execution(Tag),
413 TransactionStart(Tag),
414 TransactionEnd(Tag),
415 Error(Box<ErrorInfo>),
416 CopyIn(CopyResponse),
417 CopyOut(CopyResponse),
418 CopyBoth(CopyResponse),
419}
420
421#[cfg(test)]
422mod test {
423
424 use super::*;
425
426 #[test]
427 fn test_command_complete() {
428 let tag = Tag::new("INSERT").with_rows(100);
429 let cc = CommandComplete::from(tag);
430
431 assert_eq!(cc.tag, "INSERT 100");
432
433 let tag = Tag::new("INSERT").with_oid(0).with_rows(100);
434 let cc = CommandComplete::from(tag);
435
436 assert_eq!(cc.tag, "INSERT 0 100");
437 }
438
439 #[test]
440 #[cfg(feature = "pg-type-chrono")]
441 fn test_data_row_encoder() {
442 use std::time::SystemTime;
443
444 let schema = Arc::new(vec![
445 FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text),
446 FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text),
447 FieldInfo::new("ts".into(), None, None, Type::TIMESTAMP, FieldFormat::Text),
448 ]);
449 let now = SystemTime::now();
450 let mut encoder = DataRowEncoder::new(schema);
451 encoder.encode_field(&2001).unwrap();
452 encoder.encode_field(&"udev").unwrap();
453 encoder.encode_field(&now).unwrap();
454
455 let row = encoder.take_row();
456
457 assert_eq!(row.field_count, 3);
458
459 let mut expected = BytesMut::new();
460 expected.put_i32(4);
461 expected.put_slice("2001".as_bytes());
462 expected.put_i32(4);
463 expected.put_slice("udev".as_bytes());
464 expected.put_i32(26);
465 let _ = now.to_sql_text(&Type::TIMESTAMP, &mut expected, &FormatOptions::default());
466 assert_eq!(row.data, expected);
467 }
468}