use async_trait::async_trait;
use chrono::{DateTime, Utc};
use sea_orm::sea_query::{Alias, Cond, Expr, Iden, OnConflict, Query};
use sea_orm::{ConnectionTrait, Database, DatabaseBackend, DatabaseConnection, Statement};
use sea_orm_migration::MigratorTrait;
use serde_json::{Value, json};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::fmt;
use tracing::debug;
use crate::error::{GraphDBError, GraphDBResult};
use crate::traits::GraphDBTrait;
use crate::types::{EdgeData, GraphNode, NodeData};
const BATCH_SIZE: usize = 100;
const ALLOWED_FILTER_ATTRS: &[&str] = &["id", "name", "type"];
#[derive(Clone, Copy)]
enum GNode {
Table,
Id,
Name,
Type,
Properties,
CreatedAt,
UpdatedAt,
}
impl Iden for GNode {
#[allow(
clippy::expect_used,
reason = "writing a static &str into the fmt::Write sink is infallible"
)]
fn unquoted(&self, s: &mut dyn fmt::Write) {
write!(
s,
"{}",
match self {
Self::Table => "graph_node",
Self::Id => "id",
Self::Name => "name",
Self::Type => "type",
Self::Properties => "properties",
Self::CreatedAt => "created_at",
Self::UpdatedAt => "updated_at",
}
)
.expect("write to string cannot fail");
}
}
#[derive(Clone, Copy)]
enum GEdge {
Table,
SourceId,
TargetId,
RelationshipName,
Properties,
CreatedAt,
UpdatedAt,
}
impl Iden for GEdge {
#[allow(
clippy::expect_used,
reason = "writing a static &str into the fmt::Write sink is infallible"
)]
fn unquoted(&self, s: &mut dyn fmt::Write) {
write!(
s,
"{}",
match self {
Self::Table => "graph_edge",
Self::SourceId => "source_id",
Self::TargetId => "target_id",
Self::RelationshipName => "relationship_name",
Self::Properties => "properties",
Self::CreatedAt => "created_at",
Self::UpdatedAt => "updated_at",
}
)
.expect("write to string cannot fail");
}
}
struct NodeRow {
id: String,
name: String,
node_type: String,
properties: Value,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
pub struct PgGraphAdapter {
db: DatabaseConnection,
}
impl PgGraphAdapter {
pub async fn new(database_url: &str) -> GraphDBResult<Self> {
let db = Database::connect(database_url)
.await
.map_err(|e| GraphDBError::ConnectionError(format!("PgGraph connect failed: {e}")))?;
migrator::Migrator::up(&db, None).await.map_err(|e| {
GraphDBError::InitializationError(format!("PgGraph migration failed: {e}"))
})?;
debug!("PgGraphAdapter initialised");
Ok(Self { db })
}
pub async fn from_connection(db: DatabaseConnection) -> GraphDBResult<Self> {
migrator::Migrator::up(&db, None).await.map_err(|e| {
GraphDBError::InitializationError(format!("PgGraph migration failed: {e}"))
})?;
Ok(Self { db })
}
fn build<S: sea_orm::StatementBuilder>(&self, query: &S) -> Statement {
self.db.get_database_backend().build(query)
}
fn serialize_node_to_row(node: &Value) -> GraphDBResult<NodeRow> {
let obj = node
.as_object()
.ok_or_else(|| GraphDBError::NodeError("Expected JSON object for node".into()))?;
let id = obj
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let name = obj
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let node_type = obj
.get("type")
.or_else(|| obj.get("data_type"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let now = Utc::now();
let created_at = obj
.get("created_at")
.and_then(|v| v.as_str())
.and_then(|s| DateTime::parse_from_rfc3339(s).ok())
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or(now);
let updated_at = obj
.get("updated_at")
.and_then(|v| v.as_str())
.and_then(|s| DateTime::parse_from_rfc3339(s).ok())
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or(now);
let core_keys = [
"id",
"name",
"type",
"data_type",
"created_at",
"updated_at",
];
let extra: serde_json::Map<String, Value> = obj
.iter()
.filter(|(k, _)| !core_keys.contains(&k.as_str()))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Ok(NodeRow {
id,
name,
node_type,
properties: Value::Object(extra),
created_at,
updated_at,
})
}
fn parse_node_row(row: &sea_orm::QueryResult) -> GraphDBResult<NodeData> {
let id: String = row
.try_get("", "id")
.map_err(|e| GraphDBError::QueryError(format!("missing id column: {e}")))?;
let name: String = row
.try_get("", "name")
.map_err(|e| GraphDBError::QueryError(format!("missing name column: {e}")))?;
let node_type: String = row
.try_get("", "type")
.map_err(|e| GraphDBError::QueryError(format!("missing type column: {e}")))?;
let properties: Option<Value> = row.try_get("", "properties").unwrap_or(None);
let mut data = NodeData::new();
data.insert(Cow::Borrowed("id"), json!(id));
data.insert(Cow::Borrowed("name"), json!(name));
data.insert(Cow::Borrowed("type"), json!(node_type));
if let Some(Value::Object(extra)) = properties {
for (k, v) in extra {
data.insert(Cow::Owned(k), v);
}
}
Ok(data)
}
fn parse_edge_row(row: &sea_orm::QueryResult) -> GraphDBResult<EdgeData> {
Self::parse_edge_row_cols(
row,
"source_id",
"target_id",
"relationship_name",
"properties",
)
}
fn parse_edge_row_cols(
row: &sea_orm::QueryResult,
src_col: &str,
tgt_col: &str,
rel_col: &str,
props_col: &str,
) -> GraphDBResult<EdgeData> {
let source_id: String = row.try_get("", src_col).unwrap_or_default();
let target_id: String = row.try_get("", tgt_col).unwrap_or_default();
let rel_name: String = row.try_get("", rel_col).unwrap_or_default();
let props: Option<Value> = row.try_get("", props_col).unwrap_or(None);
let props_map = match props {
Some(Value::Object(m)) => m.into_iter().map(|(k, v)| (Cow::Owned(k), v)).collect(),
_ => HashMap::new(),
};
Ok((source_id, target_id, rel_name, props_map))
}
}
#[async_trait]
impl GraphDBTrait for PgGraphAdapter {
async fn initialize(&self) -> GraphDBResult<()> {
migrator::Migrator::up(&self.db, None).await.map_err(|e| {
GraphDBError::InitializationError(format!("PgGraph migration failed: {e}"))
})?;
Ok(())
}
async fn is_empty(&self) -> GraphDBResult<bool> {
let query = Query::select()
.expr(Expr::val(1))
.from(GNode::Table)
.limit(1)
.to_owned();
let row = self
.db
.query_one(self.build(&query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
Ok(row.is_none())
}
async fn query(
&self,
_query: &str,
_params: Option<HashMap<Cow<'static, str>, Value>>,
) -> GraphDBResult<Vec<Vec<Value>>> {
Err(GraphDBError::QueryError(
"The PostgreSQL graph backend does not support raw Cypher queries. \
Use a graph-native backend (Ladybug, Neo4j) for raw query support, \
or use the typed adapter methods (add_nodes, get_neighbors, etc.)."
.into(),
))
}
async fn delete_graph(&self) -> GraphDBResult<()> {
self.db
.execute_unprepared("TRUNCATE graph_edge, graph_node CASCADE")
.await
.map_err(|e| GraphDBError::QueryError(format!("Failed to truncate graph: {e}")))?;
Ok(())
}
async fn has_node(&self, node_id: &str) -> GraphDBResult<bool> {
let inner = Query::select()
.expr(Expr::val(1))
.from(GNode::Table)
.and_where(Expr::col(GNode::Id).eq(node_id))
.to_owned();
let query = Query::select()
.expr_as(Expr::exists(inner), Alias::new("ex"))
.to_owned();
let row = self
.db
.query_one(self.build(&query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
match row {
Some(r) => {
let ex: bool = r
.try_get("", "ex")
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
Ok(ex)
}
None => Ok(false),
}
}
async fn add_node_raw(&self, node: Value) -> GraphDBResult<()> {
self.add_nodes_raw(vec![node]).await
}
async fn add_nodes_raw(&self, nodes: Vec<Value>) -> GraphDBResult<()> {
if nodes.is_empty() {
return Ok(());
}
let mut seen: HashMap<String, NodeRow> = HashMap::new();
for node in &nodes {
let row = Self::serialize_node_to_row(node)?;
seen.insert(row.id.clone(), row);
}
let rows: Vec<NodeRow> = seen.into_values().collect();
for chunk in rows.chunks(BATCH_SIZE) {
let mut insert = Query::insert()
.into_table(GNode::Table)
.columns([
GNode::Id,
GNode::Name,
GNode::Type,
GNode::Properties,
GNode::CreatedAt,
GNode::UpdatedAt,
])
.to_owned();
for row in chunk {
insert.values_panic([
row.id.clone().into(),
row.name.clone().into(),
row.node_type.clone().into(),
row.properties.clone().into(),
row.created_at.into(),
row.updated_at.into(),
]);
}
insert.on_conflict(
OnConflict::column(GNode::Id)
.update_columns([GNode::Name, GNode::Type, GNode::Properties])
.value(GNode::UpdatedAt, Expr::current_timestamp())
.to_owned(),
);
self.db
.execute(self.build(&insert))
.await
.map_err(|e| GraphDBError::NodeError(format!("Failed to upsert nodes: {e}")))?;
}
Ok(())
}
async fn delete_node(&self, node_id: &str) -> GraphDBResult<()> {
let query = Query::delete()
.from_table(GNode::Table)
.and_where(Expr::col(GNode::Id).eq(node_id))
.to_owned();
self.db
.execute(self.build(&query))
.await
.map_err(|e| GraphDBError::NodeError(format!("Failed to delete node: {e}")))?;
Ok(())
}
async fn delete_nodes(&self, node_ids: &[String]) -> GraphDBResult<()> {
if node_ids.is_empty() {
return Ok(());
}
let query = Query::delete()
.from_table(GNode::Table)
.and_where(Expr::col(GNode::Id).is_in(node_ids.iter().map(|s| s.as_str())))
.to_owned();
self.db
.execute(self.build(&query))
.await
.map_err(|e| GraphDBError::NodeError(format!("Failed to delete nodes: {e}")))?;
Ok(())
}
async fn get_node(&self, node_id: &str) -> GraphDBResult<Option<NodeData>> {
let query = Query::select()
.columns([GNode::Id, GNode::Name, GNode::Type, GNode::Properties])
.from(GNode::Table)
.and_where(Expr::col(GNode::Id).eq(node_id))
.to_owned();
let row = self
.db
.query_one(self.build(&query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
match row {
Some(r) => Ok(Some(Self::parse_node_row(&r)?)),
None => Ok(None),
}
}
async fn get_nodes(&self, node_ids: &[String]) -> GraphDBResult<Vec<NodeData>> {
if node_ids.is_empty() {
return Ok(vec![]);
}
let query = Query::select()
.columns([GNode::Id, GNode::Name, GNode::Type, GNode::Properties])
.from(GNode::Table)
.and_where(Expr::col(GNode::Id).is_in(node_ids.iter().map(|s| s.as_str())))
.to_owned();
let rows = self
.db
.query_all(self.build(&query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
rows.iter().map(Self::parse_node_row).collect()
}
async fn has_edge(
&self,
source_id: &str,
target_id: &str,
relationship_name: &str,
) -> GraphDBResult<bool> {
let inner = Query::select()
.expr(Expr::val(1))
.from(GEdge::Table)
.and_where(Expr::col(GEdge::SourceId).eq(source_id))
.and_where(Expr::col(GEdge::TargetId).eq(target_id))
.and_where(Expr::col(GEdge::RelationshipName).eq(relationship_name))
.to_owned();
let query = Query::select()
.expr_as(Expr::exists(inner), Alias::new("ex"))
.to_owned();
let row = self
.db
.query_one(self.build(&query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
match row {
Some(r) => {
let ex: bool = r
.try_get("", "ex")
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
Ok(ex)
}
None => Ok(false),
}
}
async fn has_edges(&self, edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>> {
if edges.is_empty() {
return Ok(vec![]);
}
let sources: Vec<_> = edges.iter().map(|e| e.0.clone()).collect();
let targets: Vec<_> = edges.iter().map(|e| e.1.clone()).collect();
let rels: Vec<_> = edges.iter().map(|e| e.2.clone()).collect();
let rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"SELECT v.s, v.t, v.r \
FROM unnest($1::text[], $2::text[], $3::text[]) AS v(s, t, r) \
WHERE EXISTS ( \
SELECT 1 FROM graph_edge e \
WHERE e.source_id = v.s \
AND e.target_id = v.t \
AND e.relationship_name = v.r \
)",
[sources.into(), targets.into(), rels.into()],
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut existing: HashSet<_> = HashSet::with_capacity(rows.len());
for row in &rows {
let s: String = row
.try_get("", "s")
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let t: String = row
.try_get("", "t")
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let r: String = row
.try_get("", "r")
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
existing.insert((s, t, r));
}
let found = edges
.iter()
.filter(|e| existing.contains(&(e.0.clone(), e.1.clone(), e.2.clone())))
.cloned()
.collect();
Ok(found)
}
async fn add_edge(
&self,
source_id: &str,
target_id: &str,
relationship_name: &str,
properties: Option<HashMap<Cow<'static, str>, Value>>,
) -> GraphDBResult<()> {
let props = properties.unwrap_or_default();
let edge: EdgeData = (
source_id.to_string(),
target_id.to_string(),
relationship_name.to_string(),
props,
);
self.add_edges(&[edge]).await
}
async fn add_edges(&self, edges: &[EdgeData]) -> GraphDBResult<()> {
if edges.is_empty() {
return Ok(());
}
let now = Utc::now();
let mut seen: HashMap<(String, String, String), &EdgeData> = HashMap::new();
for edge in edges {
seen.insert((edge.0.clone(), edge.1.clone(), edge.2.clone()), edge);
}
let deduped: Vec<&EdgeData> = seen.into_values().collect();
for chunk in deduped.chunks(BATCH_SIZE) {
let mut insert = Query::insert()
.into_table(GEdge::Table)
.columns([
GEdge::SourceId,
GEdge::TargetId,
GEdge::RelationshipName,
GEdge::Properties,
GEdge::CreatedAt,
GEdge::UpdatedAt,
])
.to_owned();
for edge in chunk {
let props_json =
serde_json::to_value(&edge.3).map_err(GraphDBError::SerializationError)?;
insert.values_panic([
edge.0.clone().into(),
edge.1.clone().into(),
edge.2.clone().into(),
props_json.into(),
now.into(),
now.into(),
]);
}
insert.on_conflict(
OnConflict::columns([GEdge::SourceId, GEdge::TargetId, GEdge::RelationshipName])
.update_column(GEdge::Properties)
.value(GEdge::UpdatedAt, Expr::current_timestamp())
.to_owned(),
);
self.db
.execute(self.build(&insert))
.await
.map_err(|e| GraphDBError::EdgeError(format!("Failed to upsert edges: {e}")))?;
}
Ok(())
}
async fn get_edges(&self, node_id: &str) -> GraphDBResult<Vec<EdgeData>> {
let query = Query::select()
.columns([
GEdge::SourceId,
GEdge::TargetId,
GEdge::RelationshipName,
GEdge::Properties,
])
.from(GEdge::Table)
.cond_where(
Cond::any()
.add(Expr::col(GEdge::SourceId).eq(node_id))
.add(Expr::col(GEdge::TargetId).eq(node_id)),
)
.to_owned();
let rows = self
.db
.query_all(self.build(&query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
rows.iter().map(Self::parse_edge_row).collect()
}
async fn get_neighbors(&self, node_id: &str) -> GraphDBResult<Vec<NodeData>> {
let rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"SELECT DISTINCT m.id, m.name, m.type, m.properties \
FROM graph_edge e \
JOIN graph_node m ON m.id = CASE \
WHEN e.source_id = $1 THEN e.target_id \
ELSE e.source_id \
END \
WHERE e.source_id = $1 OR e.target_id = $1",
[node_id.into()],
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
rows.iter().map(Self::parse_node_row).collect()
}
async fn get_connections(
&self,
node_id: &str,
) -> GraphDBResult<Vec<(NodeData, HashMap<Cow<'static, str>, Value>, NodeData)>> {
let rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
"SELECT \
n.id AS src_id, n.name AS src_name, n.type AS src_type, n.properties AS src_props, \
e.relationship_name, e.properties AS edge_props, \
m.id AS tgt_id, m.name AS tgt_name, m.type AS tgt_type, m.properties AS tgt_props \
FROM graph_edge e \
JOIN graph_node n ON n.id = e.source_id \
JOIN graph_node m ON m.id = e.target_id \
WHERE e.source_id = $1 OR e.target_id = $1",
[node_id.into()],
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut connections = Vec::new();
for row in &rows {
let mut source = NodeData::new();
let src_id: String = row.try_get("", "src_id").unwrap_or_default();
let src_name: String = row.try_get("", "src_name").unwrap_or_default();
let src_type: String = row.try_get("", "src_type").unwrap_or_default();
let src_props: Option<Value> = row.try_get("", "src_props").unwrap_or(None);
source.insert(Cow::Borrowed("id"), json!(src_id));
source.insert(Cow::Borrowed("name"), json!(src_name));
source.insert(Cow::Borrowed("type"), json!(src_type));
if let Some(Value::Object(extra)) = src_props {
for (k, v) in extra {
source.insert(Cow::Owned(k), v);
}
}
let mut edge_props_map: HashMap<Cow<'static, str>, Value> = HashMap::new();
let rel_name: String = row.try_get("", "relationship_name").unwrap_or_default();
edge_props_map.insert(Cow::Borrowed("relationship_name"), json!(rel_name));
let edge_props_raw: Option<Value> = row.try_get("", "edge_props").unwrap_or(None);
if let Some(Value::Object(extra)) = edge_props_raw {
for (k, v) in extra {
edge_props_map.insert(Cow::Owned(k), v);
}
}
let mut target = NodeData::new();
let tgt_id: String = row.try_get("", "tgt_id").unwrap_or_default();
let tgt_name: String = row.try_get("", "tgt_name").unwrap_or_default();
let tgt_type: String = row.try_get("", "tgt_type").unwrap_or_default();
let tgt_props: Option<Value> = row.try_get("", "tgt_props").unwrap_or(None);
target.insert(Cow::Borrowed("id"), json!(tgt_id));
target.insert(Cow::Borrowed("name"), json!(tgt_name));
target.insert(Cow::Borrowed("type"), json!(tgt_type));
if let Some(Value::Object(extra)) = tgt_props {
for (k, v) in extra {
target.insert(Cow::Owned(k), v);
}
}
connections.push((source, edge_props_map, target));
}
Ok(connections)
}
async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
let node_query = Query::select()
.columns([GNode::Id, GNode::Name, GNode::Type, GNode::Properties])
.from(GNode::Table)
.to_owned();
let node_rows = self
.db
.query_all(self.build(&node_query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut nodes = Vec::new();
for row in &node_rows {
let data = Self::parse_node_row(row)?;
let id = data
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
nodes.push((id, data));
}
let edge_query = Query::select()
.columns([
GEdge::SourceId,
GEdge::TargetId,
GEdge::RelationshipName,
GEdge::Properties,
])
.from(GEdge::Table)
.to_owned();
let edge_rows = self
.db
.query_all(self.build(&edge_query))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut edges = Vec::new();
for row in &edge_rows {
edges.push(Self::parse_edge_row(row)?);
}
Ok((nodes, edges))
}
async fn get_graph_metrics(
&self,
include_optional: bool,
) -> GraphDBResult<HashMap<Cow<'static, str>, Value>> {
let mut metrics = HashMap::new();
let n_row = self
.db
.query_one(Statement::from_string(
DatabaseBackend::Postgres,
"SELECT count(*) AS cnt FROM graph_node".to_string(),
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let num_nodes: i64 = n_row
.as_ref()
.and_then(|r| r.try_get("", "cnt").ok())
.unwrap_or(0);
let e_row = self
.db
.query_one(Statement::from_string(
DatabaseBackend::Postgres,
"SELECT count(*) AS cnt FROM graph_edge".to_string(),
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let num_edges: i64 = e_row
.as_ref()
.and_then(|r| r.try_get("", "cnt").ok())
.unwrap_or(0);
metrics.insert(Cow::Borrowed("node_count"), json!(num_nodes));
metrics.insert(Cow::Borrowed("edge_count"), json!(num_edges));
let mean_degree = if num_nodes > 0 {
(2.0 * num_edges as f64) / num_nodes as f64
} else {
0.0
};
let edge_density = if num_nodes > 1 {
num_edges as f64 / (num_nodes as f64 * (num_nodes as f64 - 1.0))
} else {
0.0
};
metrics.insert(Cow::Borrowed("mean_degree"), json!(mean_degree));
metrics.insert(Cow::Borrowed("edge_density"), json!(edge_density));
let comp_rows = self
.db
.query_all(Statement::from_string(
DatabaseBackend::Postgres,
"WITH RECURSIVE component AS ( \
SELECT id AS node_id, id AS comp_root FROM graph_node \
UNION \
SELECT CASE WHEN e.source_id = c.node_id THEN e.target_id ELSE e.source_id END, \
c.comp_root \
FROM component c \
JOIN graph_edge e ON e.source_id = c.node_id OR e.target_id = c.node_id \
), \
node_comp AS ( \
SELECT node_id, MIN(comp_root) AS comp_id FROM component GROUP BY node_id \
) \
SELECT comp_id, count(*) AS sz FROM node_comp GROUP BY comp_id ORDER BY sz DESC"
.to_string(),
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let component_sizes: Vec<Value> = comp_rows
.iter()
.filter_map(|r| {
let sz: i64 = r.try_get("", "sz").ok()?;
Some(json!(sz))
})
.collect();
let num_components = component_sizes.len();
metrics.insert(
Cow::Borrowed("num_connected_components"),
json!(num_components),
);
metrics.insert(
Cow::Borrowed("sizes_of_connected_components"),
Value::Array(component_sizes),
);
if include_optional {
let sl_row = self
.db
.query_one(Statement::from_string(
DatabaseBackend::Postgres,
"SELECT count(*) AS cnt FROM graph_edge WHERE source_id = target_id"
.to_string(),
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let num_selfloops: i64 = sl_row
.as_ref()
.and_then(|r| r.try_get("", "cnt").ok())
.unwrap_or(0);
metrics.insert(Cow::Borrowed("num_selfloops"), json!(num_selfloops));
}
Ok(metrics)
}
async fn get_filtered_graph_data(
&self,
attribute_filters: &HashMap<Cow<'static, str>, Vec<Value>>,
) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
if attribute_filters.is_empty() {
return self.get_graph_data().await;
}
let mut where_parts = Vec::new();
let mut values: Vec<sea_orm::Value> = Vec::new();
let mut param_idx = 1u32;
for (attr, filter_values) in attribute_filters {
if filter_values.is_empty() {
continue;
}
if !ALLOWED_FILTER_ATTRS.contains(&attr.as_ref()) {
return Err(GraphDBError::QueryError(format!(
"Invalid filter attribute: {attr:?}. Allowed: {ALLOWED_FILTER_ATTRS:?}"
)));
}
let placeholders: Vec<String> = filter_values
.iter()
.map(|v| {
let ph = format!("${param_idx}");
param_idx += 1;
let s = v
.as_str()
.map(String::from)
.unwrap_or_else(|| v.to_string());
values.push(s.into());
ph
})
.collect();
where_parts.push(format!("n.{attr} IN ({})", placeholders.join(", ")));
}
if where_parts.is_empty() {
return self.get_graph_data().await;
}
let where_clause = where_parts.join(" AND ");
let sql = format!(
"WITH filtered_nodes AS ( \
SELECT id, name, type, properties FROM graph_node n WHERE {where_clause} \
) \
SELECT 'node' AS kind, fn.id, fn.name, fn.type, fn.properties, \
NULL::text AS source_id, NULL::text AS target_id, \
NULL::text AS relationship_name, NULL::jsonb AS edge_props \
FROM filtered_nodes fn \
UNION ALL \
SELECT 'edge', NULL, NULL, NULL, NULL, \
e.source_id, e.target_id, e.relationship_name, e.properties \
FROM graph_edge e \
WHERE e.source_id IN (SELECT id FROM filtered_nodes) \
AND e.target_id IN (SELECT id FROM filtered_nodes)"
);
let rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
&sql,
values,
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut nodes = Vec::new();
let mut edges = Vec::new();
for row in &rows {
let kind: String = row.try_get("", "kind").unwrap_or_default();
if kind == "node" {
let data = Self::parse_node_row(row)?;
let id = data
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
nodes.push((id, data));
} else {
edges.push(Self::parse_edge_row_cols(
row,
"source_id",
"target_id",
"relationship_name",
"edge_props",
)?);
}
}
Ok((nodes, edges))
}
async fn get_nodeset_subgraph(
&self,
node_type: &str,
node_names: &[String],
node_name_filter_operator: &str,
) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
if node_names.is_empty() {
return Ok((vec![], vec![]));
}
let name_placeholders: Vec<String> = (0..node_names.len())
.map(|i| format!("${}", i + 2))
.collect();
let names_in = name_placeholders.join(", ");
let neighbor_cte = if node_name_filter_operator == "OR" {
"neighbor_ids AS ( \
SELECT DISTINCT CASE \
WHEN e.source_id IN (SELECT id FROM primary_nodes) \
THEN e.target_id ELSE e.source_id \
END AS id \
FROM graph_edge e \
WHERE e.source_id IN (SELECT id FROM primary_nodes) \
OR e.target_id IN (SELECT id FROM primary_nodes) \
)"
.to_string()
} else {
let primary_count_param = format!("${}", node_names.len() + 2);
format!(
"neighbor_ids AS ( \
SELECT nbr_id AS id FROM ( \
SELECT CASE \
WHEN e.source_id IN (SELECT id FROM primary_nodes) \
THEN e.target_id ELSE e.source_id \
END AS nbr_id, \
CASE \
WHEN e.source_id IN (SELECT id FROM primary_nodes) \
THEN e.source_id ELSE e.target_id \
END AS primary_id \
FROM graph_edge e \
WHERE e.source_id IN (SELECT id FROM primary_nodes) \
OR e.target_id IN (SELECT id FROM primary_nodes) \
) sub \
GROUP BY nbr_id \
HAVING COUNT(DISTINCT primary_id) = {primary_count_param} \
)"
)
};
let sql = format!(
"WITH primary_nodes AS ( \
SELECT DISTINCT id FROM graph_node WHERE type = $1 AND name IN ({names_in}) \
), \
{neighbor_cte}, \
all_ids AS ( \
SELECT id FROM primary_nodes UNION SELECT id FROM neighbor_ids \
) \
SELECT 'node' AS kind, n.id, n.name, n.type, n.properties, \
NULL::text AS source_id, NULL::text AS target_id, \
NULL::text AS relationship_name, NULL::jsonb AS edge_props \
FROM graph_node n WHERE n.id IN (SELECT id FROM all_ids) \
UNION ALL \
SELECT 'edge', NULL, NULL, NULL, NULL, \
e.source_id, e.target_id, e.relationship_name, e.properties \
FROM graph_edge e \
WHERE e.source_id IN (SELECT id FROM all_ids) \
AND e.target_id IN (SELECT id FROM all_ids)"
);
let mut values: Vec<sea_orm::Value> = Vec::new();
values.push(node_type.into());
for name in node_names {
values.push(name.clone().into());
}
if node_name_filter_operator != "OR" {
values.push((node_names.len() as i64).into());
}
let rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
&sql,
values,
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut nodes = Vec::new();
let mut edges = Vec::new();
for row in &rows {
let kind: String = row.try_get("", "kind").unwrap_or_default();
if kind == "node" {
let data = Self::parse_node_row(row)?;
let id = data
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
nodes.push((id, data));
} else {
edges.push(Self::parse_edge_row_cols(
row,
"source_id",
"target_id",
"relationship_name",
"edge_props",
)?);
}
}
Ok((nodes, edges))
}
async fn get_id_filtered_graph_data(
&self,
node_ids: &[String],
) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
if node_ids.is_empty() {
return Ok((vec![], vec![]));
}
let placeholders: Vec<String> = (1..=node_ids.len()).map(|i| format!("${i}")).collect();
let in_clause = placeholders.join(", ");
let node_sql =
format!("SELECT id, name, type, properties FROM graph_node WHERE id IN ({in_clause})");
let edge_sql = format!(
"SELECT source_id, target_id, relationship_name, properties FROM graph_edge \
WHERE source_id IN ({in_clause}) AND target_id IN ({in_clause})"
);
let values: Vec<sea_orm::Value> = node_ids.iter().map(|id| id.clone().into()).collect();
let node_rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
&node_sql,
values.clone(),
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut nodes = Vec::new();
for row in &node_rows {
let data = Self::parse_node_row(row)?;
let id = data
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
nodes.push((id, data));
}
let edge_rows = self
.db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Postgres,
&edge_sql,
values,
))
.await
.map_err(|e| GraphDBError::QueryError(e.to_string()))?;
let mut edges = Vec::new();
for row in &edge_rows {
edges.push(Self::parse_edge_row(row)?);
}
Ok((nodes, edges))
}
}
mod migrator {
use sea_orm_migration::prelude::*;
pub struct Migrator;
#[async_trait::async_trait]
impl MigratorTrait for Migrator {
fn migrations() -> Vec<Box<dyn MigrationTrait>> {
vec![Box::new(CreateGraphTables)]
}
}
struct CreateGraphTables;
impl MigrationName for CreateGraphTables {
fn name(&self) -> &str {
"m20250101_000001_create_graph_tables"
}
}
#[async_trait::async_trait]
impl MigrationTrait for CreateGraphTables {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
let conn = manager.get_connection();
conn.execute_unprepared(
"CREATE TABLE IF NOT EXISTS graph_node ( \
id VARCHAR PRIMARY KEY, \
name VARCHAR NOT NULL DEFAULT '', \
type VARCHAR NOT NULL DEFAULT '', \
properties JSONB NOT NULL DEFAULT '{}', \
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), \
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() \
)",
)
.await?;
conn.execute_unprepared(
"CREATE INDEX IF NOT EXISTS idx_graph_node_type ON graph_node(type)",
)
.await?;
conn.execute_unprepared(
"CREATE TABLE IF NOT EXISTS graph_edge ( \
source_id VARCHAR NOT NULL REFERENCES graph_node(id) ON DELETE CASCADE, \
target_id VARCHAR NOT NULL REFERENCES graph_node(id) ON DELETE CASCADE, \
relationship_name VARCHAR NOT NULL, \
properties JSONB NOT NULL DEFAULT '{}', \
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), \
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), \
PRIMARY KEY (source_id, target_id, relationship_name) \
)",
)
.await?;
conn.execute_unprepared(
"CREATE INDEX IF NOT EXISTS idx_graph_edge_source_cover \
ON graph_edge(source_id) INCLUDE (target_id, relationship_name)",
)
.await?;
conn.execute_unprepared(
"CREATE INDEX IF NOT EXISTS idx_graph_edge_target_cover \
ON graph_edge(target_id) INCLUDE (source_id, relationship_name)",
)
.await?;
Ok(())
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
let conn = manager.get_connection();
conn.execute_unprepared("DROP TABLE IF EXISTS graph_edge")
.await?;
conn.execute_unprepared("DROP TABLE IF EXISTS graph_node")
.await?;
Ok(())
}
}
}