use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug)]
pub enum QueueError {
Driver(String),
Serialization(String),
HandlerNotFound(String),
JobFailed(String),
}
impl std::fmt::Display for QueueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueueError::Driver(msg) => write!(f, "Queue driver error: {}", msg),
QueueError::Serialization(msg) => write!(f, "Queue serialization error: {}", msg),
QueueError::HandlerNotFound(name) => write!(f, "No handler registered for job: {}", name),
QueueError::JobFailed(msg) => write!(f, "Job execution failed: {}", msg),
}
}
}
impl std::error::Error for QueueError {}
#[derive(Debug, Clone)]
pub struct QueuedJob {
pub id: String,
pub name: String,
pub payload: Value,
pub attempts: u32,
}
#[async_trait]
pub trait QueueDriver: Send + Sync {
async fn push(&self, id: &str, job_name: &str, payload: &str) -> Result<(), QueueError>;
async fn pop(&self) -> Result<Option<QueuedJob>, QueueError>;
async fn mark_complete(&self, job_id: &str) -> Result<(), QueueError>;
async fn mark_failed(&self, job_id: &str, error: &str) -> Result<(), QueueError>;
async fn pending_count(&self) -> Result<u64, QueueError>;
}
pub struct SqliteDriver {
pool: sqlx::SqlitePool,
}
impl SqliteDriver {
pub async fn new(database_url: &str) -> Result<Self, QueueError> {
let pool = sqlx::SqlitePool::connect(database_url)
.await
.map_err(|e| QueueError::Driver(format!("Failed to connect to SQLite: {}", e)))?;
sqlx::query(
r#"CREATE TABLE IF NOT EXISTS rullst_jobs (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
payload TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
error TEXT,
attempts INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)"#,
)
.execute(&pool)
.await
.map_err(|e| QueueError::Driver(format!("Failed to create rullst_jobs table: {}", e)))?;
Ok(Self { pool })
}
}
#[async_trait]
impl QueueDriver for SqliteDriver {
async fn push(&self, id: &str, job_name: &str, payload: &str) -> Result<(), QueueError> {
sqlx::query("INSERT INTO rullst_jobs (id, name, payload) VALUES (?, ?, ?)")
.bind(id)
.bind(job_name)
.bind(payload)
.execute(&self.pool)
.await
.map_err(|e| QueueError::Driver(format!("Failed to push job: {}", e)))?;
Ok(())
}
async fn pop(&self) -> Result<Option<QueuedJob>, QueueError> {
let row: Option<(String, String, String, i32)> = sqlx::query_as(
r#"UPDATE rullst_jobs
SET status = 'processing', attempts = attempts + 1, updated_at = datetime('now')
WHERE id = (
SELECT id FROM rullst_jobs WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1
)
RETURNING id, name, payload, attempts"#,
)
.fetch_optional(&self.pool)
.await
.map_err(|e| QueueError::Driver(format!("Failed to pop job: {}", e)))?;
Ok(row.map(|(id, name, payload_str, attempts)| {
let payload = serde_json::from_str(&payload_str).unwrap_or(Value::Null);
QueuedJob {
id,
name,
payload,
attempts: attempts as u32,
}
}))
}
async fn mark_complete(&self, job_id: &str) -> Result<(), QueueError> {
sqlx::query("DELETE FROM rullst_jobs WHERE id = ?")
.bind(job_id)
.execute(&self.pool)
.await
.map_err(|e| QueueError::Driver(format!("Failed to mark job complete: {}", e)))?;
Ok(())
}
async fn mark_failed(&self, job_id: &str, error: &str) -> Result<(), QueueError> {
sqlx::query(
"UPDATE rullst_jobs SET status = 'failed', error = ?, updated_at = datetime('now') WHERE id = ?",
)
.bind(error)
.bind(job_id)
.execute(&self.pool)
.await
.map_err(|e| QueueError::Driver(format!("Failed to mark job failed: {}", e)))?;
Ok(())
}
async fn pending_count(&self) -> Result<u64, QueueError> {
let (count,): (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM rullst_jobs WHERE status = 'pending'")
.fetch_one(&self.pool)
.await
.map_err(|e| QueueError::Driver(format!("Failed to count pending jobs: {}", e)))?;
Ok(count as u64)
}
}
#[cfg(feature = "queue-redis")]
pub mod redis_driver {
use super::*;
pub struct RedisDriver {
client: redis::Client,
queue_key: String,
}
impl RedisDriver {
pub fn new(redis_url: &str) -> Result<Self, QueueError> {
let client = redis::Client::open(redis_url)
.map_err(|e| QueueError::Driver(format!("Failed to connect to Redis: {}", e)))?;
Ok(Self {
client,
queue_key: "rullst:queue:default".to_string(),
})
}
}
#[async_trait]
impl QueueDriver for RedisDriver {
async fn push(&self, id: &str, job_name: &str, payload: &str) -> Result<(), QueueError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| QueueError::Driver(format!("Redis connection failed: {}", e)))?;
let job_data = serde_json::json!({
"id": id,
"name": job_name,
"payload": payload,
"attempts": 0
});
redis::cmd("RPUSH")
.arg(&self.queue_key)
.arg(job_data.to_string())
.query_async::<i64>(&mut con)
.await
.map_err(|e| QueueError::Driver(format!("Failed to push to Redis: {}", e)))?;
Ok(())
}
async fn pop(&self) -> Result<Option<QueuedJob>, QueueError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| QueueError::Driver(format!("Redis connection failed: {}", e)))?;
let result: Option<String> = redis::cmd("LPOP")
.arg(&self.queue_key)
.query_async(&mut con)
.await
.map_err(|e| QueueError::Driver(format!("Failed to pop from Redis: {}", e)))?;
match result {
Some(data) => {
let parsed: serde_json::Value = serde_json::from_str(&data)
.map_err(|e| QueueError::Serialization(e.to_string()))?;
let payload_str = parsed["payload"].as_str().unwrap_or("{}");
let payload = serde_json::from_str(payload_str).unwrap_or(Value::Null);
Ok(Some(QueuedJob {
id: parsed["id"].as_str().unwrap_or("").to_string(),
name: parsed["name"].as_str().unwrap_or("").to_string(),
payload,
attempts: parsed["attempts"].as_u64().unwrap_or(0) as u32 + 1,
}))
}
None => Ok(None),
}
}
async fn mark_complete(&self, _job_id: &str) -> Result<(), QueueError> {
Ok(())
}
async fn mark_failed(&self, _job_id: &str, _error: &str) -> Result<(), QueueError> {
Ok(())
}
async fn pending_count(&self) -> Result<u64, QueueError> {
let mut con = self.client.get_multiplexed_async_connection().await
.map_err(|e| QueueError::Driver(format!("Redis connection failed: {}", e)))?;
let count: i64 = redis::cmd("LLEN")
.arg(&self.queue_key)
.query_async(&mut con)
.await
.map_err(|e| QueueError::Driver(format!("Failed to get queue length: {}", e)))?;
Ok(count as u64)
}
}
}
pub struct Queue {
driver: Arc<Box<dyn QueueDriver>>,
}
impl Queue {
pub async fn sqlite(database_url: &str) -> Result<Self, QueueError> {
let driver = SqliteDriver::new(database_url).await?;
Ok(Self {
driver: Arc::new(Box::new(driver)),
})
}
#[cfg(feature = "queue-redis")]
pub fn redis(redis_url: &str) -> Result<Self, QueueError> {
let driver = redis_driver::RedisDriver::new(redis_url)?;
Ok(Self {
driver: Arc::new(Box::new(driver)),
})
}
pub fn custom(driver: Box<dyn QueueDriver>) -> Self {
Self {
driver: Arc::new(driver),
}
}
pub async fn dispatch(&self, job_name: &str, payload: Value) -> Result<String, QueueError> {
let id = Uuid::new_v4().to_string();
let payload_str = serde_json::to_string(&payload)
.map_err(|e| QueueError::Serialization(e.to_string()))?;
self.driver.push(&id, job_name, &payload_str).await?;
Ok(id)
}
pub async fn pending_count(&self) -> Result<u64, QueueError> {
self.driver.pending_count().await
}
pub(crate) fn driver_ref(&self) -> Arc<Box<dyn QueueDriver>> {
Arc::clone(&self.driver)
}
}
type JobHandler = Box<
dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send>>
+ Send
+ Sync,
>;
pub struct Worker {
driver: Arc<Box<dyn QueueDriver>>,
handlers: HashMap<String, Arc<JobHandler>>,
poll_interval_ms: u64,
}
impl Worker {
pub fn new(queue: &Queue) -> Self {
Self {
driver: queue.driver_ref(),
handlers: HashMap::new(),
poll_interval_ms: 1000,
}
}
pub fn poll_interval(mut self, ms: u64) -> Self {
self.poll_interval_ms = ms;
self
}
pub fn register<F, Fut>(&mut self, name: &str, handler: F) -> &mut Self
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send + 'static,
{
let boxed: JobHandler = Box::new(move |payload| Box::pin(handler(payload)));
self.handlers.insert(name.to_string(), Arc::new(boxed));
self
}
pub fn run(&self) {
let driver = Arc::clone(&self.driver);
let handlers = self.handlers.clone();
let poll_interval = self.poll_interval_ms;
tokio::spawn(async move {
println!("🔄 Rullst Worker started. Polling every {}ms...", poll_interval);
loop {
match driver.pop().await {
Ok(Some(job)) => {
if let Some(handler) = handlers.get(&job.name) {
let handler = Arc::clone(handler);
let driver = Arc::clone(&driver);
let job_id = job.id.clone();
let job_name = job.name.clone();
tokio::spawn(async move {
match handler(job.payload).await {
Ok(()) => {
let _ = driver.mark_complete(&job_id).await;
}
Err(e) => {
eprintln!("❌ Job '{}' ({}) failed: {}", job_name, job_id, e);
let _ = driver.mark_failed(&job_id, &e.to_string()).await;
}
}
});
} else {
eprintln!("⚠️ No handler registered for job: {}", job.name);
let _ = driver.mark_failed(&job.id, "No handler registered").await;
}
}
Ok(None) => {
}
Err(e) => {
eprintln!("❌ Queue poll error: {}", e);
}
}
tokio::time::sleep(tokio::time::Duration::from_millis(poll_interval)).await;
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_sqlite_queue_push_pop() {
let queue = Queue::sqlite("sqlite::memory:").await.unwrap();
let job_id = queue
.dispatch("test_job", serde_json::json!({"key": "value"}))
.await
.unwrap();
assert!(!job_id.is_empty());
let count = queue.pending_count().await.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn test_sqlite_queue_pop_returns_correct_job() {
let driver = SqliteDriver::new("sqlite::memory:").await.unwrap();
driver
.push("job-1", "send_email", r#"{"to":"a@b.com"}"#)
.await
.unwrap();
driver
.push("job-2", "process_image", r#"{"path":"/img.png"}"#)
.await
.unwrap();
let job = driver.pop().await.unwrap().unwrap();
assert_eq!(job.id, "job-1");
assert_eq!(job.name, "send_email");
assert_eq!(job.payload["to"], "a@b.com");
driver.mark_complete("job-1").await.unwrap();
let job2 = driver.pop().await.unwrap().unwrap();
assert_eq!(job2.id, "job-2");
assert_eq!(job2.name, "process_image");
}
#[tokio::test]
async fn test_sqlite_queue_mark_failed() {
let driver = SqliteDriver::new("sqlite::memory:").await.unwrap();
driver
.push("fail-job", "bad_job", r#"{}"#)
.await
.unwrap();
let job = driver.pop().await.unwrap().unwrap();
driver.mark_failed(&job.id, "Something went wrong").await.unwrap();
let count = driver.pending_count().await.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_sqlite_queue_empty_pop() {
let driver = SqliteDriver::new("sqlite::memory:").await.unwrap();
let result = driver.pop().await.unwrap();
assert!(result.is_none());
}
}