1use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_persistence::{Journal, JournalError, PersistentRepr};
7use fred::prelude::*;
8
9use crate::codec::StoredRepr;
10use crate::config::RedisConfig;
11
12pub struct RedisJournal {
13 client: Pool,
14 cfg: RedisConfig,
15}
16
17impl RedisJournal {
18 pub async fn connect(cfg: RedisConfig) -> Result<Arc<Self>, JournalError> {
20 let mut builder = Builder::from_config(Config::from_url(&cfg.url).map_err(JournalError::backend)?);
21 let pool = builder
22 .set_policy(ReconnectPolicy::new_constant(0, 500))
23 .build_pool(cfg.pool_size)
24 .map_err(JournalError::backend)?;
25 pool.init().await.map_err(JournalError::backend)?;
26 Ok(Arc::new(Self { client: pool, cfg }))
27 }
28
29 pub fn from_pool(pool: Pool, cfg: RedisConfig) -> Arc<Self> {
30 Arc::new(Self { client: pool, cfg })
31 }
32
33 pub fn config(&self) -> &RedisConfig {
34 &self.cfg
35 }
36
37 pub fn client(&self) -> &Pool {
38 &self.client
39 }
40}
41
42fn encode(repr: &PersistentRepr) -> Result<String, JournalError> {
43 serde_json::to_string(&StoredRepr::from(repr)).map_err(JournalError::backend)
44}
45
46fn decode(raw: &str) -> Result<PersistentRepr, JournalError> {
47 let stored: StoredRepr = serde_json::from_str(raw).map_err(JournalError::backend)?;
48 Ok(stored.into_repr())
49}
50
51#[async_trait]
52impl Journal for RedisJournal {
53 async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
54 if messages.is_empty() {
55 return Ok(());
56 }
57 let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
58 std::collections::BTreeMap::new();
59 for m in messages {
60 by_pid.entry(m.persistence_id.clone()).or_default().push(m);
61 }
62
63 for (pid, batch) in by_pid {
64 let key = self.cfg.journal_key(&pid);
65 let current: i64 = self.client.zcard(&key).await.map_err(JournalError::backend)?;
66 for (expected, msg) in (current as u64 + 1..).zip(batch.iter()) {
67 if msg.sequence_nr != expected {
68 return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
69 }
70 }
71 let tx = self.client.next().multi();
72 for msg in &batch {
73 let payload = encode(msg)?;
74 let _: () = tx
75 .zadd(
76 &key,
77 Some(SetOptions::NX),
78 None,
79 false,
80 false,
81 (msg.sequence_nr as f64, payload.clone()),
82 )
83 .await
84 .map_err(JournalError::backend)?;
85 for tag in &msg.tags {
86 let tag_key = self.cfg.tag_key(tag);
87 let member = format!("{}:{}", msg.persistence_id, msg.sequence_nr);
88 let _: () = tx
89 .zadd(
90 &tag_key,
91 Some(SetOptions::NX),
92 None,
93 false,
94 false,
95 (msg.sequence_nr as f64, member),
96 )
97 .await
98 .map_err(JournalError::backend)?;
99 }
100 }
101 let _: () = tx.exec(true).await.map_err(JournalError::backend)?;
102 }
103 Ok(())
104 }
105
106 async fn delete_messages_to(
107 &self,
108 persistence_id: &str,
109 to_sequence_nr: u64,
110 ) -> Result<(), JournalError> {
111 let key = self.cfg.journal_key(persistence_id);
112 let members: Vec<String> = self
113 .client
114 .zrangebyscore(&key, 0.0, to_sequence_nr as f64, false, None)
115 .await
116 .map_err(JournalError::backend)?;
117 for raw in members {
118 let mut repr = decode(&raw)?;
119 repr.deleted = true;
120 let new_payload = encode(&repr)?;
121 let _: () = self
122 .client
123 .zadd(&key, Some(SetOptions::XX), None, false, false, (repr.sequence_nr as f64, new_payload))
124 .await
125 .map_err(JournalError::backend)?;
126 let _: () = self.client.zrem(&key, raw).await.map_err(JournalError::backend)?;
127 }
128 Ok(())
129 }
130
131 async fn replay_messages(
132 &self,
133 persistence_id: &str,
134 from: u64,
135 to: u64,
136 max: u64,
137 ) -> Result<Vec<PersistentRepr>, JournalError> {
138 let key = self.cfg.journal_key(persistence_id);
139 let limit = if max > i64::MAX as u64 { None } else { Some((0i64, max as i64)) };
140 let members: Vec<String> = self
141 .client
142 .zrangebyscore(&key, from as f64, to as f64, false, limit)
143 .await
144 .map_err(JournalError::backend)?;
145 let mut out = Vec::with_capacity(members.len());
146 for raw in members {
147 let repr = decode(&raw)?;
148 if !repr.deleted {
149 out.push(repr);
150 }
151 }
152 Ok(out)
153 }
154
155 async fn highest_sequence_nr(&self, persistence_id: &str, _from: u64) -> Result<u64, JournalError> {
156 let key = self.cfg.journal_key(persistence_id);
157 let members: Vec<(String, f64)> =
158 self.client.zrange(&key, -1, -1, None, false, None, true).await.map_err(JournalError::backend)?;
159 Ok(members.into_iter().next().map(|(_, s)| s as u64).unwrap_or(0))
160 }
161
162 async fn events_by_tag(
163 &self,
164 tag: &str,
165 from_offset: u64,
166 max: u64,
167 ) -> Result<Vec<PersistentRepr>, JournalError> {
168 let key = self.cfg.tag_key(tag);
169 let limit = if max > i64::MAX as u64 { None } else { Some((0i64, max as i64)) };
170 let entries: Vec<String> = self
171 .client
172 .zrangebyscore(&key, from_offset as f64, f64::INFINITY, false, limit)
173 .await
174 .map_err(JournalError::backend)?;
175 let mut out = Vec::new();
176 for entry in entries {
177 let (pid, _, seq) = match entry.rsplit_once(':') {
178 Some((p, s)) => (p.to_string(), entry.as_str(), s.parse::<u64>().unwrap_or(0)),
179 None => continue,
180 };
181 let journal_key = self.cfg.journal_key(&pid);
182 let members: Vec<String> = self
183 .client
184 .zrangebyscore(&journal_key, seq as f64, seq as f64, false, None)
185 .await
186 .map_err(JournalError::backend)?;
187 for raw in members {
188 let repr = decode(&raw)?;
189 if !repr.deleted {
190 out.push(repr);
191 }
192 }
193 }
194 Ok(out)
195 }
196}