1use crate::clause::Clause;
2#[allow(unused_imports)]
14use crate::prelude::*;
15
16#[derive(Debug, Clone)]
21pub struct VariableIncidenceGraph {
22 num_vars: usize,
24 adjacency: Vec<HashMap<usize, f64>>,
26 total_weight: f64,
28}
29
30impl VariableIncidenceGraph {
31 pub fn from_clauses(num_vars: usize, clauses: &[Clause]) -> Self {
33 let mut adjacency = vec![HashMap::new(); num_vars];
34 let mut total_weight = 0.0;
35
36 for clause in clauses {
37 let vars: Vec<usize> = clause.lits.iter().map(|lit| lit.var().index()).collect();
38
39 for i in 0..vars.len() {
41 for j in (i + 1)..vars.len() {
42 let (u, v) = (vars[i], vars[j]);
43 if u < num_vars && v < num_vars {
44 *adjacency[u].entry(v).or_insert(0.0) += 1.0;
45 *adjacency[v].entry(u).or_insert(0.0) += 1.0;
46 total_weight += 2.0;
47 }
48 }
49 }
50 }
51
52 Self {
53 num_vars,
54 adjacency,
55 total_weight,
56 }
57 }
58
59 pub fn num_vars(&self) -> usize {
61 self.num_vars
62 }
63
64 pub fn neighbors(&self, var: usize) -> &HashMap<usize, f64> {
66 &self.adjacency[var]
67 }
68
69 pub fn degree(&self, var: usize) -> f64 {
71 self.adjacency[var].values().sum()
72 }
73
74 pub fn total_weight(&self) -> f64 {
76 self.total_weight
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct Communities {
83 assignment: Vec<usize>,
85 num_communities: usize,
87 modularity: f64,
89}
90
91impl Communities {
92 pub fn new(num_vars: usize) -> Self {
94 Self {
95 assignment: (0..num_vars).collect(),
96 num_communities: num_vars,
97 modularity: 0.0,
98 }
99 }
100
101 pub fn community(&self, var: usize) -> usize {
103 self.assignment[var]
104 }
105
106 pub fn num_communities(&self) -> usize {
108 self.num_communities
109 }
110
111 pub fn modularity(&self) -> f64 {
113 self.modularity
114 }
115
116 fn assign(&mut self, var: usize, community: usize) {
118 self.assignment[var] = community;
119 }
120
121 fn renumber_communities(&mut self) {
123 let mut community_map = HashMap::new();
124 let mut next_id = 0;
125
126 for var in 0..self.assignment.len() {
127 let old_id = self.assignment[var];
128 let new_id = *community_map.entry(old_id).or_insert_with(|| {
129 let id = next_id;
130 next_id += 1;
131 id
132 });
133 self.assignment[var] = new_id;
134 }
135
136 self.num_communities = next_id;
137 }
138
139 pub fn get_communities(&self) -> Vec<Vec<usize>> {
141 let mut communities = vec![Vec::new(); self.num_communities];
142 for (var, &comm_id) in self.assignment.iter().enumerate() {
143 if comm_id < self.num_communities {
144 communities[comm_id].push(var);
145 }
146 }
147 communities
148 }
149}
150
151pub struct LouvainDetector {
160 max_iterations: usize,
162 min_improvement: f64,
164}
165
166impl Default for LouvainDetector {
167 fn default() -> Self {
168 Self {
169 max_iterations: 10,
170 min_improvement: 1e-6,
171 }
172 }
173}
174
175impl LouvainDetector {
176 pub fn new(max_iterations: usize, min_improvement: f64) -> Self {
178 Self {
179 max_iterations,
180 min_improvement,
181 }
182 }
183
184 pub fn detect(&self, vig: &VariableIncidenceGraph) -> Communities {
186 let mut communities = Communities::new(vig.num_vars());
187
188 for _ in 0..self.max_iterations {
190 let old_modularity = communities.modularity;
191 let improved = self.local_moving(vig, &mut communities);
192
193 if !improved || (communities.modularity - old_modularity) < self.min_improvement {
194 break;
195 }
196 }
197
198 communities.renumber_communities();
199 communities
200 }
201
202 fn local_moving(&self, vig: &VariableIncidenceGraph, communities: &mut Communities) -> bool {
206 let mut improved = false;
207 let m2 = vig.total_weight();
208
209 let mut comm_degrees: Vec<f64> = vec![0.0; vig.num_vars()];
211 for var in 0..vig.num_vars() {
212 let comm = communities.community(var);
213 comm_degrees[comm] += vig.degree(var);
214 }
215
216 for var in 0..vig.num_vars() {
218 let current_comm = communities.community(var);
219 let var_degree = vig.degree(var);
220
221 let mut best_comm = current_comm;
223 let mut best_gain = 0.0;
224
225 let mut neighbor_comms: HashMap<usize, f64> = HashMap::new();
227 for (&neighbor, &weight) in vig.neighbors(var).iter() {
228 let neighbor_comm = communities.community(neighbor);
229 *neighbor_comms.entry(neighbor_comm).or_insert(0.0) += weight;
230 }
231
232 for (&comm, &edge_weight) in neighbor_comms.iter() {
233 if comm == current_comm {
234 continue;
235 }
236
237 let sigma_tot = comm_degrees[comm];
239 let k_i = var_degree;
240 let k_i_in = edge_weight;
241
242 let gain = (k_i_in / m2) - (sigma_tot * k_i / (m2 * m2));
243
244 if gain > best_gain {
245 best_gain = gain;
246 best_comm = comm;
247 }
248 }
249
250 if best_comm != current_comm && best_gain > 0.0 {
252 comm_degrees[current_comm] -= var_degree;
253 comm_degrees[best_comm] += var_degree;
254 communities.assign(var, best_comm);
255 improved = true;
256 }
257 }
258
259 communities.modularity = self.calculate_modularity(vig, communities);
261 improved
262 }
263
264 fn calculate_modularity(&self, vig: &VariableIncidenceGraph, communities: &Communities) -> f64 {
269 let m2 = vig.total_weight();
270 if m2 == 0.0 {
271 return 0.0;
272 }
273
274 let mut modularity = 0.0;
275
276 for i in 0..vig.num_vars() {
277 let comm_i = communities.community(i);
278 let deg_i = vig.degree(i);
279
280 for (&j, &weight) in vig.neighbors(i).iter() {
281 let comm_j = communities.community(j);
282
283 if comm_i == comm_j {
284 let deg_j = vig.degree(j);
285 modularity += weight - (deg_i * deg_j / m2);
286 }
287 }
288 }
289
290 modularity / m2
291 }
292}
293
294pub struct CommunityOrdering {
298 communities: Communities,
300 ordering: Vec<usize>,
302}
303
304impl CommunityOrdering {
305 pub fn new(communities: Communities) -> Self {
307 let mut ordering = Vec::with_capacity(communities.assignment.len());
308
309 let comm_groups = communities.get_communities();
311
312 for group in comm_groups {
313 ordering.extend(group);
314 }
315
316 Self {
317 communities,
318 ordering,
319 }
320 }
321
322 pub fn ordering(&self) -> &[usize] {
324 &self.ordering
325 }
326
327 pub fn community(&self, var: usize) -> usize {
329 self.communities.community(var)
330 }
331
332 pub fn same_community(&self, var: usize) -> Vec<usize> {
334 let target_comm = self.communities.community(var);
335 (0..self.communities.assignment.len())
336 .filter(|&v| self.communities.community(v) == target_comm)
337 .collect()
338 }
339}
340
341#[derive(Debug, Clone, Default)]
343pub struct CommunityStats {
344 pub num_communities: usize,
346 pub modularity: f64,
348 pub avg_community_size: f64,
350 pub max_community_size: usize,
352 pub min_community_size: usize,
354 pub num_vars: usize,
356}
357
358impl CommunityStats {
359 pub fn from_communities(communities: &Communities) -> Self {
361 let groups = communities.get_communities();
362 let sizes: Vec<usize> = groups.iter().map(|g| g.len()).collect();
363
364 let avg_size = if !sizes.is_empty() {
365 sizes.iter().sum::<usize>() as f64 / sizes.len() as f64
366 } else {
367 0.0
368 };
369
370 Self {
371 num_communities: communities.num_communities(),
372 modularity: communities.modularity(),
373 avg_community_size: avg_size,
374 max_community_size: sizes.iter().copied().max().unwrap_or(0),
375 min_community_size: sizes.iter().copied().min().unwrap_or(0),
376 num_vars: communities.assignment.len(),
377 }
378 }
379
380 pub fn display(&self) -> String {
382 format!(
383 "Community Detection Statistics:\n\
384 - Variables: {}\n\
385 - Communities: {}\n\
386 - Modularity: {:.4}\n\
387 - Avg Size: {:.2}\n\
388 - Size Range: [{}, {}]",
389 self.num_vars,
390 self.num_communities,
391 self.modularity,
392 self.avg_community_size,
393 self.min_community_size,
394 self.max_community_size
395 )
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::literal::{Lit, Var};
403
404 fn make_lit(var: usize, sign: bool) -> Lit {
405 let v = Var::new(var as u32);
406 if sign { Lit::pos(v) } else { Lit::neg(v) }
407 }
408
409 #[test]
410 fn test_vig_creation() {
411 let clauses = vec![
412 Clause::original(vec![make_lit(0, false), make_lit(1, false)]),
413 Clause::original(vec![make_lit(1, false), make_lit(2, false)]),
414 Clause::original(vec![make_lit(0, false), make_lit(2, false)]),
415 ];
416
417 let vig = VariableIncidenceGraph::from_clauses(3, &clauses);
418
419 assert_eq!(vig.num_vars(), 3);
420 assert!(vig.degree(0) > 0.0);
421 assert!(vig.degree(1) > 0.0);
422 assert!(vig.degree(2) > 0.0);
423 }
424
425 #[test]
426 fn test_vig_edges() {
427 let clauses = vec![Clause::original(vec![
428 make_lit(0, false),
429 make_lit(1, false),
430 ])];
431
432 let vig = VariableIncidenceGraph::from_clauses(2, &clauses);
433
434 assert_eq!(vig.neighbors(0).get(&1), Some(&1.0));
435 assert_eq!(vig.neighbors(1).get(&0), Some(&1.0));
436 }
437
438 #[test]
439 fn test_communities_creation() {
440 let communities = Communities::new(5);
441
442 assert_eq!(communities.num_communities(), 5);
443 assert_eq!(communities.community(0), 0);
444 assert_eq!(communities.community(4), 4);
445 }
446
447 #[test]
448 fn test_communities_renumber() {
449 let mut communities = Communities::new(5);
450 communities.assign(0, 10);
451 communities.assign(1, 10);
452 communities.assign(2, 20);
453 communities.assign(3, 20);
454 communities.assign(4, 30);
455
456 communities.renumber_communities();
457
458 assert_eq!(communities.num_communities(), 3);
459 assert_eq!(communities.community(0), communities.community(1));
460 assert_eq!(communities.community(2), communities.community(3));
461 }
462
463 #[test]
464 fn test_louvain_simple() {
465 let clauses = vec![
467 Clause::original(vec![make_lit(0, false), make_lit(1, false)]),
468 Clause::original(vec![make_lit(1, false), make_lit(2, false)]),
469 Clause::original(vec![make_lit(2, false), make_lit(3, false)]),
470 Clause::original(vec![make_lit(4, false), make_lit(5, false)]),
471 ];
472
473 let vig = VariableIncidenceGraph::from_clauses(6, &clauses);
474 let detector = LouvainDetector::default();
475 let communities = detector.detect(&vig);
476
477 assert!(communities.num_communities() <= 3);
479
480 let comm_01 = communities.community(0);
482 let comm_45 = communities.community(4);
483 assert_ne!(comm_01, comm_45);
484 }
485
486 #[test]
487 fn test_community_ordering() {
488 let clauses = vec![
489 Clause::original(vec![make_lit(0, false), make_lit(1, false)]),
490 Clause::original(vec![make_lit(2, false), make_lit(3, false)]),
491 ];
492
493 let vig = VariableIncidenceGraph::from_clauses(4, &clauses);
494 let detector = LouvainDetector::default();
495 let communities = detector.detect(&vig);
496 let ordering = CommunityOrdering::new(communities);
497
498 assert_eq!(ordering.ordering().len(), 4);
499 }
500
501 #[test]
502 fn test_same_community() {
503 let clauses = vec![
504 Clause::original(vec![make_lit(0, false), make_lit(1, false)]),
505 Clause::original(vec![make_lit(2, false), make_lit(3, false)]),
506 ];
507
508 let vig = VariableIncidenceGraph::from_clauses(4, &clauses);
509 let detector = LouvainDetector::default();
510 let communities = detector.detect(&vig);
511 let ordering = CommunityOrdering::new(communities);
512
513 let same = ordering.same_community(0);
514 assert!(same.contains(&0));
515 }
516
517 #[test]
518 fn test_community_stats() {
519 let mut communities = Communities::new(10);
520 communities.assign(0, 0);
522 communities.assign(1, 0);
523 communities.assign(2, 0);
524 communities.assign(3, 0);
525 communities.assign(4, 0);
526 communities.assign(5, 1);
527 communities.assign(6, 1);
528 communities.assign(7, 1);
529 communities.assign(8, 1);
530 communities.assign(9, 1);
531 communities.renumber_communities();
532
533 let stats = CommunityStats::from_communities(&communities);
534
535 assert_eq!(stats.num_communities, 2);
536 assert_eq!(stats.num_vars, 10);
537 assert_eq!(stats.avg_community_size, 5.0);
538 assert_eq!(stats.min_community_size, 5);
539 assert_eq!(stats.max_community_size, 5);
540 }
541
542 #[test]
543 fn test_modularity_calculation() {
544 let clauses = vec![
545 Clause::original(vec![make_lit(0, false), make_lit(1, false)]),
546 Clause::original(vec![make_lit(2, false), make_lit(3, false)]),
547 ];
548
549 let vig = VariableIncidenceGraph::from_clauses(4, &clauses);
550 let detector = LouvainDetector::default();
551 let communities = detector.detect(&vig);
552
553 assert!(communities.modularity() >= 0.0);
555 }
556}