use pgmq::pg_ext::VisibilityTimeoutOffset;
use pgmq::types::{ARCHIVE_PREFIX, PGMQ_SCHEMA, QUEUE_PREFIX};
use pgmq::util::connect;
use rand::Rng;
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Postgres, Row};
use std::env;
fn replace_db_string(s: &str, replacement: &str) -> String {
match s.rfind('/') {
Some(pos) => {
let prefix = &s[0..pos];
format!("{prefix}{replacement}")
}
None => s.to_string(),
}
}
async fn init_queue_ext(qname: &str) -> pgmq::PGMQueueExt {
let db_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_owned());
let queue = pgmq::PGMQueueExt::new(db_url.clone(), 2)
.await
.expect("failed to connect to postgres");
let _ = sqlx::query("CREATE DATABASE pgmq_ext_test;")
.execute(&queue.connection)
.await;
let test_db_str = replace_db_string(&db_url, "/pgmq_ext_test");
let queue = pgmq::PGMQueueExt::new(test_db_str.clone(), 2)
.await
.expect("failed to connect to test db");
install_pgmq(&queue).await;
let _ = queue.drop_queue(qname).await;
let q_success = queue.create(qname).await;
println!("q_success: {q_success:?}");
assert!(q_success.is_ok());
queue
}
#[derive(Serialize, Debug, Deserialize, Eq, PartialEq)]
struct MyMessage {
foo: String,
num: u64,
}
impl Default for MyMessage {
fn default() -> Self {
MyMessage {
foo: "bar".to_owned(),
num: rand::thread_rng().gen_range(0..100),
}
}
}
#[derive(Serialize, Debug, Deserialize)]
struct YoloMessage {
yolo: String,
}
async fn rowcount(qname: &str, connection: &Pool<Postgres>) -> i64 {
let row_ct_query = format!("SELECT count(*) as ct FROM {PGMQ_SCHEMA}.{QUEUE_PREFIX}_{qname}");
sqlx::query(&row_ct_query)
.fetch_one(connection)
.await
.unwrap()
.get::<i64, usize>(0)
}
async fn archive_rowcount(qname: &str, connection: &Pool<Postgres>) -> i64 {
let row_ct_query = format!("SELECT count(*) as ct FROM {PGMQ_SCHEMA}.{ARCHIVE_PREFIX}_{qname}");
sqlx::query(&row_ct_query)
.fetch_one(connection)
.await
.unwrap()
.get::<i64, usize>(0)
}
async fn install_pgmq(queue: &pgmq::PGMQueueExt) -> bool {
#[cfg(feature = "install-sql-embedded")]
let result = queue.install_sql_from_embedded().await.map(|_| true);
#[cfg(not(feature = "install-sql"))]
let result = queue.init().await;
result.expect("failed to init pgmq")
}
#[tokio::test]
async fn test_ext_create_list_drop() {
let test_queue = format!(
"test_ext_create_list_drop_{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let q_names = queue
.list_queues()
.await
.expect("error listing queues")
.expect("test queue was not created")
.iter()
.map(|q| q.queue_name.clone())
.collect::<Vec<String>>();
assert!(q_names.contains(&test_queue));
queue
.drop_queue(&test_queue)
.await
.expect("error dropping queue");
let post_drop_q_names = queue
.list_queues()
.await
.expect("error listing queues")
.unwrap_or(vec![])
.iter()
.map(|q| q.queue_name.clone())
.collect::<Vec<String>>();
assert!(!post_drop_q_names.contains(&test_queue));
}
async fn test_ext_send_read_delete_core<T: Into<VisibilityTimeoutOffset>>(
offset1: T,
offset2: T,
offset3: T,
offset4: T,
offset5: T,
) {
let test_queue = format!(
"test_ext_send_read_delete_{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
let num_rows_queue = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_queue, 0);
let msg_id = queue.send(&test_queue, &msg).await.unwrap();
assert!(msg_id >= 1);
let read_message = queue
.read::<MyMessage>(&test_queue, offset1)
.await
.expect("error reading message");
assert!(read_message.is_some());
let read_message = read_message.unwrap();
assert_eq!(read_message.msg_id, msg_id);
assert_eq!(read_message.message, msg);
let read_message = queue
.read::<MyMessage>(&test_queue, offset2)
.await
.expect("error reading message");
assert!(read_message.is_none());
let start_poll = std::time::Instant::now();
let read_with_poll = queue
.read_batch_with_poll::<MyMessage>(
&test_queue,
offset3,
1,
Some(std::time::Duration::from_secs(6)),
None,
)
.await
.expect("error reading message")
.expect("no message");
let poll_duration = start_poll.elapsed();
assert!(poll_duration.as_millis() > 1000);
assert_eq!(read_with_poll.len(), 1);
assert_eq!(read_with_poll[0].msg_id, msg_id);
let _vt_set = queue
.set_vt::<MyMessage>(&test_queue, msg_id, offset4)
.await
.expect("failed to set VT");
let read_message = queue
.read::<MyMessage>(&test_queue, offset5)
.await
.expect("error reading message")
.expect("expected a message");
assert_eq!(read_message.msg_id, msg_id);
let msg_id_del = queue.send(&test_queue, &msg).await.unwrap();
let deleted = queue
.delete(&test_queue, msg_id_del)
.await
.expect("failed to delete");
assert!(deleted);
let deleted = queue
.delete(&test_queue, msg_id_del)
.await
.expect("failed to delete");
assert!(!deleted);
}
#[tokio::test]
async fn test_ext_send_read_delete_i32() {
test_ext_send_read_delete_core(5i32, 2i32, 5i32, 0i32, 1i32).await;
}
#[tokio::test]
async fn test_ext_send_read_delete_i64() {
test_ext_send_read_delete_core(5i64, 2i64, 5i64, 0i64, 1i64).await;
}
#[tokio::test]
async fn test_ext_send_read_delete_u32() {
test_ext_send_read_delete_core(5u32, 2u32, 5u32, 0u32, 1u32).await;
}
#[tokio::test]
async fn test_ext_send_read_delete_u64() {
test_ext_send_read_delete_core(5u64, 2u64, 5u64, 0u64, 1u64).await;
}
#[tokio::test]
async fn test_ext_send_read_delete_chrono() {
test_ext_send_read_delete_core(
chrono::Duration::seconds(5),
chrono::Duration::seconds(2),
chrono::Duration::seconds(5),
chrono::Duration::seconds(0),
chrono::Duration::seconds(1),
)
.await;
}
#[tokio::test]
async fn test_ext_send_read_delete_std() {
test_ext_send_read_delete_core(
std::time::Duration::from_secs(5),
std::time::Duration::from_secs(2),
std::time::Duration::from_secs(5),
std::time::Duration::from_secs(0),
std::time::Duration::from_secs(1),
)
.await;
}
#[tokio::test]
async fn test_ext_send_read_delete_vt_offset() {
test_ext_send_read_delete_core(
VisibilityTimeoutOffset::seconds(5),
VisibilityTimeoutOffset::seconds(2),
VisibilityTimeoutOffset::seconds(5),
VisibilityTimeoutOffset::seconds(0),
VisibilityTimeoutOffset::seconds(1),
)
.await;
}
async fn test_ext_send_delay_core(delay: impl Into<VisibilityTimeoutOffset>) {
let test_queue = format!(
"test_ext_send_delay_{}",
rand::thread_rng().gen_range(0..100000)
);
let vt = 1;
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
queue.send_delay(&test_queue, &msg, delay).await.unwrap();
let no_messages = queue.read::<MyMessage>(&test_queue, vt).await.unwrap();
assert!(no_messages.is_none());
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
let one_messages = queue.read::<MyMessage>(&test_queue, vt).await.unwrap();
assert!(one_messages.is_some());
}
#[tokio::test]
async fn test_ext_send_delay_i32() {
test_ext_send_delay_core(5i32).await;
}
#[tokio::test]
async fn test_ext_send_delay_i64() {
test_ext_send_delay_core(5i64).await;
}
#[tokio::test]
async fn test_ext_send_delay_u32() {
test_ext_send_delay_core(5u32).await;
}
#[tokio::test]
async fn test_ext_send_delay_u64() {
test_ext_send_delay_core(5u64).await;
}
#[tokio::test]
async fn test_ext_send_delay_chrono() {
test_ext_send_delay_core(chrono::Duration::seconds(5)).await;
}
#[tokio::test]
async fn test_ext_send_delay_std() {
test_ext_send_delay_core(std::time::Duration::from_secs(5)).await;
}
#[tokio::test]
async fn test_ext_send_delay_vt_offset() {
test_ext_send_delay_core(VisibilityTimeoutOffset::seconds(5)).await;
}
#[tokio::test]
async fn test_ext_send_pop() {
let test_queue = format!(
"test_ext_send_pop_{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let popped = queue
.pop::<MyMessage>(&test_queue)
.await
.expect("failed to pop")
.expect("no message to pop");
assert_eq!(popped.message, msg);
}
#[tokio::test]
async fn test_ext_send_archive() {
let test_queue = format!(
"test_ext_send_archive_{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
let msg_id = queue.send(&test_queue, &msg).await.unwrap();
let archived = queue
.archive(&test_queue, msg_id)
.await
.expect("failed to archive");
assert!(archived);
}
#[tokio::test]
async fn test_ext_archive_batch() {
let test_queue = format!(
"test_ext_archive_batch_{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
let m1 = queue.send(&test_queue, &msg).await.unwrap();
let m2 = queue.send(&test_queue, &msg).await.unwrap();
let m3 = queue.send(&test_queue, &msg).await.unwrap();
let archive_result = queue
.archive_batch(&test_queue, &[m1, m2, m3])
.await
.expect("archive batch error");
let post_archive_rowcount = rowcount(&test_queue, &queue.connection).await;
assert_eq!(post_archive_rowcount, 0);
assert_eq!(archive_result, 3);
let post_archive_archive_rowcount = archive_rowcount(&test_queue, &queue.connection).await;
assert_eq!(post_archive_archive_rowcount, 3);
}
#[tokio::test]
async fn test_ext_delete_batch() {
let test_queue = format!(
"test_ext_delete_batch{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
let m1 = queue.send(&test_queue, &msg).await.unwrap();
let m2 = queue.send(&test_queue, &msg).await.unwrap();
let m3 = queue.send(&test_queue, &msg).await.unwrap();
let delete_result = queue
.delete_batch(&test_queue, &[m1, m2, m3])
.await
.expect("delete batch error");
let post_delete_rowcount = rowcount(&test_queue, &queue.connection).await;
assert_eq!(post_delete_rowcount, 0);
assert_eq!(delete_result, 3);
}
#[tokio::test]
async fn test_ext_purge_queue() {
let test_queue = format!(
"test_ext_purge_queue{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let msg = MyMessage::default();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let purged_count = queue
.purge_queue(&test_queue)
.await
.expect("purge queue error");
assert_eq!(purged_count, 3);
let post_purge_rowcount = rowcount(&test_queue, &queue.connection).await;
assert_eq!(post_purge_rowcount, 0);
}
#[tokio::test]
async fn test_pgmq_init() {
let test_queue = format!(
"test_ext_init_queue{}",
rand::thread_rng().gen_range(0..100000)
);
let queue = init_queue_ext(&test_queue).await;
let _ = sqlx::query("CREATE EXTENSION IF NOT EXISTS pg_partman")
.execute(&queue.connection)
.await
.expect("failed to create extension");
let qname = format!("test_dup_{}", rand::thread_rng().gen_range(0..100));
let created = queue
.create_partitioned(&qname)
.await
.expect("failed attempting to create queue");
assert!(created, "did not create queue");
let created = queue
.create_partitioned(&qname)
.await
.expect("failed attempting to create the duplicate queue");
assert!(!created, "failed to detect duplicate queue");
}
#[tokio::test]
async fn test_create_txn() {
let _q = format!("_q_{}", rand::thread_rng().gen_range(0..100000));
let _queue = init_queue_ext(&_q).await;
let pool = _queue.connection;
let queue = init_queue_ext(&_q).await;
let mut tx = pool.begin().await.expect("failed to start transaction");
let q = format!(
"test_create_txn_{}",
rand::thread_rng().gen_range(0..100000)
);
queue
.create_with_cxn(&q, &mut *tx)
.await
.expect("failed to create queue in txn");
tx.commit().await.expect("failed to commit txn");
let q_names = queue
.list_queues()
.await
.expect("error listing queues")
.expect("test queue was not created")
.iter()
.map(|q| q.queue_name.clone())
.collect::<Vec<_>>();
assert!(q_names.contains(&q), "failed to find created queue");
let mut tx = pool.begin().await.expect("failed to start transaction");
let q_rollback = format!(
"test_create_txn_rb_{}",
rand::thread_rng().gen_range(0..100000)
);
queue
.create_with_cxn(&q_rollback, &mut *tx)
.await
.expect("failed to create queue in txn");
tx.rollback().await.expect("failed to rollback txn");
let q_names = queue
.list_queues()
.await
.expect("error listing queues")
.expect("test queue was not created")
.iter()
.map(|q| q.queue_name.clone())
.collect::<Vec<_>>();
assert!(
!q_names.contains(&q_rollback),
"found queue that should not exist"
);
}
#[tokio::test]
async fn test_byop() {
let _q = format!("test_byop_{}", rand::thread_rng().gen_range(0..100000));
let _queue = init_queue_ext(&_q).await;
let pool = _queue.connection;
let queue = pgmq::PGMQueueExt::new_with_pool(pool).await;
let init = install_pgmq(&queue).await;
assert!(init, "failed to create extension");
let test_queue = format!("test_byop_{}", rand::thread_rng().gen_range(0..100000));
let created = queue
.create(&test_queue)
.await
.expect("failed to create queue");
assert!(created, "failed to create queue_{}", test_queue);
let created = queue
.create(&test_queue)
.await
.expect("failed execute create queue");
assert!(!created, "failed to detect duplicate queue");
}
#[tokio::test]
async fn test_transactional() {
let test_queue = format!("test_tx_{}", rand::thread_rng().gen_range(0..100000));
let db_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_owned());
let pool_0 = connect(&db_url, 2)
.await
.expect("failed to connect to postgres");
let pool_1 = connect(&db_url, 2)
.await
.expect("failed to connect to postgres");
let queue = pgmq::PGMQueueExt::new_with_pool(pool_0.clone()).await;
let init = install_pgmq(&queue).await;
assert!(init, "failed to create extension");
let created = queue
.create_with_cxn(&test_queue, &pool_0)
.await
.expect("failed to create queue");
assert!(created);
let mut tx = pool_0.begin().await.expect("failed to start transaction");
let sent_msg = queue
.send_with_cxn(&test_queue, &MyMessage::default(), &mut *tx)
.await
.expect("failed to send message");
assert_eq!(sent_msg, 1);
let query = format!("SELECT count(*) FROM pgmq.q_{test_queue}");
let rows = sqlx::query(&query)
.fetch_one(&pool_1)
.await
.expect("failed to fetch row")
.get::<i64, usize>(0);
assert_eq!(rows, 0);
tx.commit().await.expect("failed to commit transaction");
let rows = sqlx::query(&query)
.fetch_one(&pool_1)
.await
.expect("failed to fetch row")
.get::<i64, usize>(0);
assert_eq!(rows, 1);
}
#[tokio::test]
async fn test_create_queue_race_condition() {
let queue_name = format!("test_tx_{}", rand::thread_rng().gen_range(0..100000));
let db_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_owned());
let pool = connect(&db_url, 2)
.await
.expect("failed to connect to postgres");
let queue = pgmq::PGMQueueExt::new_with_pool(pool).await;
let init = install_pgmq(&queue).await;
assert!(init, "failed to create extension");
let mut conn1 = queue.connection.acquire().await.unwrap();
let mut conn2 = queue.connection.acquire().await.unwrap();
let (result1, result2) = tokio::try_join!(
queue.create_with_cxn(&queue_name, &mut conn1),
queue.create_with_cxn(&queue_name, &mut conn2)
)
.unwrap();
assert_ne!(result1, result2);
}