1use std::collections::{HashSet, VecDeque};
7
8use manifoldb_core::{EdgeId, EdgeType, EntityId};
9use manifoldb_storage::Transaction;
10
11use super::{Direction, ExpandResult, TraversalFilter, TraversalNode};
12use crate::index::AdjacencyIndex;
13use crate::store::{EdgeStore, GraphResult};
14
15#[derive(Debug, Clone)]
17pub struct TraversalConfig {
18 pub direction: Direction,
20 pub min_depth: usize,
22 pub max_depth: Option<usize>,
24 pub filter: TraversalFilter,
26 pub include_start: bool,
28}
29
30impl Default for TraversalConfig {
31 fn default() -> Self {
32 Self {
33 direction: Direction::Outgoing,
34 min_depth: 1,
35 max_depth: None,
36 filter: TraversalFilter::new(),
37 include_start: false,
38 }
39 }
40}
41
42impl TraversalConfig {
43 pub fn new(direction: Direction) -> Self {
45 Self { direction, ..Default::default() }
46 }
47
48 pub const fn with_min_depth(mut self, min_depth: usize) -> Self {
50 self.min_depth = min_depth;
51 self
52 }
53
54 pub const fn with_max_depth(mut self, max_depth: usize) -> Self {
56 self.max_depth = Some(max_depth);
57 self
58 }
59
60 pub fn with_edge_type(mut self, edge_type: impl Into<EdgeType>) -> Self {
62 self.filter = self.filter.with_edge_type(edge_type);
63 self
64 }
65
66 pub fn with_limit(mut self, limit: usize) -> Self {
68 self.filter = self.filter.with_limit(limit);
69 self
70 }
71
72 pub const fn include_start(mut self) -> Self {
74 self.include_start = true;
75 self.min_depth = 0;
76 self
77 }
78}
79
80struct TraversalState {
82 queue: VecDeque<(EntityId, usize)>,
84 visited: HashSet<EntityId>,
86 count: usize,
88}
89
90impl TraversalState {
91 fn new(start: EntityId) -> Self {
92 let mut visited = HashSet::new();
93 visited.insert(start);
94
95 let mut queue = VecDeque::new();
96 queue.push_back((start, 0));
97
98 Self { queue, visited, count: 0 }
99 }
100}
101
102pub struct TraversalIterator<'a, T: Transaction> {
122 tx: &'a T,
123 config: TraversalConfig,
124 state: TraversalState,
125 pending: VecDeque<TraversalNode>,
127 yielded_start: bool,
129 start: EntityId,
131}
132
133impl<'a, T: Transaction> TraversalIterator<'a, T> {
134 pub fn new(tx: &'a T, start: EntityId, config: TraversalConfig) -> Self {
136 Self {
137 tx,
138 config,
139 state: TraversalState::new(start),
140 pending: VecDeque::new(),
141 yielded_start: false,
142 start,
143 }
144 }
145
146 fn expand_current(&mut self) -> GraphResult<()> {
148 let Some((current, depth)) = self.state.queue.pop_front() else {
149 return Ok(());
150 };
151
152 let should_expand = self.config.max_depth.map_or(true, |max| depth < max);
154 if !should_expand {
155 return Ok(());
156 }
157
158 let neighbors = self.get_neighbors(current)?;
160
161 for (neighbor, _edge_id) in neighbors {
162 if self.state.visited.contains(&neighbor) {
163 continue;
164 }
165
166 self.state.visited.insert(neighbor);
167 let next_depth = depth + 1;
168
169 self.state.queue.push_back((neighbor, next_depth));
171
172 if next_depth >= self.config.min_depth
174 && self.config.filter.should_include_node(neighbor)
175 {
176 self.pending.push_back(TraversalNode::new(neighbor, next_depth));
177 }
178 }
179
180 Ok(())
181 }
182
183 fn get_neighbors(&self, node: EntityId) -> GraphResult<Vec<(EntityId, EdgeId)>> {
184 let mut neighbors = Vec::new();
185
186 if self.config.direction.includes_outgoing() {
187 self.add_outgoing_neighbors(node, &mut neighbors)?;
188 }
189
190 if self.config.direction.includes_incoming() {
191 self.add_incoming_neighbors(node, &mut neighbors)?;
192 }
193
194 Ok(neighbors)
195 }
196
197 fn add_outgoing_neighbors(
198 &self,
199 node: EntityId,
200 neighbors: &mut Vec<(EntityId, EdgeId)>,
201 ) -> GraphResult<()> {
202 match &self.config.filter.edge_types {
203 Some(types) => {
204 for edge_type in types {
205 AdjacencyIndex::for_each_outgoing_by_type(
206 self.tx,
207 node,
208 edge_type,
209 |edge_id| {
210 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
211 neighbors.push((edge.target, edge_id));
212 }
213 Ok(true)
214 },
215 )?;
216 }
217 }
218 None => {
219 AdjacencyIndex::for_each_outgoing(self.tx, node, |edge_id| {
220 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
221 neighbors.push((edge.target, edge_id));
222 }
223 Ok(true)
224 })?;
225 }
226 }
227 Ok(())
228 }
229
230 fn add_incoming_neighbors(
231 &self,
232 node: EntityId,
233 neighbors: &mut Vec<(EntityId, EdgeId)>,
234 ) -> GraphResult<()> {
235 match &self.config.filter.edge_types {
236 Some(types) => {
237 for edge_type in types {
238 AdjacencyIndex::for_each_incoming_by_type(
239 self.tx,
240 node,
241 edge_type,
242 |edge_id| {
243 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
244 neighbors.push((edge.source, edge_id));
245 }
246 Ok(true)
247 },
248 )?;
249 }
250 }
251 None => {
252 AdjacencyIndex::for_each_incoming(self.tx, node, |edge_id| {
253 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
254 neighbors.push((edge.source, edge_id));
255 }
256 Ok(true)
257 })?;
258 }
259 }
260 Ok(())
261 }
262
263 fn try_next(&mut self) -> GraphResult<Option<TraversalNode>> {
265 if self.config.filter.is_at_limit(self.state.count) {
267 return Ok(None);
268 }
269
270 if !self.yielded_start && self.config.include_start {
272 self.yielded_start = true;
273 self.state.count += 1;
274 return Ok(Some(TraversalNode::new(self.start, 0)));
275 }
276 self.yielded_start = true;
277
278 if let Some(node) = self.pending.pop_front() {
280 self.state.count += 1;
281 return Ok(Some(node));
282 }
283
284 while !self.state.queue.is_empty() {
286 self.expand_current()?;
287
288 if let Some(node) = self.pending.pop_front() {
289 self.state.count += 1;
290 return Ok(Some(node));
291 }
292 }
293
294 Ok(None)
295 }
296
297 pub fn collect_all(mut self) -> GraphResult<Vec<TraversalNode>> {
299 let mut results = Vec::new();
300 while let Some(node) = self.try_next()? {
301 results.push(node);
302 }
303 Ok(results)
304 }
305
306 pub fn collect_ids(mut self) -> GraphResult<Vec<EntityId>> {
308 let mut results = Vec::new();
309 while let Some(node) = self.try_next()? {
310 results.push(node.id);
311 }
312 Ok(results)
313 }
314
315 pub fn count_all(mut self) -> GraphResult<usize> {
317 let mut count = 0;
318 while self.try_next()?.is_some() {
319 count += 1;
320 }
321 Ok(count)
322 }
323
324 pub fn take(mut self, n: usize) -> GraphResult<Vec<TraversalNode>> {
326 let mut results = Vec::with_capacity(n);
327 for _ in 0..n {
328 match self.try_next()? {
329 Some(node) => results.push(node),
330 None => break,
331 }
332 }
333 Ok(results)
334 }
335
336 pub fn has_next(&mut self) -> GraphResult<bool> {
338 if !self.pending.is_empty() {
339 return Ok(true);
340 }
341
342 if self.config.filter.is_at_limit(self.state.count) {
343 return Ok(false);
344 }
345
346 while !self.state.queue.is_empty() && self.pending.is_empty() {
348 self.expand_current()?;
349 }
350
351 Ok(!self.pending.is_empty())
352 }
353}
354
355pub struct TraversalIteratorAdapter<'a, T: Transaction> {
357 inner: TraversalIterator<'a, T>,
358 errored: bool,
359}
360
361impl<'a, T: Transaction> TraversalIteratorAdapter<'a, T> {
362 pub fn new(tx: &'a T, start: EntityId, config: TraversalConfig) -> Self {
364 Self { inner: TraversalIterator::new(tx, start, config), errored: false }
365 }
366}
367
368impl<T: Transaction> Iterator for TraversalIteratorAdapter<'_, T> {
369 type Item = GraphResult<TraversalNode>;
370
371 fn next(&mut self) -> Option<Self::Item> {
372 if self.errored {
373 return None;
374 }
375
376 match self.inner.try_next() {
377 Ok(Some(node)) => Some(Ok(node)),
378 Ok(None) => None,
379 Err(e) => {
380 self.errored = true;
381 Some(Err(e))
382 }
383 }
384 }
385}
386
387pub struct NeighborIterator<'a, T: Transaction> {
389 tx: &'a T,
390 node: EntityId,
391 direction: Direction,
392 filter: Option<EdgeType>,
393 pending: VecDeque<ExpandResult>,
395 loaded: bool,
397}
398
399impl<'a, T: Transaction> NeighborIterator<'a, T> {
400 pub const fn new(tx: &'a T, node: EntityId, direction: Direction) -> Self {
402 Self { tx, node, direction, filter: None, pending: VecDeque::new(), loaded: false }
403 }
404
405 pub fn with_edge_type(mut self, edge_type: impl Into<EdgeType>) -> Self {
407 self.filter = Some(edge_type.into());
408 self
409 }
410
411 fn load_neighbors(&mut self) -> GraphResult<()> {
412 if self.loaded {
413 return Ok(());
414 }
415 self.loaded = true;
416
417 if self.direction.includes_outgoing() {
418 self.load_outgoing()?;
419 }
420
421 if self.direction.includes_incoming() {
422 self.load_incoming()?;
423 }
424
425 Ok(())
426 }
427
428 fn load_outgoing(&mut self) -> GraphResult<()> {
429 match &self.filter {
430 Some(et) => {
431 AdjacencyIndex::for_each_outgoing_by_type(self.tx, self.node, et, |edge_id| {
432 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
433 self.pending.push_back(ExpandResult::new(
434 edge.target,
435 edge_id,
436 Direction::Outgoing,
437 ));
438 }
439 Ok(true)
440 })?;
441 }
442 None => {
443 AdjacencyIndex::for_each_outgoing(self.tx, self.node, |edge_id| {
444 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
445 self.pending.push_back(ExpandResult::new(
446 edge.target,
447 edge_id,
448 Direction::Outgoing,
449 ));
450 }
451 Ok(true)
452 })?;
453 }
454 }
455 Ok(())
456 }
457
458 fn load_incoming(&mut self) -> GraphResult<()> {
459 match &self.filter {
460 Some(et) => {
461 AdjacencyIndex::for_each_incoming_by_type(self.tx, self.node, et, |edge_id| {
462 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
463 self.pending.push_back(ExpandResult::new(
464 edge.source,
465 edge_id,
466 Direction::Incoming,
467 ));
468 }
469 Ok(true)
470 })?;
471 }
472 None => {
473 AdjacencyIndex::for_each_incoming(self.tx, self.node, |edge_id| {
474 if let Some(edge) = EdgeStore::get(self.tx, edge_id)? {
475 self.pending.push_back(ExpandResult::new(
476 edge.source,
477 edge_id,
478 Direction::Incoming,
479 ));
480 }
481 Ok(true)
482 })?;
483 }
484 }
485 Ok(())
486 }
487
488 fn try_next(&mut self) -> GraphResult<Option<ExpandResult>> {
490 self.load_neighbors()?;
491 Ok(self.pending.pop_front())
492 }
493
494 pub fn collect_all(mut self) -> GraphResult<Vec<ExpandResult>> {
496 self.load_neighbors()?;
497 Ok(self.pending.into_iter().collect())
498 }
499}
500
501pub struct NeighborIteratorAdapter<'a, T: Transaction> {
503 inner: NeighborIterator<'a, T>,
504 errored: bool,
505}
506
507impl<'a, T: Transaction> NeighborIteratorAdapter<'a, T> {
508 pub fn new(tx: &'a T, node: EntityId, direction: Direction) -> Self {
510 Self { inner: NeighborIterator::new(tx, node, direction), errored: false }
511 }
512
513 pub fn with_edge_type(mut self, edge_type: impl Into<EdgeType>) -> Self {
515 self.inner = self.inner.with_edge_type(edge_type);
516 self
517 }
518}
519
520impl<T: Transaction> Iterator for NeighborIteratorAdapter<'_, T> {
521 type Item = GraphResult<ExpandResult>;
522
523 fn next(&mut self) -> Option<Self::Item> {
524 if self.errored {
525 return None;
526 }
527
528 match self.inner.try_next() {
529 Ok(Some(result)) => Some(Ok(result)),
530 Ok(None) => None,
531 Err(e) => {
532 self.errored = true;
533 Some(Err(e))
534 }
535 }
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn traversal_config_builder() {
545 let config = TraversalConfig::new(Direction::Both)
546 .with_min_depth(2)
547 .with_max_depth(5)
548 .with_edge_type("FRIEND")
549 .with_limit(100);
550
551 assert_eq!(config.direction, Direction::Both);
552 assert_eq!(config.min_depth, 2);
553 assert_eq!(config.max_depth, Some(5));
554 assert_eq!(config.filter.limit, Some(100));
555 }
556
557 #[test]
558 fn traversal_config_include_start() {
559 let config = TraversalConfig::new(Direction::Outgoing).include_start();
560
561 assert!(config.include_start);
562 assert_eq!(config.min_depth, 0);
563 }
564
565 #[test]
566 fn traversal_state_initialization() {
567 let start = EntityId::new(1);
568 let state = TraversalState::new(start);
569
570 assert!(state.visited.contains(&start));
571 assert_eq!(state.queue.len(), 1);
572 assert_eq!(state.count, 0);
573 }
574}