Skip to main content

atomr_persistence_redis/
journal.rs

1//! Journal implementation on Redis sorted sets.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_persistence::{Journal, JournalError, PersistentRepr};
7use fred::prelude::*;
8
9use crate::codec::StoredRepr;
10use crate::config::RedisConfig;
11
12pub struct RedisJournal {
13    client: Pool,
14    cfg: RedisConfig,
15}
16
17impl RedisJournal {
18    /// Connect to Redis using `cfg.url` and return a ready journal.
19    pub async fn connect(cfg: RedisConfig) -> Result<Arc<Self>, JournalError> {
20        let mut builder = Builder::from_config(Config::from_url(&cfg.url).map_err(JournalError::backend)?);
21        let pool = builder
22            .set_policy(ReconnectPolicy::new_constant(0, 500))
23            .build_pool(cfg.pool_size)
24            .map_err(JournalError::backend)?;
25        pool.init().await.map_err(JournalError::backend)?;
26        Ok(Arc::new(Self { client: pool, cfg }))
27    }
28
29    pub fn from_pool(pool: Pool, cfg: RedisConfig) -> Arc<Self> {
30        Arc::new(Self { client: pool, cfg })
31    }
32
33    pub fn config(&self) -> &RedisConfig {
34        &self.cfg
35    }
36
37    pub fn client(&self) -> &Pool {
38        &self.client
39    }
40}
41
42fn encode(repr: &PersistentRepr) -> Result<String, JournalError> {
43    serde_json::to_string(&StoredRepr::from(repr)).map_err(JournalError::backend)
44}
45
46fn decode(raw: &str) -> Result<PersistentRepr, JournalError> {
47    let stored: StoredRepr = serde_json::from_str(raw).map_err(JournalError::backend)?;
48    Ok(stored.into_repr())
49}
50
51#[async_trait]
52impl Journal for RedisJournal {
53    async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
54        if messages.is_empty() {
55            return Ok(());
56        }
57        let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
58            std::collections::BTreeMap::new();
59        for m in messages {
60            by_pid.entry(m.persistence_id.clone()).or_default().push(m);
61        }
62
63        for (pid, batch) in by_pid {
64            let key = self.cfg.journal_key(&pid);
65            let current: i64 = self.client.zcard(&key).await.map_err(JournalError::backend)?;
66            for (expected, msg) in (current as u64 + 1..).zip(batch.iter()) {
67                if msg.sequence_nr != expected {
68                    return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
69                }
70            }
71            let tx = self.client.next().multi();
72            for msg in &batch {
73                let payload = encode(msg)?;
74                let _: () = tx
75                    .zadd(
76                        &key,
77                        Some(SetOptions::NX),
78                        None,
79                        false,
80                        false,
81                        (msg.sequence_nr as f64, payload.clone()),
82                    )
83                    .await
84                    .map_err(JournalError::backend)?;
85                for tag in &msg.tags {
86                    let tag_key = self.cfg.tag_key(tag);
87                    let member = format!("{}:{}", msg.persistence_id, msg.sequence_nr);
88                    let _: () = tx
89                        .zadd(
90                            &tag_key,
91                            Some(SetOptions::NX),
92                            None,
93                            false,
94                            false,
95                            (msg.sequence_nr as f64, member),
96                        )
97                        .await
98                        .map_err(JournalError::backend)?;
99                }
100            }
101            let _: () = tx.exec(true).await.map_err(JournalError::backend)?;
102        }
103        Ok(())
104    }
105
106    async fn delete_messages_to(
107        &self,
108        persistence_id: &str,
109        to_sequence_nr: u64,
110    ) -> Result<(), JournalError> {
111        let key = self.cfg.journal_key(persistence_id);
112        let members: Vec<String> = self
113            .client
114            .zrangebyscore(&key, 0.0, to_sequence_nr as f64, false, None)
115            .await
116            .map_err(JournalError::backend)?;
117        for raw in members {
118            let mut repr = decode(&raw)?;
119            repr.deleted = true;
120            let new_payload = encode(&repr)?;
121            let _: () = self
122                .client
123                .zadd(&key, Some(SetOptions::XX), None, false, false, (repr.sequence_nr as f64, new_payload))
124                .await
125                .map_err(JournalError::backend)?;
126            let _: () = self.client.zrem(&key, raw).await.map_err(JournalError::backend)?;
127        }
128        Ok(())
129    }
130
131    async fn replay_messages(
132        &self,
133        persistence_id: &str,
134        from: u64,
135        to: u64,
136        max: u64,
137    ) -> Result<Vec<PersistentRepr>, JournalError> {
138        let key = self.cfg.journal_key(persistence_id);
139        let limit = if max > i64::MAX as u64 { None } else { Some((0i64, max as i64)) };
140        let members: Vec<String> = self
141            .client
142            .zrangebyscore(&key, from as f64, to as f64, false, limit)
143            .await
144            .map_err(JournalError::backend)?;
145        let mut out = Vec::with_capacity(members.len());
146        for raw in members {
147            let repr = decode(&raw)?;
148            if !repr.deleted {
149                out.push(repr);
150            }
151        }
152        Ok(out)
153    }
154
155    async fn highest_sequence_nr(&self, persistence_id: &str, _from: u64) -> Result<u64, JournalError> {
156        let key = self.cfg.journal_key(persistence_id);
157        let members: Vec<(String, f64)> =
158            self.client.zrange(&key, -1, -1, None, false, None, true).await.map_err(JournalError::backend)?;
159        Ok(members.into_iter().next().map(|(_, s)| s as u64).unwrap_or(0))
160    }
161
162    async fn events_by_tag(
163        &self,
164        tag: &str,
165        from_offset: u64,
166        max: u64,
167    ) -> Result<Vec<PersistentRepr>, JournalError> {
168        let key = self.cfg.tag_key(tag);
169        let limit = if max > i64::MAX as u64 { None } else { Some((0i64, max as i64)) };
170        let entries: Vec<String> = self
171            .client
172            .zrangebyscore(&key, from_offset as f64, f64::INFINITY, false, limit)
173            .await
174            .map_err(JournalError::backend)?;
175        let mut out = Vec::new();
176        for entry in entries {
177            let (pid, _, seq) = match entry.rsplit_once(':') {
178                Some((p, s)) => (p.to_string(), entry.as_str(), s.parse::<u64>().unwrap_or(0)),
179                None => continue,
180            };
181            let journal_key = self.cfg.journal_key(&pid);
182            let members: Vec<String> = self
183                .client
184                .zrangebyscore(&journal_key, seq as f64, seq as f64, false, None)
185                .await
186                .map_err(JournalError::backend)?;
187            for raw in members {
188                let repr = decode(&raw)?;
189                if !repr.deleted {
190                    out.push(repr);
191                }
192            }
193        }
194        Ok(out)
195    }
196}