use crate::storage::DbPool;
use async_trait::async_trait;
use sea_query::{Asterisk, Condition, Expr, Func, Query};
use serde::Deserialize;
use crate::errors::OrionError;
use crate::storage::models::{self, Trace};
use crate::storage::repositories::workflows::PaginatedResult;
use crate::storage::{build_sqlx, schema::Traces};
#[derive(Debug, Default, Deserialize)]
pub struct TraceFilter {
pub status: Option<String>,
pub channel: Option<String>,
pub mode: Option<String>,
pub limit: Option<i64>,
pub offset: Option<i64>,
pub sort_by: Option<String>,
pub sort_order: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TraceCompletedRow {
pub channel: String,
pub mode: String,
pub input_json: Option<String>,
pub result_json: String,
pub duration_ms: f64,
pub task_trace_json: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TraceResultRow {
pub id: String,
pub result_json: String,
pub duration_ms: f64,
pub task_trace_json: Option<String>,
}
#[async_trait]
pub trait TraceRepository: Send + Sync {
async fn create_pending(
&self,
channel: &str,
mode: &str,
input_json: Option<&str>,
) -> Result<Trace, OrionError>;
async fn get_by_id(&self, id: &str) -> Result<Trace, OrionError>;
async fn update_status(
&self,
id: &str,
status: &str,
error_message: Option<&str>,
) -> Result<Trace, OrionError>;
async fn set_result(
&self,
id: &str,
result_json: &str,
duration_ms: f64,
task_trace_json: Option<&str>,
) -> Result<(), OrionError>;
async fn store_completed(
&self,
channel: &str,
mode: &str,
input_json: Option<&str>,
result_json: &str,
duration_ms: f64,
task_trace_json: Option<&str>,
) -> Result<String, OrionError>;
async fn store_completed_batch(
&self,
rows: &[TraceCompletedRow],
) -> Result<Vec<String>, OrionError> {
let mut ids = Vec::with_capacity(rows.len());
for row in rows {
ids.push(
self.store_completed(
&row.channel,
&row.mode,
row.input_json.as_deref(),
&row.result_json,
row.duration_ms,
row.task_trace_json.as_deref(),
)
.await?,
);
}
Ok(ids)
}
async fn set_result_batch(&self, rows: &[TraceResultRow]) -> Result<(), OrionError> {
for row in rows {
self.set_result(
&row.id,
&row.result_json,
row.duration_ms,
row.task_trace_json.as_deref(),
)
.await?;
}
Ok(())
}
async fn list_paginated(
&self,
filter: &TraceFilter,
) -> Result<PaginatedResult<Trace>, OrionError>;
async fn delete_older_than(&self, hours: u64) -> Result<u64, OrionError>;
}
pub struct SqlTraceRepository {
pool: DbPool,
}
impl SqlTraceRepository {
pub fn new(pool: DbPool) -> Self {
Self { pool }
}
}
#[async_trait]
impl TraceRepository for SqlTraceRepository {
async fn create_pending(
&self,
channel: &str,
mode: &str,
input_json: Option<&str>,
) -> Result<Trace, OrionError> {
crate::metrics::timed_db_op("traces.create_pending", async {
let id = uuid::Uuid::new_v4().to_string();
let input_val = super::helpers::optional_string_value(input_json);
let (sql, values) = build_sqlx(
Query::insert()
.into_table(Traces::Table)
.columns([
Traces::Id,
Traces::Status,
Traces::Channel,
Traces::Mode,
Traces::InputJson,
])
.values_panic([
Expr::val(id.as_str()).into(),
Expr::val("pending").into(),
Expr::val(channel).into(),
Expr::val(mode).into(),
Expr::val(input_val).into(),
]),
);
self.pool.execute_query(&sql, values).await?;
self.get_by_id(&id).await
})
.await
}
async fn get_by_id(&self, id: &str) -> Result<Trace, OrionError> {
crate::metrics::timed_db_op("traces.get_by_id", async {
let (sql, values) = build_sqlx(
Query::select()
.column(Asterisk)
.from(Traces::Table)
.and_where(Expr::col(Traces::Id).eq(id)),
);
self.pool
.fetch_optional_as::<Trace>(&sql, values)
.await?
.ok_or_else(|| OrionError::NotFound(format!("Trace '{id}' not found")))
})
.await
}
async fn update_status(
&self,
id: &str,
status: &str,
error_message: Option<&str>,
) -> Result<Trace, OrionError> {
crate::metrics::timed_db_op("traces.update_status", async {
let now = chrono::Utc::now().naive_utc().to_string();
let (started_at, completed_at) = if status == models::TRACE_STATUS_RUNNING {
(Some(now), None)
} else if status == models::TRACE_STATUS_COMPLETED
|| status == models::TRACE_STATUS_FAILED
{
(None, Some(now))
} else {
(None, None)
};
let mut update = Query::update();
update.table(Traces::Table).value(Traces::Status, status);
if let Some(err) = error_message {
update.value(Traces::ErrorMessage, err);
}
if let Some(ref sa) = started_at {
update.value(Traces::StartedAt, sa.as_str());
}
if let Some(ref ca) = completed_at {
update.value(Traces::CompletedAt, ca.as_str());
}
update.and_where(Expr::col(Traces::Id).eq(id));
let (sql, values) = build_sqlx(&mut update);
self.pool.execute_query(&sql, values).await?;
self.get_by_id(id).await
})
.await
}
async fn set_result(
&self,
id: &str,
result_json: &str,
duration_ms: f64,
task_trace_json: Option<&str>,
) -> Result<(), OrionError> {
crate::metrics::timed_db_op("traces.set_result", async {
let task_trace_val = super::helpers::optional_string_value(task_trace_json);
let (sql, values) = build_sqlx(
Query::update()
.table(Traces::Table)
.value(Traces::ResultJson, result_json)
.value(Traces::DurationMs, duration_ms)
.value(Traces::TaskTraceJson, task_trace_val)
.and_where(Expr::col(Traces::Id).eq(id)),
);
self.pool.execute_query(&sql, values).await?;
Ok(())
})
.await
}
async fn store_completed(
&self,
channel: &str,
mode: &str,
input_json: Option<&str>,
result_json: &str,
duration_ms: f64,
task_trace_json: Option<&str>,
) -> Result<String, OrionError> {
crate::metrics::timed_db_op("traces.store_completed", async {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().naive_utc().to_string();
let input_val = super::helpers::optional_string_value(input_json);
let task_trace_val = super::helpers::optional_string_value(task_trace_json);
let (sql, values) = build_sqlx(
Query::insert()
.into_table(Traces::Table)
.columns([
Traces::Id,
Traces::Status,
Traces::Channel,
Traces::Mode,
Traces::InputJson,
Traces::ResultJson,
Traces::DurationMs,
Traces::StartedAt,
Traces::CompletedAt,
Traces::TaskTraceJson,
])
.values_panic([
Expr::val(id.as_str()).into(),
Expr::val("completed").into(),
Expr::val(channel).into(),
Expr::val(mode).into(),
Expr::val(input_val).into(),
Expr::val(result_json).into(),
Expr::val(duration_ms).into(),
Expr::val(now.as_str()).into(),
Expr::val(now.as_str()).into(),
Expr::val(task_trace_val).into(),
]),
);
self.pool.execute_query(&sql, values).await?;
Ok(id)
})
.await
}
async fn store_completed_batch(
&self,
rows: &[TraceCompletedRow],
) -> Result<Vec<String>, OrionError> {
if rows.is_empty() {
return Ok(Vec::new());
}
crate::metrics::timed_db_op("traces.store_completed_batch", async {
let now = chrono::Utc::now().naive_utc().to_string();
let mut ids = Vec::with_capacity(rows.len());
let mut insert = Query::insert();
insert.into_table(Traces::Table).columns([
Traces::Id,
Traces::Status,
Traces::Channel,
Traces::Mode,
Traces::InputJson,
Traces::ResultJson,
Traces::DurationMs,
Traces::StartedAt,
Traces::CompletedAt,
Traces::TaskTraceJson,
]);
for row in rows {
let id = uuid::Uuid::new_v4().to_string();
let input_val = super::helpers::optional_string_value(row.input_json.as_deref());
let task_trace_val =
super::helpers::optional_string_value(row.task_trace_json.as_deref());
insert.values_panic([
Expr::val(id.as_str()).into(),
Expr::val("completed").into(),
Expr::val(row.channel.as_str()).into(),
Expr::val(row.mode.as_str()).into(),
Expr::val(input_val).into(),
Expr::val(row.result_json.as_str()).into(),
Expr::val(row.duration_ms).into(),
Expr::val(now.as_str()).into(),
Expr::val(now.as_str()).into(),
Expr::val(task_trace_val).into(),
]);
ids.push(id);
}
let (sql, values) = build_sqlx(&mut insert);
self.pool.execute_query(&sql, values).await?;
crate::metrics::record_trace_persistence_batch_size(rows.len());
Ok(ids)
})
.await
}
async fn set_result_batch(&self, rows: &[TraceResultRow]) -> Result<(), OrionError> {
if rows.is_empty() {
return Ok(());
}
crate::metrics::timed_db_op("traces.set_result_batch", async {
let mut tx = self.pool.begin_tx().await.map_err(OrionError::Storage)?;
for row in rows {
let task_trace_val =
super::helpers::optional_string_value(row.task_trace_json.as_deref());
let (sql, values) = build_sqlx(
Query::update()
.table(Traces::Table)
.value(Traces::ResultJson, row.result_json.as_str())
.value(Traces::DurationMs, row.duration_ms)
.value(Traces::TaskTraceJson, task_trace_val)
.and_where(Expr::col(Traces::Id).eq(row.id.as_str())),
);
tx.execute_query(&sql, values).await?;
}
tx.commit().await.map_err(OrionError::Storage)?;
crate::metrics::record_trace_persistence_batch_size(rows.len());
Ok(())
})
.await
}
async fn list_paginated(
&self,
filter: &TraceFilter,
) -> Result<PaginatedResult<Trace>, OrionError> {
crate::metrics::timed_db_op("traces.list_paginated", async {
let (limit, offset) = super::helpers::clamp_pagination(filter.limit, filter.offset);
let mut cond = Condition::all();
if let Some(ref status) = filter.status {
cond = cond.add(Expr::col(Traces::Status).eq(status.as_str()));
}
if let Some(ref channel) = filter.channel {
cond = cond.add(Expr::col(Traces::Channel).eq(channel.as_str()));
}
if let Some(ref mode) = filter.mode {
cond = cond.add(Expr::col(Traces::Mode).eq(mode.as_str()));
}
let (sql, values) = build_sqlx(
Query::select()
.expr(Func::count(Expr::col(Asterisk)))
.from(Traces::Table)
.cond_where(cond.clone()),
);
let (total,): (i64,) = self.pool.fetch_one_as::<(i64,)>(&sql, values).await?;
let sort_iden = match filter.sort_by.as_deref() {
Some("updated_at") => Traces::UpdatedAt,
Some("status") => Traces::Status,
Some("channel") => Traces::Channel,
Some("mode") => Traces::Mode,
_ => Traces::CreatedAt,
};
let order = super::helpers::parse_sort_order(filter.sort_order.as_deref());
let (sql, values) = build_sqlx(
Query::select()
.column(Asterisk)
.from(Traces::Table)
.cond_where(cond)
.order_by(sort_iden, order)
.limit(limit as u64)
.offset(offset as u64),
);
let data = self.pool.fetch_all_as::<Trace>(&sql, values).await?;
Ok(PaginatedResult {
data,
total,
limit,
offset,
})
})
.await
}
async fn delete_older_than(&self, hours: u64) -> Result<u64, OrionError> {
crate::metrics::timed_db_op("traces.delete_older_than", async {
let cutoff = chrono::Utc::now()
.naive_utc()
.checked_sub_signed(chrono::Duration::hours(hours as i64))
.unwrap_or(chrono::NaiveDateTime::MIN)
.to_string();
let (sql, values) = build_sqlx(
Query::delete()
.from_table(Traces::Table)
.and_where(Expr::col(Traces::CreatedAt).lt(&cutoff))
.and_where(Expr::col(Traces::Status).is_in(["completed", "failed"])),
);
let rows_affected = self.pool.execute_query(&sql, values).await?;
Ok(rows_affected)
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn test_pool() -> crate::storage::DbPool {
crate::storage::init_pool(&crate::config::StorageConfig {
url: "sqlite::memory:".to_string(),
max_connections: 1,
..Default::default()
})
.await
.expect("test")
}
#[tokio::test]
async fn test_delete_older_than_removes_old_completed_traces() {
let pool = test_pool().await;
let repo = SqlTraceRepository::new(pool.clone());
let id = repo
.store_completed("orders", "sync", None, r#"{"ok":true}"#, 10.0, None)
.await
.expect("test");
let old_time = chrono::Utc::now()
.naive_utc()
.checked_sub_signed(chrono::Duration::hours(100))
.expect("test")
.to_string();
match &pool {
crate::storage::DbPool::Sqlite(p) => {
sqlx::query("UPDATE traces SET created_at = ? WHERE id = ?")
.bind(&old_time)
.bind(&id)
.execute(p)
.await
.expect("test");
}
_ => unreachable!("Test requires SQLite"),
}
let _recent_id = repo
.store_completed("orders", "sync", None, r#"{"ok":true}"#, 5.0, None)
.await
.expect("test");
let deleted = repo.delete_older_than(72).await.expect("test");
assert_eq!(deleted, 1);
let remaining = repo
.list_paginated(&TraceFilter::default())
.await
.expect("test");
assert_eq!(remaining.total, 1);
}
#[tokio::test]
async fn test_delete_older_than_preserves_pending_traces() {
let pool = test_pool().await;
let repo = SqlTraceRepository::new(pool.clone());
let trace = repo
.create_pending("orders", "async", None)
.await
.expect("test");
let old_time = chrono::Utc::now()
.naive_utc()
.checked_sub_signed(chrono::Duration::hours(200))
.expect("test")
.to_string();
match &pool {
crate::storage::DbPool::Sqlite(p) => {
sqlx::query("UPDATE traces SET created_at = ? WHERE id = ?")
.bind(&old_time)
.bind(&trace.id)
.execute(p)
.await
.expect("test");
}
_ => unreachable!("Test requires SQLite"),
}
let deleted = repo.delete_older_than(72).await.expect("test");
assert_eq!(deleted, 0);
}
}