1use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18
19use super::{AdjacencyGraph, CommunityResult};
20use crate::error::{ClusteringError, Result};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct StochasticBlockModelConfig {
29 pub num_blocks: Option<usize>,
31 pub k_range: (usize, usize),
33 pub max_iterations: usize,
35 pub convergence_threshold: f64,
37 pub degree_corrected: bool,
39 pub seed: u64,
41}
42
43impl Default for StochasticBlockModelConfig {
44 fn default() -> Self {
45 Self {
46 num_blocks: None,
47 k_range: (2, 8),
48 max_iterations: 100,
49 convergence_threshold: 1e-6,
50 degree_corrected: false,
51 seed: 42,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SBMResult {
59 pub community: CommunityResult,
61 pub block_matrix: Vec<f64>,
63 pub k: usize,
65 pub log_likelihood: f64,
67 pub icl_score: f64,
69 pub degree_corrections: Option<Vec<f64>>,
71}
72
73struct Xorshift64(u64);
78
79impl Xorshift64 {
80 fn new(seed: u64) -> Self {
81 Self(if seed == 0 { 1 } else { seed })
82 }
83 fn next_u64(&mut self) -> u64 {
84 let mut x = self.0;
85 x ^= x << 13;
86 x ^= x >> 7;
87 x ^= x << 17;
88 self.0 = x;
89 x
90 }
91 fn next_f64(&mut self) -> f64 {
93 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct StochasticBlockModel {
104 pub config: StochasticBlockModelConfig,
106}
107
108impl StochasticBlockModel {
109 pub fn new(config: StochasticBlockModelConfig) -> Self {
111 Self { config }
112 }
113
114 pub fn fit(&self, graph: &AdjacencyGraph) -> Result<SBMResult> {
120 let n = graph.n_nodes;
121 if n == 0 {
122 return Err(ClusteringError::InvalidInput(
123 "Graph has zero nodes".to_string(),
124 ));
125 }
126
127 if let Some(k) = self.config.num_blocks {
128 if k == 0 || k > n {
129 return Err(ClusteringError::InvalidInput(format!(
130 "num_blocks ({}) must be in [1, {}]",
131 k, n
132 )));
133 }
134 self.fit_k(graph, k)
135 } else {
136 let k_min = self.config.k_range.0.max(1);
137 let k_max = self.config.k_range.1.min(n);
138 if k_min > k_max {
139 return Err(ClusteringError::InvalidInput("Invalid k_range".to_string()));
140 }
141
142 let mut best: Option<SBMResult> = None;
143 for k in k_min..=k_max {
144 let result = self.fit_k(graph, k)?;
145 let is_better = best
146 .as_ref()
147 .map(|b| result.icl_score > b.icl_score)
148 .unwrap_or(true);
149 if is_better {
150 best = Some(result);
151 }
152 }
153
154 best.ok_or_else(|| ClusteringError::ComputationError("No valid K found".to_string()))
155 }
156 }
157
158 fn fit_k(&self, graph: &AdjacencyGraph, k: usize) -> Result<SBMResult> {
160 let n = graph.n_nodes;
161 let mut rng = Xorshift64::new(self.config.seed.wrapping_add(k as u64));
162
163 let mut adj_matrix = vec![0.0_f64; n * n];
165 for i in 0..n {
166 for &(j, w) in &graph.adjacency[i] {
167 adj_matrix[i * n + j] = w;
168 }
169 }
170
171 let mut tau = vec![0.0_f64; n * k];
174 for i in 0..n {
175 let assigned = (i * k / n) % k; for r in 0..k {
177 tau[i * k + r] = if r == assigned {
178 0.8
179 } else {
180 0.2 / ((k - 1).max(1) as f64)
181 };
182 }
183 let noise_sum: f64 = (0..k).map(|_| rng.next_f64() * 0.1).sum();
185 for r in 0..k {
186 tau[i * k + r] += rng.next_f64() * 0.1;
187 }
188 let row_sum: f64 = (0..k).map(|r| tau[i * k + r]).sum();
190 if row_sum > 0.0 {
191 for r in 0..k {
192 tau[i * k + r] /= row_sum;
193 }
194 }
195 let _ = noise_sum; }
197
198 let mut b_matrix = vec![0.0_f64; k * k];
200 let mut theta = vec![1.0_f64; n];
202
203 let mut prev_ll = f64::NEG_INFINITY;
204
205 for _iter in 0..self.config.max_iterations {
206 self.m_step(graph, &adj_matrix, &tau, &mut b_matrix, &mut theta, n, k);
208
209 self.e_step(graph, &adj_matrix, &b_matrix, &theta, &mut tau, n, k);
211
212 let ll = self.log_likelihood(&adj_matrix, &b_matrix, &theta, &tau, n, k);
214
215 if (ll - prev_ll).abs() < self.config.convergence_threshold {
216 break;
217 }
218 prev_ll = ll;
219 }
220
221 let mut labels = vec![0usize; n];
223 for i in 0..n {
224 let mut best_r = 0;
225 let mut best_val = f64::NEG_INFINITY;
226 for r in 0..k {
227 if tau[i * k + r] > best_val {
228 best_val = tau[i * k + r];
229 best_r = r;
230 }
231 }
232 labels[i] = best_r;
233 }
234
235 let mut mapping: HashMap<usize, usize> = HashMap::new();
237 let mut next_id = 0usize;
238 for lbl in &labels {
239 if !mapping.contains_key(lbl) {
240 mapping.insert(*lbl, next_id);
241 next_id += 1;
242 }
243 }
244 let compacted: Vec<usize> = labels
245 .iter()
246 .map(|l| mapping.get(l).copied().unwrap_or(0))
247 .collect();
248 let num_communities = next_id;
249
250 let ll = self.log_likelihood(&adj_matrix, &b_matrix, &theta, &tau, n, k);
251 let icl = self.compute_icl(ll, &compacted, n, k);
252 let quality = graph.modularity(&compacted);
253
254 let degree_corrections = if self.config.degree_corrected {
255 Some(theta)
256 } else {
257 None
258 };
259
260 Ok(SBMResult {
261 community: CommunityResult {
262 labels: compacted,
263 num_communities,
264 quality_score: Some(quality),
265 },
266 block_matrix: b_matrix,
267 k,
268 log_likelihood: ll,
269 icl_score: icl,
270 degree_corrections,
271 })
272 }
273
274 fn m_step(
276 &self,
277 _graph: &AdjacencyGraph,
278 adj_matrix: &[f64],
279 tau: &[f64],
280 b_matrix: &mut [f64],
281 theta: &mut [f64],
282 n: usize,
283 k: usize,
284 ) {
285 for r in 0..k {
287 for s in 0..k {
288 let mut numerator = 0.0;
289 let mut denominator = 0.0;
290 for i in 0..n {
291 let tau_ir = tau[i * k + r];
292 if tau_ir < 1e-15 {
293 continue;
294 }
295 for j in 0..n {
296 if i == j {
297 continue;
298 }
299 let tau_js = tau[j * k + s];
300 if tau_js < 1e-15 {
301 continue;
302 }
303 numerator += tau_ir * adj_matrix[i * n + j] * tau_js;
304 denominator += tau_ir * tau_js;
305 }
306 }
307 let val = if denominator > 1e-15 {
309 numerator / denominator
310 } else {
311 0.5
312 };
313 b_matrix[r * k + s] = val.clamp(1e-10, 1.0 - 1e-10);
314 }
315 }
316
317 if self.config.degree_corrected {
319 for i in 0..n {
321 let actual_deg: f64 = (0..n)
322 .filter(|&j| j != i)
323 .map(|j| adj_matrix[i * n + j])
324 .sum();
325
326 let mut expected = 0.0;
327 for j in 0..n {
328 if j == i {
329 continue;
330 }
331 for r in 0..k {
332 for s in 0..k {
333 expected += tau[i * k + r] * tau[j * k + s] * b_matrix[r * k + s];
334 }
335 }
336 }
337 theta[i] = if expected > 1e-15 {
338 (actual_deg / expected).max(1e-10)
339 } else {
340 1.0
341 };
342 }
343 }
344 }
345
346 fn e_step(
348 &self,
349 _graph: &AdjacencyGraph,
350 adj_matrix: &[f64],
351 b_matrix: &[f64],
352 theta: &[f64],
353 tau: &mut [f64],
354 n: usize,
355 k: usize,
356 ) {
357 let mut pi = vec![0.0_f64; k];
359 for i in 0..n {
360 for r in 0..k {
361 pi[r] += tau[i * k + r];
362 }
363 }
364 let pi_sum: f64 = pi.iter().sum();
365 if pi_sum > 0.0 {
366 for r in 0..k {
367 pi[r] = (pi[r] / pi_sum).max(1e-10);
368 }
369 }
370
371 for i in 0..n {
372 let mut log_probs = vec![0.0_f64; k];
373 for r in 0..k {
374 log_probs[r] = pi[r].ln();
375
376 for j in 0..n {
377 if j == i {
378 continue;
379 }
380 for s in 0..k {
383 let tau_js = tau[j * k + s];
384 if tau_js < 1e-15 {
385 continue;
386 }
387
388 let mut p_rs = b_matrix[r * k + s];
389 if self.config.degree_corrected {
390 p_rs *= theta[i] * theta[j];
391 }
392 p_rs = p_rs.clamp(1e-15, 1.0 - 1e-15);
393
394 let a_ij = adj_matrix[i * n + j];
395 if a_ij > 0.0 {
396 log_probs[r] += tau_js * (a_ij * p_rs.ln());
397 } else {
398 log_probs[r] += tau_js * ((1.0 - p_rs).ln());
399 }
400 }
401 }
402 }
403
404 let max_lp = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
406 let mut sum_exp = 0.0;
407 for r in 0..k {
408 log_probs[r] = (log_probs[r] - max_lp).exp();
409 sum_exp += log_probs[r];
410 }
411 if sum_exp > 0.0 {
412 for r in 0..k {
413 tau[i * k + r] = (log_probs[r] / sum_exp).max(1e-15);
414 }
415 }
416 }
417 }
418
419 fn log_likelihood(
421 &self,
422 adj_matrix: &[f64],
423 b_matrix: &[f64],
424 theta: &[f64],
425 tau: &[f64],
426 n: usize,
427 k: usize,
428 ) -> f64 {
429 let mut ll = 0.0;
430 for i in 0..n {
431 for j in (i + 1)..n {
432 let a_ij = adj_matrix[i * n + j];
433 for r in 0..k {
434 let tau_ir = tau[i * k + r];
435 if tau_ir < 1e-15 {
436 continue;
437 }
438 for s in 0..k {
439 let tau_js = tau[j * k + s];
440 if tau_js < 1e-15 {
441 continue;
442 }
443 let mut p = b_matrix[r * k + s];
444 if self.config.degree_corrected {
445 p *= theta[i] * theta[j];
446 }
447 p = p.clamp(1e-15, 1.0 - 1e-15);
448
449 if a_ij > 0.0 {
450 ll += tau_ir * tau_js * a_ij * p.ln();
451 } else {
452 ll += tau_ir * tau_js * (1.0 - p).ln();
453 }
454 }
455 }
456 }
457 }
458 ll
459 }
460
461 fn compute_icl(&self, ll: f64, labels: &[usize], n: usize, k: usize) -> f64 {
466 let n_f = n as f64;
467 let k_f = k as f64;
468 let n_b_params = k_f * (k_f + 1.0) / 2.0;
470 let n_pairs = n_f * (n_f - 1.0) / 2.0;
472 let penalty =
473 n_b_params * n_pairs.max(1.0).ln() / 2.0 + (k_f - 1.0) * n_f.max(1.0).ln() / 2.0;
474
475 let mut block_sizes = vec![0usize; k];
479 for &l in labels {
480 if l < k {
481 block_sizes[l] += 1;
482 }
483 }
484 let entropy_correction: f64 = block_sizes
485 .iter()
486 .filter(|&&s| s > 0)
487 .map(|&s| {
488 let p = s as f64 / n_f;
489 -(s as f64) * p.ln()
490 })
491 .sum();
492
493 ll - penalty - entropy_correction
494 }
495
496 pub fn predict(
500 &self,
501 graph: &AdjacencyGraph,
502 b_matrix: &[f64],
503 k: usize,
504 ) -> Result<Vec<usize>> {
505 let n = graph.n_nodes;
506 if n == 0 {
507 return Err(ClusteringError::InvalidInput(
508 "Graph has zero nodes".to_string(),
509 ));
510 }
511 if b_matrix.len() != k * k {
512 return Err(ClusteringError::InvalidInput(
513 "B matrix size mismatch".to_string(),
514 ));
515 }
516
517 let mut adj_matrix = vec![0.0_f64; n * n];
519 for i in 0..n {
520 for &(j, w) in &graph.adjacency[i] {
521 adj_matrix[i * n + j] = w;
522 }
523 }
524
525 let uniform = 1.0 / k as f64;
527 let mut tau = vec![uniform; n * k];
528 let theta = vec![1.0_f64; n];
529
530 for _iter in 0..self.config.max_iterations {
531 self.e_step(graph, &adj_matrix, b_matrix, &theta, &mut tau, n, k);
532 }
533
534 let mut labels = vec![0usize; n];
536 for i in 0..n {
537 let mut best_r = 0;
538 let mut best_val = f64::NEG_INFINITY;
539 for r in 0..k {
540 if tau[i * k + r] > best_val {
541 best_val = tau[i * k + r];
542 best_r = r;
543 }
544 }
545 labels[i] = best_r;
546 }
547
548 Ok(labels)
549 }
550
551 pub fn generate(
558 n: usize,
559 k: usize,
560 b_matrix: &[f64],
561 block_sizes: &[usize],
562 seed: u64,
563 ) -> Result<(AdjacencyGraph, Vec<usize>)> {
564 if b_matrix.len() != k * k {
565 return Err(ClusteringError::InvalidInput(
566 "B matrix size must be k*k".to_string(),
567 ));
568 }
569 if block_sizes.len() != k {
570 return Err(ClusteringError::InvalidInput(
571 "block_sizes length must equal k".to_string(),
572 ));
573 }
574 let total: usize = block_sizes.iter().sum();
575 if total != n {
576 return Err(ClusteringError::InvalidInput(format!(
577 "block_sizes sum ({}) must equal n ({})",
578 total, n
579 )));
580 }
581
582 let mut rng = Xorshift64::new(seed);
583
584 let mut labels = Vec::with_capacity(n);
586 for (block, &size) in block_sizes.iter().enumerate() {
587 for _ in 0..size {
588 labels.push(block);
589 }
590 }
591
592 let mut graph = AdjacencyGraph::new(n);
594 for i in 0..n {
595 for j in (i + 1)..n {
596 let r = labels[i];
597 let s = labels[j];
598 let p = b_matrix[r * k + s];
599 if rng.next_f64() < p {
600 let _ = graph.add_edge(i, j, 1.0);
601 }
602 }
603 }
604
605 Ok((graph, labels))
606 }
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
619 fn test_sbm_generate_and_fit() {
620 let k = 2;
621 let n = 20;
622 let b_matrix = vec![0.8, 0.05, 0.05, 0.8];
624 let block_sizes = vec![10, 10];
625 let (graph, true_labels) =
626 StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 123)
627 .expect("generate should succeed");
628
629 let config = StochasticBlockModelConfig {
630 num_blocks: Some(2),
631 max_iterations: 50,
632 seed: 42,
633 ..Default::default()
634 };
635 let sbm = StochasticBlockModel::new(config);
636 let result = sbm.fit(&graph).expect("fit should succeed");
637
638 assert_eq!(result.community.num_communities, 2);
639 assert_eq!(result.community.labels.len(), n);
640
641 let accuracy = compute_accuracy(&true_labels, &result.community.labels, k);
644 assert!(accuracy >= 0.7, "Accuracy {} is too low", accuracy);
645 }
646
647 #[test]
649 fn test_sbm_degree_corrected() {
650 let k = 2;
651 let n = 20;
652 let b_matrix = vec![0.7, 0.1, 0.1, 0.7];
653 let block_sizes = vec![10, 10];
654 let (graph, _) = StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 456)
655 .expect("generate should succeed");
656
657 let config = StochasticBlockModelConfig {
658 num_blocks: Some(2),
659 degree_corrected: true,
660 max_iterations: 30,
661 seed: 789,
662 ..Default::default()
663 };
664 let sbm = StochasticBlockModel::new(config);
665 let result = sbm.fit(&graph).expect("fit should succeed");
666
667 assert!(result.degree_corrections.is_some());
668 let dc = result
669 .degree_corrections
670 .as_ref()
671 .expect("should have degree corrections");
672 assert_eq!(dc.len(), n);
673 for &d in dc {
675 assert!(d > 0.0);
676 }
677 }
678
679 #[test]
681 fn test_sbm_model_selection() {
682 let k = 2;
683 let n = 30;
684 let b_matrix = vec![0.9, 0.05, 0.05, 0.9];
685 let block_sizes = vec![15, 15];
686 let (graph, _) = StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 111)
687 .expect("generate should succeed");
688
689 let config = StochasticBlockModelConfig {
690 num_blocks: None,
691 k_range: (2, 5),
692 max_iterations: 30,
693 seed: 222,
694 ..Default::default()
695 };
696 let sbm = StochasticBlockModel::new(config);
697 let result = sbm.fit(&graph).expect("fit should succeed");
698
699 assert!(
701 result.k >= 2 && result.k <= 3,
702 "Selected K={} seems wrong",
703 result.k
704 );
705 }
706
707 #[test]
709 fn test_sbm_predict() {
710 let k = 2;
711 let n = 20;
712 let b_matrix = vec![0.8, 0.05, 0.05, 0.8];
713 let block_sizes = vec![10, 10];
714 let (graph, true_labels) =
715 StochasticBlockModel::generate(n, k, &b_matrix, &block_sizes, 333)
716 .expect("generate should succeed");
717
718 let config = StochasticBlockModelConfig {
719 max_iterations: 30,
720 seed: 444,
721 ..Default::default()
722 };
723 let sbm = StochasticBlockModel::new(config);
724 let predicted = sbm
725 .predict(&graph, &b_matrix, k)
726 .expect("predict should succeed");
727
728 assert_eq!(predicted.len(), n);
729 let accuracy = compute_accuracy(&true_labels, &predicted, k);
730 assert!(accuracy >= 0.6, "Predict accuracy {} is too low", accuracy);
731 }
732
733 #[test]
735 fn test_sbm_generate_invalid() {
736 let result = StochasticBlockModel::generate(10, 2, &[0.5, 0.1, 0.1, 0.5], &[4, 4], 0);
738 assert!(result.is_err());
739 }
740
741 #[test]
743 fn test_sbm_empty_graph() {
744 let g = AdjacencyGraph::new(0);
745 let config = StochasticBlockModelConfig {
746 num_blocks: Some(2),
747 ..Default::default()
748 };
749 let sbm = StochasticBlockModel::new(config);
750 assert!(sbm.fit(&g).is_err());
751 }
752
753 #[test]
755 fn test_sbm_single_block() {
756 let n = 10;
757 let mut g = AdjacencyGraph::new(n);
758 for i in 0..n {
759 for j in (i + 1)..n {
760 let _ = g.add_edge(i, j, 1.0);
761 }
762 }
763 let config = StochasticBlockModelConfig {
764 num_blocks: Some(1),
765 max_iterations: 20,
766 seed: 555,
767 ..Default::default()
768 };
769 let sbm = StochasticBlockModel::new(config);
770 let result = sbm.fit(&g).expect("fit should succeed");
771 assert_eq!(result.community.num_communities, 1);
772 }
773
774 fn compute_accuracy(true_labels: &[usize], pred_labels: &[usize], k: usize) -> f64 {
780 let n = true_labels.len();
781 if n == 0 {
782 return 1.0;
783 }
784
785 let perms = generate_permutations(k);
788 let mut best_correct = 0usize;
789 for perm in &perms {
790 let correct = (0..n)
791 .filter(|&i| {
792 let mapped = if pred_labels[i] < perm.len() {
793 perm[pred_labels[i]]
794 } else {
795 pred_labels[i]
796 };
797 mapped == true_labels[i]
798 })
799 .count();
800 if correct > best_correct {
801 best_correct = correct;
802 }
803 }
804 best_correct as f64 / n as f64
805 }
806
807 fn generate_permutations(k: usize) -> Vec<Vec<usize>> {
808 if k == 0 {
809 return vec![vec![]];
810 }
811 if k == 1 {
812 return vec![vec![0]];
813 }
814 let mut result = Vec::new();
815 let sub = generate_permutations(k - 1);
816 for perm in sub {
817 for pos in 0..k {
818 let mut new_perm = perm.clone();
819 new_perm.insert(pos, k - 1);
820 result.push(new_perm);
821 }
822 }
823 result
824 }
825}