1use std::marker::PhantomData;
40use std::sync::Arc;
41
42use async_trait::async_trait;
43use chrono::{DateTime, Utc};
44use entelix_core::{ExecutionContext, Result};
45use entelix_memory::{Direction, EdgeId, GraphHop, GraphMemory, Namespace, NodeId};
46use serde::Serialize;
47use serde::de::DeserializeOwned;
48use serde_json::Value;
49use sqlx::postgres::{PgPool, PgPoolOptions};
50
51use crate::error::{PgGraphMemoryError, PgGraphMemoryResult};
52use crate::migration::bootstrap;
53use crate::tenant::set_tenant_session;
54
55const DEFAULT_NODES_TABLE: &str = "graph_nodes";
56const DEFAULT_EDGES_TABLE: &str = "graph_edges";
57
58pub struct PgGraphMemory<N, E> {
63 pool: Arc<PgPool>,
64 nodes_table: Arc<str>,
65 edges_table: Arc<str>,
66 _phantom: PhantomData<fn() -> (N, E)>,
67}
68
69impl<N, E> Clone for PgGraphMemory<N, E> {
70 fn clone(&self) -> Self {
71 Self {
72 pool: Arc::clone(&self.pool),
73 nodes_table: Arc::clone(&self.nodes_table),
74 edges_table: Arc::clone(&self.edges_table),
75 _phantom: PhantomData,
76 }
77 }
78}
79
80impl<N, E> std::fmt::Debug for PgGraphMemory<N, E> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 f.debug_struct("PgGraphMemory")
83 .field("nodes_table", &self.nodes_table)
84 .field("edges_table", &self.edges_table)
85 .finish_non_exhaustive()
86 }
87}
88
89impl<N, E> PgGraphMemory<N, E> {
90 pub fn builder() -> PgGraphMemoryBuilder<N, E> {
93 PgGraphMemoryBuilder::new()
94 }
95
96 pub async fn list_nodes(
110 &self,
111 ns: &Namespace,
112 limit: usize,
113 offset: usize,
114 ) -> Result<Vec<NodeId>> {
115 let sql = format!(
116 "SELECT id FROM {} \
117 WHERE namespace_key = $1 \
118 ORDER BY id ASC \
119 LIMIT $2 OFFSET $3",
120 self.nodes_table
121 );
122 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
123 let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
124 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
125 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
126 let rows: Vec<(String,)> = sqlx::query_as(&sql)
127 .bind(ns.render())
128 .bind(limit_i64)
129 .bind(offset_i64)
130 .fetch_all(&mut *tx)
131 .await
132 .map_err(into_core_sqlx)?;
133 tx.commit().await.map_err(into_core_sqlx)?;
134 Ok(rows
135 .into_iter()
136 .map(|(id,)| NodeId::from_string(id))
137 .collect())
138 }
139
140 pub async fn list_node_records(
144 &self,
145 ns: &Namespace,
146 limit: usize,
147 offset: usize,
148 ) -> Result<Vec<(NodeId, N)>>
149 where
150 N: DeserializeOwned,
151 {
152 let sql = format!(
153 "SELECT id, payload FROM {} \
154 WHERE namespace_key = $1 \
155 ORDER BY id ASC \
156 LIMIT $2 OFFSET $3",
157 self.nodes_table
158 );
159 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
160 let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
161 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
162 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
163 let rows: Vec<(String, Value)> = sqlx::query_as(&sql)
164 .bind(ns.render())
165 .bind(limit_i64)
166 .bind(offset_i64)
167 .fetch_all(&mut *tx)
168 .await
169 .map_err(into_core_sqlx)?;
170 tx.commit().await.map_err(into_core_sqlx)?;
171 rows.into_iter()
172 .map(|(id, payload)| {
173 let node: N = serde_json::from_value(payload).map_err(into_core_codec)?;
174 Ok((NodeId::from_string(id), node))
175 })
176 .collect()
177 }
178
179 pub async fn list_edges(
181 &self,
182 ns: &Namespace,
183 limit: usize,
184 offset: usize,
185 ) -> Result<Vec<EdgeId>> {
186 let sql = format!(
187 "SELECT id FROM {} \
188 WHERE namespace_key = $1 \
189 ORDER BY id ASC \
190 LIMIT $2 OFFSET $3",
191 self.edges_table
192 );
193 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
194 let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
195 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
196 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
197 let rows: Vec<(String,)> = sqlx::query_as(&sql)
198 .bind(ns.render())
199 .bind(limit_i64)
200 .bind(offset_i64)
201 .fetch_all(&mut *tx)
202 .await
203 .map_err(into_core_sqlx)?;
204 tx.commit().await.map_err(into_core_sqlx)?;
205 Ok(rows
206 .into_iter()
207 .map(|(id,)| EdgeId::from_string(id))
208 .collect())
209 }
210
211 pub async fn list_edge_records(
214 &self,
215 ns: &Namespace,
216 limit: usize,
217 offset: usize,
218 ) -> Result<Vec<GraphHop<E>>>
219 where
220 E: DeserializeOwned,
221 {
222 let sql = format!(
223 "SELECT id, from_node, to_node, payload, ts FROM {} \
224 WHERE namespace_key = $1 \
225 ORDER BY id ASC \
226 LIMIT $2 OFFSET $3",
227 self.edges_table
228 );
229 let limit_i64 = i64::try_from(limit).unwrap_or(i64::MAX);
230 let offset_i64 = i64::try_from(offset).unwrap_or(i64::MAX);
231 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
232 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
233 let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
234 .bind(ns.render())
235 .bind(limit_i64)
236 .bind(offset_i64)
237 .fetch_all(&mut *tx)
238 .await
239 .map_err(into_core_sqlx)?;
240 tx.commit().await.map_err(into_core_sqlx)?;
241 rows.into_iter()
242 .map(|(id, fr, to_n, payload, ts)| {
243 let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
244 Ok(GraphHop::new(
245 EdgeId::from_string(id),
246 NodeId::from_string(fr),
247 NodeId::from_string(to_n),
248 edge,
249 ts,
250 ))
251 })
252 .collect()
253 }
254
255 pub async fn prune_orphan_nodes(&self, ns: &Namespace) -> Result<usize> {
260 let sql = format!(
261 "DELETE FROM {nodes} \
262 WHERE namespace_key = $1 \
263 AND id NOT IN ( \
264 SELECT from_node FROM {edges} WHERE namespace_key = $1 \
265 UNION \
266 SELECT to_node FROM {edges} WHERE namespace_key = $1 \
267 )",
268 nodes = self.nodes_table,
269 edges = self.edges_table
270 );
271 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
272 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
273 let result = sqlx::query(&sql)
274 .bind(ns.render())
275 .execute(&mut *tx)
276 .await
277 .map_err(into_core_sqlx)?;
278 tx.commit().await.map_err(into_core_sqlx)?;
279 Ok(usize::try_from(result.rows_affected()).unwrap_or(usize::MAX))
280 }
281}
282
283#[must_use]
286pub struct PgGraphMemoryBuilder<N, E> {
287 url: Option<String>,
288 pool: Option<Arc<PgPool>>,
289 nodes_table: String,
290 edges_table: String,
291 auto_migrate: bool,
292 _phantom: PhantomData<fn() -> (N, E)>,
293}
294
295impl<N, E> Default for PgGraphMemoryBuilder<N, E> {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301impl<N, E> PgGraphMemoryBuilder<N, E> {
302 pub fn new() -> Self {
304 Self {
305 url: None,
306 pool: None,
307 nodes_table: DEFAULT_NODES_TABLE.to_owned(),
308 edges_table: DEFAULT_EDGES_TABLE.to_owned(),
309 auto_migrate: true,
310 _phantom: PhantomData,
311 }
312 }
313
314 pub fn with_connection_string(mut self, url: impl Into<String>) -> Self {
318 self.url = Some(url.into());
319 self
320 }
321
322 pub fn with_pool(mut self, pool: Arc<PgPool>) -> Self {
326 self.pool = Some(pool);
327 self
328 }
329
330 pub fn with_nodes_table(mut self, name: impl Into<String>) -> Self {
332 self.nodes_table = name.into();
333 self
334 }
335
336 pub fn with_edges_table(mut self, name: impl Into<String>) -> Self {
338 self.edges_table = name.into();
339 self
340 }
341
342 pub const fn with_auto_migrate(mut self, on: bool) -> Self {
345 self.auto_migrate = on;
346 self
347 }
348
349 pub async fn build(self) -> PgGraphMemoryResult<PgGraphMemory<N, E>> {
352 let pool = match (self.pool, self.url) {
353 (Some(_), Some(_)) => {
354 return Err(PgGraphMemoryError::Config(
355 "set with_pool() OR with_connection_string(), not both".into(),
356 ));
357 }
358 (Some(pool), None) => pool,
359 (None, Some(url)) => Arc::new(PgPoolOptions::new().connect(&url).await?),
360 (None, None) => {
361 return Err(PgGraphMemoryError::Config(
362 "with_pool() or with_connection_string() is required".into(),
363 ));
364 }
365 };
366 if self.auto_migrate {
367 bootstrap(&pool, &self.nodes_table, &self.edges_table).await?;
368 }
369 Ok(PgGraphMemory {
370 pool,
371 nodes_table: Arc::from(self.nodes_table),
372 edges_table: Arc::from(self.edges_table),
373 _phantom: PhantomData,
374 })
375 }
376}
377
378#[async_trait]
379impl<N, E> GraphMemory<N, E> for PgGraphMemory<N, E>
380where
381 N: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
382 E: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
383{
384 async fn add_node(&self, _ctx: &ExecutionContext, ns: &Namespace, node: N) -> Result<NodeId> {
385 let id = NodeId::new();
386 let payload = serde_json::to_value(&node).map_err(into_core_codec)?;
387 let sql = format!(
388 "INSERT INTO {} (tenant_id, namespace_key, id, payload) \
389 VALUES ($1, $2, $3, $4)",
390 self.nodes_table
391 );
392 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
393 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
394 sqlx::query(&sql)
395 .bind(ns.tenant_id().as_str())
396 .bind(ns.render())
397 .bind(id.as_str())
398 .bind(&payload)
399 .execute(&mut *tx)
400 .await
401 .map_err(into_core_sqlx)?;
402 tx.commit().await.map_err(into_core_sqlx)?;
403 Ok(id)
404 }
405
406 async fn add_edge(
407 &self,
408 _ctx: &ExecutionContext,
409 ns: &Namespace,
410 from: &NodeId,
411 to: &NodeId,
412 edge: E,
413 timestamp: DateTime<Utc>,
414 ) -> Result<EdgeId> {
415 let id = EdgeId::new();
416 let payload = serde_json::to_value(&edge).map_err(into_core_codec)?;
417 let sql = format!(
418 "INSERT INTO {} (tenant_id, namespace_key, id, from_node, to_node, payload, ts) \
419 VALUES ($1, $2, $3, $4, $5, $6, $7)",
420 self.edges_table
421 );
422 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
423 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
424 sqlx::query(&sql)
425 .bind(ns.tenant_id().as_str())
426 .bind(ns.render())
427 .bind(id.as_str())
428 .bind(from.as_str())
429 .bind(to.as_str())
430 .bind(&payload)
431 .bind(timestamp)
432 .execute(&mut *tx)
433 .await
434 .map_err(into_core_sqlx)?;
435 tx.commit().await.map_err(into_core_sqlx)?;
436 Ok(id)
437 }
438
439 async fn add_edges_batch(
440 &self,
441 _ctx: &ExecutionContext,
442 ns: &Namespace,
443 edges: Vec<(NodeId, NodeId, E, DateTime<Utc>)>,
444 ) -> Result<Vec<EdgeId>> {
445 if edges.is_empty() {
446 return Ok(Vec::new());
447 }
448 let count = edges.len();
453 let mut ids: Vec<EdgeId> = Vec::with_capacity(count);
454 let mut id_strings: Vec<String> = Vec::with_capacity(count);
455 let mut from_strings: Vec<String> = Vec::with_capacity(count);
456 let mut to_strings: Vec<String> = Vec::with_capacity(count);
457 let mut payloads: Vec<Value> = Vec::with_capacity(count);
458 let mut timestamps: Vec<DateTime<Utc>> = Vec::with_capacity(count);
459 for (from, to, payload, ts) in edges {
460 let id = EdgeId::new();
461 id_strings.push(id.as_str().to_owned());
462 from_strings.push(from.as_str().to_owned());
463 to_strings.push(to.as_str().to_owned());
464 payloads.push(serde_json::to_value(&payload).map_err(into_core_codec)?);
465 timestamps.push(ts);
466 ids.push(id);
467 }
468 let sql = format!(
469 "INSERT INTO {} (tenant_id, namespace_key, id, from_node, to_node, payload, ts) \
470 SELECT $1, $2, e.id, e.from_node, e.to_node, e.payload, e.ts \
471 FROM UNNEST($3::TEXT[], $4::TEXT[], $5::TEXT[], $6::JSONB[], $7::TIMESTAMPTZ[]) \
472 AS e(id, from_node, to_node, payload, ts)",
473 self.edges_table
474 );
475 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
476 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
477 sqlx::query(&sql)
478 .bind(ns.tenant_id().as_str())
479 .bind(ns.render())
480 .bind(&id_strings)
481 .bind(&from_strings)
482 .bind(&to_strings)
483 .bind(&payloads)
484 .bind(×tamps)
485 .execute(&mut *tx)
486 .await
487 .map_err(into_core_sqlx)?;
488 tx.commit().await.map_err(into_core_sqlx)?;
489 Ok(ids)
490 }
491
492 async fn get_node(
493 &self,
494 _ctx: &ExecutionContext,
495 ns: &Namespace,
496 id: &NodeId,
497 ) -> Result<Option<N>> {
498 let sql = format!(
499 "SELECT payload FROM {} WHERE namespace_key = $1 AND id = $2",
500 self.nodes_table
501 );
502 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
503 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
504 let row: Option<(Value,)> = sqlx::query_as(&sql)
505 .bind(ns.render())
506 .bind(id.as_str())
507 .fetch_optional(&mut *tx)
508 .await
509 .map_err(into_core_sqlx)?;
510 tx.commit().await.map_err(into_core_sqlx)?;
511 row.map(|(p,)| serde_json::from_value(p).map_err(into_core_codec))
512 .transpose()
513 }
514
515 async fn get_edge(
516 &self,
517 _ctx: &ExecutionContext,
518 ns: &Namespace,
519 edge_id: &EdgeId,
520 ) -> Result<Option<GraphHop<E>>> {
521 let sql = format!(
522 "SELECT from_node, to_node, payload, ts FROM {} \
523 WHERE namespace_key = $1 AND id = $2",
524 self.edges_table
525 );
526 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
527 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
528 let row: Option<(String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
529 .bind(ns.render())
530 .bind(edge_id.as_str())
531 .fetch_optional(&mut *tx)
532 .await
533 .map_err(into_core_sqlx)?;
534 tx.commit().await.map_err(into_core_sqlx)?;
535 row.map(|(fr, to_n, payload, ts)| {
536 let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
537 Ok(GraphHop::new(
538 edge_id.clone(),
539 NodeId::from_string(fr),
540 NodeId::from_string(to_n),
541 edge,
542 ts,
543 ))
544 })
545 .transpose()
546 }
547
548 async fn neighbors(
549 &self,
550 _ctx: &ExecutionContext,
551 ns: &Namespace,
552 node: &NodeId,
553 direction: Direction,
554 ) -> Result<Vec<(EdgeId, NodeId, E)>> {
555 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
556 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
557 let rows = fetch_neighbours(&mut *tx, &self.edges_table, ns, node, direction).await?;
558 tx.commit().await.map_err(into_core_sqlx)?;
559 rows.into_iter()
560 .map(|row| {
561 let payload: E = serde_json::from_value(row.payload).map_err(into_core_codec)?;
562 Ok((row.id, row.neighbour, payload))
563 })
564 .collect()
565 }
566
567 async fn traverse(
568 &self,
569 _ctx: &ExecutionContext,
570 ns: &Namespace,
571 start: &NodeId,
572 direction: Direction,
573 max_depth: usize,
574 ) -> Result<Vec<GraphHop<E>>> {
575 traverse_recursive(self, ns, start, direction, max_depth).await
576 }
577
578 async fn find_path(
579 &self,
580 _ctx: &ExecutionContext,
581 ns: &Namespace,
582 from: &NodeId,
583 to: &NodeId,
584 direction: Direction,
585 max_depth: usize,
586 ) -> Result<Option<Vec<GraphHop<E>>>> {
587 if from == to {
588 return Ok(Some(Vec::new()));
589 }
590 find_path_recursive(self, ns, from, to, direction, max_depth).await
591 }
592
593 async fn temporal_filter(
594 &self,
595 _ctx: &ExecutionContext,
596 ns: &Namespace,
597 from: DateTime<Utc>,
598 to: DateTime<Utc>,
599 ) -> Result<Vec<GraphHop<E>>> {
600 let sql = format!(
601 "SELECT id, from_node, to_node, payload, ts \
602 FROM {} \
603 WHERE namespace_key = $1 AND ts >= $2 AND ts < $3 \
604 ORDER BY ts ASC",
605 self.edges_table
606 );
607 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
608 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
609 let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
610 .bind(ns.render())
611 .bind(from)
612 .bind(to)
613 .fetch_all(&mut *tx)
614 .await
615 .map_err(into_core_sqlx)?;
616 tx.commit().await.map_err(into_core_sqlx)?;
617 rows.into_iter()
618 .map(|(id, fr, to_n, payload, ts)| {
619 let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
620 Ok(GraphHop::new(
621 EdgeId::from_string(id),
622 NodeId::from_string(fr),
623 NodeId::from_string(to_n),
624 edge,
625 ts,
626 ))
627 })
628 .collect()
629 }
630
631 async fn node_count(&self, _ctx: &ExecutionContext, ns: &Namespace) -> Result<usize> {
632 let sql = format!(
633 "SELECT COUNT(*) FROM {} WHERE namespace_key = $1",
634 self.nodes_table
635 );
636 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
637 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
638 let row: (i64,) = sqlx::query_as(&sql)
639 .bind(ns.render())
640 .fetch_one(&mut *tx)
641 .await
642 .map_err(into_core_sqlx)?;
643 tx.commit().await.map_err(into_core_sqlx)?;
644 Ok(usize::try_from(row.0.max(0)).unwrap_or(usize::MAX))
645 }
646
647 async fn edge_count(&self, _ctx: &ExecutionContext, ns: &Namespace) -> Result<usize> {
648 let sql = format!(
649 "SELECT COUNT(*) FROM {} WHERE namespace_key = $1",
650 self.edges_table
651 );
652 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
653 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
654 let row: (i64,) = sqlx::query_as(&sql)
655 .bind(ns.render())
656 .fetch_one(&mut *tx)
657 .await
658 .map_err(into_core_sqlx)?;
659 tx.commit().await.map_err(into_core_sqlx)?;
660 Ok(usize::try_from(row.0.max(0)).unwrap_or(usize::MAX))
661 }
662
663 async fn delete_edge(
664 &self,
665 _ctx: &ExecutionContext,
666 ns: &Namespace,
667 edge_id: &EdgeId,
668 ) -> Result<()> {
669 let sql = format!(
670 "DELETE FROM {} WHERE namespace_key = $1 AND id = $2",
671 self.edges_table
672 );
673 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
674 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
675 sqlx::query(&sql)
676 .bind(ns.render())
677 .bind(edge_id.as_str())
678 .execute(&mut *tx)
679 .await
680 .map_err(into_core_sqlx)?;
681 tx.commit().await.map_err(into_core_sqlx)?;
682 Ok(())
683 }
684
685 async fn delete_node(
686 &self,
687 _ctx: &ExecutionContext,
688 ns: &Namespace,
689 node_id: &NodeId,
690 ) -> Result<usize> {
691 let edges_sql = format!(
696 "DELETE FROM {} \
697 WHERE namespace_key = $1 AND (from_node = $2 OR to_node = $2)",
698 self.edges_table
699 );
700 let nodes_sql = format!(
701 "DELETE FROM {} WHERE namespace_key = $1 AND id = $2",
702 self.nodes_table
703 );
704 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
705 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
706 let edge_result = sqlx::query(&edges_sql)
707 .bind(ns.render())
708 .bind(node_id.as_str())
709 .execute(&mut *tx)
710 .await
711 .map_err(into_core_sqlx)?;
712 sqlx::query(&nodes_sql)
713 .bind(ns.render())
714 .bind(node_id.as_str())
715 .execute(&mut *tx)
716 .await
717 .map_err(into_core_sqlx)?;
718 tx.commit().await.map_err(into_core_sqlx)?;
719 Ok(usize::try_from(edge_result.rows_affected()).unwrap_or(usize::MAX))
720 }
721
722 async fn prune_older_than(
723 &self,
724 _ctx: &ExecutionContext,
725 ns: &Namespace,
726 ttl: std::time::Duration,
727 ) -> Result<usize> {
728 let cutoff = Utc::now() - chrono::Duration::from_std(ttl).unwrap_or(chrono::Duration::MAX);
732 let sql = format!(
733 "DELETE FROM {} WHERE namespace_key = $1 AND ts < $2",
734 self.edges_table
735 );
736 let mut tx = self.pool.begin().await.map_err(into_core_sqlx)?;
737 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
738 let result = sqlx::query(&sql)
739 .bind(ns.render())
740 .bind(cutoff)
741 .execute(&mut *tx)
742 .await
743 .map_err(into_core_sqlx)?;
744 tx.commit().await.map_err(into_core_sqlx)?;
745 Ok(usize::try_from(result.rows_affected()).unwrap_or(usize::MAX))
746 }
747}
748
749struct NeighbourRow {
753 id: EdgeId,
754 neighbour: NodeId,
755 payload: Value,
756}
757
758async fn fetch_neighbours<'e, E>(
759 executor: E,
760 edges_table: &str,
761 ns: &Namespace,
762 node: &NodeId,
763 direction: Direction,
764) -> Result<Vec<NeighbourRow>>
765where
766 E: sqlx::Executor<'e, Database = sqlx::Postgres>,
767{
768 let dir = direction_sql(direction)?;
769 let sql = format!(
770 "SELECT id, {next_node} AS neighbour, payload \
771 FROM {edges_table} \
772 WHERE namespace_key = $1 AND {join_pred}",
773 next_node = dir.flat_next_node,
774 join_pred = dir.flat_join_predicate,
775 );
776 let rows: Vec<(String, String, Value)> = sqlx::query_as(&sql)
777 .bind(ns.render())
778 .bind(node.as_str())
779 .fetch_all(executor)
780 .await
781 .map_err(into_core_sqlx)?;
782 Ok(rows
783 .into_iter()
784 .map(|(id, neighbour, payload)| NeighbourRow {
785 id: EdgeId::from_string(id),
786 neighbour: NodeId::from_string(neighbour),
787 payload,
788 })
789 .collect())
790}
791
792struct DirectionSql {
797 recursive_join_predicate: &'static str,
798 recursive_next_node: &'static str,
799 flat_join_predicate: &'static str,
800 flat_next_node: &'static str,
801}
802
803fn direction_sql(direction: Direction) -> Result<DirectionSql> {
808 match direction {
809 Direction::Outgoing => Ok(DirectionSql {
810 recursive_join_predicate: "e.from_node = w.frontier",
811 recursive_next_node: "e.to_node",
812 flat_join_predicate: "from_node = $2",
813 flat_next_node: "to_node",
814 }),
815 Direction::Incoming => Ok(DirectionSql {
816 recursive_join_predicate: "e.to_node = w.frontier",
817 recursive_next_node: "e.from_node",
818 flat_join_predicate: "to_node = $2",
819 flat_next_node: "from_node",
820 }),
821 Direction::Both => Ok(DirectionSql {
822 recursive_join_predicate: "(e.from_node = w.frontier OR e.to_node = w.frontier)",
823 recursive_next_node: "CASE WHEN e.from_node = w.frontier THEN e.to_node ELSE e.from_node END",
824 flat_join_predicate: "(from_node = $2 OR to_node = $2)",
825 flat_next_node: "CASE WHEN from_node = $2 THEN to_node ELSE from_node END",
826 }),
827 other => Err(entelix_core::error::Error::invalid_request(format!(
828 "PgGraphMemory: unsupported Direction variant {other:?}"
829 ))),
830 }
831}
832
833async fn traverse_recursive<N, E>(
839 g: &PgGraphMemory<N, E>,
840 ns: &Namespace,
841 start: &NodeId,
842 direction: Direction,
843 max_depth: usize,
844) -> Result<Vec<GraphHop<E>>>
845where
846 N: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
847 E: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
848{
849 let dir = direction_sql(direction)?;
850 let max_depth_i32 = i32::try_from(max_depth).unwrap_or(i32::MAX);
851 let sql = format!(
852 "WITH RECURSIVE walk(edge_id, edge_from, edge_to, payload, ts, frontier, depth, visited) AS ( \
853 SELECT NULL::TEXT, NULL::TEXT, NULL::TEXT, NULL::JSONB, NULL::TIMESTAMPTZ, \
854 $2::TEXT, 0, ARRAY[$2::TEXT] \
855 UNION ALL \
856 SELECT e.id, e.from_node, e.to_node, e.payload, e.ts, \
857 {next_node}, w.depth + 1, w.visited || ({next_node}) \
858 FROM walk w \
859 JOIN {edges_table} e ON e.namespace_key = $1 AND {join_pred} \
860 WHERE w.depth < $3 AND NOT (({next_node}) = ANY(w.visited)) \
861 ), \
862 ranked AS ( \
863 SELECT edge_id, edge_from, edge_to, payload, ts, depth, \
864 ROW_NUMBER() OVER ( \
865 PARTITION BY frontier ORDER BY depth ASC, edge_id ASC \
866 ) AS rn \
867 FROM walk WHERE depth > 0 \
868 ) \
869 SELECT edge_id, edge_from, edge_to, payload, ts \
870 FROM ranked WHERE rn = 1 \
871 ORDER BY depth ASC, edge_id ASC",
872 edges_table = g.edges_table,
873 join_pred = dir.recursive_join_predicate,
874 next_node = dir.recursive_next_node,
875 );
876 let mut tx = g.pool.begin().await.map_err(into_core_sqlx)?;
877 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
878 let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
879 .bind(ns.render())
880 .bind(start.as_str())
881 .bind(max_depth_i32)
882 .fetch_all(&mut *tx)
883 .await
884 .map_err(into_core_sqlx)?;
885 tx.commit().await.map_err(into_core_sqlx)?;
886 rows.into_iter()
887 .map(|(eid, fr, to_n, payload, ts)| {
888 let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
889 Ok(GraphHop::new(
890 EdgeId::from_string(eid),
891 NodeId::from_string(fr),
892 NodeId::from_string(to_n),
893 edge,
894 ts,
895 ))
896 })
897 .collect()
898}
899
900async fn find_path_recursive<N, E>(
908 g: &PgGraphMemory<N, E>,
909 ns: &Namespace,
910 from: &NodeId,
911 to: &NodeId,
912 direction: Direction,
913 max_depth: usize,
914) -> Result<Option<Vec<GraphHop<E>>>>
915where
916 N: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
917 E: Clone + Send + Sync + Serialize + DeserializeOwned + 'static,
918{
919 let dir = direction_sql(direction)?;
920 let max_depth_i32 = i32::try_from(max_depth).unwrap_or(i32::MAX);
921 let sql = format!(
922 "WITH RECURSIVE walk(frontier, depth, visited, edge_path) AS ( \
923 SELECT $2::TEXT, 0, ARRAY[$2::TEXT]::TEXT[], ARRAY[]::TEXT[] \
924 UNION ALL \
925 SELECT {next_node}, w.depth + 1, \
926 w.visited || ({next_node}), w.edge_path || e.id \
927 FROM walk w \
928 JOIN {edges_table} e ON e.namespace_key = $1 AND {join_pred} \
929 WHERE w.depth < $4 AND w.frontier <> $3 \
930 AND NOT (({next_node}) = ANY(w.visited)) \
931 ), \
932 shortest AS ( \
933 SELECT edge_path FROM walk \
934 WHERE frontier = $3 AND depth > 0 \
935 ORDER BY depth ASC LIMIT 1 \
936 ), \
937 unrolled AS ( \
938 SELECT u.eid, u.ord \
939 FROM shortest s, unnest(s.edge_path) WITH ORDINALITY AS u(eid, ord) \
940 ) \
941 SELECT e.id, e.from_node, e.to_node, e.payload, e.ts \
942 FROM unrolled u \
943 JOIN {edges_table} e ON e.namespace_key = $1 AND e.id = u.eid \
944 ORDER BY u.ord ASC",
945 edges_table = g.edges_table,
946 join_pred = dir.recursive_join_predicate,
947 next_node = dir.recursive_next_node,
948 );
949 let mut tx = g.pool.begin().await.map_err(into_core_sqlx)?;
950 set_tenant_session(&mut *tx, ns.tenant_id()).await?;
951 let rows: Vec<(String, String, String, Value, DateTime<Utc>)> = sqlx::query_as(&sql)
952 .bind(ns.render())
953 .bind(from.as_str())
954 .bind(to.as_str())
955 .bind(max_depth_i32)
956 .fetch_all(&mut *tx)
957 .await
958 .map_err(into_core_sqlx)?;
959 tx.commit().await.map_err(into_core_sqlx)?;
960 if rows.is_empty() {
961 return Ok(None);
962 }
963 let hops: Vec<GraphHop<E>> = rows
964 .into_iter()
965 .map(|(eid, fr, to_n, payload, ts)| {
966 let edge: E = serde_json::from_value(payload).map_err(into_core_codec)?;
967 Ok(GraphHop::new(
968 EdgeId::from_string(eid),
969 NodeId::from_string(fr),
970 NodeId::from_string(to_n),
971 edge,
972 ts,
973 ))
974 })
975 .collect::<Result<_>>()?;
976 Ok(Some(hops))
977}
978
979fn into_core_sqlx(e: sqlx::Error) -> entelix_core::error::Error {
980 PgGraphMemoryError::from(e).into()
981}
982
983fn into_core_codec(e: serde_json::Error) -> entelix_core::error::Error {
984 PgGraphMemoryError::from(e).into()
985}