use rusqlite::OptionalExtension;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tracing::{debug, error, instrument, warn};
#[derive(Debug)]
pub enum DiskQueueError {
Database(rusqlite::Error),
Serialization(postcard::Error),
Deserialization(postcard::Error),
InvalidTableName(String),
TaskJoin(String),
LockPoisoned(String),
DlqWriteFailed(String),
QueueClosed,
UnexpectedRowCount(String),
QueueFull(usize),
}
impl std::fmt::Display for DiskQueueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Database(e) => write!(f, "Database error: {e}"),
Self::Serialization(e) => write!(f, "Serialization error: {e}"),
Self::Deserialization(e) => write!(f, "Deserialization error: {e}"),
Self::InvalidTableName(s) => write!(f, "Invalid table name: {s}"),
Self::TaskJoin(s) => write!(f, "Internal task error: {s}"),
Self::LockPoisoned(s) => write!(f, "Internal lock poisoned: {s}"),
Self::DlqWriteFailed(s) => write!(f, "Failed to write to dead letter queue: {s}"),
Self::QueueClosed => f.write_str("Queue is closed"),
Self::UnexpectedRowCount(s) => write!(f, "Unexpected row count: {s}"),
Self::QueueFull(n) => write!(f, "Queue is full (max size: {n})"),
}
}
}
impl std::error::Error for DiskQueueError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Database(e) => Some(e),
Self::Serialization(e) | Self::Deserialization(e) => Some(e),
_ => None,
}
}
}
impl From<rusqlite::Error> for DiskQueueError {
fn from(e: rusqlite::Error) -> Self {
Self::Database(e)
}
}
pub type Result<T> = std::result::Result<T, DiskQueueError>;
fn lock_conn(
db: &Mutex<rusqlite::Connection>,
) -> Result<std::sync::MutexGuard<'_, rusqlite::Connection>> {
db.lock().map_err(|e| {
error!("Internal lock poisoned: {}", e);
DiskQueueError::LockPoisoned(e.to_string())
})
}
#[derive(Debug, Clone, Copy, Default)]
pub enum DurabilityLevel {
Off,
Normal,
#[default]
Full,
Extra,
}
impl DurabilityLevel {
fn as_str(&self) -> &'static str {
match self {
DurabilityLevel::Off => "OFF",
DurabilityLevel::Normal => "NORMAL",
DurabilityLevel::Full => "FULL",
DurabilityLevel::Extra => "EXTRA",
}
}
}
#[derive(Debug)]
struct CachedQueries {
insert_sql: String,
select_sql: String,
delete_sql: String,
count_sql: String,
clear_sql: String,
}
pub struct DiskBackedQueue<T> {
db: Arc<Mutex<rusqlite::Connection>>,
dlq_db: Arc<Mutex<rusqlite::Connection>>,
queries: CachedQueries,
dlq_insert_sql: String,
table_name: String,
max_size: Option<usize>,
_phantom: PhantomData<T>,
}
impl<T> std::fmt::Debug for DiskBackedQueue<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiskBackedQueue")
.field("table_name", &self.table_name)
.field("max_size", &self.max_size)
.field("queries", &self.queries)
.finish_non_exhaustive()
}
}
impl<T> DiskBackedQueue<T>
where
T: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
{
#[instrument(skip_all, fields(db_path = %db_path.as_ref().display(), table_name = %table_name))]
pub async fn new<P: AsRef<Path>>(
db_path: P,
table_name: String,
max_size: Option<usize>,
) -> Result<Self> {
Self::with_durability(db_path, table_name, max_size, DurabilityLevel::default()).await
}
#[instrument(skip_all, fields(db_path = %db_path.as_ref().display(), table_name = %table_name, durability = ?durability))]
pub async fn with_durability<P: AsRef<Path>>(
db_path: P,
table_name: String,
max_size: Option<usize>,
durability: DurabilityLevel,
) -> Result<Self> {
if table_name.is_empty() || table_name.len() > 128 {
return Err(DiskQueueError::InvalidTableName(table_name));
}
if !table_name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(DiskQueueError::InvalidTableName(table_name));
}
let conn = rusqlite::Connection::open(&db_path).map_err(|e| {
error!(
"Failed to open SQLite database at {}: {}",
db_path.as_ref().display(),
e
);
e
})?;
conn.pragma_update(None, "journal_mode", "WAL")
.map_err(|e| {
error!("Failed to enable WAL mode: {}", e);
e
})?;
conn.pragma_update(None, "synchronous", durability.as_str())
.map_err(|e| {
error!(
"Failed to set synchronous mode to {}: {}",
durability.as_str(),
e
);
e
})?;
debug!(
"Configured SQLite with WAL mode and synchronous={}",
durability.as_str()
);
conn.pragma_update(None, "foreign_keys", true)
.map_err(|e| {
error!("Failed to enable foreign keys: {}", e);
e
})?;
conn.busy_timeout(Duration::from_secs(5)).map_err(|e| {
error!("Failed to set busy timeout: {}", e);
e
})?;
let create_table_sql = format!(
"CREATE TABLE IF NOT EXISTS {table_name} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data BLOB NOT NULL,
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
)"
);
conn.execute(&create_table_sql, []).map_err(|e| {
error!("Failed to create table {}: {}", table_name, e);
e
})?;
let create_index_sql = format!(
"CREATE INDEX IF NOT EXISTS idx_{table_name}_created_at ON {table_name} (created_at, id)"
);
conn.execute(&create_index_sql, []).map_err(|e| {
error!("Failed to create index for table {}: {}", table_name, e);
e
})?;
debug!(
"Successfully initialized disk-backed queue table: {}",
table_name
);
let dlq_path = db_path.as_ref().with_extension("dlq.db");
let dlq_conn = rusqlite::Connection::open(&dlq_path).map_err(|e| {
error!(
"Failed to open DLQ database at {}: {}",
dlq_path.display(),
e
);
e
})?;
dlq_conn
.pragma_update(None, "journal_mode", "WAL")
.map_err(|e| {
error!("Failed to enable WAL mode for DLQ: {}", e);
e
})?;
dlq_conn
.pragma_update(None, "synchronous", durability.as_str())
.map_err(|e| {
error!(
"Failed to set synchronous mode for DLQ to {}: {}",
durability.as_str(),
e
);
e
})?;
dlq_conn.busy_timeout(Duration::from_secs(5)).map_err(|e| {
error!("Failed to set busy timeout for DLQ: {}", e);
e
})?;
let dlq_table_sql = format!(
"CREATE TABLE IF NOT EXISTS {table_name}_dlq (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_id INTEGER,
data BLOB NOT NULL,
error_message TEXT NOT NULL,
created_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now')),
moved_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
)"
);
dlq_conn.execute(&dlq_table_sql, []).map_err(|e| {
error!("Failed to create DLQ table {}_dlq: {}", table_name, e);
e
})?;
debug!(
"Successfully initialized dead letter queue for table: {}",
table_name
);
let queries = CachedQueries {
insert_sql: format!("INSERT INTO {table_name} (data) VALUES (?)"),
select_sql: format!(
"SELECT id, data FROM {table_name} ORDER BY created_at ASC, id ASC LIMIT 1"
),
delete_sql: format!("DELETE FROM {table_name} WHERE id = ?"),
count_sql: format!("SELECT COUNT(*) FROM {table_name}"),
clear_sql: format!("DELETE FROM {table_name}"),
};
let dlq_insert_sql = format!(
"INSERT INTO {table_name}_dlq (original_id, data, error_message) VALUES (?, ?, ?)"
);
Ok(Self {
db: Arc::new(Mutex::new(conn)),
dlq_db: Arc::new(Mutex::new(dlq_conn)),
queries,
dlq_insert_sql,
table_name,
max_size,
_phantom: PhantomData,
})
}
#[instrument(skip_all, fields(table_name = %self.table_name))]
pub async fn send(&self, item: T) -> Result<()> {
let serialized = postcard::to_allocvec(&item).map_err(|e| {
error!("Failed to serialize item for queue: {}", e);
DiskQueueError::Serialization(e)
})?;
let db = self.db.clone();
let insert_sql = self.queries.insert_sql.clone();
let count_sql = self.queries.count_sql.clone();
let max_size = self.max_size;
let table_name = self.table_name.clone();
tokio::task::spawn_blocking(move || {
let mut backoff = Duration::from_millis(10);
loop {
let mut conn = lock_conn(&db)?;
let tx = conn
.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
.map_err(|e| {
error!(
"Failed to start transaction for table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
if let Some(max) = max_size {
let count: i64 = tx
.query_row(&count_sql, [], |row| row.get(0))
.map_err(DiskQueueError::Database)?;
if usize::try_from(count).unwrap_or(usize::MAX) >= max {
drop(tx);
drop(conn);
warn!(
table_name = %table_name,
current_size = count,
max_size = max,
"Queue is full, waiting for space..."
);
std::thread::sleep(backoff);
backoff = std::cmp::min(backoff * 2, Duration::from_secs(1));
continue;
}
}
let rows_affected = tx.execute(&insert_sql, [&serialized]).map_err(|e| {
error!(
"Failed to insert item into queue table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
if rows_affected != 1 {
error!("Expected to insert 1 row, but inserted {}", rows_affected);
return Err(DiskQueueError::UnexpectedRowCount(format!(
"Insert affected {rows_affected} rows instead of 1"
)));
}
tx.commit().map_err(|e| {
error!(
"Failed to commit insert for table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
debug!("Successfully enqueued item to disk queue");
return Ok(());
}
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
#[instrument(skip_all, fields(table_name = %self.table_name, batch_size = items.len()))]
pub async fn send_batch(&self, items: Vec<T>) -> Result<()> {
if items.is_empty() {
return Ok(());
}
if let Some(max) = self.max_size
&& items.len() > max
{
return Err(DiskQueueError::QueueFull(max));
}
let mut serialized_items = Vec::with_capacity(items.len());
for item in items {
let serialized = postcard::to_allocvec(&item).map_err(|e| {
error!("Failed to serialize item for batch queue: {}", e);
DiskQueueError::Serialization(e)
})?;
serialized_items.push(serialized);
}
let db = self.db.clone();
let insert_sql = self.queries.insert_sql.clone();
let count_sql = self.queries.count_sql.clone();
let max_size = self.max_size;
let table_name = self.table_name.clone();
let batch_size = serialized_items.len();
tokio::task::spawn_blocking(move || {
let mut backoff = Duration::from_millis(10);
loop {
let mut conn = lock_conn(&db)?;
let tx = conn
.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
.map_err(|e| {
error!(
"Failed to start transaction for batch insert on table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
if let Some(max) = max_size {
let count: i64 = tx
.query_row(&count_sql, [], |row| row.get(0))
.map_err(DiskQueueError::Database)?;
if usize::try_from(count).unwrap_or(usize::MAX) + batch_size > max {
drop(tx);
drop(conn);
warn!(
table_name = %table_name,
current_size = count,
max_size = max,
batch_size = batch_size,
"Queue is full, waiting for space for batch..."
);
std::thread::sleep(backoff);
backoff = std::cmp::min(backoff * 2, Duration::from_secs(1));
continue;
}
}
{
let mut stmt = tx.prepare_cached(&insert_sql).map_err(|e| {
error!(
"Failed to prepare INSERT for table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
for serialized in &serialized_items {
stmt.execute([serialized]).map_err(|e| {
error!(
"Failed to insert item into queue table {} during batch: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
}
}
tx.commit().map_err(|e| {
error!(
"Failed to commit batch transaction for table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
debug!(
"Successfully enqueued {} items to disk queue in batch",
batch_size
);
return Ok(());
}
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
#[instrument(skip_all, fields(table_name = %self.table_name))]
pub async fn recv(&self) -> Result<Option<T>> {
let db = self.db.clone();
let dlq_db = self.dlq_db.clone();
let select_sql = self.queries.select_sql.clone();
let delete_sql = self.queries.delete_sql.clone();
let dlq_insert_sql = self.dlq_insert_sql.clone();
let table_name = self.table_name.clone();
tokio::task::spawn_blocking(move || {
let mut conn = lock_conn(&db)?;
let tx = conn.transaction().map_err(|e| {
error!(
"Failed to start transaction for table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
let result: Option<(i64, Vec<u8>)> = tx
.query_row(&select_sql, [], |row| Ok((row.get(0)?, row.get(1)?)))
.optional()
.map_err(|e| {
error!(
"Failed to execute SELECT query on table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
if let Some((id, data)) = result {
let item: T = match postcard::from_bytes::<T>(&data) {
Ok(item) => item,
Err(e) => {
error!(
"Failed to deserialize item from queue (id {}): {}. Moving to DLQ.",
id, e
);
let dlq_result = match dlq_db.lock() {
Ok(dlq_conn) => dlq_conn
.execute(
&dlq_insert_sql,
rusqlite::params![id, &data, e.to_string()],
)
.map_err(|dlq_err| dlq_err.to_string()),
Err(poison_err) => Err(format!("DLQ mutex poisoned: {poison_err}")),
};
if let Err(dlq_msg) = dlq_result {
error!(
"Failed to write corrupt item {} from table {} to DLQ: {}. Item left in main queue.",
id, table_name, dlq_msg
);
return Err(DiskQueueError::DlqWriteFailed(dlq_msg));
}
tx.execute(&delete_sql, [&id]).map_err(|err| {
error!(
"Failed to delete item {} from table {}: {}",
id, table_name, err
);
DiskQueueError::Database(err)
})?;
tx.commit().map_err(|err| {
error!(
"Failed to commit transaction after DLQ move for table {}: {}",
table_name, err
);
DiskQueueError::Database(err)
})?;
return Err(DiskQueueError::Deserialization(e));
}
};
let rows_deleted = tx.execute(&delete_sql, [&id]).map_err(|e| {
error!(
"Failed to delete item {} from table {}: {}",
id, table_name, e
);
DiskQueueError::Database(e)
})?;
if rows_deleted != 1 {
error!(
"Expected to delete 1 row, but deleted {} rows for id {}",
rows_deleted, id
);
return Err(DiskQueueError::UnexpectedRowCount(format!(
"Delete affected {rows_deleted} rows instead of 1 for id {id}"
)));
}
tx.commit().map_err(|e| {
error!(
"Failed to commit transaction for table {}: {}",
table_name, e
);
DiskQueueError::Database(e)
})?;
debug!("Successfully dequeued item from disk queue");
Ok(Some(item))
} else {
Ok(None)
}
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
#[instrument(skip_all, fields(table_name = %self.table_name, limit = limit))]
pub async fn recv_batch(&self, limit: usize) -> Result<Vec<T>> {
if limit == 0 {
return Ok(Vec::new());
}
let db = self.db.clone();
let dlq_db = self.dlq_db.clone();
let table_name = self.table_name.clone();
let dlq_insert_sql = self.dlq_insert_sql.clone();
tokio::task::spawn_blocking(move || {
let mut conn = lock_conn(&db)?;
let tx = conn.transaction().map_err(|e| {
error!("Failed to start transaction for table {}: {}", table_name, e);
DiskQueueError::Database(e)
})?;
let select_batch_sql = format!(
"SELECT id, data FROM {table_name} ORDER BY created_at ASC, id ASC LIMIT ?"
);
let mut stmt = tx.prepare_cached(&select_batch_sql).map_err(|e| {
error!("Failed to prepare SELECT statement for table {}: {}", table_name, e);
DiskQueueError::Database(e)
})?;
let limit_param = i64::try_from(limit).unwrap_or(i64::MAX);
let rows = stmt
.query_map([limit_param], |row| {
Ok((row.get::<_, i64>(0)?, row.get::<_, Vec<u8>>(1)?))
})
.map_err(|e| {
error!("Failed to execute SELECT query on table {}: {}", table_name, e);
DiskQueueError::Database(e)
})?;
let mut items = Vec::new();
let mut ids_to_delete = Vec::new();
let mut dlq_failures: usize = 0;
for row_result in rows {
let (id, data) = row_result.map_err(|e| {
error!("Failed to read row from table {}: {}", table_name, e);
DiskQueueError::Database(e)
})?;
match postcard::from_bytes::<T>(&data) {
Ok(item) => {
items.push(item);
ids_to_delete.push(id);
}
Err(e) => {
error!(
"Failed to deserialize item from queue (id {}): {}. Moving to DLQ.",
id, e
);
let dlq_ok = match dlq_db.lock() {
Ok(dlq_conn) => match dlq_conn.execute(
&dlq_insert_sql,
rusqlite::params![id, &data, e.to_string()],
) {
Ok(_) => true,
Err(dlq_err) => {
error!(
"Failed to insert corrupt item {} into DLQ for table {}: {}. Leaving item in main queue.",
id, table_name, dlq_err
);
false
}
},
Err(poison_err) => {
error!(
"DLQ mutex poisoned while handling corrupt item {}: {}. Leaving item in main queue.",
id, poison_err
);
false
}
};
if dlq_ok {
ids_to_delete.push(id);
} else {
dlq_failures += 1;
}
}
}
}
drop(stmt);
if !ids_to_delete.is_empty() {
const BATCH_DELETE_LIMIT: usize = 900;
for chunk in ids_to_delete.chunks(BATCH_DELETE_LIMIT) {
let delete_batch_sql = format!(
"DELETE FROM {} WHERE id IN ({})",
table_name,
chunk
.iter()
.map(|_| "?")
.collect::<Vec<_>>()
.join(",")
);
let params: Vec<&dyn rusqlite::ToSql> = chunk
.iter()
.map(|id| id as &dyn rusqlite::ToSql)
.collect();
let rows_deleted = tx
.execute(&delete_batch_sql, params.as_slice())
.map_err(|e| {
error!("Failed to delete items from table {}: {}", table_name, e);
DiskQueueError::Database(e)
})?;
if rows_deleted != chunk.len() {
error!(
"Expected to delete {} rows, but deleted {} rows",
chunk.len(),
rows_deleted
);
return Err(DiskQueueError::UnexpectedRowCount(format!(
"Delete affected {rows_deleted} rows instead of {}",
chunk.len()
)));
}
}
}
tx.commit().map_err(|e| {
error!("Failed to commit batch transaction for table {}: {}", table_name, e);
DiskQueueError::Database(e)
})?;
if dlq_failures > 0 && items.is_empty() {
return Err(DiskQueueError::DlqWriteFailed(format!(
"{dlq_failures} corrupt item(s) could not be moved to DLQ"
)));
}
if dlq_failures > 0 {
warn!(
table_name = %table_name,
dlq_failures = dlq_failures,
"Some corrupt items could not be moved to DLQ; they remain in the main queue"
);
}
debug!("Successfully dequeued {} items from disk queue in batch", items.len());
Ok(items)
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
#[instrument(skip_all, fields(table_name = %self.table_name))]
pub async fn len(&self) -> Result<usize> {
let db = self.db.clone();
let count_sql = self.queries.count_sql.clone();
tokio::task::spawn_blocking(move || {
let conn = lock_conn(&db)?;
let count: i64 = conn
.query_row(&count_sql, [], |row| row.get(0))
.map_err(DiskQueueError::Database)?;
Ok(usize::try_from(count).unwrap_or(usize::MAX))
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
#[instrument(skip_all, fields(table_name = %self.table_name))]
pub async fn is_empty(&self) -> Result<bool> {
let db = self.db.clone();
let exists_sql = format!(
"SELECT EXISTS(SELECT 1 FROM {} LIMIT 1)",
self.table_name
);
tokio::task::spawn_blocking(move || {
let conn = lock_conn(&db)?;
let exists: i64 = conn
.query_row(&exists_sql, [], |row| row.get(0))
.map_err(DiskQueueError::Database)?;
Ok(exists == 0)
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
#[instrument(skip_all, fields(table_name = %self.table_name))]
pub async fn clear(&self) -> Result<()> {
let db = self.db.clone();
let clear_sql = self.queries.clear_sql.clone();
tokio::task::spawn_blocking(move || {
let conn = lock_conn(&db)?;
conn.execute(&clear_sql, [])
.map_err(DiskQueueError::Database)?;
debug!("Cleared disk queue");
Ok(())
})
.await
.map_err(|e| DiskQueueError::TaskJoin(e.to_string()))?
}
}
#[derive(Debug)]
pub struct DiskBackedSender<T> {
queue: std::sync::Arc<DiskBackedQueue<T>>,
}
impl<T> DiskBackedSender<T>
where
T: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
{
pub async fn send(&self, item: T) -> Result<()> {
self.queue.send(item).await
}
pub async fn send_batch(&self, items: Vec<T>) -> Result<()> {
self.queue.send_batch(items).await
}
pub fn blocking_send(&self, item: T) -> Result<()> {
match tokio::runtime::Handle::try_current() {
Ok(handle) if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => {
tokio::task::block_in_place(|| handle.block_on(self.send(item)))
}
_ => {
let runtime = tokio::runtime::Runtime::new().map_err(|e| {
error!("Failed to create Tokio runtime: {}", e);
DiskQueueError::TaskJoin(format!("Runtime creation failed: {e}"))
})?;
runtime.block_on(self.send(item))
}
}
}
}
impl<T> Clone for DiskBackedSender<T> {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
}
}
}
#[derive(Debug)]
pub struct DiskBackedReceiver<T> {
queue: std::sync::Arc<DiskBackedQueue<T>>,
}
impl<T> DiskBackedReceiver<T>
where
T: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
{
pub async fn recv(&mut self) -> Result<Option<T>> {
self.queue.recv().await
}
pub async fn recv_batch(&mut self, limit: usize) -> Result<Vec<T>> {
self.queue.recv_batch(limit).await
}
pub async fn len(&self) -> Result<usize> {
self.queue.len().await
}
pub async fn is_empty(&self) -> Result<bool> {
self.queue.is_empty().await
}
}
pub async fn disk_backed_channel<T, P: AsRef<Path>>(
db_path: P,
table_name: String,
max_size: Option<usize>,
) -> Result<(DiskBackedSender<T>, DiskBackedReceiver<T>)>
where
T: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
{
disk_backed_channel_with_durability(db_path, table_name, max_size, DurabilityLevel::default())
.await
}
pub async fn disk_backed_channel_with_durability<T, P: AsRef<Path>>(
db_path: P,
table_name: String,
max_size: Option<usize>,
durability: DurabilityLevel,
) -> Result<(DiskBackedSender<T>, DiskBackedReceiver<T>)>
where
T: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
{
let queue = DiskBackedQueue::with_durability(db_path, table_name, max_size, durability).await?;
let queue_arc = std::sync::Arc::new(queue);
let sender = DiskBackedSender {
queue: queue_arc.clone(),
};
let receiver = DiskBackedReceiver { queue: queue_arc };
Ok((sender, receiver))
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use tempfile::NamedTempFile;
#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
struct TestMessage {
id: u64,
content: String,
}
#[tokio::test]
async fn test_disk_backed_queue_basic() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "test_queue".to_string(), None)
.await
.unwrap();
assert!(queue.is_empty().await.unwrap());
assert_eq!(queue.len().await.unwrap(), 0);
assert!(queue.recv().await.unwrap().is_none());
let msg1 = TestMessage {
id: 1,
content: "Hello".to_string(),
};
let msg2 = TestMessage {
id: 2,
content: "World".to_string(),
};
queue.send(msg1.clone()).await.unwrap();
queue.send(msg2.clone()).await.unwrap();
assert!(!queue.is_empty().await.unwrap());
assert_eq!(queue.len().await.unwrap(), 2);
let received1 = queue.recv().await.unwrap().unwrap();
assert_eq!(received1, msg1);
assert_eq!(queue.len().await.unwrap(), 1);
let received2 = queue.recv().await.unwrap().unwrap();
assert_eq!(received2, msg2);
assert_eq!(queue.len().await.unwrap(), 0);
assert!(queue.is_empty().await.unwrap());
assert!(queue.recv().await.unwrap().is_none());
}
#[tokio::test]
async fn test_disk_backed_channel() {
let temp_file = NamedTempFile::new().unwrap();
let (sender, mut receiver) = disk_backed_channel::<TestMessage, _>(
temp_file.path(),
"test_channel".to_string(),
None,
)
.await
.unwrap();
let msg = TestMessage {
id: 42,
content: "Channel test".to_string(),
};
sender.send(msg.clone()).await.unwrap();
let received = receiver.recv().await.unwrap().unwrap();
assert_eq!(received, msg);
}
#[tokio::test]
async fn test_persistence() {
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path().to_path_buf();
let msg = TestMessage {
id: 99,
content: "Persistent message".to_string(),
};
{
let queue = DiskBackedQueue::new(&temp_path, "persistence_test".to_string(), None)
.await
.unwrap();
queue.send(msg.clone()).await.unwrap();
}
{
let queue: DiskBackedQueue<TestMessage> =
DiskBackedQueue::new(&temp_path, "persistence_test".to_string(), None)
.await
.unwrap();
assert_eq!(queue.len().await.unwrap(), 1);
let received = queue.recv().await.unwrap().unwrap();
assert_eq!(received, msg);
}
}
#[tokio::test]
async fn test_error_handling_database_corruption() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "test_queue".to_string(), None)
.await
.unwrap();
let msg = TestMessage {
id: 1,
content: "Test".to_string(),
};
queue.send(msg).await.unwrap();
{
let db = queue.db.lock().unwrap();
let garbage_data: Vec<u8> = vec![0xFF, 0xFE, 0xFD, 0xFC]; db.execute(
"UPDATE test_queue SET data = ? WHERE id = 1",
[&garbage_data],
)
.unwrap();
}
let result = queue.recv().await;
assert!(result.is_err(), "Expected error, got: {:?}", result);
assert!(
matches!(result, Err(DiskQueueError::Deserialization(_))),
"Expected Deserialization error, got: {:?}",
result
);
}
#[tokio::test]
async fn test_concurrent_access() {
let temp_file = NamedTempFile::new().unwrap();
let queue = std::sync::Arc::new(
DiskBackedQueue::new(temp_file.path(), "concurrent_test".to_string(), None)
.await
.unwrap(),
);
let mut send_handles = vec![];
for i in 0..10 {
let queue_clone = queue.clone();
let handle = tokio::spawn(async move {
for j in 0..10 {
let msg = TestMessage {
id: i * 10 + j,
content: format!("Message from sender {i}, iteration {j}"),
};
queue_clone.send(msg).await.unwrap();
}
});
send_handles.push(handle);
}
let mut recv_handles = vec![];
let received_messages = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
for _ in 0..5 {
let queue_clone = queue.clone();
let messages_clone = received_messages.clone();
let handle = tokio::spawn(async move {
loop {
match queue_clone.recv().await {
Ok(Some(msg)) => {
messages_clone.lock().await.push(msg);
}
Ok(None) => {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Err(_) => break,
}
if messages_clone.lock().await.len() >= 100 {
break;
}
}
});
recv_handles.push(handle);
}
for handle in send_handles {
handle.await.unwrap();
}
let start = tokio::time::Instant::now();
let timeout = tokio::time::Duration::from_secs(10);
loop {
let count = received_messages.lock().await.len();
if count >= 100 {
break;
}
if start.elapsed() > timeout {
panic!(
"Timeout: Only received {} messages after {:?}",
count, timeout
);
}
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
let messages = received_messages.lock().await;
assert_eq!(messages.len(), 100);
for handle in recv_handles {
handle.abort();
}
}
#[tokio::test]
async fn test_empty_queue_operations() {
let temp_file = NamedTempFile::new().unwrap();
let queue: DiskBackedQueue<TestMessage> =
DiskBackedQueue::new(temp_file.path(), "empty_test".to_string(), None)
.await
.unwrap();
for _ in 0..5 {
assert!(queue.recv().await.unwrap().is_none());
}
assert!(queue.is_empty().await.unwrap());
assert_eq!(queue.len().await.unwrap(), 0);
}
#[tokio::test]
async fn test_large_messages() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "large_test".to_string(), None)
.await
.unwrap();
let large_content = "x".repeat(1024 * 1024);
let large_msg = TestMessage {
id: 42,
content: large_content.clone(),
};
queue.send(large_msg.clone()).await.unwrap();
let received = queue.recv().await.unwrap().unwrap();
assert_eq!(received.id, 42);
assert_eq!(received.content.len(), 1024 * 1024);
assert_eq!(received.content, large_content);
}
#[tokio::test]
async fn test_fifo_ordering() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "fifo_test".to_string(), None)
.await
.unwrap();
let messages: Vec<TestMessage> = (0..100)
.map(|i| TestMessage {
id: i,
content: format!("Message {i}"),
})
.collect();
for msg in &messages {
queue.send(msg.clone()).await.unwrap();
}
for expected_msg in &messages {
let received = queue.recv().await.unwrap().unwrap();
assert_eq!(received, *expected_msg);
}
assert!(queue.recv().await.unwrap().is_none());
}
#[tokio::test]
async fn test_invalid_table_name() {
let temp_file = NamedTempFile::new().unwrap();
let _result = DiskBackedQueue::<TestMessage>::new(
temp_file.path(),
"invalid-table-name-with-dashes".to_string(),
None,
)
.await;
let result2 = DiskBackedQueue::<TestMessage>::new(
temp_file.path(),
"".to_string(), None,
)
.await;
assert!(result2.is_err());
}
#[tokio::test(flavor = "multi_thread")]
async fn test_blocking_send() {
let temp_file = NamedTempFile::new().unwrap();
let (sender, mut receiver) = disk_backed_channel::<TestMessage, _>(
temp_file.path(),
"blocking_test".to_string(),
None,
)
.await
.unwrap();
let msg = TestMessage {
id: 123,
content: "Blocking test".to_string(),
};
sender.blocking_send(msg.clone()).unwrap();
let received = receiver.recv().await.unwrap().unwrap();
assert_eq!(received, msg);
}
#[tokio::test]
async fn test_database_file_permissions() {
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path().to_path_buf();
let _queue = DiskBackedQueue::<TestMessage>::new(&temp_path, "perm_test".to_string(), None)
.await
.unwrap();
assert!(temp_path.exists());
let metadata = std::fs::metadata(&temp_path).unwrap();
assert!(metadata.is_file());
assert!(metadata.len() > 0); }
#[tokio::test]
async fn test_dlq_file_created() {
let temp_file = NamedTempFile::new().unwrap();
let temp_path = temp_file.path();
disk_backed_channel::<TestMessage, _>(temp_path, "test".to_string(), None)
.await
.unwrap();
let dlq_path = temp_path.with_extension("dlq.db");
assert!(dlq_path.exists());
}
#[tokio::test]
async fn test_transaction_rollback_on_corruption() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "test_queue".to_string(), None)
.await
.unwrap();
let msg = TestMessage {
id: 1,
content: "Valid".to_string(),
};
queue.send(msg).await.unwrap();
{
let db = queue.db.lock().unwrap();
let garbage_data: Vec<u8> = vec![0xFF, 0xFE, 0xFD, 0xFC]; db.execute(
"UPDATE test_queue SET data = ? WHERE id = 1",
[&garbage_data],
)
.unwrap();
}
let result = queue.recv().await;
assert!(result.is_err());
assert!(matches!(result, Err(DiskQueueError::Deserialization(_))));
let dlq_path = temp_file.path().with_extension("dlq.db");
let dlq_conn = rusqlite::Connection::open(&dlq_path).unwrap();
let count: i64 = dlq_conn
.query_row("SELECT COUNT(*) FROM test_queue_dlq", [], |row| row.get(0))
.unwrap();
assert_eq!(count, 1);
assert_eq!(queue.len().await.unwrap(), 0);
}
#[tokio::test]
async fn test_max_size_blocks() {
let temp_file = NamedTempFile::new().unwrap();
let (tx, mut rx) = disk_backed_channel::<TestMessage, _>(
temp_file.path(),
"test".to_string(),
Some(2), )
.await
.unwrap();
let msg1 = TestMessage {
id: 1,
content: "Message 1".to_string(),
};
let msg2 = TestMessage {
id: 2,
content: "Message 2".to_string(),
};
let msg3 = TestMessage {
id: 3,
content: "Message 3".to_string(),
};
tx.send(msg1.clone()).await.unwrap();
tx.send(msg2.clone()).await.unwrap();
let tx_clone = tx.clone();
let handle = tokio::spawn(async move { tx_clone.send(msg3).await });
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert!(!handle.is_finished());
let received = rx.recv().await.unwrap().unwrap();
assert_eq!(received, msg1);
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
assert!(handle.is_finished());
handle.await.unwrap().unwrap();
assert_eq!(rx.len().await.unwrap(), 2);
}
#[tokio::test]
async fn test_send_batch() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "batch_test".to_string(), None)
.await
.unwrap();
let messages: Vec<TestMessage> = (0..100)
.map(|i| TestMessage {
id: i,
content: format!("Batch message {}", i),
})
.collect();
queue.send_batch(messages.clone()).await.unwrap();
assert_eq!(queue.len().await.unwrap(), 100);
for expected in &messages {
let received = queue.recv().await.unwrap().unwrap();
assert_eq!(received, *expected);
}
assert!(queue.is_empty().await.unwrap());
}
#[tokio::test]
async fn test_recv_batch() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "batch_recv_test".to_string(), None)
.await
.unwrap();
for i in 0..100 {
let msg = TestMessage {
id: i,
content: format!("Message {}", i),
};
queue.send(msg).await.unwrap();
}
let batch1 = queue.recv_batch(25).await.unwrap();
assert_eq!(batch1.len(), 25);
assert_eq!(batch1[0].id, 0);
assert_eq!(batch1[24].id, 24);
let batch2 = queue.recv_batch(25).await.unwrap();
assert_eq!(batch2.len(), 25);
assert_eq!(batch2[0].id, 25);
assert_eq!(batch2[24].id, 49);
let batch3 = queue.recv_batch(100).await.unwrap();
assert_eq!(batch3.len(), 50); assert_eq!(batch3[0].id, 50);
assert_eq!(batch3[49].id, 99);
assert!(queue.is_empty().await.unwrap());
let empty_batch = queue.recv_batch(10).await.unwrap();
assert!(empty_batch.is_empty());
}
#[tokio::test]
async fn test_batch_with_channel_api() {
let temp_file = NamedTempFile::new().unwrap();
let (tx, mut rx) = disk_backed_channel::<TestMessage, _>(
temp_file.path(),
"batch_channel_test".to_string(),
None,
)
.await
.unwrap();
let messages: Vec<TestMessage> = (0..50)
.map(|i| TestMessage {
id: i,
content: format!("Batch {}", i),
})
.collect();
tx.send_batch(messages.clone()).await.unwrap();
let received = rx.recv_batch(50).await.unwrap();
assert_eq!(received.len(), 50);
assert_eq!(received, messages);
}
#[tokio::test]
async fn test_batch_performance_comparison() {
let temp_file1 = NamedTempFile::new().unwrap();
let temp_file2 = NamedTempFile::new().unwrap();
let queue_single = DiskBackedQueue::new(temp_file1.path(), "single".to_string(), None)
.await
.unwrap();
let queue_batch = DiskBackedQueue::new(temp_file2.path(), "batch".to_string(), None)
.await
.unwrap();
let message_count = 1000;
let messages: Vec<TestMessage> = (0..message_count)
.map(|i| TestMessage {
id: i,
content: format!("Perf test {}", i),
})
.collect();
let start = std::time::Instant::now();
for msg in messages.clone() {
queue_single.send(msg).await.unwrap();
}
let single_duration = start.elapsed();
let start = std::time::Instant::now();
queue_batch.send_batch(messages.clone()).await.unwrap();
let batch_duration = start.elapsed();
println!(
"Single send: {:?}, Batch send: {:?}, Speedup: {:.2}x",
single_duration,
batch_duration,
single_duration.as_secs_f64() / batch_duration.as_secs_f64()
);
assert!(batch_duration < single_duration / 5);
}
#[tokio::test]
async fn test_batch_with_corrupted_data() {
let temp_file = NamedTempFile::new().unwrap();
let queue = DiskBackedQueue::new(temp_file.path(), "batch_corrupt_test".to_string(), None)
.await
.unwrap();
for i in 0..10 {
let msg = TestMessage {
id: i,
content: format!("Valid {}", i),
};
queue.send(msg).await.unwrap();
}
{
let db = queue.db.lock().unwrap();
let garbage_data: Vec<u8> = vec![0xFF, 0xFE, 0xFD, 0xFC];
db.execute(
"UPDATE batch_corrupt_test SET data = ? WHERE id = 6",
[&garbage_data],
)
.unwrap();
}
let batch = queue.recv_batch(10).await.unwrap();
assert_eq!(batch.len(), 9);
let dlq_path = temp_file.path().with_extension("dlq.db");
let dlq_conn = rusqlite::Connection::open(&dlq_path).unwrap();
let dlq_count: i64 = dlq_conn
.query_row("SELECT COUNT(*) FROM batch_corrupt_test_dlq", [], |row| {
row.get(0)
})
.unwrap();
assert_eq!(dlq_count, 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_max_size_never_exceeded_under_contention() {
const MAX: usize = 10;
const SENDERS: usize = 50;
const PER_SENDER: usize = 5;
let temp_file = NamedTempFile::new().unwrap();
let queue = std::sync::Arc::new(
DiskBackedQueue::<TestMessage>::new(
temp_file.path(),
"race_test".to_string(),
Some(MAX),
)
.await
.unwrap(),
);
let stop = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let sampler = {
let queue = queue.clone();
let stop = stop.clone();
tokio::spawn(async move {
while !stop.load(std::sync::atomic::Ordering::Relaxed) {
let len = queue.len().await.unwrap();
assert!(
len <= MAX,
"Queue length {len} exceeded max_size {MAX} — race fix regressed"
);
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
}
})
};
let mut send_handles = Vec::new();
for sender_id in 0..SENDERS {
let queue = queue.clone();
send_handles.push(tokio::spawn(async move {
for j in 0..PER_SENDER {
queue
.send(TestMessage {
id: (sender_id * PER_SENDER + j) as u64,
content: String::new(),
})
.await
.unwrap();
}
}));
}
let drainer = {
let queue = queue.clone();
tokio::spawn(async move {
let total = SENDERS * PER_SENDER;
let mut received = 0;
while received < total {
if queue.recv().await.unwrap().is_some() {
received += 1;
} else {
tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
}
}
})
};
for h in send_handles {
h.await.unwrap();
}
drainer.await.unwrap();
stop.store(true, std::sync::atomic::Ordering::Relaxed);
sampler.await.unwrap();
assert!(queue.is_empty().await.unwrap());
}
#[tokio::test] async fn test_blocking_send_on_current_thread_runtime() {
let temp_file = NamedTempFile::new().unwrap();
let (tx, mut rx) = disk_backed_channel::<TestMessage, _>(
temp_file.path(),
"blocking_ct_test".to_string(),
None,
)
.await
.unwrap();
let msg = TestMessage {
id: 7,
content: "from current-thread".to_string(),
};
let tx_clone = tx.clone();
let msg_clone = msg.clone();
let handle =
std::thread::spawn(move || tx_clone.blocking_send(msg_clone));
tokio::task::yield_now().await;
handle.join().expect("thread panicked").unwrap();
let received = rx.recv().await.unwrap().unwrap();
assert_eq!(received, msg);
}
#[tokio::test]
async fn test_send_batch_rejects_oversized_batch() {
let temp_file = NamedTempFile::new().unwrap();
let queue =
DiskBackedQueue::new(temp_file.path(), "oversize_test".to_string(), Some(5))
.await
.unwrap();
let oversized: Vec<TestMessage> = (0..10)
.map(|i| TestMessage {
id: i,
content: String::new(),
})
.collect();
let result = queue.send_batch(oversized).await;
assert!(
matches!(result, Err(DiskQueueError::QueueFull(5))),
"Expected QueueFull(5), got: {result:?}"
);
assert!(queue.is_empty().await.unwrap());
let small_batch = vec![
TestMessage {
id: 0,
content: String::new(),
};
3
];
queue.send_batch(small_batch).await.unwrap();
assert_eq!(queue.len().await.unwrap(), 3);
}
#[tokio::test]
async fn test_clear_empties_queue_and_remains_usable() {
let temp_file = NamedTempFile::new().unwrap();
let queue =
DiskBackedQueue::new(temp_file.path(), "clear_test".to_string(), None)
.await
.unwrap();
queue.clear().await.unwrap();
assert!(queue.is_empty().await.unwrap());
for i in 0..25 {
queue
.send(TestMessage {
id: i,
content: format!("msg {i}"),
})
.await
.unwrap();
}
assert_eq!(queue.len().await.unwrap(), 25);
queue.clear().await.unwrap();
assert!(queue.is_empty().await.unwrap());
assert_eq!(queue.len().await.unwrap(), 0);
assert!(queue.recv().await.unwrap().is_none());
let after = TestMessage {
id: 999,
content: "after-clear".to_string(),
};
queue.send(after.clone()).await.unwrap();
assert_eq!(queue.recv().await.unwrap().unwrap(), after);
}
}