1use async_trait::async_trait;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use graphmind::graph::GraphStore;
11use graphmind::query::{QueryEngine, RecordBatch, Value};
12
13use crate::client::GraphmindClient;
14use crate::error::{GraphmindError, GraphmindResult};
15use crate::models::{QueryResult, SdkEdge, SdkNode, ServerStatus, StorageStats};
16
17pub struct EmbeddedClient {
22 pub(crate) store: Arc<RwLock<GraphStore>>,
23 engine: QueryEngine,
24}
25
26impl EmbeddedClient {
27 pub fn new() -> Self {
29 Self {
30 store: Arc::new(RwLock::new(GraphStore::new())),
31 engine: QueryEngine::new(),
32 }
33 }
34
35 pub fn with_store(store: Arc<RwLock<GraphStore>>) -> Self {
37 Self {
38 store,
39 engine: QueryEngine::new(),
40 }
41 }
42
43 pub fn store(&self) -> &Arc<RwLock<GraphStore>> {
45 &self.store
46 }
47
48 pub async fn store_read(&self) -> tokio::sync::RwLockReadGuard<'_, GraphStore> {
52 self.store.read().await
53 }
54
55 pub async fn store_write(&self) -> tokio::sync::RwLockWriteGuard<'_, GraphStore> {
59 self.store.write().await
60 }
61
62 pub fn nlq_pipeline(
64 &self,
65 config: graphmind::persistence::tenant::NLQConfig,
66 ) -> Result<graphmind::NLQPipeline, graphmind::NLQError> {
67 graphmind::NLQPipeline::new(config)
68 }
69
70 pub fn agent_runtime(
72 &self,
73 config: graphmind::persistence::tenant::AgentConfig,
74 ) -> graphmind::agent::AgentRuntime {
75 graphmind::agent::AgentRuntime::new(config)
76 }
77
78 pub fn persistence_manager(
80 &self,
81 base_path: impl AsRef<std::path::Path>,
82 ) -> Result<graphmind::PersistenceManager, graphmind::PersistenceError> {
83 graphmind::PersistenceManager::new(base_path)
84 }
85
86 pub fn cache_stats(&self) -> &graphmind::query::CacheStats {
88 self.engine.cache_stats()
89 }
90
91 pub async fn export_snapshot(
93 &self,
94 _tenant: &str,
95 path: &std::path::Path,
96 ) -> Result<graphmind::snapshot::format::ExportStats, Box<dyn std::error::Error>> {
97 let store_guard = self.store.read().await;
98 let file = std::fs::File::create(path)?;
99 let writer = std::io::BufWriter::new(file);
100 let stats = graphmind::snapshot::export_tenant(&store_guard, writer)?;
101 Ok(stats)
102 }
103
104 pub async fn import_snapshot(
106 &self,
107 _tenant: &str,
108 path: &std::path::Path,
109 ) -> Result<graphmind::snapshot::format::ImportStats, Box<dyn std::error::Error>> {
110 let mut store_guard = self.store.write().await;
111 let file = std::fs::File::open(path)?;
112 let reader = std::io::BufReader::new(file);
113 let stats = graphmind::snapshot::import_tenant(&mut store_guard, reader)?;
114 Ok(stats)
115 }
116}
117
118impl Default for EmbeddedClient {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124fn record_batch_to_query_result(batch: &RecordBatch, store: &GraphStore) -> QueryResult {
126 let mut nodes_map: HashMap<String, SdkNode> = HashMap::new();
127 let mut edges_map: HashMap<String, SdkEdge> = HashMap::new();
128 let mut records = Vec::new();
129
130 for record in &batch.records {
131 let mut row = Vec::new();
132 for col in &batch.columns {
133 let val = match record.get(col) {
134 Some(v) => v,
135 None => {
136 row.push(serde_json::Value::Null);
137 continue;
138 }
139 };
140
141 match val {
142 Value::Node(id, node) => {
143 let mut properties = serde_json::Map::new();
144 for (k, v) in &node.properties {
145 properties.insert(k.clone(), v.to_json());
146 }
147 let id_str = id.as_u64().to_string();
148 let labels: Vec<String> =
149 node.labels.iter().map(|l| l.as_str().to_string()).collect();
150
151 let node_json = serde_json::json!({
152 "id": id_str,
153 "labels": labels,
154 "properties": properties,
155 });
156
157 nodes_map.entry(id_str.clone()).or_insert_with(|| SdkNode {
158 id: id_str,
159 labels,
160 properties: properties.into_iter().collect(),
161 });
162
163 row.push(node_json);
164 }
165 Value::NodeRef(id) => {
166 let id_str = id.as_u64().to_string();
167 let (labels, properties, node_json) = if let Some(node) = store.get_node(*id) {
169 let mut props = serde_json::Map::new();
170 for (k, v) in &node.properties {
171 props.insert(k.clone(), v.to_json());
172 }
173 let lbls: Vec<String> =
174 node.labels.iter().map(|l| l.as_str().to_string()).collect();
175 let json = serde_json::json!({
176 "id": id_str,
177 "labels": lbls,
178 "properties": props,
179 });
180 (lbls, props.into_iter().collect(), json)
181 } else {
182 let json =
183 serde_json::json!({ "id": id_str, "labels": [], "properties": {} });
184 (vec![], HashMap::new(), json)
185 };
186
187 nodes_map.entry(id_str.clone()).or_insert_with(|| SdkNode {
188 id: id_str,
189 labels,
190 properties,
191 });
192
193 row.push(node_json);
194 }
195 Value::Edge(id, edge) => {
196 let mut properties = serde_json::Map::new();
197 for (k, v) in &edge.properties {
198 properties.insert(k.clone(), v.to_json());
199 }
200 let id_str = id.as_u64().to_string();
201 let edge_json = serde_json::json!({
202 "id": id_str,
203 "source": edge.source.as_u64().to_string(),
204 "target": edge.target.as_u64().to_string(),
205 "type": edge.edge_type.as_str(),
206 "properties": properties,
207 });
208
209 edges_map.entry(id_str.clone()).or_insert_with(|| SdkEdge {
210 id: id_str,
211 source: edge.source.as_u64().to_string(),
212 target: edge.target.as_u64().to_string(),
213 edge_type: edge.edge_type.as_str().to_string(),
214 properties: properties.into_iter().collect(),
215 });
216
217 row.push(edge_json);
218 }
219 Value::EdgeRef(id, src, tgt, et) => {
220 let id_str = id.as_u64().to_string();
221 let edge_json = serde_json::json!({
222 "id": id_str,
223 "source": src.as_u64().to_string(),
224 "target": tgt.as_u64().to_string(),
225 "type": et.as_str(),
226 "properties": {},
227 });
228
229 edges_map.entry(id_str.clone()).or_insert_with(|| SdkEdge {
230 id: id_str,
231 source: src.as_u64().to_string(),
232 target: tgt.as_u64().to_string(),
233 edge_type: et.as_str().to_string(),
234 properties: HashMap::new(),
235 });
236
237 row.push(edge_json);
238 }
239 Value::Property(p) => {
240 row.push(p.to_json());
241 }
242 Value::Path {
243 nodes: path_nodes,
244 edges: path_edges,
245 } => {
246 row.push(serde_json::json!({
247 "nodes": path_nodes.iter().map(|n| n.as_u64().to_string()).collect::<Vec<_>>(),
248 "edges": path_edges.iter().map(|e| e.as_u64().to_string()).collect::<Vec<_>>(),
249 "length": path_edges.len(),
250 }));
251 }
252 Value::Null => {
253 row.push(serde_json::Value::Null);
254 }
255 }
256 }
257 records.push(row);
258 }
259
260 QueryResult {
261 nodes: nodes_map.into_values().collect(),
262 edges: edges_map.into_values().collect(),
263 columns: batch.columns.clone(),
264 records,
265 }
266}
267
268fn is_write_query(cypher: &str) -> bool {
269 let upper = cypher.trim().to_uppercase();
270 upper.starts_with("CREATE")
271 || upper.starts_with("DELETE")
272 || upper.starts_with("DETACH")
273 || upper.starts_with("SET")
274 || upper.starts_with("MERGE")
275 || upper.starts_with("CALL")
276 || upper.contains(" CREATE ")
277 || upper.contains(" DELETE ")
278 || upper.contains(" SET ")
279 || upper.contains(" MERGE ")
280 || upper.contains(" CALL ")
281 || upper.contains(" REMOVE ")
282}
283
284#[async_trait]
285impl GraphmindClient for EmbeddedClient {
286 async fn query(&self, graph: &str, cypher: &str) -> GraphmindResult<QueryResult> {
287 if is_write_query(cypher) {
288 let mut store_guard = self.store.write().await;
289 let batch = self
290 .engine
291 .execute_mut(cypher, &mut *store_guard, graph)
292 .map_err(|e| GraphmindError::QueryError(e.to_string()))?;
293 Ok(record_batch_to_query_result(&batch, &*store_guard))
294 } else {
295 let store_guard = self.store.read().await;
296 let batch = self
297 .engine
298 .execute(cypher, &*store_guard)
299 .map_err(|e| GraphmindError::QueryError(e.to_string()))?;
300 Ok(record_batch_to_query_result(&batch, &*store_guard))
301 }
302 }
303
304 async fn query_readonly(&self, _graph: &str, cypher: &str) -> GraphmindResult<QueryResult> {
305 let store_guard = self.store.read().await;
306 let batch = self
307 .engine
308 .execute(cypher, &*store_guard)
309 .map_err(|e| GraphmindError::QueryError(e.to_string()))?;
310 Ok(record_batch_to_query_result(&batch, &*store_guard))
311 }
312
313 async fn delete_graph(&self, _graph: &str) -> GraphmindResult<()> {
314 let mut store_guard = self.store.write().await;
315 store_guard.clear();
316 Ok(())
317 }
318
319 async fn list_graphs(&self) -> GraphmindResult<Vec<String>> {
320 Ok(vec!["default".to_string()])
321 }
322
323 async fn status(&self) -> GraphmindResult<ServerStatus> {
324 let store_guard = self.store.read().await;
325 Ok(ServerStatus {
326 status: "healthy".to_string(),
327 version: graphmind::VERSION.to_string(),
328 storage: StorageStats {
329 nodes: store_guard.node_count() as u64,
330 edges: store_guard.edge_count() as u64,
331 },
332 })
333 }
334
335 async fn ping(&self) -> GraphmindResult<String> {
336 Ok("PONG".to_string())
337 }
338
339 async fn schema(&self, _graph: &str) -> GraphmindResult<String> {
340 let store_guard = self.store.read().await;
341 let mut lines = Vec::new();
342 lines.push(format!("Nodes: {}", store_guard.node_count()));
343 lines.push(format!("Edges: {}", store_guard.edge_count()));
344
345 let mut label_counts: HashMap<String, usize> = HashMap::new();
347 for node in store_guard.all_nodes() {
348 for label in &node.labels {
349 *label_counts.entry(label.as_str().to_string()).or_insert(0) += 1;
350 }
351 }
352 if !label_counts.is_empty() {
353 lines.push("Node labels:".to_string());
354 for (label, count) in &label_counts {
355 lines.push(format!(" :{} ({})", label, count));
356 }
357 }
358
359 let mut edge_type_counts: HashMap<String, usize> = HashMap::new();
361 for edge in store_guard.all_edges() {
362 *edge_type_counts
363 .entry(edge.edge_type.as_str().to_string())
364 .or_insert(0) += 1;
365 }
366 if !edge_type_counts.is_empty() {
367 lines.push("Edge types:".to_string());
368 for (et, count) in &edge_type_counts {
369 lines.push(format!(" :{} ({})", et, count));
370 }
371 }
372
373 Ok(lines.join("\n"))
374 }
375
376 async fn explain(&self, _graph: &str, cypher: &str) -> GraphmindResult<QueryResult> {
377 let prefixed = if cypher.trim().to_uppercase().starts_with("EXPLAIN") {
378 cypher.to_string()
379 } else {
380 format!("EXPLAIN {}", cypher)
381 };
382 let store_guard = self.store.read().await;
383 let batch = self
384 .engine
385 .execute(&prefixed, &*store_guard)
386 .map_err(|e| GraphmindError::QueryError(e.to_string()))?;
387 Ok(record_batch_to_query_result(&batch, &*store_guard))
388 }
389
390 async fn profile(&self, _graph: &str, cypher: &str) -> GraphmindResult<QueryResult> {
391 let prefixed = if cypher.trim().to_uppercase().starts_with("PROFILE") {
392 cypher.to_string()
393 } else {
394 format!("PROFILE {}", cypher)
395 };
396 let store_guard = self.store.read().await;
397 let batch = self
398 .engine
399 .execute(&prefixed, &*store_guard)
400 .map_err(|e| GraphmindError::QueryError(e.to_string()))?;
401 Ok(record_batch_to_query_result(&batch, &*store_guard))
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[tokio::test]
410 async fn test_embedded_ping() {
411 let client = EmbeddedClient::new();
412 let result = client.ping().await.unwrap();
413 assert_eq!(result, "PONG");
414 }
415
416 #[tokio::test]
417 async fn test_embedded_status() {
418 let client = EmbeddedClient::new();
419 let status = client.status().await.unwrap();
420 assert_eq!(status.status, "healthy");
421 assert_eq!(status.storage.nodes, 0);
422 }
423
424 #[tokio::test]
425 async fn test_embedded_create_and_query() {
426 let client = EmbeddedClient::new();
427
428 client
430 .query("default", r#"CREATE (n:Person {name: "Alice", age: 30})"#)
431 .await
432 .unwrap();
433 client
434 .query("default", r#"CREATE (n:Person {name: "Bob", age: 25})"#)
435 .await
436 .unwrap();
437
438 let result = client
440 .query_readonly("default", "MATCH (n:Person) RETURN n.name, n.age")
441 .await
442 .unwrap();
443 assert_eq!(result.columns.len(), 2);
444 assert_eq!(result.records.len(), 2);
445
446 let status = client.status().await.unwrap();
448 assert_eq!(status.storage.nodes, 2);
449 }
450
451 #[tokio::test]
452 async fn test_embedded_delete_graph() {
453 let client = EmbeddedClient::new();
454
455 client
456 .query("default", r#"CREATE (n:Person {name: "Alice"})"#)
457 .await
458 .unwrap();
459
460 let status = client.status().await.unwrap();
461 assert_eq!(status.storage.nodes, 1);
462
463 client.delete_graph("default").await.unwrap();
464
465 let status = client.status().await.unwrap();
466 assert_eq!(status.storage.nodes, 0);
467 }
468
469 #[tokio::test]
470 async fn test_embedded_list_graphs() {
471 let client = EmbeddedClient::new();
472 let graphs = client.list_graphs().await.unwrap();
473 assert_eq!(graphs, vec!["default"]);
474 }
475
476 #[tokio::test]
477 async fn test_embedded_query_with_edges() {
478 let client = EmbeddedClient::new();
479
480 client
481 .query(
482 "default",
483 r#"CREATE (a:Person {name: "Alice"})-[:KNOWS]->(b:Person {name: "Bob"})"#,
484 )
485 .await
486 .unwrap();
487
488 let result = client
489 .query_readonly(
490 "default",
491 "MATCH (a:Person)-[:KNOWS]->(b:Person) RETURN a.name, b.name",
492 )
493 .await
494 .unwrap();
495
496 assert_eq!(result.records.len(), 1);
497 }
498
499 #[tokio::test]
500 async fn test_embedded_with_existing_store() {
501 let mut store = GraphStore::new();
502 let alice = store.create_node("Person");
503 if let Some(node) = store.get_node_mut(alice) {
504 node.set_property("name", "Alice");
505 }
506
507 let store = Arc::new(RwLock::new(store));
508 let client = EmbeddedClient::with_store(store);
509
510 let result = client
511 .query_readonly("default", "MATCH (n:Person) RETURN n.name")
512 .await
513 .unwrap();
514 assert_eq!(result.records.len(), 1);
515 }
516
517 #[test]
520 fn test_embedded_default() {
521 let client = EmbeddedClient::default();
522 let store = client.store();
524 assert!(Arc::strong_count(store) >= 1);
525 }
526
527 #[tokio::test]
528 async fn test_embedded_store_read() {
529 let client = EmbeddedClient::new();
530 client
531 .query("default", r#"CREATE (n:Person {name: "Alice"})"#)
532 .await
533 .unwrap();
534
535 let guard = client.store_read().await;
536 assert_eq!(guard.node_count(), 1);
537 }
538
539 #[tokio::test]
540 async fn test_embedded_store_write() {
541 let client = EmbeddedClient::new();
542 {
543 let mut guard = client.store_write().await;
544 let id = guard.create_node("Person");
545 if let Some(node) = guard.get_node_mut(id) {
546 node.set_property("name", "DirectWrite");
547 }
548 }
549
550 let result = client
551 .query_readonly("default", "MATCH (n:Person) RETURN n.name")
552 .await
553 .unwrap();
554 assert_eq!(result.records.len(), 1);
555 }
556
557 #[tokio::test]
558 async fn test_embedded_cache_stats() {
559 let client = EmbeddedClient::new();
560 let stats = client.cache_stats();
561 assert_eq!(stats.hits(), 0);
563 }
564
565 #[tokio::test]
566 async fn test_embedded_cache_stats_after_queries() {
567 let client = EmbeddedClient::new();
568 client
569 .query("default", r#"CREATE (n:Person {name: "Alice"})"#)
570 .await
571 .unwrap();
572 client
574 .query_readonly("default", "MATCH (n:Person) RETURN n.name")
575 .await
576 .unwrap();
577 client
578 .query_readonly("default", "MATCH (n:Person) RETURN n.name")
579 .await
580 .unwrap();
581
582 let stats = client.cache_stats();
583 assert!(stats.hits() + stats.misses() >= 2);
585 }
586
587 #[tokio::test]
588 async fn test_embedded_query_readonly_error() {
589 let client = EmbeddedClient::new();
590 let result = client.query_readonly("default", "INVALID SYNTAX !!!").await;
592 assert!(result.is_err());
593 }
594
595 #[tokio::test]
596 async fn test_embedded_query_write_error() {
597 let client = EmbeddedClient::new();
598 let result = client.query("default", "CREATE INVALID").await;
600 assert!(result.is_err());
601 }
602
603 #[tokio::test]
604 async fn test_embedded_version_in_status() {
605 let client = EmbeddedClient::new();
606 let status = client.status().await.unwrap();
607 assert!(!status.version.is_empty());
609 }
610
611 #[tokio::test]
612 async fn test_embedded_query_returns_nodes() {
613 let client = EmbeddedClient::new();
614 client
615 .query("default", r#"CREATE (n:Person {name: "Alice", age: 30})"#)
616 .await
617 .unwrap();
618
619 let result = client
620 .query_readonly("default", "MATCH (n:Person) RETURN n")
621 .await
622 .unwrap();
623 assert_eq!(result.records.len(), 1);
624 assert!(!result.nodes.is_empty());
625 let node = &result.nodes[0];
627 assert!(node.labels.contains(&"Person".to_string()));
628 }
629
630 #[tokio::test]
631 async fn test_embedded_query_returns_edges() {
632 let client = EmbeddedClient::new();
633 client.query("default",
634 r#"CREATE (a:Person {name: "Alice"})-[:KNOWS {since: 2020}]->(b:Person {name: "Bob"})"#
635 ).await.unwrap();
636
637 let result = client
638 .query_readonly("default", "MATCH (a)-[r:KNOWS]->(b) RETURN r")
639 .await
640 .unwrap();
641 assert_eq!(result.records.len(), 1);
642 assert!(!result.edges.is_empty());
643 let edge = &result.edges[0];
644 assert_eq!(edge.edge_type, "KNOWS");
645 }
646
647 #[tokio::test]
648 async fn test_embedded_query_returns_null() {
649 let client = EmbeddedClient::new();
650 client
651 .query("default", r#"CREATE (n:Person {name: "Alice"})"#)
652 .await
653 .unwrap();
654
655 let result = client
657 .query_readonly("default", "MATCH (n:Person) RETURN n.missing")
658 .await
659 .unwrap();
660 assert_eq!(result.records.len(), 1);
661 assert_eq!(result.records[0][0], serde_json::Value::Null);
663 }
664
665 #[tokio::test]
666 async fn test_embedded_multiple_writes_and_reads() {
667 let client = EmbeddedClient::new();
668
669 for i in 0..5 {
670 client
671 .query("default", &format!(r#"CREATE (n:Item {{id: {}}})"#, i))
672 .await
673 .unwrap();
674 }
675
676 let result = client
677 .query_readonly("default", "MATCH (n:Item) RETURN n.id")
678 .await
679 .unwrap();
680 assert_eq!(result.records.len(), 5);
681 }
682
683 #[tokio::test]
684 async fn test_embedded_delete_graph_and_recreate() {
685 let client = EmbeddedClient::new();
686
687 client
688 .query("default", r#"CREATE (n:Person {name: "Alice"})"#)
689 .await
690 .unwrap();
691 assert_eq!(client.status().await.unwrap().storage.nodes, 1);
692
693 client.delete_graph("default").await.unwrap();
694 assert_eq!(client.status().await.unwrap().storage.nodes, 0);
695
696 client
698 .query("default", r#"CREATE (n:Person {name: "Bob"})"#)
699 .await
700 .unwrap();
701 assert_eq!(client.status().await.unwrap().storage.nodes, 1);
702 }
703
704 #[tokio::test]
705 async fn test_embedded_with_store_shares_state() {
706 let store = Arc::new(RwLock::new(GraphStore::new()));
707 let client = EmbeddedClient::with_store(Arc::clone(&store));
708
709 client
710 .query("default", r#"CREATE (n:Person {name: "Alice"})"#)
711 .await
712 .unwrap();
713
714 let guard = store.read().await;
716 assert_eq!(guard.node_count(), 1);
717 }
718
719 #[test]
720 fn test_is_write_query_variants() {
721 assert!(is_write_query("CREATE (n:Person)"));
722 assert!(is_write_query("DELETE n"));
723 assert!(is_write_query("SET n.name = 'x'"));
724 assert!(is_write_query("MERGE (n:Person)"));
725 assert!(is_write_query("CALL db.something()"));
726 assert!(is_write_query("MATCH (n) CREATE (m)"));
727 assert!(is_write_query("MATCH (n) DELETE n"));
728 assert!(is_write_query("MATCH (n) SET n.x = 1"));
729 assert!(is_write_query("MATCH (n) MERGE (m)"));
730 assert!(is_write_query("MATCH (n) CALL db.x()"));
731
732 assert!(!is_write_query("MATCH (n) RETURN n"));
733 assert!(!is_write_query("MATCH (n:Person) RETURN n.name"));
734 assert!(!is_write_query("RETURN 1 + 2"));
735 }
736
737 #[tokio::test]
738 async fn test_embedded_query_property_values() {
739 let client = EmbeddedClient::new();
740 client
741 .query(
742 "default",
743 r#"CREATE (n:Person {name: "Alice", age: 30, score: 95.5, active: true})"#,
744 )
745 .await
746 .unwrap();
747
748 let result = client
749 .query_readonly(
750 "default",
751 "MATCH (n:Person) RETURN n.name, n.age, n.score, n.active",
752 )
753 .await
754 .unwrap();
755 assert_eq!(result.records.len(), 1);
756 assert_eq!(result.columns.len(), 4);
757 }
758
759 #[tokio::test]
760 async fn test_embedded_store_accessor() {
761 let client = EmbeddedClient::new();
762 let store_ref = client.store();
763 let _cloned = Arc::clone(store_ref);
765 assert!(Arc::strong_count(store_ref) >= 2);
766 }
767}