Skip to main content

grust_memory/
lib.rs

1use std::{
2    collections::BTreeMap,
3    sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11    inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16    nodes: BTreeMap<NodeId, Node>,
17    edges: BTreeMap<(NodeId, Label, NodeId), Edge>,
18}
19
20impl MemoryGraphStore {
21    pub fn new() -> Self {
22        Self::default()
23    }
24
25    pub fn graph(&self) -> Graph {
26        let inner = self.inner.read().expect("memory graph lock poisoned");
27        Graph {
28            nodes: inner.nodes.values().cloned().collect(),
29            edges: inner.edges.values().cloned().collect(),
30        }
31    }
32}
33
34#[async_trait]
35impl GraphStore for MemoryGraphStore {
36    async fn put_node(&self, node: &Node) -> Result<NodeId> {
37        let mut inner = self.inner.write().expect("memory graph lock poisoned");
38        inner.nodes.insert(node.id.clone(), node.clone());
39        Ok(node.id.clone())
40    }
41
42    async fn put_edge(&self, edge: &Edge) -> Result<Option<EdgeId>> {
43        let mut inner = self.inner.write().expect("memory graph lock poisoned");
44        inner.edges.insert(
45            (edge.from.clone(), edge.label.clone(), edge.to.clone()),
46            edge.clone(),
47        );
48        Ok(edge.id.clone())
49    }
50
51    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
52        let inner = self.inner.read().expect("memory graph lock poisoned");
53        Ok(inner.nodes.get(id).cloned())
54    }
55
56    async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
57        let inner = self.inner.read().expect("memory graph lock poisoned");
58        Ok(inner
59            .edges
60            .values()
61            .filter(|edge| {
62                query.from.as_ref().is_none_or(|from| from == &edge.from)
63                    && query.to.as_ref().is_none_or(|to| to == &edge.to)
64                    && query
65                        .label
66                        .as_ref()
67                        .is_none_or(|label| label == &edge.label)
68            })
69            .cloned()
70            .collect())
71    }
72
73    async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
74        let inner = self.inner.read().expect("memory graph lock poisoned");
75        let mut current = match traversal.start {
76            Start::Node(id) => inner
77                .nodes
78                .get(&id)
79                .cloned()
80                .into_iter()
81                .collect::<Vec<_>>(),
82            Start::NodesByLabel(label) => inner
83                .nodes
84                .values()
85                .filter(|node| node.label == label)
86                .cloned()
87                .collect(),
88            Start::NodesByProperty { label, key, value } => inner
89                .nodes
90                .values()
91                .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
92                .cloned()
93                .collect(),
94        };
95
96        for step in traversal.steps {
97            let mut next = Vec::new();
98            for node in &current {
99                for edge in inner.edges.values() {
100                    let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
101                    let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
102                        && edge.from == node.id;
103                    let in_matches = matches!(step.direction, Direction::In | Direction::Both)
104                        && edge.to == node.id;
105
106                    if !label_matches || (!out_matches && !in_matches) {
107                        continue;
108                    }
109
110                    let target_id = if out_matches { &edge.to } else { &edge.from };
111                    if let Some(target) = inner.nodes.get(target_id) {
112                        if step
113                            .node
114                            .as_ref()
115                            .is_none_or(|label| label == &target.label)
116                        {
117                            next.push(target.clone());
118                        }
119                    }
120                }
121            }
122            current = next;
123        }
124
125        if let Some(limit) = traversal.limit {
126            current.truncate(limit as usize);
127        }
128        Ok(current)
129    }
130}
131
132#[cfg(test)]
133mod tests;