Skip to main content

nova_boot_graphdb/
surreal.rs

1use crate::{builders::sanitize_symbol, error::GraphDbError, traits::GraphStore, types::*};
2use async_trait::async_trait;
3use serde_json::Value as JsonValue;
4use std::collections::{HashMap, HashSet, VecDeque};
5use tokio::sync::Mutex;
6
7pub(crate) fn surreal_result_rows(json: &JsonValue) -> Vec<JsonValue> {
8    json.as_array()
9        .and_then(|stmts| stmts.first())
10        .and_then(|stmt| stmt.get("result"))
11        .and_then(JsonValue::as_array)
12        .cloned()
13        .unwrap_or_default()
14}
15
16fn parse_surreal_record_id(value: &JsonValue) -> Option<String> {
17    if let Some(id) = value.as_str() {
18        return id
19            .split(':')
20            .nth(1)
21            .map(ToString::to_string)
22            .or_else(|| Some(id.to_string()));
23    }
24
25    let obj = value.as_object()?;
26    if let Some(inner_id) = obj.get("id") {
27        return parse_surreal_record_id(inner_id);
28    }
29
30    obj.get("tb").and_then(JsonValue::as_str).and_then(|tb| {
31        obj.get("id")
32            .and_then(JsonValue::as_str)
33            .map(|id| format!("{tb}:{id}"))
34    })
35}
36
37pub(crate) fn surreal_value_to_node(value: &JsonValue) -> Option<GraphNode> {
38    if let Some(id) = value.as_str() {
39        let parsed_id = parse_surreal_record_id(value)?;
40        return Some(GraphNode {
41            id: parsed_id,
42            labels: vec![id.split(':').next().unwrap_or("node").to_string()],
43            properties: HashMap::new(),
44        });
45    }
46
47    let obj = value.as_object()?;
48
49    let raw_id = obj.get("id")?;
50    let id = parse_surreal_record_id(raw_id)?;
51
52    let labels = raw_id
53        .as_object()
54        .and_then(|m| m.get("tb"))
55        .and_then(JsonValue::as_str)
56        .map(|tb| vec![tb.to_string()])
57        .unwrap_or_else(|| vec!["node".to_string()]);
58
59    let mut properties = obj
60        .get("properties")
61        .and_then(JsonValue::as_object)
62        .cloned()
63        .map(|m| m.into_iter().collect::<HashMap<_, _>>())
64        .unwrap_or_default();
65
66    for (k, v) in obj {
67        if k != "id" && k != "properties" && !k.starts_with('_') {
68            properties.entry(k.clone()).or_insert_with(|| v.clone());
69        }
70    }
71
72    Some(GraphNode {
73        id,
74        labels,
75        properties,
76    })
77}
78
79pub struct SurrealGraphStore {
80    pub endpoint: String,
81    pub namespace: String,
82    pub database: String,
83    client: reqwest::Client,
84    username: Option<String>,
85    password: Option<String>,
86    token: Mutex<Option<String>>,
87}
88
89impl SurrealGraphStore {
90    pub fn new(
91        endpoint: impl Into<String>,
92        namespace: impl Into<String>,
93        database: impl Into<String>,
94    ) -> Self {
95        Self {
96            endpoint: endpoint.into(),
97            namespace: namespace.into(),
98            database: database.into(),
99            client: reqwest::Client::new(),
100            username: None,
101            password: None,
102            token: Mutex::new(None),
103        }
104    }
105
106    pub fn new_with_auth(
107        endpoint: impl Into<String>,
108        namespace: impl Into<String>,
109        database: impl Into<String>,
110        username: impl Into<String>,
111        password: impl Into<String>,
112    ) -> Self {
113        Self {
114            endpoint: endpoint.into(),
115            namespace: namespace.into(),
116            database: database.into(),
117            client: reqwest::Client::new(),
118            username: Some(username.into()),
119            password: Some(password.into()),
120            token: Mutex::new(None),
121        }
122    }
123
124    async fn auth_token(&self) -> Result<Option<String>, GraphDbError> {
125        let Some(username) = &self.username else {
126            return Ok(None);
127        };
128        let Some(password) = &self.password else {
129            return Ok(None);
130        };
131
132        let mut guard = self.token.lock().await;
133        if let Some(token) = guard.as_ref() {
134            return Ok(Some(token.clone()));
135        }
136
137        let endpoint = format!("{}/signin", self.endpoint.trim_end_matches('/'));
138        let payload = serde_json::json!({
139            "user": username,
140            "pass": password,
141        });
142
143        let resp = self
144            .client
145            .post(endpoint)
146            .header("Accept", "application/json")
147            .json(&payload)
148            .send()
149            .await
150            .map_err(|e| GraphDbError::Backend(e.to_string()))?;
151
152        let status = resp.status();
153        let json: JsonValue = resp
154            .json()
155            .await
156            .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
157
158        if !status.is_success() {
159            return Err(GraphDbError::Backend(format!(
160                "surrealdb signin http status {}: {}",
161                status, json
162            )));
163        }
164
165        let token = json
166            .get("token")
167            .and_then(JsonValue::as_str)
168            .or_else(|| json.get("result").and_then(JsonValue::as_str))
169            .or_else(|| {
170                json.get("result")
171                    .and_then(JsonValue::as_object)
172                    .and_then(|obj| obj.get("token"))
173                    .and_then(JsonValue::as_str)
174            })
175            .ok_or_else(|| {
176                GraphDbError::Backend(format!("surrealdb signin response missing token: {}", json))
177            })?
178            .to_string();
179
180        *guard = Some(token.clone());
181        Ok(Some(token))
182    }
183
184    async fn run_sql(&self, sql: &str) -> Result<JsonValue, GraphDbError> {
185        let endpoint = format!("{}/sql", self.endpoint.trim_end_matches('/'));
186
187        let mut request = self
188            .client
189            .post(endpoint)
190            .header("surreal-ns", &self.namespace)
191            .header("surreal-db", &self.database)
192            .header("Accept", "application/json");
193
194        if let Some(token) = self.auth_token().await? {
195            request = request.header("Authorization", format!("Bearer {token}"));
196        }
197
198        let resp = request
199            .body(sql.to_string())
200            .send()
201            .await
202            .map_err(|e| GraphDbError::Backend(e.to_string()))?;
203
204        let status = resp.status();
205        let json: JsonValue = resp
206            .json()
207            .await
208            .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
209
210        if !status.is_success() {
211            return Err(GraphDbError::Backend(format!(
212                "surrealdb http status {}: {}",
213                status, json
214            )));
215        }
216
217        Ok(json)
218    }
219}
220
221#[async_trait]
222impl GraphStore for SurrealGraphStore {
223    async fn execute(&self, query: GraphQuery) -> Result<JsonValue, GraphDbError> {
224        let sql = match query {
225            GraphQuery::Cypher(q) => q,
226            GraphQuery::GraphQl(q) => q,
227        };
228        self.run_sql(&sql).await
229    }
230
231    async fn upsert_node(&self, node: GraphNode) -> Result<(), GraphDbError> {
232        let table = node
233            .labels
234            .first()
235            .map(|v| v.to_ascii_lowercase())
236            .unwrap_or_else(|| "node".to_string());
237        let properties = serde_json::to_string(&node.properties)
238            .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
239        let sql = format!(
240            "UPSERT {table}:{} SET id = '{}', properties = {};",
241            node.id, node.id, properties
242        );
243        self.run_sql(&sql).await.map(|_| ())
244    }
245
246    async fn upsert_edge(&self, edge: GraphEdge) -> Result<(), GraphDbError> {
247        let rel = sanitize_symbol(&edge.rel_type).to_ascii_lowercase();
248        let props = serde_json::to_string(&edge.properties)
249            .map_err(|e| GraphDbError::Serialization(e.to_string()))?;
250        let sql = format!(
251            "RELATE node:{}->{rel}->node:{} SET id = '{}', properties = {};",
252            edge.from, edge.to, edge.id, props
253        );
254        self.run_sql(&sql).await.map(|_| ())
255    }
256
257    async fn get_node(&self, node_id: &str) -> Result<Option<GraphNode>, GraphDbError> {
258        let sql = format!("SELECT * FROM node:{};", node_id);
259        let json = self.run_sql(&sql).await?;
260        let rows = surreal_result_rows(&json);
261        let Some(first) = rows.first() else {
262            return Ok(None);
263        };
264
265        Ok(surreal_value_to_node(first).or_else(|| {
266            Some(GraphNode {
267                id: node_id.to_string(),
268                labels: vec!["node".to_string()],
269                properties: HashMap::new(),
270            })
271        }))
272    }
273
274    async fn neighbors(&self, node_id: &str) -> Result<Vec<GraphNode>, GraphDbError> {
275        let mut out = Vec::new();
276
277        let sql = format!("SELECT ->?->node AS neighbors FROM node:{};", node_id);
278        let json = self.run_sql(&sql).await?;
279        let rows = surreal_result_rows(&json);
280
281        for row in rows {
282            if let Some(neighbors) = row.get("neighbors").and_then(JsonValue::as_array) {
283                for item in neighbors {
284                    if let Some(node) = surreal_value_to_node(item) {
285                        out.push(node);
286                    }
287                }
288                continue;
289            }
290
291            if let Some(node) = surreal_value_to_node(&row) {
292                out.push(node);
293            }
294        }
295
296        Ok(out)
297    }
298
299    async fn traverse(&self, start: &str, max_depth: usize) -> Result<GraphSubgraph, GraphDbError> {
300        let mut visited = HashSet::new();
301        let mut q = VecDeque::from([(start.to_string(), 0usize)]);
302        let mut nodes = Vec::new();
303        let mut edges = Vec::new();
304        let mut edge_ids = HashSet::new();
305
306        while let Some((current, depth)) = q.pop_front() {
307            if !visited.insert(current.clone()) {
308                continue;
309            }
310            if let Some(node) = self.get_node(&current).await? {
311                nodes.push(node);
312            }
313            if depth >= max_depth {
314                continue;
315            }
316
317            for n in self.neighbors(&current).await? {
318                let synthetic_edge_id = format!("{}->{}", current, n.id);
319                if edge_ids.insert(synthetic_edge_id.clone()) {
320                    edges.push(GraphEdge {
321                        id: synthetic_edge_id,
322                        from: current.clone(),
323                        to: n.id.clone(),
324                        rel_type: "RELATED".to_string(),
325                        properties: HashMap::new(),
326                    });
327                }
328                q.push_back((n.id, depth + 1));
329            }
330        }
331
332        Ok(GraphSubgraph { nodes, edges })
333    }
334}