Skip to main content

grust_memory/
lib.rs

1use std::{
2    collections::BTreeMap,
3    sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11    inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16    nodes: BTreeMap<NodeId, Node>,
17    edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18    schema: Option<GraphSchema>,
19}
20
21impl MemoryGraphStore {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn graph(&self) -> Graph {
27        let inner = self.inner.read().expect("memory graph lock poisoned");
28        Graph {
29            nodes: inner.nodes.values().cloned().collect(),
30            edges: inner.edges.values().cloned().collect(),
31        }
32    }
33}
34
35#[async_trait]
36impl GraphStore for MemoryGraphStore {
37    async fn apply_schema(&self, schema: &GraphSchema) -> Result<()> {
38        let mut inner = self.inner.write().expect("memory graph lock poisoned");
39        inner.schema = Some(schema.clone());
40        Ok(())
41    }
42
43    async fn put_node(&self, node: &Node) -> Result<PutOutcome> {
44        let mut inner = self.inner.write().expect("memory graph lock poisoned");
45        if let Some(schema) = &inner.schema {
46            schema.validate_node(node)?;
47        }
48        let previous = inner.nodes.insert(node.id.clone(), node.clone());
49        Ok(match previous {
50            Some(_) => PutOutcome::Updated,
51            None => PutOutcome::Inserted,
52        })
53    }
54
55    async fn put_edge(&self, edge: &Edge) -> Result<PutOutcome> {
56        let mut inner = self.inner.write().expect("memory graph lock poisoned");
57        if let Some(schema) = &inner.schema {
58            schema.validate_edge_with(edge, |id| inner.nodes.get(id).map(|node| &node.label))?;
59        }
60        let previous = inner.edges.insert(
61            (edge.from.clone(), edge.label.clone(), edge.to.clone()),
62            edge.clone(),
63        );
64        Ok(match previous {
65            Some(_) => PutOutcome::Updated,
66            None => PutOutcome::Inserted,
67        })
68    }
69
70    async fn put_graph(&self, graph: &Graph) -> Result<LoadReport> {
71        let mut inner = self.inner.write().expect("memory graph lock poisoned");
72        if let Some(schema) = &inner.schema {
73            schema.validate_graph(graph)?;
74        }
75        let mut report = LoadReport::default();
76        for node in &graph.nodes {
77            inner.nodes.insert(node.id.clone(), node.clone());
78            report.nodes += 1;
79        }
80        for edge in &graph.edges {
81            inner.edges.insert(
82                (edge.from.clone(), edge.label.clone(), edge.to.clone()),
83                edge.clone(),
84            );
85            report.edges += 1;
86        }
87        Ok(report)
88    }
89
90    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
91        let inner = self.inner.read().expect("memory graph lock poisoned");
92        Ok(inner.nodes.get(id).cloned())
93    }
94
95    async fn get_nodes(&self, ids: &[NodeId]) -> Result<Vec<Node>> {
96        let inner = self.inner.read().expect("memory graph lock poisoned");
97        Ok(ids
98            .iter()
99            .filter_map(|id| inner.nodes.get(id).cloned())
100            .collect())
101    }
102
103    async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
104        let inner = self.inner.read().expect("memory graph lock poisoned");
105        Ok(inner
106            .edges
107            .values()
108            .filter(|edge| {
109                query.from.as_ref().is_none_or(|from| from == &edge.from)
110                    && query.to.as_ref().is_none_or(|to| to == &edge.to)
111                    && query
112                        .label
113                        .as_ref()
114                        .is_none_or(|label| label == &edge.label)
115            })
116            .cloned()
117            .collect())
118    }
119
120    async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
121        let inner = self.inner.read().expect("memory graph lock poisoned");
122        let mut current = match traversal.start {
123            Start::Node(id) => inner
124                .nodes
125                .get(&id)
126                .cloned()
127                .into_iter()
128                .collect::<Vec<_>>(),
129            Start::NodesByLabel(label) => inner
130                .nodes
131                .values()
132                .filter(|node| node.label == label)
133                .cloned()
134                .collect(),
135            Start::NodesByProperty { label, key, value } => inner
136                .nodes
137                .values()
138                .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
139                .cloned()
140                .collect(),
141        };
142
143        for step in traversal.steps {
144            let mut next = Vec::new();
145            for node in &current {
146                for edge in inner.edges.values() {
147                    let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
148                    let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
149                        && edge.from == node.id;
150                    let in_matches = matches!(step.direction, Direction::In | Direction::Both)
151                        && edge.to == node.id;
152
153                    if !label_matches || (!out_matches && !in_matches) {
154                        continue;
155                    }
156
157                    let target_id = if out_matches { &edge.to } else { &edge.from };
158                    if let Some(target) = inner.nodes.get(target_id)
159                        && step
160                            .node
161                            .as_ref()
162                            .is_none_or(|label| label == &target.label)
163                    {
164                        next.push(target.clone());
165                    }
166                }
167            }
168            current = next;
169        }
170
171        if let Some(limit) = traversal.limit {
172            current.truncate(limit as usize);
173        }
174        Ok(current)
175    }
176}
177
178#[async_trait]
179impl GraphMutationStore for MemoryGraphStore {
180    async fn delete_node(&self, id: &NodeId) -> Result<()> {
181        let mut inner = self.inner.write().expect("memory graph lock poisoned");
182        inner.nodes.remove(id);
183        inner
184            .edges
185            .retain(|(from, _, to), _| from != id && to != id);
186        Ok(())
187    }
188
189    async fn delete_edge(&self, from: &NodeId, label: &Label, to: &NodeId) -> Result<()> {
190        let mut inner = self.inner.write().expect("memory graph lock poisoned");
191        inner
192            .edges
193            .remove(&(from.clone(), label.clone(), to.clone()));
194        Ok(())
195    }
196}
197
198#[cfg(test)]
199mod tests;