Skip to main content

grust_memory/
lib.rs

1use std::{
2    collections::BTreeMap,
3    sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11    inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16    nodes: BTreeMap<NodeId, Node>,
17    edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18    schema: Option<GraphSchema>,
19}
20
21impl MemoryGraphStore {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn graph(&self) -> Graph {
27        let inner = self.inner.read().expect("memory graph lock poisoned");
28        Graph {
29            nodes: inner.nodes.values().cloned().collect(),
30            edges: inner.edges.values().cloned().collect(),
31        }
32    }
33}
34
35#[async_trait]
36impl GraphStore for MemoryGraphStore {
37    async fn apply_schema(&self, schema: &GraphSchema) -> Result<()> {
38        let mut inner = self.inner.write().expect("memory graph lock poisoned");
39        inner.schema = Some(schema.clone());
40        Ok(())
41    }
42
43    async fn put_node(&self, node: &Node) -> Result<PutOutcome> {
44        let mut inner = self.inner.write().expect("memory graph lock poisoned");
45        if let Some(schema) = &inner.schema {
46            schema.validate_node(node)?;
47        }
48        let previous = inner.nodes.insert(node.id.clone(), node.clone());
49        Ok(match previous {
50            Some(_) => PutOutcome::Updated,
51            None => PutOutcome::Inserted,
52        })
53    }
54
55    async fn put_edge(&self, edge: &Edge) -> Result<PutOutcome> {
56        let mut inner = self.inner.write().expect("memory graph lock poisoned");
57        if let Some(schema) = &inner.schema {
58            schema.validate_edge_with(edge, |id| inner.nodes.get(id).map(|node| &node.label))?;
59        }
60        let previous = inner.edges.insert(
61            (edge.from.clone(), edge.label.clone(), edge.to.clone()),
62            edge.clone(),
63        );
64        Ok(match previous {
65            Some(_) => PutOutcome::Updated,
66            None => PutOutcome::Inserted,
67        })
68    }
69
70    async fn put_graph(&self, graph: &Graph) -> Result<LoadReport> {
71        let mut inner = self.inner.write().expect("memory graph lock poisoned");
72        if let Some(schema) = &inner.schema {
73            schema.validate_graph(graph)?;
74        }
75        let mut report = LoadReport::default();
76        for node in &graph.nodes {
77            inner.nodes.insert(node.id.clone(), node.clone());
78            report.nodes += 1;
79        }
80        for edge in &graph.edges {
81            inner.edges.insert(
82                (edge.from.clone(), edge.label.clone(), edge.to.clone()),
83                edge.clone(),
84            );
85            report.edges += 1;
86        }
87        Ok(report)
88    }
89
90    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
91        let inner = self.inner.read().expect("memory graph lock poisoned");
92        Ok(inner.nodes.get(id).cloned())
93    }
94
95    async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
96        let inner = self.inner.read().expect("memory graph lock poisoned");
97        Ok(inner
98            .edges
99            .values()
100            .filter(|edge| {
101                query.from.as_ref().is_none_or(|from| from == &edge.from)
102                    && query.to.as_ref().is_none_or(|to| to == &edge.to)
103                    && query
104                        .label
105                        .as_ref()
106                        .is_none_or(|label| label == &edge.label)
107            })
108            .cloned()
109            .collect())
110    }
111
112    async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
113        let inner = self.inner.read().expect("memory graph lock poisoned");
114        let mut current = match traversal.start {
115            Start::Node(id) => inner
116                .nodes
117                .get(&id)
118                .cloned()
119                .into_iter()
120                .collect::<Vec<_>>(),
121            Start::NodesByLabel(label) => inner
122                .nodes
123                .values()
124                .filter(|node| node.label == label)
125                .cloned()
126                .collect(),
127            Start::NodesByProperty { label, key, value } => inner
128                .nodes
129                .values()
130                .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
131                .cloned()
132                .collect(),
133        };
134
135        for step in traversal.steps {
136            let mut next = Vec::new();
137            for node in &current {
138                for edge in inner.edges.values() {
139                    let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
140                    let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
141                        && edge.from == node.id;
142                    let in_matches = matches!(step.direction, Direction::In | Direction::Both)
143                        && edge.to == node.id;
144
145                    if !label_matches || (!out_matches && !in_matches) {
146                        continue;
147                    }
148
149                    let target_id = if out_matches { &edge.to } else { &edge.from };
150                    if let Some(target) = inner.nodes.get(target_id)
151                        && step
152                            .node
153                            .as_ref()
154                            .is_none_or(|label| label == &target.label)
155                    {
156                        next.push(target.clone());
157                    }
158                }
159            }
160            current = next;
161        }
162
163        if let Some(limit) = traversal.limit {
164            current.truncate(limit as usize);
165        }
166        Ok(current)
167    }
168}
169
170#[async_trait]
171impl GraphMutationStore for MemoryGraphStore {
172    async fn delete_node(&self, id: &NodeId) -> Result<()> {
173        let mut inner = self.inner.write().expect("memory graph lock poisoned");
174        inner.nodes.remove(id);
175        inner
176            .edges
177            .retain(|(from, _, to), _| from != id && to != id);
178        Ok(())
179    }
180
181    async fn delete_edge(&self, from: &NodeId, label: &Label, to: &NodeId) -> Result<()> {
182        let mut inner = self.inner.write().expect("memory graph lock poisoned");
183        inner
184            .edges
185            .remove(&(from.clone(), label.clone(), to.clone()));
186        Ok(())
187    }
188}
189
190#[cfg(test)]
191mod tests;