Skip to main content

nova_boot_graphdb/
neo4j.rs

1use crate::{builders::sanitize_symbol, error::GraphDbError, traits::GraphStore, types::*};
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use std::collections::{HashSet, VecDeque};
5
6pub struct Neo4jGraphStore {
7    pub uri: String,
8    pub user: String,
9    pub password: String,
10    pub database: String,
11    client: reqwest::Client,
12}
13
14impl Neo4jGraphStore {
15    pub fn new(
16        uri: impl Into<String>,
17        user: impl Into<String>,
18        password: impl Into<String>,
19    ) -> Self {
20        Self {
21            uri: uri.into(),
22            user: user.into(),
23            password: password.into(),
24            database: "neo4j".to_string(),
25            client: reqwest::Client::new(),
26        }
27    }
28
29    async fn run_cypher(
30        &self,
31        statement: &str,
32        parameters: JsonValue,
33    ) -> Result<JsonValue, GraphDbError> {
34        let endpoint = format!(
35            "{}/db/{}/tx/commit",
36            self.uri.trim_end_matches('/'),
37            self.database
38        );
39
40        let payload = serde_json::json!({
41            "statements": [
42                {
43                    "statement": statement,
44                    "parameters": parameters
45                }
46            ]
47        });
48
49        let resp = self
50            .client
51            .post(endpoint)
52            .basic_auth(&self.user, Some(&self.password))
53            .json(&payload)
54            .send()
55            .await
56            .map_err(|e| GraphDbError::Backend(e.to_string()))?;
57
58        let status = resp.status();
59        let json: JsonValue = resp
60            .json()
61            .await
62            .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
63
64        if !status.is_success() {
65            return Err(GraphDbError::Backend(format!(
66                "neo4j http status {}: {}",
67                status, json
68            )));
69        }
70
71        let errors = json
72            .get("errors")
73            .and_then(JsonValue::as_array)
74            .cloned()
75            .unwrap_or_default();
76
77        if !errors.is_empty() {
78            return Err(GraphDbError::Backend(format!(
79                "neo4j query error: {errors:?}"
80            )));
81        }
82
83        Ok(json)
84    }
85
86    fn first_row(response: &JsonValue) -> Option<JsonValue> {
87        response
88            .get("results")
89            .and_then(JsonValue::as_array)
90            .and_then(|results| results.first())
91            .and_then(|result| result.get("data"))
92            .and_then(JsonValue::as_array)
93            .and_then(|data| data.first())
94            .and_then(|entry| entry.get("row"))
95            .and_then(JsonValue::as_array)
96            .and_then(|row| row.first())
97            .cloned()
98    }
99}
100
101#[async_trait]
102impl GraphStore for Neo4jGraphStore {
103    async fn execute(&self, query: GraphQuery) -> Result<JsonValue, GraphDbError> {
104        let cypher = match query {
105            GraphQuery::Cypher(q) => q,
106            GraphQuery::GraphQl(_) => {
107                return Err(GraphDbError::InvalidInput(
108                    "Neo4j adapter accepts Cypher queries only".to_string(),
109                ));
110            }
111        };
112        self.run_cypher(&cypher, serde_json::json!({})).await
113    }
114
115    async fn upsert_node(&self, node: GraphNode) -> Result<(), GraphDbError> {
116        let labels = if node.labels.is_empty() {
117            "Node".to_string()
118        } else {
119            node.labels
120                .iter()
121                .map(|l| sanitize_symbol(l))
122                .collect::<Vec<_>>()
123                .join(":")
124        };
125
126        let cypher = format!("MERGE (n:{labels} {{id: $id}}) SET n += $props RETURN n.id");
127        let params = serde_json::json!({
128            "id": node.id,
129            "props": node.properties
130        });
131        self.run_cypher(&cypher, params).await.map(|_| ())
132    }
133
134    async fn upsert_edge(&self, edge: GraphEdge) -> Result<(), GraphDbError> {
135        let rel_type = sanitize_symbol(&edge.rel_type);
136        let cypher = format!(
137            "MATCH (a {{id: $from}}), (b {{id: $to}}) MERGE (a)-[r:{rel_type} {{id: $id}}]->(b) SET r += $props RETURN r.id"
138        );
139        let params = serde_json::json!({
140            "id": edge.id,
141            "from": edge.from,
142            "to": edge.to,
143            "props": edge.properties
144        });
145        self.run_cypher(&cypher, params).await.map(|_| ())
146    }
147
148    async fn get_node(&self, node_id: &str) -> Result<Option<GraphNode>, GraphDbError> {
149        let cypher =
150            "MATCH (n {id: $id}) RETURN {id: n.id, labels: labels(n), properties: properties(n)}";
151        let response = self
152            .run_cypher(cypher, serde_json::json!({ "id": node_id }))
153            .await?;
154
155        match Self::first_row(&response) {
156            Some(value) => serde_json::from_value::<GraphNode>(value)
157                .map(Some)
158                .map_err(|e| GraphDbError::Serialization(e.to_string())),
159            None => Ok(None),
160        }
161    }
162
163    async fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, GraphDbError> {
164        let cypher = "MATCH (a {id: $id})-[]->(b) RETURN {id: b.id, labels: labels(b), properties: properties(b)}";
165        let response = self
166            .run_cypher(cypher, serde_json::json!({ "id": node_id }))
167            .await?;
168
169        let rows = response
170            .get("results")
171            .and_then(JsonValue::as_array)
172            .and_then(|results| results.first())
173            .and_then(|result| result.get("data"))
174            .and_then(JsonValue::as_array)
175            .cloned()
176            .unwrap_or_default();
177
178        rows.into_iter()
179            .filter_map(|entry| {
180                entry
181                    .get("row")
182                    .and_then(JsonValue::as_array)
183                    .and_then(|row| row.first())
184                    .cloned()
185            })
186            .map(|value| {
187                serde_json::from_value::<GraphNode>(value)
188                    .map_err(|e| GraphDbError::Serialization(e.to_string()))
189            })
190            .collect::<Result<Vec<_>, _>>()
191    }
192
193    async fn traverse(&self, start: &str, max_depth: usize) -> Result<GraphSubgraph, GraphDbError> {
194        let mut visited = HashSet::new();
195        let mut q = VecDeque::from([(start.to_string(), 0usize)]);
196        let mut nodes = Vec::new();
197        let mut edges = Vec::new();
198
199        while let Some((current, depth)) = q.pop_front() {
200            if !visited.insert(current.clone()) {
201                continue;
202            }
203
204            if let Some(node) = self.get_node(&current).await? {
205                nodes.push(node.clone());
206            }
207
208            if depth >= max_depth {
209                continue;
210            }
211
212            let cypher = "MATCH (a {id: $id})-[r]->(b) RETURN {id: r.id, from: a.id, to: b.id, rel_type: type(r), properties: properties(r)}";
213            let response = self
214                .run_cypher(cypher, serde_json::json!({ "id": current }))
215                .await?;
216
217            let rows = response
218                .get("results")
219                .and_then(JsonValue::as_array)
220                .and_then(|results| results.first())
221                .and_then(|result| result.get("data"))
222                .and_then(JsonValue::as_array)
223                .cloned()
224                .unwrap_or_default();
225
226            for entry in rows {
227                if let Some(value) = entry
228                    .get("row")
229                    .and_then(JsonValue::as_array)
230                    .and_then(|row| row.first())
231                    .cloned()
232                {
233                    let edge: GraphEdge = serde_json::from_value(value)
234                        .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
235                    q.push_back((edge.to.clone(), depth + 1));
236                    edges.push(edge);
237                }
238            }
239        }
240
241        Ok(GraphSubgraph { nodes, edges })
242    }
243}