datafusion_sqllogictest/engines/postgres_engine/
mod.rs1use async_trait::async_trait;
19use bigdecimal::BigDecimal;
20use bytes::Bytes;
21use datafusion::common::runtime::SpawnedTask;
22use futures::{SinkExt, StreamExt};
23use log::{debug, info};
24use sqllogictest::DBOutput;
25use std::path::{Path, PathBuf};
27use std::str::FromStr;
28use std::time::Duration;
29
30use super::conversion::*;
31use crate::engines::currently_executed_sql::CurrentlyExecutingSqlTracker;
32use crate::engines::output::{DFColumnType, DFOutput};
33use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
34use indicatif::ProgressBar;
35use postgres_types::Type;
36use tokio::time::Instant;
37use tokio_postgres::{SimpleQueryMessage, SimpleQueryRow};
38
39const PG_URI: &str = "postgresql://postgres@127.0.0.1/test";
41
42#[derive(Debug, thiserror::Error)]
44pub enum Error {
45 #[error("Postgres error: {0}")]
46 Postgres(#[from] tokio_postgres::error::Error),
47 #[error("Error handling copy command: {0}")]
48 Copy(String),
49}
50
51pub type Result<T, E = Error> = std::result::Result<T, E>;
52
53pub struct Postgres {
54 client: Option<tokio_postgres::Client>,
56 spawned_task: Option<SpawnedTask<()>>,
57 relative_path: PathBuf,
59 pb: ProgressBar,
60 currently_executing_sql_tracker: CurrentlyExecutingSqlTracker,
61}
62
63impl Postgres {
64 pub async fn connect(relative_path: PathBuf, pb: ProgressBar) -> Result<Self> {
78 let uri = std::env::var("PG_URI")
79 .map_or_else(|_| PG_URI.to_string(), std::convert::identity);
80
81 info!("Using postgres connection string: {uri}");
82
83 let config = tokio_postgres::Config::from_str(&uri)?;
84
85 let res = config.connect(tokio_postgres::NoTls).await;
87 if res.is_err() {
88 eprintln!("Error connecting to postgres using PG_URI={uri}");
89 };
90
91 let (client, connection) = res?;
92
93 let spawned_task = SpawnedTask::spawn(async move {
94 if let Err(e) = connection.await {
95 log::error!("Postgres connection error: {e:?}");
96 }
97 });
98
99 let schema = schema_name(&relative_path);
100
101 debug!("Creating new empty schema '{schema}'");
103 client
104 .execute(&format!("DROP SCHEMA IF EXISTS {schema} CASCADE"), &[])
105 .await?;
106
107 client
108 .execute(&format!("CREATE SCHEMA {schema}"), &[])
109 .await?;
110
111 client
112 .execute(&format!("SET search_path TO {schema}"), &[])
113 .await?;
114
115 Ok(Self {
116 client: Some(client),
117 spawned_task: Some(spawned_task),
118 relative_path,
119 pb,
120 currently_executing_sql_tracker: CurrentlyExecutingSqlTracker::default(),
121 })
122 }
123
124 pub async fn connect_with_tracked_sql(
127 relative_path: PathBuf,
128 pb: ProgressBar,
129 currently_executing_sql_tracker: CurrentlyExecutingSqlTracker,
130 ) -> Result<Self> {
131 let conn = Self::connect(relative_path, pb).await?;
132 Ok(conn.with_currently_executing_sql_tracker(currently_executing_sql_tracker))
133 }
134
135 pub fn with_currently_executing_sql_tracker(
139 self,
140 currently_executing_sql_tracker: CurrentlyExecutingSqlTracker,
141 ) -> Self {
142 Self {
143 currently_executing_sql_tracker,
144 ..self
145 }
146 }
147
148 fn get_client(&mut self) -> &mut tokio_postgres::Client {
149 self.client.as_mut().expect("client is shutdown")
150 }
151
152 async fn run_copy_command(&mut self, sql: &str) -> Result<DFOutput> {
170 let canonical_sql = sql.trim_start().to_ascii_lowercase();
171
172 debug!("Handling COPY command: {sql}");
173
174 let mut tokens = canonical_sql.split_whitespace().peekable();
176 let mut filename = None;
177
178 let mut new_sql = vec![];
185 while let Some(tok) = tokens.next() {
186 new_sql.push(tok);
187 if tok == "from" {
189 filename = tokens.next();
190 new_sql.push("STDIN");
191 }
192 }
193
194 let filename = filename.map(no_quotes).ok_or_else(|| {
195 Error::Copy(format!("Can not find filename in COPY: {sql}"))
196 })?;
197
198 let new_sql = new_sql.join(" ");
199 debug!("Copying data from file {filename} using sql: {new_sql}");
200
201 let tx = self.get_client().transaction().await?;
203 let sink = tx.copy_in(&new_sql).await?;
204 let mut sink = Box::pin(sink);
205
206 let data = std::fs::read_to_string(filename)
208 .map_err(|e| Error::Copy(format!("Error reading {filename}: {e}")))?;
209
210 let mut data_stream = futures::stream::iter(vec![Ok(Bytes::from(data))]).boxed();
211
212 sink.send_all(&mut data_stream).await?;
213 sink.close().await?;
214 tx.commit().await?;
215 Ok(DBOutput::StatementComplete(0))
216 }
217
218 fn update_slow_count(&self) {
219 let msg = self.pb.message();
220 let split: Vec<&str> = msg.split(" ").collect();
221 let mut current_count = 0;
222
223 if split.len() > 2 {
224 current_count += split[2].parse::<i32>().unwrap();
226 }
227
228 current_count += 1;
229
230 self.pb
231 .set_message(format!("{} - {} took > 500 ms", split[0], current_count));
232 }
233}
234
235fn no_quotes(t: &str) -> &str {
239 t.trim_start_matches('\'').trim_end_matches('\'')
240}
241
242fn schema_name(relative_path: &Path) -> String {
245 relative_path
246 .to_string_lossy()
247 .chars()
248 .filter(|ch| ch.is_ascii_alphanumeric())
249 .collect::<String>()
250 .trim_start_matches("pg_")
251 .to_string()
252}
253
254#[async_trait]
255impl sqllogictest::AsyncDB for Postgres {
256 type Error = Error;
257 type ColumnType = DFColumnType;
258
259 async fn run(
260 &mut self,
261 sql: &str,
262 ) -> Result<DBOutput<Self::ColumnType>, Self::Error> {
263 debug!(
264 "[{}] Running query: \"{}\"",
265 self.relative_path.display(),
266 sql
267 );
268
269 let tracked_sql = self.currently_executing_sql_tracker.set_sql(sql);
270
271 let lower_sql = sql.trim_start().to_ascii_lowercase();
272
273 let is_query_sql = {
274 lower_sql.starts_with("select")
275 || lower_sql.starts_with("values")
276 || lower_sql.starts_with("show")
277 || lower_sql.starts_with("with")
278 || lower_sql.starts_with("describe")
279 || ((lower_sql.starts_with("insert")
280 || lower_sql.starts_with("update")
281 || lower_sql.starts_with("delete"))
282 && lower_sql.contains("returning"))
283 };
284
285 if lower_sql.starts_with("copy") {
286 self.pb.inc(1);
287 let result = self.run_copy_command(sql).await;
288 self.currently_executing_sql_tracker.remove_sql(tracked_sql);
289
290 return result;
291 }
292
293 if !is_query_sql {
294 self.get_client().execute(sql, &[]).await?;
295 self.currently_executing_sql_tracker.remove_sql(tracked_sql);
296 self.pb.inc(1);
297 return Ok(DBOutput::StatementComplete(0));
298 }
299 let statement = self.get_client().prepare(sql).await?;
301 let types: Vec<Type> = statement
302 .columns()
303 .iter()
304 .map(|c| c.type_().clone())
305 .collect();
306
307 let start = Instant::now();
312 let messages = self.get_client().simple_query(sql).await?;
313 let duration = start.elapsed();
314
315 if duration.gt(&Duration::from_millis(500)) {
316 self.update_slow_count();
317 }
318
319 self.pb.inc(1);
320
321 self.currently_executing_sql_tracker.remove_sql(tracked_sql);
322
323 let rows = convert_rows(&types, &messages);
324
325 if rows.is_empty() && types.is_empty() {
326 Ok(DBOutput::StatementComplete(0))
327 } else {
328 Ok(DBOutput::Rows {
329 types: convert_types(types),
330 rows,
331 })
332 }
333 }
334
335 fn engine_name(&self) -> &str {
336 "postgres"
337 }
338
339 async fn shutdown(&mut self) {
340 if let Some(client) = self.client.take() {
341 drop(client);
342 }
343 if let Some(spawned_task) = self.spawned_task.take() {
344 spawned_task.join().await.ok();
345 }
346 }
347}
348
349fn convert_rows(types: &[Type], messages: &[SimpleQueryMessage]) -> Vec<Vec<String>> {
350 messages
351 .iter()
352 .filter_map(|message| match message {
353 SimpleQueryMessage::Row(row) => Some(row),
354 _ => None,
355 })
356 .map(|row| {
357 types
358 .iter()
359 .enumerate()
360 .map(|(idx, column_type)| cell_to_string(row, column_type, idx))
361 .collect::<Vec<String>>()
362 })
363 .collect::<Vec<_>>()
364}
365
366fn cell_to_string(row: &SimpleQueryRow, column_type: &Type, idx: usize) -> String {
367 let value = row.get(idx);
370 match (column_type, value) {
371 (_, None) => NULL_STR.to_string(),
372 (&Type::CHAR, Some(value)) => value
373 .as_bytes()
374 .first()
375 .map(|byte| (*byte as i8).to_string())
376 .unwrap_or_else(|| NULL_STR.to_string()),
377 (&Type::INT2, Some(value)) => value.parse::<i16>().unwrap().to_string(),
378 (&Type::INT4, Some(value)) => value.parse::<i32>().unwrap().to_string(),
379 (&Type::INT8, Some(value)) => value.parse::<i64>().unwrap().to_string(),
380 (&Type::NUMERIC, Some(value)) => {
381 decimal_to_str(BigDecimal::from_str(value).unwrap())
382 }
383 (&Type::DATE, Some(value)) => NaiveDate::parse_from_str(value, "%Y-%m-%d")
385 .unwrap()
386 .to_string(),
387 (&Type::TIME, Some(value)) => NaiveTime::parse_from_str(value, "%H:%M:%S%.f")
388 .unwrap()
389 .to_string(),
390 (&Type::TIMESTAMP, Some(value)) => {
391 let parsed = NaiveDateTime::parse_from_str(value, "%Y-%m-%d %H:%M:%S%.f")
392 .or_else(|_| NaiveDateTime::parse_from_str(value, "%Y-%m-%dT%H:%M:%S%.f"))
393 .unwrap();
394 format!("{parsed:?}")
395 }
396 (&Type::BOOL, Some(value)) => {
397 let parsed = match value {
398 "t" | "true" | "TRUE" => true,
399 "f" | "false" | "FALSE" => false,
400 _ => panic!("Unsupported boolean value: {value}"),
401 };
402 bool_to_str(parsed)
403 }
404 (&Type::BPCHAR | &Type::VARCHAR | &Type::TEXT, Some(value)) => {
405 varchar_to_str(value)
406 }
407 (&Type::FLOAT4, Some(value)) => f32_to_str(value.parse::<f32>().unwrap()),
408 (&Type::FLOAT8, Some(value)) => f64_to_str(value.parse::<f64>().unwrap()),
409 (&Type::REGTYPE, Some(value)) => value.to_string(),
410 _ => unimplemented!("Unsupported type: {}", column_type.name()),
411 }
412}
413
414fn convert_types(types: Vec<Type>) -> Vec<DFColumnType> {
415 types
416 .into_iter()
417 .map(|t| match t {
418 Type::BOOL => DFColumnType::Boolean,
419 Type::INT2 | Type::INT4 | Type::INT8 => DFColumnType::Integer,
420 Type::BPCHAR | Type::VARCHAR | Type::TEXT => DFColumnType::Text,
421 Type::FLOAT4 | Type::FLOAT8 | Type::NUMERIC => DFColumnType::Float,
422 Type::DATE | Type::TIME => DFColumnType::DateTime,
423 Type::TIMESTAMP => DFColumnType::Timestamp,
424 _ => DFColumnType::Another,
425 })
426 .collect()
427}