1use arrow::record_batch::RecordBatch;
2use arrow::array::{StringArray, UInt32Array};
3use arrow::datatypes::{DataType, Field, Schema};
4use std::sync::Arc;
5use std::collections::{HashMap, HashSet};
6use crate::algorithms::{GraphAlgorithm, AlgorithmParams};
7use crate::graph::ArrowGraph;
8use crate::error::{GraphError, Result};
9
10pub struct LouvainCommunityDetection;
11
12impl GraphAlgorithm for LouvainCommunityDetection {
13 fn execute(&self, _graph: &ArrowGraph, _params: &AlgorithmParams) -> Result<RecordBatch> {
14 todo!("Implement Louvain community detection algorithm")
15 }
16
17 fn name(&self) -> &'static str {
18 "louvain"
19 }
20
21 fn description(&self) -> &'static str {
22 "Louvain community detection algorithm"
23 }
24}
25
26pub struct LeidenCommunityDetection;
27
28impl LeidenCommunityDetection {
29 fn leiden_algorithm(
31 &self,
32 graph: &ArrowGraph,
33 resolution: f64,
34 max_iterations: usize,
35 _seed: Option<u64>,
36 ) -> Result<HashMap<String, u32>> {
37 let node_ids: Vec<String> = graph.node_ids().cloned().collect();
38 let node_count = node_ids.len();
39
40 if node_count == 0 {
41 return Ok(HashMap::new());
42 }
43
44 let mut communities: HashMap<String, u32> = HashMap::new();
46 for (i, node_id) in node_ids.iter().enumerate() {
47 communities.insert(node_id.clone(), i as u32);
48 }
49
50 let mut iteration = 0;
51 let mut improved = true;
52
53 while improved && iteration < max_iterations {
54 improved = false;
55 iteration += 1;
56
57 if node_count <= 10 || iteration >= 1 {
59 break;
60 }
61
62 let mut local_moves = true;
64 while local_moves {
65 local_moves = false;
66
67 for node_id in &node_ids {
68 let current_community = *communities.get(node_id).unwrap();
69 let best_community = self.find_best_community(
70 node_id,
71 graph,
72 &communities,
73 resolution,
74 )?;
75
76 if best_community != current_community {
77 communities.insert(node_id.clone(), best_community);
78 local_moves = true;
79 improved = true;
80 }
81 }
82 }
83
84 let refined_communities = self.refine_communities(
86 graph,
87 &communities,
88 resolution,
89 )?;
90
91 if refined_communities != communities {
92 communities = refined_communities;
93 improved = true;
94 }
95
96 }
100
101 self.renumber_communities(communities)
103 }
104
105 fn find_best_community(
106 &self,
107 node_id: &str,
108 graph: &ArrowGraph,
109 communities: &HashMap<String, u32>,
110 resolution: f64,
111 ) -> Result<u32> {
112 let current_community = *communities.get(node_id).unwrap();
113 let mut best_community = current_community;
114 let mut best_gain = 0.0;
115
116 let mut neighbor_communities = HashSet::new();
118 neighbor_communities.insert(current_community);
119
120 if let Some(neighbors) = graph.neighbors(node_id) {
121 for neighbor in neighbors {
122 if let Some(&neighbor_community) = communities.get(neighbor) {
123 neighbor_communities.insert(neighbor_community);
124 }
125 }
126 }
127
128 for &community in &neighbor_communities {
130 let gain = self.calculate_modularity_gain(
131 node_id,
132 community,
133 graph,
134 communities,
135 resolution,
136 )?;
137
138 if gain > best_gain {
139 best_gain = gain;
140 best_community = community;
141 }
142 }
143
144 Ok(best_community)
145 }
146
147 fn calculate_modularity_gain(
148 &self,
149 node_id: &str,
150 target_community: u32,
151 graph: &ArrowGraph,
152 communities: &HashMap<String, u32>,
153 resolution: f64,
154 ) -> Result<f64> {
155 let current_community = *communities.get(node_id).unwrap();
156
157 if target_community == current_community {
158 return Ok(0.0);
159 }
160
161 let node_degree = graph.neighbors(node_id)
163 .map(|neighbors| neighbors.len() as f64)
164 .unwrap_or(0.0);
165
166 if node_degree == 0.0 {
167 return Ok(0.0);
168 }
169
170 let total_edges = graph.edge_count() as f64;
171 if total_edges == 0.0 {
172 return Ok(0.0);
173 }
174
175 let mut connections_to_target = 0.0;
177 if let Some(neighbors) = graph.neighbors(node_id) {
178 for neighbor in neighbors {
179 if let Some(&neighbor_community) = communities.get(neighbor) {
180 if neighbor_community == target_community {
181 let weight = graph.indexes.edge_weights
183 .get(&(node_id.to_string(), neighbor.to_string()))
184 .copied()
185 .unwrap_or(1.0);
186 connections_to_target += weight;
187 }
188 }
189 }
190 }
191
192 let target_community_degree = self.calculate_community_degree(
194 target_community,
195 graph,
196 communities,
197 )?;
198
199 let gain = (connections_to_target / total_edges) -
201 resolution * (node_degree * target_community_degree) / (2.0 * total_edges * total_edges);
202
203 Ok(gain)
204 }
205
206 fn calculate_community_degree(
207 &self,
208 community: u32,
209 graph: &ArrowGraph,
210 communities: &HashMap<String, u32>,
211 ) -> Result<f64> {
212 let mut degree = 0.0;
213
214 for (node_id, &node_community) in communities {
215 if node_community == community {
216 degree += graph.neighbors(node_id)
217 .map(|neighbors| neighbors.len() as f64)
218 .unwrap_or(0.0);
219 }
220 }
221
222 Ok(degree)
223 }
224
225 fn refine_communities(
226 &self,
227 graph: &ArrowGraph,
228 communities: &HashMap<String, u32>,
229 resolution: f64,
230 ) -> Result<HashMap<String, u32>> {
231 let mut refined_communities = communities.clone();
232
233 let mut community_nodes: HashMap<u32, Vec<String>> = HashMap::new();
235 for (node_id, &community) in communities {
236 community_nodes.entry(community)
237 .or_default()
238 .push(node_id.clone());
239 }
240
241 for (community_id, nodes) in community_nodes {
243 if nodes.len() <= 1 {
244 continue;
245 }
246
247 let subcommunities = self.split_community(
248 &nodes,
249 graph,
250 resolution,
251 )?;
252
253 if subcommunities.len() > 1 {
255 let mut next_community_id = refined_communities.values().max().unwrap_or(&0) + 1;
256
257 for (i, subcom_nodes) in subcommunities.into_iter().enumerate() {
258 let target_community = if i == 0 {
259 community_id } else {
261 let id = next_community_id;
262 next_community_id += 1;
263 id
264 };
265
266 for node_id in subcom_nodes {
267 refined_communities.insert(node_id, target_community);
268 }
269 }
270 }
271 }
272
273 Ok(refined_communities)
274 }
275
276 fn split_community(
277 &self,
278 nodes: &[String],
279 graph: &ArrowGraph,
280 _resolution: f64,
281 ) -> Result<Vec<Vec<String>>> {
282 if nodes.len() <= 2 {
283 return Ok(vec![nodes.to_vec()]);
284 }
285
286 let mut visited = HashSet::new();
288 let mut subcommunities = Vec::new();
289
290 for node in nodes {
291 if visited.contains(node) {
292 continue;
293 }
294
295 let mut subcom = Vec::new();
296 let mut stack = vec![node.clone()];
297
298 while let Some(current) = stack.pop() {
299 if visited.contains(¤t) {
300 continue;
301 }
302
303 visited.insert(current.clone());
304 subcom.push(current.clone());
305
306 if let Some(neighbors) = graph.neighbors(¤t) {
308 for neighbor in neighbors {
309 if nodes.contains(neighbor) && !visited.contains(neighbor) {
310 stack.push(neighbor.clone());
311 }
312 }
313 }
314 }
315
316 if !subcom.is_empty() {
317 subcommunities.push(subcom);
318 }
319 }
320
321 Ok(subcommunities)
322 }
323
324 fn renumber_communities(
325 &self,
326 communities: HashMap<String, u32>,
327 ) -> Result<HashMap<String, u32>> {
328 let mut community_mapping = HashMap::new();
329 let mut next_id = 0u32;
330 let mut renumbered = HashMap::new();
331
332 for (node_id, &community) in &communities {
333 let new_community = *community_mapping.entry(community)
334 .or_insert_with(|| {
335 let id = next_id;
336 next_id += 1;
337 id
338 });
339
340 renumbered.insert(node_id.clone(), new_community);
341 }
342
343 Ok(renumbered)
344 }
345}
346
347impl GraphAlgorithm for LeidenCommunityDetection {
348 fn execute(&self, graph: &ArrowGraph, params: &AlgorithmParams) -> Result<RecordBatch> {
349 let resolution: f64 = params.get("resolution").unwrap_or(1.0);
350 let max_iterations: usize = params.get("max_iterations").unwrap_or(10);
351 let seed: Option<u64> = params.get("seed");
352
353 if resolution <= 0.0 {
355 return Err(GraphError::invalid_parameter(
356 "resolution must be greater than 0.0"
357 ));
358 }
359
360 if max_iterations == 0 {
361 return Err(GraphError::invalid_parameter(
362 "max_iterations must be greater than 0"
363 ));
364 }
365
366 let communities = self.leiden_algorithm(graph, resolution, max_iterations, seed)?;
367
368 if communities.is_empty() {
369 let schema = Arc::new(Schema::new(vec![
371 Field::new("node_id", DataType::Utf8, false),
372 Field::new("community_id", DataType::UInt32, false),
373 ]));
374
375 return RecordBatch::try_new(
376 schema,
377 vec![
378 Arc::new(StringArray::from(Vec::<String>::new())),
379 Arc::new(UInt32Array::from(Vec::<u32>::new())),
380 ],
381 ).map_err(GraphError::from);
382 }
383
384 let mut sorted_nodes: Vec<(&String, &u32)> = communities.iter().collect();
386 sorted_nodes.sort_by_key(|(_, &community_id)| community_id);
387
388 let node_ids: Vec<String> = sorted_nodes.iter().map(|(node, _)| (*node).clone()).collect();
389 let community_ids: Vec<u32> = sorted_nodes.iter().map(|(_, &comm)| comm).collect();
390
391 let schema = Arc::new(Schema::new(vec![
392 Field::new("node_id", DataType::Utf8, false),
393 Field::new("community_id", DataType::UInt32, false),
394 ]));
395
396 RecordBatch::try_new(
397 schema,
398 vec![
399 Arc::new(StringArray::from(node_ids)),
400 Arc::new(UInt32Array::from(community_ids)),
401 ],
402 ).map_err(GraphError::from)
403 }
404
405 fn name(&self) -> &'static str {
406 "leiden"
407 }
408
409 fn description(&self) -> &'static str {
410 "Leiden community detection algorithm with refinement phase"
411 }
412}