1use std::{
2 collections::BTreeMap,
3 sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11 inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16 nodes: BTreeMap<NodeId, Node>,
17 edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18 schema: Option<GraphSchema>,
19}
20
21impl MemoryGraphStore {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn graph(&self) -> Graph {
27 let inner = self.inner.read().expect("memory graph lock poisoned");
28 Graph {
29 nodes: inner.nodes.values().cloned().collect(),
30 edges: inner.edges.values().cloned().collect(),
31 }
32 }
33}
34
35#[async_trait]
36impl GraphStore for MemoryGraphStore {
37 async fn apply_schema(&self, schema: &GraphSchema) -> Result<()> {
38 let mut inner = self.inner.write().expect("memory graph lock poisoned");
39 inner.schema = Some(schema.clone());
40 Ok(())
41 }
42
43 async fn put_node(&self, node: &Node) -> Result<PutOutcome> {
44 let mut inner = self.inner.write().expect("memory graph lock poisoned");
45 if let Some(schema) = &inner.schema {
46 schema.validate_node(node)?;
47 }
48 let previous = inner.nodes.insert(node.id.clone(), node.clone());
49 Ok(match previous {
50 Some(_) => PutOutcome::Updated,
51 None => PutOutcome::Inserted,
52 })
53 }
54
55 async fn put_edge(&self, edge: &Edge) -> Result<PutOutcome> {
56 let mut inner = self.inner.write().expect("memory graph lock poisoned");
57 if let Some(schema) = &inner.schema {
58 schema.validate_edge_with(edge, |id| inner.nodes.get(id).map(|node| &node.label))?;
59 }
60 let previous = inner.edges.insert(
61 (edge.from.clone(), edge.label.clone(), edge.to.clone()),
62 edge.clone(),
63 );
64 Ok(match previous {
65 Some(_) => PutOutcome::Updated,
66 None => PutOutcome::Inserted,
67 })
68 }
69
70 async fn put_graph(&self, graph: &Graph) -> Result<LoadReport> {
71 let mut inner = self.inner.write().expect("memory graph lock poisoned");
72 if let Some(schema) = &inner.schema {
73 schema.validate_graph(graph)?;
74 }
75 let mut report = LoadReport::default();
76 for node in &graph.nodes {
77 inner.nodes.insert(node.id.clone(), node.clone());
78 report.nodes += 1;
79 }
80 for edge in &graph.edges {
81 inner.edges.insert(
82 (edge.from.clone(), edge.label.clone(), edge.to.clone()),
83 edge.clone(),
84 );
85 report.edges += 1;
86 }
87 Ok(report)
88 }
89
90 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
91 let inner = self.inner.read().expect("memory graph lock poisoned");
92 Ok(inner.nodes.get(id).cloned())
93 }
94
95 async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
96 let inner = self.inner.read().expect("memory graph lock poisoned");
97 Ok(inner
98 .edges
99 .values()
100 .filter(|edge| {
101 query.from.as_ref().is_none_or(|from| from == &edge.from)
102 && query.to.as_ref().is_none_or(|to| to == &edge.to)
103 && query
104 .label
105 .as_ref()
106 .is_none_or(|label| label == &edge.label)
107 })
108 .cloned()
109 .collect())
110 }
111
112 async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
113 let inner = self.inner.read().expect("memory graph lock poisoned");
114 let mut current = match traversal.start {
115 Start::Node(id) => inner
116 .nodes
117 .get(&id)
118 .cloned()
119 .into_iter()
120 .collect::<Vec<_>>(),
121 Start::NodesByLabel(label) => inner
122 .nodes
123 .values()
124 .filter(|node| node.label == label)
125 .cloned()
126 .collect(),
127 Start::NodesByProperty { label, key, value } => inner
128 .nodes
129 .values()
130 .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
131 .cloned()
132 .collect(),
133 };
134
135 for step in traversal.steps {
136 let mut next = Vec::new();
137 for node in ¤t {
138 for edge in inner.edges.values() {
139 let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
140 let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
141 && edge.from == node.id;
142 let in_matches = matches!(step.direction, Direction::In | Direction::Both)
143 && edge.to == node.id;
144
145 if !label_matches || (!out_matches && !in_matches) {
146 continue;
147 }
148
149 let target_id = if out_matches { &edge.to } else { &edge.from };
150 if let Some(target) = inner.nodes.get(target_id)
151 && step
152 .node
153 .as_ref()
154 .is_none_or(|label| label == &target.label)
155 {
156 next.push(target.clone());
157 }
158 }
159 }
160 current = next;
161 }
162
163 if let Some(limit) = traversal.limit {
164 current.truncate(limit as usize);
165 }
166 Ok(current)
167 }
168}
169
170#[async_trait]
171impl GraphMutationStore for MemoryGraphStore {
172 async fn delete_node(&self, id: &NodeId) -> Result<()> {
173 let mut inner = self.inner.write().expect("memory graph lock poisoned");
174 inner.nodes.remove(id);
175 inner
176 .edges
177 .retain(|(from, _, to), _| from != id && to != id);
178 Ok(())
179 }
180
181 async fn delete_edge(&self, from: &NodeId, label: &Label, to: &NodeId) -> Result<()> {
182 let mut inner = self.inner.write().expect("memory graph lock poisoned");
183 inner
184 .edges
185 .remove(&(from.clone(), label.clone(), to.clone()));
186 Ok(())
187 }
188}
189
190#[cfg(test)]
191mod tests;