#![allow(clippy::needless_pass_by_value)]
use rustc_hash::FxHashMap;
use std::{fmt, hash::Hash};
use crate::{
IndexedDomain, IndexedValue, ToIndex, bitset::BitSet, pointer::PointerFamily, set::IndexSet,
};
pub struct IndexMatrix<'a, R, C: IndexedValue + 'a, S: BitSet, P: PointerFamily<'a>> {
pub(crate) matrix: FxHashMap<R, IndexSet<'a, C, S, P>>,
empty_set: IndexSet<'a, C, S, P>,
col_domain: P::Pointer<IndexedDomain<C>>,
}
impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
pub fn new(col_domain: &P::Pointer<IndexedDomain<C>>) -> Self {
IndexMatrix {
matrix: FxHashMap::default(),
empty_set: IndexSet::new(col_domain),
col_domain: col_domain.clone(),
}
}
pub(crate) fn ensure_row(&mut self, row: R) -> &mut IndexSet<'a, C, S, P> {
self.matrix
.entry(row)
.or_insert_with(|| self.empty_set.clone())
}
pub fn insert<M>(&mut self, row: R, col: impl ToIndex<C, M>) -> bool {
let col = col.to_index(&self.col_domain);
self.ensure_row(row).insert(col)
}
pub fn union_into_row(&mut self, into: R, from: &IndexSet<'a, C, S, P>) -> bool {
self.ensure_row(into).union_changed(from)
}
pub fn union_rows(&mut self, from: R, to: R) -> bool {
if from == to {
return false;
}
self.ensure_row(from.clone());
self.ensure_row(to.clone());
let [Some(from), Some(to)] =
(unsafe { self.matrix.get_disjoint_unchecked_mut([&from, &to]) })
else {
unreachable!()
};
to.union_changed(from)
}
pub fn row(&self, row: &R) -> impl Iterator<Item = &C> {
self.matrix.get(row).into_iter().flat_map(IndexSet::iter)
}
pub fn rows(&self) -> impl ExactSizeIterator<Item = (&R, &IndexSet<'a, C, S, P>)> {
self.matrix.iter()
}
pub fn row_set(&self, row: &R) -> &IndexSet<'a, C, S, P> {
self.matrix.get(row).unwrap_or(&self.empty_set)
}
pub fn clear_row(&mut self, row: &R) {
self.matrix.remove(row);
}
pub fn col_domain(&self) -> &P::Pointer<IndexedDomain<C>> {
&self.col_domain
}
}
impl<'a, R, C, S, P> IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone + 'a,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
pub fn transpose<T, M>(
&self,
row_domain: &P::Pointer<IndexedDomain<T>>,
) -> IndexMatrix<'a, C::Index, T, S, P>
where
T: IndexedValue + 'a,
R: ToIndex<T, M>,
{
let mut mtx = IndexMatrix::new(row_domain);
for (row, cols) in self.rows() {
for col in cols.indices() {
mtx.insert(col, row.clone());
}
}
mtx
}
}
impl<'a, R, C, S, P> PartialEq for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
fn eq(&self, other: &Self) -> bool {
self.matrix == other.matrix
}
}
impl<'a, R, C, S, P> Eq for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
}
impl<'a, R, C, S, P> Clone for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone,
C: IndexedValue + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
fn clone(&self) -> Self {
Self {
matrix: self.matrix.clone(),
empty_set: self.empty_set.clone(),
col_domain: self.col_domain.clone(),
}
}
fn clone_from(&mut self, source: &Self) {
for col in self.matrix.values_mut() {
col.clear();
}
for (row, col) in &source.matrix {
self.ensure_row(row.clone()).clone_from(col);
}
self.empty_set.clone_from(&source.empty_set);
self.col_domain.clone_from(&source.col_domain);
}
}
impl<'a, R, C, S, P> fmt::Debug for IndexMatrix<'a, R, C, S, P>
where
R: PartialEq + Eq + Hash + Clone + fmt::Debug,
C: IndexedValue + fmt::Debug + 'a,
S: BitSet,
P: PointerFamily<'a>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.rows()).finish()
}
}
#[cfg(test)]
mod test {
use crate::{IndexedDomain, test_utils::TestIndexMatrix};
use std::rc::Rc;
fn mk(s: &str) -> String {
s.to_string()
}
#[test]
fn test_indexmatrix() {
let col_domain = Rc::new(IndexedDomain::from_iter([mk("a"), mk("b"), mk("c")]));
let mut mtx = TestIndexMatrix::new(&col_domain);
mtx.insert(0, mk("b"));
mtx.insert(1, mk("c"));
assert_eq!(mtx.row(&0).collect::<Vec<_>>(), vec!["b"]);
assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["c"]);
assert!(mtx.union_rows(0, 1));
assert_eq!(mtx.row(&1).collect::<Vec<_>>(), vec!["b", "c"]);
}
}