entelix_persistence/postgres/
checkpointer.rs1use std::marker::PhantomData;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use entelix_core::ThreadKey;
13use entelix_core::{Error, Result};
14use entelix_graph::{Checkpoint, CheckpointId, Checkpointer};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18use sqlx::postgres::PgPool;
19use uuid::Uuid;
20
21use crate::error::PersistenceError;
22use crate::postgres::tenant::set_tenant_session;
23use crate::schema_version::SessionSchemaVersion;
24
25const STATE_KEY: &str = "state";
26const SCHEMA_KEY: &str = "schema_version";
27
28pub struct PostgresCheckpointer<S> {
34 pool: Arc<PgPool>,
35 _phantom: PhantomData<fn() -> S>,
36}
37
38impl<S> PostgresCheckpointer<S>
39where
40 S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
41{
42 pub(crate) fn new(pool: Arc<PgPool>) -> Self {
43 Self {
44 pool,
45 _phantom: PhantomData,
46 }
47 }
48}
49
50#[async_trait]
51impl<S> Checkpointer<S> for PostgresCheckpointer<S>
52where
53 S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
54{
55 async fn put(&self, checkpoint: Checkpoint<S>) -> Result<()> {
56 let envelope = wrap_state(&checkpoint.state).map_err(into_core)?;
57 let parent = checkpoint.parent_id.as_ref().map(|p| *p.as_uuid());
58 let step_i64 = i64::try_from(checkpoint.step).unwrap_or(i64::MAX);
59
60 let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
61 set_tenant_session(&mut *tx, &checkpoint.tenant_id).await?;
62 sqlx::query(
63 r"
64 INSERT INTO checkpoints
65 (tenant_id, thread_id, id, parent_id, step, state, next_node, ts)
66 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
67 ",
68 )
69 .bind(checkpoint.tenant_id.as_str())
70 .bind(&checkpoint.thread_id)
71 .bind(checkpoint.id.as_uuid())
72 .bind(parent)
73 .bind(step_i64)
74 .bind(&envelope)
75 .bind(checkpoint.next_node.as_deref())
76 .bind(checkpoint.timestamp)
77 .execute(&mut *tx)
78 .await
79 .map_err(backend_to_core)?;
80 tx.commit().await.map_err(backend_to_core)?;
81 Ok(())
82 }
83
84 async fn get_latest(&self, key: &ThreadKey) -> Result<Option<Checkpoint<S>>> {
85 let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
86 set_tenant_session(&mut *tx, key.tenant_id()).await?;
87 let row: Option<CheckpointRow> = sqlx::query_as::<_, CheckpointRow>(
88 r"
89 SELECT tenant_id, thread_id, id, parent_id, step, state, next_node, ts
90 FROM checkpoints
91 WHERE tenant_id = $1 AND thread_id = $2
92 ORDER BY step DESC, ts DESC
93 LIMIT 1
94 ",
95 )
96 .bind(key.tenant_id().as_str())
97 .bind(key.thread_id())
98 .fetch_optional(&mut *tx)
99 .await
100 .map_err(backend_to_core)?;
101 tx.commit().await.map_err(backend_to_core)?;
102 row.map(|r| r.try_into_checkpoint::<S>())
103 .transpose()
104 .map_err(into_core)
105 }
106
107 async fn get_by_id(&self, key: &ThreadKey, id: &CheckpointId) -> Result<Option<Checkpoint<S>>> {
108 let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
109 set_tenant_session(&mut *tx, key.tenant_id()).await?;
110 let row: Option<CheckpointRow> = sqlx::query_as::<_, CheckpointRow>(
111 r"
112 SELECT tenant_id, thread_id, id, parent_id, step, state, next_node, ts
113 FROM checkpoints
114 WHERE tenant_id = $1 AND thread_id = $2 AND id = $3
115 ",
116 )
117 .bind(key.tenant_id().as_str())
118 .bind(key.thread_id())
119 .bind(id.as_uuid())
120 .fetch_optional(&mut *tx)
121 .await
122 .map_err(backend_to_core)?;
123 tx.commit().await.map_err(backend_to_core)?;
124 row.map(|r| r.try_into_checkpoint::<S>())
125 .transpose()
126 .map_err(into_core)
127 }
128
129 async fn list_history(&self, key: &ThreadKey, limit: usize) -> Result<Vec<Checkpoint<S>>> {
130 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
131 let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
132 set_tenant_session(&mut *tx, key.tenant_id()).await?;
133 let rows: Vec<CheckpointRow> = sqlx::query_as::<_, CheckpointRow>(
134 r"
135 SELECT tenant_id, thread_id, id, parent_id, step, state, next_node, ts
136 FROM checkpoints
137 WHERE tenant_id = $1 AND thread_id = $2
138 ORDER BY step DESC, ts DESC
139 LIMIT $3
140 ",
141 )
142 .bind(key.tenant_id().as_str())
143 .bind(key.thread_id())
144 .bind(limit_i64)
145 .fetch_all(&mut *tx)
146 .await
147 .map_err(backend_to_core)?;
148 tx.commit().await.map_err(backend_to_core)?;
149 rows.into_iter()
150 .map(CheckpointRow::try_into_checkpoint::<S>)
151 .collect::<std::result::Result<Vec<_>, _>>()
152 .map_err(into_core)
153 }
154
155 async fn update_state(
156 &self,
157 key: &ThreadKey,
158 parent_id: &CheckpointId,
159 new_state: S,
160 ) -> Result<CheckpointId> {
161 let parent = self.get_by_id(key, parent_id).await?.ok_or_else(|| {
162 Error::invalid_request(format!(
163 "PostgresCheckpointer::update_state: parent {} not found in tenant '{}' thread '{}'",
164 parent_id.to_hyphenated_string(),
165 key.tenant_id(),
166 key.thread_id()
167 ))
168 })?;
169 let new_step = parent.step.saturating_add(1);
170 let new_checkpoint = Checkpoint::new(key, new_step, new_state, parent.next_node)
171 .with_parent(parent_id.clone());
172 let new_id = new_checkpoint.id.clone();
173 self.put(new_checkpoint).await?;
174 Ok(new_id)
175 }
176}
177
178#[derive(sqlx::FromRow)]
179struct CheckpointRow {
180 tenant_id: String,
181 thread_id: String,
182 id: Uuid,
183 parent_id: Option<Uuid>,
184 step: i64,
185 state: Value,
186 next_node: Option<String>,
187 ts: DateTime<Utc>,
188}
189
190impl CheckpointRow {
191 fn try_into_checkpoint<S>(self) -> std::result::Result<Checkpoint<S>, PersistenceError>
192 where
193 S: Clone + Send + Sync + DeserializeOwned + 'static,
194 {
195 let state = unwrap_state::<S>(&self.state)?;
196 let tenant = entelix_core::TenantId::try_from(self.tenant_id)
202 .map_err(|e| PersistenceError::Backend(format!("invalid persisted tenant_id: {e}")))?;
203 let key = ThreadKey::new(tenant, self.thread_id);
204 Ok(Checkpoint::from_parts(
205 CheckpointId::from_uuid(self.id),
206 &key,
207 self.parent_id.map(CheckpointId::from_uuid),
208 usize::try_from(self.step).unwrap_or(0),
209 state,
210 self.next_node,
211 self.ts,
212 ))
213 }
214}
215
216fn wrap_state<S: Serialize>(state: &S) -> std::result::Result<Value, PersistenceError> {
217 let body = serde_json::to_value(state)?;
218 Ok(serde_json::json!({
219 SCHEMA_KEY: SessionSchemaVersion::CURRENT,
220 STATE_KEY: body,
221 }))
222}
223
224fn unwrap_state<S: DeserializeOwned>(value: &Value) -> std::result::Result<S, PersistenceError> {
225 let version = value
226 .get(SCHEMA_KEY)
227 .and_then(|v| v.as_u64())
228 .map(|n| u32::try_from(n).unwrap_or(u32::MAX))
229 .map(SessionSchemaVersion)
230 .ok_or_else(|| {
231 PersistenceError::Backend("checkpoint payload lacks schema_version".into())
232 })?;
233 version.validate()?;
234 let body = value
235 .get(STATE_KEY)
236 .ok_or_else(|| PersistenceError::Backend("checkpoint payload lacks state".into()))?;
237 Ok(serde_json::from_value(body.clone())?)
238}
239
240fn backend_to_core(e: sqlx::Error) -> Error {
241 PersistenceError::Backend(e.to_string()).into()
242}
243
244fn into_core(e: PersistenceError) -> Error {
245 e.into()
246}
247
248#[cfg(test)]
249#[allow(clippy::unwrap_used)]
250mod tests {
251 use super::*;
252
253 #[test]
261 fn try_into_checkpoint_rejects_empty_persisted_tenant_id() {
262 let row = CheckpointRow {
263 tenant_id: String::new(),
264 thread_id: "th-1".to_owned(),
265 id: Uuid::new_v4(),
266 parent_id: None,
267 step: 0,
268 state: serde_json::json!({
269 SCHEMA_KEY: SessionSchemaVersion::CURRENT,
270 STATE_KEY: 42,
271 }),
272 next_node: None,
273 ts: chrono::Utc::now(),
274 };
275 let err = row.try_into_checkpoint::<i32>().unwrap_err();
276 assert!(
277 matches!(err, PersistenceError::Backend(ref m) if m.contains("invalid persisted tenant_id")),
278 "expected Backend(\"invalid persisted tenant_id …\"), got {err:?}"
279 );
280 }
281}