Skip to main content

contextdb_graph/
mem.rs

1use crate::store::GraphStore;
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use std::collections::{HashSet, VecDeque};
5use std::sync::Arc;
6
7const MAX_VISITED: usize = 100_000;
8type AdjPair = (NodeId, NodeId);
9
10pub struct MemGraphExecutor<S: WriteSetApplicator> {
11    store: Arc<GraphStore>,
12    tx_mgr: Arc<TxManager<S>>,
13    dag_edge_types: parking_lot::RwLock<HashSet<String>>,
14}
15
16impl<S: WriteSetApplicator> MemGraphExecutor<S> {
17    pub fn new(store: Arc<GraphStore>, tx_mgr: Arc<TxManager<S>>) -> Self {
18        Self {
19            store,
20            tx_mgr,
21            dag_edge_types: parking_lot::RwLock::new(HashSet::new()),
22        }
23    }
24
25    pub fn register_dag_edge_types(&self, types: &[String]) {
26        let mut set = self.dag_edge_types.write();
27        for t in types {
28            set.insert(t.clone());
29        }
30    }
31
32    pub fn edge_count(&self, source: NodeId, edge_type: &str, snapshot: SnapshotId) -> usize {
33        let fwd = self.store.forward_adj.read();
34        fwd.get(&source)
35            .map(|entries| {
36                entries
37                    .iter()
38                    .filter(|entry| entry.edge_type == edge_type && entry.visible_at(snapshot))
39                    .count()
40            })
41            .unwrap_or(0)
42    }
43
44    fn bfs_with_write_set(
45        &self,
46        tx: TxId,
47        start: NodeId,
48        goal: NodeId,
49        edge_type: &str,
50    ) -> Result<bool> {
51        let mut visited = HashSet::new();
52        let mut queue = VecDeque::new();
53        visited.insert(start);
54        queue.push_back(start);
55
56        let (ws_inserts, ws_deletes): (Vec<AdjPair>, HashSet<AdjPair>) =
57            self.tx_mgr.with_write_set(tx, |ws| {
58                let inserts = ws
59                    .adj_inserts
60                    .iter()
61                    .filter(|e| e.edge_type == edge_type)
62                    .map(|e| (e.source, e.target))
63                    .collect();
64                let deletes = ws
65                    .adj_deletes
66                    .iter()
67                    .filter(|(_, et, _, _)| et == edge_type)
68                    .map(|(s, _, t, _)| (*s, *t))
69                    .collect();
70                (inserts, deletes)
71            })?;
72
73        while let Some(current) = queue.pop_front() {
74            {
75                let fwd = self.store.forward_adj.read();
76                if let Some(entries) = fwd.get(&current) {
77                    for e in entries {
78                        if e.edge_type != edge_type || e.deleted_tx.is_some() {
79                            continue;
80                        }
81                        if ws_deletes.contains(&(e.source, e.target)) {
82                            continue;
83                        }
84                        if e.target == goal {
85                            return Ok(true);
86                        }
87                        if visited.insert(e.target) {
88                            queue.push_back(e.target);
89                        }
90                    }
91                }
92            }
93
94            for (src, tgt) in &ws_inserts {
95                if *src == current {
96                    if *tgt == goal {
97                        return Ok(true);
98                    }
99                    if visited.insert(*tgt) {
100                        queue.push_back(*tgt);
101                    }
102                }
103            }
104        }
105
106        Ok(false)
107    }
108}
109
110impl<S: WriteSetApplicator> GraphExecutor for MemGraphExecutor<S> {
111    fn bfs(
112        &self,
113        start: NodeId,
114        edge_types: Option<&[EdgeType]>,
115        direction: Direction,
116        min_depth: u32,
117        max_depth: u32,
118        snapshot: SnapshotId,
119    ) -> Result<TraversalResult> {
120        let mut visited = HashSet::new();
121        visited.insert(start);
122
123        type BfsEntry = (NodeId, u32, Vec<(NodeId, EdgeType)>);
124        let mut queue: VecDeque<BfsEntry> = VecDeque::new();
125        queue.push_back((start, 0, vec![]));
126
127        let mut result_nodes = Vec::new();
128
129        while let Some((current, depth, path)) = queue.pop_front() {
130            if depth > 0 && depth >= min_depth {
131                result_nodes.push(TraversalNode {
132                    id: current,
133                    depth,
134                    path: path.clone(),
135                });
136            }
137
138            if depth >= max_depth {
139                continue;
140            }
141
142            let neighbors = self.neighbors(current, edge_types, direction, snapshot)?;
143            for (neighbor_id, edge_type, _) in neighbors {
144                if visited.contains(&neighbor_id) {
145                    continue;
146                }
147                visited.insert(neighbor_id);
148
149                if visited.len() > MAX_VISITED {
150                    return Err(Error::BfsVisitedExceeded(MAX_VISITED));
151                }
152
153                let mut new_path = path.clone();
154                new_path.push((current, edge_type));
155                queue.push_back((neighbor_id, depth + 1, new_path));
156            }
157        }
158
159        Ok(TraversalResult {
160            nodes: result_nodes,
161        })
162    }
163
164    fn neighbors(
165        &self,
166        node: NodeId,
167        edge_types: Option<&[EdgeType]>,
168        direction: Direction,
169        snapshot: SnapshotId,
170    ) -> Result<Vec<(NodeId, EdgeType, std::collections::HashMap<String, Value>)>> {
171        let mut results = Vec::new();
172
173        if matches!(direction, Direction::Outgoing | Direction::Both) {
174            let fwd = self.store.forward_adj.read();
175            if let Some(entries) = fwd.get(&node) {
176                for e in entries {
177                    if !e.visible_at(snapshot) {
178                        continue;
179                    }
180                    if let Some(types) = edge_types
181                        && !types.contains(&e.edge_type)
182                    {
183                        continue;
184                    }
185                    results.push((e.target, e.edge_type.clone(), e.properties.clone()));
186                }
187            }
188        }
189
190        if matches!(direction, Direction::Incoming | Direction::Both) {
191            let rev = self.store.reverse_adj.read();
192            if let Some(entries) = rev.get(&node) {
193                for e in entries {
194                    if !e.visible_at(snapshot) {
195                        continue;
196                    }
197                    if let Some(types) = edge_types
198                        && !types.contains(&e.edge_type)
199                    {
200                        continue;
201                    }
202                    results.push((e.source, e.edge_type.clone(), e.properties.clone()));
203                }
204            }
205        }
206
207        Ok(results)
208    }
209
210    fn insert_edge(
211        &self,
212        tx: TxId,
213        source: NodeId,
214        target: NodeId,
215        edge_type: EdgeType,
216        properties: std::collections::HashMap<String, Value>,
217    ) -> Result<bool> {
218        let deleted_in_ws = self.tx_mgr.with_write_set(tx, |ws| {
219            ws.adj_deletes
220                .iter()
221                .any(|(s, et, t, _)| *s == source && *t == target && et == &edge_type)
222        })?;
223
224        {
225            let fwd = self.store.forward_adj.read();
226            if let Some(entries) = fwd.get(&source) {
227                let live_in_committed = entries.iter().any(|e| {
228                    e.target == target && e.edge_type == edge_type && e.deleted_tx.is_none()
229                });
230                if live_in_committed && !deleted_in_ws {
231                    return Ok(false);
232                }
233            }
234        }
235
236        let duplicate_in_ws = self.tx_mgr.with_write_set(tx, |ws| {
237            let inserted = ws
238                .adj_inserts
239                .iter()
240                .any(|e| e.source == source && e.target == target && e.edge_type == edge_type);
241            inserted && !deleted_in_ws
242        })?;
243        if duplicate_in_ws {
244            return Ok(false);
245        }
246
247        if self.dag_edge_types.read().contains(&edge_type) {
248            if source == target {
249                return Err(Error::CycleDetected {
250                    edge_type: edge_type.clone(),
251                    source_node: source,
252                    target_node: target,
253                });
254            }
255
256            if self.bfs_with_write_set(tx, target, source, &edge_type)? {
257                return Err(Error::CycleDetected {
258                    edge_type: edge_type.clone(),
259                    source_node: source,
260                    target_node: target,
261                });
262            }
263        }
264
265        let entry = AdjEntry {
266            source,
267            target,
268            edge_type,
269            properties,
270            created_tx: tx,
271            deleted_tx: None,
272            lsn: 0,
273        };
274
275        self.tx_mgr.with_write_set(tx, |ws| {
276            ws.adj_inserts.push(entry);
277        })?;
278
279        Ok(true)
280    }
281
282    fn delete_edge(&self, tx: TxId, source: NodeId, target: NodeId, edge_type: &str) -> Result<()> {
283        self.tx_mgr.with_write_set(tx, |ws| {
284            ws.adj_deletes
285                .push((source, edge_type.to_string(), target, tx));
286        })?;
287        Ok(())
288    }
289}