#![cfg(feature = "pgwire")]
use std::net::{IpAddr, TcpListener};
use std::sync::Arc;
use std::time::Duration;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use tempfile::TempDir;
use datapress_core::config::{
AppConfig, DatasetConfig, IndexConfig, PgwireConfig, ServerConfig, SourceConfig, SourceKind,
};
use datapress_datafusion::pgwire::{serve_pgwire, spawn_pgwire};
use datapress_datafusion::store::Store;
const USER: &str = "datapress";
const PASSWORD: &str = "s3cr3t-pw";
fn write_people_parquet(path: &std::path::Path) {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![1_i64, 3, 4])),
Arc::new(StringArray::from(vec!["Anna", "Cara", "Dan"])),
],
)
.unwrap();
let file = std::fs::File::create(path).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
}
async fn make_people_store(location: &str) -> Store {
let cfg = AppConfig {
server: ServerConfig::default(),
docs: datapress_core::config::DocsConfig::default(),
swagger: datapress_core::config::SwaggerConfig::default(),
auth: datapress_core::config::AuthConfig::default(),
metrics: datapress_core::config::MetricsConfig::default(),
explorer: datapress_core::config::ExplorerConfig::default(),
sql: datapress_core::config::SqlConfig::default(),
datafusion: datapress_core::config::DataFusionConfig::default(),
datasets: vec![DatasetConfig {
name: "people".into(),
source: SourceConfig {
kind: SourceKind::Parquet,
location: location.to_string(),
},
s3: None,
index: IndexConfig::default(),
columns: vec![],
dict_encode: true,
lazy: false,
predicate_filter: Default::default(),
projection_filter: Default::default(),
}],
};
Store::load(&cfg).await.expect("Store::load")
}
fn free_port() -> u16 {
TcpListener::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port()
}
fn loopback() -> IpAddr {
IpAddr::from([127, 0, 0, 1])
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pgwire_password_auth_queries() {
let dir = TempDir::new().unwrap();
let parquet = dir.path().join("people.parquet");
write_people_parquet(&parquet);
let store = make_people_store(parquet.to_str().unwrap()).await;
let port = free_port();
let cfg = PgwireConfig {
enabled: true,
listen: loopback(),
port,
username: USER.into(),
password: Some(PASSWORD.into()),
tls_cert: None,
tls_key: None,
};
let ctx = store.session_context().clone();
let server = tokio::spawn(async move {
let _ = serve_pgwire(ctx, cfg).await;
});
let good = format!("host=127.0.0.1 port={port} user={USER} password={PASSWORD} dbname=datapress");
let client = {
let mut attempt = None;
for _ in 0..50 {
match tokio_postgres::connect(&good, tokio_postgres::NoTls).await {
Ok((client, connection)) => {
tokio::spawn(async move {
let _ = connection.await;
});
attempt = Some(client);
break;
}
Err(_) => tokio::time::sleep(Duration::from_millis(100)).await,
}
}
attempt.expect("pgwire server did not become reachable")
};
let row = client
.query_one("SELECT count(*) AS c FROM people", &[])
.await
.expect("count query");
let c: i64 = row.get("c");
assert_eq!(c, 3);
let rows = client
.query("SELECT id, name FROM people ORDER BY id", &[])
.await
.expect("select query");
assert_eq!(rows.len(), 3);
let id0: i64 = rows[0].get("id");
let name0: &str = rows[0].get("name");
assert_eq!(id0, 1);
assert_eq!(name0, "Anna");
let rows = client
.query("SELECT name FROM people WHERE id = $1", &[&3_i64])
.await
.expect("prepared query");
assert_eq!(rows.len(), 1);
let name: &str = rows[0].get("name");
assert_eq!(name, "Cara");
let row = client
.query_one("SELECT count(*) AS c FROM pg_catalog.pg_type", &[])
.await
.expect("pg_catalog.pg_type query");
let pg_type_rows: i64 = row.get("c");
assert!(
pg_type_rows > 0,
"pg_catalog.pg_type should be populated, got {pg_type_rows}"
);
let rows = client
.query("SELECT table_name FROM information_schema.tables", &[])
.await
.expect("information_schema.tables query");
let table_names: Vec<String> = rows.iter().map(|r| r.get::<_, String>("table_name")).collect();
assert!(
table_names.iter().any(|t| t == "people"),
"information_schema.tables should include 'people', got {table_names:?}"
);
let row = client
.query_one("SELECT current_schema() AS s", &[])
.await
.expect("current_schema query");
let schema: &str = row.get("s");
assert_eq!(schema, "public");
for stmt in ["DISCARD ALL", "RESET ALL", "DEALLOCATE ALL", "UNLISTEN *"] {
client
.batch_execute(stmt)
.await
.unwrap_or_else(|e| panic!("session-maintenance statement `{stmt}` failed: {e}"));
}
let row = client
.query_one("SELECT count(*) AS c FROM people", &[])
.await
.expect("query after session reset");
let c: i64 = row.get("c");
assert_eq!(c, 3);
let row = client
.query_one("SELECT 'DISCARD ALL' AS s", &[])
.await
.expect("string-literal query");
let literal: &str = row.get("s");
assert_eq!(literal, "DISCARD ALL");
let bad = format!("host=127.0.0.1 port={port} user={USER} password=wrong-pw dbname=datapress");
let denied = tokio_postgres::connect(&bad, tokio_postgres::NoTls).await;
assert!(
denied.is_err(),
"connection with an incorrect password must be rejected"
);
server.abort();
drop(store);
}
#[derive(Debug)]
struct NoCertVerify;
impl rustls::client::danger::ServerCertVerifier for NoCertVerify {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pgwire_password_auth_over_tls() {
let dir = TempDir::new().unwrap();
let parquet = dir.path().join("people.parquet");
write_people_parquet(&parquet);
let store = make_people_store(parquet.to_str().unwrap()).await;
let issued =
rcgen::generate_simple_self_signed(vec!["localhost".into(), "127.0.0.1".into()]).unwrap();
let cert_path = dir.path().join("server.crt");
let key_path = dir.path().join("server.key");
std::fs::write(&cert_path, issued.cert.pem()).unwrap();
std::fs::write(&key_path, issued.key_pair.serialize_pem()).unwrap();
let port = free_port();
let cfg = PgwireConfig {
enabled: true,
listen: loopback(),
port,
username: USER.into(),
password: Some(PASSWORD.into()),
tls_cert: Some(cert_path),
tls_key: Some(key_path),
};
let ctx = store.session_context().clone();
let server = tokio::spawn(async move {
let _ = serve_pgwire(ctx, cfg).await;
});
let tls_config =
rustls::ClientConfig::builder_with_provider(Arc::new(rustls::crypto::ring::default_provider()))
.with_safe_default_protocol_versions()
.unwrap()
.dangerous()
.with_custom_certificate_verifier(Arc::new(NoCertVerify))
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
let conn_str =
format!("host=127.0.0.1 port={port} user={USER} password={PASSWORD} dbname=datapress sslmode=require");
let client = {
let mut attempt = None;
for _ in 0..50 {
match tokio_postgres::connect(&conn_str, tls.clone()).await {
Ok((client, connection)) => {
tokio::spawn(async move {
let _ = connection.await;
});
attempt = Some(client);
break;
}
Err(_) => tokio::time::sleep(Duration::from_millis(100)).await,
}
}
attempt.expect("pgwire TLS server did not become reachable")
};
let row = client
.query_one("SELECT count(*) AS c FROM people", &[])
.await
.expect("count query over TLS");
let c: i64 = row.get("c");
assert_eq!(c, 3);
server.abort();
drop(store);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn pgwire_dedicated_runtime_serves_and_stops() {
let dir = TempDir::new().unwrap();
let parquet = dir.path().join("people.parquet");
write_people_parquet(&parquet);
let store = make_people_store(parquet.to_str().unwrap()).await;
let port = free_port();
let cfg = PgwireConfig {
enabled: true,
listen: loopback(),
port,
username: USER.into(),
password: None,
tls_cert: None,
tls_key: None,
};
let ctx = store.session_context().clone();
let server = spawn_pgwire(ctx, cfg).expect("spawn pgwire runtime");
let conn = format!("host=127.0.0.1 port={port} user={USER} dbname=datapress");
let client = {
let mut attempt = None;
for _ in 0..50 {
match tokio_postgres::connect(&conn, tokio_postgres::NoTls).await {
Ok((client, connection)) => {
tokio::spawn(async move {
let _ = connection.await;
});
attempt = Some(client);
break;
}
Err(_) => tokio::time::sleep(Duration::from_millis(100)).await,
}
}
attempt.expect("pgwire dedicated-runtime server did not become reachable")
};
let row = client
.query_one("SELECT count(*) AS c FROM people WHERE id > 0", &[])
.await
.expect("count query on dedicated runtime");
let c: i64 = row.get("c");
assert_eq!(c, 3);
let row = client
.query_one("SELECT count(*) AS c FROM pg_catalog.pg_type", &[])
.await
.expect("pg_catalog.pg_type query on dedicated runtime");
let pg_type_rows: i64 = row.get("c");
assert!(pg_type_rows > 0, "pg_catalog.pg_type should be populated");
drop(client);
drop(server);
drop(store);
}