#![cfg(feature = "runtime-tokio")]
use std::any::Any;
use std::fmt;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use chrono::{DateTime, Local, TimeZone, Utc};
use futures::StreamExt;
use graphile_worker_database::{
row_mapping, Database, DatabaseDriver, DbCell, DbError, DbExecutor, DbParams, DbRow,
DbTransaction, DbValue, NotificationStream, TransactionDriver,
};
use serde_json::json;
#[cfg(feature = "driver-sqlx")]
use graphile_worker_database::sqlx::SqlxDatabase;
#[cfg(feature = "driver-tokio-postgres")]
use graphile_worker_database::tokio_postgres::TokioPostgresDatabase;
fn database_url() -> String {
std::env::var("DATABASE_URL").expect("DATABASE_URL must be set")
}
fn timestamp() -> DateTime<Utc> {
Utc.with_ymd_and_hms(2026, 5, 9, 10, 11, 12)
.single()
.expect("valid timestamp")
}
fn unique_channel(prefix: &str) -> String {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time should be after unix epoch")
.as_nanos();
format!("{prefix}_{nanos}")
}
#[test]
fn db_row_decodes_supported_cells_and_errors() {
let ts = timestamp();
let row = row_mapping::cells([
("bool", DbCell::Bool(true)),
("i16", DbCell::I16(16)),
("i32", DbCell::I32(32)),
("i64", DbCell::I64(64)),
("json", DbCell::Json(json!({ "ok": true }))),
("text", DbCell::Text("value".to_string())),
("timestamp", DbCell::TimestampTz(ts)),
("null", DbCell::Null),
]);
assert!(row.try_get::<bool>("bool").unwrap());
assert_eq!(row.try_get::<i16>("i16").unwrap(), 16);
assert_eq!(row.try_get::<i32>("i32").unwrap(), 32);
assert_eq!(row.try_get::<i64>("i64").unwrap(), 64);
assert_eq!(
row.try_get::<serde_json::Value>("json").unwrap(),
json!({ "ok": true })
);
assert_eq!(row.try_get::<String>("text").unwrap(), "value");
assert_eq!(row.try_get::<DateTime<Utc>>("timestamp").unwrap(), ts);
assert_eq!(
row.try_get::<DateTime<Local>>("timestamp")
.unwrap()
.with_timezone(&Utc),
ts
);
assert_eq!(row.try_get::<Option<bool>>("null").unwrap(), None);
assert_eq!(
row.try_get::<Option<String>>("text").unwrap(),
Some("value".to_string())
);
let missing = row.try_get::<bool>("missing").unwrap_err();
assert!(missing.to_string().contains("was not present"));
let wrong_type = row.try_get::<i32>("text").unwrap_err();
assert!(wrong_type
.to_string()
.contains("could not be decoded as i32"));
let coded = DbError::with_code("duplicate key", "23505");
assert_eq!(coded.code(), Some("23505"));
}
#[test]
fn db_params_preserve_values() {
let mut params = DbParams::new();
params.push(DbValue::Bool(true));
params.push(DbValue::Text("hello".to_string()));
assert_eq!(params.values().len(), 2);
let from_vec = DbParams::from(vec![DbValue::I32(42)]);
assert_eq!(from_vec.values().len(), 1);
}
#[derive(Debug)]
struct MockDriver {
rows: Vec<DbRow>,
}
impl DbExecutor for MockDriver {
fn execute<'a>(
&'a self,
_sql: &'a str,
_params: DbParams,
) -> graphile_worker_database::BoxFuture<'a, Result<u64, DbError>> {
Box::pin(async { Ok(7) })
}
fn fetch_all<'a>(
&'a self,
_sql: &'a str,
_params: DbParams,
) -> graphile_worker_database::BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
Box::pin(async { Ok(self.rows.clone()) })
}
}
impl DatabaseDriver for MockDriver {
fn as_any(&self) -> &dyn Any {
self
}
fn begin<'a>(
&'a self,
) -> graphile_worker_database::BoxFuture<'a, Result<DbTransaction, DbError>> {
Box::pin(async { Ok(DbTransaction::new(Box::new(MockTransaction))) })
}
fn listen<'a>(
&'a self,
_channel: &'a str,
) -> graphile_worker_database::BoxFuture<'a, Result<Option<NotificationStream>, DbError>> {
Box::pin(async { Ok(None) })
}
}
struct MockTransaction;
impl DbExecutor for MockTransaction {
fn execute<'a>(
&'a self,
_sql: &'a str,
_params: DbParams,
) -> graphile_worker_database::BoxFuture<'a, Result<u64, DbError>> {
Box::pin(async { Ok(3) })
}
fn fetch_all<'a>(
&'a self,
_sql: &'a str,
_params: DbParams,
) -> graphile_worker_database::BoxFuture<'a, Result<Vec<DbRow>, DbError>> {
Box::pin(async { Ok(vec![row_mapping::cells([("value", DbCell::I32(99))])]) })
}
}
impl TransactionDriver for MockTransaction {
fn commit(
self: Box<Self>,
) -> graphile_worker_database::BoxFuture<'static, Result<(), DbError>> {
Box::pin(async { Ok(()) })
}
}
#[tokio::test]
async fn database_wrapper_delegates_to_inner_driver() {
let database = Database::new(MockDriver {
rows: vec![row_mapping::cells([("value", DbCell::I32(1))])],
});
let cloned = Database::from(&database);
assert!(format!("{database:?}").contains("Database"));
assert!(cloned.downcast_ref::<MockDriver>().is_some());
assert_eq!(
database
.fetch_one("select value", DbParams::new())
.await
.unwrap()
.try_get::<i32>("value")
.unwrap(),
1
);
assert_eq!(
database
.fetch_optional("select value", DbParams::new())
.await
.unwrap()
.expect("one row")
.try_get::<i32>("value")
.unwrap(),
1
);
assert!(database.listen("events").await.unwrap().is_none());
assert_eq!(
database.execute("select 1", DbParams::new()).await.unwrap(),
7
);
let tx = database.begin().await.unwrap();
assert_eq!(
tx.fetch_one("select value", DbParams::new())
.await
.unwrap()
.try_get::<i32>("value")
.unwrap(),
99
);
assert_eq!(tx.execute("select 1", DbParams::new()).await.unwrap(), 3);
tx.commit().await.unwrap();
let empty_database = Database::new(MockDriver { rows: vec![] });
assert!(empty_database
.fetch_optional("select value", DbParams::new())
.await
.unwrap()
.is_none());
assert!(empty_database
.fetch_one("select value", DbParams::new())
.await
.unwrap_err()
.to_string()
.contains("exactly one row"));
}
async fn exercise_executor(executor: &impl DbExecutor) {
let ts = timestamp();
let rows = executor
.fetch_all(
r#"
select
$1::bool as bool_value,
$2::bool as bool_null,
$3::int2 as i16_value,
$4::int2 as i16_null,
$5::int4 as i32_value,
$6::int4 as i32_null,
$7::int8 as i64_value,
$8::int8 as i64_null,
$9::jsonb as json_value,
$10::jsonb as json_null,
$11::text as text_value,
$12::text as text_null,
array_length($13::text[], 1)::int4 as text_array_len,
array_length($14::text[], 1)::int4 as text_array_null_len,
array_length($15::int4[], 1)::int4 as i32_array_len,
array_length($16::int8[], 1)::int4 as i64_array_len,
$17::timestamptz as timestamp_value,
$18::timestamptz as timestamp_null
"#,
DbParams::from(vec![
DbValue::Bool(true),
DbValue::BoolOpt(None),
DbValue::I16(16),
DbValue::I16Opt(None),
DbValue::I32(32),
DbValue::I32Opt(None),
DbValue::I64(64),
DbValue::I64Opt(None),
DbValue::Json(json!({ "driver": "covered" })),
DbValue::JsonOpt(None),
DbValue::Text("hello".to_string()),
DbValue::TextOpt(None),
DbValue::TextArray(vec!["a".to_string(), "b".to_string()]),
DbValue::TextArrayOpt(None),
DbValue::I32Array(vec![1, 2, 3]),
DbValue::I64Array(vec![4, 5]),
DbValue::TimestampTz(ts),
DbValue::TimestampTzOpt(None),
]),
)
.await
.unwrap();
let row = rows.first().expect("one row");
assert!(row.try_get::<bool>("bool_value").unwrap());
assert_eq!(row.try_get::<Option<bool>>("bool_null").unwrap(), None);
assert_eq!(row.try_get::<i16>("i16_value").unwrap(), 16);
assert_eq!(row.try_get::<Option<i16>>("i16_null").unwrap(), None);
assert_eq!(row.try_get::<i32>("i32_value").unwrap(), 32);
assert_eq!(row.try_get::<Option<i32>>("i32_null").unwrap(), None);
assert_eq!(row.try_get::<i64>("i64_value").unwrap(), 64);
assert_eq!(row.try_get::<Option<i64>>("i64_null").unwrap(), None);
assert_eq!(
row.try_get::<serde_json::Value>("json_value").unwrap(),
json!({ "driver": "covered" })
);
assert_eq!(
row.try_get::<Option<serde_json::Value>>("json_null")
.unwrap(),
None
);
assert_eq!(row.try_get::<String>("text_value").unwrap(), "hello");
assert_eq!(row.try_get::<Option<String>>("text_null").unwrap(), None);
assert_eq!(row.try_get::<i32>("text_array_len").unwrap(), 2);
assert_eq!(
row.try_get::<Option<i32>>("text_array_null_len").unwrap(),
None
);
assert_eq!(row.try_get::<i32>("i32_array_len").unwrap(), 3);
assert_eq!(row.try_get::<i32>("i64_array_len").unwrap(), 2);
assert_eq!(row.try_get::<DateTime<Utc>>("timestamp_value").unwrap(), ts);
assert_eq!(
row.try_get::<Option<DateTime<Utc>>>("timestamp_null")
.unwrap(),
None
);
let fetched = executor
.fetch_optional("select 123::int4 as value", DbParams::new())
.await
.unwrap()
.expect("one row");
assert_eq!(fetched.try_get::<i32>("value").unwrap(), 123);
let first = executor
.fetch_optional(
"select value::int4 from unnest(array[1, 2]) as values(value)",
DbParams::new(),
)
.await
.unwrap()
.expect("one row");
assert_eq!(first.try_get::<i32>("value").unwrap(), 1);
let none = executor
.fetch_optional("select 123::int4 as value where false", DbParams::new())
.await
.unwrap();
assert!(none.is_none());
assert!(executor
.fetch_one("select 123::int4 as value where false", DbParams::new())
.await
.unwrap_err()
.to_string()
.contains("exactly one row"));
assert!(executor
.fetch_all("select 1.5::numeric as unsupported", DbParams::new())
.await
.unwrap_err()
.to_string()
.contains("unsupported PostgreSQL result type"));
executor
.execute("select $1::int4", DbParams::from(vec![DbValue::I32(1)]))
.await
.unwrap();
}
async fn exercise_database(database: &Database) {
exercise_executor(database).await;
let tx = database.begin().await.unwrap();
tx.execute(
"create temp table database_driver_contract(value int4) on commit drop",
DbParams::new(),
)
.await
.unwrap();
tx.execute(
"insert into database_driver_contract(value) values ($1)",
DbParams::from(vec![DbValue::I32(42)]),
)
.await
.unwrap();
assert_eq!(
tx.fetch_one(
"select value from database_driver_contract",
DbParams::new()
)
.await
.unwrap()
.try_get::<i32>("value")
.unwrap(),
42
);
tx.commit().await.unwrap();
let tx = database.begin().await.unwrap();
tx.execute("select 1", DbParams::new()).await.unwrap();
drop(tx);
tokio::time::sleep(Duration::from_millis(20)).await;
}
async fn exercise_listen(database: &Database, channel: &str) {
let mut stream = database
.listen(channel)
.await
.unwrap()
.expect("driver should support notifications");
database
.execute(
&format!("notify {channel}, 'driver-payload'"),
DbParams::new(),
)
.await
.unwrap();
let notification = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("notification should arrive")
.expect("notification stream should stay open")
.unwrap();
assert_eq!(notification.channel, channel);
assert_eq!(notification.payload, "driver-payload");
}
#[cfg(feature = "driver-tokio-postgres")]
async fn wait_for_tokio_postgres_listener_pid(
database: &Database,
application_name: &str,
listen_query: &str,
excluded_pid: Option<i32>,
) -> i32 {
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
loop {
let row = database
.fetch_optional(
r#"
select pid::int4 as pid
from pg_stat_activity
where application_name = $1
and query = $2
and ($3::int4 is null or pid <> $3)
order by backend_start desc
limit 1
"#,
DbParams::from(vec![
DbValue::Text(application_name.to_string()),
DbValue::Text(listen_query.to_string()),
DbValue::I32Opt(excluded_pid),
]),
)
.await
.unwrap();
if let Some(row) = row {
return row.try_get::<i32>("pid").unwrap();
}
assert!(
tokio::time::Instant::now() < deadline,
"listener backend should reconnect before timeout"
);
tokio::time::sleep(Duration::from_millis(20)).await;
}
}
#[cfg(feature = "driver-tokio-postgres")]
async fn notify(database: &Database, channel: &str, payload: &str) {
database
.execute(
"select pg_notify($1, $2)",
DbParams::from(vec![
DbValue::Text(channel.to_string()),
DbValue::Text(payload.to_string()),
]),
)
.await
.unwrap();
}
#[cfg(feature = "driver-tokio-postgres")]
async fn expect_notification(stream: &mut NotificationStream, channel: &str, payload: &str) {
let notification = tokio::time::timeout(Duration::from_secs(5), stream.next())
.await
.expect("notification should arrive")
.expect("notification stream should stay open")
.unwrap();
assert_eq!(notification.channel, channel);
assert_eq!(notification.payload, payload);
}
#[cfg(feature = "driver-sqlx")]
#[tokio::test]
async fn sqlx_driver_satisfies_database_contract() {
let pool = ::sqlx::PgPool::connect(&database_url()).await.unwrap();
let sqlx_database = SqlxDatabase::new(pool.clone());
assert!(fmt::format(format_args!("{sqlx_database:?}")).contains("SqlxDatabase"));
assert!(!sqlx_database.pool().is_closed());
let database = Database::from(sqlx_database.clone());
assert!(database.downcast_ref::<SqlxDatabase>().is_some());
exercise_database(&database).await;
exercise_listen(&database, &unique_channel("database_driver_sqlx")).await;
let from_pool = Database::from(pool.clone());
exercise_database(&from_pool).await;
let from_pool_ref = Database::from(&pool);
exercise_database(&from_pool_ref).await;
}
#[cfg(feature = "driver-tokio-postgres")]
#[tokio::test]
async fn tokio_postgres_driver_satisfies_database_contract() {
let tokio_database = TokioPostgresDatabase::from_url(&database_url(), 4).unwrap();
assert!(fmt::format(format_args!("{tokio_database:?}")).contains("TokioPostgresDatabase"));
let database = Database::from(tokio_database.clone());
assert!(database.downcast_ref::<TokioPostgresDatabase>().is_some());
exercise_database(&database).await;
exercise_listen(&database, &unique_channel("database_driver_tokio")).await;
let pool_database = TokioPostgresDatabase::new(tokio_database.pool().clone());
assert!(pool_database
.listen("without_config")
.await
.unwrap()
.is_none());
let from_pool = Database::from(tokio_database.pool().clone());
exercise_executor(&from_pool).await;
}
#[cfg(feature = "driver-tokio-postgres")]
#[tokio::test]
async fn tokio_postgres_listener_reconnects_after_connection_loss() {
let channel = unique_channel("database_driver_tokio_reconnect");
let application_name = unique_channel("database_driver_tokio_listener");
let listen_query = format!(r#"LISTEN "{channel}""#);
let mut config = database_url().parse::<tokio_postgres::Config>().unwrap();
config.application_name(&application_name);
let tokio_database = TokioPostgresDatabase::from_config(config, 4).unwrap();
let database = Database::from(tokio_database);
let mut stream = database
.listen(&channel)
.await
.unwrap()
.expect("driver should support notifications");
let first_pid =
wait_for_tokio_postgres_listener_pid(&database, &application_name, &listen_query, None)
.await;
notify(&database, &channel, "before-reconnect").await;
expect_notification(&mut stream, &channel, "before-reconnect").await;
let terminated = database
.fetch_one(
"select pg_terminate_backend($1) as terminated",
DbParams::from(vec![DbValue::I32(first_pid)]),
)
.await
.unwrap();
assert!(terminated.try_get::<bool>("terminated").unwrap());
let _second_pid = wait_for_tokio_postgres_listener_pid(
&database,
&application_name,
&listen_query,
Some(first_pid),
)
.await;
notify(&database, &channel, "after-reconnect").await;
expect_notification(&mut stream, &channel, "after-reconnect").await;
}