use async_trait::async_trait;
use celers_core::event::{Event, EventEmitter};
use celers_core::event_persistence::EventPersister;
use chrono::{DateTime, Utc};
use sqlx::{PgPool, Row};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use crate::{BackendError, Result};
#[derive(Debug, Clone)]
pub struct DbEventPersisterConfig {
pub batch_size: usize,
pub flush_interval: Duration,
pub enabled: bool,
}
impl Default for DbEventPersisterConfig {
fn default() -> Self {
Self {
batch_size: 100,
flush_interval: Duration::from_secs(5),
enabled: true,
}
}
}
impl DbEventPersisterConfig {
#[must_use]
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
#[must_use]
pub fn with_flush_interval(mut self, interval: Duration) -> Self {
self.flush_interval = interval;
self
}
#[must_use]
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
pub struct DbEventPersister {
pool: PgPool,
config: DbEventPersisterConfig,
buffer: Arc<Mutex<Vec<Event>>>,
}
impl DbEventPersister {
pub async fn new(pool: PgPool, config: DbEventPersisterConfig) -> Result<Self> {
Ok(Self {
pool,
config,
buffer: Arc::new(Mutex::new(Vec::new())),
})
}
pub async fn migrate(&self) -> Result<()> {
let sql = r#"
CREATE TABLE IF NOT EXISTS celers_events (
id BIGSERIAL PRIMARY KEY,
event_type VARCHAR(64) NOT NULL,
task_id UUID,
worker VARCHAR(255),
timestamp TIMESTAMPTZ NOT NULL,
payload JSONB NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_celers_events_type ON celers_events(event_type);
CREATE INDEX IF NOT EXISTS idx_celers_events_task_id ON celers_events(task_id);
CREATE INDEX IF NOT EXISTS idx_celers_events_timestamp ON celers_events(timestamp);
"#;
sqlx::query(sql).execute(&self.pool).await.map_err(|e| {
BackendError::Connection(format!("Failed to run event migrations: {}", e))
})?;
Ok(())
}
async fn flush_buffer(&self) -> Result<()> {
let events = {
let mut buf = self.buffer.lock().await;
if buf.is_empty() {
return Ok(());
}
std::mem::take(&mut *buf)
};
let mut tx =
self.pool.begin().await.map_err(|e| {
BackendError::Connection(format!("Failed to begin transaction: {}", e))
})?;
for event in &events {
let event_type = event.event_type();
let task_id = event.task_id();
let worker = event.hostname().map(|s| s.to_string());
let timestamp = event.timestamp();
let payload = serde_json::to_value(event).map_err(|e| {
BackendError::Serialization(format!("Failed to serialize event: {}", e))
})?;
sqlx::query(
r#"
INSERT INTO celers_events (event_type, task_id, worker, timestamp, payload)
VALUES ($1, $2, $3, $4, $5)
"#,
)
.bind(event_type)
.bind(task_id)
.bind(&worker)
.bind(timestamp)
.bind(&payload)
.execute(&mut *tx)
.await
.map_err(|e| BackendError::Connection(format!("Failed to insert event: {}", e)))?;
}
tx.commit().await.map_err(|e| {
BackendError::Connection(format!("Failed to commit event batch: {}", e))
})?;
Ok(())
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn buffer_len(&self) -> usize {
self.buffer.lock().await.len()
}
}
#[async_trait]
impl EventEmitter for DbEventPersister {
async fn emit(&self, event: Event) -> celers_core::Result<()> {
if !self.config.enabled {
return Ok(());
}
let should_flush = {
let mut buf = self.buffer.lock().await;
buf.push(event);
buf.len() >= self.config.batch_size
};
if should_flush {
self.flush_buffer().await.map_err(|e| {
celers_core::CelersError::Other(format!("DB event flush failed: {}", e))
})?;
}
Ok(())
}
async fn emit_batch(&self, events: Vec<Event>) -> celers_core::Result<()> {
if !self.config.enabled {
return Ok(());
}
let should_flush = {
let mut buf = self.buffer.lock().await;
buf.extend(events);
buf.len() >= self.config.batch_size
};
if should_flush {
self.flush_buffer().await.map_err(|e| {
celers_core::CelersError::Other(format!("DB event flush failed: {}", e))
})?;
}
Ok(())
}
fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[async_trait]
impl EventPersister for DbEventPersister {
async fn query_events(
&self,
from: DateTime<Utc>,
to: DateTime<Utc>,
event_type_filter: Option<&str>,
) -> celers_core::Result<Vec<Event>> {
self.flush_buffer().await.map_err(|e| {
celers_core::CelersError::Other(format!("DB event flush failed: {}", e))
})?;
let rows = if let Some(et) = event_type_filter {
sqlx::query(
r#"
SELECT payload FROM celers_events
WHERE timestamp >= $1 AND timestamp <= $2 AND event_type = $3
ORDER BY timestamp ASC
"#,
)
.bind(from)
.bind(to)
.bind(et)
.fetch_all(&self.pool)
.await
} else {
sqlx::query(
r#"
SELECT payload FROM celers_events
WHERE timestamp >= $1 AND timestamp <= $2
ORDER BY timestamp ASC
"#,
)
.bind(from)
.bind(to)
.fetch_all(&self.pool)
.await
}
.map_err(|e| celers_core::CelersError::Other(format!("Failed to query events: {}", e)))?;
let mut events = Vec::with_capacity(rows.len());
for row in &rows {
let payload: serde_json::Value = row.get("payload");
match serde_json::from_value::<Event>(payload) {
Ok(event) => events.push(event),
Err(e) => {
tracing::warn!("Failed to deserialize event from DB: {}", e);
}
}
}
Ok(events)
}
async fn count_events(&self, event_type: Option<&str>) -> celers_core::Result<u64> {
self.flush_buffer().await.map_err(|e| {
celers_core::CelersError::Other(format!("DB event flush failed: {}", e))
})?;
let count: i64 = if let Some(et) = event_type {
let row = sqlx::query("SELECT COUNT(*) FROM celers_events WHERE event_type = $1")
.bind(et)
.fetch_one(&self.pool)
.await
.map_err(|e| {
celers_core::CelersError::Other(format!("Failed to count events: {}", e))
})?;
row.get(0)
} else {
let row = sqlx::query("SELECT COUNT(*) FROM celers_events")
.fetch_one(&self.pool)
.await
.map_err(|e| {
celers_core::CelersError::Other(format!("Failed to count events: {}", e))
})?;
row.get(0)
};
Ok(count as u64)
}
async fn cleanup(&self, older_than: chrono::Duration) -> celers_core::Result<u64> {
let cutoff = Utc::now().checked_sub_signed(older_than).ok_or_else(|| {
celers_core::CelersError::Other("Invalid duration for cleanup cutoff".to_string())
})?;
let result = sqlx::query("DELETE FROM celers_events WHERE timestamp < $1")
.bind(cutoff)
.execute(&self.pool)
.await
.map_err(|e| {
celers_core::CelersError::Other(format!("Failed to cleanup events: {}", e))
})?;
Ok(result.rows_affected())
}
async fn flush(&self) -> celers_core::Result<()> {
self.flush_buffer()
.await
.map_err(|e| celers_core::CelersError::Other(format!("DB event flush failed: {}", e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_db_persister_config_defaults() {
let config = DbEventPersisterConfig::default();
assert_eq!(config.batch_size, 100);
assert_eq!(config.flush_interval, Duration::from_secs(5));
assert!(config.enabled);
}
#[tokio::test]
async fn test_db_persister_buffer_logic() {
let config = DbEventPersisterConfig::default()
.with_batch_size(1000)
.with_enabled(true);
assert_eq!(config.batch_size, 1000);
assert!(config.enabled);
let config2 = DbEventPersisterConfig::default()
.with_enabled(false)
.with_flush_interval(Duration::from_secs(10));
assert!(!config2.enabled);
assert_eq!(config2.flush_interval, Duration::from_secs(10));
}
}