use crate::error::Result;
#[derive(Debug, Clone)]
pub struct LevelSchedule {
pub level_of_row: Vec<usize>,
pub rows_per_level: Vec<Vec<usize>>,
pub num_levels: usize,
pub max_parallelism: usize,
}
pub fn compute_levels_lower(
n: usize,
row_ptrs: &[i64],
col_indices: &[i64],
) -> Result<LevelSchedule> {
let mut level_of_row = vec![0usize; n];
for i in 0..n {
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
let mut max_dep_level = 0usize;
for idx in start..end {
let j = col_indices[idx] as usize;
if j < i {
max_dep_level = max_dep_level.max(level_of_row[j] + 1);
}
}
level_of_row[i] = max_dep_level;
}
let num_levels = level_of_row.iter().max().map(|&x| x + 1).unwrap_or(0);
let mut rows_per_level: Vec<Vec<usize>> = vec![Vec::new(); num_levels];
for (row, &level) in level_of_row.iter().enumerate() {
rows_per_level[level].push(row);
}
let max_parallelism = rows_per_level.iter().map(|v| v.len()).max().unwrap_or(0);
Ok(LevelSchedule {
level_of_row,
rows_per_level,
num_levels,
max_parallelism,
})
}
pub fn compute_levels_upper(
n: usize,
row_ptrs: &[i64],
col_indices: &[i64],
) -> Result<LevelSchedule> {
let mut level_of_row = vec![0usize; n];
for i in (0..n).rev() {
let start = row_ptrs[i] as usize;
let end = row_ptrs[i + 1] as usize;
let mut max_dep_level = 0usize;
for idx in start..end {
let j = col_indices[idx] as usize;
if j > i {
max_dep_level = max_dep_level.max(level_of_row[j] + 1);
}
}
level_of_row[i] = max_dep_level;
}
let num_levels = level_of_row.iter().max().map(|&x| x + 1).unwrap_or(0);
let mut rows_per_level: Vec<Vec<usize>> = vec![Vec::new(); num_levels];
for (row, &level) in level_of_row.iter().enumerate() {
rows_per_level[level].push(row);
}
let max_parallelism = rows_per_level.iter().map(|v| v.len()).max().unwrap_or(0);
Ok(LevelSchedule {
level_of_row,
rows_per_level,
num_levels,
max_parallelism,
})
}
pub fn compute_levels_ilu(
n: usize,
row_ptrs: &[i64],
col_indices: &[i64],
) -> Result<LevelSchedule> {
compute_levels_lower(n, row_ptrs, col_indices)
}
pub fn compute_levels_csc_lower(
n: usize,
col_ptrs: &[i64],
row_indices: &[i64],
) -> Result<LevelSchedule> {
let mut row_to_cols: Vec<Vec<usize>> = vec![Vec::new(); n];
for col in 0..n {
let start = col_ptrs[col] as usize;
let end = col_ptrs[col + 1] as usize;
for idx in start..end {
let row = row_indices[idx] as usize;
row_to_cols[row].push(col);
}
}
let mut level_of_col = vec![0usize; n];
for j in 0..n {
let mut max_dep_level = 0usize;
for &k in &row_to_cols[j] {
if k < j {
max_dep_level = max_dep_level.max(level_of_col[k] + 1);
}
}
level_of_col[j] = max_dep_level;
}
let num_levels = level_of_col.iter().max().map(|&x| x + 1).unwrap_or(0);
let mut rows_per_level: Vec<Vec<usize>> = vec![Vec::new(); num_levels];
for (col, &level) in level_of_col.iter().enumerate() {
rows_per_level[level].push(col);
}
let max_parallelism = rows_per_level.iter().map(|v| v.len()).max().unwrap_or(0);
Ok(LevelSchedule {
level_of_row: level_of_col, rows_per_level, num_levels,
max_parallelism,
})
}
pub fn compute_levels_csc_upper(
n: usize,
col_ptrs: &[i64],
row_indices: &[i64],
) -> Result<LevelSchedule> {
let mut row_to_cols: Vec<Vec<usize>> = vec![Vec::new(); n];
for col in 0..n {
let start = col_ptrs[col] as usize;
let end = col_ptrs[col + 1] as usize;
for idx in start..end {
let row = row_indices[idx] as usize;
row_to_cols[row].push(col);
}
}
let mut level_of_col = vec![0usize; n];
for j in (0..n).rev() {
let mut max_dep_level = 0usize;
for &k in &row_to_cols[j] {
if k > j {
max_dep_level = max_dep_level.max(level_of_col[k] + 1);
}
}
level_of_col[j] = max_dep_level;
}
let num_levels = level_of_col.iter().max().map(|&x| x + 1).unwrap_or(0);
let mut rows_per_level: Vec<Vec<usize>> = vec![Vec::new(); num_levels];
for (col, &level) in level_of_col.iter().enumerate() {
rows_per_level[level].push(col);
}
let max_parallelism = rows_per_level.iter().map(|v| v.len()).max().unwrap_or(0);
Ok(LevelSchedule {
level_of_row: level_of_col,
rows_per_level,
num_levels,
max_parallelism,
})
}
pub fn flatten_levels(schedule: &LevelSchedule) -> (Vec<i32>, Vec<i32>) {
let n: usize = schedule.level_of_row.len();
let mut level_ptrs = Vec::with_capacity(schedule.num_levels + 1);
let mut level_rows = Vec::with_capacity(n);
level_ptrs.push(0i32);
for level_row_list in &schedule.rows_per_level {
for &row in level_row_list {
level_rows.push(row as i32);
}
level_ptrs.push(level_rows.len() as i32);
}
(level_ptrs, level_rows)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levels_lower_diagonal() {
let row_ptrs = vec![0i64, 1, 2, 3];
let col_indices = vec![0i64, 1, 2];
let schedule = compute_levels_lower(3, &row_ptrs, &col_indices).unwrap();
assert_eq!(schedule.num_levels, 1);
assert_eq!(schedule.level_of_row, vec![0, 0, 0]);
assert_eq!(schedule.max_parallelism, 3);
}
#[test]
fn test_levels_lower_tridiagonal() {
let row_ptrs = vec![0i64, 1, 3, 5];
let col_indices = vec![0i64, 0, 1, 1, 2];
let schedule = compute_levels_lower(3, &row_ptrs, &col_indices).unwrap();
assert_eq!(schedule.num_levels, 3);
assert_eq!(schedule.level_of_row, vec![0, 1, 2]);
assert_eq!(schedule.max_parallelism, 1); }
#[test]
fn test_levels_lower_with_parallelism() {
let row_ptrs = vec![0i64, 1, 2, 4, 6];
let col_indices = vec![0i64, 1, 0, 2, 1, 3];
let schedule = compute_levels_lower(4, &row_ptrs, &col_indices).unwrap();
assert_eq!(schedule.num_levels, 2);
assert_eq!(schedule.level_of_row, vec![0, 0, 1, 1]);
assert_eq!(schedule.max_parallelism, 2);
assert_eq!(schedule.rows_per_level[0], vec![0, 1]);
assert_eq!(schedule.rows_per_level[1], vec![2, 3]);
}
#[test]
fn test_levels_upper() {
let row_ptrs = vec![0i64, 2, 4, 5];
let col_indices = vec![0i64, 1, 1, 2, 2];
let schedule = compute_levels_upper(3, &row_ptrs, &col_indices).unwrap();
assert_eq!(schedule.num_levels, 3);
assert_eq!(schedule.level_of_row, vec![2, 1, 0]);
}
#[test]
fn test_flatten_levels() {
let schedule = LevelSchedule {
level_of_row: vec![0, 0, 1, 1],
rows_per_level: vec![vec![0, 1], vec![2, 3]],
num_levels: 2,
max_parallelism: 2,
};
let (level_ptrs, level_rows) = flatten_levels(&schedule);
assert_eq!(level_ptrs, vec![0, 2, 4]);
assert_eq!(level_rows, vec![0, 1, 2, 3]);
}
}