Skip to main content

entelix_persistence/postgres/
checkpointer.rs

1//! `PostgresCheckpointer<S>` — `entelix_graph::Checkpointer<S>` over
2//! the `checkpoints` table. Every read/write partitions by
3//! `(tenant_id, thread_id)` per Invariant 11 — the trait surface
4//! supplies a `&ThreadKey` so cross-tenant reads are not even
5//! constructible from this backend.
6
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use entelix_core::ThreadKey;
13use entelix_core::{Error, Result};
14use entelix_graph::{Checkpoint, CheckpointId, Checkpointer};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use serde_json::Value;
18use sqlx::postgres::PgPool;
19use uuid::Uuid;
20
21use crate::error::PersistenceError;
22use crate::postgres::tenant::set_tenant_session;
23use crate::schema_version::SessionSchemaVersion;
24
25const STATE_KEY: &str = "state";
26const SCHEMA_KEY: &str = "schema_version";
27
28/// Postgres-backed [`Checkpointer<S>`].
29///
30/// State payloads are stamped with [`SessionSchemaVersion`] before
31/// serialisation, so a downgrade that can't read the format fails
32/// loudly instead of silently corrupting the row.
33pub struct PostgresCheckpointer<S> {
34    pool: Arc<PgPool>,
35    _phantom: PhantomData<fn() -> S>,
36}
37
38impl<S> PostgresCheckpointer<S>
39where
40    S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
41{
42    pub(crate) fn new(pool: Arc<PgPool>) -> Self {
43        Self {
44            pool,
45            _phantom: PhantomData,
46        }
47    }
48}
49
50#[async_trait]
51impl<S> Checkpointer<S> for PostgresCheckpointer<S>
52where
53    S: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
54{
55    async fn put(&self, checkpoint: Checkpoint<S>) -> Result<()> {
56        let envelope = wrap_state(&checkpoint.state).map_err(into_core)?;
57        let parent = checkpoint.parent_id.as_ref().map(|p| *p.as_uuid());
58        let step_i64 = i64::try_from(checkpoint.step).unwrap_or(i64::MAX);
59
60        let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
61        set_tenant_session(&mut *tx, &checkpoint.tenant_id).await?;
62        sqlx::query(
63            r"
64            INSERT INTO checkpoints
65                (tenant_id, thread_id, id, parent_id, step, state, next_node, ts)
66            VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
67            ",
68        )
69        .bind(checkpoint.tenant_id.as_str())
70        .bind(&checkpoint.thread_id)
71        .bind(checkpoint.id.as_uuid())
72        .bind(parent)
73        .bind(step_i64)
74        .bind(&envelope)
75        .bind(checkpoint.next_node.as_deref())
76        .bind(checkpoint.timestamp)
77        .execute(&mut *tx)
78        .await
79        .map_err(backend_to_core)?;
80        tx.commit().await.map_err(backend_to_core)?;
81        Ok(())
82    }
83
84    async fn get_latest(&self, key: &ThreadKey) -> Result<Option<Checkpoint<S>>> {
85        let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
86        set_tenant_session(&mut *tx, key.tenant_id()).await?;
87        let row: Option<CheckpointRow> = sqlx::query_as::<_, CheckpointRow>(
88            r"
89            SELECT tenant_id, thread_id, id, parent_id, step, state, next_node, ts
90            FROM checkpoints
91            WHERE tenant_id = $1 AND thread_id = $2
92            ORDER BY step DESC, ts DESC
93            LIMIT 1
94            ",
95        )
96        .bind(key.tenant_id().as_str())
97        .bind(key.thread_id())
98        .fetch_optional(&mut *tx)
99        .await
100        .map_err(backend_to_core)?;
101        tx.commit().await.map_err(backend_to_core)?;
102        row.map(|r| r.try_into_checkpoint::<S>())
103            .transpose()
104            .map_err(into_core)
105    }
106
107    async fn get_by_id(&self, key: &ThreadKey, id: &CheckpointId) -> Result<Option<Checkpoint<S>>> {
108        let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
109        set_tenant_session(&mut *tx, key.tenant_id()).await?;
110        let row: Option<CheckpointRow> = sqlx::query_as::<_, CheckpointRow>(
111            r"
112            SELECT tenant_id, thread_id, id, parent_id, step, state, next_node, ts
113            FROM checkpoints
114            WHERE tenant_id = $1 AND thread_id = $2 AND id = $3
115            ",
116        )
117        .bind(key.tenant_id().as_str())
118        .bind(key.thread_id())
119        .bind(id.as_uuid())
120        .fetch_optional(&mut *tx)
121        .await
122        .map_err(backend_to_core)?;
123        tx.commit().await.map_err(backend_to_core)?;
124        row.map(|r| r.try_into_checkpoint::<S>())
125            .transpose()
126            .map_err(into_core)
127    }
128
129    async fn list_history(&self, key: &ThreadKey, limit: usize) -> Result<Vec<Checkpoint<S>>> {
130        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
131        let mut tx = self.pool.begin().await.map_err(backend_to_core)?;
132        set_tenant_session(&mut *tx, key.tenant_id()).await?;
133        let rows: Vec<CheckpointRow> = sqlx::query_as::<_, CheckpointRow>(
134            r"
135            SELECT tenant_id, thread_id, id, parent_id, step, state, next_node, ts
136            FROM checkpoints
137            WHERE tenant_id = $1 AND thread_id = $2
138            ORDER BY step DESC, ts DESC
139            LIMIT $3
140            ",
141        )
142        .bind(key.tenant_id().as_str())
143        .bind(key.thread_id())
144        .bind(limit_i64)
145        .fetch_all(&mut *tx)
146        .await
147        .map_err(backend_to_core)?;
148        tx.commit().await.map_err(backend_to_core)?;
149        rows.into_iter()
150            .map(CheckpointRow::try_into_checkpoint::<S>)
151            .collect::<std::result::Result<Vec<_>, _>>()
152            .map_err(into_core)
153    }
154
155    async fn update_state(
156        &self,
157        key: &ThreadKey,
158        parent_id: &CheckpointId,
159        new_state: S,
160    ) -> Result<CheckpointId> {
161        let parent = self.get_by_id(key, parent_id).await?.ok_or_else(|| {
162            Error::invalid_request(format!(
163                "PostgresCheckpointer::update_state: parent {} not found in tenant '{}' thread '{}'",
164                parent_id.to_hyphenated_string(),
165                key.tenant_id(),
166                key.thread_id()
167            ))
168        })?;
169        let new_step = parent.step.saturating_add(1);
170        let new_checkpoint = Checkpoint::new(key, new_step, new_state, parent.next_node)
171            .with_parent(parent_id.clone());
172        let new_id = new_checkpoint.id.clone();
173        self.put(new_checkpoint).await?;
174        Ok(new_id)
175    }
176}
177
178#[derive(sqlx::FromRow)]
179struct CheckpointRow {
180    tenant_id: String,
181    thread_id: String,
182    id: Uuid,
183    parent_id: Option<Uuid>,
184    step: i64,
185    state: Value,
186    next_node: Option<String>,
187    ts: DateTime<Utc>,
188}
189
190impl CheckpointRow {
191    fn try_into_checkpoint<S>(self) -> std::result::Result<Checkpoint<S>, PersistenceError>
192    where
193        S: Clone + Send + Sync + DeserializeOwned + 'static,
194    {
195        let state = unwrap_state::<S>(&self.state)?;
196        // Persistence-layer row hydration runs the validating
197        // `TenantId::try_from`; an empty `tenant_id` column (which
198        // would otherwise produce a tenantless `Checkpoint` whose
199        // RLS policy comparison silently mis-routes) surfaces as
200        // `Error::InvalidRequest` rather than a constructed value.
201        let tenant = entelix_core::TenantId::try_from(self.tenant_id)
202            .map_err(|e| PersistenceError::Backend(format!("invalid persisted tenant_id: {e}")))?;
203        let key = ThreadKey::new(tenant, self.thread_id);
204        Ok(Checkpoint::from_parts(
205            CheckpointId::from_uuid(self.id),
206            &key,
207            self.parent_id.map(CheckpointId::from_uuid),
208            usize::try_from(self.step).unwrap_or(0),
209            state,
210            self.next_node,
211            self.ts,
212        ))
213    }
214}
215
216fn wrap_state<S: Serialize>(state: &S) -> std::result::Result<Value, PersistenceError> {
217    let body = serde_json::to_value(state)?;
218    Ok(serde_json::json!({
219        SCHEMA_KEY: SessionSchemaVersion::CURRENT,
220        STATE_KEY: body,
221    }))
222}
223
224fn unwrap_state<S: DeserializeOwned>(value: &Value) -> std::result::Result<S, PersistenceError> {
225    let version = value
226        .get(SCHEMA_KEY)
227        .and_then(|v| v.as_u64())
228        .map(|n| u32::try_from(n).unwrap_or(u32::MAX))
229        .map(SessionSchemaVersion)
230        .ok_or_else(|| {
231            PersistenceError::Backend("checkpoint payload lacks schema_version".into())
232        })?;
233    version.validate()?;
234    let body = value
235        .get(STATE_KEY)
236        .ok_or_else(|| PersistenceError::Backend("checkpoint payload lacks state".into()))?;
237    Ok(serde_json::from_value(body.clone())?)
238}
239
240fn backend_to_core(e: sqlx::Error) -> Error {
241    PersistenceError::Backend(e.to_string()).into()
242}
243
244fn into_core(e: PersistenceError) -> Error {
245    e.into()
246}
247
248#[cfg(test)]
249#[allow(clippy::unwrap_used)]
250mod tests {
251    use super::*;
252
253    /// Persistence-layer row hydration must run the `TenantId`
254    /// validator on the persisted column. A row whose `tenant_id`
255    /// column is empty (whether from a misconfigured admin script
256    /// or a corrupted backup) cannot construct a tenantless
257    /// `Checkpoint` whose RLS-filter comparison would then run
258    /// against `''` and silently widen the result set
259    /// (invariant 11 /).
260    #[test]
261    fn try_into_checkpoint_rejects_empty_persisted_tenant_id() {
262        let row = CheckpointRow {
263            tenant_id: String::new(),
264            thread_id: "th-1".to_owned(),
265            id: Uuid::new_v4(),
266            parent_id: None,
267            step: 0,
268            state: serde_json::json!({
269                SCHEMA_KEY: SessionSchemaVersion::CURRENT,
270                STATE_KEY: 42,
271            }),
272            next_node: None,
273            ts: chrono::Utc::now(),
274        };
275        let err = row.try_into_checkpoint::<i32>().unwrap_err();
276        assert!(
277            matches!(err, PersistenceError::Backend(ref m) if m.contains("invalid persisted tenant_id")),
278            "expected Backend(\"invalid persisted tenant_id …\"), got {err:?}"
279        );
280    }
281}