#![doc(html_root_url = "https://docs.rs/pgmq/")]
use serde::{Deserialize, Serialize};
use sqlx::error::Error;
use sqlx::postgres::PgRow;
use sqlx::types::chrono::Utc;
use sqlx::{Pool, Postgres, Row};
pub mod errors;
pub mod pg_ext;
pub mod types;
pub mod util;
mod query;
pub use errors::PgmqError;
pub use pg_ext::PGMQueueExt;
pub use types::Message;
use std::time::Duration;
#[derive(Clone, Debug)]
pub struct PGMQueue {
pub url: String,
pub connection: Pool<Postgres>,
}
impl PGMQueue {
pub async fn new(url: String) -> Result<Self, PgmqError> {
let con = util::connect(&url, 5).await?;
Ok(Self {
url,
connection: con,
})
}
pub async fn new_with_pool(pool: Pool<Postgres>) -> Self {
Self {
url: "".to_owned(),
connection: pool,
}
}
pub async fn create(&self, queue_name: &str) -> Result<(), PgmqError> {
let mut tx = self.connection.begin().await?;
let setup = query::init_queue_client_only(queue_name, false)?;
for q in setup {
sqlx::query(&q).execute(&mut *tx).await?;
}
tx.commit().await?;
Ok(())
}
pub async fn create_unlogged(&self, queue_name: &str) -> Result<(), PgmqError> {
let mut tx = self.connection.begin().await?;
let setup = query::init_queue_client_only(queue_name, true)?;
for q in setup {
sqlx::query(&q).execute(&mut *tx).await?;
}
tx.commit().await?;
Ok(())
}
pub async fn destroy(&self, queue_name: &str) -> Result<(), PgmqError> {
let mut tx = self.connection.begin().await?;
let setup = query::destroy_queue_client_only(queue_name)?;
for q in setup {
sqlx::query(&q).execute(&mut *tx).await?;
}
tx.commit().await?;
Ok(())
}
pub async fn send<T: Serialize>(
&self,
queue_name: &str,
message: &T,
) -> Result<i64, PgmqError> {
let msg = serde_json::json!(&message);
let row: PgRow = sqlx::query(&query::enqueue(queue_name, 1, &0)?)
.bind(msg)
.fetch_one(&self.connection)
.await?;
let msg_id: i64 = row.get("msg_id");
Ok(msg_id)
}
pub async fn send_delay<T: Serialize>(
&self,
queue_name: &str,
message: &T,
delay: u64,
) -> Result<i64, PgmqError> {
let msg = serde_json::json!(&message);
let row: PgRow = sqlx::query(&query::enqueue(queue_name, 1, &delay)?)
.bind(msg)
.fetch_one(&self.connection)
.await?;
let msg_id: i64 = row.get("msg_id");
Ok(msg_id)
}
pub async fn send_batch<T: Serialize>(
&self,
queue_name: &str,
messages: &[T],
) -> Result<Vec<i64>, PgmqError> {
let mut msg_ids: Vec<i64> = Vec::new();
let query = query::enqueue(queue_name, messages.len(), &0)?;
let mut q = sqlx::query(&query);
for msg in messages.iter() {
q = q.bind(serde_json::json!(msg));
}
let rows: Vec<PgRow> = q.fetch_all(&self.connection).await?;
for row in rows.iter() {
msg_ids.push(row.get("msg_id"));
}
Ok(msg_ids)
}
pub async fn read<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
vt: Option<i32>,
) -> Result<Option<Message<T>>, PgmqError> {
let vt_ = match vt {
Some(t) => t,
None => types::VT_DEFAULT,
};
let limit = types::READ_LIMIT_DEFAULT;
let query = &query::read(queue_name, vt_, limit)?;
let message = util::fetch_one_message::<T>(query, &self.connection).await?;
Ok(message)
}
pub async fn read_batch<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
vt: Option<i32>,
num_msgs: i32,
) -> Result<Option<Vec<Message<T>>>, PgmqError> {
let vt_ = match vt {
Some(t) => t,
None => types::VT_DEFAULT,
};
let query = &query::read(queue_name, vt_, num_msgs)?;
let messages = fetch_messages::<T>(query, &self.connection).await?;
Ok(messages)
}
pub async fn read_batch_with_poll<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
vt: Option<i32>,
max_batch_size: i32,
poll_timeout: Option<Duration>,
poll_interval: Option<Duration>,
) -> Result<Option<Vec<Message<T>>>, PgmqError> {
let vt_ = vt.unwrap_or(types::VT_DEFAULT);
let poll_timeout_ = poll_timeout.unwrap_or(types::POLL_TIMEOUT_DEFAULT);
let poll_interval_ = poll_interval.unwrap_or(types::POLL_INTERVAL_DEFAULT);
let start_time = std::time::Instant::now();
loop {
let query = &query::read(queue_name, vt_, max_batch_size)?;
let messages = fetch_messages::<T>(query, &self.connection).await?;
match messages {
Some(m) => {
break Ok(Some(m));
}
None => {
if start_time.elapsed() < poll_timeout_ {
tokio::time::sleep(poll_interval_).await;
continue;
} else {
break Ok(None);
}
}
}
}
}
pub async fn delete(&self, queue_name: &str, msg_id: i64) -> Result<u64, PgmqError> {
let query = &query::delete_batch(queue_name)?;
let row = sqlx::query(query)
.bind(vec![msg_id])
.execute(&self.connection)
.await?;
let num_deleted = row.rows_affected();
Ok(num_deleted)
}
pub async fn delete_batch(&self, queue_name: &str, msg_ids: &[i64]) -> Result<u64, PgmqError> {
let query = &query::delete_batch(queue_name)?;
let row = sqlx::query(query)
.bind(msg_ids)
.execute(&self.connection)
.await?;
let num_deleted = row.rows_affected();
Ok(num_deleted)
}
pub async fn purge(&self, queue_name: &str) -> Result<u64, PgmqError> {
let query = &query::purge_queue(queue_name)?;
let row = sqlx::query(query).execute(&self.connection).await?;
let num_deleted = row.rows_affected();
Ok(num_deleted)
}
pub async fn archive(&self, queue_name: &str, msg_id: i64) -> Result<u64, PgmqError> {
self.archive_batch(queue_name, &[msg_id]).await
}
pub async fn archive_batch(&self, queue_name: &str, msg_ids: &[i64]) -> Result<u64, PgmqError> {
let query = query::archive_batch(queue_name)?;
let row = sqlx::query(&query)
.bind(msg_ids)
.execute(&self.connection)
.await?;
let num_achived = row.rows_affected();
Ok(num_achived)
}
pub async fn pop<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
) -> Result<Option<Message<T>>, PgmqError> {
let query = &query::pop(queue_name)?;
let message = util::fetch_one_message::<T>(query, &self.connection).await?;
Ok(message)
}
pub async fn set_vt<T: for<'de> Deserialize<'de>>(
&self,
queue_name: &str,
msg_id: i64,
vt: chrono::DateTime<Utc>,
) -> Result<Option<Message<T>>, PgmqError> {
let query = &query::set_vt(queue_name, msg_id, vt)?;
let updated_message = util::fetch_one_message::<T>(query, &self.connection).await?;
Ok(updated_message)
}
}
async fn fetch_messages<T: for<'de> Deserialize<'de>>(
query: &str,
connection: &Pool<Postgres>,
) -> Result<Option<Vec<Message<T>>>, PgmqError> {
let mut messages: Vec<Message<T>> = Vec::new();
let result: Result<Vec<PgRow>, Error> = sqlx::query(query).fetch_all(connection).await;
match result {
Ok(rows) => {
if rows.is_empty() {
Ok(None)
} else {
for row in rows.iter() {
let raw_msg = row.get("message");
let parsed_msg = serde_json::from_value::<T>(raw_msg);
if let Err(e) = parsed_msg {
return Err(PgmqError::JsonParsingError(e));
} else if let Ok(parsed_msg) = parsed_msg {
messages.push(Message {
msg_id: row.get("msg_id"),
vt: row.get("vt"),
read_ct: row.get("read_ct"),
enqueued_at: row.get("enqueued_at"),
message: parsed_msg,
})
}
}
Ok(Some(messages))
}
}
Err(e) => Err(e)?,
}
}