1use anyhow::{Result, anyhow};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
18pub struct PersistenceEnvelope {
19 pub trace_id: String,
20 pub circuit: String,
21 pub schematic_version: String,
22 pub step: u64,
23 pub node_id: Option<String>,
24 pub outcome_kind: String,
25 pub timestamp_ms: u64,
26 pub payload_hash: Option<String>,
27 pub payload: Option<serde_json::Value>,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
32pub struct Intervention {
33 pub target_node: String,
34 pub payload_override: Option<serde_json::Value>,
35 pub timestamp_ms: u64,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
40pub enum CompletionState {
41 Success,
42 Fault,
43 Cancelled,
44 Compensated,
45}
46
47#[cfg(feature = "persistence-postgres")]
48fn completion_state_to_wire(state: &CompletionState) -> &'static str {
49 match state {
50 CompletionState::Success => "success",
51 CompletionState::Fault => "fault",
52 CompletionState::Cancelled => "cancelled",
53 CompletionState::Compensated => "compensated",
54 }
55}
56
57#[cfg(feature = "persistence-postgres")]
58fn completion_state_from_wire(value: &str) -> Result<CompletionState> {
59 match value {
60 "success" => Ok(CompletionState::Success),
61 "fault" => Ok(CompletionState::Fault),
62 "cancelled" => Ok(CompletionState::Cancelled),
63 "compensated" => Ok(CompletionState::Compensated),
64 other => Err(anyhow!("unknown completion state value: {}", other)),
65 }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
70pub struct PersistedTrace {
71 pub trace_id: String,
72 pub circuit: String,
73 pub schematic_version: String,
75 pub events: Vec<PersistenceEnvelope>,
76 pub interventions: Vec<Intervention>,
77 pub resumed_from_step: Option<u64>,
78 pub completion: Option<CompletionState>,
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
83pub struct ResumeCursor {
84 pub trace_id: String,
85 pub next_step: u64,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct PersistenceTraceId(pub String);
93
94impl PersistenceTraceId {
95 pub fn new(value: impl Into<String>) -> Self {
96 Self(value.into())
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub struct PersistenceAutoComplete(pub bool);
105
106#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
111pub struct CompensationContext {
112 pub trace_id: String,
113 pub circuit: String,
114 pub fault_kind: String,
115 pub fault_step: u64,
116 pub timestamp_ms: u64,
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub struct CompensationAutoTrigger(pub bool);
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub struct CompensationRetryPolicy {
130 pub max_attempts: u32,
131 pub backoff_ms: u64,
132}
133
134impl Default for CompensationRetryPolicy {
135 fn default() -> Self {
136 Self {
137 max_attempts: 1,
138 backoff_ms: 0,
139 }
140 }
141}
142
143#[async_trait]
149pub trait CompensationHook: Send + Sync {
150 async fn compensate(&self, context: CompensationContext) -> Result<()>;
151}
152
153#[derive(Clone)]
155pub struct CompensationHandle {
156 inner: Arc<dyn CompensationHook>,
157}
158
159impl std::fmt::Debug for CompensationHandle {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct("CompensationHandle").finish_non_exhaustive()
162 }
163}
164
165impl CompensationHandle {
166 pub fn from_hook<H>(hook: H) -> Self
168 where
169 H: CompensationHook + 'static,
170 {
171 Self {
172 inner: Arc::new(hook),
173 }
174 }
175
176 pub fn from_arc(hook: Arc<dyn CompensationHook>) -> Self {
178 Self { inner: hook }
179 }
180
181 pub fn hook(&self) -> Arc<dyn CompensationHook> {
183 self.inner.clone()
184 }
185}
186
187#[async_trait]
195pub trait CompensationIdempotencyStore: Send + Sync {
196 async fn was_compensated(&self, key: &str) -> Result<bool>;
197 async fn mark_compensated(&self, key: &str) -> Result<()>;
198}
199
200#[derive(Clone)]
202pub struct CompensationIdempotencyHandle {
203 inner: Arc<dyn CompensationIdempotencyStore>,
204}
205
206impl std::fmt::Debug for CompensationIdempotencyHandle {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 f.debug_struct("CompensationIdempotencyHandle")
209 .finish_non_exhaustive()
210 }
211}
212
213impl CompensationIdempotencyHandle {
214 pub fn from_store<S>(store: S) -> Self
215 where
216 S: CompensationIdempotencyStore + 'static,
217 {
218 Self {
219 inner: Arc::new(store),
220 }
221 }
222
223 pub fn from_arc(store: Arc<dyn CompensationIdempotencyStore>) -> Self {
224 Self { inner: store }
225 }
226
227 pub fn store(&self) -> Arc<dyn CompensationIdempotencyStore> {
228 self.inner.clone()
229 }
230}
231
232#[derive(Debug, Default, Clone)]
234pub struct InMemoryCompensationIdempotencyStore {
235 keys: Arc<RwLock<HashSet<String>>>,
236}
237
238impl InMemoryCompensationIdempotencyStore {
239 pub fn new() -> Self {
240 Self::default()
241 }
242}
243
244#[async_trait]
245impl CompensationIdempotencyStore for InMemoryCompensationIdempotencyStore {
246 async fn was_compensated(&self, key: &str) -> Result<bool> {
247 let guard = self.keys.read().await;
248 Ok(guard.contains(key))
249 }
250
251 async fn mark_compensated(&self, key: &str) -> Result<()> {
252 let mut guard = self.keys.write().await;
253 guard.insert(key.to_string());
254 Ok(())
255 }
256}
257
258#[cfg(feature = "persistence-postgres")]
259#[derive(Debug, Clone)]
260pub struct PostgresCompensationIdempotencyStore {
261 pool: sqlx::Pool<sqlx::Postgres>,
262 table: String,
263}
264
265#[cfg(feature = "persistence-postgres")]
266impl PostgresCompensationIdempotencyStore {
267 pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
269 Self::with_table_prefix(pool, "ranvier_persistence")
270 }
271
272 pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
274 let prefix = prefix.into();
275 Self {
276 pool,
277 table: format!("{}_compensation_idempotency", prefix),
278 }
279 }
280
281 pub async fn ensure_schema(&self) -> Result<()> {
283 let create = format!(
284 "CREATE TABLE IF NOT EXISTS {} (
285 idempotency_key TEXT PRIMARY KEY,
286 created_at_ms BIGINT NOT NULL
287 )",
288 self.table
289 );
290 sqlx::query(&create).execute(&self.pool).await?;
291 Ok(())
292 }
293
294 pub async fn purge_older_than_ms(&self, cutoff_ms: i64) -> Result<u64> {
296 let query = format!(
297 "DELETE FROM {}
298 WHERE created_at_ms < $1",
299 self.table
300 );
301 let rows = sqlx::query(&query)
302 .bind(cutoff_ms)
303 .execute(&self.pool)
304 .await?
305 .rows_affected();
306 Ok(rows)
307 }
308}
309
310#[cfg(feature = "persistence-postgres")]
311#[async_trait]
312impl CompensationIdempotencyStore for PostgresCompensationIdempotencyStore {
313 async fn was_compensated(&self, key: &str) -> Result<bool> {
314 let query = format!(
315 "SELECT 1
316 FROM {}
317 WHERE idempotency_key = $1
318 LIMIT 1",
319 self.table
320 );
321 let row: Option<i32> = sqlx::query_scalar(&query)
322 .bind(key)
323 .fetch_optional(&self.pool)
324 .await?;
325 Ok(row.is_some())
326 }
327
328 async fn mark_compensated(&self, key: &str) -> Result<()> {
329 let query = format!(
330 "INSERT INTO {} (idempotency_key, created_at_ms)
331 VALUES ($1, $2)
332 ON CONFLICT (idempotency_key) DO NOTHING",
333 self.table
334 );
335 let now_ms = std::time::SystemTime::now()
336 .duration_since(std::time::UNIX_EPOCH)?
337 .as_millis();
338 sqlx::query(&query)
339 .bind(key)
340 .bind(i64::try_from(now_ms)?)
341 .execute(&self.pool)
342 .await?;
343 Ok(())
344 }
345}
346
347#[cfg(feature = "persistence-redis")]
348#[derive(Clone)]
349pub struct RedisCompensationIdempotencyStore {
350 manager: redis::aio::ConnectionManager,
351 key_prefix: String,
352 ttl_seconds: Option<u64>,
353}
354
355#[cfg(feature = "persistence-redis")]
356impl std::fmt::Debug for RedisCompensationIdempotencyStore {
357 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358 f.debug_struct("RedisCompensationIdempotencyStore")
359 .field("key_prefix", &self.key_prefix)
360 .field("ttl_seconds", &self.ttl_seconds)
361 .finish_non_exhaustive()
362 }
363}
364
365#[cfg(feature = "persistence-redis")]
366impl RedisCompensationIdempotencyStore {
367 pub async fn connect(url: &str) -> Result<Self> {
369 let client = redis::Client::open(url)?;
370 let manager = redis::aio::ConnectionManager::new(client).await?;
371 Ok(Self {
372 manager,
373 key_prefix: "ranvier:compensation:idempotency".to_string(),
374 ttl_seconds: None,
375 })
376 }
377
378 pub fn with_prefix(
379 manager: redis::aio::ConnectionManager,
380 key_prefix: impl Into<String>,
381 ) -> Self {
382 Self {
383 manager,
384 key_prefix: key_prefix.into(),
385 ttl_seconds: None,
386 }
387 }
388
389 pub fn with_prefix_and_ttl(
390 manager: redis::aio::ConnectionManager,
391 key_prefix: impl Into<String>,
392 ttl_seconds: u64,
393 ) -> Self {
394 Self {
395 manager,
396 key_prefix: key_prefix.into(),
397 ttl_seconds: Some(ttl_seconds),
398 }
399 }
400
401 fn key(&self, idempotency_key: &str) -> String {
402 format!("{}:{}", self.key_prefix, idempotency_key)
403 }
404}
405
406#[cfg(feature = "persistence-redis")]
407#[async_trait]
408impl CompensationIdempotencyStore for RedisCompensationIdempotencyStore {
409 async fn was_compensated(&self, key: &str) -> Result<bool> {
410 use redis::AsyncCommands;
411 let mut conn = self.manager.clone();
412 let exists: bool = conn.exists(self.key(key)).await?;
413 Ok(exists)
414 }
415
416 async fn mark_compensated(&self, key: &str) -> Result<()> {
417 use redis::AsyncCommands;
418 let mut conn = self.manager.clone();
419 let redis_key = self.key(key);
420 let inserted: bool = conn.set_nx(&redis_key, "1").await?;
421 if inserted && let Some(ttl_seconds) = self.ttl_seconds {
422 let ttl_i64 = i64::try_from(ttl_seconds)?;
423 let _: bool = conn.expire(&redis_key, ttl_i64).await?;
424 }
425 Ok(())
426 }
427}
428
429#[async_trait]
450pub trait PersistenceStore: Send + Sync {
451 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()>;
452 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>>;
453 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor>;
454 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()>;
455 async fn save_intervention(&self, trace_id: &str, intervention: Intervention) -> Result<()>;
456}
457
458#[derive(Clone)]
460pub struct PersistenceHandle {
461 inner: Arc<dyn PersistenceStore>,
462}
463
464impl std::fmt::Debug for PersistenceHandle {
465 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466 f.debug_struct("PersistenceHandle").finish_non_exhaustive()
467 }
468}
469
470impl PersistenceHandle {
471 pub fn from_store<S>(store: S) -> Self
473 where
474 S: PersistenceStore + 'static,
475 {
476 Self {
477 inner: Arc::new(store),
478 }
479 }
480
481 pub fn from_arc(store: Arc<dyn PersistenceStore>) -> Self {
483 Self { inner: store }
484 }
485
486 pub fn store(&self) -> Arc<dyn PersistenceStore> {
488 self.inner.clone()
489 }
490}
491
492#[derive(Debug, Default, Clone)]
494pub struct InMemoryPersistenceStore {
495 inner: Arc<RwLock<HashMap<String, PersistedTrace>>>,
496}
497
498impl InMemoryPersistenceStore {
499 pub fn new() -> Self {
500 Self::default()
501 }
502}
503
504#[async_trait]
505impl PersistenceStore for InMemoryPersistenceStore {
506 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
507 let mut guard = self.inner.write().await;
508 let entry = guard
509 .entry(envelope.trace_id.clone())
510 .or_insert_with(|| PersistedTrace {
511 trace_id: envelope.trace_id.clone(),
512 circuit: envelope.circuit.clone(),
513 schematic_version: envelope.schematic_version.clone(),
514 events: Vec::new(),
515 interventions: Vec::new(),
516 resumed_from_step: None,
517 completion: None,
518 });
519
520 entry.schematic_version = envelope.schematic_version.clone();
521
522 if entry.circuit != envelope.circuit {
523 return Err(anyhow!(
524 "trace_id {} already exists for circuit {}, got {}",
525 envelope.trace_id,
526 entry.circuit,
527 envelope.circuit
528 ));
529 }
530 if entry.completion.is_some() {
531 return Err(anyhow!(
532 "trace_id {} is already completed and cannot accept new events",
533 envelope.trace_id
534 ));
535 }
536 entry.events.push(envelope);
537 entry.events.sort_by_key(|e| e.step);
538 Ok(())
539 }
540
541 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
542 let guard = self.inner.read().await;
543 Ok(guard.get(trace_id).cloned())
544 }
545
546 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
547 let mut guard = self.inner.write().await;
548 let trace = guard
549 .get_mut(trace_id)
550 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
551 trace.resumed_from_step = Some(resume_from_step);
552 Ok(ResumeCursor {
553 trace_id: trace_id.to_string(),
554 next_step: resume_from_step.saturating_add(1),
555 })
556 }
557
558 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
559 let mut guard = self.inner.write().await;
560 let trace = guard
561 .get_mut(trace_id)
562 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
563 trace.completion = Some(completion);
564 Ok(())
565 }
566
567 async fn save_intervention(&self, trace_id: &str, intervention: Intervention) -> Result<()> {
568 let mut guard = self.inner.write().await;
569 let trace = guard
570 .get_mut(trace_id)
571 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
572
573 trace.interventions.push(intervention);
576 Ok(())
577 }
578}
579
580#[cfg(feature = "persistence-postgres")]
581#[derive(Debug, Clone)]
582pub struct PostgresPersistenceStore {
583 pool: sqlx::Pool<sqlx::Postgres>,
584 events_table: String,
585 state_table: String,
586 interventions_table: String,
587}
588
589#[cfg(feature = "persistence-postgres")]
590#[derive(sqlx::FromRow)]
591struct PostgresEventRow {
592 trace_id: String,
593 circuit: String,
594 schematic_version: String,
595 step: i64,
596 outcome_kind: String,
597 timestamp_ms: i64,
598 payload_hash: Option<String>,
599 payload: Option<serde_json::Value>,
600}
601
602#[cfg(feature = "persistence-postgres")]
603#[derive(sqlx::FromRow)]
604struct PostgresStateRow {
605 trace_id: String,
606 circuit: String,
607 schematic_version: String,
608 resumed_from_step: Option<i64>,
609 completion: Option<String>,
610}
611
612#[cfg(feature = "persistence-postgres")]
613#[derive(sqlx::FromRow)]
614struct PostgresInterventionRow {
615 _trace_id: String,
616 target_node: String,
617 payload_override: Option<serde_json::Value>,
618 timestamp_ms: i64,
619}
620
621#[cfg(feature = "persistence-postgres")]
622impl PostgresPersistenceStore {
623 pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
629 Self::with_table_prefix(pool, "ranvier_persistence")
630 }
631
632 pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
634 let prefix = prefix.into();
635 Self {
636 pool,
637 events_table: format!("{}_events", prefix),
638 state_table: format!("{}_state", prefix),
639 interventions_table: format!("{}_interventions", prefix),
640 }
641 }
642
643 pub async fn ensure_schema(&self) -> Result<()> {
645 let create_state = format!(
646 "CREATE TABLE IF NOT EXISTS {} (
647 trace_id TEXT PRIMARY KEY,
648 circuit TEXT NOT NULL,
649 schematic_version TEXT NOT NULL,
650 resumed_from_step BIGINT NULL,
651 completion TEXT NULL
652 )",
653 self.state_table
654 );
655 sqlx::query(&create_state).execute(&self.pool).await?;
656
657 let create_events = format!(
658 "CREATE TABLE IF NOT EXISTS {} (
659 trace_id TEXT NOT NULL,
660 circuit TEXT NOT NULL,
661 schematic_version TEXT NOT NULL,
662 step BIGINT NOT NULL,
663 outcome_kind TEXT NOT NULL,
664 timestamp_ms BIGINT NOT NULL,
665 payload_hash TEXT NULL,
666 payload JSONB NULL,
667 PRIMARY KEY (trace_id, step)
668 )",
669 self.events_table
670 );
671 sqlx::query(&create_events).execute(&self.pool).await?;
672
673 let create_interventions = format!(
674 "CREATE TABLE IF NOT EXISTS {} (
675 trace_id TEXT NOT NULL,
676 target_node TEXT NOT NULL,
677 payload_override JSONB NULL,
678 timestamp_ms BIGINT NOT NULL,
679 FOREIGN KEY (trace_id) REFERENCES {} (trace_id)
680 )",
681 self.interventions_table, self.state_table
682 );
683 sqlx::query(&create_interventions)
684 .execute(&self.pool)
685 .await?;
686
687 Ok(())
688 }
689}
690
691#[cfg(feature = "persistence-postgres")]
692#[async_trait]
693impl PersistenceStore for PostgresPersistenceStore {
694 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
695 let insert_state = format!(
696 "INSERT INTO {} (trace_id, circuit, schematic_version, resumed_from_step, completion)
697 VALUES ($1, $2, $3, NULL, NULL)
698 ON CONFLICT (trace_id) DO UPDATE SET schematic_version = $3",
699 self.state_table
700 );
701 sqlx::query(&insert_state)
702 .bind(&envelope.trace_id)
703 .bind(&envelope.circuit)
704 .bind(&envelope.schematic_version)
705 .execute(&self.pool)
706 .await?;
707
708 let read_state = format!(
709 "SELECT circuit FROM {} WHERE trace_id = $1",
710 self.state_table
711 );
712 let existing_circuit: Option<String> = sqlx::query_scalar(&read_state)
713 .bind(&envelope.trace_id)
714 .fetch_optional(&self.pool)
715 .await?;
716 if existing_circuit.as_deref() != Some(envelope.circuit.as_str()) {
717 return Err(anyhow!(
718 "trace_id {} already exists for another circuit",
719 envelope.trace_id
720 ));
721 }
722
723 let completion_query = format!(
724 "SELECT completion FROM {} WHERE trace_id = $1",
725 self.state_table
726 );
727 let completion: Option<Option<String>> = sqlx::query_scalar(&completion_query)
728 .bind(&envelope.trace_id)
729 .fetch_optional(&self.pool)
730 .await?;
731 if completion.flatten().is_some() {
732 return Err(anyhow!(
733 "trace_id {} is already completed and cannot accept new events",
734 envelope.trace_id
735 ));
736 }
737
738 let step_i64 = i64::try_from(envelope.step)?;
739 let ts_i64 = i64::try_from(envelope.timestamp_ms)?;
740 let insert_event = format!(
741 "INSERT INTO {} (trace_id, circuit, schematic_version, step, outcome_kind, timestamp_ms, payload_hash, payload)
742 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
743 self.events_table
744 );
745 sqlx::query(&insert_event)
746 .bind(&envelope.trace_id)
747 .bind(&envelope.circuit)
748 .bind(&envelope.schematic_version)
749 .bind(step_i64)
750 .bind(&envelope.outcome_kind)
751 .bind(ts_i64)
752 .bind(&envelope.payload_hash)
753 .bind(&envelope.payload)
754 .execute(&self.pool)
755 .await?;
756 Ok(())
757 }
758
759 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
760 let state_query = format!(
761 "SELECT trace_id, circuit, schematic_version, resumed_from_step, completion
762 FROM {}
763 WHERE trace_id = $1",
764 self.state_table
765 );
766 let Some(state): Option<PostgresStateRow> = sqlx::query_as(&state_query)
767 .bind(trace_id)
768 .fetch_optional(&self.pool)
769 .await?
770 else {
771 return Ok(None);
772 };
773
774 let events_query = format!(
775 "SELECT trace_id, circuit, schematic_version, step, outcome_kind, timestamp_ms, payload_hash, payload
776 FROM {}
777 WHERE trace_id = $1
778 ORDER BY step ASC",
779 self.events_table
780 );
781 let rows: Vec<PostgresEventRow> = sqlx::query_as(&events_query)
782 .bind(trace_id)
783 .fetch_all(&self.pool)
784 .await?;
785
786 let mut events = Vec::with_capacity(rows.len());
787 for row in rows {
788 events.push(PersistenceEnvelope {
789 trace_id: row.trace_id,
790 circuit: row.circuit,
791 schematic_version: row.schematic_version,
792 step: u64::try_from(row.step)?,
793 node_id: None,
794 outcome_kind: row.outcome_kind,
795 timestamp_ms: u64::try_from(row.timestamp_ms)?,
796 payload_hash: row.payload_hash,
797 payload: row.payload,
798 });
799 }
800
801 let completion = match state.completion {
802 Some(value) => Some(completion_state_from_wire(&value)?),
803 None => None,
804 };
805
806 let interventions_query = format!(
807 "SELECT trace_id, target_node, payload_override, timestamp_ms
808 FROM {}
809 WHERE trace_id = $1
810 ORDER BY timestamp_ms ASC",
811 self.interventions_table
812 );
813 let intervention_rows: Vec<PostgresInterventionRow> = sqlx::query_as(&interventions_query)
814 .bind(trace_id)
815 .fetch_all(&self.pool)
816 .await?;
817
818 let mut interventions = Vec::with_capacity(intervention_rows.len());
819 for row in intervention_rows {
820 interventions.push(Intervention {
821 target_node: row.target_node,
822 payload_override: row.payload_override,
823 timestamp_ms: u64::try_from(row.timestamp_ms)?,
824 });
825 }
826
827 Ok(Some(PersistedTrace {
828 trace_id: state.trace_id,
829 circuit: state.circuit,
830 schematic_version: state.schematic_version,
831 events,
832 interventions,
833 resumed_from_step: state.resumed_from_step.map(u64::try_from).transpose()?,
834 completion,
835 }))
836 }
837
838 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
839 let query = format!(
840 "UPDATE {}
841 SET resumed_from_step = $2
842 WHERE trace_id = $1",
843 self.state_table
844 );
845 let rows = sqlx::query(&query)
846 .bind(trace_id)
847 .bind(i64::try_from(resume_from_step)?)
848 .execute(&self.pool)
849 .await?
850 .rows_affected();
851 if rows == 0 {
852 return Err(anyhow!("trace_id {} not found", trace_id));
853 }
854 Ok(ResumeCursor {
855 trace_id: trace_id.to_string(),
856 next_step: resume_from_step.saturating_add(1),
857 })
858 }
859
860 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
861 let query = format!(
862 "UPDATE {}
863 SET completion = $2
864 WHERE trace_id = $1",
865 self.state_table
866 );
867 let rows = sqlx::query(&query)
868 .bind(trace_id)
869 .bind(completion_state_to_wire(&completion))
870 .execute(&self.pool)
871 .await?
872 .rows_affected();
873 if rows == 0 {
874 return Err(anyhow!("trace_id {} not found", trace_id));
875 }
876 Ok(())
877 }
878
879 async fn save_intervention(&self, trace_id: &str, intervention: Intervention) -> Result<()> {
880 let ts_i64 = i64::try_from(intervention.timestamp_ms)?;
881 let query = format!(
882 "INSERT INTO {} (trace_id, target_node, payload_override, timestamp_ms)
883 VALUES ($1, $2, $3, $4)",
884 self.interventions_table
885 );
886 sqlx::query(&query)
887 .bind(trace_id)
888 .bind(&intervention.target_node)
889 .bind(&intervention.payload_override)
890 .bind(ts_i64)
891 .execute(&self.pool)
892 .await?;
893 Ok(())
894 }
895}
896
897#[cfg(feature = "persistence-redis")]
898#[derive(Clone)]
899pub struct RedisPersistenceStore {
900 manager: redis::aio::ConnectionManager,
901 key_prefix: String,
902}
903
904#[cfg(feature = "persistence-redis")]
905impl std::fmt::Debug for RedisPersistenceStore {
906 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
907 f.debug_struct("RedisPersistenceStore")
908 .field("key_prefix", &self.key_prefix)
909 .finish_non_exhaustive()
910 }
911}
912
913#[cfg(feature = "persistence-redis")]
914impl RedisPersistenceStore {
915 pub async fn connect(url: &str) -> Result<Self> {
919 let client = redis::Client::open(url)?;
920 let manager = redis::aio::ConnectionManager::new(client).await?;
921 Ok(Self {
922 manager,
923 key_prefix: "ranvier:persistence".to_string(),
924 })
925 }
926
927 pub fn with_prefix(
928 manager: redis::aio::ConnectionManager,
929 key_prefix: impl Into<String>,
930 ) -> Self {
931 Self {
932 manager,
933 key_prefix: key_prefix.into(),
934 }
935 }
936
937 fn key(&self, trace_id: &str) -> String {
938 format!("{}:{}", self.key_prefix, trace_id)
939 }
940
941 async fn write_trace(&self, trace: &PersistedTrace) -> Result<()> {
942 use redis::AsyncCommands;
943 let key = self.key(&trace.trace_id);
944 let payload = serde_json::to_string(trace)?;
945 let mut conn = self.manager.clone();
946 conn.set::<_, _, ()>(key, payload).await?;
947 Ok(())
948 }
949}
950
951#[cfg(feature = "persistence-redis")]
952#[async_trait]
953impl PersistenceStore for RedisPersistenceStore {
954 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
955 let mut trace = self
956 .load(&envelope.trace_id)
957 .await?
958 .unwrap_or_else(|| PersistedTrace {
959 trace_id: envelope.trace_id.clone(),
960 circuit: envelope.circuit.clone(),
961 schematic_version: envelope.schematic_version.clone(),
962 events: Vec::new(),
963 interventions: Vec::new(),
964 resumed_from_step: None,
965 completion: None,
966 });
967
968 trace.schematic_version = envelope.schematic_version.clone();
969
970 if trace.circuit != envelope.circuit {
971 return Err(anyhow!(
972 "trace_id {} already exists for circuit {}, got {}",
973 envelope.trace_id,
974 trace.circuit,
975 envelope.circuit
976 ));
977 }
978 if trace.completion.is_some() {
979 return Err(anyhow!(
980 "trace_id {} is already completed and cannot accept new events",
981 envelope.trace_id
982 ));
983 }
984
985 trace.events.push(envelope);
986 trace.events.sort_by_key(|event| event.step);
987 self.write_trace(&trace).await?;
988 Ok(())
989 }
990
991 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
992 use redis::AsyncCommands;
993 let key = self.key(trace_id);
994 let mut conn = self.manager.clone();
995 let payload: Option<String> = conn.get(key).await?;
996 let trace = payload
997 .map(|raw| serde_json::from_str::<PersistedTrace>(&raw))
998 .transpose()?;
999 Ok(trace)
1000 }
1001
1002 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
1003 let mut trace = self
1004 .load(trace_id)
1005 .await?
1006 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
1007 trace.resumed_from_step = Some(resume_from_step);
1008 self.write_trace(&trace).await?;
1009 Ok(ResumeCursor {
1010 trace_id: trace_id.to_string(),
1011 next_step: resume_from_step.saturating_add(1),
1012 })
1013 }
1014
1015 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
1016 let mut trace = self
1017 .load(trace_id)
1018 .await?
1019 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
1020 trace.completion = Some(completion);
1021 self.write_trace(&trace).await?;
1022 Ok(())
1023 }
1024
1025 async fn save_intervention(&self, trace_id: &str, intervention: Intervention) -> Result<()> {
1026 let mut trace = self
1027 .load(trace_id)
1028 .await?
1029 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
1030 trace.interventions.push(intervention);
1031 self.write_trace(&trace).await?;
1032 Ok(())
1033 }
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038 use super::*;
1039 #[cfg(any(feature = "persistence-postgres", feature = "persistence-redis"))]
1040 use uuid::Uuid;
1041
1042 fn envelope(step: u64, outcome_kind: &str) -> PersistenceEnvelope {
1043 PersistenceEnvelope {
1044 trace_id: "trace-1".to_string(),
1045 circuit: "OrderCircuit".to_string(),
1046 schematic_version: "1.0".to_string(),
1047 step,
1048 node_id: None,
1049 outcome_kind: outcome_kind.to_string(),
1050 timestamp_ms: 1_700_000_000_000 + step,
1051 payload_hash: Some(format!("hash-{}", step)),
1052 payload: None,
1053 }
1054 }
1055
1056 #[tokio::test]
1057 async fn append_and_load_roundtrip() {
1058 let store = InMemoryPersistenceStore::new();
1059 store.append(envelope(1, "Next")).await.unwrap();
1060 store.append(envelope(2, "Branch")).await.unwrap();
1061
1062 let loaded = store.load("trace-1").await.unwrap().unwrap();
1063 assert_eq!(loaded.trace_id, "trace-1");
1064 assert_eq!(loaded.circuit, "OrderCircuit");
1065 assert_eq!(loaded.events.len(), 2);
1066 assert_eq!(loaded.events[0].step, 1);
1067 assert_eq!(loaded.events[1].outcome_kind, "Branch");
1068 assert_eq!(loaded.completion, None);
1069 }
1070
1071 #[tokio::test]
1072 async fn resume_records_cursor() {
1073 let store = InMemoryPersistenceStore::new();
1074 store.append(envelope(3, "Fault")).await.unwrap();
1075
1076 let cursor = store.resume("trace-1", 3).await.unwrap();
1077 assert_eq!(
1078 cursor,
1079 ResumeCursor {
1080 trace_id: "trace-1".to_string(),
1081 next_step: 4
1082 }
1083 );
1084
1085 let loaded = store.load("trace-1").await.unwrap().unwrap();
1086 assert_eq!(loaded.resumed_from_step, Some(3));
1087 }
1088
1089 #[tokio::test]
1090 async fn complete_marks_trace_and_blocks_append() {
1091 let store = InMemoryPersistenceStore::new();
1092 store.append(envelope(1, "Next")).await.unwrap();
1093 store
1094 .complete("trace-1", CompletionState::Success)
1095 .await
1096 .unwrap();
1097
1098 let loaded = store.load("trace-1").await.unwrap().unwrap();
1099 assert_eq!(loaded.completion, Some(CompletionState::Success));
1100
1101 let err = store.append(envelope(2, "Next")).await.unwrap_err();
1102 assert!(
1103 err.to_string()
1104 .contains("is already completed and cannot accept new events")
1105 );
1106 }
1107
1108 #[tokio::test]
1109 async fn append_rejects_cross_circuit_trace_reuse() {
1110 let store = InMemoryPersistenceStore::new();
1111 store.append(envelope(1, "Next")).await.unwrap();
1112
1113 let mut invalid = envelope(2, "Next");
1114 invalid.circuit = "AnotherCircuit".to_string();
1115 let err = store.append(invalid).await.unwrap_err();
1116 assert!(
1117 err.to_string()
1118 .contains("already exists for circuit OrderCircuit")
1119 );
1120 }
1121
1122 #[tokio::test]
1123 async fn in_memory_compensation_idempotency_roundtrip() {
1124 let store = InMemoryCompensationIdempotencyStore::new();
1125 let key = "trace-a:OrderFlow:Fault";
1126
1127 assert!(!store.was_compensated(key).await.unwrap());
1128 store.mark_compensated(key).await.unwrap();
1129 assert!(store.was_compensated(key).await.unwrap());
1130 }
1131
1132 #[cfg(feature = "persistence-postgres")]
1133 #[tokio::test]
1134 async fn postgres_store_roundtrip_when_configured() {
1135 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1136 Ok(value) => value,
1137 Err(_) => return,
1138 };
1139
1140 let pool = sqlx::postgres::PgPoolOptions::new()
1141 .max_connections(5)
1142 .connect(&url)
1143 .await
1144 .unwrap();
1145 let table_prefix = format!("ranvier_persistence_test_{}", Uuid::new_v4().simple());
1146 let store = PostgresPersistenceStore::with_table_prefix(pool.clone(), table_prefix.clone());
1147 store.ensure_schema().await.unwrap();
1148
1149 let trace_id = format!("trace-{}", Uuid::new_v4().simple());
1150 let circuit = "PgCircuit".to_string();
1151
1152 let mut first = envelope(1, "Next");
1153 first.trace_id = trace_id.clone();
1154 first.circuit = circuit.clone();
1155 store.append(first).await.unwrap();
1156
1157 let mut second = envelope(2, "Branch");
1158 second.trace_id = trace_id.clone();
1159 second.circuit = circuit.clone();
1160 store.append(second).await.unwrap();
1161
1162 let cursor = store.resume(&trace_id, 2).await.unwrap();
1163 assert_eq!(cursor.next_step, 3);
1164
1165 store
1166 .complete(&trace_id, CompletionState::Compensated)
1167 .await
1168 .unwrap();
1169
1170 let loaded = store.load(&trace_id).await.unwrap().unwrap();
1171 assert_eq!(loaded.trace_id, trace_id);
1172 assert_eq!(loaded.circuit, circuit);
1173 assert_eq!(loaded.events.len(), 2);
1174 assert_eq!(loaded.resumed_from_step, Some(2));
1175 assert_eq!(loaded.completion, Some(CompletionState::Compensated));
1176
1177 let drop_events = format!("DROP TABLE IF EXISTS {}", store.events_table);
1178 let drop_state = format!("DROP TABLE IF EXISTS {}", store.state_table);
1179 sqlx::query(&drop_events).execute(&pool).await.unwrap();
1180 sqlx::query(&drop_state).execute(&pool).await.unwrap();
1181 }
1182
1183 #[cfg(feature = "persistence-redis")]
1184 #[tokio::test]
1185 async fn redis_store_roundtrip_when_configured() {
1186 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1187 Ok(value) => value,
1188 Err(_) => return,
1189 };
1190
1191 let base = RedisPersistenceStore::connect(&url).await.unwrap();
1192 let prefix = format!("ranvier:persistence:test:{}", Uuid::new_v4().simple());
1193 let store = RedisPersistenceStore::with_prefix(base.manager.clone(), prefix);
1194
1195 let trace_id = format!("trace-{}", Uuid::new_v4().simple());
1196 let circuit = "RedisCircuit".to_string();
1197
1198 let mut first = envelope(1, "Next");
1199 first.trace_id = trace_id.clone();
1200 first.circuit = circuit.clone();
1201 store.append(first).await.unwrap();
1202
1203 let mut second = envelope(2, "Fault");
1204 second.trace_id = trace_id.clone();
1205 second.circuit = circuit.clone();
1206 store.append(second).await.unwrap();
1207
1208 let cursor = store.resume(&trace_id, 2).await.unwrap();
1209 assert_eq!(cursor.next_step, 3);
1210
1211 store
1212 .complete(&trace_id, CompletionState::Fault)
1213 .await
1214 .unwrap();
1215
1216 let loaded = store.load(&trace_id).await.unwrap().unwrap();
1217 assert_eq!(loaded.trace_id, trace_id);
1218 assert_eq!(loaded.circuit, circuit);
1219 assert_eq!(loaded.events.len(), 2);
1220 assert_eq!(loaded.resumed_from_step, Some(2));
1221 assert_eq!(loaded.completion, Some(CompletionState::Fault));
1222
1223 use redis::AsyncCommands;
1224 let key = store.key(&trace_id);
1225 let mut conn = store.manager.clone();
1226 let _: () = conn.del(key).await.unwrap();
1227 }
1228
1229 #[cfg(feature = "persistence-postgres")]
1230 #[tokio::test]
1231 async fn postgres_compensation_idempotency_roundtrip_when_configured() {
1232 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1233 Ok(value) => value,
1234 Err(_) => return,
1235 };
1236
1237 let pool = sqlx::postgres::PgPoolOptions::new()
1238 .max_connections(5)
1239 .connect(&url)
1240 .await
1241 .unwrap();
1242 let table_prefix = format!(
1243 "ranvier_compensation_idempotency_test_{}",
1244 Uuid::new_v4().simple()
1245 );
1246 let store =
1247 PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1248 store.ensure_schema().await.unwrap();
1249
1250 let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1251 assert!(!store.was_compensated(&key).await.unwrap());
1252 store.mark_compensated(&key).await.unwrap();
1253 assert!(store.was_compensated(&key).await.unwrap());
1254 store.mark_compensated(&key).await.unwrap();
1255 assert!(store.was_compensated(&key).await.unwrap());
1256
1257 let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1258 sqlx::query(&drop_table).execute(&pool).await.unwrap();
1259 }
1260
1261 #[cfg(feature = "persistence-postgres")]
1262 #[tokio::test]
1263 async fn postgres_compensation_idempotency_purge_when_configured() {
1264 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1265 Ok(value) => value,
1266 Err(_) => return,
1267 };
1268
1269 let pool = sqlx::postgres::PgPoolOptions::new()
1270 .max_connections(5)
1271 .connect(&url)
1272 .await
1273 .unwrap();
1274 let table_prefix = format!(
1275 "ranvier_compensation_idempotency_purge_test_{}",
1276 Uuid::new_v4().simple()
1277 );
1278 let store =
1279 PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1280 store.ensure_schema().await.unwrap();
1281
1282 let stale_key = format!("stale-{}", Uuid::new_v4().simple());
1283 let fresh_key = format!("fresh-{}", Uuid::new_v4().simple());
1284 store.mark_compensated(&stale_key).await.unwrap();
1285 store.mark_compensated(&fresh_key).await.unwrap();
1286
1287 let force_stale_query = format!(
1288 "UPDATE {}
1289 SET created_at_ms = 0
1290 WHERE idempotency_key = $1",
1291 store.table
1292 );
1293 sqlx::query(&force_stale_query)
1294 .bind(&stale_key)
1295 .execute(&pool)
1296 .await
1297 .unwrap();
1298
1299 let purged = store.purge_older_than_ms(1).await.unwrap();
1300 assert!(purged >= 1);
1301 assert!(!store.was_compensated(&stale_key).await.unwrap());
1302 assert!(store.was_compensated(&fresh_key).await.unwrap());
1303
1304 let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1305 sqlx::query(&drop_table).execute(&pool).await.unwrap();
1306 }
1307
1308 #[cfg(feature = "persistence-redis")]
1309 #[tokio::test]
1310 async fn redis_compensation_idempotency_roundtrip_when_configured() {
1311 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1312 Ok(value) => value,
1313 Err(_) => return,
1314 };
1315
1316 let base = RedisCompensationIdempotencyStore::connect(&url)
1317 .await
1318 .unwrap();
1319 let prefix = format!(
1320 "ranvier:compensation:idempotency:test:{}",
1321 Uuid::new_v4().simple()
1322 );
1323 let store = RedisCompensationIdempotencyStore::with_prefix(base.manager.clone(), prefix);
1324 let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1325
1326 assert!(!store.was_compensated(&key).await.unwrap());
1327 store.mark_compensated(&key).await.unwrap();
1328 assert!(store.was_compensated(&key).await.unwrap());
1329 store.mark_compensated(&key).await.unwrap();
1330 assert!(store.was_compensated(&key).await.unwrap());
1331
1332 use redis::AsyncCommands;
1333 let mut conn = store.manager.clone();
1334 let _: () = conn.del(store.key(&key)).await.unwrap();
1335 }
1336
1337 #[cfg(feature = "persistence-redis")]
1338 #[tokio::test]
1339 async fn redis_compensation_idempotency_ttl_when_configured() {
1340 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1341 Ok(value) => value,
1342 Err(_) => return,
1343 };
1344
1345 let base = RedisCompensationIdempotencyStore::connect(&url)
1346 .await
1347 .unwrap();
1348 let prefix = format!(
1349 "ranvier:compensation:idempotency:ttl:test:{}",
1350 Uuid::new_v4().simple()
1351 );
1352 let store =
1353 RedisCompensationIdempotencyStore::with_prefix_and_ttl(base.manager.clone(), prefix, 1);
1354 let key = format!("ttl-{}", Uuid::new_v4().simple());
1355
1356 assert!(!store.was_compensated(&key).await.unwrap());
1357 store.mark_compensated(&key).await.unwrap();
1358 assert!(store.was_compensated(&key).await.unwrap());
1359
1360 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
1361 assert!(!store.was_compensated(&key).await.unwrap());
1362 }
1363}