1use anyhow::{Result, anyhow};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
18pub struct PersistenceEnvelope {
19 pub trace_id: String,
20 pub circuit: String,
21 pub step: u64,
22 pub outcome_kind: String,
23 pub timestamp_ms: u64,
24 pub payload_hash: Option<String>,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub enum CompletionState {
30 Success,
31 Fault,
32 Cancelled,
33 Compensated,
34}
35
36#[cfg(feature = "persistence-postgres")]
37fn completion_state_to_wire(state: &CompletionState) -> &'static str {
38 match state {
39 CompletionState::Success => "success",
40 CompletionState::Fault => "fault",
41 CompletionState::Cancelled => "cancelled",
42 CompletionState::Compensated => "compensated",
43 }
44}
45
46#[cfg(feature = "persistence-postgres")]
47fn completion_state_from_wire(value: &str) -> Result<CompletionState> {
48 match value {
49 "success" => Ok(CompletionState::Success),
50 "fault" => Ok(CompletionState::Fault),
51 "cancelled" => Ok(CompletionState::Cancelled),
52 "compensated" => Ok(CompletionState::Compensated),
53 other => Err(anyhow!("unknown completion state value: {}", other)),
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct PersistedTrace {
60 pub trace_id: String,
61 pub circuit: String,
62 pub events: Vec<PersistenceEnvelope>,
63 pub resumed_from_step: Option<u64>,
64 pub completion: Option<CompletionState>,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
69pub struct ResumeCursor {
70 pub trace_id: String,
71 pub next_step: u64,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct PersistenceTraceId(pub String);
79
80impl PersistenceTraceId {
81 pub fn new(value: impl Into<String>) -> Self {
82 Self(value.into())
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub struct PersistenceAutoComplete(pub bool);
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct CompensationContext {
98 pub trace_id: String,
99 pub circuit: String,
100 pub fault_kind: String,
101 pub fault_step: u64,
102 pub timestamp_ms: u64,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub struct CompensationAutoTrigger(pub bool);
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub struct CompensationRetryPolicy {
116 pub max_attempts: u32,
117 pub backoff_ms: u64,
118}
119
120impl Default for CompensationRetryPolicy {
121 fn default() -> Self {
122 Self {
123 max_attempts: 1,
124 backoff_ms: 0,
125 }
126 }
127}
128
129#[async_trait]
131pub trait CompensationHook: Send + Sync {
132 async fn compensate(&self, context: CompensationContext) -> Result<()>;
133}
134
135#[derive(Clone)]
137pub struct CompensationHandle {
138 inner: Arc<dyn CompensationHook>,
139}
140
141impl std::fmt::Debug for CompensationHandle {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 f.debug_struct("CompensationHandle").finish_non_exhaustive()
144 }
145}
146
147impl CompensationHandle {
148 pub fn from_hook<H>(hook: H) -> Self
150 where
151 H: CompensationHook + 'static,
152 {
153 Self {
154 inner: Arc::new(hook),
155 }
156 }
157
158 pub fn from_arc(hook: Arc<dyn CompensationHook>) -> Self {
160 Self { inner: hook }
161 }
162
163 pub fn hook(&self) -> Arc<dyn CompensationHook> {
165 self.inner.clone()
166 }
167}
168
169#[async_trait]
171pub trait CompensationIdempotencyStore: Send + Sync {
172 async fn was_compensated(&self, key: &str) -> Result<bool>;
173 async fn mark_compensated(&self, key: &str) -> Result<()>;
174}
175
176#[derive(Clone)]
178pub struct CompensationIdempotencyHandle {
179 inner: Arc<dyn CompensationIdempotencyStore>,
180}
181
182impl std::fmt::Debug for CompensationIdempotencyHandle {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("CompensationIdempotencyHandle")
185 .finish_non_exhaustive()
186 }
187}
188
189impl CompensationIdempotencyHandle {
190 pub fn from_store<S>(store: S) -> Self
191 where
192 S: CompensationIdempotencyStore + 'static,
193 {
194 Self {
195 inner: Arc::new(store),
196 }
197 }
198
199 pub fn from_arc(store: Arc<dyn CompensationIdempotencyStore>) -> Self {
200 Self { inner: store }
201 }
202
203 pub fn store(&self) -> Arc<dyn CompensationIdempotencyStore> {
204 self.inner.clone()
205 }
206}
207
208#[derive(Debug, Default, Clone)]
210pub struct InMemoryCompensationIdempotencyStore {
211 keys: Arc<RwLock<HashSet<String>>>,
212}
213
214impl InMemoryCompensationIdempotencyStore {
215 pub fn new() -> Self {
216 Self::default()
217 }
218}
219
220#[async_trait]
221impl CompensationIdempotencyStore for InMemoryCompensationIdempotencyStore {
222 async fn was_compensated(&self, key: &str) -> Result<bool> {
223 let guard = self.keys.read().await;
224 Ok(guard.contains(key))
225 }
226
227 async fn mark_compensated(&self, key: &str) -> Result<()> {
228 let mut guard = self.keys.write().await;
229 guard.insert(key.to_string());
230 Ok(())
231 }
232}
233
234#[cfg(feature = "persistence-postgres")]
235#[derive(Debug, Clone)]
236pub struct PostgresCompensationIdempotencyStore {
237 pool: sqlx::Pool<sqlx::Postgres>,
238 table: String,
239}
240
241#[cfg(feature = "persistence-postgres")]
242impl PostgresCompensationIdempotencyStore {
243 pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
245 Self::with_table_prefix(pool, "ranvier_persistence")
246 }
247
248 pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
250 let prefix = prefix.into();
251 Self {
252 pool,
253 table: format!("{}_compensation_idempotency", prefix),
254 }
255 }
256
257 pub async fn ensure_schema(&self) -> Result<()> {
259 let create = format!(
260 "CREATE TABLE IF NOT EXISTS {} (
261 idempotency_key TEXT PRIMARY KEY,
262 created_at_ms BIGINT NOT NULL
263 )",
264 self.table
265 );
266 sqlx::query(&create).execute(&self.pool).await?;
267 Ok(())
268 }
269
270 pub async fn purge_older_than_ms(&self, cutoff_ms: i64) -> Result<u64> {
272 let query = format!(
273 "DELETE FROM {}
274 WHERE created_at_ms < $1",
275 self.table
276 );
277 let rows = sqlx::query(&query)
278 .bind(cutoff_ms)
279 .execute(&self.pool)
280 .await?
281 .rows_affected();
282 Ok(rows)
283 }
284}
285
286#[cfg(feature = "persistence-postgres")]
287#[async_trait]
288impl CompensationIdempotencyStore for PostgresCompensationIdempotencyStore {
289 async fn was_compensated(&self, key: &str) -> Result<bool> {
290 let query = format!(
291 "SELECT 1
292 FROM {}
293 WHERE idempotency_key = $1
294 LIMIT 1",
295 self.table
296 );
297 let row: Option<i32> = sqlx::query_scalar(&query)
298 .bind(key)
299 .fetch_optional(&self.pool)
300 .await?;
301 Ok(row.is_some())
302 }
303
304 async fn mark_compensated(&self, key: &str) -> Result<()> {
305 let query = format!(
306 "INSERT INTO {} (idempotency_key, created_at_ms)
307 VALUES ($1, $2)
308 ON CONFLICT (idempotency_key) DO NOTHING",
309 self.table
310 );
311 let now_ms = std::time::SystemTime::now()
312 .duration_since(std::time::UNIX_EPOCH)?
313 .as_millis();
314 sqlx::query(&query)
315 .bind(key)
316 .bind(i64::try_from(now_ms)?)
317 .execute(&self.pool)
318 .await?;
319 Ok(())
320 }
321}
322
323#[cfg(feature = "persistence-redis")]
324#[derive(Clone)]
325pub struct RedisCompensationIdempotencyStore {
326 manager: redis::aio::ConnectionManager,
327 key_prefix: String,
328 ttl_seconds: Option<u64>,
329}
330
331#[cfg(feature = "persistence-redis")]
332impl std::fmt::Debug for RedisCompensationIdempotencyStore {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 f.debug_struct("RedisCompensationIdempotencyStore")
335 .field("key_prefix", &self.key_prefix)
336 .field("ttl_seconds", &self.ttl_seconds)
337 .finish_non_exhaustive()
338 }
339}
340
341#[cfg(feature = "persistence-redis")]
342impl RedisCompensationIdempotencyStore {
343 pub async fn connect(url: &str) -> Result<Self> {
345 let client = redis::Client::open(url)?;
346 let manager = redis::aio::ConnectionManager::new(client).await?;
347 Ok(Self {
348 manager,
349 key_prefix: "ranvier:compensation:idempotency".to_string(),
350 ttl_seconds: None,
351 })
352 }
353
354 pub fn with_prefix(
355 manager: redis::aio::ConnectionManager,
356 key_prefix: impl Into<String>,
357 ) -> Self {
358 Self {
359 manager,
360 key_prefix: key_prefix.into(),
361 ttl_seconds: None,
362 }
363 }
364
365 pub fn with_prefix_and_ttl(
366 manager: redis::aio::ConnectionManager,
367 key_prefix: impl Into<String>,
368 ttl_seconds: u64,
369 ) -> Self {
370 Self {
371 manager,
372 key_prefix: key_prefix.into(),
373 ttl_seconds: Some(ttl_seconds),
374 }
375 }
376
377 fn key(&self, idempotency_key: &str) -> String {
378 format!("{}:{}", self.key_prefix, idempotency_key)
379 }
380}
381
382#[cfg(feature = "persistence-redis")]
383#[async_trait]
384impl CompensationIdempotencyStore for RedisCompensationIdempotencyStore {
385 async fn was_compensated(&self, key: &str) -> Result<bool> {
386 use redis::AsyncCommands;
387 let mut conn = self.manager.clone();
388 let exists: bool = conn.exists(self.key(key)).await?;
389 Ok(exists)
390 }
391
392 async fn mark_compensated(&self, key: &str) -> Result<()> {
393 use redis::AsyncCommands;
394 let mut conn = self.manager.clone();
395 let redis_key = self.key(key);
396 let inserted: bool = conn.set_nx(&redis_key, "1").await?;
397 if inserted {
398 if let Some(ttl_seconds) = self.ttl_seconds {
399 let ttl_i64 = i64::try_from(ttl_seconds)?;
400 let _: bool = conn.expire(&redis_key, ttl_i64).await?;
401 }
402 }
403 Ok(())
404 }
405}
406
407#[async_trait]
411pub trait PersistenceStore: Send + Sync {
412 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()>;
413 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>>;
414 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor>;
415 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()>;
416}
417
418#[derive(Clone)]
420pub struct PersistenceHandle {
421 inner: Arc<dyn PersistenceStore>,
422}
423
424impl std::fmt::Debug for PersistenceHandle {
425 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426 f.debug_struct("PersistenceHandle").finish_non_exhaustive()
427 }
428}
429
430impl PersistenceHandle {
431 pub fn from_store<S>(store: S) -> Self
433 where
434 S: PersistenceStore + 'static,
435 {
436 Self {
437 inner: Arc::new(store),
438 }
439 }
440
441 pub fn from_arc(store: Arc<dyn PersistenceStore>) -> Self {
443 Self { inner: store }
444 }
445
446 pub fn store(&self) -> Arc<dyn PersistenceStore> {
448 self.inner.clone()
449 }
450}
451
452#[derive(Debug, Default, Clone)]
454pub struct InMemoryPersistenceStore {
455 inner: Arc<RwLock<HashMap<String, PersistedTrace>>>,
456}
457
458impl InMemoryPersistenceStore {
459 pub fn new() -> Self {
460 Self::default()
461 }
462}
463
464#[async_trait]
465impl PersistenceStore for InMemoryPersistenceStore {
466 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
467 let mut guard = self.inner.write().await;
468 let entry = guard
469 .entry(envelope.trace_id.clone())
470 .or_insert_with(|| PersistedTrace {
471 trace_id: envelope.trace_id.clone(),
472 circuit: envelope.circuit.clone(),
473 events: Vec::new(),
474 resumed_from_step: None,
475 completion: None,
476 });
477
478 if entry.circuit != envelope.circuit {
479 return Err(anyhow!(
480 "trace_id {} already exists for circuit {}, got {}",
481 envelope.trace_id,
482 entry.circuit,
483 envelope.circuit
484 ));
485 }
486 if entry.completion.is_some() {
487 return Err(anyhow!(
488 "trace_id {} is already completed and cannot accept new events",
489 envelope.trace_id
490 ));
491 }
492 entry.events.push(envelope);
493 entry.events.sort_by_key(|e| e.step);
494 Ok(())
495 }
496
497 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
498 let guard = self.inner.read().await;
499 Ok(guard.get(trace_id).cloned())
500 }
501
502 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
503 let mut guard = self.inner.write().await;
504 let trace = guard
505 .get_mut(trace_id)
506 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
507 trace.resumed_from_step = Some(resume_from_step);
508 Ok(ResumeCursor {
509 trace_id: trace_id.to_string(),
510 next_step: resume_from_step.saturating_add(1),
511 })
512 }
513
514 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
515 let mut guard = self.inner.write().await;
516 let trace = guard
517 .get_mut(trace_id)
518 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
519 trace.completion = Some(completion);
520 Ok(())
521 }
522}
523
524#[cfg(feature = "persistence-postgres")]
525#[derive(Debug, Clone)]
526pub struct PostgresPersistenceStore {
527 pool: sqlx::Pool<sqlx::Postgres>,
528 events_table: String,
529 state_table: String,
530}
531
532#[cfg(feature = "persistence-postgres")]
533#[derive(sqlx::FromRow)]
534struct PostgresEventRow {
535 trace_id: String,
536 circuit: String,
537 step: i64,
538 outcome_kind: String,
539 timestamp_ms: i64,
540 payload_hash: Option<String>,
541}
542
543#[cfg(feature = "persistence-postgres")]
544#[derive(sqlx::FromRow)]
545struct PostgresStateRow {
546 trace_id: String,
547 circuit: String,
548 resumed_from_step: Option<i64>,
549 completion: Option<String>,
550}
551
552#[cfg(feature = "persistence-postgres")]
553impl PostgresPersistenceStore {
554 pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
558 Self::with_table_prefix(pool, "ranvier_persistence")
559 }
560
561 pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
563 let prefix = prefix.into();
564 Self {
565 pool,
566 events_table: format!("{}_events", prefix),
567 state_table: format!("{}_state", prefix),
568 }
569 }
570
571 pub async fn ensure_schema(&self) -> Result<()> {
573 let create_state = format!(
574 "CREATE TABLE IF NOT EXISTS {} (
575 trace_id TEXT PRIMARY KEY,
576 circuit TEXT NOT NULL,
577 resumed_from_step BIGINT NULL,
578 completion TEXT NULL
579 )",
580 self.state_table
581 );
582 sqlx::query(&create_state).execute(&self.pool).await?;
583
584 let create_events = format!(
585 "CREATE TABLE IF NOT EXISTS {} (
586 trace_id TEXT NOT NULL,
587 circuit TEXT NOT NULL,
588 step BIGINT NOT NULL,
589 outcome_kind TEXT NOT NULL,
590 timestamp_ms BIGINT NOT NULL,
591 payload_hash TEXT NULL,
592 PRIMARY KEY (trace_id, step)
593 )",
594 self.events_table
595 );
596 sqlx::query(&create_events).execute(&self.pool).await?;
597 Ok(())
598 }
599}
600
601#[cfg(feature = "persistence-postgres")]
602#[async_trait]
603impl PersistenceStore for PostgresPersistenceStore {
604 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
605 let insert_state = format!(
606 "INSERT INTO {} (trace_id, circuit, resumed_from_step, completion)
607 VALUES ($1, $2, NULL, NULL)
608 ON CONFLICT (trace_id) DO NOTHING",
609 self.state_table
610 );
611 sqlx::query(&insert_state)
612 .bind(&envelope.trace_id)
613 .bind(&envelope.circuit)
614 .execute(&self.pool)
615 .await?;
616
617 let read_state = format!(
618 "SELECT circuit FROM {} WHERE trace_id = $1",
619 self.state_table
620 );
621 let existing_circuit: Option<String> = sqlx::query_scalar(&read_state)
622 .bind(&envelope.trace_id)
623 .fetch_optional(&self.pool)
624 .await?;
625 if existing_circuit.as_deref() != Some(envelope.circuit.as_str()) {
626 return Err(anyhow!(
627 "trace_id {} already exists for another circuit",
628 envelope.trace_id
629 ));
630 }
631
632 let completion_query = format!(
633 "SELECT completion FROM {} WHERE trace_id = $1",
634 self.state_table
635 );
636 let completion: Option<Option<String>> = sqlx::query_scalar(&completion_query)
637 .bind(&envelope.trace_id)
638 .fetch_optional(&self.pool)
639 .await?;
640 if completion.flatten().is_some() {
641 return Err(anyhow!(
642 "trace_id {} is already completed and cannot accept new events",
643 envelope.trace_id
644 ));
645 }
646
647 let step_i64 = i64::try_from(envelope.step)?;
648 let ts_i64 = i64::try_from(envelope.timestamp_ms)?;
649 let insert_event = format!(
650 "INSERT INTO {} (trace_id, circuit, step, outcome_kind, timestamp_ms, payload_hash)
651 VALUES ($1, $2, $3, $4, $5, $6)",
652 self.events_table
653 );
654 sqlx::query(&insert_event)
655 .bind(&envelope.trace_id)
656 .bind(&envelope.circuit)
657 .bind(step_i64)
658 .bind(&envelope.outcome_kind)
659 .bind(ts_i64)
660 .bind(&envelope.payload_hash)
661 .execute(&self.pool)
662 .await?;
663 Ok(())
664 }
665
666 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
667 let state_query = format!(
668 "SELECT trace_id, circuit, resumed_from_step, completion
669 FROM {}
670 WHERE trace_id = $1",
671 self.state_table
672 );
673 let Some(state): Option<PostgresStateRow> = sqlx::query_as(&state_query)
674 .bind(trace_id)
675 .fetch_optional(&self.pool)
676 .await?
677 else {
678 return Ok(None);
679 };
680
681 let events_query = format!(
682 "SELECT trace_id, circuit, step, outcome_kind, timestamp_ms, payload_hash
683 FROM {}
684 WHERE trace_id = $1
685 ORDER BY step ASC",
686 self.events_table
687 );
688 let rows: Vec<PostgresEventRow> = sqlx::query_as(&events_query)
689 .bind(trace_id)
690 .fetch_all(&self.pool)
691 .await?;
692
693 let mut events = Vec::with_capacity(rows.len());
694 for row in rows {
695 events.push(PersistenceEnvelope {
696 trace_id: row.trace_id,
697 circuit: row.circuit,
698 step: u64::try_from(row.step)?,
699 outcome_kind: row.outcome_kind,
700 timestamp_ms: u64::try_from(row.timestamp_ms)?,
701 payload_hash: row.payload_hash,
702 });
703 }
704
705 let completion = match state.completion {
706 Some(value) => Some(completion_state_from_wire(&value)?),
707 None => None,
708 };
709
710 Ok(Some(PersistedTrace {
711 trace_id: state.trace_id,
712 circuit: state.circuit,
713 events,
714 resumed_from_step: state.resumed_from_step.map(u64::try_from).transpose()?,
715 completion,
716 }))
717 }
718
719 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
720 let query = format!(
721 "UPDATE {}
722 SET resumed_from_step = $2
723 WHERE trace_id = $1",
724 self.state_table
725 );
726 let rows = sqlx::query(&query)
727 .bind(trace_id)
728 .bind(i64::try_from(resume_from_step)?)
729 .execute(&self.pool)
730 .await?
731 .rows_affected();
732 if rows == 0 {
733 return Err(anyhow!("trace_id {} not found", trace_id));
734 }
735 Ok(ResumeCursor {
736 trace_id: trace_id.to_string(),
737 next_step: resume_from_step.saturating_add(1),
738 })
739 }
740
741 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
742 let query = format!(
743 "UPDATE {}
744 SET completion = $2
745 WHERE trace_id = $1",
746 self.state_table
747 );
748 let rows = sqlx::query(&query)
749 .bind(trace_id)
750 .bind(completion_state_to_wire(&completion))
751 .execute(&self.pool)
752 .await?
753 .rows_affected();
754 if rows == 0 {
755 return Err(anyhow!("trace_id {} not found", trace_id));
756 }
757 Ok(())
758 }
759}
760
761#[cfg(feature = "persistence-redis")]
762#[derive(Clone)]
763pub struct RedisPersistenceStore {
764 manager: redis::aio::ConnectionManager,
765 key_prefix: String,
766}
767
768#[cfg(feature = "persistence-redis")]
769impl std::fmt::Debug for RedisPersistenceStore {
770 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
771 f.debug_struct("RedisPersistenceStore")
772 .field("key_prefix", &self.key_prefix)
773 .finish_non_exhaustive()
774 }
775}
776
777#[cfg(feature = "persistence-redis")]
778impl RedisPersistenceStore {
779 pub async fn connect(url: &str) -> Result<Self> {
783 let client = redis::Client::open(url)?;
784 let manager = redis::aio::ConnectionManager::new(client).await?;
785 Ok(Self {
786 manager,
787 key_prefix: "ranvier:persistence".to_string(),
788 })
789 }
790
791 pub fn with_prefix(
792 manager: redis::aio::ConnectionManager,
793 key_prefix: impl Into<String>,
794 ) -> Self {
795 Self {
796 manager,
797 key_prefix: key_prefix.into(),
798 }
799 }
800
801 fn key(&self, trace_id: &str) -> String {
802 format!("{}:{}", self.key_prefix, trace_id)
803 }
804
805 async fn write_trace(&self, trace: &PersistedTrace) -> Result<()> {
806 use redis::AsyncCommands;
807 let key = self.key(&trace.trace_id);
808 let payload = serde_json::to_string(trace)?;
809 let mut conn = self.manager.clone();
810 conn.set::<_, _, ()>(key, payload).await?;
811 Ok(())
812 }
813}
814
815#[cfg(feature = "persistence-redis")]
816#[async_trait]
817impl PersistenceStore for RedisPersistenceStore {
818 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
819 let mut trace = self
820 .load(&envelope.trace_id)
821 .await?
822 .unwrap_or_else(|| PersistedTrace {
823 trace_id: envelope.trace_id.clone(),
824 circuit: envelope.circuit.clone(),
825 events: Vec::new(),
826 resumed_from_step: None,
827 completion: None,
828 });
829
830 if trace.circuit != envelope.circuit {
831 return Err(anyhow!(
832 "trace_id {} already exists for circuit {}, got {}",
833 envelope.trace_id,
834 trace.circuit,
835 envelope.circuit
836 ));
837 }
838 if trace.completion.is_some() {
839 return Err(anyhow!(
840 "trace_id {} is already completed and cannot accept new events",
841 envelope.trace_id
842 ));
843 }
844
845 trace.events.push(envelope);
846 trace.events.sort_by_key(|event| event.step);
847 self.write_trace(&trace).await?;
848 Ok(())
849 }
850
851 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
852 use redis::AsyncCommands;
853 let key = self.key(trace_id);
854 let mut conn = self.manager.clone();
855 let payload: Option<String> = conn.get(key).await?;
856 let trace = payload
857 .map(|raw| serde_json::from_str::<PersistedTrace>(&raw))
858 .transpose()?;
859 Ok(trace)
860 }
861
862 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
863 let mut trace = self
864 .load(trace_id)
865 .await?
866 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
867 trace.resumed_from_step = Some(resume_from_step);
868 self.write_trace(&trace).await?;
869 Ok(ResumeCursor {
870 trace_id: trace_id.to_string(),
871 next_step: resume_from_step.saturating_add(1),
872 })
873 }
874
875 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
876 let mut trace = self
877 .load(trace_id)
878 .await?
879 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
880 trace.completion = Some(completion);
881 self.write_trace(&trace).await?;
882 Ok(())
883 }
884}
885
886#[cfg(test)]
887mod tests {
888 use super::*;
889 #[cfg(any(feature = "persistence-postgres", feature = "persistence-redis"))]
890 use uuid::Uuid;
891
892 fn envelope(step: u64, outcome_kind: &str) -> PersistenceEnvelope {
893 PersistenceEnvelope {
894 trace_id: "trace-1".to_string(),
895 circuit: "OrderCircuit".to_string(),
896 step,
897 outcome_kind: outcome_kind.to_string(),
898 timestamp_ms: 1_700_000_000_000 + step,
899 payload_hash: Some(format!("hash-{}", step)),
900 }
901 }
902
903 #[tokio::test]
904 async fn append_and_load_roundtrip() {
905 let store = InMemoryPersistenceStore::new();
906 store.append(envelope(1, "Next")).await.unwrap();
907 store.append(envelope(2, "Branch")).await.unwrap();
908
909 let loaded = store.load("trace-1").await.unwrap().unwrap();
910 assert_eq!(loaded.trace_id, "trace-1");
911 assert_eq!(loaded.circuit, "OrderCircuit");
912 assert_eq!(loaded.events.len(), 2);
913 assert_eq!(loaded.events[0].step, 1);
914 assert_eq!(loaded.events[1].outcome_kind, "Branch");
915 assert_eq!(loaded.completion, None);
916 }
917
918 #[tokio::test]
919 async fn resume_records_cursor() {
920 let store = InMemoryPersistenceStore::new();
921 store.append(envelope(3, "Fault")).await.unwrap();
922
923 let cursor = store.resume("trace-1", 3).await.unwrap();
924 assert_eq!(
925 cursor,
926 ResumeCursor {
927 trace_id: "trace-1".to_string(),
928 next_step: 4
929 }
930 );
931
932 let loaded = store.load("trace-1").await.unwrap().unwrap();
933 assert_eq!(loaded.resumed_from_step, Some(3));
934 }
935
936 #[tokio::test]
937 async fn complete_marks_trace_and_blocks_append() {
938 let store = InMemoryPersistenceStore::new();
939 store.append(envelope(1, "Next")).await.unwrap();
940 store
941 .complete("trace-1", CompletionState::Success)
942 .await
943 .unwrap();
944
945 let loaded = store.load("trace-1").await.unwrap().unwrap();
946 assert_eq!(loaded.completion, Some(CompletionState::Success));
947
948 let err = store.append(envelope(2, "Next")).await.unwrap_err();
949 assert!(
950 err.to_string()
951 .contains("is already completed and cannot accept new events")
952 );
953 }
954
955 #[tokio::test]
956 async fn append_rejects_cross_circuit_trace_reuse() {
957 let store = InMemoryPersistenceStore::new();
958 store.append(envelope(1, "Next")).await.unwrap();
959
960 let mut invalid = envelope(2, "Next");
961 invalid.circuit = "AnotherCircuit".to_string();
962 let err = store.append(invalid).await.unwrap_err();
963 assert!(
964 err.to_string()
965 .contains("already exists for circuit OrderCircuit")
966 );
967 }
968
969 #[tokio::test]
970 async fn in_memory_compensation_idempotency_roundtrip() {
971 let store = InMemoryCompensationIdempotencyStore::new();
972 let key = "trace-a:OrderFlow:Fault";
973
974 assert!(!store.was_compensated(key).await.unwrap());
975 store.mark_compensated(key).await.unwrap();
976 assert!(store.was_compensated(key).await.unwrap());
977 }
978
979 #[cfg(feature = "persistence-postgres")]
980 #[tokio::test]
981 async fn postgres_store_roundtrip_when_configured() {
982 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
983 Ok(value) => value,
984 Err(_) => return,
985 };
986
987 let pool = sqlx::postgres::PgPoolOptions::new()
988 .max_connections(5)
989 .connect(&url)
990 .await
991 .unwrap();
992 let table_prefix = format!("ranvier_persistence_test_{}", Uuid::new_v4().simple());
993 let store = PostgresPersistenceStore::with_table_prefix(pool.clone(), table_prefix.clone());
994 store.ensure_schema().await.unwrap();
995
996 let trace_id = format!("trace-{}", Uuid::new_v4().simple());
997 let circuit = "PgCircuit".to_string();
998
999 let mut first = envelope(1, "Next");
1000 first.trace_id = trace_id.clone();
1001 first.circuit = circuit.clone();
1002 store.append(first).await.unwrap();
1003
1004 let mut second = envelope(2, "Branch");
1005 second.trace_id = trace_id.clone();
1006 second.circuit = circuit.clone();
1007 store.append(second).await.unwrap();
1008
1009 let cursor = store.resume(&trace_id, 2).await.unwrap();
1010 assert_eq!(cursor.next_step, 3);
1011
1012 store
1013 .complete(&trace_id, CompletionState::Compensated)
1014 .await
1015 .unwrap();
1016
1017 let loaded = store.load(&trace_id).await.unwrap().unwrap();
1018 assert_eq!(loaded.trace_id, trace_id);
1019 assert_eq!(loaded.circuit, circuit);
1020 assert_eq!(loaded.events.len(), 2);
1021 assert_eq!(loaded.resumed_from_step, Some(2));
1022 assert_eq!(loaded.completion, Some(CompletionState::Compensated));
1023
1024 let drop_events = format!("DROP TABLE IF EXISTS {}", store.events_table);
1025 let drop_state = format!("DROP TABLE IF EXISTS {}", store.state_table);
1026 sqlx::query(&drop_events).execute(&pool).await.unwrap();
1027 sqlx::query(&drop_state).execute(&pool).await.unwrap();
1028 }
1029
1030 #[cfg(feature = "persistence-redis")]
1031 #[tokio::test]
1032 async fn redis_store_roundtrip_when_configured() {
1033 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1034 Ok(value) => value,
1035 Err(_) => return,
1036 };
1037
1038 let base = RedisPersistenceStore::connect(&url).await.unwrap();
1039 let prefix = format!("ranvier:persistence:test:{}", Uuid::new_v4().simple());
1040 let store = RedisPersistenceStore::with_prefix(base.manager.clone(), prefix);
1041
1042 let trace_id = format!("trace-{}", Uuid::new_v4().simple());
1043 let circuit = "RedisCircuit".to_string();
1044
1045 let mut first = envelope(1, "Next");
1046 first.trace_id = trace_id.clone();
1047 first.circuit = circuit.clone();
1048 store.append(first).await.unwrap();
1049
1050 let mut second = envelope(2, "Fault");
1051 second.trace_id = trace_id.clone();
1052 second.circuit = circuit.clone();
1053 store.append(second).await.unwrap();
1054
1055 let cursor = store.resume(&trace_id, 2).await.unwrap();
1056 assert_eq!(cursor.next_step, 3);
1057
1058 store
1059 .complete(&trace_id, CompletionState::Fault)
1060 .await
1061 .unwrap();
1062
1063 let loaded = store.load(&trace_id).await.unwrap().unwrap();
1064 assert_eq!(loaded.trace_id, trace_id);
1065 assert_eq!(loaded.circuit, circuit);
1066 assert_eq!(loaded.events.len(), 2);
1067 assert_eq!(loaded.resumed_from_step, Some(2));
1068 assert_eq!(loaded.completion, Some(CompletionState::Fault));
1069
1070 use redis::AsyncCommands;
1071 let key = store.key(&trace_id);
1072 let mut conn = store.manager.clone();
1073 let _: () = conn.del(key).await.unwrap();
1074 }
1075
1076 #[cfg(feature = "persistence-postgres")]
1077 #[tokio::test]
1078 async fn postgres_compensation_idempotency_roundtrip_when_configured() {
1079 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1080 Ok(value) => value,
1081 Err(_) => return,
1082 };
1083
1084 let pool = sqlx::postgres::PgPoolOptions::new()
1085 .max_connections(5)
1086 .connect(&url)
1087 .await
1088 .unwrap();
1089 let table_prefix = format!(
1090 "ranvier_compensation_idempotency_test_{}",
1091 Uuid::new_v4().simple()
1092 );
1093 let store =
1094 PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1095 store.ensure_schema().await.unwrap();
1096
1097 let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1098 assert!(!store.was_compensated(&key).await.unwrap());
1099 store.mark_compensated(&key).await.unwrap();
1100 assert!(store.was_compensated(&key).await.unwrap());
1101 store.mark_compensated(&key).await.unwrap();
1102 assert!(store.was_compensated(&key).await.unwrap());
1103
1104 let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1105 sqlx::query(&drop_table).execute(&pool).await.unwrap();
1106 }
1107
1108 #[cfg(feature = "persistence-postgres")]
1109 #[tokio::test]
1110 async fn postgres_compensation_idempotency_purge_when_configured() {
1111 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1112 Ok(value) => value,
1113 Err(_) => return,
1114 };
1115
1116 let pool = sqlx::postgres::PgPoolOptions::new()
1117 .max_connections(5)
1118 .connect(&url)
1119 .await
1120 .unwrap();
1121 let table_prefix = format!(
1122 "ranvier_compensation_idempotency_purge_test_{}",
1123 Uuid::new_v4().simple()
1124 );
1125 let store =
1126 PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1127 store.ensure_schema().await.unwrap();
1128
1129 let stale_key = format!("stale-{}", Uuid::new_v4().simple());
1130 let fresh_key = format!("fresh-{}", Uuid::new_v4().simple());
1131 store.mark_compensated(&stale_key).await.unwrap();
1132 store.mark_compensated(&fresh_key).await.unwrap();
1133
1134 let force_stale_query = format!(
1135 "UPDATE {}
1136 SET created_at_ms = 0
1137 WHERE idempotency_key = $1",
1138 store.table
1139 );
1140 sqlx::query(&force_stale_query)
1141 .bind(&stale_key)
1142 .execute(&pool)
1143 .await
1144 .unwrap();
1145
1146 let purged = store.purge_older_than_ms(1).await.unwrap();
1147 assert!(purged >= 1);
1148 assert!(!store.was_compensated(&stale_key).await.unwrap());
1149 assert!(store.was_compensated(&fresh_key).await.unwrap());
1150
1151 let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1152 sqlx::query(&drop_table).execute(&pool).await.unwrap();
1153 }
1154
1155 #[cfg(feature = "persistence-redis")]
1156 #[tokio::test]
1157 async fn redis_compensation_idempotency_roundtrip_when_configured() {
1158 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1159 Ok(value) => value,
1160 Err(_) => return,
1161 };
1162
1163 let base = RedisCompensationIdempotencyStore::connect(&url)
1164 .await
1165 .unwrap();
1166 let prefix = format!(
1167 "ranvier:compensation:idempotency:test:{}",
1168 Uuid::new_v4().simple()
1169 );
1170 let store = RedisCompensationIdempotencyStore::with_prefix(base.manager.clone(), prefix);
1171 let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1172
1173 assert!(!store.was_compensated(&key).await.unwrap());
1174 store.mark_compensated(&key).await.unwrap();
1175 assert!(store.was_compensated(&key).await.unwrap());
1176 store.mark_compensated(&key).await.unwrap();
1177 assert!(store.was_compensated(&key).await.unwrap());
1178
1179 use redis::AsyncCommands;
1180 let mut conn = store.manager.clone();
1181 let _: () = conn.del(store.key(&key)).await.unwrap();
1182 }
1183
1184 #[cfg(feature = "persistence-redis")]
1185 #[tokio::test]
1186 async fn redis_compensation_idempotency_ttl_when_configured() {
1187 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1188 Ok(value) => value,
1189 Err(_) => return,
1190 };
1191
1192 let base = RedisCompensationIdempotencyStore::connect(&url)
1193 .await
1194 .unwrap();
1195 let prefix = format!(
1196 "ranvier:compensation:idempotency:ttl:test:{}",
1197 Uuid::new_v4().simple()
1198 );
1199 let store =
1200 RedisCompensationIdempotencyStore::with_prefix_and_ttl(base.manager.clone(), prefix, 1);
1201 let key = format!("ttl-{}", Uuid::new_v4().simple());
1202
1203 assert!(!store.was_compensated(&key).await.unwrap());
1204 store.mark_compensated(&key).await.unwrap();
1205 assert!(store.was_compensated(&key).await.unwrap());
1206
1207 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
1208 assert!(!store.was_compensated(&key).await.unwrap());
1209 }
1210}