use crate::DType;
use numr::error::Result;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::graph::traits::types::{ComponentResult, GraphData};
use super::helpers::extract_csr_arrays;
pub fn tarjan_impl<R, C>(_client: &C, graph: &GraphData<R>) -> Result<ComponentResult<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let (row_ptrs, col_indices, _values, n) = extract_csr_arrays(graph)?;
let device = match &graph.adjacency {
numr::sparse::SparseTensor::Csr(csr) => csr.values().device().clone(),
_ => unreachable!(),
};
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for u in 0..n {
let start = row_ptrs[u] as usize;
let end = row_ptrs[u + 1] as usize;
for &v_idx in col_indices.iter().take(end).skip(start) {
let v = v_idx as usize;
adj[u].push(v);
}
}
let mut index_counter = 0;
let mut stack: Vec<usize> = Vec::new();
let mut indices = vec![-1i64; n];
let mut lowlinks = vec![-1i64; n];
let mut on_stack = vec![false; n];
let mut labels = vec![-1i64; n];
let mut num_components = 0;
#[allow(clippy::too_many_arguments)]
fn tarjan_dfs(
u: usize,
index_counter: &mut i64,
stack: &mut Vec<usize>,
indices: &mut [i64],
lowlinks: &mut [i64],
on_stack: &mut [bool],
labels: &mut [i64],
num_components: &mut usize,
adj: &[Vec<usize>],
) {
indices[u] = *index_counter;
lowlinks[u] = *index_counter;
*index_counter += 1;
stack.push(u);
on_stack[u] = true;
for &v in &adj[u] {
if indices[v] == -1 {
tarjan_dfs(
v,
index_counter,
stack,
indices,
lowlinks,
on_stack,
labels,
num_components,
adj,
);
lowlinks[u] = lowlinks[u].min(lowlinks[v]);
} else if on_stack[v] {
lowlinks[u] = lowlinks[u].min(indices[v]);
}
}
if lowlinks[u] == indices[u] {
let comp_id = *num_components as i64;
while let Some(w) = stack.pop() {
on_stack[w] = false;
labels[w] = comp_id;
if w == u {
break;
}
}
*num_components += 1;
}
}
for u in 0..n {
if indices[u] == -1 {
tarjan_dfs(
u,
&mut index_counter,
&mut stack,
&mut indices,
&mut lowlinks,
&mut on_stack,
&mut labels,
&mut num_components,
&adj,
);
}
}
let labels_tensor = Tensor::<R>::from_slice(&labels, &[n], &device);
Ok(ComponentResult {
labels: labels_tensor,
num_components,
})
}