use crate::DType;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use numr::error::Result;
use numr::runtime::{Runtime, RuntimeClient};
use numr::sparse::SparseTensor;
use numr::tensor::Tensor;
use crate::graph::traits::types::{AllPairsResult, GraphData};
use super::helpers::extract_csr_arrays;
pub fn johnson_impl<R, C>(_client: &C, graph: &GraphData<R>) -> Result<AllPairsResult<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let n = graph.num_nodes;
let (row_ptrs, col_indices, values, _) = extract_csr_arrays(graph)?;
let device = match &graph.adjacency {
SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => unreachable!(), };
let mut h = vec![f64::INFINITY; n];
h[0] = 0.0;
for _ in 0..(n - 1) {
let mut updated = false;
for u in 0..n {
if h[u].is_infinite() {
continue;
}
let start = row_ptrs[u] as usize;
let end = row_ptrs[u + 1] as usize;
for i in start..end {
let v = col_indices[i] as usize;
let weight = values[i];
let new_h = h[u] + weight;
if new_h < h[v] {
h[v] = new_h;
updated = true;
}
}
}
if !updated {
break;
}
}
let mut new_weights = values.clone();
for u in 0..n {
let start = row_ptrs[u] as usize;
let end = row_ptrs[u + 1] as usize;
for i in start..end {
let v = col_indices[i] as usize;
new_weights[i] += h[u] - h[v];
}
}
let mut all_distances = vec![f64::INFINITY; n * n];
let mut all_predecessors = vec![-1i64; n * n];
for source in 0..n {
let mut dist = vec![f64::INFINITY; n];
let mut pred = vec![-1i64; n];
dist[source] = 0.0;
let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::new();
heap.push(Reverse((0.0_f64.to_bits(), source)));
while let Some(Reverse((d_bits, u))) = heap.pop() {
let d = f64::from_bits(d_bits);
if d > dist[u] {
continue;
}
let start = row_ptrs[u] as usize;
let end = row_ptrs[u + 1] as usize;
for i in start..end {
let v = col_indices[i] as usize;
let weight = new_weights[i];
let new_dist = dist[u] + weight;
if new_dist < dist[v] {
dist[v] = new_dist;
pred[v] = u as i64;
heap.push(Reverse((new_dist.to_bits(), v)));
}
}
}
for v in 0..n {
if !dist[v].is_infinite() {
dist[v] = dist[v] + h[v] - h[source];
}
all_distances[source * n + v] = dist[v];
all_predecessors[source * n + v] = pred[v];
}
}
let distances = Tensor::<R>::from_slice(&all_distances, &[n, n], &device);
let predecessors = Tensor::<R>::from_slice(&all_predecessors, &[n, n], &device);
Ok(AllPairsResult {
distances,
predecessors,
})
}