use crate::DType;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use numr::error::Result;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::graph::traits::types::{GraphData, ShortestPathResult};
use super::helpers::{extract_csr_arrays, validate_node};
pub fn dijkstra_impl<R, C>(
_client: &C,
graph: &GraphData<R>,
source: usize,
) -> Result<ShortestPathResult<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
validate_node(source, graph.num_nodes, "dijkstra source")?;
let (row_ptrs, col_indices, values, n) = extract_csr_arrays(graph)?;
let device = match &graph.adjacency {
numr::sparse::SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => unreachable!(), };
let mut distances = vec![f64::INFINITY; n];
let mut predecessors = vec![-1i64; n];
distances[source] = 0.0;
let mut heap = BinaryHeap::new();
heap.push(Reverse((0.0_f64.to_bits(), source as i64)));
while let Some(Reverse((dist_bits, u))) = heap.pop() {
let u = u as usize;
let dist = f64::from_bits(dist_bits);
if dist > distances[u] {
continue;
}
let start = row_ptrs[u] as usize;
let end = row_ptrs[u + 1] as usize;
for i in start..end {
let v = col_indices[i] as usize;
let weight = values[i];
let new_dist = distances[u] + weight;
if new_dist < distances[v] {
distances[v] = new_dist;
predecessors[v] = u as i64;
heap.push(Reverse((new_dist.to_bits(), v as i64)));
}
}
}
let dist_tensor = Tensor::<R>::from_slice(&distances, &[n], &device);
let pred_tensor = Tensor::<R>::from_slice(&predecessors, &[n], &device);
Ok(ShortestPathResult {
distances: dist_tensor,
predecessors: pred_tensor,
})
}