use crate::DType;
use numr::error::{Error, Result};
use numr::ops::ScalarOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::sparse::{SparseOps, SparseTensor};
use numr::tensor::Tensor;
use crate::graph::traits::types::GraphData;
pub fn degree_centrality_impl<R, C>(client: &C, graph: &GraphData<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + SparseOps<R> + ScalarOps<R>,
{
let n = graph.num_nodes;
if n <= 1 {
let device = match &graph.adjacency {
SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => {
return Err(Error::InvalidArgument {
arg: "graph",
reason: "Expected CSR format".to_string(),
});
}
};
return Ok(Tensor::<R>::from_slice(&vec![0.0f64; n], &[n], &device));
}
let device = match &graph.adjacency {
SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => {
return Err(Error::InvalidArgument {
arg: "graph",
reason: "Expected CSR format".to_string(),
});
}
};
let ones = Tensor::<R>::from_slice(&vec![1.0f64; n], &[n], &device);
let degrees = client.spmv(&graph.adjacency, &ones)?;
let scale = 1.0 / (n - 1) as f64;
client.mul_scalar(°rees, scale)
}