use crate::DType;
use numr::error::{Error, Result};
use numr::ops::{BinaryOps, ReduceOps, ScalarOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::sparse::{SparseOps, SparseTensor};
use numr::tensor::Tensor;
use crate::graph::traits::types::{EigCentralityOptions, GraphData};
pub fn eigenvector_centrality_impl<R, C>(
client: &C,
graph: &GraphData<R>,
options: &EigCentralityOptions,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + SparseOps<R> + BinaryOps<R> + ReduceOps<R> + ScalarOps<R> + UnaryOps<R>,
{
let n = graph.num_nodes;
if n == 0 {
let device = match &graph.adjacency {
SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => {
return Err(Error::InvalidArgument {
arg: "graph",
reason: "Expected CSR format".to_string(),
});
}
};
return Ok(Tensor::<R>::from_slice(&[] as &[f64], &[0], &device));
}
let init_val = 1.0 / (n as f64).sqrt();
let device = match &graph.adjacency {
SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => {
return Err(Error::InvalidArgument {
arg: "graph",
reason: "Expected CSR format".to_string(),
});
}
};
let mut x = Tensor::<R>::from_slice(&vec![init_val; n], &[n], &device);
for _ in 0..options.max_iter {
let x_new = client.spmv(&graph.adjacency, &x)?;
let x_sq = client.mul(&x_new, &x_new)?;
let sum_sq = client.sum(&x_sq, &[0], false)?;
let norm_val: f64 = sum_sq.to_vec()[0]; let norm = norm_val.sqrt();
if norm < 1e-15 {
return Ok(x);
}
let x_normalized = client.mul_scalar(&x_new, 1.0 / norm)?;
let diff = client.sub(&x_normalized, &x)?;
let diff_sq = client.mul(&diff, &diff)?;
let diff_sum = client.sum(&diff_sq, &[0], false)?;
let diff_norm: f64 = diff_sum.to_vec()[0];
x = x_normalized;
if diff_norm.sqrt() < options.tol {
break;
}
}
client.abs(&x)
}