Skip to main content

entelix_persistence/redis/
checkpointer.rs

1//! `RedisCheckpointer<S>` — [`Checkpointer<S>`] over Redis sorted
2//! sets keyed by `step`. A companion HASH provides O(1) lookup by
3//! checkpoint id. Keys partition by `(tenant_id, thread_id)` per
4//! Invariant 11 — cross-tenant reads are not constructible from
5//! this surface.
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use entelix_core::ThreadKey;
12use entelix_core::{Error, Result};
13use entelix_graph::{Checkpoint, CheckpointId, Checkpointer};
14use redis::aio::ConnectionManager;
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18
19use crate::error::PersistenceError;
20use crate::schema_version::SessionSchemaVersion;
21
22/// Redis-backed [`Checkpointer<S>`].
23pub struct RedisCheckpointer<S> {
24    manager: Arc<ConnectionManager>,
25    _phantom: PhantomData<fn() -> S>,
26}
27
28impl<S> RedisCheckpointer<S>
29where
30    S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
31{
32    pub(crate) fn new(manager: Arc<ConnectionManager>) -> Self {
33        Self {
34            manager,
35            _phantom: PhantomData,
36        }
37    }
38}
39
40fn zset_key(key: &ThreadKey) -> String {
41    format!("entelix:cp:{}:{}:bystep", key.tenant_id(), key.thread_id())
42}
43
44fn hash_key(key: &ThreadKey) -> String {
45    format!("entelix:cp:{}:{}:byid", key.tenant_id(), key.thread_id())
46}
47
48#[async_trait]
49impl<S> Checkpointer<S> for RedisCheckpointer<S>
50where
51    S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
52{
53    async fn put(&self, checkpoint: Checkpoint<S>) -> Result<()> {
54        let key = checkpoint.key();
55        let envelope = wrap_envelope(&checkpoint).map_err(into_core)?;
56        let id_str = checkpoint.id.to_hyphenated_string();
57        let mut conn = (*self.manager).clone();
58        let step_score = i64::try_from(checkpoint.step).unwrap_or(i64::MAX) as f64;
59        // Two-step write — Redis pipeline keeps the round-trip minimal.
60        redis::pipe()
61            .atomic()
62            .zadd(zset_key(&key), &id_str, step_score)
63            .hset(hash_key(&key), &id_str, envelope.to_string())
64            .query_async::<()>(&mut conn)
65            .await
66            .map_err(backend_to_core)?;
67        Ok(())
68    }
69
70    async fn get_latest(&self, key: &ThreadKey) -> Result<Option<Checkpoint<S>>> {
71        let mut conn = (*self.manager).clone();
72        let ids: Vec<String> = redis::cmd("ZREVRANGE")
73            .arg(zset_key(key))
74            .arg(0)
75            .arg(0)
76            .query_async(&mut conn)
77            .await
78            .map_err(backend_to_core)?;
79        let Some(id) = ids.into_iter().next() else {
80            return Ok(None);
81        };
82        load_by_id::<S>(&mut conn, key, &id).await
83    }
84
85    async fn get_by_id(&self, key: &ThreadKey, id: &CheckpointId) -> Result<Option<Checkpoint<S>>> {
86        let mut conn = (*self.manager).clone();
87        load_by_id::<S>(&mut conn, key, &id.to_hyphenated_string()).await
88    }
89
90    async fn list_history(&self, key: &ThreadKey, limit: usize) -> Result<Vec<Checkpoint<S>>> {
91        let mut conn = (*self.manager).clone();
92        let stop = if limit == 0 || limit == usize::MAX {
93            -1isize
94        } else {
95            isize::try_from(limit.saturating_sub(1)).unwrap_or(isize::MAX)
96        };
97        let ids: Vec<String> = redis::cmd("ZREVRANGE")
98            .arg(zset_key(key))
99            .arg(0)
100            .arg(stop)
101            .query_async(&mut conn)
102            .await
103            .map_err(backend_to_core)?;
104        let mut out = Vec::with_capacity(ids.len());
105        for id in ids {
106            if let Some(cp) = load_by_id::<S>(&mut conn, key, &id).await? {
107                out.push(cp);
108            }
109        }
110        Ok(out)
111    }
112
113    async fn update_state(
114        &self,
115        key: &ThreadKey,
116        parent_id: &CheckpointId,
117        new_state: S,
118    ) -> Result<CheckpointId> {
119        let parent = self.get_by_id(key, parent_id).await?.ok_or_else(|| {
120            Error::invalid_request(format!(
121                "RedisCheckpointer::update_state: parent {} not found in tenant '{}' thread '{}'",
122                parent_id.to_hyphenated_string(),
123                key.tenant_id(),
124                key.thread_id()
125            ))
126        })?;
127        let new_step = parent.step.saturating_add(1);
128        let new_checkpoint = Checkpoint::new(key, new_step, new_state, parent.next_node)
129            .with_parent(parent_id.clone());
130        let new_id = new_checkpoint.id.clone();
131        self.put(new_checkpoint).await?;
132        Ok(new_id)
133    }
134}
135
136async fn load_by_id<S>(
137    conn: &mut ConnectionManager,
138    key: &ThreadKey,
139    id: &str,
140) -> Result<Option<Checkpoint<S>>>
141where
142    S: Clone + Send + Sync + DeserializeOwned + 'static,
143{
144    let raw: Option<String> = redis::cmd("HGET")
145        .arg(hash_key(key))
146        .arg(id)
147        .query_async(conn)
148        .await
149        .map_err(backend_to_core)?;
150    let Some(raw) = raw else { return Ok(None) };
151    let value: Value = serde_json::from_str(&raw).map_err(Error::Serde)?;
152    let cp = unwrap_envelope::<S>(&value).map_err(into_core)?;
153    Ok(Some(cp))
154}
155
156fn wrap_envelope<S>(cp: &Checkpoint<S>) -> std::result::Result<Value, PersistenceError>
157where
158    S: Clone + Send + Sync + Serialize + 'static,
159{
160    let body = serde_json::json!({
161        "id": cp.id,
162        "tenant_id": cp.tenant_id,
163        "thread_id": cp.thread_id,
164        "parent_id": cp.parent_id,
165        "step": cp.step,
166        "state": serde_json::to_value(&cp.state)?,
167        "next_node": cp.next_node,
168        "timestamp": cp.timestamp,
169    });
170    Ok(serde_json::json!({
171        "schema_version": SessionSchemaVersion::CURRENT,
172        "body": body,
173    }))
174}
175
176fn unwrap_envelope<S>(value: &Value) -> std::result::Result<Checkpoint<S>, PersistenceError>
177where
178    S: Clone + Send + Sync + DeserializeOwned + 'static,
179{
180    let version = value
181        .get("schema_version")
182        .and_then(|v| v.as_u64())
183        .map(|n| u32::try_from(n).unwrap_or(u32::MAX))
184        .map(SessionSchemaVersion)
185        .ok_or_else(|| {
186            PersistenceError::Backend("checkpoint envelope lacks schema_version".into())
187        })?;
188    version.validate()?;
189    let body = value
190        .get("body")
191        .ok_or_else(|| PersistenceError::Backend("checkpoint envelope lacks body".into()))?;
192    let id: CheckpointId = serde_json::from_value(
193        body.get("id")
194            .cloned()
195            .ok_or_else(|| PersistenceError::Backend("checkpoint missing id".into()))?,
196    )?;
197    // Persistence-layer row hydration validates the `tenant_id`
198    // through `TenantId::try_from`; an empty value (which would
199    // otherwise produce a tenantless `Checkpoint`) surfaces as a
200    // typed error rather than a constructed instance.
201    let tenant_id_str = body
202        .get("tenant_id")
203        .and_then(|v| v.as_str())
204        .ok_or_else(|| PersistenceError::Backend("checkpoint missing tenant_id".into()))?;
205    let tenant_id = entelix_core::TenantId::try_from(tenant_id_str)
206        .map_err(|e| PersistenceError::Backend(format!("invalid persisted tenant_id: {e}")))?;
207    let thread_id: String = body
208        .get("thread_id")
209        .and_then(|v| v.as_str())
210        .ok_or_else(|| PersistenceError::Backend("checkpoint missing thread_id".into()))?
211        .to_owned();
212    let parent_id = match body.get("parent_id") {
213        Some(Value::Null) | None => None,
214        Some(v) => Some(serde_json::from_value::<CheckpointId>(v.clone())?),
215    };
216    let step = body
217        .get("step")
218        .and_then(|v| v.as_u64())
219        .and_then(|n| usize::try_from(n).ok())
220        .ok_or_else(|| PersistenceError::Backend("checkpoint missing step".into()))?;
221    let state: S = body
222        .get("state")
223        .map(|s| serde_json::from_value(s.clone()))
224        .ok_or_else(|| PersistenceError::Backend("checkpoint missing state".into()))??;
225    let next_node = body
226        .get("next_node")
227        .and_then(|v| v.as_str())
228        .map(ToOwned::to_owned);
229    let timestamp = body
230        .get("timestamp")
231        .map(|v| serde_json::from_value::<chrono::DateTime<chrono::Utc>>(v.clone()))
232        .ok_or_else(|| PersistenceError::Backend("checkpoint missing timestamp".into()))??;
233    let key = ThreadKey::new(tenant_id, thread_id);
234    Ok(Checkpoint::from_parts(
235        id, &key, parent_id, step, state, next_node, timestamp,
236    ))
237}
238
239fn backend_to_core(e: redis::RedisError) -> Error {
240    PersistenceError::Backend(e.to_string()).into()
241}
242
243fn into_core(e: PersistenceError) -> Error {
244    e.into()
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used)]
249mod tests {
250    use super::*;
251    use uuid::Uuid;
252
253    /// Redis-side mirror of the Postgres hydration regression. A
254    /// stored envelope whose `body.tenant_id` is empty cannot
255    /// hydrate into a `Checkpoint` — the validator catches it at
256    /// the deserialise boundary, surfacing
257    /// `PersistenceError::Backend("invalid persisted tenant_id …")`
258    /// rather than constructing a tenantless `Checkpoint` whose
259    /// downstream key comparison would silently mis-route
260    /// (invariant 11 /).
261    #[test]
262    fn unwrap_envelope_rejects_empty_persisted_tenant_id() {
263        let envelope = serde_json::json!({
264            "schema_version": SessionSchemaVersion::CURRENT,
265            "body": {
266                "id": CheckpointId::from_uuid(Uuid::new_v4()),
267                "tenant_id": "",
268                "thread_id": "th-1",
269                "parent_id": serde_json::Value::Null,
270                "step": 0u64,
271                "state": 42,
272                "next_node": serde_json::Value::Null,
273                "timestamp": chrono::Utc::now(),
274            }
275        });
276        let err = unwrap_envelope::<i32>(&envelope).unwrap_err();
277        assert!(
278            matches!(err, PersistenceError::Backend(ref m) if m.contains("invalid persisted tenant_id")),
279            "expected Backend(\"invalid persisted tenant_id …\"), got {err:?}"
280        );
281    }
282}