1use std::{
2 collections::BTreeMap,
3 sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11 inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16 nodes: BTreeMap<NodeId, Node>,
17 edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18}
19
20impl MemoryGraphStore {
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 pub fn graph(&self) -> Graph {
26 let inner = self.inner.read().expect("memory graph lock poisoned");
27 Graph {
28 nodes: inner.nodes.values().cloned().collect(),
29 edges: inner.edges.values().cloned().collect(),
30 }
31 }
32}
33
34#[async_trait]
35impl GraphStore for MemoryGraphStore {
36 async fn put_node(&self, node: &Node) -> Result<NodeId> {
37 let mut inner = self.inner.write().expect("memory graph lock poisoned");
38 inner.nodes.insert(node.id.clone(), node.clone());
39 Ok(node.id.clone())
40 }
41
42 async fn put_edge(&self, edge: &Edge) -> Result<Option<EdgeId>> {
43 let mut inner = self.inner.write().expect("memory graph lock poisoned");
44 inner.edges.insert(
45 (edge.from.clone(), edge.label.clone(), edge.to.clone()),
46 edge.clone(),
47 );
48 Ok(edge.id.clone())
49 }
50
51 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
52 let inner = self.inner.read().expect("memory graph lock poisoned");
53 Ok(inner.nodes.get(id).cloned())
54 }
55
56 async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
57 let inner = self.inner.read().expect("memory graph lock poisoned");
58 Ok(inner
59 .edges
60 .values()
61 .filter(|edge| {
62 query.from.as_ref().is_none_or(|from| from == &edge.from)
63 && query.to.as_ref().is_none_or(|to| to == &edge.to)
64 && query
65 .label
66 .as_ref()
67 .is_none_or(|label| label == &edge.label)
68 })
69 .cloned()
70 .collect())
71 }
72
73 async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
74 let inner = self.inner.read().expect("memory graph lock poisoned");
75 let mut current = match traversal.start {
76 Start::Node(id) => inner
77 .nodes
78 .get(&id)
79 .cloned()
80 .into_iter()
81 .collect::<Vec<_>>(),
82 Start::NodesByLabel(label) => inner
83 .nodes
84 .values()
85 .filter(|node| node.label == label)
86 .cloned()
87 .collect(),
88 Start::NodesByProperty { label, key, value } => inner
89 .nodes
90 .values()
91 .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
92 .cloned()
93 .collect(),
94 };
95
96 for step in traversal.steps {
97 let mut next = Vec::new();
98 for node in ¤t {
99 for edge in inner.edges.values() {
100 let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
101 let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
102 && edge.from == node.id;
103 let in_matches = matches!(step.direction, Direction::In | Direction::Both)
104 && edge.to == node.id;
105
106 if !label_matches || (!out_matches && !in_matches) {
107 continue;
108 }
109
110 let target_id = if out_matches { &edge.to } else { &edge.from };
111 if let Some(target) = inner.nodes.get(target_id) {
112 if step
113 .node
114 .as_ref()
115 .is_none_or(|label| label == &target.label)
116 {
117 next.push(target.clone());
118 }
119 }
120 }
121 }
122 current = next;
123 }
124
125 if let Some(limit) = traversal.limit {
126 current.truncate(limit as usize);
127 }
128 Ok(current)
129 }
130}
131
132#[cfg(test)]
133mod tests;