Skip to main content

entelix_graphmemory_pg/
store.rs

1//! `PgGraphMemory<N, E>` — `entelix_memory::GraphMemory<N, E>`
2//! over Postgres + JSONB payload columns.
3//!
4//! ## Storage shape
5//!
6//! Two tables:
7//!
8//! - **nodes** (`graph_nodes` by default): `(namespace_key, id,
9//!   payload)` with composite PK `(namespace_key, id)`.
10//! - **edges** (`graph_edges` by default): `(namespace_key, id,
11//!   from_node, to_node, payload, ts)` with composite PK
12//!   `(namespace_key, id)`, plus covering indexes on
13//!   `(namespace_key, from_node)` / `(namespace_key, to_node)` /
14//!   `(namespace_key, ts)`.
15//!
16//! Every read / write rides a `WHERE namespace_key = $1` anchor —
17//! invariant 11 / F2 demands structural tenant isolation, and the
18//! composite PK doubles as the B-tree index that anchor relies
19//! on.
20//!
21//! ## Traversal model
22//!
23//! `traverse` and `find_path` issue a single `WITH RECURSIVE` query
24//! per call — Postgres expands the BFS server-side, returning the
25//! visited hops (or the reconstructed shortest path) in one
26//! round-trip regardless of `max_depth`. The recursive CTE carries
27//! a per-row `visited` array for cycle prevention, and `find_path`
28//! additionally accumulates an `edge_path` array that the outer
29//! query unrolls and rejoins to the edges table to reconstruct the
30//! full hop sequence in BFS order.
31//!
32//! ## Schema-as-code escape hatch
33//!
34//! Operators that own the schema externally (DBA-managed, IaC,
35//! migration pipeline) call [`PgGraphMemoryBuilder::with_auto_migrate`]
36//! with `false` — the builder skips table / index creation,
37//! trusting the operator to have stamped them.
38
39use std::marker::PhantomData;
40use std::sync::Arc;
41
42use async_trait::async_trait;
43use chrono::{DateTime, Utc};
44use entelix_core::{ExecutionContext, Result};
45use entelix_memory::{Direction, EdgeId, GraphHop, GraphMemory, Namespace, NodeId};
46use serde::Serialize;
47use serde::de::DeserializeOwned;
48use serde_json::Value;
49use sqlx::postgres::{PgPool, PgPoolOptions};
50
51use crate::error::{PgGraphMemoryError, PgGraphMemoryResult};
52use crate::migration::bootstrap;
53use crate::tenant::set_tenant_session;
54
55const DEFAULT_NODES_TABLE: &str = "graph_nodes";
56const DEFAULT_EDGES_TABLE: &str = "graph_edges";
57
58/// Postgres-backed [`GraphMemory<N, E>`].
59///
60/// Cheap to clone — internal state is an `Arc<PgPool>` plus two
61/// owned table-name strings.
62pub struct PgGraphMemory<N, E> {
63    pool: Arc<PgPool>,
64    nodes_table: Arc<str>,
65    edges_table: Arc<str>,
66    _phantom: PhantomData<fn() -> (N, E)>,
67}
68
69impl<N, E> Clone for PgGraphMemory<N, E> {
70    fn clone(&self) -> Self {
71        Self {
72            pool: Arc::clone(&self.pool),
73            nodes_table: Arc::clone(&self.nodes_table),
74            edges_table: Arc::clone(&self.edges_table),
75            _phantom: PhantomData,
76        }
77    }
78}
79
80impl<N, E> std::fmt::Debug for PgGraphMemory<N, E> {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("PgGraphMemory")
83            .field("nodes_table", &self.nodes_table)
84            .field("edges_table", &self.edges_table)
85            .finish_non_exhaustive()
86    }
87}
88
89impl<N, E> PgGraphMemory<N, E> {
90    /// Start a fluent builder. `connection_string` is the only
91    /// required field; everything else has a sensible default.
92    pub fn builder() -> PgGraphMemoryBuilder<N, E> {
93        PgGraphMemoryBuilder::new()
94    }
95
96    // ── Backend-specific admin / migration surface ────────────────────
97    //
98    // The methods below are *not* part of the `GraphMemory` trait — they
99    // are operator-side enumeration / cleanup paths the SDK delegates to
100    // the backend type because (a) the trait would otherwise grow a
101    // surface every backend must re-implement (silent no-op risk per
102    // invariant 15), and (b) these are admin/migration concerns the
103    // *operator* runs, not the *agent*. Callers that hold a concrete
104    // `PgGraphMemory<N, E>` reach for these directly; trait-erased
105    // call sites do not see them.
106
107    /// Paginated node-id enumeration, ascending by id (UUID v7
108    /// mint-time order). Operator-side admin / migration path.
109    pub async fn list_nodes(
110        &self,
111        ns: &Namespace,
112        limit: usize,
113        offset: usize,
114    ) -> Result<Vec<NodeId>> {
115        let sql = format!(
116            "SELECT id FROM {} \
117             WHERE namespace_key = $1 \
118             ORDER BY id ASC \
119             LIMIT $2 OFFSET $3",
120            self.nodes_table
121        );
122        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
123        let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
124        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
125        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
126        let rows: Vec<(String,)> = sqlx::query_as(&sql)
127            .bind(ns.render())
128            .bind(limit_i64)
129            .bind(offset_i64)
130            .fetch_all(&mut *tx)
131            .await
132            .map_err(into_core_sqlx)?;
133        tx.commit().await.map_err(into_core_sqlx)?;
134        Ok(rows
135            .into_iter()
136            .map(|(id,)| NodeId::from_string(id))
137            .collect())
138    }
139
140    /// Paginated `(NodeId, N)` enumeration — single round-trip
141    /// versus `list_nodes` + per-id `node()`. Operator-side
142    /// bulk-export path.
143    pub async fn list_node_records(
144        &self,
145        ns: &Namespace,
146        limit: usize,
147        offset: usize,
148    ) -> Result<Vec<(NodeId, N)>>
149    where
150        N: DeserializeOwned,
151    {
152        let sql = format!(
153            "SELECT id, payload FROM {} \
154             WHERE namespace_key = $1 \
155             ORDER BY id ASC \
156             LIMIT $2 OFFSET $3",
157            self.nodes_table
158        );
159        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
160        let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
161        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
162        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
163        let rows: Vec<(String, Value)> = sqlx::query_as(&sql)
164            .bind(ns.render())
165            .bind(limit_i64)
166            .bind(offset_i64)
167            .fetch_all(&mut *tx)
168            .await
169            .map_err(into_core_sqlx)?;
170        tx.commit().await.map_err(into_core_sqlx)?;
171        rows.into_iter()
172            .map(|(id, payload)| {
173                let node: N = serde_json::from_value(payload).map_err(into_core_codec)?;
174                Ok((NodeId::from_string(id), node))
175            })
176            .collect()
177    }
178
179    /// Paginated edge-id enumeration. Operator-side migration path.
180    pub async fn list_edges(
181        &self,
182        ns: &Namespace,
183        limit: usize,
184        offset: usize,
185    ) -> Result<Vec<EdgeId>> {
186        let sql = format!(
187            "SELECT id FROM {} \
188             WHERE namespace_key = $1 \
189             ORDER BY id ASC \
190             LIMIT $2 OFFSET $3",
191            self.edges_table
192        );
193        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
194        let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
195        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
196        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
197        let rows: Vec<(String,)> = sqlx::query_as(&sql)
198            .bind(ns.render())
199            .bind(limit_i64)
200            .bind(offset_i64)
201            .fetch_all(&mut *tx)
202            .await
203            .map_err(into_core_sqlx)?;
204        tx.commit().await.map_err(into_core_sqlx)?;
205        Ok(rows
206            .into_iter()
207            .map(|(id,)| EdgeId::from_string(id))
208            .collect())
209    }
210
211    /// Paginated `GraphHop<E>` enumeration — full structural body
212    /// in one round-trip. Operator-side bulk-export path.
213    pub async fn list_edge_records(
214        &self,
215        ns: &Namespace,
216        limit: usize,
217        offset: usize,
218    ) -> Result<Vec<GraphHop<E>>>
219    where
220        E: DeserializeOwned,
221    {
222        let sql = format!(
223            "SELECT id, from_node, to_node, payload, ts FROM {} \
224             WHERE namespace_key = $1 \
225             ORDER BY id ASC \
226             LIMIT $2 OFFSET $3",
227            self.edges_table
228        );
229        let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
230        let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
231        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
232        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
233        let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
234            .bind(ns.render())
235            .bind(limit_i64)
236            .bind(offset_i64)
237            .fetch_all(&mut *tx)
238            .await
239            .map_err(into_core_sqlx)?;
240        tx.commit().await.map_err(into_core_sqlx)?;
241        rows.into_iter()
242            .map(|(id, fr, to_n, payload, ts)| {
243                let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
244                Ok(GraphHop::new(
245                    EdgeId::from_string(id),
246                    NodeId::from_string(fr),
247                    NodeId::from_string(to_n),
248                    edge,
249                    ts,
250                ))
251            })
252            .collect()
253    }
254
255    /// Drop every node with no incident edge — single SQL anti-join
256    /// against the edges table. Two-phase prune companion to
257    /// `prune_older_than`: the edge sweep leaves orphans that this
258    /// call cleans up. Operator-side admin path.
259    pub async fn prune_orphan_nodes(&self, ns: &Namespace) -> Result<usize> {
260        let sql = format!(
261            "DELETE FROM {nodes} \
262             WHERE namespace_key = $1 \
263               AND id NOT IN ( \
264                   SELECT from_node FROM {edges} WHERE namespace_key = $1 \
265                   UNION \
266                   SELECT to_node FROM {edges} WHERE namespace_key = $1 \
267               )",
268            nodes = self.nodes_table,
269            edges = self.edges_table
270        );
271        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
272        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
273        let result = sqlx::query(&sql)
274            .bind(ns.render())
275            .execute(&mut *tx)
276            .await
277            .map_err(into_core_sqlx)?;
278        tx.commit().await.map_err(into_core_sqlx)?;
279        Ok(usize::try_from(result.rows_affected()).unwrap_or(usize::MAX))
280    }
281}
282
283/// Fluent builder for [`PgGraphMemory`]. Use
284/// [`PgGraphMemory::builder`].
285#[must_use]
286pub struct PgGraphMemoryBuilder<N, E> {
287    url: Option<String>,
288    pool: Option<Arc<PgPool>>,
289    nodes_table: String,
290    edges_table: String,
291    auto_migrate: bool,
292    _phantom: PhantomData<fn() -> (N, E)>,
293}
294
295impl<N, E> Default for PgGraphMemoryBuilder<N, E> {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301impl<N, E> PgGraphMemoryBuilder<N, E> {
302    /// Empty builder.
303    pub fn new() -> Self {
304        Self {
305            url: None,
306            pool: None,
307            nodes_table: DEFAULT_NODES_TABLE.to_owned(),
308            edges_table: DEFAULT_EDGES_TABLE.to_owned(),
309            auto_migrate: true,
310            _phantom: PhantomData,
311        }
312    }
313
314    /// Postgres connection string. Mutually exclusive with
315    /// [`Self::with_pool`] — the builder rejects construction if
316    /// both are set.
317    pub fn with_connection_string(mut self, url: impl Into<String>) -> Self {
318        self.url = Some(url.into());
319        self
320    }
321
322    /// Reuse an existing pool — useful when the operator already
323    /// manages a `PgPool` for other persistence layers and wants
324    /// `PgGraphMemory` to share it.
325    pub fn with_pool(mut self, pool: Arc<PgPool>) -> Self {
326        self.pool = Some(pool);
327        self
328    }
329
330    /// Override the nodes table name (default `graph_nodes`).
331    pub fn with_nodes_table(mut self, name: impl Into<String>) -> Self {
332        self.nodes_table = name.into();
333        self
334    }
335
336    /// Override the edges table name (default `graph_edges`).
337    pub fn with_edges_table(mut self, name: impl Into<String>) -> Self {
338        self.edges_table = name.into();
339        self
340    }
341
342    /// Toggle the idempotent schema bootstrap. Default `true`;
343    /// set to `false` when the schema is owned externally.
344    pub const fn with_auto_migrate(mut self, on: bool) -> Self {
345        self.auto_migrate = on;
346        self
347    }
348
349    /// Open the pool (if needed), run the migration (if enabled),
350    /// and return the configured backend.
351    pub async fn build(self) -> PgGraphMemoryResult<PgGraphMemory<N, E>> {
352        let pool = match (self.pool, self.url) {
353            (Some(_), Some(_)) => {
354                return Err(PgGraphMemoryError::Config(
355                    "set with_pool() OR with_connection_string(), not both".into(),
356                ));
357            }
358            (Some(pool), None) => pool,
359            (None, Some(url)) => Arc::new(PgPoolOptions::new().connect(&url).await?),
360            (None, None) => {
361                return Err(PgGraphMemoryError::Config(
362                    "with_pool() or with_connection_string() is required".into(),
363                ));
364            }
365        };
366        if self.auto_migrate {
367            bootstrap(&pool, &self.nodes_table, &self.edges_table).await?;
368        }
369        Ok(PgGraphMemory {
370            pool,
371            nodes_table: Arc::from(self.nodes_table),
372            edges_table: Arc::from(self.edges_table),
373            _phantom: PhantomData,
374        })
375    }
376}
377
378#[async_trait]
379impl<N, E> GraphMemory<N, E> for PgGraphMemory<N, E>
380where
381    N: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
382    E: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
383{
384    async fn add_node(&self, _ctx: &ExecutionContext, ns: &Namespace, node: N) -> Result<NodeId> {
385        let id = NodeId::new();
386        let payload = serde_json::to_value(&node).map_err(into_core_codec)?;
387        let sql = format!(
388            "INSERT INTO {} (tenant_id, namespace_key, id, payload) \
389             VALUES ($1, $2, $3, $4)",
390            self.nodes_table
391        );
392        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
393        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
394        sqlx::query(&sql)
395            .bind(ns.tenant_id().as_str())
396            .bind(ns.render())
397            .bind(id.as_str())
398            .bind(&payload)
399            .execute(&mut *tx)
400            .await
401            .map_err(into_core_sqlx)?;
402        tx.commit().await.map_err(into_core_sqlx)?;
403        Ok(id)
404    }
405
406    async fn add_edge(
407        &self,
408        _ctx: &ExecutionContext,
409        ns: &Namespace,
410        from: &NodeId,
411        to: &NodeId,
412        edge: E,
413        timestamp: DateTime<Utc>,
414    ) -> Result<EdgeId> {
415        let id = EdgeId::new();
416        let payload = serde_json::to_value(&edge).map_err(into_core_codec)?;
417        let sql = format!(
418            "INSERT INTO {} (tenant_id, namespace_key, id, from_node, to_node, payload, ts) \
419             VALUES ($1, $2, $3, $4, $5, $6, $7)",
420            self.edges_table
421        );
422        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
423        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
424        sqlx::query(&sql)
425            .bind(ns.tenant_id().as_str())
426            .bind(ns.render())
427            .bind(id.as_str())
428            .bind(from.as_str())
429            .bind(to.as_str())
430            .bind(&payload)
431            .bind(timestamp)
432            .execute(&mut *tx)
433            .await
434            .map_err(into_core_sqlx)?;
435        tx.commit().await.map_err(into_core_sqlx)?;
436        Ok(id)
437    }
438
439    async fn add_edges_batch(
440        &self,
441        _ctx: &ExecutionContext,
442        ns: &Namespace,
443        edges: Vec<(NodeId, NodeId, E, DateTime<Utc>)>,
444    ) -> Result<Vec<EdgeId>> {
445        if edges.is_empty() {
446            return Ok(Vec::new());
447        }
448        // Pre-allocate per-column arrays. Postgres' UNNEST takes one
449        // array per column and zips them row-wise — N edges become
450        // one INSERT … SELECT FROM UNNEST(…), one round-trip
451        // regardless of N.
452        let count = edges.len();
453        let mut ids: Vec<EdgeId> = Vec::with_capacity(count);
454        let mut id_strings: Vec<String> = Vec::with_capacity(count);
455        let mut from_strings: Vec<String> = Vec::with_capacity(count);
456        let mut to_strings: Vec<String> = Vec::with_capacity(count);
457        let mut payloads: Vec<Value> = Vec::with_capacity(count);
458        let mut timestamps: Vec<DateTime<Utc>> = Vec::with_capacity(count);
459        for (from, to, payload, ts) in edges {
460            let id = EdgeId::new();
461            id_strings.push(id.as_str().to_owned());
462            from_strings.push(from.as_str().to_owned());
463            to_strings.push(to.as_str().to_owned());
464            payloads.push(serde_json::to_value(&payload).map_err(into_core_codec)?);
465            timestamps.push(ts);
466            ids.push(id);
467        }
468        let sql = format!(
469            "INSERT INTO {} (tenant_id, namespace_key, id, from_node, to_node, payload, ts) \
470             SELECT $1, $2, e.id, e.from_node, e.to_node, e.payload, e.ts \
471             FROM UNNEST($3::TEXT[], $4::TEXT[], $5::TEXT[], $6::JSONB[], $7::TIMESTAMPTZ[]) \
472                  AS e(id, from_node, to_node, payload, ts)",
473            self.edges_table
474        );
475        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
476        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
477        sqlx::query(&sql)
478            .bind(ns.tenant_id().as_str())
479            .bind(ns.render())
480            .bind(&id_strings)
481            .bind(&from_strings)
482            .bind(&to_strings)
483            .bind(&payloads)
484            .bind(&timestamps)
485            .execute(&mut *tx)
486            .await
487            .map_err(into_core_sqlx)?;
488        tx.commit().await.map_err(into_core_sqlx)?;
489        Ok(ids)
490    }
491
492    async fn get_node(
493        &self,
494        _ctx: &ExecutionContext,
495        ns: &Namespace,
496        id: &NodeId,
497    ) -> Result<Option<N>> {
498        let sql = format!(
499            "SELECT payload FROM {} WHERE namespace_key = $1 AND id = $2",
500            self.nodes_table
501        );
502        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
503        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
504        let row: Option<(Value,)> = sqlx::query_as(&sql)
505            .bind(ns.render())
506            .bind(id.as_str())
507            .fetch_optional(&mut *tx)
508            .await
509            .map_err(into_core_sqlx)?;
510        tx.commit().await.map_err(into_core_sqlx)?;
511        row.map(|(p,)| serde_json::from_value(p).map_err(into_core_codec))
512            .transpose()
513    }
514
515    async fn get_edge(
516        &self,
517        _ctx: &ExecutionContext,
518        ns: &Namespace,
519        edge_id: &EdgeId,
520    ) -> Result<Option<GraphHop<E>>> {
521        let sql = format!(
522            "SELECT from_node, to_node, payload, ts FROM {} \
523             WHERE namespace_key = $1 AND id = $2",
524            self.edges_table
525        );
526        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
527        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
528        let row: Option<(String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
529            .bind(ns.render())
530            .bind(edge_id.as_str())
531            .fetch_optional(&mut *tx)
532            .await
533            .map_err(into_core_sqlx)?;
534        tx.commit().await.map_err(into_core_sqlx)?;
535        row.map(|(fr, to_n, payload, ts)| {
536            let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
537            Ok(GraphHop::new(
538                edge_id.clone(),
539                NodeId::from_string(fr),
540                NodeId::from_string(to_n),
541                edge,
542                ts,
543            ))
544        })
545        .transpose()
546    }
547
548    async fn neighbors(
549        &self,
550        _ctx: &ExecutionContext,
551        ns: &Namespace,
552        node: &NodeId,
553        direction: Direction,
554    ) -> Result<Vec<(EdgeId, NodeId, E)>> {
555        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
556        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
557        let rows = fetch_neighbours(&mut *tx, &self.edges_table, ns, node, direction).await?;
558        tx.commit().await.map_err(into_core_sqlx)?;
559        rows.into_iter()
560            .map(|row| {
561                let payload: E = serde_json::from_value(row.payload).map_err(into_core_codec)?;
562                Ok((row.id, row.neighbour, payload))
563            })
564            .collect()
565    }
566
567    async fn traverse(
568        &self,
569        _ctx: &ExecutionContext,
570        ns: &Namespace,
571        start: &NodeId,
572        direction: Direction,
573        max_depth: usize,
574    ) -> Result<Vec<GraphHop<E>>> {
575        traverse_recursive(self, ns, start, direction, max_depth).await
576    }
577
578    async fn find_path(
579        &self,
580        _ctx: &ExecutionContext,
581        ns: &Namespace,
582        from: &NodeId,
583        to: &NodeId,
584        direction: Direction,
585        max_depth: usize,
586    ) -> Result<Option<Vec<GraphHop<E>>>> {
587        if from == to {
588            return Ok(Some(Vec::new()));
589        }
590        find_path_recursive(self, ns, from, to, direction, max_depth).await
591    }
592
593    async fn temporal_filter(
594        &self,
595        _ctx: &ExecutionContext,
596        ns: &Namespace,
597        from: DateTime<Utc>,
598        to: DateTime<Utc>,
599    ) -> Result<Vec<GraphHop<E>>> {
600        let sql = format!(
601            "SELECT id, from_node, to_node, payload, ts \
602             FROM {} \
603             WHERE namespace_key = $1 AND ts >= $2 AND ts < $3 \
604             ORDER BY ts ASC",
605            self.edges_table
606        );
607        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
608        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
609        let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
610            .bind(ns.render())
611            .bind(from)
612            .bind(to)
613            .fetch_all(&mut *tx)
614            .await
615            .map_err(into_core_sqlx)?;
616        tx.commit().await.map_err(into_core_sqlx)?;
617        rows.into_iter()
618            .map(|(id, fr, to_n, payload, ts)| {
619                let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
620                Ok(GraphHop::new(
621                    EdgeId::from_string(id),
622                    NodeId::from_string(fr),
623                    NodeId::from_string(to_n),
624                    edge,
625                    ts,
626                ))
627            })
628            .collect()
629    }
630
631    async fn node_count(&self, _ctx: &ExecutionContext, ns: &Namespace) -> Result<usize> {
632        let sql = format!(
633            "SELECT COUNT(*) FROM {} WHERE namespace_key = $1",
634            self.nodes_table
635        );
636        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
637        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
638        let row: (i64,) = sqlx::query_as(&sql)
639            .bind(ns.render())
640            .fetch_one(&mut *tx)
641            .await
642            .map_err(into_core_sqlx)?;
643        tx.commit().await.map_err(into_core_sqlx)?;
644        Ok(usize::try_from(row.0.max(0)).unwrap_or(usize::MAX))
645    }
646
647    async fn edge_count(&self, _ctx: &ExecutionContext, ns: &Namespace) -> Result<usize> {
648        let sql = format!(
649            "SELECT COUNT(*) FROM {} WHERE namespace_key = $1",
650            self.edges_table
651        );
652        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
653        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
654        let row: (i64,) = sqlx::query_as(&sql)
655            .bind(ns.render())
656            .fetch_one(&mut *tx)
657            .await
658            .map_err(into_core_sqlx)?;
659        tx.commit().await.map_err(into_core_sqlx)?;
660        Ok(usize::try_from(row.0.max(0)).unwrap_or(usize::MAX))
661    }
662
663    async fn delete_edge(
664        &self,
665        _ctx: &ExecutionContext,
666        ns: &Namespace,
667        edge_id: &EdgeId,
668    ) -> Result<()> {
669        let sql = format!(
670            "DELETE FROM {} WHERE namespace_key = $1 AND id = $2",
671            self.edges_table
672        );
673        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
674        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
675        sqlx::query(&sql)
676            .bind(ns.render())
677            .bind(edge_id.as_str())
678            .execute(&mut *tx)
679            .await
680            .map_err(into_core_sqlx)?;
681        tx.commit().await.map_err(into_core_sqlx)?;
682        Ok(())
683    }
684
685    async fn delete_node(
686        &self,
687        _ctx: &ExecutionContext,
688        ns: &Namespace,
689        node_id: &NodeId,
690    ) -> Result<usize> {
691        // Cascade — drop incident edges first, then the node
692        // itself, in one tenant-stamped transaction so a
693        // concurrent reader never sees a half-applied state
694        // (an edge whose endpoint node is gone, or vice versa).
695        let edges_sql = format!(
696            "DELETE FROM {} \
697             WHERE namespace_key = $1 AND (from_node = $2 OR to_node = $2)",
698            self.edges_table
699        );
700        let nodes_sql = format!(
701            "DELETE FROM {} WHERE namespace_key = $1 AND id = $2",
702            self.nodes_table
703        );
704        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
705        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
706        let edge_result = sqlx::query(&edges_sql)
707            .bind(ns.render())
708            .bind(node_id.as_str())
709            .execute(&mut *tx)
710            .await
711            .map_err(into_core_sqlx)?;
712        sqlx::query(&nodes_sql)
713            .bind(ns.render())
714            .bind(node_id.as_str())
715            .execute(&mut *tx)
716            .await
717            .map_err(into_core_sqlx)?;
718        tx.commit().await.map_err(into_core_sqlx)?;
719        Ok(usize::try_from(edge_result.rows_affected()).unwrap_or(usize::MAX))
720    }
721
722    async fn prune_older_than(
723        &self,
724        _ctx: &ExecutionContext,
725        ns: &Namespace,
726        ttl: std::time::Duration,
727    ) -> Result<usize> {
728        // chrono::Duration is signed and uses i64 nanoseconds; for
729        // pathological ttls (above i64::MAX seconds) saturate to
730        // chrono::Duration::MAX so the cutoff stays in the past.
731        let cutoff = Utc::now() - chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::MAX);
732        let sql = format!(
733            "DELETE FROM {} WHERE namespace_key = $1 AND ts < $2",
734            self.edges_table
735        );
736        let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
737        set_tenant_session(&mut *tx, ns.tenant_id()).await?;
738        let result = sqlx::query(&sql)
739            .bind(ns.render())
740            .bind(cutoff)
741            .execute(&mut *tx)
742            .await
743            .map_err(into_core_sqlx)?;
744        tx.commit().await.map_err(into_core_sqlx)?;
745        Ok(usize::try_from(result.rows_affected()).unwrap_or(usize::MAX))
746    }
747}
748
749/// One row decoded from the `neighbours` projection. `neighbour`
750/// is whichever endpoint *isn't* the queried node (for
751/// `Direction::Both`, we project both — see [`fetch_neighbours`]).
752struct NeighbourRow {
753    id: EdgeId,
754    neighbour: NodeId,
755    payload: Value,
756}
757
758async fn fetch_neighbours<'e, E>(
759    executor: E,
760    edges_table: &str,
761    ns: &Namespace,
762    node: &NodeId,
763    direction: Direction,
764) -> Result<Vec<NeighbourRow>>
765where
766    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
767{
768    let dir = direction_sql(direction)?;
769    let sql = format!(
770        "SELECT id, {next_node} AS neighbour, payload \
771         FROM {edges_table} \
772         WHERE namespace_key = $1 AND {join_pred}",
773        next_node = dir.flat_next_node,
774        join_pred = dir.flat_join_predicate,
775    );
776    let rows: Vec<(String, String, Value)> = sqlx::query_as(&sql)
777        .bind(ns.render())
778        .bind(node.as_str())
779        .fetch_all(executor)
780        .await
781        .map_err(into_core_sqlx)?;
782    Ok(rows
783        .into_iter()
784        .map(|(id, neighbour, payload)| NeighbourRow {
785            id: EdgeId::from_string(id),
786            neighbour: NodeId::from_string(neighbour),
787            payload,
788        })
789        .collect())
790}
791
792/// SQL fragments parameterised on [`Direction`]. `recursive_*`
793/// variants reference `w.frontier` (the current row of the
794/// recursive CTE); `flat_*` variants reference the bound parameter
795/// `$2` (the seed node) for one-shot neighbour lookups.
796struct DirectionSql {
797    recursive_join_predicate: &'static str,
798    recursive_next_node: &'static str,
799    flat_join_predicate: &'static str,
800    flat_next_node: &'static str,
801}
802
803/// `Direction` is `#[non_exhaustive]`. A future variant added
804/// upstream surfaces as a typed encode-time rejection so the
805/// backend never silently approximates an unknown traversal
806/// semantic with one of the existing arms (invariant #15).
807fn direction_sql(direction: Direction) -> Result<DirectionSql> {
808    match direction {
809        Direction::Outgoing => Ok(DirectionSql {
810            recursive_join_predicate: "e.from_node = w.frontier",
811            recursive_next_node: "e.to_node",
812            flat_join_predicate: "from_node = $2",
813            flat_next_node: "to_node",
814        }),
815        Direction::Incoming => Ok(DirectionSql {
816            recursive_join_predicate: "e.to_node = w.frontier",
817            recursive_next_node: "e.from_node",
818            flat_join_predicate: "to_node = $2",
819            flat_next_node: "from_node",
820        }),
821        Direction::Both => Ok(DirectionSql {
822            recursive_join_predicate: "(e.from_node = w.frontier OR e.to_node = w.frontier)",
823            recursive_next_node: "CASE WHEN e.from_node = w.frontier THEN e.to_node ELSE e.from_node END",
824            flat_join_predicate: "(from_node = $2 OR to_node = $2)",
825            flat_next_node: "CASE WHEN from_node = $2 THEN to_node ELSE from_node END",
826        }),
827        other => Err(entelix_core::error::Error::invalid_request(format!(
828            "PgGraphMemory: unsupported Direction variant {other:?}"
829        ))),
830    }
831}
832
833/// Single-round-trip BFS via `WITH RECURSIVE`. Postgres expands
834/// the frontier server-side, carrying a `visited` array per row
835/// for cycle prevention. The outer query keeps the first arrival
836/// at each destination node (BFS shortest-depth wins), preserving
837/// the dedupe semantics of the original layer-by-layer Rust BFS.
838async fn traverse_recursive<N, E>(
839    g: &PgGraphMemory<N, E>,
840    ns: &Namespace,
841    start: &NodeId,
842    direction: Direction,
843    max_depth: usize,
844) -> Result<Vec<GraphHop<E>>>
845where
846    N: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
847    E: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
848{
849    let dir = direction_sql(direction)?;
850    let max_depth_i32 = i32::try_from(max_depth).unwrap_or(i32::MAX);
851    let sql = format!(
852        "WITH RECURSIVE walk(edge_id, edge_from, edge_to, payload, ts, frontier, depth, visited) AS ( \
853            SELECT NULL::TEXT, NULL::TEXT, NULL::TEXT, NULL::JSONB, NULL::TIMESTAMPTZ, \
854                   $2::TEXT, 0, ARRAY[$2::TEXT] \
855            UNION ALL \
856            SELECT e.id, e.from_node, e.to_node, e.payload, e.ts, \
857                   {next_node}, w.depth + 1, w.visited || ({next_node}) \
858            FROM walk w \
859            JOIN {edges_table} e ON e.namespace_key = $1 AND {join_pred} \
860            WHERE w.depth < $3 AND NOT (({next_node}) = ANY(w.visited)) \
861        ), \
862        ranked AS ( \
863            SELECT edge_id, edge_from, edge_to, payload, ts, depth, \
864                   ROW_NUMBER() OVER ( \
865                       PARTITION BY frontier ORDER BY depth ASC, edge_id ASC \
866                   ) AS rn \
867            FROM walk WHERE depth > 0 \
868        ) \
869        SELECT edge_id, edge_from, edge_to, payload, ts \
870        FROM ranked WHERE rn = 1 \
871        ORDER BY depth ASC, edge_id ASC",
872        edges_table = g.edges_table,
873        join_pred = dir.recursive_join_predicate,
874        next_node = dir.recursive_next_node,
875    );
876    let mut tx = g.pool.begin().await.map_err(into_core_sqlx)?;
877    set_tenant_session(&mut *tx, ns.tenant_id()).await?;
878    let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
879        .bind(ns.render())
880        .bind(start.as_str())
881        .bind(max_depth_i32)
882        .fetch_all(&mut *tx)
883        .await
884        .map_err(into_core_sqlx)?;
885    tx.commit().await.map_err(into_core_sqlx)?;
886    rows.into_iter()
887        .map(|(eid, fr, to_n, payload, ts)| {
888            let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
889            Ok(GraphHop::new(
890                EdgeId::from_string(eid),
891                NodeId::from_string(fr),
892                NodeId::from_string(to_n),
893                edge,
894                ts,
895            ))
896        })
897        .collect()
898}
899
900/// Single-round-trip shortest-path via `WITH RECURSIVE`. The
901/// recursive CTE accumulates an `edge_path` array per row; the
902/// outer query picks the shortest path that reached `to` and
903/// rejoins the unrolled edge ids back to the edges table for
904/// payload + endpoint reconstruction. Returns `Ok(None)` when no
905/// path within `max_depth` exists; caller handles `from == to` as
906/// `Ok(Some(Vec::new()))` ahead of this call.
907async fn find_path_recursive<N, E>(
908    g: &PgGraphMemory<N, E>,
909    ns: &Namespace,
910    from: &NodeId,
911    to: &NodeId,
912    direction: Direction,
913    max_depth: usize,
914) -> Result<Option<Vec<GraphHop<E>>>>
915where
916    N: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
917    E: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
918{
919    let dir = direction_sql(direction)?;
920    let max_depth_i32 = i32::try_from(max_depth).unwrap_or(i32::MAX);
921    let sql = format!(
922        "WITH RECURSIVE walk(frontier, depth, visited, edge_path) AS ( \
923            SELECT $2::TEXT, 0, ARRAY[$2::TEXT]::TEXT[], ARRAY[]::TEXT[] \
924            UNION ALL \
925            SELECT {next_node}, w.depth + 1, \
926                   w.visited || ({next_node}), w.edge_path || e.id \
927            FROM walk w \
928            JOIN {edges_table} e ON e.namespace_key = $1 AND {join_pred} \
929            WHERE w.depth < $4 AND w.frontier <> $3 \
930              AND NOT (({next_node}) = ANY(w.visited)) \
931        ), \
932        shortest AS ( \
933            SELECT edge_path FROM walk \
934            WHERE frontier = $3 AND depth > 0 \
935            ORDER BY depth ASC LIMIT 1 \
936        ), \
937        unrolled AS ( \
938            SELECT u.eid, u.ord \
939            FROM shortest s, unnest(s.edge_path) WITH ORDINALITY AS u(eid, ord) \
940        ) \
941        SELECT e.id, e.from_node, e.to_node, e.payload, e.ts \
942        FROM unrolled u \
943        JOIN {edges_table} e ON e.namespace_key = $1 AND e.id = u.eid \
944        ORDER BY u.ord ASC",
945        edges_table = g.edges_table,
946        join_pred = dir.recursive_join_predicate,
947        next_node = dir.recursive_next_node,
948    );
949    let mut tx = g.pool.begin().await.map_err(into_core_sqlx)?;
950    set_tenant_session(&mut *tx, ns.tenant_id()).await?;
951    let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
952        .bind(ns.render())
953        .bind(from.as_str())
954        .bind(to.as_str())
955        .bind(max_depth_i32)
956        .fetch_all(&mut *tx)
957        .await
958        .map_err(into_core_sqlx)?;
959    tx.commit().await.map_err(into_core_sqlx)?;
960    if rows.is_empty() {
961        return Ok(None);
962    }
963    let hops: Vec<GraphHop<E>> = rows
964        .into_iter()
965        .map(|(eid, fr, to_n, payload, ts)| {
966            let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
967            Ok(GraphHop::new(
968                EdgeId::from_string(eid),
969                NodeId::from_string(fr),
970                NodeId::from_string(to_n),
971                edge,
972                ts,
973            ))
974        })
975        .collect::<Result<_>>()?;
976    Ok(Some(hops))
977}
978
979fn into_core_sqlx(e: sqlx::Error) -> entelix_core::error::Error {
980    PgGraphMemoryError::from(e).into()
981}
982
983fn into_core_codec(e: serde_json::Error) -> entelix_core::error::Error {
984    PgGraphMemoryError::from(e).into()
985}