Skip to main content

ranvier_runtime/
persistence.rs

1use anyhow::{Result, anyhow};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8/// Minimal persisted envelope for Axon execution checkpoints.
9///
10/// M148 baseline contract fields:
11/// - trace
12/// - circuit
13/// - step
14/// - outcome
15/// - timestamp
16/// - payload hash
17#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
18pub struct PersistenceEnvelope {
19    pub trace_id: String,
20    pub circuit: String,
21    pub step: u64,
22    pub outcome_kind: String,
23    pub timestamp_ms: u64,
24    pub payload_hash: Option<String>,
25}
26
27/// Final completion state tracked for a persisted trace.
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub enum CompletionState {
30    Success,
31    Fault,
32    Cancelled,
33    Compensated,
34}
35
36#[cfg(feature = "persistence-postgres")]
37fn completion_state_to_wire(state: &CompletionState) -> &'static str {
38    match state {
39        CompletionState::Success => "success",
40        CompletionState::Fault => "fault",
41        CompletionState::Cancelled => "cancelled",
42        CompletionState::Compensated => "compensated",
43    }
44}
45
46#[cfg(feature = "persistence-postgres")]
47fn completion_state_from_wire(value: &str) -> Result<CompletionState> {
48    match value {
49        "success" => Ok(CompletionState::Success),
50        "fault" => Ok(CompletionState::Fault),
51        "cancelled" => Ok(CompletionState::Cancelled),
52        "compensated" => Ok(CompletionState::Compensated),
53        other => Err(anyhow!("unknown completion state value: {}", other)),
54    }
55}
56
57/// Stored trace state returned from [`PersistenceStore::load`].
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct PersistedTrace {
60    pub trace_id: String,
61    pub circuit: String,
62    pub events: Vec<PersistenceEnvelope>,
63    pub resumed_from_step: Option<u64>,
64    pub completion: Option<CompletionState>,
65}
66
67/// Resume cursor returned from [`PersistenceStore::resume`].
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
69pub struct ResumeCursor {
70    pub trace_id: String,
71    pub next_step: u64,
72}
73
74/// Optional trace identifier override for persistence hooks.
75///
76/// Insert into `Bus` when a stable trace identity is required across process restarts.
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct PersistenceTraceId(pub String);
79
80impl PersistenceTraceId {
81    pub fn new(value: impl Into<String>) -> Self {
82        Self(value.into())
83    }
84}
85
86/// Controls whether runtime execution should call `complete` automatically.
87///
88/// Default runtime behavior when this resource is absent: `true`.
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub struct PersistenceAutoComplete(pub bool);
91
92/// Runtime context delivered to compensation hooks.
93///
94/// The context is intentionally compact so hooks can map it to idempotent
95/// compensating actions in domain/infrastructure layers.
96#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct CompensationContext {
98    pub trace_id: String,
99    pub circuit: String,
100    pub fault_kind: String,
101    pub fault_step: u64,
102    pub timestamp_ms: u64,
103}
104
105/// Controls whether compensation hooks should run automatically on `Fault`.
106///
107/// Default runtime behavior when this resource is absent: `true`.
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub struct CompensationAutoTrigger(pub bool);
110
111/// Retry policy for compensation hook execution.
112///
113/// Defaults to a single attempt (no retry).
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub struct CompensationRetryPolicy {
116    pub max_attempts: u32,
117    pub backoff_ms: u64,
118}
119
120impl Default for CompensationRetryPolicy {
121    fn default() -> Self {
122        Self {
123            max_attempts: 1,
124            backoff_ms: 0,
125        }
126    }
127}
128
129/// Compensation hook contract for irreversible side effects.
130#[async_trait]
131pub trait CompensationHook: Send + Sync {
132    async fn compensate(&self, context: CompensationContext) -> Result<()>;
133}
134
135/// Bus-insertable compensation hook handle used by runtime execution hooks.
136#[derive(Clone)]
137pub struct CompensationHandle {
138    inner: Arc<dyn CompensationHook>,
139}
140
141impl std::fmt::Debug for CompensationHandle {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        f.debug_struct("CompensationHandle").finish_non_exhaustive()
144    }
145}
146
147impl CompensationHandle {
148    /// Create a handle from a concrete compensation hook implementation.
149    pub fn from_hook<H>(hook: H) -> Self
150    where
151        H: CompensationHook + 'static,
152    {
153        Self {
154            inner: Arc::new(hook),
155        }
156    }
157
158    /// Create a handle from an existing trait-object Arc.
159    pub fn from_arc(hook: Arc<dyn CompensationHook>) -> Self {
160        Self { inner: hook }
161    }
162
163    /// Access the shared compensation hook.
164    pub fn hook(&self) -> Arc<dyn CompensationHook> {
165        self.inner.clone()
166    }
167}
168
169/// Idempotency store contract for compensation execution deduplication.
170#[async_trait]
171pub trait CompensationIdempotencyStore: Send + Sync {
172    async fn was_compensated(&self, key: &str) -> Result<bool>;
173    async fn mark_compensated(&self, key: &str) -> Result<()>;
174}
175
176/// Bus-insertable idempotency handle for compensation hooks.
177#[derive(Clone)]
178pub struct CompensationIdempotencyHandle {
179    inner: Arc<dyn CompensationIdempotencyStore>,
180}
181
182impl std::fmt::Debug for CompensationIdempotencyHandle {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        f.debug_struct("CompensationIdempotencyHandle")
185            .finish_non_exhaustive()
186    }
187}
188
189impl CompensationIdempotencyHandle {
190    pub fn from_store<S>(store: S) -> Self
191    where
192        S: CompensationIdempotencyStore + 'static,
193    {
194        Self {
195            inner: Arc::new(store),
196        }
197    }
198
199    pub fn from_arc(store: Arc<dyn CompensationIdempotencyStore>) -> Self {
200        Self { inner: store }
201    }
202
203    pub fn store(&self) -> Arc<dyn CompensationIdempotencyStore> {
204        self.inner.clone()
205    }
206}
207
208/// In-memory idempotency store for compensation deduplication.
209#[derive(Debug, Default, Clone)]
210pub struct InMemoryCompensationIdempotencyStore {
211    keys: Arc<RwLock<HashSet<String>>>,
212}
213
214impl InMemoryCompensationIdempotencyStore {
215    pub fn new() -> Self {
216        Self::default()
217    }
218}
219
220#[async_trait]
221impl CompensationIdempotencyStore for InMemoryCompensationIdempotencyStore {
222    async fn was_compensated(&self, key: &str) -> Result<bool> {
223        let guard = self.keys.read().await;
224        Ok(guard.contains(key))
225    }
226
227    async fn mark_compensated(&self, key: &str) -> Result<()> {
228        let mut guard = self.keys.write().await;
229        guard.insert(key.to_string());
230        Ok(())
231    }
232}
233
234#[cfg(feature = "persistence-postgres")]
235#[derive(Debug, Clone)]
236pub struct PostgresCompensationIdempotencyStore {
237    pool: sqlx::Pool<sqlx::Postgres>,
238    table: String,
239}
240
241#[cfg(feature = "persistence-postgres")]
242impl PostgresCompensationIdempotencyStore {
243    /// Create a PostgreSQL-backed compensation idempotency store.
244    pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
245        Self::with_table_prefix(pool, "ranvier_persistence")
246    }
247
248    /// Create with custom table prefix.
249    pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
250        let prefix = prefix.into();
251        Self {
252            pool,
253            table: format!("{}_compensation_idempotency", prefix),
254        }
255    }
256
257    /// Initialize adapter table when absent.
258    pub async fn ensure_schema(&self) -> Result<()> {
259        let create = format!(
260            "CREATE TABLE IF NOT EXISTS {} (
261                idempotency_key TEXT PRIMARY KEY,
262                created_at_ms BIGINT NOT NULL
263            )",
264            self.table
265        );
266        sqlx::query(&create).execute(&self.pool).await?;
267        Ok(())
268    }
269
270    /// Remove stale idempotency rows older than `cutoff_ms` (unix epoch ms).
271    pub async fn purge_older_than_ms(&self, cutoff_ms: i64) -> Result<u64> {
272        let query = format!(
273            "DELETE FROM {}
274             WHERE created_at_ms < $1",
275            self.table
276        );
277        let rows = sqlx::query(&query)
278            .bind(cutoff_ms)
279            .execute(&self.pool)
280            .await?
281            .rows_affected();
282        Ok(rows)
283    }
284}
285
286#[cfg(feature = "persistence-postgres")]
287#[async_trait]
288impl CompensationIdempotencyStore for PostgresCompensationIdempotencyStore {
289    async fn was_compensated(&self, key: &str) -> Result<bool> {
290        let query = format!(
291            "SELECT 1
292             FROM {}
293             WHERE idempotency_key = $1
294             LIMIT 1",
295            self.table
296        );
297        let row: Option<i32> = sqlx::query_scalar(&query)
298            .bind(key)
299            .fetch_optional(&self.pool)
300            .await?;
301        Ok(row.is_some())
302    }
303
304    async fn mark_compensated(&self, key: &str) -> Result<()> {
305        let query = format!(
306            "INSERT INTO {} (idempotency_key, created_at_ms)
307             VALUES ($1, $2)
308             ON CONFLICT (idempotency_key) DO NOTHING",
309            self.table
310        );
311        let now_ms = std::time::SystemTime::now()
312            .duration_since(std::time::UNIX_EPOCH)?
313            .as_millis();
314        sqlx::query(&query)
315            .bind(key)
316            .bind(i64::try_from(now_ms)?)
317            .execute(&self.pool)
318            .await?;
319        Ok(())
320    }
321}
322
323#[cfg(feature = "persistence-redis")]
324#[derive(Clone)]
325pub struct RedisCompensationIdempotencyStore {
326    manager: redis::aio::ConnectionManager,
327    key_prefix: String,
328    ttl_seconds: Option<u64>,
329}
330
331#[cfg(feature = "persistence-redis")]
332impl std::fmt::Debug for RedisCompensationIdempotencyStore {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        f.debug_struct("RedisCompensationIdempotencyStore")
335            .field("key_prefix", &self.key_prefix)
336            .field("ttl_seconds", &self.ttl_seconds)
337            .finish_non_exhaustive()
338    }
339}
340
341#[cfg(feature = "persistence-redis")]
342impl RedisCompensationIdempotencyStore {
343    /// Connect using Redis connection URL.
344    pub async fn connect(url: &str) -> Result<Self> {
345        let client = redis::Client::open(url)?;
346        let manager = redis::aio::ConnectionManager::new(client).await?;
347        Ok(Self {
348            manager,
349            key_prefix: "ranvier:compensation:idempotency".to_string(),
350            ttl_seconds: None,
351        })
352    }
353
354    pub fn with_prefix(
355        manager: redis::aio::ConnectionManager,
356        key_prefix: impl Into<String>,
357    ) -> Self {
358        Self {
359            manager,
360            key_prefix: key_prefix.into(),
361            ttl_seconds: None,
362        }
363    }
364
365    pub fn with_prefix_and_ttl(
366        manager: redis::aio::ConnectionManager,
367        key_prefix: impl Into<String>,
368        ttl_seconds: u64,
369    ) -> Self {
370        Self {
371            manager,
372            key_prefix: key_prefix.into(),
373            ttl_seconds: Some(ttl_seconds),
374        }
375    }
376
377    fn key(&self, idempotency_key: &str) -> String {
378        format!("{}:{}", self.key_prefix, idempotency_key)
379    }
380}
381
382#[cfg(feature = "persistence-redis")]
383#[async_trait]
384impl CompensationIdempotencyStore for RedisCompensationIdempotencyStore {
385    async fn was_compensated(&self, key: &str) -> Result<bool> {
386        use redis::AsyncCommands;
387        let mut conn = self.manager.clone();
388        let exists: bool = conn.exists(self.key(key)).await?;
389        Ok(exists)
390    }
391
392    async fn mark_compensated(&self, key: &str) -> Result<()> {
393        use redis::AsyncCommands;
394        let mut conn = self.manager.clone();
395        let redis_key = self.key(key);
396        let inserted: bool = conn.set_nx(&redis_key, "1").await?;
397        if inserted {
398            if let Some(ttl_seconds) = self.ttl_seconds {
399                let ttl_i64 = i64::try_from(ttl_seconds)?;
400                let _: bool = conn.expire(&redis_key, ttl_i64).await?;
401            }
402        }
403        Ok(())
404    }
405}
406
407/// Persistence abstraction draft for long-running workflow recovery.
408///
409/// This is intentionally minimal and marked experimental while M148 is active.
410#[async_trait]
411pub trait PersistenceStore: Send + Sync {
412    async fn append(&self, envelope: PersistenceEnvelope) -> Result<()>;
413    async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>>;
414    async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor>;
415    async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()>;
416}
417
418/// Bus-insertable persistence handle used by runtime execution hooks.
419#[derive(Clone)]
420pub struct PersistenceHandle {
421    inner: Arc<dyn PersistenceStore>,
422}
423
424impl std::fmt::Debug for PersistenceHandle {
425    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426        f.debug_struct("PersistenceHandle").finish_non_exhaustive()
427    }
428}
429
430impl PersistenceHandle {
431    /// Create a handle from a concrete store implementation.
432    pub fn from_store<S>(store: S) -> Self
433    where
434        S: PersistenceStore + 'static,
435    {
436        Self {
437            inner: Arc::new(store),
438        }
439    }
440
441    /// Create a handle from an existing trait-object Arc.
442    pub fn from_arc(store: Arc<dyn PersistenceStore>) -> Self {
443        Self { inner: store }
444    }
445
446    /// Access the shared persistence store.
447    pub fn store(&self) -> Arc<dyn PersistenceStore> {
448        self.inner.clone()
449    }
450}
451
452/// In-memory reference adapter for local testing and contract validation.
453#[derive(Debug, Default, Clone)]
454pub struct InMemoryPersistenceStore {
455    inner: Arc<RwLock<HashMap<String, PersistedTrace>>>,
456}
457
458impl InMemoryPersistenceStore {
459    pub fn new() -> Self {
460        Self::default()
461    }
462}
463
464#[async_trait]
465impl PersistenceStore for InMemoryPersistenceStore {
466    async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
467        let mut guard = self.inner.write().await;
468        let entry = guard
469            .entry(envelope.trace_id.clone())
470            .or_insert_with(|| PersistedTrace {
471                trace_id: envelope.trace_id.clone(),
472                circuit: envelope.circuit.clone(),
473                events: Vec::new(),
474                resumed_from_step: None,
475                completion: None,
476            });
477
478        if entry.circuit != envelope.circuit {
479            return Err(anyhow!(
480                "trace_id {} already exists for circuit {}, got {}",
481                envelope.trace_id,
482                entry.circuit,
483                envelope.circuit
484            ));
485        }
486        if entry.completion.is_some() {
487            return Err(anyhow!(
488                "trace_id {} is already completed and cannot accept new events",
489                envelope.trace_id
490            ));
491        }
492        entry.events.push(envelope);
493        entry.events.sort_by_key(|e| e.step);
494        Ok(())
495    }
496
497    async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
498        let guard = self.inner.read().await;
499        Ok(guard.get(trace_id).cloned())
500    }
501
502    async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
503        let mut guard = self.inner.write().await;
504        let trace = guard
505            .get_mut(trace_id)
506            .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
507        trace.resumed_from_step = Some(resume_from_step);
508        Ok(ResumeCursor {
509            trace_id: trace_id.to_string(),
510            next_step: resume_from_step.saturating_add(1),
511        })
512    }
513
514    async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
515        let mut guard = self.inner.write().await;
516        let trace = guard
517            .get_mut(trace_id)
518            .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
519        trace.completion = Some(completion);
520        Ok(())
521    }
522}
523
524#[cfg(feature = "persistence-postgres")]
525#[derive(Debug, Clone)]
526pub struct PostgresPersistenceStore {
527    pool: sqlx::Pool<sqlx::Postgres>,
528    events_table: String,
529    state_table: String,
530}
531
532#[cfg(feature = "persistence-postgres")]
533#[derive(sqlx::FromRow)]
534struct PostgresEventRow {
535    trace_id: String,
536    circuit: String,
537    step: i64,
538    outcome_kind: String,
539    timestamp_ms: i64,
540    payload_hash: Option<String>,
541}
542
543#[cfg(feature = "persistence-postgres")]
544#[derive(sqlx::FromRow)]
545struct PostgresStateRow {
546    trace_id: String,
547    circuit: String,
548    resumed_from_step: Option<i64>,
549    completion: Option<String>,
550}
551
552#[cfg(feature = "persistence-postgres")]
553impl PostgresPersistenceStore {
554    /// Create a PostgreSQL-backed persistence store.
555    ///
556    /// This is an alpha adapter intended for M148 validation.
557    pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
558        Self::with_table_prefix(pool, "ranvier_persistence")
559    }
560
561    /// Create with custom table prefix.
562    pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
563        let prefix = prefix.into();
564        Self {
565            pool,
566            events_table: format!("{}_events", prefix),
567            state_table: format!("{}_state", prefix),
568        }
569    }
570
571    /// Initialize adapter tables when absent.
572    pub async fn ensure_schema(&self) -> Result<()> {
573        let create_state = format!(
574            "CREATE TABLE IF NOT EXISTS {} (
575                trace_id TEXT PRIMARY KEY,
576                circuit TEXT NOT NULL,
577                resumed_from_step BIGINT NULL,
578                completion TEXT NULL
579            )",
580            self.state_table
581        );
582        sqlx::query(&create_state).execute(&self.pool).await?;
583
584        let create_events = format!(
585            "CREATE TABLE IF NOT EXISTS {} (
586                trace_id TEXT NOT NULL,
587                circuit TEXT NOT NULL,
588                step BIGINT NOT NULL,
589                outcome_kind TEXT NOT NULL,
590                timestamp_ms BIGINT NOT NULL,
591                payload_hash TEXT NULL,
592                PRIMARY KEY (trace_id, step)
593            )",
594            self.events_table
595        );
596        sqlx::query(&create_events).execute(&self.pool).await?;
597        Ok(())
598    }
599}
600
601#[cfg(feature = "persistence-postgres")]
602#[async_trait]
603impl PersistenceStore for PostgresPersistenceStore {
604    async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
605        let insert_state = format!(
606            "INSERT INTO {} (trace_id, circuit, resumed_from_step, completion)
607             VALUES ($1, $2, NULL, NULL)
608             ON CONFLICT (trace_id) DO NOTHING",
609            self.state_table
610        );
611        sqlx::query(&insert_state)
612            .bind(&envelope.trace_id)
613            .bind(&envelope.circuit)
614            .execute(&self.pool)
615            .await?;
616
617        let read_state = format!(
618            "SELECT circuit FROM {} WHERE trace_id = $1",
619            self.state_table
620        );
621        let existing_circuit: Option<String> = sqlx::query_scalar(&read_state)
622            .bind(&envelope.trace_id)
623            .fetch_optional(&self.pool)
624            .await?;
625        if existing_circuit.as_deref() != Some(envelope.circuit.as_str()) {
626            return Err(anyhow!(
627                "trace_id {} already exists for another circuit",
628                envelope.trace_id
629            ));
630        }
631
632        let completion_query = format!(
633            "SELECT completion FROM {} WHERE trace_id = $1",
634            self.state_table
635        );
636        let completion: Option<Option<String>> = sqlx::query_scalar(&completion_query)
637            .bind(&envelope.trace_id)
638            .fetch_optional(&self.pool)
639            .await?;
640        if completion.flatten().is_some() {
641            return Err(anyhow!(
642                "trace_id {} is already completed and cannot accept new events",
643                envelope.trace_id
644            ));
645        }
646
647        let step_i64 = i64::try_from(envelope.step)?;
648        let ts_i64 = i64::try_from(envelope.timestamp_ms)?;
649        let insert_event = format!(
650            "INSERT INTO {} (trace_id, circuit, step, outcome_kind, timestamp_ms, payload_hash)
651             VALUES ($1, $2, $3, $4, $5, $6)",
652            self.events_table
653        );
654        sqlx::query(&insert_event)
655            .bind(&envelope.trace_id)
656            .bind(&envelope.circuit)
657            .bind(step_i64)
658            .bind(&envelope.outcome_kind)
659            .bind(ts_i64)
660            .bind(&envelope.payload_hash)
661            .execute(&self.pool)
662            .await?;
663        Ok(())
664    }
665
666    async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
667        let state_query = format!(
668            "SELECT trace_id, circuit, resumed_from_step, completion
669             FROM {}
670             WHERE trace_id = $1",
671            self.state_table
672        );
673        let Some(state): Option<PostgresStateRow> = sqlx::query_as(&state_query)
674            .bind(trace_id)
675            .fetch_optional(&self.pool)
676            .await?
677        else {
678            return Ok(None);
679        };
680
681        let events_query = format!(
682            "SELECT trace_id, circuit, step, outcome_kind, timestamp_ms, payload_hash
683             FROM {}
684             WHERE trace_id = $1
685             ORDER BY step ASC",
686            self.events_table
687        );
688        let rows: Vec<PostgresEventRow> = sqlx::query_as(&events_query)
689            .bind(trace_id)
690            .fetch_all(&self.pool)
691            .await?;
692
693        let mut events = Vec::with_capacity(rows.len());
694        for row in rows {
695            events.push(PersistenceEnvelope {
696                trace_id: row.trace_id,
697                circuit: row.circuit,
698                step: u64::try_from(row.step)?,
699                outcome_kind: row.outcome_kind,
700                timestamp_ms: u64::try_from(row.timestamp_ms)?,
701                payload_hash: row.payload_hash,
702            });
703        }
704
705        let completion = match state.completion {
706            Some(value) => Some(completion_state_from_wire(&value)?),
707            None => None,
708        };
709
710        Ok(Some(PersistedTrace {
711            trace_id: state.trace_id,
712            circuit: state.circuit,
713            events,
714            resumed_from_step: state.resumed_from_step.map(u64::try_from).transpose()?,
715            completion,
716        }))
717    }
718
719    async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
720        let query = format!(
721            "UPDATE {}
722             SET resumed_from_step = $2
723             WHERE trace_id = $1",
724            self.state_table
725        );
726        let rows = sqlx::query(&query)
727            .bind(trace_id)
728            .bind(i64::try_from(resume_from_step)?)
729            .execute(&self.pool)
730            .await?
731            .rows_affected();
732        if rows == 0 {
733            return Err(anyhow!("trace_id {} not found", trace_id));
734        }
735        Ok(ResumeCursor {
736            trace_id: trace_id.to_string(),
737            next_step: resume_from_step.saturating_add(1),
738        })
739    }
740
741    async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
742        let query = format!(
743            "UPDATE {}
744             SET completion = $2
745             WHERE trace_id = $1",
746            self.state_table
747        );
748        let rows = sqlx::query(&query)
749            .bind(trace_id)
750            .bind(completion_state_to_wire(&completion))
751            .execute(&self.pool)
752            .await?
753            .rows_affected();
754        if rows == 0 {
755            return Err(anyhow!("trace_id {} not found", trace_id));
756        }
757        Ok(())
758    }
759}
760
761#[cfg(feature = "persistence-redis")]
762#[derive(Clone)]
763pub struct RedisPersistenceStore {
764    manager: redis::aio::ConnectionManager,
765    key_prefix: String,
766}
767
768#[cfg(feature = "persistence-redis")]
769impl std::fmt::Debug for RedisPersistenceStore {
770    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
771        f.debug_struct("RedisPersistenceStore")
772            .field("key_prefix", &self.key_prefix)
773            .finish_non_exhaustive()
774    }
775}
776
777#[cfg(feature = "persistence-redis")]
778impl RedisPersistenceStore {
779    /// Connect using Redis connection URL.
780    ///
781    /// Example: `redis://127.0.0.1:6379`
782    pub async fn connect(url: &str) -> Result<Self> {
783        let client = redis::Client::open(url)?;
784        let manager = redis::aio::ConnectionManager::new(client).await?;
785        Ok(Self {
786            manager,
787            key_prefix: "ranvier:persistence".to_string(),
788        })
789    }
790
791    pub fn with_prefix(
792        manager: redis::aio::ConnectionManager,
793        key_prefix: impl Into<String>,
794    ) -> Self {
795        Self {
796            manager,
797            key_prefix: key_prefix.into(),
798        }
799    }
800
801    fn key(&self, trace_id: &str) -> String {
802        format!("{}:{}", self.key_prefix, trace_id)
803    }
804
805    async fn write_trace(&self, trace: &PersistedTrace) -> Result<()> {
806        use redis::AsyncCommands;
807        let key = self.key(&trace.trace_id);
808        let payload = serde_json::to_string(trace)?;
809        let mut conn = self.manager.clone();
810        conn.set::<_, _, ()>(key, payload).await?;
811        Ok(())
812    }
813}
814
815#[cfg(feature = "persistence-redis")]
816#[async_trait]
817impl PersistenceStore for RedisPersistenceStore {
818    async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
819        let mut trace = self
820            .load(&envelope.trace_id)
821            .await?
822            .unwrap_or_else(|| PersistedTrace {
823                trace_id: envelope.trace_id.clone(),
824                circuit: envelope.circuit.clone(),
825                events: Vec::new(),
826                resumed_from_step: None,
827                completion: None,
828            });
829
830        if trace.circuit != envelope.circuit {
831            return Err(anyhow!(
832                "trace_id {} already exists for circuit {}, got {}",
833                envelope.trace_id,
834                trace.circuit,
835                envelope.circuit
836            ));
837        }
838        if trace.completion.is_some() {
839            return Err(anyhow!(
840                "trace_id {} is already completed and cannot accept new events",
841                envelope.trace_id
842            ));
843        }
844
845        trace.events.push(envelope);
846        trace.events.sort_by_key(|event| event.step);
847        self.write_trace(&trace).await?;
848        Ok(())
849    }
850
851    async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
852        use redis::AsyncCommands;
853        let key = self.key(trace_id);
854        let mut conn = self.manager.clone();
855        let payload: Option<String> = conn.get(key).await?;
856        let trace = payload
857            .map(|raw| serde_json::from_str::<PersistedTrace>(&raw))
858            .transpose()?;
859        Ok(trace)
860    }
861
862    async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
863        let mut trace = self
864            .load(trace_id)
865            .await?
866            .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
867        trace.resumed_from_step = Some(resume_from_step);
868        self.write_trace(&trace).await?;
869        Ok(ResumeCursor {
870            trace_id: trace_id.to_string(),
871            next_step: resume_from_step.saturating_add(1),
872        })
873    }
874
875    async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
876        let mut trace = self
877            .load(trace_id)
878            .await?
879            .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
880        trace.completion = Some(completion);
881        self.write_trace(&trace).await?;
882        Ok(())
883    }
884}
885
886#[cfg(test)]
887mod tests {
888    use super::*;
889    #[cfg(any(feature = "persistence-postgres", feature = "persistence-redis"))]
890    use uuid::Uuid;
891
892    fn envelope(step: u64, outcome_kind: &str) -> PersistenceEnvelope {
893        PersistenceEnvelope {
894            trace_id: "trace-1".to_string(),
895            circuit: "OrderCircuit".to_string(),
896            step,
897            outcome_kind: outcome_kind.to_string(),
898            timestamp_ms: 1_700_000_000_000 + step,
899            payload_hash: Some(format!("hash-{}", step)),
900        }
901    }
902
903    #[tokio::test]
904    async fn append_and_load_roundtrip() {
905        let store = InMemoryPersistenceStore::new();
906        store.append(envelope(1, "Next")).await.unwrap();
907        store.append(envelope(2, "Branch")).await.unwrap();
908
909        let loaded = store.load("trace-1").await.unwrap().unwrap();
910        assert_eq!(loaded.trace_id, "trace-1");
911        assert_eq!(loaded.circuit, "OrderCircuit");
912        assert_eq!(loaded.events.len(), 2);
913        assert_eq!(loaded.events[0].step, 1);
914        assert_eq!(loaded.events[1].outcome_kind, "Branch");
915        assert_eq!(loaded.completion, None);
916    }
917
918    #[tokio::test]
919    async fn resume_records_cursor() {
920        let store = InMemoryPersistenceStore::new();
921        store.append(envelope(3, "Fault")).await.unwrap();
922
923        let cursor = store.resume("trace-1", 3).await.unwrap();
924        assert_eq!(
925            cursor,
926            ResumeCursor {
927                trace_id: "trace-1".to_string(),
928                next_step: 4
929            }
930        );
931
932        let loaded = store.load("trace-1").await.unwrap().unwrap();
933        assert_eq!(loaded.resumed_from_step, Some(3));
934    }
935
936    #[tokio::test]
937    async fn complete_marks_trace_and_blocks_append() {
938        let store = InMemoryPersistenceStore::new();
939        store.append(envelope(1, "Next")).await.unwrap();
940        store
941            .complete("trace-1", CompletionState::Success)
942            .await
943            .unwrap();
944
945        let loaded = store.load("trace-1").await.unwrap().unwrap();
946        assert_eq!(loaded.completion, Some(CompletionState::Success));
947
948        let err = store.append(envelope(2, "Next")).await.unwrap_err();
949        assert!(
950            err.to_string()
951                .contains("is already completed and cannot accept new events")
952        );
953    }
954
955    #[tokio::test]
956    async fn append_rejects_cross_circuit_trace_reuse() {
957        let store = InMemoryPersistenceStore::new();
958        store.append(envelope(1, "Next")).await.unwrap();
959
960        let mut invalid = envelope(2, "Next");
961        invalid.circuit = "AnotherCircuit".to_string();
962        let err = store.append(invalid).await.unwrap_err();
963        assert!(
964            err.to_string()
965                .contains("already exists for circuit OrderCircuit")
966        );
967    }
968
969    #[tokio::test]
970    async fn in_memory_compensation_idempotency_roundtrip() {
971        let store = InMemoryCompensationIdempotencyStore::new();
972        let key = "trace-a:OrderFlow:Fault";
973
974        assert!(!store.was_compensated(key).await.unwrap());
975        store.mark_compensated(key).await.unwrap();
976        assert!(store.was_compensated(key).await.unwrap());
977    }
978
979    #[cfg(feature = "persistence-postgres")]
980    #[tokio::test]
981    async fn postgres_store_roundtrip_when_configured() {
982        let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
983            Ok(value) => value,
984            Err(_) => return,
985        };
986
987        let pool = sqlx::postgres::PgPoolOptions::new()
988            .max_connections(5)
989            .connect(&url)
990            .await
991            .unwrap();
992        let table_prefix = format!("ranvier_persistence_test_{}", Uuid::new_v4().simple());
993        let store = PostgresPersistenceStore::with_table_prefix(pool.clone(), table_prefix.clone());
994        store.ensure_schema().await.unwrap();
995
996        let trace_id = format!("trace-{}", Uuid::new_v4().simple());
997        let circuit = "PgCircuit".to_string();
998
999        let mut first = envelope(1, "Next");
1000        first.trace_id = trace_id.clone();
1001        first.circuit = circuit.clone();
1002        store.append(first).await.unwrap();
1003
1004        let mut second = envelope(2, "Branch");
1005        second.trace_id = trace_id.clone();
1006        second.circuit = circuit.clone();
1007        store.append(second).await.unwrap();
1008
1009        let cursor = store.resume(&trace_id, 2).await.unwrap();
1010        assert_eq!(cursor.next_step, 3);
1011
1012        store
1013            .complete(&trace_id, CompletionState::Compensated)
1014            .await
1015            .unwrap();
1016
1017        let loaded = store.load(&trace_id).await.unwrap().unwrap();
1018        assert_eq!(loaded.trace_id, trace_id);
1019        assert_eq!(loaded.circuit, circuit);
1020        assert_eq!(loaded.events.len(), 2);
1021        assert_eq!(loaded.resumed_from_step, Some(2));
1022        assert_eq!(loaded.completion, Some(CompletionState::Compensated));
1023
1024        let drop_events = format!("DROP TABLE IF EXISTS {}", store.events_table);
1025        let drop_state = format!("DROP TABLE IF EXISTS {}", store.state_table);
1026        sqlx::query(&drop_events).execute(&pool).await.unwrap();
1027        sqlx::query(&drop_state).execute(&pool).await.unwrap();
1028    }
1029
1030    #[cfg(feature = "persistence-redis")]
1031    #[tokio::test]
1032    async fn redis_store_roundtrip_when_configured() {
1033        let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1034            Ok(value) => value,
1035            Err(_) => return,
1036        };
1037
1038        let base = RedisPersistenceStore::connect(&url).await.unwrap();
1039        let prefix = format!("ranvier:persistence:test:{}", Uuid::new_v4().simple());
1040        let store = RedisPersistenceStore::with_prefix(base.manager.clone(), prefix);
1041
1042        let trace_id = format!("trace-{}", Uuid::new_v4().simple());
1043        let circuit = "RedisCircuit".to_string();
1044
1045        let mut first = envelope(1, "Next");
1046        first.trace_id = trace_id.clone();
1047        first.circuit = circuit.clone();
1048        store.append(first).await.unwrap();
1049
1050        let mut second = envelope(2, "Fault");
1051        second.trace_id = trace_id.clone();
1052        second.circuit = circuit.clone();
1053        store.append(second).await.unwrap();
1054
1055        let cursor = store.resume(&trace_id, 2).await.unwrap();
1056        assert_eq!(cursor.next_step, 3);
1057
1058        store
1059            .complete(&trace_id, CompletionState::Fault)
1060            .await
1061            .unwrap();
1062
1063        let loaded = store.load(&trace_id).await.unwrap().unwrap();
1064        assert_eq!(loaded.trace_id, trace_id);
1065        assert_eq!(loaded.circuit, circuit);
1066        assert_eq!(loaded.events.len(), 2);
1067        assert_eq!(loaded.resumed_from_step, Some(2));
1068        assert_eq!(loaded.completion, Some(CompletionState::Fault));
1069
1070        use redis::AsyncCommands;
1071        let key = store.key(&trace_id);
1072        let mut conn = store.manager.clone();
1073        let _: () = conn.del(key).await.unwrap();
1074    }
1075
1076    #[cfg(feature = "persistence-postgres")]
1077    #[tokio::test]
1078    async fn postgres_compensation_idempotency_roundtrip_when_configured() {
1079        let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1080            Ok(value) => value,
1081            Err(_) => return,
1082        };
1083
1084        let pool = sqlx::postgres::PgPoolOptions::new()
1085            .max_connections(5)
1086            .connect(&url)
1087            .await
1088            .unwrap();
1089        let table_prefix = format!(
1090            "ranvier_compensation_idempotency_test_{}",
1091            Uuid::new_v4().simple()
1092        );
1093        let store =
1094            PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1095        store.ensure_schema().await.unwrap();
1096
1097        let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1098        assert!(!store.was_compensated(&key).await.unwrap());
1099        store.mark_compensated(&key).await.unwrap();
1100        assert!(store.was_compensated(&key).await.unwrap());
1101        store.mark_compensated(&key).await.unwrap();
1102        assert!(store.was_compensated(&key).await.unwrap());
1103
1104        let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1105        sqlx::query(&drop_table).execute(&pool).await.unwrap();
1106    }
1107
1108    #[cfg(feature = "persistence-postgres")]
1109    #[tokio::test]
1110    async fn postgres_compensation_idempotency_purge_when_configured() {
1111        let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1112            Ok(value) => value,
1113            Err(_) => return,
1114        };
1115
1116        let pool = sqlx::postgres::PgPoolOptions::new()
1117            .max_connections(5)
1118            .connect(&url)
1119            .await
1120            .unwrap();
1121        let table_prefix = format!(
1122            "ranvier_compensation_idempotency_purge_test_{}",
1123            Uuid::new_v4().simple()
1124        );
1125        let store =
1126            PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1127        store.ensure_schema().await.unwrap();
1128
1129        let stale_key = format!("stale-{}", Uuid::new_v4().simple());
1130        let fresh_key = format!("fresh-{}", Uuid::new_v4().simple());
1131        store.mark_compensated(&stale_key).await.unwrap();
1132        store.mark_compensated(&fresh_key).await.unwrap();
1133
1134        let force_stale_query = format!(
1135            "UPDATE {}
1136             SET created_at_ms = 0
1137             WHERE idempotency_key = $1",
1138            store.table
1139        );
1140        sqlx::query(&force_stale_query)
1141            .bind(&stale_key)
1142            .execute(&pool)
1143            .await
1144            .unwrap();
1145
1146        let purged = store.purge_older_than_ms(1).await.unwrap();
1147        assert!(purged >= 1);
1148        assert!(!store.was_compensated(&stale_key).await.unwrap());
1149        assert!(store.was_compensated(&fresh_key).await.unwrap());
1150
1151        let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1152        sqlx::query(&drop_table).execute(&pool).await.unwrap();
1153    }
1154
1155    #[cfg(feature = "persistence-redis")]
1156    #[tokio::test]
1157    async fn redis_compensation_idempotency_roundtrip_when_configured() {
1158        let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1159            Ok(value) => value,
1160            Err(_) => return,
1161        };
1162
1163        let base = RedisCompensationIdempotencyStore::connect(&url)
1164            .await
1165            .unwrap();
1166        let prefix = format!(
1167            "ranvier:compensation:idempotency:test:{}",
1168            Uuid::new_v4().simple()
1169        );
1170        let store = RedisCompensationIdempotencyStore::with_prefix(base.manager.clone(), prefix);
1171        let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1172
1173        assert!(!store.was_compensated(&key).await.unwrap());
1174        store.mark_compensated(&key).await.unwrap();
1175        assert!(store.was_compensated(&key).await.unwrap());
1176        store.mark_compensated(&key).await.unwrap();
1177        assert!(store.was_compensated(&key).await.unwrap());
1178
1179        use redis::AsyncCommands;
1180        let mut conn = store.manager.clone();
1181        let _: () = conn.del(store.key(&key)).await.unwrap();
1182    }
1183
1184    #[cfg(feature = "persistence-redis")]
1185    #[tokio::test]
1186    async fn redis_compensation_idempotency_ttl_when_configured() {
1187        let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1188            Ok(value) => value,
1189            Err(_) => return,
1190        };
1191
1192        let base = RedisCompensationIdempotencyStore::connect(&url)
1193            .await
1194            .unwrap();
1195        let prefix = format!(
1196            "ranvier:compensation:idempotency:ttl:test:{}",
1197            Uuid::new_v4().simple()
1198        );
1199        let store =
1200            RedisCompensationIdempotencyStore::with_prefix_and_ttl(base.manager.clone(), prefix, 1);
1201        let key = format!("ttl-{}", Uuid::new_v4().simple());
1202
1203        assert!(!store.was_compensated(&key).await.unwrap());
1204        store.mark_compensated(&key).await.unwrap();
1205        assert!(store.was_compensated(&key).await.unwrap());
1206
1207        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
1208        assert!(!store.was_compensated(&key).await.unwrap());
1209    }
1210}