1use std::collections::HashMap;
34
35use manifoldb_core::EntityId;
36use manifoldb_storage::Transaction;
37
38use crate::index::AdjacencyIndex;
39use crate::store::{EdgeStore, GraphError, GraphResult, NodeStore};
40
41use super::pagerank::DEFAULT_MAX_GRAPH_NODES;
42
43#[derive(Debug, Clone)]
45pub struct ConnectedComponentsConfig {
46 pub max_graph_nodes: Option<usize>,
50}
51
52impl Default for ConnectedComponentsConfig {
53 fn default() -> Self {
54 Self { max_graph_nodes: Some(DEFAULT_MAX_GRAPH_NODES) }
55 }
56}
57
58impl ConnectedComponentsConfig {
59 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub const fn with_max_graph_nodes(mut self, limit: Option<usize>) -> Self {
74 self.max_graph_nodes = limit;
75 self
76 }
77}
78
79#[derive(Debug, Clone)]
84pub struct ComponentResult {
85 pub assignments: HashMap<EntityId, usize>,
88
89 pub num_components: usize,
91}
92
93impl ComponentResult {
94 pub fn component(&self, node: EntityId) -> Option<usize> {
96 self.assignments.get(&node).copied()
97 }
98
99 pub fn nodes_in_component(&self, component_id: usize) -> Vec<EntityId> {
101 self.assignments.iter().filter(|(_, &c)| c == component_id).map(|(&node, _)| node).collect()
102 }
103
104 pub fn component_sizes(&self) -> HashMap<usize, usize> {
106 let mut sizes: HashMap<usize, usize> = HashMap::new();
107 for &component in self.assignments.values() {
108 *sizes.entry(component).or_insert(0) += 1;
109 }
110 sizes
111 }
112
113 pub fn components_by_size(&self) -> Vec<(usize, usize)> {
115 let mut sizes: Vec<_> = self.component_sizes().into_iter().collect();
116 sizes.sort_by(|a, b| b.1.cmp(&a.1));
117 sizes
118 }
119
120 pub fn largest_component(&self) -> Option<(usize, usize)> {
122 self.components_by_size().into_iter().next()
123 }
124
125 pub fn smallest_component(&self) -> Option<(usize, usize)> {
127 self.components_by_size().into_iter().last()
128 }
129
130 pub fn same_component(&self, node1: EntityId, node2: EntityId) -> bool {
132 match (self.component(node1), self.component(node2)) {
133 (Some(c1), Some(c2)) => c1 == c2,
134 _ => false,
135 }
136 }
137
138 pub fn component_size(&self, component_id: usize) -> usize {
140 self.component_sizes().get(&component_id).copied().unwrap_or(0)
141 }
142}
143
144struct UnionFind {
149 parent: Vec<usize>,
150 rank: Vec<usize>,
151}
152
153impl UnionFind {
154 fn new(n: usize) -> Self {
156 Self { parent: (0..n).collect(), rank: vec![0; n] }
157 }
158
159 fn find(&mut self, x: usize) -> usize {
161 if self.parent[x] != x {
162 self.parent[x] = self.find(self.parent[x]);
163 }
164 self.parent[x]
165 }
166
167 fn union(&mut self, x: usize, y: usize) {
169 let root_x = self.find(x);
170 let root_y = self.find(y);
171
172 if root_x != root_y {
173 match self.rank[root_x].cmp(&self.rank[root_y]) {
174 std::cmp::Ordering::Less => {
175 self.parent[root_x] = root_y;
176 }
177 std::cmp::Ordering::Greater => {
178 self.parent[root_y] = root_x;
179 }
180 std::cmp::Ordering::Equal => {
181 self.parent[root_y] = root_x;
182 self.rank[root_x] += 1;
183 }
184 }
185 }
186 }
187}
188
189pub struct ConnectedComponents;
191
192impl ConnectedComponents {
193 pub fn weakly_connected<T: Transaction>(
211 tx: &T,
212 config: &ConnectedComponentsConfig,
213 ) -> GraphResult<ComponentResult> {
214 if let Some(limit) = config.max_graph_nodes {
216 let node_count = NodeStore::count(tx)?;
217 if node_count > limit {
218 return Err(GraphError::GraphTooLarge { node_count, limit });
219 }
220 }
221
222 let mut nodes: Vec<EntityId> = Vec::new();
224 NodeStore::for_each(tx, |entity| {
225 nodes.push(entity.id);
226 true
227 })?;
228
229 let n = nodes.len();
230 if n == 0 {
231 return Ok(ComponentResult { assignments: HashMap::new(), num_components: 0 });
232 }
233
234 let node_index: HashMap<EntityId, usize> =
236 nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
237
238 let mut uf = UnionFind::new(n);
240
241 for (i, &node) in nodes.iter().enumerate() {
243 let outgoing = AdjacencyIndex::get_outgoing_edge_ids(tx, node)?;
245 for edge_id in outgoing {
246 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
247 if let Some(&j) = node_index.get(&edge.target) {
248 uf.union(i, j);
249 }
250 }
251 }
252
253 let incoming = AdjacencyIndex::get_incoming_edge_ids(tx, node)?;
255 for edge_id in incoming {
256 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
257 if let Some(&j) = node_index.get(&edge.source) {
258 uf.union(i, j);
259 }
260 }
261 }
262 }
263
264 let mut root_to_component: HashMap<usize, usize> = HashMap::new();
266 let mut next_component = 0usize;
267
268 let mut assignments: HashMap<EntityId, usize> = HashMap::with_capacity(n);
269 for (i, &node) in nodes.iter().enumerate() {
270 let root = uf.find(i);
271 let component = *root_to_component.entry(root).or_insert_with(|| {
272 let c = next_component;
273 next_component += 1;
274 c
275 });
276 assignments.insert(node, component);
277 }
278
279 Ok(ComponentResult { assignments, num_components: next_component })
280 }
281
282 pub fn strongly_connected<T: Transaction>(
299 tx: &T,
300 config: &ConnectedComponentsConfig,
301 ) -> GraphResult<ComponentResult> {
302 if let Some(limit) = config.max_graph_nodes {
304 let node_count = NodeStore::count(tx)?;
305 if node_count > limit {
306 return Err(GraphError::GraphTooLarge { node_count, limit });
307 }
308 }
309
310 let mut nodes: Vec<EntityId> = Vec::new();
312 NodeStore::for_each(tx, |entity| {
313 nodes.push(entity.id);
314 true
315 })?;
316
317 let n = nodes.len();
318 if n == 0 {
319 return Ok(ComponentResult { assignments: HashMap::new(), num_components: 0 });
320 }
321
322 let node_index: HashMap<EntityId, usize> =
324 nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
325
326 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
328 for (i, &node) in nodes.iter().enumerate() {
329 let outgoing = AdjacencyIndex::get_outgoing_edge_ids(tx, node)?;
330 for edge_id in outgoing {
331 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
332 if let Some(&j) = node_index.get(&edge.target) {
333 adjacency[i].push(j);
334 }
335 }
336 }
337 }
338
339 let mut state = TarjanState::new(n);
341
342 for i in 0..n {
343 if state.index[i].is_none() {
344 tarjan_dfs(i, &adjacency, &mut state);
345 }
346 }
347
348 let mut assignments: HashMap<EntityId, usize> = HashMap::with_capacity(n);
350 for (i, &node) in nodes.iter().enumerate() {
351 if let Some(component) = state.component[i] {
352 assignments.insert(node, component);
353 }
354 }
355
356 Ok(ComponentResult { assignments, num_components: state.num_components })
357 }
358
359 pub fn weakly_connected_for_nodes<T: Transaction>(
369 tx: &T,
370 nodes: &[EntityId],
371 config: &ConnectedComponentsConfig,
372 ) -> GraphResult<ComponentResult> {
373 let n = nodes.len();
374 if n == 0 {
375 return Ok(ComponentResult { assignments: HashMap::new(), num_components: 0 });
376 }
377
378 if let Some(limit) = config.max_graph_nodes {
380 if n > limit {
381 return Err(GraphError::GraphTooLarge { node_count: n, limit });
382 }
383 }
384
385 let node_set: std::collections::HashSet<EntityId> = nodes.iter().copied().collect();
387 let node_index: HashMap<EntityId, usize> =
388 nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
389
390 let mut uf = UnionFind::new(n);
392
393 for (i, &node) in nodes.iter().enumerate() {
395 let outgoing = AdjacencyIndex::get_outgoing_edge_ids(tx, node)?;
397 for edge_id in outgoing {
398 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
399 if node_set.contains(&edge.target) {
400 if let Some(&j) = node_index.get(&edge.target) {
401 uf.union(i, j);
402 }
403 }
404 }
405 }
406
407 let incoming = AdjacencyIndex::get_incoming_edge_ids(tx, node)?;
409 for edge_id in incoming {
410 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
411 if node_set.contains(&edge.source) {
412 if let Some(&j) = node_index.get(&edge.source) {
413 uf.union(i, j);
414 }
415 }
416 }
417 }
418 }
419
420 let mut root_to_component: HashMap<usize, usize> = HashMap::new();
422 let mut next_component = 0usize;
423
424 let mut assignments: HashMap<EntityId, usize> = HashMap::with_capacity(n);
425 for (i, &node) in nodes.iter().enumerate() {
426 let root = uf.find(i);
427 let component = *root_to_component.entry(root).or_insert_with(|| {
428 let c = next_component;
429 next_component += 1;
430 c
431 });
432 assignments.insert(node, component);
433 }
434
435 Ok(ComponentResult { assignments, num_components: next_component })
436 }
437
438 pub fn strongly_connected_for_nodes<T: Transaction>(
448 tx: &T,
449 nodes: &[EntityId],
450 config: &ConnectedComponentsConfig,
451 ) -> GraphResult<ComponentResult> {
452 let n = nodes.len();
453 if n == 0 {
454 return Ok(ComponentResult { assignments: HashMap::new(), num_components: 0 });
455 }
456
457 if let Some(limit) = config.max_graph_nodes {
459 if n > limit {
460 return Err(GraphError::GraphTooLarge { node_count: n, limit });
461 }
462 }
463
464 let node_set: std::collections::HashSet<EntityId> = nodes.iter().copied().collect();
466 let node_index: HashMap<EntityId, usize> =
467 nodes.iter().enumerate().map(|(i, &id)| (id, i)).collect();
468
469 let mut adjacency: Vec<Vec<usize>> = vec![Vec::new(); n];
471 for (i, &node) in nodes.iter().enumerate() {
472 let outgoing = AdjacencyIndex::get_outgoing_edge_ids(tx, node)?;
473 for edge_id in outgoing {
474 if let Some(edge) = EdgeStore::get(tx, edge_id)? {
475 if node_set.contains(&edge.target) {
476 if let Some(&j) = node_index.get(&edge.target) {
477 adjacency[i].push(j);
478 }
479 }
480 }
481 }
482 }
483
484 let mut state = TarjanState::new(n);
486
487 for i in 0..n {
488 if state.index[i].is_none() {
489 tarjan_dfs(i, &adjacency, &mut state);
490 }
491 }
492
493 let mut assignments: HashMap<EntityId, usize> = HashMap::with_capacity(n);
495 for (i, &node) in nodes.iter().enumerate() {
496 if let Some(component) = state.component[i] {
497 assignments.insert(node, component);
498 }
499 }
500
501 Ok(ComponentResult { assignments, num_components: state.num_components })
502 }
503}
504
505struct TarjanState {
507 index: Vec<Option<usize>>,
509 lowlink: Vec<usize>,
511 on_stack: Vec<bool>,
513 stack: Vec<usize>,
515 component: Vec<Option<usize>>,
517 current_index: usize,
519 num_components: usize,
521}
522
523impl TarjanState {
524 fn new(n: usize) -> Self {
525 Self {
526 index: vec![None; n],
527 lowlink: vec![0; n],
528 on_stack: vec![false; n],
529 stack: Vec::new(),
530 component: vec![None; n],
531 current_index: 0,
532 num_components: 0,
533 }
534 }
535}
536
537fn tarjan_dfs(start: usize, adjacency: &[Vec<usize>], state: &mut TarjanState) {
539 let mut work_stack: Vec<(usize, usize, u8)> = vec![(start, 0, 0)];
544
545 while let Some((v, neighbor_idx, phase)) = work_stack.pop() {
546 match phase {
547 0 => {
548 state.index[v] = Some(state.current_index);
550 state.lowlink[v] = state.current_index;
551 state.current_index += 1;
552 state.on_stack[v] = true;
553 state.stack.push(v);
554
555 work_stack.push((v, 0, 1));
557 }
558 1 => {
559 if neighbor_idx < adjacency[v].len() {
561 let w = adjacency[v][neighbor_idx];
562
563 if state.index[w].is_none() {
564 work_stack.push((v, neighbor_idx + 1, 2)); work_stack.push((w, 0, 0)); } else if state.on_stack[w] {
568 if let Some(w_index) = state.index[w] {
571 state.lowlink[v] = state.lowlink[v].min(w_index);
572 }
573 work_stack.push((v, neighbor_idx + 1, 1)); } else {
575 work_stack.push((v, neighbor_idx + 1, 1));
577 }
578 } else {
579 if let Some(v_index) = state.index[v] {
582 if state.lowlink[v] == v_index {
583 let component_id = state.num_components;
585 state.num_components += 1;
586
587 while let Some(w) = state.stack.pop() {
589 state.on_stack[w] = false;
590 state.component[w] = Some(component_id);
591 if w == v {
592 break;
593 }
594 }
595 }
596 }
597 }
598 }
599 2 => {
600 let w = adjacency[v][neighbor_idx - 1];
603 state.lowlink[v] = state.lowlink[v].min(state.lowlink[w]);
604
605 work_stack.push((v, neighbor_idx, 1));
607 }
608 _ => unreachable!(),
609 }
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn config_defaults() {
619 let config = ConnectedComponentsConfig::default();
620 assert_eq!(config.max_graph_nodes, Some(DEFAULT_MAX_GRAPH_NODES));
621 }
622
623 #[test]
624 fn config_builder() {
625 let config = ConnectedComponentsConfig::new().with_max_graph_nodes(Some(1000));
626 assert_eq!(config.max_graph_nodes, Some(1000));
627
628 let config = ConnectedComponentsConfig::new().with_max_graph_nodes(None);
629 assert_eq!(config.max_graph_nodes, None);
630 }
631
632 #[test]
633 fn result_empty() {
634 let result = ComponentResult { assignments: HashMap::new(), num_components: 0 };
635
636 assert!(result.component(EntityId::new(1)).is_none());
637 assert!(result.nodes_in_component(0).is_empty());
638 assert!(result.component_sizes().is_empty());
639 assert!(result.largest_component().is_none());
640 assert!(result.smallest_component().is_none());
641 }
642
643 #[test]
644 fn result_operations() {
645 let mut assignments = HashMap::new();
646 assignments.insert(EntityId::new(1), 0);
647 assignments.insert(EntityId::new(2), 0);
648 assignments.insert(EntityId::new(3), 1);
649 assignments.insert(EntityId::new(4), 1);
650 assignments.insert(EntityId::new(5), 1);
651
652 let result = ComponentResult { assignments, num_components: 2 };
653
654 assert_eq!(result.component(EntityId::new(1)), Some(0));
656 assert_eq!(result.component(EntityId::new(3)), Some(1));
657 assert_eq!(result.component(EntityId::new(99)), None);
658
659 let nodes_0 = result.nodes_in_component(0);
661 assert_eq!(nodes_0.len(), 2);
662
663 let nodes_1 = result.nodes_in_component(1);
664 assert_eq!(nodes_1.len(), 3);
665
666 let sizes = result.component_sizes();
668 assert_eq!(sizes.get(&0), Some(&2));
669 assert_eq!(sizes.get(&1), Some(&3));
670
671 assert_eq!(result.largest_component(), Some((1, 3)));
673 assert_eq!(result.smallest_component(), Some((0, 2)));
674
675 assert!(result.same_component(EntityId::new(1), EntityId::new(2)));
677 assert!(result.same_component(EntityId::new(3), EntityId::new(4)));
678 assert!(!result.same_component(EntityId::new(1), EntityId::new(3)));
679
680 assert_eq!(result.component_size(0), 2);
682 assert_eq!(result.component_size(1), 3);
683 assert_eq!(result.component_size(99), 0);
684 }
685
686 #[test]
687 fn union_find_basic() {
688 let mut uf = UnionFind::new(5);
689
690 assert_ne!(uf.find(0), uf.find(1));
692
693 uf.union(0, 1);
695 assert_eq!(uf.find(0), uf.find(1));
696
697 uf.union(2, 3);
699 assert_eq!(uf.find(2), uf.find(3));
700
701 assert_ne!(uf.find(0), uf.find(2));
703
704 uf.union(1, 3);
706 assert_eq!(uf.find(0), uf.find(2));
707 assert_eq!(uf.find(0), uf.find(3));
708 }
709
710 #[test]
711 fn union_find_chain() {
712 let mut uf = UnionFind::new(10);
713
714 for i in 0..9 {
716 uf.union(i, i + 1);
717 }
718
719 let root = uf.find(0);
721 for i in 1..10 {
722 assert_eq!(uf.find(i), root);
723 }
724 }
725}