1use anyhow::{Result, anyhow};
2use async_trait::async_trait;
3use serde::{Deserialize, Serialize};
4use std::collections::{HashMap, HashSet};
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
18pub struct PersistenceEnvelope {
19 pub trace_id: String,
20 pub circuit: String,
21 pub step: u64,
22 pub outcome_kind: String,
23 pub timestamp_ms: u64,
24 pub payload_hash: Option<String>,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub enum CompletionState {
30 Success,
31 Fault,
32 Cancelled,
33 Compensated,
34}
35
36#[cfg(feature = "persistence-postgres")]
37fn completion_state_to_wire(state: &CompletionState) -> &'static str {
38 match state {
39 CompletionState::Success => "success",
40 CompletionState::Fault => "fault",
41 CompletionState::Cancelled => "cancelled",
42 CompletionState::Compensated => "compensated",
43 }
44}
45
46#[cfg(feature = "persistence-postgres")]
47fn completion_state_from_wire(value: &str) -> Result<CompletionState> {
48 match value {
49 "success" => Ok(CompletionState::Success),
50 "fault" => Ok(CompletionState::Fault),
51 "cancelled" => Ok(CompletionState::Cancelled),
52 "compensated" => Ok(CompletionState::Compensated),
53 other => Err(anyhow!("unknown completion state value: {}", other)),
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct PersistedTrace {
60 pub trace_id: String,
61 pub circuit: String,
62 pub events: Vec<PersistenceEnvelope>,
63 pub resumed_from_step: Option<u64>,
64 pub completion: Option<CompletionState>,
65}
66
67#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
69pub struct ResumeCursor {
70 pub trace_id: String,
71 pub next_step: u64,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct PersistenceTraceId(pub String);
79
80impl PersistenceTraceId {
81 pub fn new(value: impl Into<String>) -> Self {
82 Self(value.into())
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub struct PersistenceAutoComplete(pub bool);
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
97pub struct CompensationContext {
98 pub trace_id: String,
99 pub circuit: String,
100 pub fault_kind: String,
101 pub fault_step: u64,
102 pub timestamp_ms: u64,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub struct CompensationAutoTrigger(pub bool);
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub struct CompensationRetryPolicy {
116 pub max_attempts: u32,
117 pub backoff_ms: u64,
118}
119
120impl Default for CompensationRetryPolicy {
121 fn default() -> Self {
122 Self {
123 max_attempts: 1,
124 backoff_ms: 0,
125 }
126 }
127}
128
129#[async_trait]
135pub trait CompensationHook: Send + Sync {
136 async fn compensate(&self, context: CompensationContext) -> Result<()>;
137}
138
139#[derive(Clone)]
141pub struct CompensationHandle {
142 inner: Arc<dyn CompensationHook>,
143}
144
145impl std::fmt::Debug for CompensationHandle {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("CompensationHandle").finish_non_exhaustive()
148 }
149}
150
151impl CompensationHandle {
152 pub fn from_hook<H>(hook: H) -> Self
154 where
155 H: CompensationHook + 'static,
156 {
157 Self {
158 inner: Arc::new(hook),
159 }
160 }
161
162 pub fn from_arc(hook: Arc<dyn CompensationHook>) -> Self {
164 Self { inner: hook }
165 }
166
167 pub fn hook(&self) -> Arc<dyn CompensationHook> {
169 self.inner.clone()
170 }
171}
172
173#[async_trait]
181pub trait CompensationIdempotencyStore: Send + Sync {
182 async fn was_compensated(&self, key: &str) -> Result<bool>;
183 async fn mark_compensated(&self, key: &str) -> Result<()>;
184}
185
186#[derive(Clone)]
188pub struct CompensationIdempotencyHandle {
189 inner: Arc<dyn CompensationIdempotencyStore>,
190}
191
192impl std::fmt::Debug for CompensationIdempotencyHandle {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("CompensationIdempotencyHandle")
195 .finish_non_exhaustive()
196 }
197}
198
199impl CompensationIdempotencyHandle {
200 pub fn from_store<S>(store: S) -> Self
201 where
202 S: CompensationIdempotencyStore + 'static,
203 {
204 Self {
205 inner: Arc::new(store),
206 }
207 }
208
209 pub fn from_arc(store: Arc<dyn CompensationIdempotencyStore>) -> Self {
210 Self { inner: store }
211 }
212
213 pub fn store(&self) -> Arc<dyn CompensationIdempotencyStore> {
214 self.inner.clone()
215 }
216}
217
218#[derive(Debug, Default, Clone)]
220pub struct InMemoryCompensationIdempotencyStore {
221 keys: Arc<RwLock<HashSet<String>>>,
222}
223
224impl InMemoryCompensationIdempotencyStore {
225 pub fn new() -> Self {
226 Self::default()
227 }
228}
229
230#[async_trait]
231impl CompensationIdempotencyStore for InMemoryCompensationIdempotencyStore {
232 async fn was_compensated(&self, key: &str) -> Result<bool> {
233 let guard = self.keys.read().await;
234 Ok(guard.contains(key))
235 }
236
237 async fn mark_compensated(&self, key: &str) -> Result<()> {
238 let mut guard = self.keys.write().await;
239 guard.insert(key.to_string());
240 Ok(())
241 }
242}
243
244#[cfg(feature = "persistence-postgres")]
245#[derive(Debug, Clone)]
246pub struct PostgresCompensationIdempotencyStore {
247 pool: sqlx::Pool<sqlx::Postgres>,
248 table: String,
249}
250
251#[cfg(feature = "persistence-postgres")]
252impl PostgresCompensationIdempotencyStore {
253 pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
255 Self::with_table_prefix(pool, "ranvier_persistence")
256 }
257
258 pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
260 let prefix = prefix.into();
261 Self {
262 pool,
263 table: format!("{}_compensation_idempotency", prefix),
264 }
265 }
266
267 pub async fn ensure_schema(&self) -> Result<()> {
269 let create = format!(
270 "CREATE TABLE IF NOT EXISTS {} (
271 idempotency_key TEXT PRIMARY KEY,
272 created_at_ms BIGINT NOT NULL
273 )",
274 self.table
275 );
276 sqlx::query(&create).execute(&self.pool).await?;
277 Ok(())
278 }
279
280 pub async fn purge_older_than_ms(&self, cutoff_ms: i64) -> Result<u64> {
282 let query = format!(
283 "DELETE FROM {}
284 WHERE created_at_ms < $1",
285 self.table
286 );
287 let rows = sqlx::query(&query)
288 .bind(cutoff_ms)
289 .execute(&self.pool)
290 .await?
291 .rows_affected();
292 Ok(rows)
293 }
294}
295
296#[cfg(feature = "persistence-postgres")]
297#[async_trait]
298impl CompensationIdempotencyStore for PostgresCompensationIdempotencyStore {
299 async fn was_compensated(&self, key: &str) -> Result<bool> {
300 let query = format!(
301 "SELECT 1
302 FROM {}
303 WHERE idempotency_key = $1
304 LIMIT 1",
305 self.table
306 );
307 let row: Option<i32> = sqlx::query_scalar(&query)
308 .bind(key)
309 .fetch_optional(&self.pool)
310 .await?;
311 Ok(row.is_some())
312 }
313
314 async fn mark_compensated(&self, key: &str) -> Result<()> {
315 let query = format!(
316 "INSERT INTO {} (idempotency_key, created_at_ms)
317 VALUES ($1, $2)
318 ON CONFLICT (idempotency_key) DO NOTHING",
319 self.table
320 );
321 let now_ms = std::time::SystemTime::now()
322 .duration_since(std::time::UNIX_EPOCH)?
323 .as_millis();
324 sqlx::query(&query)
325 .bind(key)
326 .bind(i64::try_from(now_ms)?)
327 .execute(&self.pool)
328 .await?;
329 Ok(())
330 }
331}
332
333#[cfg(feature = "persistence-redis")]
334#[derive(Clone)]
335pub struct RedisCompensationIdempotencyStore {
336 manager: redis::aio::ConnectionManager,
337 key_prefix: String,
338 ttl_seconds: Option<u64>,
339}
340
341#[cfg(feature = "persistence-redis")]
342impl std::fmt::Debug for RedisCompensationIdempotencyStore {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 f.debug_struct("RedisCompensationIdempotencyStore")
345 .field("key_prefix", &self.key_prefix)
346 .field("ttl_seconds", &self.ttl_seconds)
347 .finish_non_exhaustive()
348 }
349}
350
351#[cfg(feature = "persistence-redis")]
352impl RedisCompensationIdempotencyStore {
353 pub async fn connect(url: &str) -> Result<Self> {
355 let client = redis::Client::open(url)?;
356 let manager = redis::aio::ConnectionManager::new(client).await?;
357 Ok(Self {
358 manager,
359 key_prefix: "ranvier:compensation:idempotency".to_string(),
360 ttl_seconds: None,
361 })
362 }
363
364 pub fn with_prefix(
365 manager: redis::aio::ConnectionManager,
366 key_prefix: impl Into<String>,
367 ) -> Self {
368 Self {
369 manager,
370 key_prefix: key_prefix.into(),
371 ttl_seconds: None,
372 }
373 }
374
375 pub fn with_prefix_and_ttl(
376 manager: redis::aio::ConnectionManager,
377 key_prefix: impl Into<String>,
378 ttl_seconds: u64,
379 ) -> Self {
380 Self {
381 manager,
382 key_prefix: key_prefix.into(),
383 ttl_seconds: Some(ttl_seconds),
384 }
385 }
386
387 fn key(&self, idempotency_key: &str) -> String {
388 format!("{}:{}", self.key_prefix, idempotency_key)
389 }
390}
391
392#[cfg(feature = "persistence-redis")]
393#[async_trait]
394impl CompensationIdempotencyStore for RedisCompensationIdempotencyStore {
395 async fn was_compensated(&self, key: &str) -> Result<bool> {
396 use redis::AsyncCommands;
397 let mut conn = self.manager.clone();
398 let exists: bool = conn.exists(self.key(key)).await?;
399 Ok(exists)
400 }
401
402 async fn mark_compensated(&self, key: &str) -> Result<()> {
403 use redis::AsyncCommands;
404 let mut conn = self.manager.clone();
405 let redis_key = self.key(key);
406 let inserted: bool = conn.set_nx(&redis_key, "1").await?;
407 if inserted {
408 if let Some(ttl_seconds) = self.ttl_seconds {
409 let ttl_i64 = i64::try_from(ttl_seconds)?;
410 let _: bool = conn.expire(&redis_key, ttl_i64).await?;
411 }
412 }
413 Ok(())
414 }
415}
416
417#[async_trait]
438pub trait PersistenceStore: Send + Sync {
439 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()>;
440 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>>;
441 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor>;
442 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()>;
443}
444
445#[derive(Clone)]
447pub struct PersistenceHandle {
448 inner: Arc<dyn PersistenceStore>,
449}
450
451impl std::fmt::Debug for PersistenceHandle {
452 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
453 f.debug_struct("PersistenceHandle").finish_non_exhaustive()
454 }
455}
456
457impl PersistenceHandle {
458 pub fn from_store<S>(store: S) -> Self
460 where
461 S: PersistenceStore + 'static,
462 {
463 Self {
464 inner: Arc::new(store),
465 }
466 }
467
468 pub fn from_arc(store: Arc<dyn PersistenceStore>) -> Self {
470 Self { inner: store }
471 }
472
473 pub fn store(&self) -> Arc<dyn PersistenceStore> {
475 self.inner.clone()
476 }
477}
478
479#[derive(Debug, Default, Clone)]
481pub struct InMemoryPersistenceStore {
482 inner: Arc<RwLock<HashMap<String, PersistedTrace>>>,
483}
484
485impl InMemoryPersistenceStore {
486 pub fn new() -> Self {
487 Self::default()
488 }
489}
490
491#[async_trait]
492impl PersistenceStore for InMemoryPersistenceStore {
493 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
494 let mut guard = self.inner.write().await;
495 let entry = guard
496 .entry(envelope.trace_id.clone())
497 .or_insert_with(|| PersistedTrace {
498 trace_id: envelope.trace_id.clone(),
499 circuit: envelope.circuit.clone(),
500 events: Vec::new(),
501 resumed_from_step: None,
502 completion: None,
503 });
504
505 if entry.circuit != envelope.circuit {
506 return Err(anyhow!(
507 "trace_id {} already exists for circuit {}, got {}",
508 envelope.trace_id,
509 entry.circuit,
510 envelope.circuit
511 ));
512 }
513 if entry.completion.is_some() {
514 return Err(anyhow!(
515 "trace_id {} is already completed and cannot accept new events",
516 envelope.trace_id
517 ));
518 }
519 entry.events.push(envelope);
520 entry.events.sort_by_key(|e| e.step);
521 Ok(())
522 }
523
524 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
525 let guard = self.inner.read().await;
526 Ok(guard.get(trace_id).cloned())
527 }
528
529 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
530 let mut guard = self.inner.write().await;
531 let trace = guard
532 .get_mut(trace_id)
533 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
534 trace.resumed_from_step = Some(resume_from_step);
535 Ok(ResumeCursor {
536 trace_id: trace_id.to_string(),
537 next_step: resume_from_step.saturating_add(1),
538 })
539 }
540
541 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
542 let mut guard = self.inner.write().await;
543 let trace = guard
544 .get_mut(trace_id)
545 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
546 trace.completion = Some(completion);
547 Ok(())
548 }
549}
550
551#[cfg(feature = "persistence-postgres")]
552#[derive(Debug, Clone)]
553pub struct PostgresPersistenceStore {
554 pool: sqlx::Pool<sqlx::Postgres>,
555 events_table: String,
556 state_table: String,
557}
558
559#[cfg(feature = "persistence-postgres")]
560#[derive(sqlx::FromRow)]
561struct PostgresEventRow {
562 trace_id: String,
563 circuit: String,
564 step: i64,
565 outcome_kind: String,
566 timestamp_ms: i64,
567 payload_hash: Option<String>,
568}
569
570#[cfg(feature = "persistence-postgres")]
571#[derive(sqlx::FromRow)]
572struct PostgresStateRow {
573 trace_id: String,
574 circuit: String,
575 resumed_from_step: Option<i64>,
576 completion: Option<String>,
577}
578
579#[cfg(feature = "persistence-postgres")]
580impl PostgresPersistenceStore {
581 pub fn new(pool: sqlx::Pool<sqlx::Postgres>) -> Self {
587 Self::with_table_prefix(pool, "ranvier_persistence")
588 }
589
590 pub fn with_table_prefix(pool: sqlx::Pool<sqlx::Postgres>, prefix: impl Into<String>) -> Self {
592 let prefix = prefix.into();
593 Self {
594 pool,
595 events_table: format!("{}_events", prefix),
596 state_table: format!("{}_state", prefix),
597 }
598 }
599
600 pub async fn ensure_schema(&self) -> Result<()> {
602 let create_state = format!(
603 "CREATE TABLE IF NOT EXISTS {} (
604 trace_id TEXT PRIMARY KEY,
605 circuit TEXT NOT NULL,
606 resumed_from_step BIGINT NULL,
607 completion TEXT NULL
608 )",
609 self.state_table
610 );
611 sqlx::query(&create_state).execute(&self.pool).await?;
612
613 let create_events = format!(
614 "CREATE TABLE IF NOT EXISTS {} (
615 trace_id TEXT NOT NULL,
616 circuit TEXT NOT NULL,
617 step BIGINT NOT NULL,
618 outcome_kind TEXT NOT NULL,
619 timestamp_ms BIGINT NOT NULL,
620 payload_hash TEXT NULL,
621 PRIMARY KEY (trace_id, step)
622 )",
623 self.events_table
624 );
625 sqlx::query(&create_events).execute(&self.pool).await?;
626 Ok(())
627 }
628}
629
630#[cfg(feature = "persistence-postgres")]
631#[async_trait]
632impl PersistenceStore for PostgresPersistenceStore {
633 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
634 let insert_state = format!(
635 "INSERT INTO {} (trace_id, circuit, resumed_from_step, completion)
636 VALUES ($1, $2, NULL, NULL)
637 ON CONFLICT (trace_id) DO NOTHING",
638 self.state_table
639 );
640 sqlx::query(&insert_state)
641 .bind(&envelope.trace_id)
642 .bind(&envelope.circuit)
643 .execute(&self.pool)
644 .await?;
645
646 let read_state = format!(
647 "SELECT circuit FROM {} WHERE trace_id = $1",
648 self.state_table
649 );
650 let existing_circuit: Option<String> = sqlx::query_scalar(&read_state)
651 .bind(&envelope.trace_id)
652 .fetch_optional(&self.pool)
653 .await?;
654 if existing_circuit.as_deref() != Some(envelope.circuit.as_str()) {
655 return Err(anyhow!(
656 "trace_id {} already exists for another circuit",
657 envelope.trace_id
658 ));
659 }
660
661 let completion_query = format!(
662 "SELECT completion FROM {} WHERE trace_id = $1",
663 self.state_table
664 );
665 let completion: Option<Option<String>> = sqlx::query_scalar(&completion_query)
666 .bind(&envelope.trace_id)
667 .fetch_optional(&self.pool)
668 .await?;
669 if completion.flatten().is_some() {
670 return Err(anyhow!(
671 "trace_id {} is already completed and cannot accept new events",
672 envelope.trace_id
673 ));
674 }
675
676 let step_i64 = i64::try_from(envelope.step)?;
677 let ts_i64 = i64::try_from(envelope.timestamp_ms)?;
678 let insert_event = format!(
679 "INSERT INTO {} (trace_id, circuit, step, outcome_kind, timestamp_ms, payload_hash)
680 VALUES ($1, $2, $3, $4, $5, $6)",
681 self.events_table
682 );
683 sqlx::query(&insert_event)
684 .bind(&envelope.trace_id)
685 .bind(&envelope.circuit)
686 .bind(step_i64)
687 .bind(&envelope.outcome_kind)
688 .bind(ts_i64)
689 .bind(&envelope.payload_hash)
690 .execute(&self.pool)
691 .await?;
692 Ok(())
693 }
694
695 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
696 let state_query = format!(
697 "SELECT trace_id, circuit, resumed_from_step, completion
698 FROM {}
699 WHERE trace_id = $1",
700 self.state_table
701 );
702 let Some(state): Option<PostgresStateRow> = sqlx::query_as(&state_query)
703 .bind(trace_id)
704 .fetch_optional(&self.pool)
705 .await?
706 else {
707 return Ok(None);
708 };
709
710 let events_query = format!(
711 "SELECT trace_id, circuit, step, outcome_kind, timestamp_ms, payload_hash
712 FROM {}
713 WHERE trace_id = $1
714 ORDER BY step ASC",
715 self.events_table
716 );
717 let rows: Vec<PostgresEventRow> = sqlx::query_as(&events_query)
718 .bind(trace_id)
719 .fetch_all(&self.pool)
720 .await?;
721
722 let mut events = Vec::with_capacity(rows.len());
723 for row in rows {
724 events.push(PersistenceEnvelope {
725 trace_id: row.trace_id,
726 circuit: row.circuit,
727 step: u64::try_from(row.step)?,
728 outcome_kind: row.outcome_kind,
729 timestamp_ms: u64::try_from(row.timestamp_ms)?,
730 payload_hash: row.payload_hash,
731 });
732 }
733
734 let completion = match state.completion {
735 Some(value) => Some(completion_state_from_wire(&value)?),
736 None => None,
737 };
738
739 Ok(Some(PersistedTrace {
740 trace_id: state.trace_id,
741 circuit: state.circuit,
742 events,
743 resumed_from_step: state.resumed_from_step.map(u64::try_from).transpose()?,
744 completion,
745 }))
746 }
747
748 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
749 let query = format!(
750 "UPDATE {}
751 SET resumed_from_step = $2
752 WHERE trace_id = $1",
753 self.state_table
754 );
755 let rows = sqlx::query(&query)
756 .bind(trace_id)
757 .bind(i64::try_from(resume_from_step)?)
758 .execute(&self.pool)
759 .await?
760 .rows_affected();
761 if rows == 0 {
762 return Err(anyhow!("trace_id {} not found", trace_id));
763 }
764 Ok(ResumeCursor {
765 trace_id: trace_id.to_string(),
766 next_step: resume_from_step.saturating_add(1),
767 })
768 }
769
770 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
771 let query = format!(
772 "UPDATE {}
773 SET completion = $2
774 WHERE trace_id = $1",
775 self.state_table
776 );
777 let rows = sqlx::query(&query)
778 .bind(trace_id)
779 .bind(completion_state_to_wire(&completion))
780 .execute(&self.pool)
781 .await?
782 .rows_affected();
783 if rows == 0 {
784 return Err(anyhow!("trace_id {} not found", trace_id));
785 }
786 Ok(())
787 }
788}
789
790#[cfg(feature = "persistence-redis")]
791#[derive(Clone)]
792pub struct RedisPersistenceStore {
793 manager: redis::aio::ConnectionManager,
794 key_prefix: String,
795}
796
797#[cfg(feature = "persistence-redis")]
798impl std::fmt::Debug for RedisPersistenceStore {
799 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
800 f.debug_struct("RedisPersistenceStore")
801 .field("key_prefix", &self.key_prefix)
802 .finish_non_exhaustive()
803 }
804}
805
806#[cfg(feature = "persistence-redis")]
807impl RedisPersistenceStore {
808 pub async fn connect(url: &str) -> Result<Self> {
812 let client = redis::Client::open(url)?;
813 let manager = redis::aio::ConnectionManager::new(client).await?;
814 Ok(Self {
815 manager,
816 key_prefix: "ranvier:persistence".to_string(),
817 })
818 }
819
820 pub fn with_prefix(
821 manager: redis::aio::ConnectionManager,
822 key_prefix: impl Into<String>,
823 ) -> Self {
824 Self {
825 manager,
826 key_prefix: key_prefix.into(),
827 }
828 }
829
830 fn key(&self, trace_id: &str) -> String {
831 format!("{}:{}", self.key_prefix, trace_id)
832 }
833
834 async fn write_trace(&self, trace: &PersistedTrace) -> Result<()> {
835 use redis::AsyncCommands;
836 let key = self.key(&trace.trace_id);
837 let payload = serde_json::to_string(trace)?;
838 let mut conn = self.manager.clone();
839 conn.set::<_, _, ()>(key, payload).await?;
840 Ok(())
841 }
842}
843
844#[cfg(feature = "persistence-redis")]
845#[async_trait]
846impl PersistenceStore for RedisPersistenceStore {
847 async fn append(&self, envelope: PersistenceEnvelope) -> Result<()> {
848 let mut trace = self
849 .load(&envelope.trace_id)
850 .await?
851 .unwrap_or_else(|| PersistedTrace {
852 trace_id: envelope.trace_id.clone(),
853 circuit: envelope.circuit.clone(),
854 events: Vec::new(),
855 resumed_from_step: None,
856 completion: None,
857 });
858
859 if trace.circuit != envelope.circuit {
860 return Err(anyhow!(
861 "trace_id {} already exists for circuit {}, got {}",
862 envelope.trace_id,
863 trace.circuit,
864 envelope.circuit
865 ));
866 }
867 if trace.completion.is_some() {
868 return Err(anyhow!(
869 "trace_id {} is already completed and cannot accept new events",
870 envelope.trace_id
871 ));
872 }
873
874 trace.events.push(envelope);
875 trace.events.sort_by_key(|event| event.step);
876 self.write_trace(&trace).await?;
877 Ok(())
878 }
879
880 async fn load(&self, trace_id: &str) -> Result<Option<PersistedTrace>> {
881 use redis::AsyncCommands;
882 let key = self.key(trace_id);
883 let mut conn = self.manager.clone();
884 let payload: Option<String> = conn.get(key).await?;
885 let trace = payload
886 .map(|raw| serde_json::from_str::<PersistedTrace>(&raw))
887 .transpose()?;
888 Ok(trace)
889 }
890
891 async fn resume(&self, trace_id: &str, resume_from_step: u64) -> Result<ResumeCursor> {
892 let mut trace = self
893 .load(trace_id)
894 .await?
895 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
896 trace.resumed_from_step = Some(resume_from_step);
897 self.write_trace(&trace).await?;
898 Ok(ResumeCursor {
899 trace_id: trace_id.to_string(),
900 next_step: resume_from_step.saturating_add(1),
901 })
902 }
903
904 async fn complete(&self, trace_id: &str, completion: CompletionState) -> Result<()> {
905 let mut trace = self
906 .load(trace_id)
907 .await?
908 .ok_or_else(|| anyhow!("trace_id {} not found", trace_id))?;
909 trace.completion = Some(completion);
910 self.write_trace(&trace).await?;
911 Ok(())
912 }
913}
914
915#[cfg(test)]
916mod tests {
917 use super::*;
918 #[cfg(any(feature = "persistence-postgres", feature = "persistence-redis"))]
919 use uuid::Uuid;
920
921 fn envelope(step: u64, outcome_kind: &str) -> PersistenceEnvelope {
922 PersistenceEnvelope {
923 trace_id: "trace-1".to_string(),
924 circuit: "OrderCircuit".to_string(),
925 step,
926 outcome_kind: outcome_kind.to_string(),
927 timestamp_ms: 1_700_000_000_000 + step,
928 payload_hash: Some(format!("hash-{}", step)),
929 }
930 }
931
932 #[tokio::test]
933 async fn append_and_load_roundtrip() {
934 let store = InMemoryPersistenceStore::new();
935 store.append(envelope(1, "Next")).await.unwrap();
936 store.append(envelope(2, "Branch")).await.unwrap();
937
938 let loaded = store.load("trace-1").await.unwrap().unwrap();
939 assert_eq!(loaded.trace_id, "trace-1");
940 assert_eq!(loaded.circuit, "OrderCircuit");
941 assert_eq!(loaded.events.len(), 2);
942 assert_eq!(loaded.events[0].step, 1);
943 assert_eq!(loaded.events[1].outcome_kind, "Branch");
944 assert_eq!(loaded.completion, None);
945 }
946
947 #[tokio::test]
948 async fn resume_records_cursor() {
949 let store = InMemoryPersistenceStore::new();
950 store.append(envelope(3, "Fault")).await.unwrap();
951
952 let cursor = store.resume("trace-1", 3).await.unwrap();
953 assert_eq!(
954 cursor,
955 ResumeCursor {
956 trace_id: "trace-1".to_string(),
957 next_step: 4
958 }
959 );
960
961 let loaded = store.load("trace-1").await.unwrap().unwrap();
962 assert_eq!(loaded.resumed_from_step, Some(3));
963 }
964
965 #[tokio::test]
966 async fn complete_marks_trace_and_blocks_append() {
967 let store = InMemoryPersistenceStore::new();
968 store.append(envelope(1, "Next")).await.unwrap();
969 store
970 .complete("trace-1", CompletionState::Success)
971 .await
972 .unwrap();
973
974 let loaded = store.load("trace-1").await.unwrap().unwrap();
975 assert_eq!(loaded.completion, Some(CompletionState::Success));
976
977 let err = store.append(envelope(2, "Next")).await.unwrap_err();
978 assert!(
979 err.to_string()
980 .contains("is already completed and cannot accept new events")
981 );
982 }
983
984 #[tokio::test]
985 async fn append_rejects_cross_circuit_trace_reuse() {
986 let store = InMemoryPersistenceStore::new();
987 store.append(envelope(1, "Next")).await.unwrap();
988
989 let mut invalid = envelope(2, "Next");
990 invalid.circuit = "AnotherCircuit".to_string();
991 let err = store.append(invalid).await.unwrap_err();
992 assert!(
993 err.to_string()
994 .contains("already exists for circuit OrderCircuit")
995 );
996 }
997
998 #[tokio::test]
999 async fn in_memory_compensation_idempotency_roundtrip() {
1000 let store = InMemoryCompensationIdempotencyStore::new();
1001 let key = "trace-a:OrderFlow:Fault";
1002
1003 assert!(!store.was_compensated(key).await.unwrap());
1004 store.mark_compensated(key).await.unwrap();
1005 assert!(store.was_compensated(key).await.unwrap());
1006 }
1007
1008 #[cfg(feature = "persistence-postgres")]
1009 #[tokio::test]
1010 async fn postgres_store_roundtrip_when_configured() {
1011 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1012 Ok(value) => value,
1013 Err(_) => return,
1014 };
1015
1016 let pool = sqlx::postgres::PgPoolOptions::new()
1017 .max_connections(5)
1018 .connect(&url)
1019 .await
1020 .unwrap();
1021 let table_prefix = format!("ranvier_persistence_test_{}", Uuid::new_v4().simple());
1022 let store = PostgresPersistenceStore::with_table_prefix(pool.clone(), table_prefix.clone());
1023 store.ensure_schema().await.unwrap();
1024
1025 let trace_id = format!("trace-{}", Uuid::new_v4().simple());
1026 let circuit = "PgCircuit".to_string();
1027
1028 let mut first = envelope(1, "Next");
1029 first.trace_id = trace_id.clone();
1030 first.circuit = circuit.clone();
1031 store.append(first).await.unwrap();
1032
1033 let mut second = envelope(2, "Branch");
1034 second.trace_id = trace_id.clone();
1035 second.circuit = circuit.clone();
1036 store.append(second).await.unwrap();
1037
1038 let cursor = store.resume(&trace_id, 2).await.unwrap();
1039 assert_eq!(cursor.next_step, 3);
1040
1041 store
1042 .complete(&trace_id, CompletionState::Compensated)
1043 .await
1044 .unwrap();
1045
1046 let loaded = store.load(&trace_id).await.unwrap().unwrap();
1047 assert_eq!(loaded.trace_id, trace_id);
1048 assert_eq!(loaded.circuit, circuit);
1049 assert_eq!(loaded.events.len(), 2);
1050 assert_eq!(loaded.resumed_from_step, Some(2));
1051 assert_eq!(loaded.completion, Some(CompletionState::Compensated));
1052
1053 let drop_events = format!("DROP TABLE IF EXISTS {}", store.events_table);
1054 let drop_state = format!("DROP TABLE IF EXISTS {}", store.state_table);
1055 sqlx::query(&drop_events).execute(&pool).await.unwrap();
1056 sqlx::query(&drop_state).execute(&pool).await.unwrap();
1057 }
1058
1059 #[cfg(feature = "persistence-redis")]
1060 #[tokio::test]
1061 async fn redis_store_roundtrip_when_configured() {
1062 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1063 Ok(value) => value,
1064 Err(_) => return,
1065 };
1066
1067 let base = RedisPersistenceStore::connect(&url).await.unwrap();
1068 let prefix = format!("ranvier:persistence:test:{}", Uuid::new_v4().simple());
1069 let store = RedisPersistenceStore::with_prefix(base.manager.clone(), prefix);
1070
1071 let trace_id = format!("trace-{}", Uuid::new_v4().simple());
1072 let circuit = "RedisCircuit".to_string();
1073
1074 let mut first = envelope(1, "Next");
1075 first.trace_id = trace_id.clone();
1076 first.circuit = circuit.clone();
1077 store.append(first).await.unwrap();
1078
1079 let mut second = envelope(2, "Fault");
1080 second.trace_id = trace_id.clone();
1081 second.circuit = circuit.clone();
1082 store.append(second).await.unwrap();
1083
1084 let cursor = store.resume(&trace_id, 2).await.unwrap();
1085 assert_eq!(cursor.next_step, 3);
1086
1087 store
1088 .complete(&trace_id, CompletionState::Fault)
1089 .await
1090 .unwrap();
1091
1092 let loaded = store.load(&trace_id).await.unwrap().unwrap();
1093 assert_eq!(loaded.trace_id, trace_id);
1094 assert_eq!(loaded.circuit, circuit);
1095 assert_eq!(loaded.events.len(), 2);
1096 assert_eq!(loaded.resumed_from_step, Some(2));
1097 assert_eq!(loaded.completion, Some(CompletionState::Fault));
1098
1099 use redis::AsyncCommands;
1100 let key = store.key(&trace_id);
1101 let mut conn = store.manager.clone();
1102 let _: () = conn.del(key).await.unwrap();
1103 }
1104
1105 #[cfg(feature = "persistence-postgres")]
1106 #[tokio::test]
1107 async fn postgres_compensation_idempotency_roundtrip_when_configured() {
1108 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1109 Ok(value) => value,
1110 Err(_) => return,
1111 };
1112
1113 let pool = sqlx::postgres::PgPoolOptions::new()
1114 .max_connections(5)
1115 .connect(&url)
1116 .await
1117 .unwrap();
1118 let table_prefix = format!(
1119 "ranvier_compensation_idempotency_test_{}",
1120 Uuid::new_v4().simple()
1121 );
1122 let store =
1123 PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1124 store.ensure_schema().await.unwrap();
1125
1126 let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1127 assert!(!store.was_compensated(&key).await.unwrap());
1128 store.mark_compensated(&key).await.unwrap();
1129 assert!(store.was_compensated(&key).await.unwrap());
1130 store.mark_compensated(&key).await.unwrap();
1131 assert!(store.was_compensated(&key).await.unwrap());
1132
1133 let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1134 sqlx::query(&drop_table).execute(&pool).await.unwrap();
1135 }
1136
1137 #[cfg(feature = "persistence-postgres")]
1138 #[tokio::test]
1139 async fn postgres_compensation_idempotency_purge_when_configured() {
1140 let url = match std::env::var("RANVIER_PERSISTENCE_POSTGRES_URL") {
1141 Ok(value) => value,
1142 Err(_) => return,
1143 };
1144
1145 let pool = sqlx::postgres::PgPoolOptions::new()
1146 .max_connections(5)
1147 .connect(&url)
1148 .await
1149 .unwrap();
1150 let table_prefix = format!(
1151 "ranvier_compensation_idempotency_purge_test_{}",
1152 Uuid::new_v4().simple()
1153 );
1154 let store =
1155 PostgresCompensationIdempotencyStore::with_table_prefix(pool.clone(), &table_prefix);
1156 store.ensure_schema().await.unwrap();
1157
1158 let stale_key = format!("stale-{}", Uuid::new_v4().simple());
1159 let fresh_key = format!("fresh-{}", Uuid::new_v4().simple());
1160 store.mark_compensated(&stale_key).await.unwrap();
1161 store.mark_compensated(&fresh_key).await.unwrap();
1162
1163 let force_stale_query = format!(
1164 "UPDATE {}
1165 SET created_at_ms = 0
1166 WHERE idempotency_key = $1",
1167 store.table
1168 );
1169 sqlx::query(&force_stale_query)
1170 .bind(&stale_key)
1171 .execute(&pool)
1172 .await
1173 .unwrap();
1174
1175 let purged = store.purge_older_than_ms(1).await.unwrap();
1176 assert!(purged >= 1);
1177 assert!(!store.was_compensated(&stale_key).await.unwrap());
1178 assert!(store.was_compensated(&fresh_key).await.unwrap());
1179
1180 let drop_table = format!("DROP TABLE IF EXISTS {}", store.table);
1181 sqlx::query(&drop_table).execute(&pool).await.unwrap();
1182 }
1183
1184 #[cfg(feature = "persistence-redis")]
1185 #[tokio::test]
1186 async fn redis_compensation_idempotency_roundtrip_when_configured() {
1187 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1188 Ok(value) => value,
1189 Err(_) => return,
1190 };
1191
1192 let base = RedisCompensationIdempotencyStore::connect(&url)
1193 .await
1194 .unwrap();
1195 let prefix = format!(
1196 "ranvier:compensation:idempotency:test:{}",
1197 Uuid::new_v4().simple()
1198 );
1199 let store = RedisCompensationIdempotencyStore::with_prefix(base.manager.clone(), prefix);
1200 let key = format!("trace-{}:OrderFlow:Fault", Uuid::new_v4().simple());
1201
1202 assert!(!store.was_compensated(&key).await.unwrap());
1203 store.mark_compensated(&key).await.unwrap();
1204 assert!(store.was_compensated(&key).await.unwrap());
1205 store.mark_compensated(&key).await.unwrap();
1206 assert!(store.was_compensated(&key).await.unwrap());
1207
1208 use redis::AsyncCommands;
1209 let mut conn = store.manager.clone();
1210 let _: () = conn.del(store.key(&key)).await.unwrap();
1211 }
1212
1213 #[cfg(feature = "persistence-redis")]
1214 #[tokio::test]
1215 async fn redis_compensation_idempotency_ttl_when_configured() {
1216 let url = match std::env::var("RANVIER_PERSISTENCE_REDIS_URL") {
1217 Ok(value) => value,
1218 Err(_) => return,
1219 };
1220
1221 let base = RedisCompensationIdempotencyStore::connect(&url)
1222 .await
1223 .unwrap();
1224 let prefix = format!(
1225 "ranvier:compensation:idempotency:ttl:test:{}",
1226 Uuid::new_v4().simple()
1227 );
1228 let store =
1229 RedisCompensationIdempotencyStore::with_prefix_and_ttl(base.manager.clone(), prefix, 1);
1230 let key = format!("ttl-{}", Uuid::new_v4().simple());
1231
1232 assert!(!store.was_compensated(&key).await.unwrap());
1233 store.mark_compensated(&key).await.unwrap();
1234 assert!(store.was_compensated(&key).await.unwrap());
1235
1236 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
1237 assert!(!store.was_compensated(&key).await.unwrap());
1238 }
1239}