entelix_persistence/redis/
checkpointer.rs1use std::marker::PhantomData;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use entelix_core::ThreadKey;
12use entelix_core::{Error, Result};
13use entelix_graph::{Checkpoint, CheckpointId, Checkpointer};
14use redis::aio::ConnectionManager;
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18
19use crate::error::PersistenceError;
20use crate::schema_version::SessionSchemaVersion;
21
22pub struct RedisCheckpointer<S> {
24 manager: Arc<ConnectionManager>,
25 _phantom: PhantomData<fn() -> S>,
26}
27
28impl<S> RedisCheckpointer<S>
29where
30 S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
31{
32 pub(crate) fn new(manager: Arc<ConnectionManager>) -> Self {
33 Self {
34 manager,
35 _phantom: PhantomData,
36 }
37 }
38}
39
40fn zset_key(key: &ThreadKey) -> String {
41 format!("entelix:cp:{}:{}:bystep", key.tenant_id(), key.thread_id())
42}
43
44fn hash_key(key: &ThreadKey) -> String {
45 format!("entelix:cp:{}:{}:byid", key.tenant_id(), key.thread_id())
46}
47
48#[async_trait]
49impl<S> Checkpointer<S> for RedisCheckpointer<S>
50where
51 S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
52{
53 async fn put(&self, checkpoint: Checkpoint<S>) -> Result<()> {
54 let key = checkpoint.key();
55 let envelope = wrap_envelope(&checkpoint).map_err(into_core)?;
56 let id_str = checkpoint.id.to_hyphenated_string();
57 let mut conn = (*self.manager).clone();
58 let step_score = i64::try_from(checkpoint.step).unwrap_or(i64::MAX) as f64;
59 redis::pipe()
61 .atomic()
62 .zadd(zset_key(&key), &id_str, step_score)
63 .hset(hash_key(&key), &id_str, envelope.to_string())
64 .query_async::<()>(&mut conn)
65 .await
66 .map_err(backend_to_core)?;
67 Ok(())
68 }
69
70 async fn get_latest(&self, key: &ThreadKey) -> Result<Option<Checkpoint<S>>> {
71 let mut conn = (*self.manager).clone();
72 let ids: Vec<String> = redis::cmd("ZREVRANGE")
73 .arg(zset_key(key))
74 .arg(0)
75 .arg(0)
76 .query_async(&mut conn)
77 .await
78 .map_err(backend_to_core)?;
79 let Some(id) = ids.into_iter().next() else {
80 return Ok(None);
81 };
82 load_by_id::<S>(&mut conn, key, &id).await
83 }
84
85 async fn get_by_id(&self, key: &ThreadKey, id: &CheckpointId) -> Result<Option<Checkpoint<S>>> {
86 let mut conn = (*self.manager).clone();
87 load_by_id::<S>(&mut conn, key, &id.to_hyphenated_string()).await
88 }
89
90 async fn list_history(&self, key: &ThreadKey, limit: usize) -> Result<Vec<Checkpoint<S>>> {
91 let mut conn = (*self.manager).clone();
92 let stop = if limit == 0 || limit == usize::MAX {
93 -1isize
94 } else {
95 isize::try_from(limit.saturating_sub(1)).unwrap_or(isize::MAX)
96 };
97 let ids: Vec<String> = redis::cmd("ZREVRANGE")
98 .arg(zset_key(key))
99 .arg(0)
100 .arg(stop)
101 .query_async(&mut conn)
102 .await
103 .map_err(backend_to_core)?;
104 let mut out = Vec::with_capacity(ids.len());
105 for id in ids {
106 if let Some(cp) = load_by_id::<S>(&mut conn, key, &id).await? {
107 out.push(cp);
108 }
109 }
110 Ok(out)
111 }
112
113 async fn update_state(
114 &self,
115 key: &ThreadKey,
116 parent_id: &CheckpointId,
117 new_state: S,
118 ) -> Result<CheckpointId> {
119 let parent = self.get_by_id(key, parent_id).await?.ok_or_else(|| {
120 Error::invalid_request(format!(
121 "RedisCheckpointer::update_state: parent {} not found in tenant '{}' thread '{}'",
122 parent_id.to_hyphenated_string(),
123 key.tenant_id(),
124 key.thread_id()
125 ))
126 })?;
127 let new_step = parent.step.saturating_add(1);
128 let new_checkpoint = Checkpoint::new(key, new_step, new_state, parent.next_node)
129 .with_parent(parent_id.clone());
130 let new_id = new_checkpoint.id.clone();
131 self.put(new_checkpoint).await?;
132 Ok(new_id)
133 }
134}
135
136async fn load_by_id<S>(
137 conn: &mut ConnectionManager,
138 key: &ThreadKey,
139 id: &str,
140) -> Result<Option<Checkpoint<S>>>
141where
142 S: Clone + Send + Sync + DeserializeOwned + 'static,
143{
144 let raw: Option<String> = redis::cmd("HGET")
145 .arg(hash_key(key))
146 .arg(id)
147 .query_async(conn)
148 .await
149 .map_err(backend_to_core)?;
150 let Some(raw) = raw else { return Ok(None) };
151 let value: Value = serde_json::from_str(&raw).map_err(Error::Serde)?;
152 let cp = unwrap_envelope::<S>(&value).map_err(into_core)?;
153 Ok(Some(cp))
154}
155
156fn wrap_envelope<S>(cp: &Checkpoint<S>) -> std::result::Result<Value, PersistenceError>
157where
158 S: Clone + Send + Sync + Serialize + 'static,
159{
160 let body = serde_json::json!({
161 "id": cp.id,
162 "tenant_id": cp.tenant_id,
163 "thread_id": cp.thread_id,
164 "parent_id": cp.parent_id,
165 "step": cp.step,
166 "state": serde_json::to_value(&cp.state)?,
167 "next_node": cp.next_node,
168 "timestamp": cp.timestamp,
169 });
170 Ok(serde_json::json!({
171 "schema_version": SessionSchemaVersion::CURRENT,
172 "body": body,
173 }))
174}
175
176fn unwrap_envelope<S>(value: &Value) -> std::result::Result<Checkpoint<S>, PersistenceError>
177where
178 S: Clone + Send + Sync + DeserializeOwned + 'static,
179{
180 let version = value
181 .get("schema_version")
182 .and_then(|v| v.as_u64())
183 .map(|n| u32::try_from(n).unwrap_or(u32::MAX))
184 .map(SessionSchemaVersion)
185 .ok_or_else(|| {
186 PersistenceError::Backend("checkpoint envelope lacks schema_version".into())
187 })?;
188 version.validate()?;
189 let body = value
190 .get("body")
191 .ok_or_else(|| PersistenceError::Backend("checkpoint envelope lacks body".into()))?;
192 let id: CheckpointId = serde_json::from_value(
193 body.get("id")
194 .cloned()
195 .ok_or_else(|| PersistenceError::Backend("checkpoint missing id".into()))?,
196 )?;
197 let tenant_id_str = body
202 .get("tenant_id")
203 .and_then(|v| v.as_str())
204 .ok_or_else(|| PersistenceError::Backend("checkpoint missing tenant_id".into()))?;
205 let tenant_id = entelix_core::TenantId::try_from(tenant_id_str)
206 .map_err(|e| PersistenceError::Backend(format!("invalid persisted tenant_id: {e}")))?;
207 let thread_id: String = body
208 .get("thread_id")
209 .and_then(|v| v.as_str())
210 .ok_or_else(|| PersistenceError::Backend("checkpoint missing thread_id".into()))?
211 .to_owned();
212 let parent_id = match body.get("parent_id") {
213 Some(Value::Null) | None => None,
214 Some(v) => Some(serde_json::from_value::<CheckpointId>(v.clone())?),
215 };
216 let step = body
217 .get("step")
218 .and_then(|v| v.as_u64())
219 .and_then(|n| usize::try_from(n).ok())
220 .ok_or_else(|| PersistenceError::Backend("checkpoint missing step".into()))?;
221 let state: S = body
222 .get("state")
223 .map(|s| serde_json::from_value(s.clone()))
224 .ok_or_else(|| PersistenceError::Backend("checkpoint missing state".into()))??;
225 let next_node = body
226 .get("next_node")
227 .and_then(|v| v.as_str())
228 .map(ToOwned::to_owned);
229 let timestamp = body
230 .get("timestamp")
231 .map(|v| serde_json::from_value::<chrono::DateTime<chrono::Utc>>(v.clone()))
232 .ok_or_else(|| PersistenceError::Backend("checkpoint missing timestamp".into()))??;
233 let key = ThreadKey::new(tenant_id, thread_id);
234 Ok(Checkpoint::from_parts(
235 id, &key, parent_id, step, state, next_node, timestamp,
236 ))
237}
238
239fn backend_to_core(e: redis::RedisError) -> Error {
240 PersistenceError::Backend(e.to_string()).into()
241}
242
243fn into_core(e: PersistenceError) -> Error {
244 e.into()
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used)]
249mod tests {
250 use super::*;
251 use uuid::Uuid;
252
253 #[test]
262 fn unwrap_envelope_rejects_empty_persisted_tenant_id() {
263 let envelope = serde_json::json!({
264 "schema_version": SessionSchemaVersion::CURRENT,
265 "body": {
266 "id": CheckpointId::from_uuid(Uuid::new_v4()),
267 "tenant_id": "",
268 "thread_id": "th-1",
269 "parent_id": serde_json::Value::Null,
270 "step": 0u64,
271 "state": 42,
272 "next_node": serde_json::Value::Null,
273 "timestamp": chrono::Utc::now(),
274 }
275 });
276 let err = unwrap_envelope::<i32>(&envelope).unwrap_err();
277 assert!(
278 matches!(err, PersistenceError::Backend(ref m) if m.contains("invalid persisted tenant_id")),
279 "expected Backend(\"invalid persisted tenant_id …\"), got {err:?}"
280 );
281 }
282}