use pounce_common::types::{Index, Number};
#[derive(Debug, Clone)]
pub struct EqRow {
pub cols: Vec<Index>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LicqVerdict {
Full,
EmptyRow(Index),
OverDetermined { m_eq: Index, n: Index },
StructuralRank(Index),
}
pub fn licq_check(rows: &[EqRow], n: Index) -> LicqVerdict {
let m = rows.len() as Index;
if m == 0 {
return LicqVerdict::Full;
}
if m > n {
return LicqVerdict::OverDetermined { m_eq: m, n };
}
for (i, r) in rows.iter().enumerate() {
if r.cols.is_empty() {
return LicqVerdict::EmptyRow(i as Index);
}
}
let rank = bipartite_matching_rank(rows, n as usize);
if rank == rows.len() {
LicqVerdict::Full
} else {
LicqVerdict::StructuralRank(rank as Index)
}
}
fn bipartite_matching_rank(rows: &[EqRow], n: usize) -> usize {
let mut match_col: Vec<isize> = vec![-1; n];
let mut count = 0;
for (u, row) in rows.iter().enumerate() {
let mut seen = vec![false; n];
if try_augment(u, row, rows, &mut match_col, &mut seen) {
count += 1;
}
}
count
}
fn try_augment(
u: usize,
row: &EqRow,
rows: &[EqRow],
match_col: &mut [isize],
seen: &mut [bool],
) -> bool {
for &col in &row.cols {
let c = col as usize;
if c >= seen.len() || seen[c] {
continue;
}
seen[c] = true;
if match_col[c] < 0
|| try_augment(
match_col[c] as usize,
&rows[match_col[c] as usize],
rows,
match_col,
seen,
)
{
match_col[c] = u as isize;
return true;
}
}
false
}
pub fn eq_rows_from_triples(
eq_row_indices: &[usize],
triples: &[(Index, Index, Number)],
inner_m: usize,
) -> Vec<EqRow> {
use std::collections::BTreeSet;
let mut by_row: Vec<BTreeSet<Index>> = vec![BTreeSet::new(); inner_m];
for &(i, j, v) in triples {
if v == 0.0 {
continue;
}
let i = i as usize;
if i < inner_m {
by_row[i].insert(j);
}
}
eq_row_indices
.iter()
.map(|&i| EqRow {
cols: by_row[i].iter().copied().collect(),
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn row(cols: &[Index]) -> EqRow {
EqRow {
cols: cols.to_vec(),
}
}
#[test]
fn no_equality_rows_is_full_rank() {
assert_eq!(licq_check(&[], 5), LicqVerdict::Full);
}
#[test]
fn over_determined_caught() {
let rows = vec![row(&[0]), row(&[0]), row(&[0])];
assert!(matches!(
licq_check(&rows, 2),
LicqVerdict::OverDetermined { m_eq: 3, n: 2 }
));
}
#[test]
fn empty_row_caught() {
let rows = vec![row(&[0]), row(&[])];
assert!(matches!(licq_check(&rows, 5), LicqVerdict::EmptyRow(1)));
}
#[test]
fn duplicate_singletons_dropped_by_matching() {
let rows = vec![row(&[0]), row(&[0])];
assert!(matches!(
licq_check(&rows, 5),
LicqVerdict::StructuralRank(1)
));
}
#[test]
fn distinct_singletons_full_rank() {
let rows = vec![row(&[0]), row(&[1]), row(&[2])];
assert_eq!(licq_check(&rows, 5), LicqVerdict::Full);
}
#[test]
fn matching_via_augmenting_path() {
let rows = vec![row(&[0, 1]), row(&[0])];
assert_eq!(licq_check(&rows, 2), LicqVerdict::Full);
}
}