use anyhow::anyhow;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sqlx::Row;
use uuid::Uuid;
use crate::error::{FusilladeError, Result};
use crate::request::RequestId;
use crate::response_step::{
CreateStepInput, ResponseStep, ResponseStepStore, StepId, StepKind, StepState,
};
pub use sqlx_pool_router::PoolProvider;
pub struct PostgresResponseStepManager<P: PoolProvider> {
pools: P,
}
impl<P: PoolProvider> PostgresResponseStepManager<P> {
pub fn new(pools: P) -> Self {
Self { pools }
}
}
fn step_from_row(row: &sqlx::postgres::PgRow) -> Result<ResponseStep> {
let kind_str: String = row.get("step_kind");
let kind = StepKind::parse(&kind_str).ok_or_else(|| {
FusilladeError::Other(anyhow!("Unknown step_kind in response_steps: {}", kind_str))
})?;
let state_str: String = row.get("state");
let state = StepState::parse(&state_str).ok_or_else(|| {
FusilladeError::Other(anyhow!("Unknown state in response_steps: {}", state_str))
})?;
Ok(ResponseStep {
id: StepId(row.get("id")),
request_id: RequestId(row.get("request_id")),
prev_step_id: row.get::<Option<Uuid>, _>("prev_step_id").map(StepId),
parent_step_id: row.get::<Option<Uuid>, _>("parent_step_id").map(StepId),
step_kind: kind,
step_sequence: row.get("step_sequence"),
request_payload: row.get("request_payload"),
response_payload: row.get::<Option<serde_json::Value>, _>("response_payload"),
state,
started_at: row.get::<Option<DateTime<Utc>>, _>("started_at"),
completed_at: row.get::<Option<DateTime<Utc>>, _>("completed_at"),
failed_at: row.get::<Option<DateTime<Utc>>, _>("failed_at"),
canceled_at: row.get::<Option<DateTime<Utc>>, _>("canceled_at"),
retry_attempt: row.get("retry_attempt"),
error: row.get::<Option<serde_json::Value>, _>("error"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
})
}
const STEP_COLUMNS: &str = "id, request_id, prev_step_id, parent_step_id, step_kind, step_sequence, \
request_payload, response_payload, state, started_at, completed_at, failed_at, \
canceled_at, retry_attempt, error, created_at, updated_at";
async fn fetch_state(pool: &sqlx::PgPool, id: StepId) -> Result<Option<String>> {
sqlx::query("SELECT state FROM response_steps WHERE id = $1")
.bind(id.0)
.fetch_optional(pool)
.await
.map(|opt| opt.map(|row| row.get::<String, _>("state")))
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch step state: {}", e)))
}
#[async_trait]
impl<P: PoolProvider> ResponseStepStore for PostgresResponseStepManager<P> {
async fn create_step(&self, input: CreateStepInput) -> Result<StepId> {
let pool = self.pools.write();
let id = input.id.unwrap_or_else(Uuid::new_v4);
sqlx::query(
"INSERT INTO response_steps \
(id, request_id, prev_step_id, parent_step_id, step_kind, step_sequence, request_payload) \
VALUES ($1, $2, $3, $4, $5, $6, $7)",
)
.bind(id)
.bind(input.request_id.0)
.bind(input.prev_step_id.map(|s| s.0))
.bind(input.parent_step_id.map(|s| s.0))
.bind(input.step_kind.as_str())
.bind(input.step_sequence)
.bind(&input.request_payload)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to insert response_step: {}", e)))?;
Ok(StepId(id))
}
async fn get_step(&self, id: StepId) -> Result<Option<ResponseStep>> {
let pool = self.pools.write();
let query = format!("SELECT {} FROM response_steps WHERE id = $1", STEP_COLUMNS);
let row = sqlx::query(&query)
.bind(id.0)
.fetch_optional(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fetch response_step: {}", e)))?;
row.as_ref().map(step_from_row).transpose()
}
async fn list_chain(&self, request_id: RequestId) -> Result<Vec<ResponseStep>> {
let pool = self.pools.write();
let query = format!(
"SELECT {} FROM response_steps WHERE request_id = $1 ORDER BY step_sequence ASC",
STEP_COLUMNS
);
let rows = sqlx::query(&query)
.bind(request_id.0)
.fetch_all(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to list response_steps: {}", e)))?;
rows.iter().map(step_from_row).collect()
}
async fn list_scope(
&self,
request_id: RequestId,
scope_parent: Option<StepId>,
) -> Result<Vec<ResponseStep>> {
let pool = self.pools.write();
let query = format!(
"SELECT {} FROM response_steps \
WHERE request_id = $1 AND parent_step_id IS NOT DISTINCT FROM $2 \
ORDER BY step_sequence ASC",
STEP_COLUMNS
);
let rows = sqlx::query(&query)
.bind(request_id.0)
.bind(scope_parent.map(|s| s.0))
.fetch_all(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to list response_steps: {}", e)))?;
rows.iter().map(step_from_row).collect()
}
async fn mark_step_processing(&self, id: StepId) -> Result<()> {
let pool = self.pools.write();
let result = sqlx::query(
"UPDATE response_steps \
SET state = 'processing', started_at = NOW(), updated_at = NOW() \
WHERE id = $1 AND state = 'pending'",
)
.bind(id.0)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to mark step as processing: {}", e)))?;
if result.rows_affected() == 0 {
if fetch_state(pool, id).await?.is_none() {
return Err(FusilladeError::Other(anyhow!(
"response_step not found: {}",
id
)));
}
}
Ok(())
}
async fn complete_step(&self, id: StepId, response: serde_json::Value) -> Result<()> {
let pool = self.pools.write();
let result = sqlx::query(
"UPDATE response_steps \
SET state = 'completed', \
response_payload = $2, \
completed_at = NOW(), \
updated_at = NOW() \
WHERE id = $1 AND state IN ('pending', 'processing')",
)
.bind(id.0)
.bind(&response)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to complete response_step: {}", e)))?;
if result.rows_affected() == 0 {
return Err(match fetch_state(pool, id).await? {
Some(state) => FusilladeError::Other(anyhow!(
"response_step {} not in completable state (current: {})",
id,
state
)),
None => FusilladeError::Other(anyhow!("response_step not found: {}", id)),
});
}
Ok(())
}
async fn fail_step(&self, id: StepId, error: serde_json::Value) -> Result<()> {
let pool = self.pools.write();
let result = sqlx::query(
"UPDATE response_steps \
SET state = 'failed', \
error = $2, \
failed_at = NOW(), \
updated_at = NOW() \
WHERE id = $1 AND state IN ('pending', 'processing')",
)
.bind(id.0)
.bind(&error)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to fail response_step: {}", e)))?;
if result.rows_affected() == 0 {
return Err(match fetch_state(pool, id).await? {
Some(state) => FusilladeError::Other(anyhow!(
"response_step {} not in failable state (current: {})",
id,
state
)),
None => FusilladeError::Other(anyhow!("response_step not found: {}", id)),
});
}
Ok(())
}
async fn cancel_step(&self, id: StepId) -> Result<()> {
let pool = self.pools.write();
let result = sqlx::query(
"UPDATE response_steps \
SET state = 'canceled', \
canceled_at = NOW(), \
updated_at = NOW() \
WHERE id = $1 AND state IN ('pending', 'processing')",
)
.bind(id.0)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to cancel response_step: {}", e)))?;
if result.rows_affected() == 0 {
return Err(match fetch_state(pool, id).await? {
Some(state) => FusilladeError::Other(anyhow!(
"response_step {} not in cancelable state (current: {})",
id,
state
)),
None => FusilladeError::Other(anyhow!("response_step not found: {}", id)),
});
}
Ok(())
}
async fn requeue_step_for_retry(&self, id: StepId) -> Result<()> {
let pool = self.pools.write();
let result = sqlx::query(
"UPDATE response_steps \
SET state = 'pending', \
retry_attempt = retry_attempt + 1, \
started_at = NULL, \
updated_at = NOW() \
WHERE id = $1 AND state = 'processing'",
)
.bind(id.0)
.execute(pool)
.await
.map_err(|e| FusilladeError::Other(anyhow!("Failed to requeue response_step: {}", e)))?;
if result.rows_affected() == 0 {
return Err(match fetch_state(pool, id).await? {
Some(state) => FusilladeError::Other(anyhow!(
"response_step {} not in retryable state (current: {})",
id,
state
)),
None => FusilladeError::Other(anyhow!("response_step not found: {}", id)),
});
}
Ok(())
}
}