use crate::DType;
use numr::error::Result;
use numr::ops::{BinaryOps, CompareOps, ConditionalOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::sparse::SparseTensor;
use numr::tensor::Tensor;
use crate::graph::traits::types::{AllPairsResult, GraphData};
use super::helpers::extract_csr_arrays;
pub fn floyd_warshall_impl<R, C>(client: &C, graph: &GraphData<R>) -> Result<AllPairsResult<R>>
where
R: Runtime<DType = DType>,
C: BinaryOps<R> + CompareOps<R> + ConditionalOps<R> + RuntimeClient<R>,
{
let n = graph.num_nodes;
let (row_ptrs, col_indices, values, _) = extract_csr_arrays(graph)?;
let device = match &graph.adjacency {
SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => unreachable!(), };
let mut distances = vec![f64::INFINITY; n * n];
let mut predecessors = vec![-1i64; n * n];
for i in 0..n {
distances[i * n + i] = 0.0;
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
for idx in start..end {
let j = col_indices[idx] as usize;
let weight = values[idx];
distances[i * n + j] = weight;
predecessors[i * n + j] = i as i64;
}
}
let mut d = Tensor::<R>::from_slice(&distances, &[n, n], &device);
let mut pred = Tensor::<R>::from_slice(&predecessors, &[n, n], &device);
for k in 0..n {
let d_ik = d.narrow(1, k, 1)?; let d_kj = d.narrow(0, k, 1)?;
let path_sum = client.add(&d_ik, &d_kj)?;
let old_d = d.clone();
d = client.minimum(&old_d, &path_sum)?;
let mask = client.lt(&path_sum, &old_d)?;
let pred_kj = pred.narrow(0, k, 1)?;
pred = client.where_cond(&mask, &pred_kj, &pred)?;
}
let dist_vec: Vec<f64> = d.to_vec();
let pred_vec: Vec<i64> = pred.to_vec();
let distances_tensor = Tensor::<R>::from_slice(&dist_vec, &[n, n], &device);
let predecessors_tensor = Tensor::<R>::from_slice(&pred_vec, &[n, n], &device);
Ok(AllPairsResult {
distances: distances_tensor,
predecessors: predecessors_tensor,
})
}