1use crate::{builders::sanitize_symbol, error::GraphDbError, traits::GraphStore, types::*};
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use std::collections::{HashMap, HashSet, VecDeque};
5use tokio::sync::Mutex;
6
7pub(crate) fn surreal_result_rows(json: &JsonValue) -> Vec<JsonValue> {
8 json.as_array()
9 .and_then(|stmts| stmts.first())
10 .and_then(|stmt| stmt.get("result"))
11 .and_then(JsonValue::as_array)
12 .cloned()
13 .unwrap_or_default()
14}
15
16fn parse_surreal_record_id(value: &JsonValue) -> Option<String> {
17 if let Some(id) = value.as_str() {
18 return id
19 .split(':')
20 .nth(1)
21 .map(ToString::to_string)
22 .or_else(|| Some(id.to_string()));
23 }
24
25 let obj = value.as_object()?;
26 if let Some(inner_id) = obj.get("id") {
27 return parse_surreal_record_id(inner_id);
28 }
29
30 obj.get("tb").and_then(JsonValue::as_str).and_then(|tb| {
31 obj.get("id")
32 .and_then(JsonValue::as_str)
33 .map(|id| format!("{tb}:{id}"))
34 })
35}
36
37pub(crate) fn surreal_value_to_node(value: &JsonValue) -> Option<GraphNode> {
38 if let Some(id) = value.as_str() {
39 let parsed_id = parse_surreal_record_id(value)?;
40 return Some(GraphNode {
41 id: parsed_id,
42 labels: vec![id.split(':').next().unwrap_or("node").to_string()],
43 properties: HashMap::new(),
44 });
45 }
46
47 let obj = value.as_object()?;
48
49 let raw_id = obj.get("id")?;
50 let id = parse_surreal_record_id(raw_id)?;
51
52 let labels = raw_id
53 .as_object()
54 .and_then(|m| m.get("tb"))
55 .and_then(JsonValue::as_str)
56 .map(|tb| vec![tb.to_string()])
57 .unwrap_or_else(|| vec!["node".to_string()]);
58
59 let mut properties = obj
60 .get("properties")
61 .and_then(JsonValue::as_object)
62 .cloned()
63 .map(|m| m.into_iter().collect::<HashMap<_, _>>())
64 .unwrap_or_default();
65
66 for (k, v) in obj {
67 if k != "id" && k != "properties" && !k.starts_with('_') {
68 properties.entry(k.clone()).or_insert_with(|| v.clone());
69 }
70 }
71
72 Some(GraphNode {
73 id,
74 labels,
75 properties,
76 })
77}
78
79pub struct SurrealGraphStore {
80 pub endpoint: String,
81 pub namespace: String,
82 pub database: String,
83 client: reqwest::Client,
84 username: Option<String>,
85 password: Option<String>,
86 token: Mutex<Option<String>>,
87}
88
89impl SurrealGraphStore {
90 pub fn new(
91 endpoint: impl Into<String>,
92 namespace: impl Into<String>,
93 database: impl Into<String>,
94 ) -> Self {
95 Self {
96 endpoint: endpoint.into(),
97 namespace: namespace.into(),
98 database: database.into(),
99 client: reqwest::Client::new(),
100 username: None,
101 password: None,
102 token: Mutex::new(None),
103 }
104 }
105
106 pub fn new_with_auth(
107 endpoint: impl Into<String>,
108 namespace: impl Into<String>,
109 database: impl Into<String>,
110 username: impl Into<String>,
111 password: impl Into<String>,
112 ) -> Self {
113 Self {
114 endpoint: endpoint.into(),
115 namespace: namespace.into(),
116 database: database.into(),
117 client: reqwest::Client::new(),
118 username: Some(username.into()),
119 password: Some(password.into()),
120 token: Mutex::new(None),
121 }
122 }
123
124 async fn auth_token(&self) -> Result<Option<String>, GraphDbError> {
125 let Some(username) = &self.username else {
126 return Ok(None);
127 };
128 let Some(password) = &self.password else {
129 return Ok(None);
130 };
131
132 let mut guard = self.token.lock().await;
133 if let Some(token) = guard.as_ref() {
134 return Ok(Some(token.clone()));
135 }
136
137 let endpoint = format!("{}/signin", self.endpoint.trim_end_matches('/'));
138 let payload = serde_json::json!({
139 "user": username,
140 "pass": password,
141 });
142
143 let resp = self
144 .client
145 .post(endpoint)
146 .header("Accept", "application/json")
147 .json(&payload)
148 .send()
149 .await
150 .map_err(|e| GraphDbError::Backend(e.to_string()))?;
151
152 let status = resp.status();
153 let json: JsonValue = resp
154 .json()
155 .await
156 .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
157
158 if !status.is_success() {
159 return Err(GraphDbError::Backend(format!(
160 "surrealdb signin http status {}: {}",
161 status, json
162 )));
163 }
164
165 let token = json
166 .get("token")
167 .and_then(JsonValue::as_str)
168 .or_else(|| json.get("result").and_then(JsonValue::as_str))
169 .or_else(|| {
170 json.get("result")
171 .and_then(JsonValue::as_object)
172 .and_then(|obj| obj.get("token"))
173 .and_then(JsonValue::as_str)
174 })
175 .ok_or_else(|| {
176 GraphDbError::Backend(format!("surrealdb signin response missing token: {}", json))
177 })?
178 .to_string();
179
180 *guard = Some(token.clone());
181 Ok(Some(token))
182 }
183
184 async fn run_sql(&self, sql: &str) -> Result<JsonValue, GraphDbError> {
185 let endpoint = format!("{}/sql", self.endpoint.trim_end_matches('/'));
186
187 let mut request = self
188 .client
189 .post(endpoint)
190 .header("surreal-ns", &self.namespace)
191 .header("surreal-db", &self.database)
192 .header("Accept", "application/json");
193
194 if let Some(token) = self.auth_token().await? {
195 request = request.header("Authorization", format!("Bearer {token}"));
196 }
197
198 let resp = request
199 .body(sql.to_string())
200 .send()
201 .await
202 .map_err(|e| GraphDbError::Backend(e.to_string()))?;
203
204 let status = resp.status();
205 let json: JsonValue = resp
206 .json()
207 .await
208 .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
209
210 if !status.is_success() {
211 return Err(GraphDbError::Backend(format!(
212 "surrealdb http status {}: {}",
213 status, json
214 )));
215 }
216
217 Ok(json)
218 }
219}
220
221#[async_trait]
222impl GraphStore for SurrealGraphStore {
223 async fn execute(&self, query: GraphQuery) -> Result<JsonValue, GraphDbError> {
224 let sql = match query {
225 GraphQuery::Cypher(q) => q,
226 GraphQuery::GraphQl(q) => q,
227 };
228 self.run_sql(&sql).await
229 }
230
231 async fn upsert_node(&self, node: GraphNode) -> Result<(), GraphDbError> {
232 let table = node
233 .labels
234 .first()
235 .map(|v| v.to_ascii_lowercase())
236 .unwrap_or_else(|| "node".to_string());
237 let properties = serde_json::to_string(&node.properties)
238 .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
239 let sql = format!(
240 "UPSERT {table}:{} SET id = '{}', properties = {};",
241 node.id, node.id, properties
242 );
243 self.run_sql(&sql).await.map(|_| ())
244 }
245
246 async fn upsert_edge(&self, edge: GraphEdge) -> Result<(), GraphDbError> {
247 let rel = sanitize_symbol(&edge.rel_type).to_ascii_lowercase();
248 let props = serde_json::to_string(&edge.properties)
249 .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
250 let sql = format!(
251 "RELATE node:{}->{rel}->node:{} SET id = '{}', properties = {};",
252 edge.from, edge.to, edge.id, props
253 );
254 self.run_sql(&sql).await.map(|_| ())
255 }
256
257 async fn get_node(&self, node_id: &str) -> Result<Option<GraphNode>, GraphDbError> {
258 let sql = format!("SELECT * FROM node:{};", node_id);
259 let json = self.run_sql(&sql).await?;
260 let rows = surreal_result_rows(&json);
261 let Some(first) = rows.first() else {
262 return Ok(None);
263 };
264
265 Ok(surreal_value_to_node(first).or_else(|| {
266 Some(GraphNode {
267 id: node_id.to_string(),
268 labels: vec!["node".to_string()],
269 properties: HashMap::new(),
270 })
271 }))
272 }
273
274 async fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, GraphDbError> {
275 let mut out = Vec::new();
276
277 let sql = format!("SELECT ->?->node AS neighbors FROM node:{};", node_id);
278 let json = self.run_sql(&sql).await?;
279 let rows = surreal_result_rows(&json);
280
281 for row in rows {
282 if let Some(neighbors) = row.get("neighbors").and_then(JsonValue::as_array) {
283 for item in neighbors {
284 if let Some(node) = surreal_value_to_node(item) {
285 out.push(node);
286 }
287 }
288 continue;
289 }
290
291 if let Some(node) = surreal_value_to_node(&row) {
292 out.push(node);
293 }
294 }
295
296 Ok(out)
297 }
298
299 async fn traverse(&self, start: &str, max_depth: usize) -> Result<GraphSubgraph, GraphDbError> {
300 let mut visited = HashSet::new();
301 let mut q = VecDeque::from([(start.to_string(), 0usize)]);
302 let mut nodes = Vec::new();
303 let mut edges = Vec::new();
304 let mut edge_ids = HashSet::new();
305
306 while let Some((current, depth)) = q.pop_front() {
307 if !visited.insert(current.clone()) {
308 continue;
309 }
310 if let Some(node) = self.get_node(¤t).await? {
311 nodes.push(node);
312 }
313 if depth >= max_depth {
314 continue;
315 }
316
317 for n in self.neighbors(¤t).await? {
318 let synthetic_edge_id = format!("{}->{}", current, n.id);
319 if edge_ids.insert(synthetic_edge_id.clone()) {
320 edges.push(GraphEdge {
321 id: synthetic_edge_id,
322 from: current.clone(),
323 to: n.id.clone(),
324 rel_type: "RELATED".to_string(),
325 properties: HashMap::new(),
326 });
327 }
328 q.push_back((n.id, depth + 1));
329 }
330 }
331
332 Ok(GraphSubgraph { nodes, edges })
333 }
334}