1use crate::store::GraphStore;
2use contextdb_core::*;
3use contextdb_tx::{TxManager, WriteSetApplicator};
4use std::collections::{HashSet, VecDeque};
5use std::sync::Arc;
6
7const MAX_VISITED: usize = 100_000;
8type AdjPair = (NodeId, NodeId);
9
10pub struct MemGraphExecutor<S: WriteSetApplicator> {
11 store: Arc<GraphStore>,
12 tx_mgr: Arc<TxManager<S>>,
13 dag_edge_types: parking_lot::RwLock<HashSet<String>>,
14}
15
16impl<S: WriteSetApplicator> MemGraphExecutor<S> {
17 pub fn new(store: Arc<GraphStore>, tx_mgr: Arc<TxManager<S>>) -> Self {
18 Self {
19 store,
20 tx_mgr,
21 dag_edge_types: parking_lot::RwLock::new(HashSet::new()),
22 }
23 }
24
25 pub fn register_dag_edge_types(&self, types: &[String]) {
26 let mut set = self.dag_edge_types.write();
27 for t in types {
28 set.insert(t.clone());
29 }
30 }
31
32 pub fn edge_count(&self, source: NodeId, edge_type: &str, snapshot: SnapshotId) -> usize {
33 let fwd = self.store.forward_adj.read();
34 fwd.get(&source)
35 .map(|entries| {
36 entries
37 .iter()
38 .filter(|entry| entry.edge_type == edge_type && entry.visible_at(snapshot))
39 .count()
40 })
41 .unwrap_or(0)
42 }
43
44 fn bfs_with_write_set(
45 &self,
46 tx: TxId,
47 start: NodeId,
48 goal: NodeId,
49 edge_type: &str,
50 ) -> Result<bool> {
51 let mut visited = HashSet::new();
52 let mut queue = VecDeque::new();
53 visited.insert(start);
54 queue.push_back(start);
55
56 let (ws_inserts, ws_deletes): (Vec<AdjPair>, HashSet<AdjPair>) =
57 self.tx_mgr.with_write_set(tx, |ws| {
58 let inserts = ws
59 .adj_inserts
60 .iter()
61 .filter(|e| e.edge_type == edge_type)
62 .map(|e| (e.source, e.target))
63 .collect();
64 let deletes = ws
65 .adj_deletes
66 .iter()
67 .filter(|(_, et, _, _)| et == edge_type)
68 .map(|(s, _, t, _)| (*s, *t))
69 .collect();
70 (inserts, deletes)
71 })?;
72
73 while let Some(current) = queue.pop_front() {
74 {
75 let fwd = self.store.forward_adj.read();
76 if let Some(entries) = fwd.get(¤t) {
77 for e in entries {
78 if e.edge_type != edge_type || e.deleted_tx.is_some() {
79 continue;
80 }
81 if ws_deletes.contains(&(e.source, e.target)) {
82 continue;
83 }
84 if e.target == goal {
85 return Ok(true);
86 }
87 if visited.insert(e.target) {
88 queue.push_back(e.target);
89 }
90 }
91 }
92 }
93
94 for (src, tgt) in &ws_inserts {
95 if *src == current {
96 if *tgt == goal {
97 return Ok(true);
98 }
99 if visited.insert(*tgt) {
100 queue.push_back(*tgt);
101 }
102 }
103 }
104 }
105
106 Ok(false)
107 }
108}
109
110impl<S: WriteSetApplicator> GraphExecutor for MemGraphExecutor<S> {
111 fn bfs(
112 &self,
113 start: NodeId,
114 edge_types: Option<&[EdgeType]>,
115 direction: Direction,
116 min_depth: u32,
117 max_depth: u32,
118 snapshot: SnapshotId,
119 ) -> Result<TraversalResult> {
120 let mut visited = HashSet::new();
121 visited.insert(start);
122
123 type BfsEntry = (NodeId, u32, Vec<(NodeId, EdgeType)>);
124 let mut queue: VecDeque<BfsEntry> = VecDeque::new();
125 queue.push_back((start, 0, vec![]));
126
127 let mut result_nodes = Vec::new();
128
129 while let Some((current, depth, path)) = queue.pop_front() {
130 if depth > 0 && depth >= min_depth {
131 result_nodes.push(TraversalNode {
132 id: current,
133 depth,
134 path: path.clone(),
135 });
136 }
137
138 if depth >= max_depth {
139 continue;
140 }
141
142 let neighbors = self.neighbors(current, edge_types, direction, snapshot)?;
143 for (neighbor_id, edge_type, _) in neighbors {
144 if visited.contains(&neighbor_id) {
145 continue;
146 }
147 visited.insert(neighbor_id);
148
149 if visited.len() > MAX_VISITED {
150 return Err(Error::BfsVisitedExceeded(MAX_VISITED));
151 }
152
153 let mut new_path = path.clone();
154 new_path.push((current, edge_type));
155 queue.push_back((neighbor_id, depth + 1, new_path));
156 }
157 }
158
159 Ok(TraversalResult {
160 nodes: result_nodes,
161 })
162 }
163
164 fn neighbors(
165 &self,
166 node: NodeId,
167 edge_types: Option<&[EdgeType]>,
168 direction: Direction,
169 snapshot: SnapshotId,
170 ) -> Result<Vec<(NodeId, EdgeType, std::collections::HashMap<String, Value>)>> {
171 let mut results = Vec::new();
172
173 if matches!(direction, Direction::Outgoing | Direction::Both) {
174 let fwd = self.store.forward_adj.read();
175 if let Some(entries) = fwd.get(&node) {
176 for e in entries {
177 if !e.visible_at(snapshot) {
178 continue;
179 }
180 if let Some(types) = edge_types
181 && !types.contains(&e.edge_type)
182 {
183 continue;
184 }
185 results.push((e.target, e.edge_type.clone(), e.properties.clone()));
186 }
187 }
188 }
189
190 if matches!(direction, Direction::Incoming | Direction::Both) {
191 let rev = self.store.reverse_adj.read();
192 if let Some(entries) = rev.get(&node) {
193 for e in entries {
194 if !e.visible_at(snapshot) {
195 continue;
196 }
197 if let Some(types) = edge_types
198 && !types.contains(&e.edge_type)
199 {
200 continue;
201 }
202 results.push((e.source, e.edge_type.clone(), e.properties.clone()));
203 }
204 }
205 }
206
207 Ok(results)
208 }
209
210 fn insert_edge(
211 &self,
212 tx: TxId,
213 source: NodeId,
214 target: NodeId,
215 edge_type: EdgeType,
216 properties: std::collections::HashMap<String, Value>,
217 ) -> Result<bool> {
218 let deleted_in_ws = self.tx_mgr.with_write_set(tx, |ws| {
219 ws.adj_deletes
220 .iter()
221 .any(|(s, et, t, _)| *s == source && *t == target && et == &edge_type)
222 })?;
223
224 {
225 let fwd = self.store.forward_adj.read();
226 if let Some(entries) = fwd.get(&source) {
227 let live_in_committed = entries.iter().any(|e| {
228 e.target == target && e.edge_type == edge_type && e.deleted_tx.is_none()
229 });
230 if live_in_committed && !deleted_in_ws {
231 return Ok(false);
232 }
233 }
234 }
235
236 let duplicate_in_ws = self.tx_mgr.with_write_set(tx, |ws| {
237 let inserted = ws
238 .adj_inserts
239 .iter()
240 .any(|e| e.source == source && e.target == target && e.edge_type == edge_type);
241 inserted && !deleted_in_ws
242 })?;
243 if duplicate_in_ws {
244 return Ok(false);
245 }
246
247 if self.dag_edge_types.read().contains(&edge_type) {
248 if source == target {
249 return Err(Error::CycleDetected {
250 edge_type: edge_type.clone(),
251 source_node: source,
252 target_node: target,
253 });
254 }
255
256 if self.bfs_with_write_set(tx, target, source, &edge_type)? {
257 return Err(Error::CycleDetected {
258 edge_type: edge_type.clone(),
259 source_node: source,
260 target_node: target,
261 });
262 }
263 }
264
265 let entry = AdjEntry {
266 source,
267 target,
268 edge_type,
269 properties,
270 created_tx: tx,
271 deleted_tx: None,
272 lsn: 0,
273 };
274
275 self.tx_mgr.with_write_set(tx, |ws| {
276 ws.adj_inserts.push(entry);
277 })?;
278
279 Ok(true)
280 }
281
282 fn delete_edge(&self, tx: TxId, source: NodeId, target: NodeId, edge_type: &str) -> Result<()> {
283 self.tx_mgr.with_write_set(tx, |ws| {
284 ws.adj_deletes
285 .push((source, edge_type.to_string(), target, tx));
286 })?;
287 Ok(())
288 }
289}