Skip to main content

grust_memory/
lib.rs

1use std::{
2    collections::BTreeMap,
3    sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11    inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16    nodes: BTreeMap<NodeId, Node>,
17    edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18    schema: Option<GraphSchema>,
19}
20
21impl MemoryGraphStore {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    pub fn graph(&self) -> Graph {
27        let inner = self.inner.read().expect("memory graph lock poisoned");
28        Graph {
29            nodes: inner.nodes.values().cloned().collect(),
30            edges: inner.edges.values().cloned().collect(),
31        }
32    }
33}
34
35#[async_trait]
36impl GraphStore for MemoryGraphStore {
37    async fn apply_schema(&self, schema: &GraphSchema) -> Result<()> {
38        let mut inner = self.inner.write().expect("memory graph lock poisoned");
39        inner.schema = Some(schema.clone());
40        Ok(())
41    }
42
43    async fn put_node(&self, node: &Node) -> Result<NodeId> {
44        let mut inner = self.inner.write().expect("memory graph lock poisoned");
45        if let Some(schema) = &inner.schema {
46            schema.validate_node(node)?;
47        }
48        inner.nodes.insert(node.id.clone(), node.clone());
49        Ok(node.id.clone())
50    }
51
52    async fn put_edge(&self, edge: &Edge) -> Result<Option<EdgeId>> {
53        let mut inner = self.inner.write().expect("memory graph lock poisoned");
54        if let Some(schema) = &inner.schema {
55            let mut graph = Graph {
56                nodes: inner.nodes.values().cloned().collect(),
57                edges: inner.edges.values().cloned().collect(),
58            };
59            graph.edges.push(edge.clone());
60            schema.validate_edge(edge, &graph)?;
61        }
62        inner.edges.insert(
63            (edge.from.clone(), edge.label.clone(), edge.to.clone()),
64            edge.clone(),
65        );
66        Ok(edge.id.clone())
67    }
68
69    async fn put_graph(&self, graph: &Graph) -> Result<LoadReport> {
70        let mut inner = self.inner.write().expect("memory graph lock poisoned");
71        if let Some(schema) = &inner.schema {
72            schema.validate_graph(graph)?;
73        }
74        let mut report = LoadReport::default();
75        for node in &graph.nodes {
76            inner.nodes.insert(node.id.clone(), node.clone());
77            report.nodes += 1;
78        }
79        for edge in &graph.edges {
80            inner.edges.insert(
81                (edge.from.clone(), edge.label.clone(), edge.to.clone()),
82                edge.clone(),
83            );
84            report.edges += 1;
85        }
86        Ok(report)
87    }
88
89    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
90        let inner = self.inner.read().expect("memory graph lock poisoned");
91        Ok(inner.nodes.get(id).cloned())
92    }
93
94    async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
95        let inner = self.inner.read().expect("memory graph lock poisoned");
96        Ok(inner
97            .edges
98            .values()
99            .filter(|edge| {
100                query.from.as_ref().is_none_or(|from| from == &edge.from)
101                    && query.to.as_ref().is_none_or(|to| to == &edge.to)
102                    && query
103                        .label
104                        .as_ref()
105                        .is_none_or(|label| label == &edge.label)
106            })
107            .cloned()
108            .collect())
109    }
110
111    async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
112        let inner = self.inner.read().expect("memory graph lock poisoned");
113        let mut current = match traversal.start {
114            Start::Node(id) => inner
115                .nodes
116                .get(&id)
117                .cloned()
118                .into_iter()
119                .collect::<Vec<_>>(),
120            Start::NodesByLabel(label) => inner
121                .nodes
122                .values()
123                .filter(|node| node.label == label)
124                .cloned()
125                .collect(),
126            Start::NodesByProperty { label, key, value } => inner
127                .nodes
128                .values()
129                .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
130                .cloned()
131                .collect(),
132        };
133
134        for step in traversal.steps {
135            let mut next = Vec::new();
136            for node in &current {
137                for edge in inner.edges.values() {
138                    let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
139                    let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
140                        && edge.from == node.id;
141                    let in_matches = matches!(step.direction, Direction::In | Direction::Both)
142                        && edge.to == node.id;
143
144                    if !label_matches || (!out_matches && !in_matches) {
145                        continue;
146                    }
147
148                    let target_id = if out_matches { &edge.to } else { &edge.from };
149                    if let Some(target) = inner.nodes.get(target_id) {
150                        if step
151                            .node
152                            .as_ref()
153                            .is_none_or(|label| label == &target.label)
154                        {
155                            next.push(target.clone());
156                        }
157                    }
158                }
159            }
160            current = next;
161        }
162
163        if let Some(limit) = traversal.limit {
164            current.truncate(limit as usize);
165        }
166        Ok(current)
167    }
168}
169
170#[cfg(test)]
171mod tests;