use std::collections::VecDeque;
use super::params::AlgoParams;
use super::result::AlgoResultBatch;
use crate::engine::graph::algo::GraphAlgorithm;
use crate::engine::graph::csr::CsrIndex;
pub fn run(csr: &CsrIndex, params: &AlgoParams) -> AlgoResultBatch {
let n = csr.node_count();
if n == 0 {
return AlgoResultBatch::new(GraphAlgorithm::Diameter);
}
let mode = params
.mode
.as_deref()
.unwrap_or("APPROXIMATE")
.to_uppercase();
let (diameter, radius) = if mode == "EXACT" {
compute_exact(csr, n)
} else {
compute_approximate(csr, n)
};
let mut batch = AlgoResultBatch::new(GraphAlgorithm::Diameter);
batch.push_diameter(diameter, radius);
batch
}
fn compute_exact(csr: &CsrIndex, n: usize) -> (i64, i64) {
let mut max_ecc = 0i64;
let mut min_ecc = i64::MAX;
for v in 0..n {
let ecc = bfs_eccentricity(csr, v as u32, n);
if ecc > 0 {
max_ecc = max_ecc.max(ecc);
min_ecc = min_ecc.min(ecc);
}
}
if min_ecc == i64::MAX {
min_ecc = 0;
}
(max_ecc, min_ecc)
}
fn compute_approximate(csr: &CsrIndex, n: usize) -> (i64, i64) {
let start = (0..n)
.find(|&i| csr.out_degree(i as u32) > 0 || csr.in_degree(i as u32) > 0)
.unwrap_or(0);
let (farthest1, _dist1) = bfs_farthest(csr, start as u32, n);
let (_farthest2, dist2) = bfs_farthest(csr, farthest1, n);
let diameter = dist2 as i64;
let radius = (diameter + 1) / 2;
(diameter, radius)
}
fn bfs_eccentricity(csr: &CsrIndex, source: u32, n: usize) -> i64 {
let mut dist = vec![u32::MAX; n];
dist[source as usize] = 0;
let mut queue = VecDeque::new();
queue.push_back(source);
let mut max_dist = 0u32;
while let Some(v) = queue.pop_front() {
let d = dist[v as usize];
for neighbor in undirected_neighbors(csr, v) {
let ni = neighbor as usize;
if dist[ni] == u32::MAX {
dist[ni] = d + 1;
if dist[ni] > max_dist {
max_dist = dist[ni];
}
queue.push_back(neighbor);
}
}
}
max_dist as i64
}
fn bfs_farthest(csr: &CsrIndex, source: u32, n: usize) -> (u32, u32) {
let mut dist = vec![u32::MAX; n];
dist[source as usize] = 0;
let mut queue = VecDeque::new();
queue.push_back(source);
let mut farthest = source;
let mut max_dist = 0u32;
while let Some(v) = queue.pop_front() {
let d = dist[v as usize];
for neighbor in undirected_neighbors(csr, v) {
let ni = neighbor as usize;
if dist[ni] == u32::MAX {
dist[ni] = d + 1;
if dist[ni] > max_dist {
max_dist = dist[ni];
farthest = neighbor;
}
queue.push_back(neighbor);
}
}
}
(farthest, max_dist)
}
use super::util::undirected_neighbors;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn diameter_path() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("b", "L", "c");
csr.add_edge("c", "L", "d");
csr.compact();
let batch = run(
&csr,
&AlgoParams {
mode: Some("EXACT".into()),
..Default::default()
},
);
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
assert_eq!(rows[0]["diameter"].as_i64().unwrap(), 3);
assert_eq!(rows[0]["radius"].as_i64().unwrap(), 2);
}
#[test]
fn diameter_triangle() {
let mut csr = CsrIndex::new();
for (s, d) in &[("a", "b"), ("b", "c"), ("c", "a")] {
csr.add_edge(s, "L", d);
}
csr.compact();
let batch = run(
&csr,
&AlgoParams {
mode: Some("EXACT".into()),
..Default::default()
},
);
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
let d = rows[0]["diameter"].as_i64().unwrap();
assert!((1..=2).contains(&d));
}
#[test]
fn diameter_approximate() {
let mut csr = CsrIndex::new();
csr.add_edge("a", "L", "b");
csr.add_edge("b", "L", "c");
csr.add_edge("c", "L", "d");
csr.add_edge("d", "L", "e");
csr.compact();
let batch = run(&csr, &AlgoParams::default());
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
let d = rows[0]["diameter"].as_i64().unwrap();
assert_eq!(d, 4);
}
#[test]
fn diameter_empty() {
let csr = CsrIndex::new();
assert!(run(&csr, &AlgoParams::default()).is_empty());
}
#[test]
fn diameter_single_node() {
let mut csr = CsrIndex::new();
csr.add_node("solo");
csr.compact();
let batch = run(
&csr,
&AlgoParams {
mode: Some("EXACT".into()),
..Default::default()
},
);
let json = batch.to_json().unwrap();
let rows: Vec<serde_json::Value> = serde_json::from_slice(&json).unwrap();
assert_eq!(rows[0]["diameter"].as_i64().unwrap(), 0);
}
}