Skip to main content

datafusion_sqllogictest/engines/postgres_engine/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use async_trait::async_trait;
19use bigdecimal::BigDecimal;
20use bytes::Bytes;
21use datafusion::common::runtime::SpawnedTask;
22use futures::{SinkExt, StreamExt};
23use log::{debug, info};
24use sqllogictest::DBOutput;
25/// Postgres engine implementation for sqllogictest.
26use std::path::{Path, PathBuf};
27use std::str::FromStr;
28use std::time::Duration;
29
30use super::conversion::*;
31use crate::engines::currently_executed_sql::CurrentlyExecutingSqlTracker;
32use crate::engines::output::{DFColumnType, DFOutput};
33use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
34use indicatif::ProgressBar;
35use postgres_types::Type;
36use tokio::time::Instant;
37use tokio_postgres::{SimpleQueryMessage, SimpleQueryRow};
38
39// default connect string, can be overridden by the `PG_URL` environment variable
40const PG_URI: &str = "postgresql://postgres@127.0.0.1/test";
41
42/// DataFusion sql-logicaltest error
43#[derive(Debug, thiserror::Error)]
44pub enum Error {
45    #[error("Postgres error: {0}")]
46    Postgres(#[from] tokio_postgres::error::Error),
47    #[error("Error handling copy command: {0}")]
48    Copy(String),
49}
50
51pub type Result<T, E = Error> = std::result::Result<T, E>;
52
53pub struct Postgres {
54    // None means the connection has been shutdown
55    client: Option<tokio_postgres::Client>,
56    spawned_task: Option<SpawnedTask<()>>,
57    /// Relative test file path
58    relative_path: PathBuf,
59    pb: ProgressBar,
60    currently_executing_sql_tracker: CurrentlyExecutingSqlTracker,
61}
62
63impl Postgres {
64    /// Creates a runner for executing queries against an existing postgres connection.
65    /// `relative_path` is used for display output and to create a postgres schema.
66    ///
67    /// The database connection details can be overridden by the
68    /// `PG_URI` environment variable.
69    ///
70    /// This defaults to
71    ///
72    /// ```text
73    /// PG_URI="postgresql://postgres@127.0.0.1/test"
74    /// ```
75    ///
76    /// See https://docs.rs/tokio-postgres/latest/tokio_postgres/config/struct.Config.html#url for format
77    pub async fn connect(relative_path: PathBuf, pb: ProgressBar) -> Result<Self> {
78        let uri = std::env::var("PG_URI")
79            .map_or_else(|_| PG_URI.to_string(), std::convert::identity);
80
81        info!("Using postgres connection string: {uri}");
82
83        let config = tokio_postgres::Config::from_str(&uri)?;
84
85        // hint to user what the connection string was
86        let res = config.connect(tokio_postgres::NoTls).await;
87        if res.is_err() {
88            eprintln!("Error connecting to postgres using PG_URI={uri}");
89        };
90
91        let (client, connection) = res?;
92
93        let spawned_task = SpawnedTask::spawn(async move {
94            if let Err(e) = connection.await {
95                log::error!("Postgres connection error: {e:?}");
96            }
97        });
98
99        let schema = schema_name(&relative_path);
100
101        // create a new clean schema for running the test
102        debug!("Creating new empty schema '{schema}'");
103        client
104            .execute(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"), &[])
105            .await?;
106
107        client
108            .execute(&format!("CREATE SCHEMA {schema}"), &[])
109            .await?;
110
111        client
112            .execute(&format!("SET search_path TO {schema}"), &[])
113            .await?;
114
115        Ok(Self {
116            client: Some(client),
117            spawned_task: Some(spawned_task),
118            relative_path,
119            pb,
120            currently_executing_sql_tracker: CurrentlyExecutingSqlTracker::default(),
121        })
122    }
123
124    /// Creates a runner for executing queries against an existing postgres connection
125    /// with a tracker for currently executing SQL statements.
126    pub async fn connect_with_tracked_sql(
127        relative_path: PathBuf,
128        pb: ProgressBar,
129        currently_executing_sql_tracker: CurrentlyExecutingSqlTracker,
130    ) -> Result<Self> {
131        let conn = Self::connect(relative_path, pb).await?;
132        Ok(conn.with_currently_executing_sql_tracker(currently_executing_sql_tracker))
133    }
134
135    /// Add a tracker that will track the currently executed SQL statement.
136    ///
137    /// This is useful for logging and debugging purposes.
138    pub fn with_currently_executing_sql_tracker(
139        self,
140        currently_executing_sql_tracker: CurrentlyExecutingSqlTracker,
141    ) -> Self {
142        Self {
143            currently_executing_sql_tracker,
144            ..self
145        }
146    }
147
148    fn get_client(&mut self) -> &mut tokio_postgres::Client {
149        self.client.as_mut().expect("client is shutdown")
150    }
151
152    /// Special COPY command support. "COPY 'filename'" requires the
153    /// server to read the file which may not be possible (maybe it is
154    /// remote or running in some other docker container).
155    ///
156    /// Thus, we rewrite  sql statements like
157    ///
158    /// ```sql
159    /// COPY ... FROM 'filename' ...
160    /// ```
161    ///
162    /// Into
163    ///
164    /// ```sql
165    /// COPY ... FROM STDIN ...
166    /// ```
167    ///
168    /// And read the file locally.
169    async fn run_copy_command(&mut self, sql: &str) -> Result<DFOutput> {
170        let canonical_sql = sql.trim_start().to_ascii_lowercase();
171
172        debug!("Handling COPY command: {sql}");
173
174        // Hacky way to  find the 'filename' in the statement
175        let mut tokens = canonical_sql.split_whitespace().peekable();
176        let mut filename = None;
177
178        // COPY FROM '/opt/data/csv/aggregate_test_100.csv' ...
179        //
180        // into
181        //
182        // COPY FROM STDIN ...
183
184        let mut new_sql = vec![];
185        while let Some(tok) = tokens.next() {
186            new_sql.push(tok);
187            // rewrite FROM <file> to FROM STDIN
188            if tok == "from" {
189                filename = tokens.next();
190                new_sql.push("STDIN");
191            }
192        }
193
194        let filename = filename.map(no_quotes).ok_or_else(|| {
195            Error::Copy(format!("Can not find filename in COPY: {sql}"))
196        })?;
197
198        let new_sql = new_sql.join(" ");
199        debug!("Copying data from file {filename} using sql: {new_sql}");
200
201        // start the COPY command and get location to write data to
202        let tx = self.get_client().transaction().await?;
203        let sink = tx.copy_in(&new_sql).await?;
204        let mut sink = Box::pin(sink);
205
206        // read the input file as a string ans feed it to the copy command
207        let data = std::fs::read_to_string(filename)
208            .map_err(|e| Error::Copy(format!("Error reading {filename}: {e}")))?;
209
210        let mut data_stream = futures::stream::iter(vec![Ok(Bytes::from(data))]).boxed();
211
212        sink.send_all(&mut data_stream).await?;
213        sink.close().await?;
214        tx.commit().await?;
215        Ok(DBOutput::StatementComplete(0))
216    }
217
218    fn update_slow_count(&self) {
219        let msg = self.pb.message();
220        let split: Vec<&str> = msg.split(" ").collect();
221        let mut current_count = 0;
222
223        if split.len() > 2 {
224            // second match will be current slow count
225            current_count += split[2].parse::<i32>().unwrap();
226        }
227
228        current_count += 1;
229
230        self.pb
231            .set_message(format!("{} - {} took > 500 ms", split[0], current_count));
232    }
233}
234
235/// remove single quotes from the start and end of the string
236///
237/// 'filename' --> filename
238fn no_quotes(t: &str) -> &str {
239    t.trim_start_matches('\'').trim_end_matches('\'')
240}
241
242/// Given a file name like pg_compat_foo.slt
243/// return a schema name
244fn schema_name(relative_path: &Path) -> String {
245    relative_path
246        .to_string_lossy()
247        .chars()
248        .filter(|ch| ch.is_ascii_alphanumeric())
249        .collect::<String>()
250        .trim_start_matches("pg_")
251        .to_string()
252}
253
254#[async_trait]
255impl sqllogictest::AsyncDB for Postgres {
256    type Error = Error;
257    type ColumnType = DFColumnType;
258
259    async fn run(
260        &mut self,
261        sql: &str,
262    ) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
263        debug!(
264            "[{}] Running query: \"{}\"",
265            self.relative_path.display(),
266            sql
267        );
268
269        let tracked_sql = self.currently_executing_sql_tracker.set_sql(sql);
270
271        let lower_sql = sql.trim_start().to_ascii_lowercase();
272
273        let is_query_sql = {
274            lower_sql.starts_with("select")
275                || lower_sql.starts_with("values")
276                || lower_sql.starts_with("show")
277                || lower_sql.starts_with("with")
278                || lower_sql.starts_with("describe")
279                || ((lower_sql.starts_with("insert")
280                    || lower_sql.starts_with("update")
281                    || lower_sql.starts_with("delete"))
282                    && lower_sql.contains("returning"))
283        };
284
285        if lower_sql.starts_with("copy") {
286            self.pb.inc(1);
287            let result = self.run_copy_command(sql).await;
288            self.currently_executing_sql_tracker.remove_sql(tracked_sql);
289
290            return result;
291        }
292
293        if !is_query_sql {
294            self.get_client().execute(sql, &[]).await?;
295            self.currently_executing_sql_tracker.remove_sql(tracked_sql);
296            self.pb.inc(1);
297            return Ok(DBOutput::StatementComplete(0));
298        }
299        // Use a prepared statement to get the output column types
300        let statement = self.get_client().prepare(sql).await?;
301        let types: Vec<Type> = statement
302            .columns()
303            .iter()
304            .map(|c| c.type_().clone())
305            .collect();
306
307        // Run the actual query using the "simple query" protocol that returns all
308        // rows as text. Doing this avoids having to convert values from the binary
309        // format to strings, which is somewhat tricky for numeric types.
310        // See https://github.com/apache/datafusion/pull/19666#discussion_r2668090587
311        let start = Instant::now();
312        let messages = self.get_client().simple_query(sql).await?;
313        let duration = start.elapsed();
314
315        if duration.gt(&Duration::from_millis(500)) {
316            self.update_slow_count();
317        }
318
319        self.pb.inc(1);
320
321        self.currently_executing_sql_tracker.remove_sql(tracked_sql);
322
323        let rows = convert_rows(&types, &messages);
324
325        if rows.is_empty() && types.is_empty() {
326            Ok(DBOutput::StatementComplete(0))
327        } else {
328            Ok(DBOutput::Rows {
329                types: convert_types(types),
330                rows,
331            })
332        }
333    }
334
335    fn engine_name(&self) -> &str {
336        "postgres"
337    }
338
339    async fn shutdown(&mut self) {
340        if let Some(client) = self.client.take() {
341            drop(client);
342        }
343        if let Some(spawned_task) = self.spawned_task.take() {
344            spawned_task.join().await.ok();
345        }
346    }
347}
348
349fn convert_rows(types: &[Type], messages: &[SimpleQueryMessage]) -> Vec<Vec<String>> {
350    messages
351        .iter()
352        .filter_map(|message| match message {
353            SimpleQueryMessage::Row(row) => Some(row),
354            _ => None,
355        })
356        .map(|row| {
357            types
358                .iter()
359                .enumerate()
360                .map(|(idx, column_type)| cell_to_string(row, column_type, idx))
361                .collect::<Vec<String>>()
362        })
363        .collect::<Vec<_>>()
364}
365
366fn cell_to_string(row: &SimpleQueryRow, column_type: &Type, idx: usize) -> String {
367    // simple_query returns text values, so we parse by Postgres type to keep
368    // normalization aligned with the DataFusion engine output.
369    let value = row.get(idx);
370    match (column_type, value) {
371        (_, None) => NULL_STR.to_string(),
372        (&Type::CHAR, Some(value)) => value
373            .as_bytes()
374            .first()
375            .map(|byte| (*byte as i8).to_string())
376            .unwrap_or_else(|| NULL_STR.to_string()),
377        (&Type::INT2, Some(value)) => value.parse::<i16>().unwrap().to_string(),
378        (&Type::INT4, Some(value)) => value.parse::<i32>().unwrap().to_string(),
379        (&Type::INT8, Some(value)) => value.parse::<i64>().unwrap().to_string(),
380        (&Type::NUMERIC, Some(value)) => {
381            decimal_to_str(BigDecimal::from_str(value).unwrap())
382        }
383        // Parse date/time strings explicitly to avoid locale-specific formatting.
384        (&Type::DATE, Some(value)) => NaiveDate::parse_from_str(value, "%Y-%m-%d")
385            .unwrap()
386            .to_string(),
387        (&Type::TIME, Some(value)) => NaiveTime::parse_from_str(value, "%H:%M:%S%.f")
388            .unwrap()
389            .to_string(),
390        (&Type::TIMESTAMP, Some(value)) => {
391            let parsed = NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S%.f")
392                .or_else(|_| NaiveDateTime::parse_from_str(value, "%Y-%m-%dT%H:%M:%S%.f"))
393                .unwrap();
394            format!("{parsed:?}")
395        }
396        (&Type::BOOL, Some(value)) => {
397            let parsed = match value {
398                "t" | "true" | "TRUE" => true,
399                "f" | "false" | "FALSE" => false,
400                _ => panic!("Unsupported boolean value: {value}"),
401            };
402            bool_to_str(parsed)
403        }
404        (&Type::BPCHAR | &Type::VARCHAR | &Type::TEXT, Some(value)) => {
405            varchar_to_str(value)
406        }
407        (&Type::FLOAT4, Some(value)) => f32_to_str(value.parse::<f32>().unwrap()),
408        (&Type::FLOAT8, Some(value)) => f64_to_str(value.parse::<f64>().unwrap()),
409        (&Type::REGTYPE, Some(value)) => value.to_string(),
410        _ => unimplemented!("Unsupported type: {}", column_type.name()),
411    }
412}
413
414fn convert_types(types: Vec<Type>) -> Vec<DFColumnType> {
415    types
416        .into_iter()
417        .map(|t| match t {
418            Type::BOOL => DFColumnType::Boolean,
419            Type::INT2 | Type::INT4 | Type::INT8 => DFColumnType::Integer,
420            Type::BPCHAR | Type::VARCHAR | Type::TEXT => DFColumnType::Text,
421            Type::FLOAT4 | Type::FLOAT8 | Type::NUMERIC => DFColumnType::Float,
422            Type::DATE | Type::TIME => DFColumnType::DateTime,
423            Type::TIMESTAMP => DFColumnType::Timestamp,
424            _ => DFColumnType::Another,
425        })
426        .collect()
427}