Skip to main content

grust_memory/
lib.rs

1use std::{
2    collections::BTreeMap,
3    sync::{Arc, RwLock},
4};
5
6use async_trait::async_trait;
7use grust_core::prelude::*;
8
9#[derive(Clone, Debug, Default)]
10pub struct MemoryGraphStore {
11    inner: Arc<RwLock<MemoryGraph>>,
12}
13
14#[derive(Clone, Debug, Default)]
15struct MemoryGraph {
16    nodes: BTreeMap<NodeId, Node>,
17    edges: BTreeMap<MemoryEdgeKey, Edge>,
18    schema: Option<GraphSchema>,
19}
20
21#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
22struct MemoryEdgeKey {
23    from: NodeId,
24    label: Label,
25    to: NodeId,
26    id: Option<EdgeId>,
27}
28
29impl MemoryEdgeKey {
30    fn new(from: NodeId, label: Label, to: NodeId, id: Option<EdgeId>) -> Self {
31        Self {
32            from,
33            label,
34            to,
35            id,
36        }
37    }
38
39    fn from_edge(edge: &Edge) -> Self {
40        Self::new(
41            edge.from.clone(),
42            edge.label.clone(),
43            edge.to.clone(),
44            edge.id.clone(),
45        )
46    }
47}
48
49impl MemoryGraphStore {
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    pub fn graph(&self) -> Graph {
55        let inner = self.inner.read().expect("memory graph lock poisoned");
56        Graph {
57            nodes: inner.nodes.values().cloned().collect(),
58            edges: inner.edges.values().cloned().collect(),
59        }
60    }
61
62    fn node_matches(
63        node: &Node,
64        label: Option<&Label>,
65        props: &Props,
66        predicates: &[GraphPropertyPredicate],
67    ) -> bool {
68        label.is_none_or(|label| &node.label == label)
69            && props.iter().all(|(key, value)| {
70                if key == "id" {
71                    value.as_str().is_some_and(|id| node.id.as_str() == id)
72                } else {
73                    node.props.get(key) == Some(value)
74                }
75            })
76            && predicates
77                .iter()
78                .all(|predicate| predicate.matches(node.props.get(&predicate.key)))
79    }
80
81    fn matching_node_ids(
82        inner: &MemoryGraph,
83        label: Option<&Label>,
84        props: &Props,
85        predicates: &[GraphPropertyPredicate],
86    ) -> Vec<NodeId> {
87        inner
88            .nodes
89            .values()
90            .filter(|node| Self::node_matches(node, label, props, predicates))
91            .map(|node| node.id.clone())
92            .collect()
93    }
94
95    fn relationship_matches(
96        inner: &MemoryGraph,
97        edge: &Edge,
98        relationship: &GraphRelationshipMatch,
99    ) -> bool {
100        if edge.label != relationship.label {
101            return false;
102        }
103        if relationship
104            .id
105            .as_ref()
106            .is_some_and(|id| edge.id.as_ref() != Some(id))
107        {
108            return false;
109        }
110        if !relationship
111            .props
112            .iter()
113            .all(|(key, value)| edge.props.get(key) == Some(value))
114        {
115            return false;
116        }
117        if !relationship
118            .predicates
119            .iter()
120            .all(|predicate| predicate.matches(edge.props.get(&predicate.key)))
121        {
122            return false;
123        }
124        let Some(from) = inner.nodes.get(&edge.from) else {
125            return false;
126        };
127        let Some(to) = inner.nodes.get(&edge.to) else {
128            return false;
129        };
130        Self::node_matches(
131            from,
132            relationship.from.label.as_ref(),
133            &relationship.from.props,
134            &relationship.from.predicates,
135        ) && Self::node_matches(
136            to,
137            relationship.to.label.as_ref(),
138            &relationship.to.props,
139            &relationship.to.predicates,
140        )
141    }
142
143    fn matching_edges(inner: &MemoryGraph, relationship: &GraphRelationshipMatch) -> Vec<Edge> {
144        inner
145            .edges
146            .values()
147            .filter(|edge| Self::relationship_matches(inner, edge, relationship))
148            .cloned()
149            .collect()
150    }
151
152    fn graph_snapshot(inner: &MemoryGraph) -> Graph {
153        Graph {
154            nodes: inner.nodes.values().cloned().collect(),
155            edges: inner.edges.values().cloned().collect(),
156        }
157    }
158
159    fn graph_snapshot_with_node(inner: &MemoryGraph, node: &Node) -> Graph {
160        let mut graph = Self::graph_snapshot(inner);
161        if let Some(existing) = graph
162            .nodes
163            .iter_mut()
164            .find(|existing| existing.id == node.id)
165        {
166            *existing = node.clone();
167        } else {
168            graph.nodes.push(node.clone());
169        }
170        graph
171    }
172
173    fn graph_snapshot_with_edge(inner: &MemoryGraph, edge: &Edge) -> Graph {
174        let mut graph = Self::graph_snapshot(inner);
175        let key = MemoryEdgeKey::from_edge(edge);
176        if let Some(existing) = graph
177            .edges
178            .iter_mut()
179            .find(|existing| MemoryEdgeKey::from_edge(existing) == key)
180        {
181            *existing = edge.clone();
182        } else {
183            graph.edges.push(edge.clone());
184        }
185        graph
186    }
187
188    fn graph_snapshot_with_graph(inner: &MemoryGraph, input: &Graph) -> Graph {
189        let mut nodes = inner.nodes.clone();
190        let mut edges = inner.edges.clone();
191        for node in &input.nodes {
192            nodes.insert(node.id.clone(), node.clone());
193        }
194        for edge in &input.edges {
195            edges.insert(MemoryEdgeKey::from_edge(edge), edge.clone());
196        }
197        Graph {
198            nodes: nodes.into_values().collect(),
199            edges: edges.into_values().collect(),
200        }
201    }
202}
203
204#[async_trait]
205impl GraphStore for MemoryGraphStore {
206    async fn apply_schema(&self, schema: &GraphSchema) -> Result<()> {
207        let mut inner = self.inner.write().expect("memory graph lock poisoned");
208        schema.validate_graph(&Self::graph_snapshot(&inner))?;
209        inner.schema = Some(schema.clone());
210        Ok(())
211    }
212
213    fn constraint_capability(&self, constraint: &GraphConstraint) -> GraphConstraintCapability {
214        match constraint {
215            GraphConstraint::NodePropertyRequired { .. }
216            | GraphConstraint::EdgePropertyRequired { .. }
217            | GraphConstraint::NodePropertyUnique { .. }
218            | GraphConstraint::EdgePropertyUnique { .. } => {
219                GraphConstraintCapability::ValidateBeforeWrite
220            }
221        }
222    }
223
224    async fn put_node(&self, node: &Node) -> Result<PutOutcome> {
225        let mut inner = self.inner.write().expect("memory graph lock poisoned");
226        if let Some(schema) = &inner.schema {
227            schema.validate_graph(&Self::graph_snapshot_with_node(&inner, node))?;
228        }
229        let previous = inner.nodes.insert(node.id.clone(), node.clone());
230        Ok(match previous {
231            Some(_) => PutOutcome::Updated,
232            None => PutOutcome::Inserted,
233        })
234    }
235
236    async fn put_edge(&self, edge: &Edge) -> Result<PutOutcome> {
237        let mut inner = self.inner.write().expect("memory graph lock poisoned");
238        if let Some(schema) = &inner.schema {
239            schema.validate_graph(&Self::graph_snapshot_with_edge(&inner, edge))?;
240        }
241        let previous = inner
242            .edges
243            .insert(MemoryEdgeKey::from_edge(edge), edge.clone());
244        Ok(match previous {
245            Some(_) => PutOutcome::Updated,
246            None => PutOutcome::Inserted,
247        })
248    }
249
250    async fn put_graph(&self, graph: &Graph) -> Result<LoadReport> {
251        let mut inner = self.inner.write().expect("memory graph lock poisoned");
252        if let Some(schema) = &inner.schema {
253            schema.validate_graph(&Self::graph_snapshot_with_graph(&inner, graph))?;
254        }
255        let mut report = LoadReport::default();
256        for node in &graph.nodes {
257            inner.nodes.insert(node.id.clone(), node.clone());
258            report.nodes += 1;
259        }
260        for edge in &graph.edges {
261            inner
262                .edges
263                .insert(MemoryEdgeKey::from_edge(edge), edge.clone());
264            report.edges += 1;
265        }
266        Ok(report)
267    }
268
269    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
270        let inner = self.inner.read().expect("memory graph lock poisoned");
271        Ok(inner.nodes.get(id).cloned())
272    }
273
274    async fn get_nodes(&self, ids: &[NodeId]) -> Result<Vec<Node>> {
275        let inner = self.inner.read().expect("memory graph lock poisoned");
276        Ok(ids
277            .iter()
278            .filter_map(|id| inner.nodes.get(id).cloned())
279            .collect())
280    }
281
282    async fn get_edges(&self, query: EdgeQuery) -> Result<Vec<Edge>> {
283        let inner = self.inner.read().expect("memory graph lock poisoned");
284        Ok(inner
285            .edges
286            .values()
287            .filter(|edge| {
288                query.from.as_ref().is_none_or(|from| from == &edge.from)
289                    && query.to.as_ref().is_none_or(|to| to == &edge.to)
290                    && query
291                        .label
292                        .as_ref()
293                        .is_none_or(|label| label == &edge.label)
294            })
295            .cloned()
296            .collect())
297    }
298
299    async fn traverse(&self, traversal: Traversal) -> Result<Vec<Node>> {
300        let inner = self.inner.read().expect("memory graph lock poisoned");
301        let mut current = match traversal.start {
302            Start::Node(id) => inner
303                .nodes
304                .get(&id)
305                .cloned()
306                .into_iter()
307                .collect::<Vec<_>>(),
308            Start::NodesByLabel(label) => inner
309                .nodes
310                .values()
311                .filter(|node| node.label == label)
312                .cloned()
313                .collect(),
314            Start::NodesByProperty { label, key, value } => inner
315                .nodes
316                .values()
317                .filter(|node| node.label == label && node.props.get(&key) == Some(&value))
318                .cloned()
319                .collect(),
320        };
321
322        for step in traversal.steps {
323            let mut next = Vec::new();
324            for node in &current {
325                for edge in inner.edges.values() {
326                    let label_matches = step.edge.as_ref().is_none_or(|label| label == &edge.label);
327                    let out_matches = matches!(step.direction, Direction::Out | Direction::Both)
328                        && edge.from == node.id;
329                    let in_matches = matches!(step.direction, Direction::In | Direction::Both)
330                        && edge.to == node.id;
331
332                    if !label_matches || (!out_matches && !in_matches) {
333                        continue;
334                    }
335
336                    let target_id = if out_matches { &edge.to } else { &edge.from };
337                    if let Some(target) = inner.nodes.get(target_id)
338                        && step
339                            .node
340                            .as_ref()
341                            .is_none_or(|label| label == &target.label)
342                    {
343                        next.push(target.clone());
344                    }
345                }
346            }
347            current = next;
348        }
349
350        if let Some(limit) = traversal.limit {
351            current.truncate(limit as usize);
352        }
353        Ok(current)
354    }
355}
356
357#[async_trait]
358impl GraphMutationStore for MemoryGraphStore {
359    async fn delete_node(&self, id: &NodeId) -> Result<()> {
360        let mut inner = self.inner.write().expect("memory graph lock poisoned");
361        inner.nodes.remove(id);
362        inner
363            .edges
364            .retain(|key, _| key.from != *id && key.to != *id);
365        Ok(())
366    }
367
368    async fn delete_edge(&self, from: &NodeId, label: &Label, to: &NodeId) -> Result<()> {
369        let mut inner = self.inner.write().expect("memory graph lock poisoned");
370        inner
371            .edges
372            .retain(|key, _| key.from != *from || key.label != *label || key.to != *to);
373        Ok(())
374    }
375}
376
377#[async_trait]
378impl CypherMutationExecutor for MemoryGraphStore {
379    async fn execute_cypher_mutation_plan(
380        &self,
381        plan: &GraphMutationPlan,
382    ) -> Result<GraphMutationReport> {
383        let mut report = plan.report();
384        for operation in &plan.operations {
385            match operation {
386                GraphMutationPlanOp::PatchMatchingNodes {
387                    label,
388                    props,
389                    predicates,
390                    patch,
391                    ..
392                } => {
393                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
394                    let ids = Self::matching_node_ids(&inner, label.as_ref(), props, predicates);
395                    report.matched_rows += ids.len();
396                    report.node_patches += ids.len();
397                    report.changed_nodes += ids.len();
398
399                    let mut patched = Vec::with_capacity(ids.len());
400                    for id in &ids {
401                        if let Some(node) = inner.nodes.get(id) {
402                            let mut node = node.clone();
403                            for (key, value) in patch {
404                                node.props.insert(key.clone(), value.clone());
405                            }
406                            if let Some(schema) = &inner.schema {
407                                schema.validate_node(&node)?;
408                            }
409                            patched.push(node);
410                        }
411                    }
412                    for node in patched {
413                        inner.nodes.insert(node.id.clone(), node);
414                    }
415                }
416                GraphMutationPlanOp::UpdateMatchingNodeProperty {
417                    label,
418                    props,
419                    predicates,
420                    target_key,
421                    source_key,
422                    op,
423                    operand,
424                    ..
425                } => {
426                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
427                    let ids = Self::matching_node_ids(&inner, label.as_ref(), props, predicates);
428                    report.matched_rows += ids.len();
429                    report.node_patches += ids.len();
430                    report.changed_nodes += ids.len();
431
432                    let mut updated = Vec::with_capacity(ids.len());
433                    for id in &ids {
434                        if let Some(node) = inner.nodes.get(id) {
435                            let mut node = node.clone();
436                            let current = node.props.get(source_key).ok_or_else(|| {
437                                GrustError::CypherExecution(format!(
438                                    "numeric expression source property '{source_key}' is missing"
439                                ))
440                            })?;
441                            let value = evaluate_numeric_update(current, *op, operand)?;
442                            node.props.insert(target_key.clone(), value);
443                            if let Some(schema) = &inner.schema {
444                                schema.validate_node(&node)?;
445                            }
446                            updated.push(node);
447                        }
448                    }
449                    for node in updated {
450                        inner.nodes.insert(node.id.clone(), node);
451                    }
452                }
453                GraphMutationPlanOp::RemoveMatchingNodeProps {
454                    label,
455                    props,
456                    predicates,
457                    keys,
458                    ..
459                } => {
460                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
461                    let ids = Self::matching_node_ids(&inner, label.as_ref(), props, predicates);
462                    report.matched_rows += ids.len();
463                    report.node_property_removes += ids.len();
464                    report.changed_nodes += ids.len();
465
466                    let mut updated = Vec::with_capacity(ids.len());
467                    for id in &ids {
468                        if let Some(node) = inner.nodes.get(id) {
469                            let mut node = node.clone();
470                            for key in keys {
471                                node.props.remove(key);
472                            }
473                            if let Some(schema) = &inner.schema {
474                                schema.validate_node(&node)?;
475                            }
476                            updated.push(node);
477                        }
478                    }
479                    for node in updated {
480                        inner.nodes.insert(node.id.clone(), node);
481                    }
482                }
483                GraphMutationPlanOp::DeleteMatchingNodes {
484                    label,
485                    props,
486                    predicates,
487                    ..
488                } => {
489                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
490                    let ids = Self::matching_node_ids(&inner, label.as_ref(), props, predicates);
491                    let incident_edges = inner
492                        .edges
493                        .keys()
494                        .filter(|key| ids.iter().any(|id| id == &key.from || id == &key.to))
495                        .count();
496
497                    report.matched_rows += ids.len();
498                    report.node_deletes += ids.len();
499                    report.changed_nodes += ids.len();
500                    report.edge_deletes += incident_edges;
501                    report.changed_edges += incident_edges;
502
503                    for id in &ids {
504                        inner.nodes.remove(id);
505                    }
506                    inner
507                        .edges
508                        .retain(|key, _| !ids.iter().any(|id| id == &key.from || id == &key.to));
509                }
510                GraphMutationPlanOp::PatchMatchingEdges {
511                    relationship,
512                    patch,
513                    ..
514                } => {
515                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
516                    let edges = Self::matching_edges(&inner, relationship);
517                    report.matched_rows += edges.len();
518                    report.edge_patches += edges.len();
519                    report.changed_edges += edges.len();
520
521                    let mut patched = Vec::with_capacity(edges.len());
522                    for mut edge in edges {
523                        for (key, value) in patch {
524                            edge.props.insert(key.clone(), value.clone());
525                        }
526                        if let Some(schema) = &inner.schema {
527                            schema.validate_edge_with(&edge, |id| {
528                                inner.nodes.get(id).map(|node| &node.label)
529                            })?;
530                        }
531                        patched.push(edge);
532                    }
533                    for edge in patched {
534                        inner.edges.insert(MemoryEdgeKey::from_edge(&edge), edge);
535                    }
536                }
537                GraphMutationPlanOp::UpdateMatchingEdgeProperty {
538                    relationship,
539                    target_key,
540                    source_key,
541                    op,
542                    operand,
543                    ..
544                } => {
545                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
546                    let edges = Self::matching_edges(&inner, relationship);
547                    report.matched_rows += edges.len();
548                    report.edge_patches += edges.len();
549                    report.changed_edges += edges.len();
550
551                    let mut updated = Vec::with_capacity(edges.len());
552                    for mut edge in edges {
553                        let current = edge.props.get(source_key).ok_or_else(|| {
554                            GrustError::CypherExecution(format!(
555                                "numeric expression source property '{source_key}' is missing"
556                            ))
557                        })?;
558                        let value = evaluate_numeric_update(current, *op, operand)?;
559                        edge.props.insert(target_key.clone(), value);
560                        if let Some(schema) = &inner.schema {
561                            schema.validate_edge_with(&edge, |id| {
562                                inner.nodes.get(id).map(|node| &node.label)
563                            })?;
564                        }
565                        updated.push(edge);
566                    }
567                    for edge in updated {
568                        inner.edges.insert(MemoryEdgeKey::from_edge(&edge), edge);
569                    }
570                }
571                GraphMutationPlanOp::RemoveMatchingEdgeProps {
572                    relationship, keys, ..
573                } => {
574                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
575                    let edges = Self::matching_edges(&inner, relationship);
576                    report.matched_rows += edges.len();
577                    report.edge_property_removes += edges.len();
578                    report.changed_edges += edges.len();
579
580                    let mut updated = Vec::with_capacity(edges.len());
581                    for mut edge in edges {
582                        for key in keys {
583                            edge.props.remove(key);
584                        }
585                        if let Some(schema) = &inner.schema {
586                            schema.validate_edge_with(&edge, |id| {
587                                inner.nodes.get(id).map(|node| &node.label)
588                            })?;
589                        }
590                        updated.push(edge);
591                    }
592                    for edge in updated {
593                        inner.edges.insert(MemoryEdgeKey::from_edge(&edge), edge);
594                    }
595                }
596                GraphMutationPlanOp::DeleteMatchingEdges { relationship, .. } => {
597                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
598                    let edges = Self::matching_edges(&inner, relationship);
599                    report.matched_rows += edges.len();
600                    report.edge_deletes += edges.len();
601                    report.changed_edges += edges.len();
602                    for edge in edges {
603                        inner.edges.remove(&MemoryEdgeKey::from_edge(&edge));
604                    }
605                }
606                GraphMutationPlanOp::UpsertEdgesFromNodeMatches {
607                    kind,
608                    from,
609                    to,
610                    label,
611                    props,
612                    edge_id_policy,
613                    ..
614                } => {
615                    let mut inner = self.inner.write().expect("memory graph lock poisoned");
616                    let from_ids = Self::matching_node_ids(
617                        &inner,
618                        from.label.as_ref(),
619                        &from.props,
620                        &from.predicates,
621                    );
622                    let to_ids = Self::matching_node_ids(
623                        &inner,
624                        to.label.as_ref(),
625                        &to.props,
626                        &to.predicates,
627                    );
628                    let matched_rows = from_ids.len().saturating_mul(to_ids.len());
629                    report.matched_rows += matched_rows;
630                    report.edge_upserts += matched_rows;
631                    report.changed_edges += matched_rows;
632                    let explicit_edge_id = explicit_edge_id_from_props(props)?;
633                    if explicit_edge_id.is_some() && matched_rows > 1 {
634                        return Err(GrustError::CypherUnsupportedCardinality(
635                            "row-producing MATCH ... CREATE/MERGE with an explicit relationship id must produce exactly one edge".to_string(),
636                        ));
637                    }
638
639                    let mut edges = Vec::with_capacity(matched_rows);
640                    for from_id in &from_ids {
641                        for to_id in &to_ids {
642                            let mut edge = Edge::new(
643                                label.clone(),
644                                from_id.clone(),
645                                to_id.clone(),
646                                props.clone(),
647                            );
648                            if let Some(id) = explicit_edge_id.clone() {
649                                edge = edge.with_id(id);
650                            } else if row_edge_id_policy_generates(*kind, *edge_id_policy) {
651                                edge = edge
652                                    .with_id(generated_row_edge_id(from_id, label, to_id, props));
653                            }
654                            if let Some(schema) = &inner.schema {
655                                schema.validate_edge_with(&edge, |id| {
656                                    inner.nodes.get(id).map(|node| &node.label)
657                                })?;
658                            }
659                            edges.push(edge);
660                        }
661                    }
662                    for edge in edges {
663                        let previous = inner.edges.insert(MemoryEdgeKey::from_edge(&edge), edge);
664                        if previous.is_some() {
665                            report.edge_updates += 1;
666                        } else {
667                            report.edge_inserts += 1;
668                        }
669                    }
670                }
671                GraphMutationPlanOp::UpsertNode { node, .. } => {
672                    classify_node_upsert(self.put_node(node).await?, &mut report);
673                }
674                GraphMutationPlanOp::UpsertEdge { edge, .. } => {
675                    classify_edge_upsert(self.put_edge(edge).await?, &mut report);
676                }
677                _ => {
678                    let mutation = GraphMutation::from(operation.clone());
679                    self.apply_mutations(std::slice::from_ref(&mutation))
680                        .await?;
681                }
682            }
683        }
684        Ok(report)
685    }
686}
687
688fn explicit_edge_id_from_props(props: &Props) -> Result<Option<String>> {
689    match props.get("id") {
690        Some(Value::String(id)) => Ok(Some(id.clone())),
691        Some(_) => Err(GrustError::CypherSyntax(
692            "relationship id property must be a string literal".to_string(),
693        )),
694        None => Ok(None),
695    }
696}
697
698fn row_edge_id_policy_generates(kind: GraphMutationPlanKind, policy: GraphRowEdgeIdPolicy) -> bool {
699    matches!(
700        (kind, policy),
701        (
702            GraphMutationPlanKind::Create,
703            GraphRowEdgeIdPolicy::GenerateForCreate
704                | GraphRowEdgeIdPolicy::GenerateForCreateAndMerge
705        ) | (
706            GraphMutationPlanKind::Merge,
707            GraphRowEdgeIdPolicy::GenerateForCreateAndMerge
708        )
709    )
710}
711
712#[cfg(test)]
713mod tests;