arrow_graph/algorithms/
components.rs

1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, UInt32Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::HashMap;
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10/// Union-Find (Disjoint Set) data structure for efficient connected components
11#[derive(Debug)]
12struct UnionFind {
13    parent: HashMap<String, String>,
14    rank: HashMap<String, usize>,
15    component_sizes: HashMap<String, usize>,
16}
17
18impl UnionFind {
19    fn new() -> Self {
20        UnionFind {
21            parent: HashMap::new(),
22            rank: HashMap::new(),
23            component_sizes: HashMap::new(),
24        }
25    }
26    
27    fn make_set(&mut self, node: String) {
28        if !self.parent.contains_key(&node) {
29            self.parent.insert(node.clone(), node.clone());
30            self.rank.insert(node.clone(), 0);
31            self.component_sizes.insert(node.clone(), 1);
32        }
33    }
34    
35    fn find(&mut self, node: &str) -> Option<String> {
36        if !self.parent.contains_key(node) {
37            return None;
38        }
39        
40        // Path compression
41        let parent = self.parent.get(node).unwrap().clone();
42        if parent != node {
43            let root = self.find(&parent)?;
44            self.parent.insert(node.to_string(), root.clone());
45            Some(root)
46        } else {
47            Some(parent)
48        }
49    }
50    
51    fn union(&mut self, node1: &str, node2: &str) -> bool {
52        let root1 = match self.find(node1) {
53            Some(r) => r,
54            None => return false,
55        };
56        
57        let root2 = match self.find(node2) {
58            Some(r) => r,
59            None => return false,
60        };
61        
62        if root1 == root2 {
63            return false; // Already in same component
64        }
65        
66        // Union by rank
67        let rank1 = *self.rank.get(&root1).unwrap_or(&0);
68        let rank2 = *self.rank.get(&root2).unwrap_or(&0);
69        
70        let (new_root, old_root) = if rank1 > rank2 {
71            (root1, root2)
72        } else if rank1 < rank2 {
73            (root2, root1)
74        } else {
75            // Equal ranks, choose root1 and increment its rank
76            self.rank.insert(root1.clone(), rank1 + 1);
77            (root1, root2)
78        };
79        
80        // Update parent
81        self.parent.insert(old_root.clone(), new_root.clone());
82        
83        // Update component sizes
84        let size1 = *self.component_sizes.get(&new_root).unwrap_or(&0);
85        let size2 = *self.component_sizes.get(&old_root).unwrap_or(&0);
86        self.component_sizes.insert(new_root, size1 + size2);
87        self.component_sizes.remove(&old_root);
88        
89        true
90    }
91    
92    fn get_components(&mut self) -> HashMap<String, Vec<String>> {
93        let mut components: HashMap<String, Vec<String>> = HashMap::new();
94        
95        // Get all nodes and their root components
96        let nodes: Vec<String> = self.parent.keys().cloned().collect();
97        for node in nodes {
98            if let Some(root) = self.find(&node) {
99                components.entry(root).or_default().push(node);
100            }
101        }
102        
103        components
104    }
105    
106    #[allow(dead_code)]
107    fn component_count(&mut self) -> usize {
108        self.get_components().len()
109    }
110}
111
112pub struct WeaklyConnectedComponents;
113
114impl WeaklyConnectedComponents {
115    /// Find weakly connected components using Union-Find
116    fn compute_components(&self, graph: &ArrowGraph) -> Result<HashMap<String, u32>> {
117        let mut uf = UnionFind::new();
118        
119        // Initialize all nodes
120        for node_id in graph.node_ids() {
121            uf.make_set(node_id.clone());
122        }
123        
124        // Union nodes connected by edges (treat as undirected)
125        for node_id in graph.node_ids() {
126            if let Some(neighbors) = graph.neighbors(node_id) {
127                for neighbor in neighbors {
128                    uf.union(node_id, neighbor);
129                }
130            }
131        }
132        
133        // Assign component IDs
134        let components = uf.get_components();
135        let mut node_to_component: HashMap<String, u32> = HashMap::new();
136        
137        for (component_id, (_root, nodes)) in components.into_iter().enumerate() {
138            for node in nodes {
139                node_to_component.insert(node, component_id as u32);
140            }
141        }
142        
143        Ok(node_to_component)
144    }
145}
146
147impl GraphAlgorithm for WeaklyConnectedComponents {
148    fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
149        let component_map = self.compute_components(graph)?;
150        
151        if component_map.is_empty() {
152            // Return empty result with proper schema
153            let schema = Arc::new(Schema::new(vec![
154                Field::new("node_id", DataType::Utf8, false),
155                Field::new("component_id", DataType::UInt32, false),
156            ]));
157            
158            return RecordBatch::try_new(
159                schema,
160                vec![
161                    Arc::new(StringArray::from(Vec::<String>::new())),
162                    Arc::new(UInt32Array::from(Vec::<u32>::new())),
163                ],
164            ).map_err(GraphError::from);
165        }
166        
167        // Sort by component ID for consistent output
168        let mut sorted_nodes: Vec<(&String, &u32)> = component_map.iter().collect();
169        sorted_nodes.sort_by_key(|(_, &component_id)| component_id);
170        
171        let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
172        let component_ids: Vec<u32> = sorted_nodes.iter().map(|(_, &comp)| comp).collect();
173        
174        let schema = Arc::new(Schema::new(vec![
175            Field::new("node_id", DataType::Utf8, false),
176            Field::new("component_id", DataType::UInt32, false),
177        ]));
178        
179        RecordBatch::try_new(
180            schema,
181            vec![
182                Arc::new(StringArray::from(node_ids)),
183                Arc::new(UInt32Array::from(component_ids)),
184            ],
185        ).map_err(GraphError::from)
186    }
187    
188    fn name(&self) -> &'static str {
189        "weakly_connected_components"
190    }
191    
192    fn description(&self) -> &'static str {
193        "Find weakly connected components using Union-Find with path compression"
194    }
195}
196
197pub struct StronglyConnectedComponents;
198
199impl StronglyConnectedComponents {
200    /// Tarjan's algorithm for strongly connected components
201    fn tarjan_scc(&self, graph: &ArrowGraph) -> Result<HashMap<String, u32>> {
202        let mut index_counter = 0;
203        let mut stack = Vec::new();
204        let mut indices: HashMap<String, usize> = HashMap::new();
205        let mut lowlinks: HashMap<String, usize> = HashMap::new();
206        let mut on_stack: HashMap<String, bool> = HashMap::new();
207        let mut components: Vec<Vec<String>> = Vec::new();
208        
209        // Initialize
210        for node_id in graph.node_ids() {
211            on_stack.insert(node_id.clone(), false);
212        }
213        
214        // Run DFS from each unvisited node
215        for node_id in graph.node_ids() {
216            if !indices.contains_key(node_id) {
217                self.tarjan_strongconnect(
218                    node_id,
219                    graph,
220                    &mut index_counter,
221                    &mut stack,
222                    &mut indices,
223                    &mut lowlinks,
224                    &mut on_stack,
225                    &mut components,
226                )?;
227            }
228        }
229        
230        // Create component mapping
231        let mut node_to_component: HashMap<String, u32> = HashMap::new();
232        for (comp_id, component) in components.into_iter().enumerate() {
233            for node in component {
234                node_to_component.insert(node, comp_id as u32);
235            }
236        }
237        
238        Ok(node_to_component)
239    }
240    
241    fn tarjan_strongconnect(
242        &self,
243        node: &str,
244        graph: &ArrowGraph,
245        index_counter: &mut usize,
246        stack: &mut Vec<String>,
247        indices: &mut HashMap<String, usize>,
248        lowlinks: &mut HashMap<String, usize>,
249        on_stack: &mut HashMap<String, bool>,
250        components: &mut Vec<Vec<String>>,
251    ) -> Result<()> {
252        // Set the depth index for this node
253        indices.insert(node.to_string(), *index_counter);
254        lowlinks.insert(node.to_string(), *index_counter);
255        *index_counter += 1;
256        
257        // Push node onto stack
258        stack.push(node.to_string());
259        on_stack.insert(node.to_string(), true);
260        
261        // Consider successors
262        if let Some(neighbors) = graph.neighbors(node) {
263            for neighbor in neighbors {
264                if !indices.contains_key(neighbor) {
265                    // Successor has not yet been visited; recurse on it
266                    self.tarjan_strongconnect(
267                        neighbor,
268                        graph,
269                        index_counter,
270                        stack,
271                        indices,
272                        lowlinks,
273                        on_stack,
274                        components,
275                    )?;
276                    
277                    let neighbor_lowlink = *lowlinks.get(neighbor).unwrap_or(&0);
278                    let current_lowlink = *lowlinks.get(node).unwrap_or(&0);
279                    lowlinks.insert(node.to_string(), current_lowlink.min(neighbor_lowlink));
280                } else if *on_stack.get(neighbor).unwrap_or(&false) {
281                    // Successor is in stack and hence in the current SCC
282                    let neighbor_index = *indices.get(neighbor).unwrap_or(&0);
283                    let current_lowlink = *lowlinks.get(node).unwrap_or(&0);
284                    lowlinks.insert(node.to_string(), current_lowlink.min(neighbor_index));
285                }
286            }
287        }
288        
289        // If node is a root node, pop the stack and create an SCC
290        let node_index = *indices.get(node).unwrap_or(&0);
291        let node_lowlink = *lowlinks.get(node).unwrap_or(&0);
292        
293        if node_lowlink == node_index {
294            let mut component = Vec::new();
295            loop {
296                if let Some(w) = stack.pop() {
297                    on_stack.insert(w.clone(), false);
298                    component.push(w.clone());
299                    if w == node {
300                        break;
301                    }
302                } else {
303                    break;
304                }
305            }
306            components.push(component);
307        }
308        
309        Ok(())
310    }
311}
312
313impl GraphAlgorithm for StronglyConnectedComponents {
314    fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
315        let component_map = self.tarjan_scc(graph)?;
316        
317        if component_map.is_empty() {
318            // Return empty result with proper schema
319            let schema = Arc::new(Schema::new(vec![
320                Field::new("node_id", DataType::Utf8, false),
321                Field::new("component_id", DataType::UInt32, false),
322            ]));
323            
324            return RecordBatch::try_new(
325                schema,
326                vec![
327                    Arc::new(StringArray::from(Vec::<String>::new())),
328                    Arc::new(UInt32Array::from(Vec::<u32>::new())),
329                ],
330            ).map_err(GraphError::from);
331        }
332        
333        // Sort by component ID for consistent output
334        let mut sorted_nodes: Vec<(&String, &u32)> = component_map.iter().collect();
335        sorted_nodes.sort_by_key(|(_, &component_id)| component_id);
336        
337        let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
338        let component_ids: Vec<u32> = sorted_nodes.iter().map(|(_, &comp)| comp).collect();
339        
340        let schema = Arc::new(Schema::new(vec![
341            Field::new("node_id", DataType::Utf8, false),
342            Field::new("component_id", DataType::UInt32, false),
343        ]));
344        
345        RecordBatch::try_new(
346            schema,
347            vec![
348                Arc::new(StringArray::from(node_ids)),
349                Arc::new(UInt32Array::from(component_ids)),
350            ],
351        ).map_err(GraphError::from)
352    }
353    
354    fn name(&self) -> &'static str {
355        "strongly_connected_components"
356    }
357    
358    fn description(&self) -> &'static str {
359        "Find strongly connected components using Tarjan's algorithm"
360    }
361}