1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, UInt32Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::HashMap;
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10#[derive(Debug)]
12struct UnionFind {
13 parent: HashMap<String, String>,
14 rank: HashMap<String, usize>,
15 component_sizes: HashMap<String, usize>,
16}
17
18impl UnionFind {
19 fn new() -> Self {
20 UnionFind {
21 parent: HashMap::new(),
22 rank: HashMap::new(),
23 component_sizes: HashMap::new(),
24 }
25 }
26
27 fn make_set(&mut self, node: String) {
28 if !self.parent.contains_key(&node) {
29 self.parent.insert(node.clone(), node.clone());
30 self.rank.insert(node.clone(), 0);
31 self.component_sizes.insert(node.clone(), 1);
32 }
33 }
34
35 fn find(&mut self, node: &str) -> Option<String> {
36 if !self.parent.contains_key(node) {
37 return None;
38 }
39
40 let parent = self.parent.get(node).unwrap().clone();
42 if parent != node {
43 let root = self.find(&parent)?;
44 self.parent.insert(node.to_string(), root.clone());
45 Some(root)
46 } else {
47 Some(parent)
48 }
49 }
50
51 fn union(&mut self, node1: &str, node2: &str) -> bool {
52 let root1 = match self.find(node1) {
53 Some(r) => r,
54 None => return false,
55 };
56
57 let root2 = match self.find(node2) {
58 Some(r) => r,
59 None => return false,
60 };
61
62 if root1 == root2 {
63 return false; }
65
66 let rank1 = *self.rank.get(&root1).unwrap_or(&0);
68 let rank2 = *self.rank.get(&root2).unwrap_or(&0);
69
70 let (new_root, old_root) = if rank1 > rank2 {
71 (root1, root2)
72 } else if rank1 < rank2 {
73 (root2, root1)
74 } else {
75 self.rank.insert(root1.clone(), rank1 + 1);
77 (root1, root2)
78 };
79
80 self.parent.insert(old_root.clone(), new_root.clone());
82
83 let size1 = *self.component_sizes.get(&new_root).unwrap_or(&0);
85 let size2 = *self.component_sizes.get(&old_root).unwrap_or(&0);
86 self.component_sizes.insert(new_root, size1 + size2);
87 self.component_sizes.remove(&old_root);
88
89 true
90 }
91
92 fn get_components(&mut self) -> HashMap<String, Vec<String>> {
93 let mut components: HashMap<String, Vec<String>> = HashMap::new();
94
95 let nodes: Vec<String> = self.parent.keys().cloned().collect();
97 for node in nodes {
98 if let Some(root) = self.find(&node) {
99 components.entry(root).or_default().push(node);
100 }
101 }
102
103 components
104 }
105
106 #[allow(dead_code)]
107 fn component_count(&mut self) -> usize {
108 self.get_components().len()
109 }
110}
111
112pub struct WeaklyConnectedComponents;
113
114impl WeaklyConnectedComponents {
115 fn compute_components(&self, graph: &ArrowGraph) -> Result<HashMap<String, u32>> {
117 let mut uf = UnionFind::new();
118
119 for node_id in graph.node_ids() {
121 uf.make_set(node_id.clone());
122 }
123
124 for node_id in graph.node_ids() {
126 if let Some(neighbors) = graph.neighbors(node_id) {
127 for neighbor in neighbors {
128 uf.union(node_id, neighbor);
129 }
130 }
131 }
132
133 let components = uf.get_components();
135 let mut node_to_component: HashMap<String, u32> = HashMap::new();
136
137 for (component_id, (_root, nodes)) in components.into_iter().enumerate() {
138 for node in nodes {
139 node_to_component.insert(node, component_id as u32);
140 }
141 }
142
143 Ok(node_to_component)
144 }
145}
146
147impl GraphAlgorithm for WeaklyConnectedComponents {
148 fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
149 let component_map = self.compute_components(graph)?;
150
151 if component_map.is_empty() {
152 let schema = Arc::new(Schema::new(vec![
154 Field::new("node_id", DataType::Utf8, false),
155 Field::new("component_id", DataType::UInt32, false),
156 ]));
157
158 return RecordBatch::try_new(
159 schema,
160 vec![
161 Arc::new(StringArray::from(Vec::<String>::new())),
162 Arc::new(UInt32Array::from(Vec::<u32>::new())),
163 ],
164 ).map_err(GraphError::from);
165 }
166
167 let mut sorted_nodes: Vec<(&String, &u32)> = component_map.iter().collect();
169 sorted_nodes.sort_by_key(|(_, &component_id)| component_id);
170
171 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
172 let component_ids: Vec<u32> = sorted_nodes.iter().map(|(_, &comp)| comp).collect();
173
174 let schema = Arc::new(Schema::new(vec![
175 Field::new("node_id", DataType::Utf8, false),
176 Field::new("component_id", DataType::UInt32, false),
177 ]));
178
179 RecordBatch::try_new(
180 schema,
181 vec![
182 Arc::new(StringArray::from(node_ids)),
183 Arc::new(UInt32Array::from(component_ids)),
184 ],
185 ).map_err(GraphError::from)
186 }
187
188 fn name(&self) -> &'static str {
189 "weakly_connected_components"
190 }
191
192 fn description(&self) -> &'static str {
193 "Find weakly connected components using Union-Find with path compression"
194 }
195}
196
197pub struct StronglyConnectedComponents;
198
199impl StronglyConnectedComponents {
200 fn tarjan_scc(&self, graph: &ArrowGraph) -> Result<HashMap<String, u32>> {
202 let mut index_counter = 0;
203 let mut stack = Vec::new();
204 let mut indices: HashMap<String, usize> = HashMap::new();
205 let mut lowlinks: HashMap<String, usize> = HashMap::new();
206 let mut on_stack: HashMap<String, bool> = HashMap::new();
207 let mut components: Vec<Vec<String>> = Vec::new();
208
209 for node_id in graph.node_ids() {
211 on_stack.insert(node_id.clone(), false);
212 }
213
214 for node_id in graph.node_ids() {
216 if !indices.contains_key(node_id) {
217 self.tarjan_strongconnect(
218 node_id,
219 graph,
220 &mut index_counter,
221 &mut stack,
222 &mut indices,
223 &mut lowlinks,
224 &mut on_stack,
225 &mut components,
226 )?;
227 }
228 }
229
230 let mut node_to_component: HashMap<String, u32> = HashMap::new();
232 for (comp_id, component) in components.into_iter().enumerate() {
233 for node in component {
234 node_to_component.insert(node, comp_id as u32);
235 }
236 }
237
238 Ok(node_to_component)
239 }
240
241 fn tarjan_strongconnect(
242 &self,
243 node: &str,
244 graph: &ArrowGraph,
245 index_counter: &mut usize,
246 stack: &mut Vec<String>,
247 indices: &mut HashMap<String, usize>,
248 lowlinks: &mut HashMap<String, usize>,
249 on_stack: &mut HashMap<String, bool>,
250 components: &mut Vec<Vec<String>>,
251 ) -> Result<()> {
252 indices.insert(node.to_string(), *index_counter);
254 lowlinks.insert(node.to_string(), *index_counter);
255 *index_counter += 1;
256
257 stack.push(node.to_string());
259 on_stack.insert(node.to_string(), true);
260
261 if let Some(neighbors) = graph.neighbors(node) {
263 for neighbor in neighbors {
264 if !indices.contains_key(neighbor) {
265 self.tarjan_strongconnect(
267 neighbor,
268 graph,
269 index_counter,
270 stack,
271 indices,
272 lowlinks,
273 on_stack,
274 components,
275 )?;
276
277 let neighbor_lowlink = *lowlinks.get(neighbor).unwrap_or(&0);
278 let current_lowlink = *lowlinks.get(node).unwrap_or(&0);
279 lowlinks.insert(node.to_string(), current_lowlink.min(neighbor_lowlink));
280 } else if *on_stack.get(neighbor).unwrap_or(&false) {
281 let neighbor_index = *indices.get(neighbor).unwrap_or(&0);
283 let current_lowlink = *lowlinks.get(node).unwrap_or(&0);
284 lowlinks.insert(node.to_string(), current_lowlink.min(neighbor_index));
285 }
286 }
287 }
288
289 let node_index = *indices.get(node).unwrap_or(&0);
291 let node_lowlink = *lowlinks.get(node).unwrap_or(&0);
292
293 if node_lowlink == node_index {
294 let mut component = Vec::new();
295 loop {
296 if let Some(w) = stack.pop() {
297 on_stack.insert(w.clone(), false);
298 component.push(w.clone());
299 if w == node {
300 break;
301 }
302 } else {
303 break;
304 }
305 }
306 components.push(component);
307 }
308
309 Ok(())
310 }
311}
312
313impl GraphAlgorithm for StronglyConnectedComponents {
314 fn execute(&self, graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
315 let component_map = self.tarjan_scc(graph)?;
316
317 if component_map.is_empty() {
318 let schema = Arc::new(Schema::new(vec![
320 Field::new("node_id", DataType::Utf8, false),
321 Field::new("component_id", DataType::UInt32, false),
322 ]));
323
324 return RecordBatch::try_new(
325 schema,
326 vec![
327 Arc::new(StringArray::from(Vec::<String>::new())),
328 Arc::new(UInt32Array::from(Vec::<u32>::new())),
329 ],
330 ).map_err(GraphError::from);
331 }
332
333 let mut sorted_nodes: Vec<(&String, &u32)> = component_map.iter().collect();
335 sorted_nodes.sort_by_key(|(_, &component_id)| component_id);
336
337 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
338 let component_ids: Vec<u32> = sorted_nodes.iter().map(|(_, &comp)| comp).collect();
339
340 let schema = Arc::new(Schema::new(vec![
341 Field::new("node_id", DataType::Utf8, false),
342 Field::new("component_id", DataType::UInt32, false),
343 ]));
344
345 RecordBatch::try_new(
346 schema,
347 vec![
348 Arc::new(StringArray::from(node_ids)),
349 Arc::new(UInt32Array::from(component_ids)),
350 ],
351 ).map_err(GraphError::from)
352 }
353
354 fn name(&self) -> &'static str {
355 "strongly_connected_components"
356 }
357
358 fn description(&self) -> &'static str {
359 "Find strongly connected components using Tarjan's algorithm"
360 }
361}