use crate::broker_core::MysqlBroker;
use async_trait::async_trait;
use celers_core::{Broker, BrokerMessage, CelersError, Result, SerializedTask, TaskId};
use serde_json::json;
use sqlx::Row;
use std::sync::atomic::Ordering;
use uuid::Uuid;
#[cfg(feature = "metrics")]
use celers_metrics::{TASKS_ENQUEUED_BY_TYPE, TASKS_ENQUEUED_TOTAL};
#[async_trait]
impl Broker for MysqlBroker {
async fn enqueue(&self, task: SerializedTask) -> Result<TaskId> {
let task_id = task.metadata.id;
let mut db_metadata = json!({
"queue": self.queue_name,
"enqueued_at": chrono::Utc::now().to_rfc3339(),
});
if let Ok(task_meta) = serde_json::to_value(&task.metadata) {
if let Some(obj) = db_metadata.as_object_mut() {
if let Some(meta_obj) = task_meta.as_object() {
for (k, v) in meta_obj {
obj.insert(k.clone(), v.clone());
}
}
}
}
sqlx::query(
r#"
INSERT INTO celers_tasks
(id, task_name, payload, state, priority, max_retries, metadata, created_at, scheduled_at)
VALUES (?, ?, ?, 'pending', ?, ?, ?, NOW(), NOW())
"#,
)
.bind(task_id.to_string())
.bind(&task.metadata.name)
.bind(&task.payload)
.bind(task.metadata.priority)
.bind(task.metadata.max_retries as i32)
.bind(serde_json::to_string(&db_metadata).unwrap_or_else(|_| "{}".to_string()))
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to enqueue task: {}", e)))?;
#[cfg(feature = "metrics")]
{
TASKS_ENQUEUED_TOTAL.inc();
TASKS_ENQUEUED_BY_TYPE
.with_label_values(&[&task.metadata.name])
.inc();
}
Ok(task_id)
}
async fn dequeue(&self) -> Result<Option<BrokerMessage>> {
if self.paused.load(Ordering::SeqCst) {
return Ok(None);
}
let mut tx = self
.pool
.begin()
.await
.map_err(|e| CelersError::Other(format!("Failed to begin transaction: {}", e)))?;
let row = sqlx::query(
r#"
SELECT id, task_name, payload, retry_count
FROM celers_tasks
WHERE state = 'pending'
AND scheduled_at <= NOW()
ORDER BY priority DESC, created_at ASC
FOR UPDATE SKIP LOCKED
LIMIT 1
"#,
)
.fetch_optional(&mut *tx)
.await
.map_err(|e| CelersError::Other(format!("Failed to dequeue task: {}", e)))?;
if let Some(row) = row {
let task_id_str: String = row.get("id");
let _task_id = Uuid::parse_str(&task_id_str)
.map_err(|e| CelersError::Other(format!("Invalid UUID: {}", e)))?;
let task_name: String = row.get("task_name");
let payload: Vec<u8> = row.get("payload");
let retry_count: i32 = row.get("retry_count");
sqlx::query(
r#"
UPDATE celers_tasks
SET state = 'processing',
started_at = NOW(),
retry_count = retry_count + 1
WHERE id = ?
"#,
)
.bind(&task_id_str)
.execute(&mut *tx)
.await
.map_err(|e| CelersError::Other(format!("Failed to mark task as processing: {}", e)))?;
tx.commit()
.await
.map_err(|e| CelersError::Other(format!("Failed to commit transaction: {}", e)))?;
Ok(Some(BrokerMessage {
task: SerializedTask::new(task_name, payload),
receipt_handle: Some(retry_count.to_string()),
}))
} else {
tx.rollback().await.map_err(|e| {
CelersError::Other(format!("Failed to rollback transaction: {}", e))
})?;
Ok(None)
}
}
async fn ack(&self, task_id: &TaskId, _receipt_handle: Option<&str>) -> Result<()> {
sqlx::query(
r#"
UPDATE celers_tasks
SET state = 'completed',
completed_at = NOW()
WHERE id = ?
"#,
)
.bind(task_id.to_string())
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to ack task: {}", e)))?;
Ok(())
}
async fn reject(
&self,
task_id: &TaskId,
_receipt_handle: Option<&str>,
requeue: bool,
) -> Result<()> {
if requeue {
let row = sqlx::query(
r#"
SELECT retry_count, max_retries
FROM celers_tasks
WHERE id = ?
"#,
)
.bind(task_id.to_string())
.fetch_one(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to fetch task: {}", e)))?;
let retry_count: i32 = row.get("retry_count");
let max_retries: i32 = row.get("max_retries");
if retry_count >= max_retries {
self.move_to_dlq(task_id).await?;
} else {
let backoff_seconds = 2_i64.pow(retry_count as u32).min(3600);
sqlx::query(
r#"
UPDATE celers_tasks
SET state = 'pending',
scheduled_at = DATE_ADD(NOW(), INTERVAL ? SECOND),
started_at = NULL,
worker_id = NULL
WHERE id = ?
"#,
)
.bind(backoff_seconds)
.bind(task_id.to_string())
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to requeue task: {}", e)))?;
}
} else {
sqlx::query(
r#"
UPDATE celers_tasks
SET state = 'failed',
completed_at = NOW()
WHERE id = ?
"#,
)
.bind(task_id.to_string())
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to mark task as failed: {}", e)))?;
}
Ok(())
}
async fn queue_size(&self) -> Result<usize> {
let row = sqlx::query(
r#"
SELECT COUNT(*) as count
FROM celers_tasks
WHERE state = 'pending'
"#,
)
.fetch_one(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to get queue size: {}", e)))?;
let count: i64 = row.get("count");
Ok(count as usize)
}
async fn cancel(&self, task_id: &TaskId) -> Result<bool> {
let result = sqlx::query(
r#"
UPDATE celers_tasks
SET state = 'cancelled',
completed_at = NOW()
WHERE id = ? AND state IN ('pending', 'processing')
"#,
)
.bind(task_id.to_string())
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to cancel task: {}", e)))?;
Ok(result.rows_affected() > 0)
}
async fn enqueue_at(&self, task: SerializedTask, execute_at: i64) -> Result<TaskId> {
let task_id = task.metadata.id;
let mut db_metadata = json!({
"queue": self.queue_name,
"enqueued_at": chrono::Utc::now().to_rfc3339(),
"scheduled_for": execute_at,
});
if let Ok(task_meta) = serde_json::to_value(&task.metadata) {
if let Some(obj) = db_metadata.as_object_mut() {
if let Some(meta_obj) = task_meta.as_object() {
for (k, v) in meta_obj {
obj.insert(k.clone(), v.clone());
}
}
}
}
let scheduled_at = chrono::DateTime::from_timestamp(execute_at, 0)
.ok_or_else(|| CelersError::Other("Invalid timestamp".to_string()))?
.format("%Y-%m-%d %H:%M:%S")
.to_string();
sqlx::query(
r#"
INSERT INTO celers_tasks
(id, task_name, payload, state, priority, max_retries, metadata, created_at, scheduled_at)
VALUES (?, ?, ?, 'pending', ?, ?, ?, NOW(), ?)
"#,
)
.bind(task_id.to_string())
.bind(&task.metadata.name)
.bind(&task.payload)
.bind(task.metadata.priority)
.bind(task.metadata.max_retries as i32)
.bind(serde_json::to_string(&db_metadata).unwrap_or_else(|_| "{}".to_string()))
.bind(scheduled_at)
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to enqueue delayed task: {}", e)))?;
#[cfg(feature = "metrics")]
{
TASKS_ENQUEUED_TOTAL.inc();
TASKS_ENQUEUED_BY_TYPE
.with_label_values(&[&task.metadata.name])
.inc();
}
Ok(task_id)
}
async fn enqueue_after(&self, task: SerializedTask, delay_secs: u64) -> Result<TaskId> {
let task_id = task.metadata.id;
let mut db_metadata = json!({
"queue": self.queue_name,
"enqueued_at": chrono::Utc::now().to_rfc3339(),
"delay_seconds": delay_secs,
});
if let Ok(task_meta) = serde_json::to_value(&task.metadata) {
if let Some(obj) = db_metadata.as_object_mut() {
if let Some(meta_obj) = task_meta.as_object() {
for (k, v) in meta_obj {
obj.insert(k.clone(), v.clone());
}
}
}
}
sqlx::query(
r#"
INSERT INTO celers_tasks
(id, task_name, payload, state, priority, max_retries, metadata, created_at, scheduled_at)
VALUES (?, ?, ?, 'pending', ?, ?, ?, NOW(), DATE_ADD(NOW(), INTERVAL ? SECOND))
"#,
)
.bind(task_id.to_string())
.bind(&task.metadata.name)
.bind(&task.payload)
.bind(task.metadata.priority)
.bind(task.metadata.max_retries as i32)
.bind(serde_json::to_string(&db_metadata).unwrap_or_else(|_| "{}".to_string()))
.bind(delay_secs as i64)
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to enqueue delayed task: {}", e)))?;
#[cfg(feature = "metrics")]
{
TASKS_ENQUEUED_TOTAL.inc();
TASKS_ENQUEUED_BY_TYPE
.with_label_values(&[&task.metadata.name])
.inc();
}
Ok(task_id)
}
async fn enqueue_batch(&self, tasks: Vec<SerializedTask>) -> Result<Vec<TaskId>> {
self.enqueue_batch_impl(tasks).await
}
async fn dequeue_batch(&self, count: usize) -> Result<Vec<BrokerMessage>> {
self.dequeue_batch_impl(count).await
}
async fn ack_batch(&self, tasks: &[(TaskId, Option<String>)]) -> Result<()> {
if tasks.is_empty() {
return Ok(());
}
let task_ids: Vec<String> = tasks.iter().map(|(id, _)| id.to_string()).collect();
let placeholders = task_ids.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
let query_str = format!(
r#"
UPDATE celers_tasks
SET state = 'completed',
completed_at = NOW()
WHERE id IN ({})
"#,
placeholders
);
let mut query = sqlx::query(&query_str);
for task_id in task_ids {
query = query.bind(task_id);
}
query
.execute(&self.pool)
.await
.map_err(|e| CelersError::Other(format!("Failed to batch ack tasks: {}", e)))?;
Ok(())
}
}