use async_trait::async_trait;
use serde::Serialize;
use serde_json::Value;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use crate::{EdgeData, GraphDBResult, GraphNode, NodeData};
pub type EdgeKey = (String, String, String);
#[async_trait]
pub trait GraphDBTrait: Send + Sync {
async fn initialize(&self) -> GraphDBResult<()>;
async fn is_empty(&self) -> GraphDBResult<bool>;
async fn query(
&self,
query: &str,
params: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
) -> GraphDBResult<Vec<Vec<serde_json::Value>>>;
async fn delete_graph(&self) -> GraphDBResult<()>;
async fn has_node(&self, node_id: &str) -> GraphDBResult<bool>;
async fn add_node_raw(&self, node: Value) -> GraphDBResult<()>;
async fn add_nodes_raw(&self, nodes: Vec<Value>) -> GraphDBResult<()>;
async fn delete_node(&self, node_id: &str) -> GraphDBResult<()>;
async fn delete_nodes(&self, node_ids: &[String]) -> GraphDBResult<()>;
async fn get_node(&self, node_id: &str) -> GraphDBResult<Option<NodeData>>;
async fn get_nodes(&self, node_ids: &[String]) -> GraphDBResult<Vec<NodeData>>;
async fn has_edge(
&self,
source_id: &str,
target_id: &str,
relationship_name: &str,
) -> GraphDBResult<bool>;
async fn has_edges(&self, edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>>;
async fn add_edge(
&self,
source_id: &str,
target_id: &str,
relationship_name: &str,
properties: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
) -> GraphDBResult<()>;
async fn add_edges(&self, edges: &[EdgeData]) -> GraphDBResult<()>;
async fn get_edges(&self, node_id: &str) -> GraphDBResult<Vec<EdgeData>>;
async fn get_neighbors(&self, node_id: &str) -> GraphDBResult<Vec<NodeData>>;
async fn get_connections(
&self,
node_id: &str,
) -> GraphDBResult<
Vec<(
NodeData,
HashMap<Cow<'static, str>, serde_json::Value>,
NodeData,
)>,
>;
async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)>;
async fn get_graph_metrics(
&self,
include_optional: bool,
) -> GraphDBResult<HashMap<Cow<'static, str>, serde_json::Value>>;
async fn get_filtered_graph_data(
&self,
attribute_filters: &HashMap<Cow<'static, str>, Vec<serde_json::Value>>,
) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)>;
async fn get_nodeset_subgraph(
&self,
node_type: &str,
node_names: &[String],
node_name_filter_operator: &str,
) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)>;
async fn get_degree_one_nodes(&self, node_type: &str) -> GraphDBResult<Vec<crate::GraphNode>> {
let (nodes, edges) = self.get_graph_data().await?;
let mut degree: HashMap<String, usize> = HashMap::new();
for (src, tgt, _, _) in &edges {
*degree.entry(src.clone()).or_default() += 1;
*degree.entry(tgt.clone()).or_default() += 1;
}
Ok(nodes
.into_iter()
.filter(|(id, props)| {
let type_matches = props
.get("type")
.and_then(|v| v.as_str())
.is_some_and(|t| t == node_type);
let deg = degree.get(id).copied().unwrap_or(0);
type_matches && deg == 1
})
.collect())
}
async fn get_all_relationship_names(&self) -> GraphDBResult<HashSet<String>> {
let (_, edges) = self.get_graph_data().await?;
Ok(edges.into_iter().map(|(_, _, rel, _)| rel).collect())
}
async fn get_zero_degree_edge_type_nodes(&self) -> GraphDBResult<Vec<crate::GraphNode>> {
let (nodes, edges) = self.get_graph_data().await?;
let active_rel_names: HashSet<&str> =
edges.iter().map(|(_, _, rel, _)| rel.as_str()).collect();
let mut degree: HashMap<String, usize> = HashMap::new();
for (src, tgt, _, _) in &edges {
*degree.entry(src.clone()).or_default() += 1;
*degree.entry(tgt.clone()).or_default() += 1;
}
Ok(nodes
.into_iter()
.filter(|(id, props)| {
let is_edge_type = props
.get("type")
.and_then(|v| v.as_str())
.is_some_and(|t| t == "EdgeType");
if !is_edge_type {
return false;
}
let deg = degree.get(id).copied().unwrap_or(0);
if deg > 0 {
return false;
}
let rel_name = props
.get("relationship_name")
.and_then(|v| v.as_str())
.unwrap_or("");
!active_rel_names.contains(rel_name)
})
.collect())
}
async fn update_node_property(
&self,
node_id: &str,
key: &str,
value: serde_json::Value,
) -> GraphDBResult<()> {
let node = self
.get_node(node_id)
.await?
.ok_or_else(|| crate::GraphDBError::NodeError(format!("Node not found: {node_id}")))?;
let edges = self.get_edges(node_id).await.unwrap_or_default();
let mut props = serde_json::Map::new();
for (k, v) in node {
props.insert(k.into_owned(), v);
}
props.insert(key.to_string(), value);
self.delete_node(node_id).await?;
self.add_node_raw(Value::Object(props)).await?;
if !edges.is_empty() {
self.add_edges(&edges).await?;
}
Ok(())
}
async fn update_edge_property(
&self,
source_id: &str,
target_id: &str,
relationship_name: &str,
key: &str,
value: serde_json::Value,
) -> GraphDBResult<()> {
let _ = (source_id, target_id, relationship_name, key, value);
tracing::warn!(
"update_edge_property not implemented for this backend; \
edge {source_id} -> {target_id} ({relationship_name}) property {key} not updated"
);
Ok(())
}
async fn get_node_feedback_weights(
&self,
node_ids: &[String],
) -> GraphDBResult<HashMap<String, f64>> {
let mut out = HashMap::with_capacity(node_ids.len());
for id in node_ids {
if let Some(node) = self.get_node(id).await?
&& let Some(v) = node.get("feedback_weight").and_then(|v| v.as_f64())
{
out.insert(id.clone(), v);
}
}
Ok(out)
}
async fn set_node_feedback_weights(
&self,
updates: &HashMap<String, f64>,
) -> GraphDBResult<HashMap<String, bool>> {
let mut out = HashMap::with_capacity(updates.len());
for (id, w) in updates {
let ok = self
.update_node_property(id, "feedback_weight", serde_json::json!(w))
.await
.is_ok();
out.insert(id.clone(), ok);
}
Ok(out)
}
async fn get_edge_feedback_weights(
&self,
edge_keys: &[EdgeKey],
) -> GraphDBResult<HashMap<EdgeKey, f64>> {
if !edge_keys.is_empty() {
tracing::warn!(
"get_edge_feedback_weights not implemented for this backend; \
returning empty map for {} edge(s)",
edge_keys.len()
);
}
Ok(HashMap::new())
}
async fn set_edge_feedback_weights(
&self,
updates: &HashMap<EdgeKey, f64>,
) -> GraphDBResult<HashMap<EdgeKey, bool>> {
let mut out = HashMap::with_capacity(updates.len());
for (key, w) in updates {
let ok = self
.update_edge_property(
&key.0,
&key.1,
&key.2,
"feedback_weight",
serde_json::json!(w),
)
.await
.is_ok();
out.insert(key.clone(), ok);
}
Ok(out)
}
async fn get_id_filtered_graph_data(
&self,
node_ids: &[String],
) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
if node_ids.is_empty() {
return Ok((vec![], vec![]));
}
let (all_nodes, all_edges) = self.get_graph_data().await?;
let id_set: std::collections::HashSet<&str> = node_ids.iter().map(String::as_str).collect();
let filtered_nodes: Vec<GraphNode> = all_nodes
.into_iter()
.filter(|(id, _)| id_set.contains(id.as_str()))
.collect();
let filtered_edges: Vec<EdgeData> = all_edges
.into_iter()
.filter(|(src, tgt, _, _)| {
id_set.contains(src.as_str()) && id_set.contains(tgt.as_str())
})
.collect();
Ok((filtered_nodes, filtered_edges))
}
}
#[async_trait]
pub trait GraphDBTraitExt: GraphDBTrait {
async fn add_node<T: Serialize + Sync>(&self, node: &T) -> GraphDBResult<()> {
let value = serde_json::to_value(node).map_err(|e| {
crate::GraphDBError::QueryError(format!("Failed to serialize node: {e}"))
})?;
self.add_node_raw(value).await
}
async fn add_nodes<T: Serialize + Sync>(&self, nodes: &[&T]) -> GraphDBResult<()> {
let values: Vec<Value> = nodes
.iter()
.map(serde_json::to_value)
.collect::<Result<_, _>>()
.map_err(|e| {
crate::GraphDBError::QueryError(format!("Failed to serialize nodes: {e}"))
})?;
self.add_nodes_raw(values).await
}
}
impl<T: GraphDBTrait + ?Sized> GraphDBTraitExt for T {}