1use std::{
2 collections::BTreeMap,
3 sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11 inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16 nodes: BTreeMap<NodeId, Node>,
17 edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18 schema: Option<GraphSchema>,
19}
20
21impl MemoryGraphStore {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn graph(&self) -> Graph {
27 let inner = self.inner.read().expect("memory graph lock poisoned");
28 Graph {
29 nodes: inner.nodes.values().cloned().collect(),
30 edges: inner.edges.values().cloned().collect(),
31 }
32 }
33}
34
35#[async_trait]
36impl GraphStore for MemoryGraphStore {
37 async fn apply_schema(&self, schema: &GraphSchema) -> Result<()> {
38 let mut inner = self.inner.write().expect("memory graph lock poisoned");
39 inner.schema = Some(schema.clone());
40 Ok(())
41 }
42
43 async fn put_node(&self, node: &Node) -> Result<PutOutcome> {
44 let mut inner = self.inner.write().expect("memory graph lock poisoned");
45 if let Some(schema) = &inner.schema {
46 schema.validate_node(node)?;
47 }
48 let previous = inner.nodes.insert(node.id.clone(), node.clone());
49 Ok(match previous {
50 Some(_) => PutOutcome::Updated,
51 None => PutOutcome::Inserted,
52 })
53 }
54
55 async fn put_edge(&self, edge: &Edge) -> Result<PutOutcome> {
56 let mut inner = self.inner.write().expect("memory graph lock poisoned");
57 if let Some(schema) = &inner.schema {
58 schema.validate_edge_with(edge, |id| inner.nodes.get(id).map(|node| &node.label))?;
59 }
60 let previous = inner.edges.insert(
61 (edge.from.clone(), edge.label.clone(), edge.to.clone()),
62 edge.clone(),
63 );
64 Ok(match previous {
65 Some(_) => PutOutcome::Updated,
66 None => PutOutcome::Inserted,
67 })
68 }
69
70 async fn put_graph(&self, graph: &Graph) -> Result<LoadReport> {
71 let mut inner = self.inner.write().expect("memory graph lock poisoned");
72 if let Some(schema) = &inner.schema {
73 schema.validate_graph(graph)?;
74 }
75 let mut report = LoadReport::default();
76 for node in &graph.nodes {
77 inner.nodes.insert(node.id.clone(), node.clone());
78 report.nodes += 1;
79 }
80 for edge in &graph.edges {
81 inner.edges.insert(
82 (edge.from.clone(), edge.label.clone(), edge.to.clone()),
83 edge.clone(),
84 );
85 report.edges += 1;
86 }
87 Ok(report)
88 }
89
90 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
91 let inner = self.inner.read().expect("memory graph lock poisoned");
92 Ok(inner.nodes.get(id).cloned())
93 }
94
95 async fn get_nodes(&self, ids: &[NodeId]) -> Result<Vec<Node>> {
96 let inner = self.inner.read().expect("memory graph lock poisoned");
97 Ok(ids
98 .iter()
99 .filter_map(|id| inner.nodes.get(id).cloned())
100 .collect())
101 }
102
103 async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
104 let inner = self.inner.read().expect("memory graph lock poisoned");
105 Ok(inner
106 .edges
107 .values()
108 .filter(|edge| {
109 query.from.as_ref().is_none_or(|from| from == &edge.from)
110 && query.to.as_ref().is_none_or(|to| to == &edge.to)
111 && query
112 .label
113 .as_ref()
114 .is_none_or(|label| label == &edge.label)
115 })
116 .cloned()
117 .collect())
118 }
119
120 async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
121 let inner = self.inner.read().expect("memory graph lock poisoned");
122 let mut current = match traversal.start {
123 Start::Node(id) => inner
124 .nodes
125 .get(&id)
126 .cloned()
127 .into_iter()
128 .collect::<Vec<_>>(),
129 Start::NodesByLabel(label) => inner
130 .nodes
131 .values()
132 .filter(|node| node.label == label)
133 .cloned()
134 .collect(),
135 Start::NodesByProperty { label, key, value } => inner
136 .nodes
137 .values()
138 .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
139 .cloned()
140 .collect(),
141 };
142
143 for step in traversal.steps {
144 let mut next = Vec::new();
145 for node in ¤t {
146 for edge in inner.edges.values() {
147 let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
148 let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
149 && edge.from == node.id;
150 let in_matches = matches!(step.direction, Direction::In | Direction::Both)
151 && edge.to == node.id;
152
153 if !label_matches || (!out_matches && !in_matches) {
154 continue;
155 }
156
157 let target_id = if out_matches { &edge.to } else { &edge.from };
158 if let Some(target) = inner.nodes.get(target_id)
159 && step
160 .node
161 .as_ref()
162 .is_none_or(|label| label == &target.label)
163 {
164 next.push(target.clone());
165 }
166 }
167 }
168 current = next;
169 }
170
171 if let Some(limit) = traversal.limit {
172 current.truncate(limit as usize);
173 }
174 Ok(current)
175 }
176}
177
178#[async_trait]
179impl GraphMutationStore for MemoryGraphStore {
180 async fn delete_node(&self, id: &NodeId) -> Result<()> {
181 let mut inner = self.inner.write().expect("memory graph lock poisoned");
182 inner.nodes.remove(id);
183 inner
184 .edges
185 .retain(|(from, _, to), _| from != id && to != id);
186 Ok(())
187 }
188
189 async fn delete_edge(&self, from: &NodeId, label: &Label, to: &NodeId) -> Result<()> {
190 let mut inner = self.inner.write().expect("memory graph lock poisoned");
191 inner
192 .edges
193 .remove(&(from.clone(), label.clone(), to.clone()));
194 Ok(())
195 }
196}
197
198#[cfg(test)]
199mod tests;