use chrono::{Duration, Utc};
use pgmq::{
errors::PgmqError,
types::{Message, ARCHIVE_PREFIX, PGMQ_SCHEMA, QUEUE_PREFIX},
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{Pool, Postgres, Row};
use std::env;
async fn init_queue(qname: &str) -> pgmq::PGMQueue {
let db_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_owned());
let queue = pgmq::PGMQueue::new(db_url)
.await
.expect("failed to connect to postgres");
queue.destroy(qname).await.unwrap();
let random_sleep_ms = rand::thread_rng().gen_range(0..1000);
tokio::time::sleep(std::time::Duration::from_millis(random_sleep_ms)).await;
let q_success = queue.create(qname).await;
println!("q_success: {q_success:?}");
assert!(q_success.is_ok());
queue
}
#[derive(Serialize, Debug, Deserialize, Eq, PartialEq)]
struct MyMessage {
foo: String,
num: u64,
}
impl Default for MyMessage {
fn default() -> Self {
MyMessage {
foo: "bar".to_owned(),
num: rand::thread_rng().gen_range(0..100),
}
}
}
#[derive(Serialize, Debug, Deserialize)]
struct YoloMessage {
yolo: String,
}
async fn rowcount(qname: &str, connection: &Pool<Postgres>) -> i64 {
let row_ct_query = format!("SELECT count(*) as ct FROM {PGMQ_SCHEMA}.{QUEUE_PREFIX}_{qname}");
sqlx::query(&row_ct_query)
.fetch_one(connection)
.await
.unwrap()
.get::<i64, usize>(0)
}
async fn archive_rowcount(qname: &str, connection: &Pool<Postgres>) -> i64 {
let row_ct_query = format!("SELECT count(*) as ct FROM {PGMQ_SCHEMA}.{ARCHIVE_PREFIX}_{qname}");
sqlx::query(&row_ct_query)
.fetch_one(connection)
.await
.unwrap()
.get::<i64, usize>(0)
}
async fn fallible_rowcount(qname: &str, connection: &Pool<Postgres>) -> Result<i64, PgmqError> {
let row_ct_query = format!("SELECT count(*) as ct FROM {PGMQ_SCHEMA}.{QUEUE_PREFIX}_{qname}");
Ok(sqlx::query(&row_ct_query)
.fetch_one(connection)
.await?
.get::<i64, usize>(0))
}
async fn fallible_archive_rowcount(
qname: &str,
connection: &Pool<Postgres>,
) -> Result<i64, PgmqError> {
let row_ct_query = format!("SELECT count(*) as ct FROM {PGMQ_SCHEMA}.{ARCHIVE_PREFIX}_{qname}");
Ok(sqlx::query(&row_ct_query)
.fetch_one(connection)
.await?
.get::<i64, usize>(0))
}
#[tokio::test]
async fn test_lifecycle() {
let test_queue = "test_queue_0".to_owned();
let queue = init_queue(&test_queue).await;
let q_success = queue.create(&test_queue).await;
assert!(q_success.is_ok());
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 0);
let msg = serde_json::json!({
"foo": "bar"
});
let msg_id = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id, 1);
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 1);
let vt = 2;
let msg1 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
assert_eq!(msg1.msg_id, 1);
let no_messages = queue.read::<Value>(&test_queue, Some(vt)).await.unwrap();
assert!(no_messages.is_none());
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 1);
tokio::time::sleep(std::time::Duration::from_secs(vt as u64)).await;
let msg2 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
assert_eq!(msg2.msg_id, 1);
let deleted = queue.delete(&test_queue, msg1.msg_id).await.unwrap();
assert_eq!(deleted, 1);
let msg3 = queue.read::<Value>(&test_queue, Some(vt)).await.unwrap();
assert!(msg3.is_none());
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 0);
}
#[tokio::test]
async fn test_fifo() {
let test_queue = "test_fifo_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = serde_json::json!({
"foo": "bar1"
});
let msg_id1 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id1, 1);
let msg_id2 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id2, 2);
let msg_id3 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id3, 3);
let vt: i32 = 1;
let read1 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
let read2 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
assert_eq!(read2.msg_id, 2);
assert_eq!(read1.msg_id, 1);
tokio::time::sleep(std::time::Duration::from_secs(vt as u64)).await;
tokio::time::sleep(std::time::Duration::from_secs(vt as u64)).await;
let read1 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
let read2 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
let read3 = queue
.read::<Value>(&test_queue, Some(vt))
.await
.unwrap()
.unwrap();
assert_eq!(read1.msg_id, 1);
assert_eq!(read2.msg_id, 2);
assert_eq!(read3.msg_id, 3);
}
#[tokio::test]
async fn test_send_delay() {
let vt: i32 = 1;
let test_queue = "test_send_delay_queue".to_owned();
let queue = init_queue(&test_queue).await;
let q_success = queue.create(&test_queue).await;
assert!(q_success.is_ok());
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 0);
let msg = serde_json::json!({
"foo": "bar"
});
let _ = queue.send_delay(&test_queue, &msg, 5).await.unwrap();
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 1);
let no_messages = queue.read::<Value>(&test_queue, Some(vt)).await.unwrap();
assert!(no_messages.is_none());
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
let one_messages = queue.read::<Value>(&test_queue, Some(vt)).await.unwrap();
assert!(one_messages.is_some());
}
#[tokio::test]
async fn test_read_batch_with_poll() {
let test_queue = "test_read_batch_with_poll".to_owned();
let queue = init_queue(&test_queue).await;
let msg = serde_json::json!({
"foo": "bar1"
});
let msg_id1 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id1, 1);
let msg_id2 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id2, 2);
let msg_id3 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id3, 3);
let read_message_1 = queue
.read_batch_with_poll::<Value>(
&test_queue,
Some(5),
5,
Some(std::time::Duration::from_secs(6)),
None,
)
.await
.unwrap()
.unwrap();
assert_eq!(read_message_1.len(), 3);
let starting_time = std::time::Instant::now();
let read_message_2 = queue
.read_batch_with_poll::<Value>(
&test_queue,
Some(5),
5,
Some(std::time::Duration::from_secs(6)),
None,
)
.await
.unwrap()
.unwrap();
assert_eq!(read_message_2.len(), 3);
assert!(starting_time.elapsed() > std::time::Duration::from_secs(3));
let read_message_3 = queue
.read_batch_with_poll::<Value>(
&test_queue,
Some(3),
5,
Some(std::time::Duration::from_secs(1)),
None,
)
.await
.unwrap();
assert!(read_message_3.is_none());
}
#[tokio::test]
async fn test_read_batch() {
let test_queue = "test_read_batch".to_owned();
let queue = init_queue(&test_queue).await;
let msg = serde_json::json!({
"foo": "bar1"
});
let msg_id1 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id1, 1);
let msg_id2 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id2, 2);
let msg_id3 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg_id3, 3);
let vt: i32 = 1;
let num_msgs = 3;
let batch = queue
.read_batch::<Value>(&test_queue, Some(vt), num_msgs)
.await
.unwrap()
.unwrap();
for (i, message) in batch.iter().enumerate() {
let index = i + 1;
assert_eq!(message.msg_id.to_string(), index.to_string());
}
let msg = queue.read::<Value>(&test_queue, Some(vt)).await.unwrap();
assert!(msg.is_none());
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 3);
}
#[tokio::test]
async fn test_send_batch() {
let test_queue = "test_send_batch".to_owned();
let queue = init_queue(&test_queue).await;
let msgs = vec![
serde_json::json!({"foo": "bar1"}),
serde_json::json!({"foo": "bar2"}),
serde_json::json!({"foo": "bar3"}),
];
let msg_ids = queue
.send_batch(&test_queue, &msgs)
.await
.expect("Failed to enqueue messages");
for (i, id) in msg_ids.iter().enumerate() {
assert_eq!(id.to_string(), msg_ids[i].to_string());
}
#[derive(Serialize, Debug, Deserialize)]
struct MyMessage {
foo: String,
}
let msgs2 = vec![
MyMessage {
foo: "bar1".to_owned(),
},
MyMessage {
foo: "bar2".to_owned(),
},
MyMessage {
foo: "bar3".to_owned(),
},
];
let msg_ids2 = queue
.send_batch(&test_queue, &msgs2)
.await
.expect("Failed to enqueue messages");
for (i, id) in msg_ids2.iter().enumerate() {
assert_eq!(id.to_string(), msg_ids2[i].to_string());
}
let vt: i32 = 1;
let num_msgs = 3;
let batch = queue
.read_batch::<Value>(&test_queue, Some(vt), num_msgs)
.await
.unwrap()
.unwrap();
for (i, message) in batch.iter().enumerate() {
let index = i + 1;
assert_eq!(message.msg_id.to_string(), index.to_string());
}
}
#[tokio::test]
async fn test_delete_batch() {
let test_queue = "test_delete_batch".to_owned();
let queue = init_queue(&test_queue).await;
let vt: i32 = 1;
let mut msg_id_first_last: Vec<i64> = Vec::new();
let msg_first = serde_json::json!({
"foo": "first"
});
let msg_id1 = queue.send(&test_queue, &msg_first).await.unwrap();
assert_eq!(msg_id1, 1);
msg_id_first_last.push(msg_id1);
let msgs = vec![
serde_json::json!({"foo": "bar1"}),
serde_json::json!({"foo": "bar2"}),
serde_json::json!({"foo": "bar3"}),
];
let msg_ids = queue
.send_batch(&test_queue, &msgs)
.await
.expect("Failed to enqueue messages");
for (i, id) in msg_ids.iter().enumerate() {
assert_eq!(id.to_string(), msg_ids[i].to_string());
}
let msg_last = serde_json::json!({
"foo": "last"
});
let msg_id2 = queue.send(&test_queue, &msg_last).await.unwrap();
assert_eq!(msg_id2, 5);
msg_id_first_last.push(msg_id2);
let del = queue
.delete_batch(&test_queue, &msg_ids)
.await
.expect("Failed to delete messages from queue");
assert_eq!(del.to_string(), msg_ids.len().to_string());
let first = queue
.read::<Value>(&test_queue, Some(vt))
.await
.expect("Failed to read message");
assert_eq!(first.unwrap().msg_id, msg_id1);
let last = queue
.read::<Value>(&test_queue, Some(vt))
.await
.expect("Failed to read message");
assert_eq!(last.unwrap().msg_id, msg_id2);
let del_first_last = queue
.delete_batch(&test_queue, &msg_id_first_last)
.await
.expect("Failed to delete messages from queue");
assert_eq!(
del_first_last.to_string(),
msg_id_first_last.len().to_string()
);
let msg = queue.read::<Value>(&test_queue, Some(vt)).await.unwrap();
assert!(msg.is_none());
}
#[tokio::test]
async fn test_serde() {
let mut rng = rand::thread_rng();
let test_queue = "test_ser_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage {
foo: "bar".to_owned(),
num: rng.gen_range(0..100000),
};
let msg1 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg1, 1);
let msg_read = queue
.read::<MyMessage>(&test_queue, Some(30_i32))
.await
.unwrap()
.unwrap();
let _ = queue.delete(&test_queue, msg_read.msg_id).await;
assert_eq!(msg_read.message.num, msg.num);
let msg = serde_json::json!({
"foo": "bar",
"num": rng.gen_range(0..100000)
});
let msg2 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg2, 2);
let msg_read = queue
.read::<Value>(&test_queue, Some(30_i32))
.await
.unwrap()
.unwrap();
let _ = queue.delete(&test_queue, msg_read.msg_id).await.unwrap();
assert_eq!(msg_read.message["num"], msg["num"]);
assert_eq!(msg_read.message["foo"], msg["foo"]);
let msg = serde_json::json!({
"foo": "bar",
"num": rng.gen_range(0..100000)
});
let msg3 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg3, 3);
let msg_read = queue
.read::<MyMessage>(&test_queue, Some(30_i32))
.await
.unwrap()
.unwrap();
queue.delete(&test_queue, msg_read.msg_id).await.unwrap();
assert_eq!(msg_read.message.foo, msg["foo"].to_owned());
assert_eq!(msg_read.message.num, msg["num"].as_u64().unwrap());
let msg = MyMessage {
foo: "bar".to_owned(),
num: rng.gen_range(0..100000),
};
let msg4 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg4, 4);
let msg_read = queue
.read::<Value>(&test_queue, Some(30_i32))
.await
.unwrap()
.unwrap();
let _ = queue.delete(&test_queue, msg_read.msg_id).await;
assert_eq!(msg_read.message["foo"].to_owned(), msg.foo);
assert_eq!(msg_read.message["num"].as_u64().unwrap(), msg.num);
let msg = serde_json::json!( {
"foo": "bar".to_owned(),
"num": rng.gen_range(0..100000),
});
let msg5 = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg5, 5);
let msg_read: Message = queue
.read(&test_queue, Some(30_i32)) .await
.unwrap()
.unwrap();
let _ = queue.delete(&test_queue, msg_read.msg_id).await.unwrap();
assert_eq!(msg_read.message["foo"].to_owned(), msg["foo"].to_owned());
assert_eq!(
msg_read.message["num"].as_u64().unwrap(),
msg["num"].as_u64().unwrap()
);
}
#[tokio::test]
async fn test_pop() {
let test_queue = "test_pop_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let msg = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg, 1);
let popped_msg = queue.pop::<MyMessage>(&test_queue).await.unwrap().unwrap();
assert_eq!(popped_msg.msg_id, 1);
let num_rows = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows, 0);
}
#[tokio::test]
async fn test_archive() {
let test_queue = "test_archive_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let msg = queue.send(&test_queue, &msg).await.unwrap();
assert_eq!(msg, 1);
let num_moved = queue.archive(&test_queue, msg).await.unwrap();
assert_eq!(num_moved, 1);
let num_rows_queue = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_queue, 0);
let num_rows_archive = archive_rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_archive, 1);
}
#[tokio::test]
async fn test_archive_batch() {
let test_queue = "test_archive_batch_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let msg_1 = queue.send(&test_queue, &msg).await.unwrap();
let msg_2 = queue.send(&test_queue, &msg).await.unwrap();
let msg_3 = queue.send(&test_queue, &msg).await.unwrap();
let num_moved = queue
.archive_batch(&test_queue, &[msg_1, msg_2, msg_3])
.await
.unwrap();
assert_eq!(num_moved, 3);
let num_rows_queue = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_queue, 0);
let num_rows_archive = archive_rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_archive, 3);
}
#[tokio::test]
async fn test_database_error_modes() {
let db_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/postgres".to_owned());
let queue = pgmq::PGMQueue::new(db_url)
.await
.expect("failed to connect to postgres");
let msg_id = queue.send("doesNotExist", &"foo").await;
assert!(msg_id.is_err());
let read_msg = queue.read::<Message>("doesNotExist", Some(10_i32)).await;
assert!(read_msg.is_err());
let queue = pgmq::PGMQueue::new("postgres://DNE:5432".to_owned()).await;
match queue {
Err(e) => {
if let PgmqError::UrlParsingError { .. } = e {
} else {
panic!("expected a url parsing error, got {e:?}");
}
}
_ => panic!("expected a url parsing error, got {read_msg:?}"),
}
let queue = pgmq::PGMQueue::new("postgres://user:pass@badhost:5432".to_owned()).await;
match queue {
Err(e) => {
if let PgmqError::DatabaseError { .. } = e {
} else {
panic!("expected a db error, got {e:?}");
}
}
_ => panic!("expected a db error, got {read_msg:?}"),
}
}
#[tokio::test]
async fn test_purge() {
let test_queue = format!("test_purge_{}", rand::thread_rng().gen_range(0..100000));
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let purged_count = queue.purge(&test_queue).await.expect("purge queue error");
assert_eq!(purged_count, 3);
let post_purge_rowcount = rowcount(&test_queue, &queue.connection).await;
assert_eq!(post_purge_rowcount, 0);
}
#[tokio::test]
async fn test_parsing_error_modes() {
let test_queue = "test_parsing_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let _ = queue.send(&test_queue, &msg).await.unwrap();
let read_msg = queue.read::<YoloMessage>(&test_queue, Some(10_i32)).await;
match read_msg {
Err(e) => {
if let PgmqError::JsonParsingError { .. } = e {
} else {
panic!("expected a parse error, got {e:?}");
}
}
_ => panic!("expected a parse error, got {read_msg:?}"),
}
}
#[tokio::test]
async fn test_destroy() {
let test_queue = "test_destroy_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let msg1 = queue.send(&test_queue, &msg).await.unwrap();
let msg2 = queue.send(&test_queue, &msg).await.unwrap();
let _ = queue.archive(&test_queue, msg1).await.unwrap();
let read: Message = queue
.read(&test_queue, Some(30_i32))
.await
.unwrap()
.unwrap();
assert_eq!(read.msg_id, msg2);
queue.destroy(&test_queue).await.unwrap();
let queue_table = fallible_rowcount(&test_queue, &queue.connection).await;
assert!(queue_table.is_err());
let archive_table = fallible_archive_rowcount(&test_queue, &queue.connection).await;
assert!(archive_table.is_err());
let pgmq_meta_query = format!(
"SELECT count(*) as ct
FROM {PGMQ_SCHEMA}.meta
WHERE queue_name = '{test_queue}'",
);
let rowcount = sqlx::query(&pgmq_meta_query)
.fetch_one(&queue.connection)
.await
.unwrap()
.get::<i64, usize>(0);
assert_eq!(rowcount, 0);
}
#[tokio::test]
async fn test_set_vt() {
let test_queue = "test_set_vt_queue".to_owned();
let queue = init_queue(&test_queue).await;
let msg = MyMessage::default();
let num_rows_queue = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_queue, 0);
let msg_id = queue.send(&test_queue, &msg).await.unwrap();
let read: Message = queue.read(&test_queue, Some(0_i32)).await.unwrap().unwrap();
assert_eq!(read.msg_id, msg_id);
let read: Message = queue.read(&test_queue, Some(0_i32)).await.unwrap().unwrap();
assert_eq!(read.msg_id, msg_id);
assert_eq!(read.msg_id, msg_id);
let utc_24h_from_now = Utc::now() + Duration::hours(24);
let _ = queue
.set_vt::<MyMessage>(&test_queue, msg_id, utc_24h_from_now)
.await
.unwrap();
let read: Option<Message> = queue.read(&test_queue, Some(0_i32)).await.unwrap();
assert!(read.is_none());
let now = Utc::now();
let _ = queue
.set_vt::<MyMessage>(&test_queue, msg_id, now)
.await
.unwrap();
let num_rows_queue = rowcount(&test_queue, &queue.connection).await;
assert_eq!(num_rows_queue, 1);
}